diff --git a/DEPS b/DEPS index ab68dde..c48ac9c 100644 --- a/DEPS +++ b/DEPS
@@ -245,15 +245,15 @@ # Three lines of non-changing comments so that # the commit queue can handle CLs rolling Skia # and whatever else without interference from each other. - 'skia_revision': '3887123079922a846fa43260740d2b86de450d91', + 'skia_revision': '93e0041c6ca7e2e9fa8047ef67ed3804f2ccbef9', # Three lines of non-changing comments so that # the commit queue can handle CLs rolling V8 # and whatever else without interference from each other. - 'v8_revision': '239e6102fad60fb3317e8f4349f69df760291ef1', + 'v8_revision': '8280a12b2bdf7c22235d44f27998d6de3b47a3bd', # Three lines of non-changing comments so that # the commit queue can handle CLs rolling ANGLE # and whatever else without interference from each other. - 'angle_revision': '8ab13284ad8c9f05d1cce6b4fd6b3f51a2e1e9b3', + 'angle_revision': '40c5cb255c0a07bdab574aa076ee603e7d791ab3', # Three lines of non-changing comments so that # the commit queue can handle CLs rolling SwiftShader # and whatever else without interference from each other. @@ -268,7 +268,7 @@ # # Note this revision should be updated with # third_party/boringssl/roll_boringssl.py, not roll-dep. - 'boringssl_revision': '3a667d10e94186fd503966f5638e134fe9fb4080', + 'boringssl_revision': '295b31324f8c557dcd3c1c831857e33a7f23bc52', # Three lines of non-changing comments so that # the commit queue can handle CLs rolling google-toolbox-for-mac # and whatever else without interference from each other. @@ -288,7 +288,7 @@ # Three lines of non-changing comments so that # the commit queue can handle CLs rolling NaCl # and whatever else without interference from each other. - 'nacl_revision': '8694f994aeb5f8d7bb6b1bb5ec68e490366d9e61', + 'nacl_revision': '38e06bc9954cbe8582f557d525d2effbbc0d849e', # Three lines of non-changing comments so that # the commit queue can handle CLs rolling freetype # and whatever else without interference from each other. @@ -320,7 +320,7 @@ # Three lines of non-changing comments so that # the commit queue can handle CLs rolling devtools-frontend # and whatever else without interference from each other. - 'devtools_frontend_revision': 'd3425b95d0ff161d2615702f7a686cc8702980a7', + 'devtools_frontend_revision': '7d228add956c220c44e95914472c0d56b969152c', # Three lines of non-changing comments so that # the commit queue can handle CLs rolling libprotobuf-mutator # and whatever else without interference from each other. @@ -360,11 +360,11 @@ # Three lines of non-changing comments so that # the commit queue can handle CLs rolling feed # and whatever else without interference from each other. - 'dawn_revision': '8b4d03d30231aec2755322a9f2200f3ad05d51f7', + 'dawn_revision': '03b3594f79a2fb322dad39b757dbe706b905666c', # Three lines of non-changing comments so that # the commit queue can handle CLs rolling feed # and whatever else without interference from each other. - 'quiche_revision': 'cb6ab3eb3fdf36c5db3a4a845f4af1c7c5d68b7e', + 'quiche_revision': 'b0fda3339a7e0c24d8bea7cfb6d9ea306e2686a1', # Three lines of non-changing comments so that # the commit queue can handle CLs rolling ios_webkit # and whatever else without interference from each other. @@ -427,7 +427,7 @@ 'libcxx_revision': '79a2e924d96e2fc1e4b937c42efd08898fa472d7', # GN CIPD package version. - 'gn_version': 'git_revision:54284c12607e2818293157cd76d29d03a36bfd68', + 'gn_version': 'git_revision:f1b1412521b41e47118b29863224171e434a27a2', } # Only these hosts are allowed for dependencies in this DEPS file. @@ -805,7 +805,7 @@ 'packages': [ { 'package': 'chromium/third_party/androidx', - 'version': 'SBd1_afac7GmjsqtDzNI7TUwKr4KVbHuX3oqnduAW9EC', + 'version': 'hyrO0XEiBqf22_8YM-FkRhjqAyxUWeSZ_9LZqPJDkYIC', }, ], 'condition': 'checkout_android', @@ -1024,7 +1024,7 @@ # Tools used when building Chrome for Chrome OS. This affects both the Simple # Chrome workflow, as well as the chromeos-chrome ebuild. 'src/third_party/chromite': { - 'url': Var('chromium_git') + '/chromiumos/chromite.git' + '@' + '81093092ece1b06a493257f50a23ec429c01529a', + 'url': Var('chromium_git') + '/chromiumos/chromite.git' + '@' + 'd0f15301899309b439cf5710ef2752f8075d4dc3', 'condition': 'checkout_chromeos', }, @@ -1044,7 +1044,7 @@ }, 'src/third_party/depot_tools': - Var('chromium_git') + '/chromium/tools/depot_tools.git' + '@' + '44dda9648cce2a12c67aa96a498adfb3245d38e7', + Var('chromium_git') + '/chromium/tools/depot_tools.git' + '@' + '59e6796cd25bb2fd957d411a23ab33bea36e2732', 'src/third_party/devtools-frontend/src': Var('chromium_git') + '/devtools/devtools-frontend' + '@' + Var('devtools_frontend_revision'), @@ -1609,7 +1609,7 @@ 'src/third_party/usrsctp/usrsctplib': Var('chromium_git') + '/external/github.com/sctplab/usrsctp' + '@' + '62d7d0c928c9a040dce96aa2f16c00e7e67d59cb', - 'src/third_party/vulkan-deps': '{chromium_git}/vulkan-deps@57e7495dc76d1dd940f6603db80adcdd62c27d5a', + 'src/third_party/vulkan-deps': '{chromium_git}/vulkan-deps@9e7f44a63b2e5265935eeddb98159a0ddb8205ef', 'src/third_party/vulkan_memory_allocator': Var('chromium_git') + '/external/github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator.git' + '@' + '5e49f57a6e71a026a54eb42e366de09a4142d24e', @@ -1648,7 +1648,7 @@ Var('chromium_git') + '/external/github.com/gpuweb/cts.git' + '@' + 'fca7b339442bd70c5dc49bb33ee7f9466b560a97', 'src/third_party/webrtc': - Var('webrtc_git') + '/src.git' + '@' + 'd908d74fac46c58b80c1120b982d3f1b407292f3', + Var('webrtc_git') + '/src.git' + '@' + '651586c4e1b97a4615b8522a3a3d56bb4376e72f', 'src/third_party/libgifcodec': Var('skia_git') + '/libgifcodec' + '@'+ Var('libgifcodec_revision'), @@ -1730,7 +1730,7 @@ Var('chromium_git') + '/v8/v8.git' + '@' + Var('v8_revision'), 'src-internal': { - 'url': 'https://chrome-internal.googlesource.com/chrome/src-internal.git@82d92b1bf860b2e239f7a62a1eedbb306caa055a', + 'url': 'https://chrome-internal.googlesource.com/chrome/src-internal.git@5f4f6cba430b42f7ef03457af99737cf73aa908e', 'condition': 'checkout_src_internal', }, @@ -1760,7 +1760,7 @@ 'packages': [ { 'package': 'chromeos_internal/apps/help_app/app', - 'version': 'ALKqv_hdzRbzx6zAqVrRyRoC7_l0PzxpaXcMy8FY12gC', + 'version': 'iznE7Bnmb5EY-LvHVo0U6uawQvfDk5Rp7Q5FEh6DSd8C', }, ], 'condition': 'checkout_chromeos and checkout_src_internal', @@ -1771,7 +1771,7 @@ 'packages': [ { 'package': 'chromeos_internal/apps/media_app/app', - 'version': 'AbBYfsC4757NystoBycRj-9JENE23kFNcGpJfUoBw5IC', + 'version': 'ww1_CusDQQSjKCZTDoFyYzAl_CHJPeA2OSIxwVGUAPkC', }, ], 'condition': 'checkout_chromeos and checkout_src_internal',
diff --git a/ash/app_list/app_list_bubble_event_filter.cc b/ash/app_list/app_list_bubble_event_filter.cc index 226794f..5c814d79 100644 --- a/ash/app_list/app_list_bubble_event_filter.cc +++ b/ash/app_list/app_list_bubble_event_filter.cc
@@ -5,9 +5,12 @@ #include "ash/app_list/app_list_bubble_event_filter.h" #include "ash/bubble/bubble_utils.h" +#include "ash/shelf/hotseat_widget.h" +#include "ash/shelf/shelf.h" #include "ash/shell.h" #include "base/callback.h" #include "base/check.h" +#include "ui/aura/window.h" #include "ui/events/event.h" #include "ui/gfx/geometry/point.h" #include "ui/gfx/geometry/rect.h" @@ -66,6 +69,16 @@ return; } + // Ignore clicks in the shelf area containing app icons. + aura::Window* target = static_cast<aura::Window*>(event.target()); + if (target) { + Shelf* shelf = Shelf::ForWindow(target); + if (target == shelf->hotseat_widget()->GetNativeWindow() && + shelf->hotseat_widget()->EventTargetsShelfView(event)) { + return; + } + } + on_click_outside_.Run(); }
diff --git a/ash/app_list/app_list_bubble_presenter.cc b/ash/app_list/app_list_bubble_presenter.cc index b39195eb..0006a1a7 100644 --- a/ash/app_list/app_list_bubble_presenter.cc +++ b/ash/app_list/app_list_bubble_presenter.cc
@@ -22,6 +22,7 @@ #include "ash/shelf/shelf.h" #include "ash/shelf/shelf_navigation_widget.h" #include "ash/shell.h" +#include "ash/system/tray/tray_background_view.h" #include "ash/wm/container_finder.h" #include "base/bind.h" #include "base/check.h" @@ -302,7 +303,7 @@ // Check for widget because the code could be waiting for zero-state search // results before first show. if (bubble_widget_) - bubble_widget_->Hide(); + OnHideAnimationEnded(); } controller_->OnVisibilityChanged(/*visible=*/false, display_id); @@ -404,6 +405,10 @@ } void AppListBubblePresenter::OnHideAnimationEnded() { + // Hiding the launcher causes a window activation change. If the launcher is + // hiding because the user opened a system tray bubble, don't immediately + // close the bubble in response. + auto lock = TrayBackgroundView::DisableCloseBubbleOnWindowActivated(); bubble_widget_->Hide(); controller_->MaybeCloseAssistant();
diff --git a/ash/app_list/app_list_bubble_presenter_unittest.cc b/ash/app_list/app_list_bubble_presenter_unittest.cc index 8fe30dc..35f3e53 100644 --- a/ash/app_list/app_list_bubble_presenter_unittest.cc +++ b/ash/app_list/app_list_bubble_presenter_unittest.cc
@@ -20,7 +20,9 @@ #include "ash/shelf/shelf.h" #include "ash/shelf/shelf_navigation_widget.h" #include "ash/shell.h" +#include "ash/system/unified/unified_system_tray.h" #include "ash/test/ash_test_base.h" +#include "ash/test/layer_animation_stopped_waiter.h" #include "ash/test/test_widget_builder.h" #include "base/run_loop.h" #include "base/strings/string_number_conversions.h" @@ -490,6 +492,34 @@ EXPECT_TRUE(presenter->IsShowing()); } +// Regression test for https://crbug.com/1285443. +TEST_F(AppListBubblePresenterTest, CanOpenBubbleThenOpenSystemTray) { + // Enable animations. + base::test::ScopedFeatureList features( + features::kProductivityLauncherAnimation); + ui::ScopedAnimationDurationScaleMode duration( + ui::ScopedAnimationDurationScaleMode::NON_ZERO_DURATION); + + // Create a widget, which will activate itself when the launcher closes. + std::unique_ptr<views::Widget> widget = + TestWidgetBuilder().SetShow(true).BuildOwnsNativeWidget(); + + // Show the launcher. + AppListBubblePresenter* presenter = GetBubblePresenter(); + presenter->Show(GetPrimaryDisplay().id()); + + // Click on the system tray. + LeftClickOn(GetPrimaryUnifiedSystemTray()); + + // Wait for launcher animations to end. + LayerAnimationStoppedWaiter().Wait( + presenter->bubble_view_for_test()->layer()); + + // Launcher is closed and system tray is open. + EXPECT_FALSE(presenter->IsShowing()); + EXPECT_TRUE(GetPrimaryUnifiedSystemTray()->IsBubbleShown()); +} + TEST_F(AppListBubblePresenterTest, BubbleOpensInBottomLeftForBottomShelf) { GetPrimaryShelf()->SetAlignment(ShelfAlignment::kBottom);
diff --git a/ash/app_list/app_list_presenter_unittest.cc b/ash/app_list/app_list_presenter_unittest.cc index 42ada8a..621ec3fa 100644 --- a/ash/app_list/app_list_presenter_unittest.cc +++ b/ash/app_list/app_list_presenter_unittest.cc
@@ -3829,8 +3829,12 @@ EXPECT_EQ(SHELF_AUTO_HIDE_SHOWN, shelf->GetAutoHideState()); } -// TODO(crbug.com/1273162): Fix for ProductivityLauncher. -TEST_F(AppListPresenterTest, ClickingShelfArrowDoesNotHideAppList) { +TEST_P(AppListPresenterTest, ClickingShelfArrowDoesNotHideAppList) { + // Parameterize by ProductivityLauncher. + base::test::ScopedFeatureList feature_list; + feature_list.InitWithFeatureState(features::kProductivityLauncher, + GetParam()); + // Add enough shelf items for the shelf to enter overflow. Shelf* const shelf = GetPrimaryShelf(); ScrollableShelfView* const scrollable_shelf_view =
diff --git a/ash/app_list/views/app_list_folder_view.cc b/ash/app_list/views/app_list_folder_view.cc index 579f467..9a5ee18 100644 --- a/ash/app_list/views/app_list_folder_view.cc +++ b/ash/app_list/views/app_list_folder_view.cc
@@ -630,6 +630,7 @@ contents_container_->AddChildView(std::make_unique<PagedAppsGridView>( contents_view, a11y_announcer_, this, /*folder_controller=*/nullptr, /*container_delegate=*/this)); + contents_container_->layer()->SetMasksToBounds(true); items_grid_view_ = items_grid_view; items_grid_view_->Init();
diff --git a/ash/app_list/views/app_list_item_view.h b/ash/app_list/views/app_list_item_view.h index 077ea40..6a14412 100644 --- a/ash/app_list/views/app_list_item_view.h +++ b/ash/app_list/views/app_list_item_view.h
@@ -232,8 +232,8 @@ private: friend class AppListItemViewProductivityLauncherTest; + friend class AppListMainViewTest; friend class test::AppsGridViewTest; - friend class test::AppListMainViewTest; class IconImageView; class AppNotificationIndicatorView;
diff --git a/ash/app_list/views/app_list_main_view_unittest.cc b/ash/app_list/views/app_list_main_view_unittest.cc index d90cd4c..7adb0cffc 100644 --- a/ash/app_list/views/app_list_main_view_unittest.cc +++ b/ash/app_list/views/app_list_main_view_unittest.cc
@@ -19,8 +19,10 @@ #include "ash/app_list/views/page_switcher.h" #include "ash/app_list/views/paged_apps_grid_view.h" #include "ash/app_list/views/search_box_view.h" +#include "ash/constants/ash_features.h" #include "ash/public/cpp/app_list/app_list_features.h" #include "ash/public/cpp/test/test_app_list_color_provider.h" +#include "ash/style/ash_color_provider.h" #include "base/test/scoped_feature_list.h" #include "ui/compositor/layer.h" #include "ui/compositor/scoped_animation_duration_scale_mode.h" @@ -34,17 +36,20 @@ #include "ui/views/widget/widget.h" namespace ash { -namespace test { - namespace { const int kInitialItems = 2; } // namespace -class AppListMainViewTest : public views::ViewsTestBase { +// Parameterized by ProductivityLauncher. +class AppListMainViewTest : public views::ViewsTestBase, + public testing::WithParamInterface<bool> { public: - AppListMainViewTest() = default; + AppListMainViewTest() { + feature_list_.InitWithFeatureState(features::kProductivityLauncher, + GetParam()); + } AppListMainViewTest(const AppListMainViewTest& other) = delete; AppListMainViewTest& operator=(const AppListMainViewTest& other) = delete; ~AppListMainViewTest() override = default; @@ -57,7 +62,7 @@ ui::ScopedAnimationDurationScaleMode::ZERO_DURATION); // Create, and show the app list is fullscreen apps grid state. - delegate_ = std::make_unique<AppListTestViewDelegate>(); + delegate_ = std::make_unique<test::AppListTestViewDelegate>(); app_list_view_ = new AppListView(delegate_.get()); app_list_view_->InitView(GetContext()); app_list_view_->Show(AppListViewState::kFullscreenAllApps, @@ -236,37 +241,43 @@ } protected: - TestAppListColorProvider color_provider_; // Needed by AppListView. + base::test::ScopedFeatureList feature_list_; + TestAppListColorProvider app_list_color_provider_; // Needed by AppListView. + AshColorProvider ash_color_provider_; // Needed by ContinueContainer. AppListView* app_list_view_ = nullptr; // Owned by native widget. - std::unique_ptr<AppListTestViewDelegate> delegate_; + std::unique_ptr<test::AppListTestViewDelegate> delegate_; private: std::unique_ptr<ui::ScopedAnimationDurationScaleMode> zero_duration_mode_; }; +INSTANTIATE_TEST_SUITE_P(ProductivityLauncher, + AppListMainViewTest, + testing::Bool()); + // Tests that the close button becomes invisible after close button is clicked. -TEST_F(AppListMainViewTest, CloseButtonInvisibleAfterCloseButtonClicked) { +TEST_P(AppListMainViewTest, CloseButtonInvisibleAfterCloseButtonClicked) { PressKeyInSearchBox(ui::VKEY_A); ClickButton(search_box_view()->close_button()); EXPECT_FALSE(search_box_view()->close_button()->GetVisible()); } // Tests that the search box becomes empty after close button is clicked. -TEST_F(AppListMainViewTest, SearchBoxEmptyAfterCloseButtonClicked) { +TEST_P(AppListMainViewTest, SearchBoxEmptyAfterCloseButtonClicked) { PressKeyInSearchBox(ui::VKEY_A); ClickButton(search_box_view()->close_button()); EXPECT_TRUE(search_box_view()->search_box()->GetText().empty()); } // Tests that the search box is no longer active after close button is clicked. -TEST_F(AppListMainViewTest, SearchBoxActiveAfterCloseButtonClicked) { +TEST_P(AppListMainViewTest, SearchBoxActiveAfterCloseButtonClicked) { PressKeyInSearchBox(ui::VKEY_A); ClickButton(search_box_view()->close_button()); EXPECT_FALSE(search_box_view()->is_search_box_active()); } // Tests changing the AppListModel when switching profiles. -TEST_F(AppListMainViewTest, ModelChanged) { +TEST_P(AppListMainViewTest, ModelChanged) { delegate_->GetTestModel()->PopulateApps(kInitialItems); EXPECT_EQ(kInitialItems, GetRootViewModel()->view_size()); @@ -283,7 +294,7 @@ // Tests dragging an item out of a single item folder and dropping it onto the // page switcher. Regression test for http://crbug.com/415530/. -TEST_F(AppListMainViewTest, DragReparentItemOntoPageSwitcher) { +TEST_P(AppListMainViewTest, DragReparentItemOntoPageSwitcher) { AppListItemView* folder_item_view = CreateAndOpenSingleItemFolder(); ASSERT_TRUE(folder_item_view); @@ -318,7 +329,7 @@ // Test that an interrupted drag while reparenting an item from a folder, when // canceled via the root grid, correctly forwards the cancelation to the drag // ocurring from the folder. -TEST_F(AppListMainViewTest, MouseDragItemOutOfFolderWithCancel) { +TEST_P(AppListMainViewTest, MouseDragItemOutOfFolderWithCancel) { CreateAndOpenSingleItemFolder(); AppListItemView* dragged = StartDragForReparent(0); @@ -341,7 +352,7 @@ // Test that dragging an app out of a single item folder and reparenting it // back into its original folder results in a cancelled reparent. This is a // regression test for http://crbug.com/429083. -TEST_F(AppListMainViewTest, ReparentSingleItemOntoSelf) { +TEST_P(AppListMainViewTest, ReparentSingleItemOntoSelf) { // Add a folder with 1 item. AppListItemView* folder_item_view = CreateAndOpenSingleItemFolder(); std::string folder_id = folder_item_view->item()->id(); @@ -369,5 +380,4 @@ EXPECT_EQ(1u, folder_item->item_list()->item_count()); } -} // namespace test } // namespace ash
diff --git a/ash/app_list/views/apps_container_view.cc b/ash/app_list/views/apps_container_view.cc index 469574f..3258af4 100644 --- a/ash/app_list/views/apps_container_view.cc +++ b/ash/app_list/views/apps_container_view.cc
@@ -29,6 +29,7 @@ #include "ash/public/cpp/app_list/app_list_switches.h" #include "ash/resources/vector_icons/vector_icons.h" #include "ash/search_box/search_box_constants.h" +#include "ash/shelf/gradient_layer_delegate.h" #include "ash/style/ash_color_provider.h" #include "ash/style/style_util.h" #include "base/bind.h" @@ -119,9 +120,8 @@ // Minimal horizontal distance from the page switcher to apps container bounds. constexpr int kPageSwitcherEndMargin = 16; -// The apps grid view's fadeout zone height, that contains a fadeout mask, and -// which is used as a margin for the `AppsGridView` contents. -constexpr int kGridFadeoutZoneHeight = 24; +// The vertical margin for the `AppsGridView` contents. +constexpr int kGridVerticalMargin = 24; // The space between sort ui controls (including sort button and redo button). constexpr int kSortUiControlSpacing = 10; @@ -143,6 +143,10 @@ // The width of the separator. constexpr int kSeparatorWidth = 240; +// The actual height of the fadeout gradient mask at the top and bottom of the +// `scrollable_container_`. +constexpr int kDefaultFadeoutMaskHeight = 16; + // SortUiControl --------------------------------------------------------------- class SortUiControl : public views::ImageButton { @@ -370,10 +374,11 @@ AppListViewDelegate* view_delegate = contents_view_->GetAppListMainView()->view_delegate(); + // The bounds of the |scrollable_container_| will visually clip the + // |continue_container_| and |apps_grid_view_| layers. + scrollable_container_->layer()->SetMasksToBounds(true); + if (features::IsProductivityLauncherEnabled()) { - // The bounds of the |scrollable_container_| will visually clip the - // |continue_container_| layer during page changes. - scrollable_container_->layer()->SetMasksToBounds(true); continue_container_ = scrollable_container_->AddChildView( std::make_unique<ContinueContainer>(this, view_delegate)); } else { @@ -395,9 +400,9 @@ /*container_delegate=*/this), 0); apps_grid_view_->Init(); - + apps_grid_view_->pagination_model()->AddObserver(this); if (features::IsProductivityLauncherEnabled()) - apps_grid_view_->pagination_model()->AddObserver(this); + apps_grid_view_->set_margin_for_gradient_mask(kDefaultFadeoutMaskHeight); // Page switcher should be initialized after AppsGridView. auto page_switcher = std::make_unique<PageSwitcher>( @@ -429,9 +434,7 @@ AppsContainerView::~AppsContainerView() { AppListModelProvider::Get()->RemoveObserver(this); - - if (features::IsProductivityLauncherEnabled()) - apps_grid_view_->pagination_model()->RemoveObserver(this); + apps_grid_view_->pagination_model()->RemoveObserver(this); // Make sure |page_switcher_| is deleted before |apps_grid_view_| because // |page_switcher_| uses the PaginationModel owned by |apps_grid_view_|. @@ -471,7 +474,7 @@ contents_view_->GetSearchBoxSize(AppListState::kStateApps)), 0, 0); // Subtracts apps grid view insets from space available for apps grid. - available_bounds.Inset(0, kGridFadeoutZoneHeight); + available_bounds.Inset(0, kGridVerticalMargin); return available_bounds; } @@ -610,6 +613,11 @@ // PaginationModelObserver: void AppsContainerView::SelectedPageChanged(int old_selected, int new_selected) { + // There is no |continue_container_| to translate when productivity launcher + // is not enabled, so return early. + if (!features::IsProductivityLauncherEnabled()) + return; + // |continue_container_| is hidden above the grid when not on the first page. gfx::Transform transform; gfx::Vector2dF translate; @@ -619,6 +627,11 @@ } void AppsContainerView::TransitionChanged() { + // There is no |continue_container_| to translate when productivity launcher + // is not enabled, so return early. + if (!features::IsProductivityLauncherEnabled()) + return; + auto* pagination_model = apps_grid_view_->pagination_model(); const PaginationModel::Transition& transition = pagination_model->transition(); @@ -646,6 +659,28 @@ } } +void AppsContainerView::TransitionStarted() { + MaybeCreateGradientMask(); +} + +void AppsContainerView::TransitionEnded() { + // TODO(crbug.com/1285184): Sometimes gradient mask is not removed because + // this function does not get called in some cases. + + // Gradient mask is no longer necessary once transition is finished. + MaybeRemoveGradientMask(); +} + +void AppsContainerView::ScrollStarted() { + MaybeCreateGradientMask(); +} + +void AppsContainerView::ScrollEnded() { + // Need to reset the mask because transition will not happen in some + // cases. (See https://crbug.com/1049275) + MaybeRemoveGradientMask(); +} + // PagedAppsGridView::ContainerDelegate: bool AppsContainerView::IsPointWithinPageFlipBuffer( const gfx::Point& point_in_apps_grid) const { @@ -674,6 +709,36 @@ point_in_parent.y() < kBottomDragBufferMax; } +void AppsContainerView::MaybeCreateGradientMask() { + if (features::IsBackgroundBlurEnabled()) { + if (!layer()->layer_mask_layer() && !gradient_layer_delegate_) { + gradient_layer_delegate_ = std::make_unique<GradientLayerDelegate>(); + UpdateGradientMaskBounds(); + } + if (gradient_layer_delegate_) { + scrollable_container_->layer()->SetMaskLayer( + gradient_layer_delegate_->layer()); + } + } +} + +void AppsContainerView::MaybeRemoveGradientMask() { + if (scrollable_container_->layer()->layer_mask_layer() && + !keep_gradient_mask_for_cardified_state_) { + scrollable_container_->layer()->SetMaskLayer(nullptr); + } +} + +void AppsContainerView::OnCardifiedStateStarted() { + keep_gradient_mask_for_cardified_state_ = true; + MaybeCreateGradientMask(); +} + +void AppsContainerView::OnCardifiedStateEnded() { + keep_gradient_mask_for_cardified_state_ = false; + MaybeRemoveGradientMask(); +} + // RecentAppsView::Delegate: void AppsContainerView::MoveFocusUpFromRecents() { DCHECK(!GetRecentApps()->children().empty()); @@ -848,19 +913,31 @@ GetContentsBounds(), contents_view_->GetSearchBoxSize(AppListState::kStateApps)); gfx::Rect grid_rect = rect; - grid_rect.Inset(margins.left(), kGridFadeoutZoneHeight, margins.right(), + grid_rect.Inset(margins.left(), kGridVerticalMargin, margins.right(), margins.bottom()); // The grid rect insets are added to calculated margins. Given that the // grid bounds rect should include insets, they have to be removed from // added margins. grid_rect.Inset(-grid_insets); - scrollable_container_->SetBoundsRect(grid_rect); + + gfx::Rect scrollable_bounds = grid_rect; + // With productivity launcher enabled, add space to the top of the + // `scrollable_container_` bounds to make room for the gradient mask to be + // placed above the continue section. + if (features::IsProductivityLauncherEnabled()) + scrollable_bounds.Inset(0, -kDefaultFadeoutMaskHeight, 0, 0); + scrollable_container_->SetBoundsRect(scrollable_bounds); + + if (gradient_layer_delegate_) + UpdateGradientMaskBounds(); + bool first_page_config_changed = false; if (features::IsProductivityLauncherEnabled()) { const int continue_container_height = continue_container_->GetPreferredSize().height(); - continue_container_->SetBoundsRect( - gfx::Rect(0, 0, grid_rect.width(), continue_container_height)); + continue_container_->SetBoundsRect(gfx::Rect(0, kDefaultFadeoutMaskHeight, + grid_rect.width(), + continue_container_height)); // Setting this offset prevents the app items in the grid from overlapping // with the continue section. first_page_config_changed = apps_grid_view_->ConfigureFirstPagePadding( @@ -872,7 +949,11 @@ // shown in the grid. UpdateTopLevelGridDimensions(); - const gfx::Rect apps_grid_bounds(grid_rect.size()); + gfx::Rect apps_grid_bounds(grid_rect.size()); + // Set the apps grid bounds y to make room for the top gradient mask. + if (features::IsProductivityLauncherEnabled()) + apps_grid_bounds.set_y(kDefaultFadeoutMaskHeight); + if (apps_grid_view_->bounds() != apps_grid_bounds) { apps_grid_view_->SetBoundsRect(apps_grid_bounds); } else if (first_page_config_changed) { @@ -1085,9 +1166,7 @@ ? 0 : kSuggestionChipContainerHeight + kSuggestionChipContainerTopMargin; - // NOTE: Use the fadeout zone height as min top margin to match the apps grid - // view's bottom margin. - return search_box_size.height() + kGridFadeoutZoneHeight + + return search_box_size.height() + kGridVerticalMargin + suggestion_chip_container_size; } @@ -1152,7 +1231,7 @@ // Productivity launcher does not have a preset number of rows per page. // Instead of adjusting the margins to fit a set number of rows, the grid // will change the number of rows to fit within the provided space. - vertical_margin = kGridFadeoutZoneHeight; + vertical_margin = kGridVerticalMargin; } else { vertical_margin = calculate_margin(GetIdealVerticalMargin(), available_height, @@ -1166,9 +1245,9 @@ const int min_horizontal_margin = GetMinHorizontalMarginForAppsGrid(); cached_container_margins_.margins = - gfx::Insets(std::max(vertical_margin, kGridFadeoutZoneHeight), + gfx::Insets(std::max(vertical_margin, kGridVerticalMargin), std::max(horizontal_margin, min_horizontal_margin), - std::max(vertical_margin, kGridFadeoutZoneHeight), + std::max(vertical_margin, kGridVerticalMargin), std::max(horizontal_margin, min_horizontal_margin)); cached_container_margins_.bounds_size = available_bounds.size(); cached_container_margins_.search_box_size = search_box_size; @@ -1390,4 +1469,21 @@ suggestion_chip_container_view_->SetBlurDisabled(false); } +void AppsContainerView::UpdateGradientMaskBounds() { + const gfx::Rect container_bounds = scrollable_container_->bounds(); + const gfx::Rect top_gradient_bounds(0, 0, container_bounds.width(), + kDefaultFadeoutMaskHeight); + const gfx::Rect bottom_gradient_bounds( + 0, container_bounds.height() - kDefaultFadeoutMaskHeight, + container_bounds.width(), kDefaultFadeoutMaskHeight); + + gradient_layer_delegate_->set_start_fade_zone({top_gradient_bounds, + /*fade_in=*/true, + /*is_horizontal=*/false}); + gradient_layer_delegate_->set_end_fade_zone({bottom_gradient_bounds, + /*fade_in=*/false, + /*is_horizonal=*/false}); + gradient_layer_delegate_->layer()->SetBounds(container_bounds); +} + } // namespace ash
diff --git a/ash/app_list/views/apps_container_view.h b/ash/app_list/views/apps_container_view.h index e942e32d..d23a730 100644 --- a/ash/app_list/views/apps_container_view.h +++ b/ash/app_list/views/apps_container_view.h
@@ -31,6 +31,7 @@ class FolderBackgroundView; class PageSwitcher; class SuggestionChipContainerView; +class GradientLayerDelegate; // AppsContainerView contains a root level AppsGridView to render the root level // app items, and a AppListFolderView to render the app items inside the @@ -149,11 +150,17 @@ // PaginationModelObserver: void SelectedPageChanged(int old_selected, int new_selected) override; void TransitionChanged() override; + void TransitionStarted() override; + void TransitionEnded() override; + void ScrollStarted() override; + void ScrollEnded() override; // PagedAppsGridView::ContainerDelegate: bool IsPointWithinPageFlipBuffer(const gfx::Point& point) const override; bool IsPointWithinBottomDragBuffer(const gfx::Point& point, int page_flip_zone_size) const override; + void OnCardifiedStateStarted() override; + void OnCardifiedStateEnded() override; // RecentAppsView::Delegate: void MoveFocusUpFromRecents() override; @@ -252,6 +259,22 @@ // Callback returned by DisableBlur(). void OnSuggestionChipsBlurDisablerReleased(); + // Updates the bounds of the gradient mask to fit the current bounds of the + // `scrollable_container_`. + void UpdateGradientMaskBounds(); + + // Creates a layer mask for gradient alpha when the feature is enabled. The + // gradient appears at the top and bottom of the 'scrollable_container_' to + // create a "fade out" effect when dragging the whole page. + void MaybeCreateGradientMask(); + + // Removes the gradient mask from being set as the mask layer. + void MaybeRemoveGradientMask(); + + // While true, the gradient mask will not be removed as a mask layer until + // cardified state ends. + bool keep_gradient_mask_for_cardified_state_ = false; + ContentsView* const contents_view_; // The app list config used to configure sizing and layout of apps grid items @@ -292,6 +315,8 @@ // arguments (otherwise the margins will be recalculated). CachedContainerMargins cached_container_margins_; + std::unique_ptr<GradientLayerDelegate> gradient_layer_delegate_; + base::WeakPtrFactory<AppsContainerView> weak_ptr_factory_{this}; };
diff --git a/ash/app_list/views/apps_grid_view.cc b/ash/app_list/views/apps_grid_view.cc index 86209f2..60d8fd8 100644 --- a/ash/app_list/views/apps_grid_view.cc +++ b/ash/app_list/views/apps_grid_view.cc
@@ -307,10 +307,7 @@ if (!folder_delegate_) DCHECK(folder_controller_); - // Clip any icons that are outside the grid view's bounds. These icons would - // otherwise be visible to the user when the grid view is off screen. SetPaintToLayer(ui::LAYER_NOT_DRAWN); - layer()->SetMasksToBounds(true); items_container_ = AddChildView(std::make_unique<views::View>()); items_container_->SetPaintToLayer();
diff --git a/ash/app_list/views/apps_grid_view_unittest.cc b/ash/app_list/views/apps_grid_view_unittest.cc index 83969bf..2d12141 100644 --- a/ash/app_list/views/apps_grid_view_unittest.cc +++ b/ash/app_list/views/apps_grid_view_unittest.cc
@@ -3757,7 +3757,12 @@ AppListItemView* folder_view = GetItemViewInTopLevelGrid(0); ASSERT_TRUE(folder_view->is_folder()); - ASSERT_FALSE(apps_grid_view_->layer()->layer_mask_layer()); + + views::View* scrollable_container = app_list_view_->app_list_main_view() + ->contents_view() + ->apps_container_view() + ->scrollable_container_for_test(); + ASSERT_FALSE(scrollable_container->layer()->layer_mask_layer()); // On the first page drag upwards, there should not be a page switch and the // layer mask should make the folder lose blur. @@ -3768,7 +3773,7 @@ EXPECT_TRUE(scroll_update_upwards.handled()); ASSERT_EQ(0, GetPaginationModel()->selected_page()); - ASSERT_TRUE(apps_grid_view_->layer()->layer_mask_layer()); + ASSERT_TRUE(scrollable_container->layer()->layer_mask_layer()); // Continue drag, now switching directions and release. There shouldn't be // any transition and the mask layer should've been reset. @@ -3778,7 +3783,7 @@ EXPECT_TRUE(scroll_end.handled()); EXPECT_FALSE(GetPaginationModel()->has_transition()); - EXPECT_FALSE(apps_grid_view_->layer()->layer_mask_layer()); + EXPECT_FALSE(scrollable_container->layer()->layer_mask_layer()); } TEST_F(AppsGridViewTest, PopulateAppsGridWithTwoApps) {
diff --git a/ash/app_list/views/paged_apps_grid_view.cc b/ash/app_list/views/paged_apps_grid_view.cc index a2930be..9bdd897 100644 --- a/ash/app_list/views/paged_apps_grid_view.cc +++ b/ash/app_list/views/paged_apps_grid_view.cc
@@ -144,67 +144,6 @@ } // namespace -// A layer delegate used for PagedAppsGridView's mask layer, with top and bottom -// gradient fading out zones. -class PagedAppsGridView::FadeoutLayerDelegate : public ui::LayerDelegate { - public: - explicit FadeoutLayerDelegate(int fadeout_mask_height) - : layer_(ui::LAYER_TEXTURED), fadeout_mask_height_(fadeout_mask_height) { - layer_.set_delegate(this); - layer_.SetFillsBoundsOpaquely(false); - } - FadeoutLayerDelegate(const FadeoutLayerDelegate&) = delete; - FadeoutLayerDelegate& operator=(const FadeoutLayerDelegate&) = delete; - - ~FadeoutLayerDelegate() override { layer_.set_delegate(nullptr); } - - ui::Layer* layer() { return &layer_; } - - private: - // ui::LayerDelegate: - // TODO(warx): using a mask is expensive. It would be more efficient to avoid - // the mask for the central area and only use it for top/bottom areas. - void OnPaintLayer(const ui::PaintContext& context) override { - const gfx::Size size = layer()->size(); - gfx::Rect top_rect(0, 0, size.width(), fadeout_mask_height_); - gfx::Rect bottom_rect(0, size.height() - fadeout_mask_height_, size.width(), - fadeout_mask_height_); - - views::PaintInfo paint_info = - views::PaintInfo::CreateRootPaintInfo(context, size); - const auto& prs = paint_info.paint_recording_size(); - - // Pass the scale factor when constructing PaintRecorder so the MaskLayer - // size is not incorrectly rounded (see https://crbug.com/921274). - ui::PaintRecorder recorder(context, paint_info.paint_recording_size(), - static_cast<float>(prs.width()) / size.width(), - static_cast<float>(prs.height()) / size.height(), - nullptr); - - gfx::Canvas* canvas = recorder.canvas(); - // Clear the canvas. - canvas->DrawColor(SK_ColorBLACK, SkBlendMode::kSrc); - // Draw top gradient zone. - cc::PaintFlags flags; - flags.setBlendMode(SkBlendMode::kSrc); - flags.setAntiAlias(false); - flags.setShader(gfx::CreateGradientShader( - gfx::Point(), gfx::Point(0, fadeout_mask_height_), SK_ColorTRANSPARENT, - SK_ColorBLACK)); - canvas->DrawRect(top_rect, flags); - // Draw bottom gradient zone. - flags.setShader(gfx::CreateGradientShader( - gfx::Point(0, size.height() - fadeout_mask_height_), - gfx::Point(0, size.height()), SK_ColorBLACK, SK_ColorTRANSPARENT)); - canvas->DrawRect(bottom_rect, flags); - } - void OnDeviceScaleFactorChanged(float old_device_scale_factor, - float new_device_scale_factor) override {} - - ui::Layer layer_; - const int fadeout_mask_height_; -}; - PagedAppsGridView::PagedAppsGridView( ContentsView* contents_view, AppListA11yAnnouncer* a11y_announcer, @@ -521,9 +460,6 @@ page_height * pages)); } - if (fadeout_layer_delegate_) - fadeout_layer_delegate_->layer()->SetBounds(layer()->bounds()); - CalculateIdealBounds(); for (int i = 0; i < view_model()->view_size(); ++i) { AppListItemView* view = GetItemViewAt(i); @@ -539,7 +475,6 @@ items_container()->layer()->StackAtBottom(background_card); } MaskContainerToBackgroundBounds(); - MaybeCreateGradientMask(); } views::ViewModelUtils::SetViewBoundsToIdealBounds(pulsing_blocks_model()); } @@ -783,7 +718,6 @@ // Drag ends and animation starts. presentation_time_recorder_.reset(); - MaybeCreateGradientMask(); CancelContextMenusOnCurrentPage(); } @@ -828,16 +762,11 @@ void PagedAppsGridView::TransitionEnded() { pagination_metrics_tracker_->Stop(); - - // Gradient mask is no longer necessary once transition is finished. - if (layer()->layer_mask_layer()) - layer()->SetMaskLayer(nullptr); } void PagedAppsGridView::ScrollStarted() { DCHECK(!presentation_time_recorder_); - MaybeCreateGradientMask(); if (IsTabletMode()) { presentation_time_recorder_ = CreatePresentationTimeHistogramRecorder( GetWidget()->GetCompositor(), kPageDragScrollInTabletHistogram, @@ -852,9 +781,6 @@ void PagedAppsGridView::ScrollEnded() { // Scroll can end without triggering state animation. presentation_time_recorder_.reset(); - // Need to reset the mask because transition will not happen in some - // cases. (See https://crbug.com/1049275) - layer()->SetMaskLayer(nullptr); } //////////////////////////////////////////////////////////////////////////////// @@ -936,24 +862,6 @@ return true; } -void PagedAppsGridView::MaybeCreateGradientMask() { - if (!IsInFolder() && features::IsBackgroundBlurEnabled()) { - // TODO(newcomer): Improve implementation of the mask layer so we can - // enable it on all devices https://crbug.com/765292. - if (!layer()->layer_mask_layer()) { - // Always create a new layer. The layer may be recreated by animation, - // and using the mask layer used by the detached layer can lead to - // crash. b/118822974. - if (!fadeout_layer_delegate_) { - fadeout_layer_delegate_ = - std::make_unique<FadeoutLayerDelegate>(GetFadeoutMaskHeight()); - fadeout_layer_delegate_->layer()->SetBounds(layer()->bounds()); - } - layer()->SetMaskLayer(fadeout_layer_delegate_->layer()); - } - } -} - bool PagedAppsGridView::IsValidPageFlipTarget(int page) const { if (pagination_model_.is_valid_page(page)) return true; @@ -1047,7 +955,7 @@ AppendBackgroundCard(); cardified_state_ = true; UpdateTilePadding(); - MaybeCreateGradientMask(); + container_delegate_->OnCardifiedStateStarted(); AnimateCardifiedState(); } @@ -1201,6 +1109,11 @@ if (bounds_animation_for_cardified_state_in_progress_ == 0) { animation_observers_.clear(); OnBoundsAnimatorDone(/*animator=*/nullptr); + + // Notify container that cardified state has ended once ending animations + // are complete. + if (!cardified_state_) + container_delegate_->OnCardifiedStateEnded(); } } @@ -1274,10 +1187,12 @@ void PagedAppsGridView::MaskContainerToBackgroundBounds() { DCHECK(!background_cards_.empty()); - // Mask apps grid container layer to the background card width. - layer()->SetClipRect(gfx::Rect(background_cards_[0]->bounds().x(), 0, - background_cards_[0]->bounds().width(), - layer()->bounds().height())); + // Mask apps grid container layer to the background card width. Optionally + // also include extra height to ensure the top gradient mask is shown as well. + layer()->SetClipRect( + gfx::Rect(background_cards_[0]->bounds().x(), -margin_for_gradient_mask_, + background_cards_[0]->bounds().width(), + layer()->bounds().height() + margin_for_gradient_mask_)); } void PagedAppsGridView::RemoveAllBackgroundCards() {
diff --git a/ash/app_list/views/paged_apps_grid_view.h b/ash/app_list/views/paged_apps_grid_view.h index e05b554..3f5fee0c 100644 --- a/ash/app_list/views/paged_apps_grid_view.h +++ b/ash/app_list/views/paged_apps_grid_view.h
@@ -55,6 +55,13 @@ virtual bool IsPointWithinBottomDragBuffer( const gfx::Point& point, int page_flip_zone_size) const = 0; + + // Triggered when cardified state begins before animations start. + virtual void OnCardifiedStateStarted() {} + + // Triggered when cardified state ends and the bounds animations for leaving + // cardified state have completed. + virtual void OnCardifiedStateEnded() {} }; PagedAppsGridView(ContentsView* contents_view, @@ -175,11 +182,13 @@ int GetFirstPageRowsForTesting() const { return max_rows_on_first_page_; } int GetRowsForTesting() const { return max_rows_; } + void set_margin_for_gradient_mask(int margin) { + margin_for_gradient_mask_ = margin; + } + private: friend class test::AppsGridViewTest; - class FadeoutLayerDelegate; - // Gets the leading padding for app list item grid on the first app list page. // Includes the space reserved for the continue seaction of the app list UI, // and additional vertical tile padding before the first row of apps when @@ -198,11 +207,6 @@ // handled by PagedAppsGridView. bool ShouldHandleDragEvent(const ui::LocatedEvent& event); - // Creates a layer mask for gradient alpha when the feature is enabled. The - // gradient appears at the top and bottom of the apps grid to create a - // "fade out" effect when dragging the whole page. - void MaybeCreateGradientMask(); - // Returns true if the page is the right target to flip to. bool IsValidPageFlipTarget(int page) const; @@ -282,10 +286,6 @@ // between-item drags that move the entire grid, not for app icon drags. gfx::PointF last_mouse_drag_point_; - // Implements a "fade out" gradient at the top and bottom of the grid. Used - // during page flip transitions and for cardified drags. - std::unique_ptr<FadeoutLayerDelegate> fadeout_layer_delegate_; - // Records smoothness of pagination animation. absl::optional<ui::ThroughputTracker> pagination_metrics_tracker_; @@ -326,6 +326,10 @@ std::vector<std::unique_ptr<ui::ImplicitAnimationObserver>> animation_observers_; + // A margin added to the height of the clip rect used for clipping the + // cardified state's background cards. + int margin_for_gradient_mask_ = 0; + base::WeakPtrFactory<PagedAppsGridView> weak_ptr_factory_{this}; };
diff --git a/ash/components/phonehub/BUILD.gn b/ash/components/phonehub/BUILD.gn index fbb1f56f..780020c 100644 --- a/ash/components/phonehub/BUILD.gn +++ b/ash/components/phonehub/BUILD.gn
@@ -220,6 +220,7 @@ sources = [ "browser_tabs_model_controller_unittest.cc", "browser_tabs_model_unittest.cc", + "camera_roll_item_unittest.cc", "camera_roll_manager_impl_unittest.cc", "camera_roll_thumbnail_decoder_impl_unittest.cc", "connection_scheduler_impl_unittest.cc",
diff --git a/ash/components/phonehub/camera_roll_item_unittest.cc b/ash/components/phonehub/camera_roll_item_unittest.cc new file mode 100644 index 0000000..adf394c --- /dev/null +++ b/ash/components/phonehub/camera_roll_item_unittest.cc
@@ -0,0 +1,80 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "ash/components/phonehub/camera_roll_item.h" + +#include "ash/components/phonehub/proto/phonehub_api.pb.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "ui/gfx/image/image.h" +#include "ui/gfx/image/image_skia.h" + +namespace ash { +namespace phonehub { + +namespace { + +const gfx::Image CreateTestImage() { + SkBitmap test_bitmap; + test_bitmap.allocN32Pixels(1, 1); + gfx::ImageSkia image_skia = gfx::ImageSkia::CreateFrom1xBitmap(test_bitmap); + image_skia.MakeThreadSafe(); + return gfx::Image(image_skia); +} + +} // namespace + +class CameraRollItemTest : public testing::Test { + protected: + CameraRollItemTest() = default; + CameraRollItemTest(const CameraRollItemTest&) = delete; + CameraRollItemTest& operator=(const CameraRollItemTest&) = delete; + ~CameraRollItemTest() override = default; +}; + +TEST_F(CameraRollItemTest, ItemsMatch) { + proto::CameraRollItemMetadata metadata_1; + metadata_1.set_key("key1"); + metadata_1.set_mime_type("image/png"); + metadata_1.set_last_modified_millis(123456789L); + metadata_1.set_file_size_bytes(123456789L); + metadata_1.set_file_name("FakeImage.png"); + + proto::CameraRollItemMetadata metadata_2; + metadata_2.set_key("key1"); + metadata_2.set_mime_type("image/png"); + metadata_2.set_last_modified_millis(123456789L); + metadata_2.set_file_size_bytes(123456789L); + metadata_2.set_file_name("FakeImage.png"); + + CameraRollItem item_1(metadata_1, CreateTestImage()); + CameraRollItem item_2(metadata_2, CreateTestImage()); + + EXPECT_TRUE(item_1 == item_2); + EXPECT_FALSE(item_1 != item_2); +} + +TEST_F(CameraRollItemTest, ItemsDoNotMatch) { + proto::CameraRollItemMetadata metadata_1; + metadata_1.set_key("key1"); + metadata_1.set_mime_type("image/png"); + metadata_1.set_last_modified_millis(123456789L); + metadata_1.set_file_size_bytes(123456789L); + metadata_1.set_file_name("FakeImage.png"); + + proto::CameraRollItemMetadata metadata_2; + metadata_2.set_key("key2"); + metadata_2.set_mime_type("video/mp4"); + metadata_2.set_last_modified_millis(987654321L); + metadata_2.set_file_size_bytes(987654321L); + metadata_2.set_file_name("FakeVideo.mp4"); + + CameraRollItem item_1(metadata_1, CreateTestImage()); + CameraRollItem item_2(metadata_2, CreateTestImage()); + + EXPECT_FALSE(item_1 == item_2); + EXPECT_TRUE(item_1 != item_2); +} + +} // namespace phonehub +} // namespace ash
diff --git a/ash/drag_drop/drag_drop_controller.cc b/ash/drag_drop/drag_drop_controller.cc index b55f11e..e6d5759 100644 --- a/ash/drag_drop/drag_drop_controller.cc +++ b/ash/drag_drop/drag_drop_controller.cc
@@ -241,8 +241,11 @@ !pending_long_tap_.get()) { // If drag cancel animation is running, this cleanup is done when the // animation completes. - if (drag_source_window_) + if (drag_source_window_) { + // A check to catch an UAF issue like crbug.com/1282480 on non asan build. + DCHECK(!drag_source_window_->is_destroying()); drag_source_window_->RemoveObserver(this); + } drag_source_window_ = nullptr; } @@ -703,9 +706,14 @@ void DragDropController::Cleanup() { for (aura::client::DragDropClientObserver& observer : observers_) observer.OnDragEnded(); - if (drag_window_) + + // Do not remove observer `the drag_window_1 is same as `drag_source_window_`. + // `drag_source_window_` is still necessary to process long tab and the + // observer will be reset when `drag_source_window_` is destroyed. + if (drag_window_ && drag_window_ != drag_source_window_) drag_window_->RemoveObserver(this); drag_window_ = nullptr; + drag_data_.reset(); allowed_operations_ = 0; tab_drag_drop_delegate_.reset();
diff --git a/ash/login/ui/public_account_menu_view.cc b/ash/login/ui/public_account_menu_view.cc index 11b51c53..7c7cacae 100644 --- a/ash/login/ui/public_account_menu_view.cc +++ b/ash/login/ui/public_account_menu_view.cc
@@ -62,7 +62,8 @@ PublicAccountMenuView::PublicAccountMenuView(const std::vector<Item>& items, const size_t selected_index, const OnSelect& on_select) - : views::Combobox(new PublicAccountComboboxModel(items, selected_index)), + : views::Combobox( + std::make_unique<PublicAccountComboboxModel>(items, selected_index)), items_(items), on_select_(on_select) { SetPreferredSize(
diff --git a/ash/quick_pair/common/fast_pair/fast_pair_metrics.cc b/ash/quick_pair/common/fast_pair/fast_pair_metrics.cc index 3ce4e034..c2b70c687 100644 --- a/ash/quick_pair/common/fast_pair/fast_pair_metrics.cc +++ b/ash/quick_pair/common/fast_pair/fast_pair_metrics.cc
@@ -140,6 +140,17 @@ "TotalConnectTime"; const char kDeviceMetadataFetchResult[] = "Bluetooth.ChromeOS.FastPair.DeviceMetadataFetcher.Result"; +const char kFootprintsFetcherDeleteResult[] = + "Bluetooth.ChromeOS.FastPair.FootprintsFetcher.Delete.Result"; +const char kFootprintsFetcherPostResult[] = + "Bluetooth.ChromeOS.FastPair.FootprintsFetcher.Post.Result"; +const char kFootprintsFetcherGetResult[] = + "Bluetooth.ChromeOS.FastPair.FootprintsFetcher.Get.Result"; +const char kFastPairRepositoryCacheResult[] = + "Bluetooth.ChromeOS.FastPair.FastPairRepository.Cache.Result"; +const char kHandshakeResult[] = "Bluetooth.ChromeOS.FastPair.Handshake.Result"; +const char kHandshakeFailureReason[] = + "Bluetooth.ChromeOS.FastPair.Handshake.FailureReason"; } // namespace @@ -385,5 +396,28 @@ base::UmaHistogramBoolean(kDeviceMetadataFetchResult, success); } +void RecordFootprintsFetcherDeleteResult(bool success) { + base::UmaHistogramBoolean(kFootprintsFetcherDeleteResult, success); +} + +void RecordFootprintsFetcherPostResult(bool success) { + base::UmaHistogramBoolean(kFootprintsFetcherPostResult, success); +} + +void RecordFootprintsFetcherGetResult(bool success) { + base::UmaHistogramBoolean(kFootprintsFetcherGetResult, success); +} + +void RecordFastPairRepositoryCacheResult(bool success) { + base::UmaHistogramBoolean(kFastPairRepositoryCacheResult, success); +} + +void RecordHandshakeResult(bool success) { + base::UmaHistogramBoolean(kHandshakeResult, success); +} +void RecordHandshakeFailureReason(HandshakeFailureReason failure_reason) { + base::UmaHistogramEnumeration(kHandshakeFailureReason, failure_reason); +} + } // namespace quick_pair } // namespace ash
diff --git a/ash/quick_pair/common/fast_pair/fast_pair_metrics.h b/ash/quick_pair/common/fast_pair/fast_pair_metrics.h index 9eb26eec..9a70582 100644 --- a/ash/quick_pair/common/fast_pair/fast_pair_metrics.h +++ b/ash/quick_pair/common/fast_pair/fast_pair_metrics.h
@@ -60,6 +60,19 @@ kMaxValue = kSystemPairingUi, }; +// These values are persisted to logs. Entries should not be renumbered and +// numeric values should never be reused. This enum should be kept in sync +// with the FastPairHandshakeFailureReason enum in +// src/tools/metrics/histograms/enums.xml. +enum class COMPONENT_EXPORT(QUICK_PAIR_COMMON) HandshakeFailureReason { + kFailedGattInit = 0, + kFailedCreateEncryptor = 1, + kFailedWriteResponse = 2, + kFailedDecryptResponse = 3, + kFailedIncorrectResponseType = 4, + kMaxValue = kFailedIncorrectResponseType, +}; + COMPONENT_EXPORT(QUICK_PAIR_COMMON) void AttemptRecordingFastPairEngagementFlow(const Device& device, FastPairEngagementFlowEvent event); @@ -179,6 +192,24 @@ COMPONENT_EXPORT(QUICK_PAIR_COMMON) void RecordDeviceMetadataFetchResult(bool success); +COMPONENT_EXPORT(QUICK_PAIR_COMMON) +void RecordFootprintsFetcherDeleteResult(bool success); + +COMPONENT_EXPORT(QUICK_PAIR_COMMON) +void RecordFootprintsFetcherPostResult(bool success); + +COMPONENT_EXPORT(QUICK_PAIR_COMMON) +void RecordFootprintsFetcherGetResult(bool success); + +COMPONENT_EXPORT(QUICK_PAIR_COMMON) +void RecordFastPairRepositoryCacheResult(bool success); + +COMPONENT_EXPORT(QUICK_PAIR_COMMON) +void RecordHandshakeResult(bool success); + +COMPONENT_EXPORT(QUICK_PAIR_COMMON) +void RecordHandshakeFailureReason(HandshakeFailureReason failure_reason); + } // namespace quick_pair } // namespace ash
diff --git a/ash/quick_pair/fast_pair_handshake/fast_pair_handshake_impl.cc b/ash/quick_pair/fast_pair_handshake/fast_pair_handshake_impl.cc index 73c5893c..fce55b61 100644 --- a/ash/quick_pair/fast_pair_handshake/fast_pair_handshake_impl.cc +++ b/ash/quick_pair/fast_pair_handshake/fast_pair_handshake_impl.cc
@@ -48,6 +48,8 @@ << ": Failed to init gatt client with failure = " << failure.value(); std::move(on_complete_callback_).Run(device_, failure.value()); + RecordHandshakeResult(/*success=*/false); + RecordHandshakeFailureReason(HandshakeFailureReason::kFailedGattInit); return; } @@ -68,6 +70,9 @@ << ": Fast Pair Data Encryptor failed to be created."; std::move(on_complete_callback_) .Run(device_, PairFailure::kDataEncryptorRetrieval); + RecordHandshakeResult(/*success=*/false); + RecordHandshakeFailureReason( + HandshakeFailureReason::kFailedCreateEncryptor); return; } @@ -98,6 +103,8 @@ QP_LOG(WARNING) << __func__ << ": Failed to write request: " << failure.value(); RecordWriteKeyBasedCharacteristicPairFailure(failure.value()); + RecordHandshakeResult(/*success=*/false); + RecordHandshakeFailureReason(HandshakeFailureReason::kFailedWriteResponse); std::move(on_complete_callback_).Run(device_, failure.value()); return; } @@ -118,6 +125,9 @@ std::move(on_complete_callback_) .Run(device_, PairFailure::kKeybasedPairingResponseDecryptFailure); RecordKeyBasedCharacteristicDecryptResult(/*success=*/false); + RecordHandshakeResult(/*success=*/false); + RecordHandshakeFailureReason( + HandshakeFailureReason::kFailedDecryptResponse); return; } @@ -127,6 +137,9 @@ std::move(on_complete_callback_) .Run(device_, PairFailure::kIncorrectKeyBasedPairingResponseType); RecordKeyBasedCharacteristicDecryptResult(/*success=*/false); + RecordHandshakeResult(/*success=*/false); + RecordHandshakeFailureReason( + HandshakeFailureReason::kFailedIncorrectResponseType); return; } @@ -138,6 +151,7 @@ device_->set_classic_address(device_address); completed_successfully_ = true; + RecordHandshakeResult(/*success=*/true); std::move(on_complete_callback_).Run(device_, absl::nullopt); }
diff --git a/ash/quick_pair/fast_pair_handshake/fast_pair_handshake_impl_unittest.cc b/ash/quick_pair/fast_pair_handshake/fast_pair_handshake_impl_unittest.cc index 57d9864..d53ab6eb 100644 --- a/ash/quick_pair/fast_pair_handshake/fast_pair_handshake_impl_unittest.cc +++ b/ash/quick_pair/fast_pair_handshake/fast_pair_handshake_impl_unittest.cc
@@ -105,6 +105,9 @@ "Bluetooth.ChromeOS.FastPair.KeyBasedPairing.DecryptResult"; const char kTotalDataEncryptorCreateTimeMetric[] = "Bluetooth.ChromeOS.FastPair.FastPairDataEncryptor.CreateTime"; +const char kHandshakeResult[] = "Bluetooth.ChromeOS.FastPair.Handshake.Result"; +const char kHandshakeFailureReason[] = + "Bluetooth.ChromeOS.FastPair.Handshake.FailureReason"; } // namespace @@ -161,13 +164,19 @@ }; TEST_F(FastPairHandshakeImplTest, GattError) { + histogram_tester().ExpectTotalCount(kHandshakeResult, 0); + histogram_tester().ExpectTotalCount(kHandshakeFailureReason, 0); fake_fast_pair_gatt_service_client()->RunOnGattClientInitializedCallback( PairFailure::kCreateGattConnection); EXPECT_EQ(failure_.value(), PairFailure::kCreateGattConnection); EXPECT_FALSE(handshake_->completed_successfully()); + histogram_tester().ExpectTotalCount(kHandshakeResult, 1); + histogram_tester().ExpectTotalCount(kHandshakeFailureReason, 1); } TEST_F(FastPairHandshakeImplTest, DataEncryptorCreateError) { + histogram_tester().ExpectTotalCount(kHandshakeResult, 0); + histogram_tester().ExpectTotalCount(kHandshakeFailureReason, 0); histogram_tester().ExpectTotalCount(kWriteKeyBasedCharacteristicResultMetric, 0); histogram_tester().ExpectTotalCount(kDataEncryptorCreateResultMetric, 0); @@ -184,9 +193,13 @@ histogram_tester().ExpectTotalCount( kWriteKeyBasedCharacteristicPairFailureMetric, 0); histogram_tester().ExpectTotalCount(kTotalDataEncryptorCreateTimeMetric, 0); + histogram_tester().ExpectTotalCount(kHandshakeResult, 1); + histogram_tester().ExpectTotalCount(kHandshakeFailureReason, 1); } TEST_F(FastPairHandshakeImplTest, WriteResponseError) { + histogram_tester().ExpectTotalCount(kHandshakeResult, 0); + histogram_tester().ExpectTotalCount(kHandshakeFailureReason, 0); histogram_tester().ExpectTotalCount(kDataEncryptorCreateResultMetric, 0); histogram_tester().ExpectTotalCount(kWriteKeyBasedCharacteristicResultMetric, 0); @@ -204,9 +217,13 @@ 1); histogram_tester().ExpectTotalCount( kWriteKeyBasedCharacteristicPairFailureMetric, 1); + histogram_tester().ExpectTotalCount(kHandshakeResult, 1); + histogram_tester().ExpectTotalCount(kHandshakeFailureReason, 1); } TEST_F(FastPairHandshakeImplTest, ParseResponseError) { + histogram_tester().ExpectTotalCount(kHandshakeResult, 0); + histogram_tester().ExpectTotalCount(kHandshakeFailureReason, 0); histogram_tester().ExpectTotalCount(kKeyBasedCharacteristicDecryptTime, 0); histogram_tester().ExpectTotalCount(kKeyBasedCharacteristicDecryptResult, 0); histogram_tester().ExpectTotalCount( @@ -222,9 +239,13 @@ kWriteKeyBasedCharacteristicPairFailureMetric, 0); histogram_tester().ExpectTotalCount(kKeyBasedCharacteristicDecryptTime, 0); histogram_tester().ExpectTotalCount(kKeyBasedCharacteristicDecryptResult, 1); + histogram_tester().ExpectTotalCount(kHandshakeResult, 1); + histogram_tester().ExpectTotalCount(kHandshakeFailureReason, 1); } TEST_F(FastPairHandshakeImplTest, ParseResponseWrongType) { + histogram_tester().ExpectTotalCount(kHandshakeResult, 0); + histogram_tester().ExpectTotalCount(kHandshakeFailureReason, 0); histogram_tester().ExpectTotalCount(kKeyBasedCharacteristicDecryptTime, 0); histogram_tester().ExpectTotalCount(kKeyBasedCharacteristicDecryptResult, 0); fake_fast_pair_gatt_service_client()->RunOnGattClientInitializedCallback(); @@ -239,9 +260,13 @@ EXPECT_FALSE(handshake_->completed_successfully()); histogram_tester().ExpectTotalCount(kKeyBasedCharacteristicDecryptTime, 0); histogram_tester().ExpectTotalCount(kKeyBasedCharacteristicDecryptResult, 1); + histogram_tester().ExpectTotalCount(kHandshakeResult, 1); + histogram_tester().ExpectTotalCount(kHandshakeFailureReason, 1); } TEST_F(FastPairHandshakeImplTest, Success) { + histogram_tester().ExpectTotalCount(kHandshakeResult, 0); + histogram_tester().ExpectTotalCount(kHandshakeFailureReason, 0); histogram_tester().ExpectTotalCount(kKeyBasedCharacteristicDecryptTime, 0); histogram_tester().ExpectTotalCount(kKeyBasedCharacteristicDecryptResult, 0); fake_fast_pair_gatt_service_client()->RunOnGattClientInitializedCallback(); @@ -255,6 +280,8 @@ EXPECT_TRUE(handshake_->completed_successfully()); histogram_tester().ExpectTotalCount(kKeyBasedCharacteristicDecryptTime, 1); histogram_tester().ExpectTotalCount(kKeyBasedCharacteristicDecryptResult, 1); + histogram_tester().ExpectTotalCount(kHandshakeResult, 1); + histogram_tester().ExpectTotalCount(kHandshakeFailureReason, 0); } } // namespace quick_pair
diff --git a/ash/quick_pair/keyed_service/BUILD.gn b/ash/quick_pair/keyed_service/BUILD.gn index 511b8b8f..ad3b8f0 100644 --- a/ash/quick_pair/keyed_service/BUILD.gn +++ b/ash/quick_pair/keyed_service/BUILD.gn
@@ -57,6 +57,7 @@ "//ash/quick_pair/message_stream:test_support", "//ash/quick_pair/pairing:test_support", "//ash/quick_pair/repository", + "//ash/quick_pair/repository:test_support", "//ash/quick_pair/scanning:test_support", "//ash/quick_pair/ui:test_support", "//ash/services/quick_pair",
diff --git a/ash/quick_pair/keyed_service/quick_pair_mediator.cc b/ash/quick_pair/keyed_service/quick_pair_mediator.cc index bffbcad..41e58834 100644 --- a/ash/quick_pair/keyed_service/quick_pair_mediator.cc +++ b/ash/quick_pair/keyed_service/quick_pair_mediator.cc
@@ -131,7 +131,9 @@ void Mediator::OnDeviceFound(scoped_refptr<Device> device) { QP_LOG(INFO) << __func__ << ": " << device; + // On discovery, download and decode device images. ui_broker_->ShowDiscovery(device); + fast_pair_repository_->FetchDeviceImages(device); } void Mediator::OnDeviceLost(scoped_refptr<Device> device) { @@ -156,7 +158,8 @@ void Mediator::OnDevicePaired(scoped_refptr<Device> device) { QP_LOG(INFO) << __func__ << ": Device=" << device; - ui_broker_->RemoveNotifications(std::move(device)); + ui_broker_->RemoveNotifications(device); + fast_pair_repository_->PersistDeviceImages(device); } void Mediator::OnPairFailure(scoped_refptr<Device> device,
diff --git a/ash/quick_pair/keyed_service/quick_pair_mediator_unittest.cc b/ash/quick_pair/keyed_service/quick_pair_mediator_unittest.cc index b06b977..e200da02 100644 --- a/ash/quick_pair/keyed_service/quick_pair_mediator_unittest.cc +++ b/ash/quick_pair/keyed_service/quick_pair_mediator_unittest.cc
@@ -18,7 +18,7 @@ #include "ash/quick_pair/pairing/mock_pairer_broker.h" #include "ash/quick_pair/pairing/pairer_broker.h" #include "ash/quick_pair/pairing/retroactive_pairing_detector.h" -#include "ash/quick_pair/repository/fast_pair_repository.h" +#include "ash/quick_pair/repository/mock_fast_pair_repository.h" #include "ash/quick_pair/scanning/mock_scanner_broker.h" #include "ash/quick_pair/scanning/scanner_broker.h" #include "ash/quick_pair/ui/mock_ui_broker.h" @@ -72,11 +72,16 @@ std::unique_ptr<UIBroker> ui_broker = std::make_unique<MockUIBroker>(); mock_ui_broker_ = static_cast<MockUIBroker*>(ui_broker.get()); + std::unique_ptr<FastPairRepository> fast_pair_repository = + std::make_unique<MockFastPairRepository>(); + mock_fast_pair_repository_ = + static_cast<MockFastPairRepository*>(fast_pair_repository.get()); + mediator_ = std::make_unique<Mediator>( std::move(tracker), std::move(scanner_broker), std::move(retroactive_pairing_detector), std::make_unique<FakeMessageStreamLookup>(), std::move(pairer_broker), - std::move(ui_broker), std::unique_ptr<FastPairRepository>(), + std::move(ui_broker), std::move(fast_pair_repository), std::make_unique<QuickPairProcessManagerImpl>()); device_ = base::MakeRefCounted<Device>(kTestMetadataId, kTestAddress, @@ -91,6 +96,7 @@ FakeRetroactivePairingDetector* fake_retroactive_pairing_detector_; MockPairerBroker* mock_pairer_broker_; MockUIBroker* mock_ui_broker_; + MockFastPairRepository* mock_fast_pair_repository_; std::unique_ptr<Mediator> mediator_; base::test::SingleThreadTaskEnvironment task_environment_; };
diff --git a/ash/quick_pair/pairing/fast_pair/fast_pair_pairer.cc b/ash/quick_pair/pairing/fast_pair/fast_pair_pairer.cc index c29dadfa..a9d7013 100644 --- a/ash/quick_pair/pairing/fast_pair/fast_pair_pairer.cc +++ b/ash/quick_pair/pairing/fast_pair/fast_pair_pairer.cc
@@ -339,6 +339,7 @@ if (device->GetAddress() == device_->ble_address || device->GetAddress() == device_->classic_address()) { std::move(paired_callback_).Run(device_); + std::move(pairing_procedure_complete_).Run(device_); } }
diff --git a/ash/quick_pair/pairing/fast_pair/fast_pair_unpair_handler.cc b/ash/quick_pair/pairing/fast_pair/fast_pair_unpair_handler.cc index a45424c..87a271e 100644 --- a/ash/quick_pair/pairing/fast_pair/fast_pair_unpair_handler.cc +++ b/ash/quick_pair/pairing/fast_pair/fast_pair_unpair_handler.cc
@@ -31,6 +31,14 @@ if (new_paired_status) return; + if (FastPairRepository::Get()->EvictDeviceImages(device)) { + QP_LOG(INFO) << __func__ << ": Repository evicted device images."; + } else { + QP_LOG(INFO) << __func__ + << ": Repository did not evict device images (no images found " + "or other matching device IDs still paired)."; + } + if (FastPairRepository::Get()->DeleteAssociatedDevice(device)) { QP_LOG(INFO) << __func__ << ": Repository is processing the delete"; } else {
diff --git a/ash/quick_pair/pairing/fast_pair/fast_pair_unpair_handler_unittest.cc b/ash/quick_pair/pairing/fast_pair/fast_pair_unpair_handler_unittest.cc index 91a8387..d321abb 100644 --- a/ash/quick_pair/pairing/fast_pair/fast_pair_unpair_handler_unittest.cc +++ b/ash/quick_pair/pairing/fast_pair/fast_pair_unpair_handler_unittest.cc
@@ -66,5 +66,16 @@ NotifyPairChanged(/*new_pair_state=*/false); } +TEST_F(FastPairUnpairHandlerTest, DoesntEvictIfDevicePaired) { + EXPECT_CALL(*(mock_repository_.get()), EvictDeviceImages).Times(0); + NotifyPairChanged(/*new_pair_state=*/true); +} + +TEST_F(FastPairUnpairHandlerTest, EvictsExpectedDevice) { + EXPECT_CALL(*(mock_repository_.get()), EvictDeviceImages(device_.get())) + .Times(1); + NotifyPairChanged(/*new_pair_state=*/false); +} + } // namespace quick_pair } // namespace ash
diff --git a/ash/quick_pair/repository/BUILD.gn b/ash/quick_pair/repository/BUILD.gn index 83da8b0..1661ab5b 100644 --- a/ash/quick_pair/repository/BUILD.gn +++ b/ash/quick_pair/repository/BUILD.gn
@@ -79,6 +79,7 @@ "//ash/quick_pair/proto:fastpair_proto", "//base", "//base/test:test_support", + "//chromeos/services/bluetooth_config/public/cpp", "//device/bluetooth", "//net/traffic_annotation:test_support", "//testing/gtest",
diff --git a/ash/quick_pair/repository/fake_fast_pair_repository.cc b/ash/quick_pair/repository/fake_fast_pair_repository.cc index 20ca579..772d442 100644 --- a/ash/quick_pair/repository/fake_fast_pair_repository.cc +++ b/ash/quick_pair/repository/fake_fast_pair_repository.cc
@@ -6,6 +6,7 @@ #include "ash/quick_pair/proto/fastpair.pb.h" #include "base/strings/string_util.h" +#include "chromeos/services/bluetooth_config/public/cpp/device_image_info.h" #include "device/bluetooth/bluetooth_device.h" namespace ash { @@ -74,5 +75,27 @@ return saved_account_keys_.erase(device->GetAddress()) == 1; } +// Unimplemented. +void FakeFastPairRepository::FetchDeviceImages(scoped_refptr<Device> device) { + return; +} + +// Unimplemented. +bool FakeFastPairRepository::PersistDeviceImages(scoped_refptr<Device> device) { + return true; +} + +// Unimplemented. +bool FakeFastPairRepository::EvictDeviceImages( + const device::BluetoothDevice* device) { + return true; +} + +// Unimplemented. +absl::optional<const chromeos::bluetooth_config::DeviceImageInfo> +FakeFastPairRepository::GetImagesForDevice(const std::string& device_id) { + return absl::nullopt; +} + } // namespace quick_pair } // namespace ash
diff --git a/ash/quick_pair/repository/fake_fast_pair_repository.h b/ash/quick_pair/repository/fake_fast_pair_repository.h index 9bc7fb35..ad55b91 100644 --- a/ash/quick_pair/repository/fake_fast_pair_repository.h +++ b/ash/quick_pair/repository/fake_fast_pair_repository.h
@@ -13,9 +13,15 @@ #include "base/containers/flat_map.h" #include "third_party/abseil-cpp/absl/types/optional.h" +namespace chromeos { +namespace bluetooth_config { +class DeviceImageInfo; +} // namespace bluetooth_config +} // namespace chromeos + namespace device { class BluetoothDevice; -} +} // namespace device namespace ash { namespace quick_pair { @@ -49,6 +55,11 @@ void AssociateAccountKey(scoped_refptr<Device> device, const std::vector<uint8_t>& account_key) override; bool DeleteAssociatedDevice(const device::BluetoothDevice* device) override; + void FetchDeviceImages(scoped_refptr<Device> device) override; + bool PersistDeviceImages(scoped_refptr<Device> device) override; + bool EvictDeviceImages(const device::BluetoothDevice* device) override; + absl::optional<const chromeos::bluetooth_config::DeviceImageInfo> + GetImagesForDevice(const std::string& device_id) override; private: static void SetInstance(FastPairRepository* instance);
diff --git a/ash/quick_pair/repository/fast_pair/footprints_fetcher.cc b/ash/quick_pair/repository/fast_pair/footprints_fetcher.cc index fba4778..e92216d 100644 --- a/ash/quick_pair/repository/fast_pair/footprints_fetcher.cc +++ b/ash/quick_pair/repository/fast_pair/footprints_fetcher.cc
@@ -4,6 +4,7 @@ #include "ash/quick_pair/repository/fast_pair/footprints_fetcher.h" +#include "ash/quick_pair/common/fast_pair/fast_pair_metrics.h" #include "ash/quick_pair/common/logging.h" #include "ash/quick_pair/proto/fastpair.pb.h" #include "ash/quick_pair/proto/fastpair_data.pb.h" @@ -89,6 +90,7 @@ if (!response_body) { QP_LOG(WARNING) << __func__ << ": No response."; + RecordFootprintsFetcherGetResult(/*success=*/false); std::move(callback).Run(absl::nullopt); return; } @@ -96,10 +98,12 @@ nearby::fastpair::UserReadDevicesResponse devices; if (!devices.ParseFromString(*response_body)) { QP_LOG(WARNING) << __func__ << ": Failed to parse."; + RecordFootprintsFetcherGetResult(/*success=*/false); std::move(callback).Run(absl::nullopt); return; } + RecordFootprintsFetcherGetResult(/*success=*/true); QP_LOG(VERBOSE) << __func__ << ": Successfully retrived footprints data. Paired devices:"; @@ -131,6 +135,7 @@ std::unique_ptr<HttpFetcher> http_fetcher, std::unique_ptr<std::string> response_body) { QP_LOG(VERBOSE) << __func__; + RecordFootprintsFetcherPostResult(/*success=*/response_body ? true : false); if (!response_body) { QP_LOG(WARNING) << __func__ << ": No response."; @@ -158,6 +163,7 @@ std::unique_ptr<HttpFetcher> http_fetcher, std::unique_ptr<std::string> response_body) { QP_LOG(VERBOSE) << __func__; + RecordFootprintsFetcherDeleteResult(/*success=*/response_body ? true : false); if (!response_body) { QP_LOG(WARNING) << __func__ << ": No response.";
diff --git a/ash/quick_pair/repository/fast_pair_repository.h b/ash/quick_pair/repository/fast_pair_repository.h index addfd93..c92b63b 100644 --- a/ash/quick_pair/repository/fast_pair_repository.h +++ b/ash/quick_pair/repository/fast_pair_repository.h
@@ -11,11 +11,18 @@ #include "ash/quick_pair/repository/fast_pair/pairing_metadata.h" #include "base/callback.h" #include "base/containers/flat_map.h" +#include "chromeos/services/bluetooth_config/public/cpp/device_image_info.h" #include "third_party/abseil-cpp/absl/types/optional.h" +namespace chromeos { +namespace bluetooth_config { +class DeviceImageInfo; +} // namespace bluetooth_config +} // namespace chromeos + namespace device { class BluetoothDevice; -} +} // namespace device namespace ash { namespace quick_pair { @@ -62,6 +69,22 @@ virtual bool DeleteAssociatedDevice( const device::BluetoothDevice* device) = 0; + // Fetches the |device| images and a record of the device ID -> model ID + // mapping to memory. + virtual void FetchDeviceImages(scoped_refptr<Device> device) = 0; + + // Persists the images and device ID belonging to |device| to + // disk, if model ID is not already persisted. + virtual bool PersistDeviceImages(scoped_refptr<Device> device) = 0; + + // Evicts the images and device ID belonging to |device| from + // disk, if model ID is not in use by other device IDs. + virtual bool EvictDeviceImages(const device::BluetoothDevice* device) = 0; + + // Returns device images belonging to |device_id|, if found. + virtual absl::optional<const chromeos::bluetooth_config::DeviceImageInfo> + GetImagesForDevice(const std::string& device_id) = 0; + protected: static void SetInstance(FastPairRepository* instance); };
diff --git a/ash/quick_pair/repository/fast_pair_repository_impl.cc b/ash/quick_pair/repository/fast_pair_repository_impl.cc index 1b3eee03..0b2b138d 100644 --- a/ash/quick_pair/repository/fast_pair_repository_impl.cc +++ b/ash/quick_pair/repository/fast_pair_repository_impl.cc
@@ -16,8 +16,11 @@ #include "ash/quick_pair/repository/fast_pair/proto_conversions.h" #include "ash/quick_pair/repository/fast_pair/saved_device_registry.h" #include "ash/services/quick_pair/public/cpp/account_key_filter.h" +#include "base/callback_helpers.h" +#include "base/memory/scoped_refptr.h" #include "base/strings/string_number_conversions.h" #include "base/strings/stringprintf.h" +#include "chromeos/services/bluetooth_config/public/cpp/device_image_info.h" #include "components/image_fetcher/core/image_fetcher.h" #include "device/bluetooth/bluetooth_device.h" @@ -48,10 +51,12 @@ std::string normalized_id = base::ToUpperASCII(hex_model_id); if (metadata_cache_.contains(normalized_id)) { QP_LOG(VERBOSE) << __func__ << ": Data already in cache."; + RecordFastPairRepositoryCacheResult(/*success=*/true); std::move(callback).Run(metadata_cache_[normalized_id].get()); return; } QP_LOG(VERBOSE) << __func__ << ": Not cached, fetching from web service."; + RecordFastPairRepositoryCacheResult(/*success=*/false); device_metadata_fetcher_->LookupHexDeviceId( normalized_id, base::BindOnce(&FastPairRepositoryImpl::OnMetadataFetched, weak_ptr_factory_.GetWeakPtr(), @@ -242,5 +247,62 @@ return true; } +void FastPairRepositoryImpl::FetchDeviceImages(scoped_refptr<Device> device) { + QP_LOG(INFO) << __func__ << ": Fetching device images for model ID " + << device->metadata_id; + // Save a record of the device ID -> model ID for this device so that we can + // display images for device objects that lack a model ID, such as + // device::BluetoothDevice. + device_id_map_->SaveModelIdForDevice(device); + + GetDeviceMetadata( + device->metadata_id, + base::BindOnce(&FastPairRepositoryImpl::CompleteFetchDeviceImages, + weak_ptr_factory_.GetWeakPtr(), device->metadata_id)); +} + +void FastPairRepositoryImpl::CompleteFetchDeviceImages( + const std::string& hex_model_id, + DeviceMetadata* device_metadata) { + QP_LOG(INFO) << __func__ + << ": Completing fetching device images for model ID " + << hex_model_id; + device_image_store_->SaveDeviceImages(hex_model_id, device_metadata, + base::DoNothing()); +} + +bool FastPairRepositoryImpl::PersistDeviceImages(scoped_refptr<Device> device) { + QP_LOG(INFO) << __func__ << ": Persisting device images for model ID " + << device->metadata_id; + device_id_map_->PersistRecordsForDevice(device); + return device_image_store_->PersistDeviceImages(device->metadata_id); +} + +bool FastPairRepositoryImpl::EvictDeviceImages( + const device::BluetoothDevice* device) { + const std::string device_id = device->GetIdentifier(); + absl::optional<const std::string> hex_model_id = + device_id_map_->GetModelIdForDeviceId(device_id); + if (!hex_model_id) + return false; + device_id_map_->EvictDeviceIdRecord(device_id); + + // Before evicting images, check if other device IDs map to this model ID. + if (device_id_map_->HasPersistedRecordsForModelId(hex_model_id.value())) + return false; + + return device_image_store_->EvictDeviceImages(hex_model_id.value()); +} + +absl::optional<const chromeos::bluetooth_config::DeviceImageInfo> +FastPairRepositoryImpl::GetImagesForDevice(const std::string& device_id) { + absl::optional<const std::string> hex_model_id = + device_id_map_->GetModelIdForDeviceId(device_id); + if (!hex_model_id) + return absl::nullopt; + + return device_image_store_->GetImagesForDeviceModel(hex_model_id.value()); +} + } // namespace quick_pair } // namespace ash
diff --git a/ash/quick_pair/repository/fast_pair_repository_impl.h b/ash/quick_pair/repository/fast_pair_repository_impl.h index b9a81470..b8ee590 100644 --- a/ash/quick_pair/repository/fast_pair_repository_impl.h +++ b/ash/quick_pair/repository/fast_pair_repository_impl.h
@@ -13,9 +13,15 @@ #include "base/time/time.h" #include "third_party/abseil-cpp/absl/types/optional.h" +namespace chromeos { +namespace bluetooth_config { +class DeviceImageInfo; +} // namespace bluetooth_config +} // namespace chromeos + namespace device { class BluetoothDevice; -} +} // namespace device namespace nearby { namespace fastpair { @@ -52,6 +58,11 @@ void AssociateAccountKey(scoped_refptr<Device> device, const std::vector<uint8_t>& account_key) override; bool DeleteAssociatedDevice(const device::BluetoothDevice* device) override; + void FetchDeviceImages(scoped_refptr<Device> device) override; + bool PersistDeviceImages(scoped_refptr<Device> device) override; + bool EvictDeviceImages(const device::BluetoothDevice* device) override; + absl::optional<const chromeos::bluetooth_config::DeviceImageInfo> + GetImagesForDevice(const std::string& device_id) override; private: void CheckAccountKeysImpl(const AccountKeyFilter& account_key_filter, @@ -81,6 +92,10 @@ void OnAddToFootprintsComplete(const std::string& mac_address, const std::vector<uint8_t>& account_key, bool success); + // Fethces the |device_metadata| images to the DeviceImageStore for + // |hex_model_id|. + void CompleteFetchDeviceImages(const std::string& hex_model_id, + DeviceMetadata* device_metadata); std::unique_ptr<DeviceMetadataFetcher> device_metadata_fetcher_; std::unique_ptr<FootprintsFetcher> footprints_fetcher_;
diff --git a/ash/quick_pair/repository/mock_fast_pair_repository.h b/ash/quick_pair/repository/mock_fast_pair_repository.h index 822cbc6..3c09cc1 100644 --- a/ash/quick_pair/repository/mock_fast_pair_repository.h +++ b/ash/quick_pair/repository/mock_fast_pair_repository.h
@@ -8,6 +8,7 @@ #include "ash/quick_pair/repository/fast_pair_repository.h" #include "base/memory/scoped_refptr.h" #include "testing/gmock/include/gmock/gmock.h" +#include "third_party/abseil-cpp/absl/types/optional.h" namespace ash { namespace quick_pair { @@ -42,6 +43,22 @@ DeleteAssociatedDevice, (const device::BluetoothDevice* device), (override)); + MOCK_METHOD(void, + FetchDeviceImages, + (scoped_refptr<Device> device), + (override)); + MOCK_METHOD(bool, + PersistDeviceImages, + (scoped_refptr<Device> device), + (override)); + MOCK_METHOD(bool, + EvictDeviceImages, + (const device::BluetoothDevice* device), + (override)); + MOCK_METHOD(absl::optional<const chromeos::bluetooth_config::DeviceImageInfo>, + GetImagesForDevice, + (const std::string& device_id), + (override)); }; } // namespace quick_pair
diff --git a/ash/system/phonehub/camera_roll_thumbnail.h b/ash/system/phonehub/camera_roll_thumbnail.h index 914dee7..c925b93 100644 --- a/ash/system/phonehub/camera_roll_thumbnail.h +++ b/ash/system/phonehub/camera_roll_thumbnail.h
@@ -36,6 +36,9 @@ const char* GetClassName() const override; private: + FRIEND_TEST_ALL_PREFIXES(CameraRollViewTest, ImageThumbnail); + FRIEND_TEST_ALL_PREFIXES(CameraRollViewTest, VideoThumbnail); + void ButtonPressed(); ui::SimpleMenuModel* GetMenuModel(); void DownloadRequested();
diff --git a/ash/system/phonehub/camera_roll_view.h b/ash/system/phonehub/camera_roll_view.h index 6abbd23..66fb42e 100644 --- a/ash/system/phonehub/camera_roll_view.h +++ b/ash/system/phonehub/camera_roll_view.h
@@ -41,10 +41,13 @@ bool should_disable_annimator_timer_for_test_ = false; private: + friend class CameraRollViewTest; FRIEND_TEST_ALL_PREFIXES(CameraRollViewTest, DisplayOptInView); FRIEND_TEST_ALL_PREFIXES(CameraRollViewTest, OptInAlready); FRIEND_TEST_ALL_PREFIXES(CameraRollViewTest, RightAfterOptIn); FRIEND_TEST_ALL_PREFIXES(CameraRollViewTest, ViewLayout); + FRIEND_TEST_ALL_PREFIXES(CameraRollViewTest, ImageThumbnail); + FRIEND_TEST_ALL_PREFIXES(CameraRollViewTest, VideoThumbnail); class CameraRollItemsView : public views::View { public: @@ -64,6 +67,8 @@ private: FRIEND_TEST_ALL_PREFIXES(CameraRollViewTest, ViewLayout); + FRIEND_TEST_ALL_PREFIXES(CameraRollViewTest, ImageThumbnail); + FRIEND_TEST_ALL_PREFIXES(CameraRollViewTest, VideoThumbnail); gfx::Point GetCameraRollItemPosition(int index); void CalculateIdealBounds();
diff --git a/ash/system/phonehub/camera_roll_view_unittest.cc b/ash/system/phonehub/camera_roll_view_unittest.cc index 3e6b55f..29e0571 100644 --- a/ash/system/phonehub/camera_roll_view_unittest.cc +++ b/ash/system/phonehub/camera_roll_view_unittest.cc
@@ -7,6 +7,8 @@ #include "ash/components/phonehub/camera_roll_item.h" #include "ash/components/phonehub/fake_camera_roll_manager.h" #include "ash/components/phonehub/fake_user_action_recorder.h" +#include "ash/system/phonehub/camera_roll_thumbnail.h" +#include "ash/system/phonehub/phone_hub_metrics.h" #include "ash/test/ash_test_base.h" #include "camera_roll_view.h" #include "third_party/skia/include/core/SkBitmap.h" @@ -85,12 +87,36 @@ return items; } - const views::View* GetItemsView() const { - return camera_roll_view()->children().at(1); + const std::vector<phonehub::CameraRollItem> CreateSingleItemWithType( + bool is_video) { + phonehub::proto::CameraRollItemMetadata metadata; + metadata.set_key("key"); + metadata.set_last_modified_millis(1577865600); + metadata.set_file_size_bytes(123456); + + if (is_video) { + metadata.set_mime_type("video/mp4"); + metadata.set_file_name("fake_video.mp4"); + } else { + metadata.set_mime_type("image/png"); + metadata.set_file_name("fake_image.png"); + } + + SkBitmap bitmap; + bitmap.allocN32Pixels(1, 1); + gfx::Image thumbnail = gfx::Image::CreateFrom1xBitmap(bitmap); + + return std::vector<phonehub::CameraRollItem>{ + phonehub::CameraRollItem(metadata, thumbnail)}; } - const views::MenuButton* GetThumbnailView(int index) const { - return static_cast<views::MenuButton*>( + CameraRollView::CameraRollItemsView* GetItemsView() const { + return static_cast<CameraRollView::CameraRollItemsView*>( + camera_roll_view()->children().at(1)); + } + + CameraRollThumbnail* GetThumbnailView(int index) const { + return static_cast<CameraRollThumbnail*>( GetItemsView()->children().at(index)); } @@ -189,16 +215,12 @@ // Test the layout size and positions of the items. If the layout is being // intentionally changed this test will need to be updated. fake_camera_roll_manager()->SetCurrentItems(CreateFakeItems(4)); - EXPECT_EQ(camera_roll_view()->items_view_->CalculatePreferredSize(), - gfx::Size(328, 82)); - EXPECT_EQ(camera_roll_view()->items_view_->GetCameraRollItemPosition(0), - gfx::Point(4, 4)); - EXPECT_EQ(camera_roll_view()->items_view_->GetCameraRollItemPosition(1), - gfx::Point(86, 4)); - EXPECT_EQ(camera_roll_view()->items_view_->GetCameraRollItemPosition(2), - gfx::Point(168, 4)); - EXPECT_EQ(camera_roll_view()->items_view_->GetCameraRollItemPosition(3), - gfx::Point(250, 4)); + GetItemsView()->Layout(); + EXPECT_EQ(GetItemsView()->CalculatePreferredSize(), gfx::Size(328, 82)); + EXPECT_EQ(GetThumbnailView(0)->bounds(), gfx::Rect(4, 4, 74, 74)); + EXPECT_EQ(GetThumbnailView(1)->bounds(), gfx::Rect(86, 4, 74, 74)); + EXPECT_EQ(GetThumbnailView(2)->bounds(), gfx::Rect(168, 4, 74, 74)); + EXPECT_EQ(GetThumbnailView(3)->bounds(), gfx::Rect(250, 4, 74, 74)); } TEST_F(CameraRollViewTest, AccessibleNameAndTooltip) { @@ -216,4 +238,24 @@ EXPECT_EQ(u"Recent photo 4 of 4.", GetThumbnailView(3)->GetTooltipText()); } +TEST_F(CameraRollViewTest, ImageThumbnail) { + PresetCameraRollOptInState(/*has_been_dismissed=*/true, + /*can_be_enabled=*/false); + fake_camera_roll_manager()->SetCurrentItems( + CreateSingleItemWithType(/*is_video=*/false)); + + EXPECT_EQ(GetThumbnailView(0)->GetMediaType(), + phone_hub_metrics::CameraRollMediaType::kPhoto); +} + +TEST_F(CameraRollViewTest, VideoThumbnail) { + PresetCameraRollOptInState(/*has_been_dismissed=*/true, + /*can_be_enabled=*/false); + fake_camera_roll_manager()->SetCurrentItems( + CreateSingleItemWithType(/*is_video=*/true)); + + EXPECT_EQ(GetThumbnailView(0)->GetMediaType(), + phone_hub_metrics::CameraRollMediaType::kVideo); +} + } // namespace ash
diff --git a/ash/system/toast/toast_manager_impl.cc b/ash/system/toast/toast_manager_impl.cc index b11bf73..9eb56c7 100644 --- a/ash/system/toast/toast_manager_impl.cc +++ b/ash/system/toast/toast_manager_impl.cc
@@ -31,7 +31,11 @@ DCHECK(!id.empty()); if (current_toast_data_ && current_toast_data_->id == id) { - // TODO(yoshiki): Replaces the visible toast. + // Replace the visible toast by adding the new toast data to the front of + // the queue and hiding the visible toast. Once the visible toast finishes + // hiding, the new toast will be displayed. + queue_.emplace_front(data); + overlay_->Show(false); return; }
diff --git a/ash/system/toast/toast_manager_unittest.cc b/ash/system/toast/toast_manager_unittest.cc index 75172c4..bc000e4 100644 --- a/ash/system/toast/toast_manager_unittest.cc +++ b/ash/system/toast/toast_manager_unittest.cc
@@ -7,6 +7,7 @@ #include <string> #include "ash/public/cpp/shelf_config.h" +#include "ash/public/cpp/toast_data.h" #include "ash/root_window_controller.h" #include "ash/screen_util.h" #include "ash/session/session_controller_impl.h" @@ -121,6 +122,14 @@ void CancelToast(const std::string& id) { manager()->Cancel(id); } + void ReplaceToast(const std::string& id, + const std::string& text, + int32_t duration, + bool visible_on_lock_screen = false) { + manager()->Show(ToastData(id, base::ASCIIToUTF16(text), duration, + visible_on_lock_screen)); + } + void ChangeLockState(bool lock) { SessionInfo info; info.state = lock ? session_manager::SessionState::LOCKED @@ -451,6 +460,54 @@ EXPECT_EQ(2, GetToastSerial()); } +TEST_F(ToastManagerImplTest, ReplaceContentsOfQueuedToast) { + std::string id1 = ShowToast(/*text=*/"TEXT1", ToastData::kInfiniteDuration); + std::string id2 = ShowToast(/*text=*/"TEXT2", ToastData::kInfiniteDuration); + + // Confirm that the first toast is shown. + EXPECT_EQ(u"TEXT1", GetCurrentText()); + EXPECT_EQ(1, GetToastSerial()); + + // Replace the contents of the queued toast. + ReplaceToast(id2, /*text=*/"TEXT2_updated", ToastData::kInfiniteDuration); + + // Confirm that the shown toast is still visible. + EXPECT_EQ(u"TEXT1", GetCurrentText()); + EXPECT_EQ(1, GetToastSerial()); + + // Cancel the shown toast. + CancelToast(id1); + + // Confirm that the next toast is visible with the updated text. + EXPECT_EQ(u"TEXT2_updated", GetCurrentText()); + EXPECT_EQ(2, GetToastSerial()); +} + +TEST_F(ToastManagerImplTest, ReplaceContentsOfCurrentToast) { + std::string id1 = ShowToast(/*text=*/"TEXT1", ToastData::kInfiniteDuration); + std::string id2 = ShowToast(/*text=*/"TEXT2", ToastData::kInfiniteDuration); + + // Confirm that the first toast is shown. + EXPECT_EQ(u"TEXT1", GetCurrentText()); + EXPECT_EQ(1, GetToastSerial()); + + // Replace the contents of the current toast showing. + ReplaceToast(id1, /*text=*/"TEXT1_updated", ToastData::kInfiniteDuration); + + // Confirm that the new toast content is visible. The toast serial should be + // different, indicating the original toast's timeout won't close the new + // toast's. + EXPECT_EQ(u"TEXT1_updated", GetCurrentText()); + EXPECT_EQ(2, GetToastSerial()); + + // Cancel the shown toast. + CancelToast(id1); + + // Confirm that the second toast is now showing. + EXPECT_EQ(u"TEXT2", GetCurrentText()); + EXPECT_EQ(3, GetToastSerial()); +} + TEST_F(ToastManagerImplTest, ShowToastOnLockScreen) { // Simulate device lock. ChangeLockState(true);
diff --git a/ash/system/tray/tray_background_view.cc b/ash/system/tray/tray_background_view.cc index 60704e08..4620d02 100644 --- a/ash/system/tray/tray_background_view.cc +++ b/ash/system/tray/tray_background_view.cc
@@ -82,6 +82,9 @@ // TrayBackgroundView. const base::TimeDelta kShowAnimationDelayMs = base::Milliseconds(100); +// Number of active requests to disable CloseBubble(). +int g_disable_close_bubble_on_window_activated = 0; + // Switches left and right insets if RTL mode is active. void MirrorInsetsIfNecessary(gfx::Insets* insets) { if (base::i18n::IsRTL()) { @@ -339,6 +342,18 @@ weak_factory_.GetWeakPtr())); } +base::ScopedClosureRunner +TrayBackgroundView::DisableCloseBubbleOnWindowActivated() { + ++g_disable_close_bubble_on_window_activated; + return base::ScopedClosureRunner( + base::BindOnce([]() { --g_disable_close_bubble_on_window_activated; })); +} + +// static +bool TrayBackgroundView::ShouldCloseBubbleOnWindowActivated() { + return g_disable_close_bubble_on_window_activated == 0; +} + void TrayBackgroundView::UpdateStatusArea(bool should_log_visible_pod_count) { auto* status_area_widget = shelf_->GetStatusAreaWidget(); if (status_area_widget) { @@ -451,8 +466,6 @@ return nullptr; } -void TrayBackgroundView::CloseBubble() {} - void TrayBackgroundView::ShowBubble() {} void TrayBackgroundView::CalculateTargetBounds() {
diff --git a/ash/system/tray/tray_background_view.h b/ash/system/tray/tray_background_view.h index f16a0e6..32e2a42 100644 --- a/ash/system/tray/tray_background_view.h +++ b/ash/system/tray/tray_background_view.h
@@ -68,9 +68,16 @@ // returns nullptr. virtual views::Widget* GetBubbleWidget() const; + // Returns a lock that prevents window activation from closing bubbles. + static base::ScopedClosureRunner DisableCloseBubbleOnWindowActivated() + WARN_UNUSED_RESULT; + + // Whether a window activation change should close bubbles. + static bool ShouldCloseBubbleOnWindowActivated(); + // Closes the associated tray bubble view if it exists and is currently // showing. - virtual void CloseBubble(); + virtual void CloseBubble() {} // Shows the associated tray bubble if one exists. |show_by_click| indicates // whether the showing operation is initiated by mouse or gesture click.
diff --git a/ash/system/tray/tray_bubble_wrapper.cc b/ash/system/tray/tray_bubble_wrapper.cc index 5cca406..4d64131 100644 --- a/ash/system/tray/tray_bubble_wrapper.cc +++ b/ash/system/tray/tray_bubble_wrapper.cc
@@ -85,6 +85,10 @@ if (!gained_active) return; + // Check for the CloseBubble() lock. + if (!TrayBackgroundView::ShouldCloseBubbleOnWindowActivated()) + return; + views::Widget* bubble_widget = bubble_view()->GetWidget(); // Don't close the bubble if a transient child is gaining or losing // activation.
diff --git a/ash/system/unified/unified_system_tray_bubble.cc b/ash/system/unified/unified_system_tray_bubble.cc index 74da714..bd2b0f9d 100644 --- a/ash/system/unified/unified_system_tray_bubble.cc +++ b/ash/system/unified/unified_system_tray_bubble.cc
@@ -9,6 +9,7 @@ #include "ash/shell.h" #include "ash/system/message_center/unified_message_center_bubble.h" #include "ash/system/status_area_widget.h" +#include "ash/system/tray/tray_background_view.h" #include "ash/system/tray/tray_constants.h" #include "ash/system/tray/tray_event_filter.h" #include "ash/system/tray/tray_utils.h" @@ -280,6 +281,10 @@ if (!gained_active || !bubble_widget_) return; + // Check for the CloseBubble() lock. + if (!TrayBackgroundView::ShouldCloseBubbleOnWindowActivated()) + return; + // Don't close the bubble if a transient child is gaining or losing // activation. if (bubble_widget_ == views::Widget::GetWidgetForNativeView(gained_active) ||
diff --git a/ash/webui/personalization_app/BUILD.gn b/ash/webui/personalization_app/BUILD.gn index f116736a..ad4948f 100644 --- a/ash/webui/personalization_app/BUILD.gn +++ b/ash/webui/personalization_app/BUILD.gn
@@ -9,6 +9,7 @@ static_library("personalization_app") { sources = [ + "personalization_app_theme_provider.h", "personalization_app_ui.cc", "personalization_app_ui.h", "personalization_app_url_constants.cc", @@ -46,6 +47,8 @@ source_set("browser_test_support") { testonly = true sources = [ + "test/fake_personalization_app_theme_provider.cc", + "test/fake_personalization_app_theme_provider.h", "test/fake_personalization_app_wallpaper_provider.cc", "test/fake_personalization_app_wallpaper_provider.h", "test/personalization_app_browsertest_fixture.cc",
diff --git a/ash/webui/personalization_app/mojom/personalization_app.mojom b/ash/webui/personalization_app/mojom/personalization_app.mojom index ed87407..823ad4e8 100644 --- a/ash/webui/personalization_app/mojom/personalization_app.mojom +++ b/ash/webui/personalization_app/mojom/personalization_app.mojom
@@ -176,3 +176,20 @@ // called in preview mode. CancelPreviewWallpaper(); }; + +// Receives information whenever there are theme related changes such as color +// mode. +interface ThemeObserver { + // Triggered by |ColorModeObserver::OnColorModeChanged|. Retrieves information + // whether dark mode is enabled + OnColorModeChanged(bool dark_mode_enabled); +}; + +// Provides APIs to expose theme settings such dark/light color mode. +interface ThemeProvider { + // Binds a listener to start receiving updates on color mode changes. + SetThemeObserver(pending_remote<ThemeObserver> observer); + + // Disables or enables dark color mode. + SetColorModePref(bool dark_mode_enabled); +};
diff --git a/ash/webui/personalization_app/personalization_app_theme_provider.h b/ash/webui/personalization_app/personalization_app_theme_provider.h new file mode 100644 index 0000000..6f7e51d45 --- /dev/null +++ b/ash/webui/personalization_app/personalization_app_theme_provider.h
@@ -0,0 +1,23 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef ASH_WEBUI_PERSONALIZATION_APP_PERSONALIZATION_APP_THEME_PROVIDER_H_ +#define ASH_WEBUI_PERSONALIZATION_APP_PERSONALIZATION_APP_THEME_PROVIDER_H_ + +#include "ash/webui/personalization_app/mojom/personalization_app.mojom.h" +#include "mojo/public/cpp/bindings/pending_receiver.h" + +namespace ash { + +class PersonalizationAppThemeProvider + : public personalization_app::mojom::ThemeProvider { + public: + virtual void BindInterface( + mojo::PendingReceiver<personalization_app::mojom::ThemeProvider> + receiver) = 0; +}; + +} // namespace ash + +#endif // ASH_WEBUI_PERSONALIZATION_APP_PERSONALIZATION_APP_THEME_PROVIDER_H_
diff --git a/ash/webui/personalization_app/personalization_app_ui.cc b/ash/webui/personalization_app/personalization_app_ui.cc index 46f5336..fe32c0c 100644 --- a/ash/webui/personalization_app/personalization_app_ui.cc +++ b/ash/webui/personalization_app/personalization_app_ui.cc
@@ -7,6 +7,7 @@ #include "ash/constants/ash_features.h" #include "ash/grit/ash_personalization_app_resources.h" #include "ash/grit/ash_personalization_app_resources_map.h" +#include "ash/webui/personalization_app/personalization_app_theme_provider.h" #include "ash/webui/personalization_app/personalization_app_url_constants.h" #include "ash/webui/personalization_app/personalization_app_wallpaper_provider.h" #include "base/strings/strcat.h" @@ -77,7 +78,10 @@ {"exitFullscreen", IDS_PERSONALIZATION_APP_EXIT_FULL_SCREEN}, {"ariaLabelExitFullscreen", IDS_PERSONALIZATION_APP_ARIA_LABEL_EXIT_FULL_SCREEN}, - {"setAsWallpaper", IDS_PERSONALIZATION_APP_SET_AS_WALLPAPER}}; + {"setAsWallpaper", IDS_PERSONALIZATION_APP_SET_AS_WALLPAPER}, + {"themeLabel", IDS_PERSONALIZATION_APP_THEME_LABEL}, + {"darkColorMode", IDS_PERSONALIZATION_APP_THEME_DARK_COLOR_MODE}, + {"lightColorMode", IDS_PERSONALIZATION_APP_THEME_LIGHT_COLOR_MODE}}; source->AddLocalizedStrings(kLocalizedStrings); if (features::IsWallpaperGooglePhotosIntegrationEnabled()) { @@ -113,8 +117,10 @@ PersonalizationAppUI::PersonalizationAppUI( content::WebUI* web_ui, + std::unique_ptr<PersonalizationAppThemeProvider> theme_provider, std::unique_ptr<PersonalizationAppWallpaperProvider> wallpaper_provider) : ui::MojoWebUIController(web_ui), + theme_provider_(std::move(theme_provider)), wallpaper_provider_(std::move(wallpaper_provider)) { DCHECK(wallpaper_provider_); @@ -148,6 +154,11 @@ PersonalizationAppUI::~PersonalizationAppUI() = default; void PersonalizationAppUI::BindInterface( + mojo::PendingReceiver<personalization_app::mojom::ThemeProvider> receiver) { + theme_provider_->BindInterface(std::move(receiver)); +} + +void PersonalizationAppUI::BindInterface( mojo::PendingReceiver<personalization_app::mojom::WallpaperProvider> receiver) { wallpaper_provider_->BindInterface(std::move(receiver));
diff --git a/ash/webui/personalization_app/personalization_app_ui.h b/ash/webui/personalization_app/personalization_app_ui.h index 7ba53dff..df962c5 100644 --- a/ash/webui/personalization_app/personalization_app_ui.h +++ b/ash/webui/personalization_app/personalization_app_ui.h
@@ -13,12 +13,14 @@ namespace ash { +class PersonalizationAppThemeProvider; class PersonalizationAppWallpaperProvider; class PersonalizationAppUI : public ui::MojoWebUIController { public: PersonalizationAppUI( content::WebUI* web_ui, + std::unique_ptr<PersonalizationAppThemeProvider> theme_provider, std::unique_ptr<PersonalizationAppWallpaperProvider> wallpaper_provider); PersonalizationAppUI(const PersonalizationAppUI&) = delete; @@ -27,10 +29,15 @@ ~PersonalizationAppUI() override; void BindInterface( + mojo::PendingReceiver<personalization_app::mojom::ThemeProvider> + receiver); + + void BindInterface( mojo::PendingReceiver<personalization_app::mojom::WallpaperProvider> receiver); private: + std::unique_ptr<PersonalizationAppThemeProvider> theme_provider_; std::unique_ptr<PersonalizationAppWallpaperProvider> wallpaper_provider_; WEB_UI_CONTROLLER_TYPE_DECL();
diff --git a/ash/webui/personalization_app/resources/BUILD.gn b/ash/webui/personalization_app/resources/BUILD.gn index fa11e78..fd6fb87 100644 --- a/ash/webui/personalization_app/resources/BUILD.gn +++ b/ash/webui/personalization_app/resources/BUILD.gn
@@ -28,6 +28,11 @@ "trusted/personalization_state.ts", "trusted/personalization_store.ts", "trusted/personalization_test_api.ts", + "trusted/theme/theme_actions.ts", + "trusted/theme/theme_controller.ts", + "trusted/theme/theme_interface_provider.ts", + "trusted/theme/theme_reducers.ts", + "trusted/theme/theme_state.ts", "trusted/utils.ts", "trusted/wallpaper/wallpaper_actions.ts", "trusted/wallpaper/untrusted_message_handler.ts", @@ -49,6 +54,7 @@ "trusted/ambient/ambient_subpage_element.ts", "trusted/personalization_main_element.ts", "trusted/personalization_router_element.ts", + "trusted/personalization_theme_element.ts", "trusted/personalization_toast_element.ts", "trusted/personalization_breadcrumb_element.ts", "trusted/user/user_subpage_element.ts",
diff --git a/ash/webui/personalization_app/resources/common/styles.html b/ash/webui/personalization_app/resources/common/styles.html index 6e39147..78291781c 100644 --- a/ash/webui/personalization_app/resources/common/styles.html +++ b/ash/webui/personalization_app/resources/common/styles.html
@@ -218,5 +218,17 @@ width: 100%; z-index: 1; } + cr-button { + border-color: var(--cros-button-stroke-color-secondary); + border-radius: 16px; + } + cr-button[aria-pressed=true], + cr-button[aria-selected=true] { + background-color: var(--cros-highlight-color); + border: 0; + } + cr-button + cr-button { + margin-inline-start: 8px; + } </style> </template>
diff --git a/ash/webui/personalization_app/resources/trusted/personalization_actions.ts b/ash/webui/personalization_app/resources/trusted/personalization_actions.ts index 19fd56e..2b20806 100644 --- a/ash/webui/personalization_app/resources/trusted/personalization_actions.ts +++ b/ash/webui/personalization_app/resources/trusted/personalization_actions.ts
@@ -4,6 +4,7 @@ import {Action} from 'chrome://resources/js/cr/ui/store.js'; +import {ThemeActions} from './theme/theme_actions.js'; import {WallpaperActions} from './wallpaper/wallpaper_actions.js'; /** @@ -24,4 +25,4 @@ return {name: PersonalizationActionName.DISMISS_ERROR}; } -export type Actions = WallpaperActions|DismissErrorAction; +export type Actions = ThemeActions|WallpaperActions|DismissErrorAction;
diff --git a/ash/webui/personalization_app/resources/trusted/personalization_app.ts b/ash/webui/personalization_app/resources/trusted/personalization_app.ts index bda9b01..47f7209 100644 --- a/ash/webui/personalization_app/resources/trusted/personalization_app.ts +++ b/ash/webui/personalization_app/resources/trusted/personalization_app.ts
@@ -15,6 +15,7 @@ import './personalization_toast_element.js'; import './personalization_breadcrumb_element.js'; import './personalization_main_element.js'; +import './personalization_theme_element.js'; import './user/user_subpage_element.js'; import './wallpaper/wallpaper_subpage.js'; import {emptyState} from './personalization_state.js';
diff --git a/ash/webui/personalization_app/resources/trusted/personalization_main_element.html b/ash/webui/personalization_app/resources/trusted/personalization_main_element.html index a5102cee..407ed982 100644 --- a/ash/webui/personalization_app/resources/trusted/personalization_main_element.html +++ b/ash/webui/personalization_app/resources/trusted/personalization_main_element.html
@@ -2,4 +2,7 @@ <div id="container"> <h1>Personalization</h1> <wallpaper-preview></wallpaper-preview> -</div> \ No newline at end of file + <template is="dom-if" if="[[isDarkLightModeEnabled_()]]"> + <personalization-theme></personalization-theme> + </template> +</div>
diff --git a/ash/webui/personalization_app/resources/trusted/personalization_main_element.ts b/ash/webui/personalization_app/resources/trusted/personalization_main_element.ts index 56c5806..a11c8ce 100644 --- a/ash/webui/personalization_app/resources/trusted/personalization_main_element.ts +++ b/ash/webui/personalization_app/resources/trusted/personalization_main_element.ts
@@ -7,6 +7,7 @@ * the personalization hub. */ +import {loadTimeData} from 'chrome://resources/js/load_time_data.m.js'; import {html} from 'chrome://resources/polymer/v3_0/polymer/polymer_bundled.min.js'; import {WithPersonalizationStore} from './personalization_store.js'; @@ -22,6 +23,10 @@ static get properties() { return {}; } + + private isDarkLightModeEnabled_(): boolean { + return loadTimeData.getBoolean('isDarkLightModeEnabled'); + } } customElements.define(PersonalizationMain.is, PersonalizationMain);
diff --git a/ash/webui/personalization_app/resources/trusted/personalization_reducers.ts b/ash/webui/personalization_app/resources/trusted/personalization_reducers.ts index 138336d..aff2da7 100644 --- a/ash/webui/personalization_app/resources/trusted/personalization_reducers.ts +++ b/ash/webui/personalization_app/resources/trusted/personalization_reducers.ts
@@ -16,6 +16,8 @@ import {Actions} from './personalization_actions.js'; import {WallpaperImage} from './personalization_app.mojom-webui.js'; import {PersonalizationState} from './personalization_state.js'; +import {themeReducers} from './theme/theme_reducers.js'; +import {ThemeState} from './theme/theme_state.js'; import {WallpaperActionName} from './wallpaper/wallpaper_actions.js'; import {wallpaperReducers} from './wallpaper/wallpaper_reducers.js'; import {WallpaperState} from './wallpaper/wallpaper_state.js'; @@ -94,6 +96,7 @@ } const root = combineReducers<PersonalizationState>({ + theme: combineReducers<ThemeState>(themeReducers), wallpaper: combineReducers<WallpaperState>(wallpaperReducers), error: errorReducer, });
diff --git a/ash/webui/personalization_app/resources/trusted/personalization_state.ts b/ash/webui/personalization_app/resources/trusted/personalization_state.ts index 3232e2c..6fb0278 100644 --- a/ash/webui/personalization_app/resources/trusted/personalization_state.ts +++ b/ash/webui/personalization_app/resources/trusted/personalization_state.ts
@@ -2,16 +2,19 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +import {emptyState as emptyThemeState, ThemeState} from './theme/theme_state.js'; import {emptyState as emptyWallpaperState, WallpaperState} from './wallpaper/wallpaper_state.js'; export interface PersonalizationState { wallpaper: WallpaperState; error: string|null; + theme: ThemeState; } export function emptyState(): PersonalizationState { return { wallpaper: emptyWallpaperState(), error: null, + theme: emptyThemeState(), }; }
diff --git a/ash/webui/personalization_app/resources/trusted/personalization_theme_element.html b/ash/webui/personalization_app/resources/trusted/personalization_theme_element.html new file mode 100644 index 0000000..7f6a9c4 --- /dev/null +++ b/ash/webui/personalization_app/resources/trusted/personalization_theme_element.html
@@ -0,0 +1,15 @@ +<style include="common-style"> +</style> +<div id="container"> + <h2>[[i18n('themeLabel')]]</h2> + <cr-button id="lightMode" data-color-mode="LIGHT" + on-click="onClickColorModeButton_" + aria-pressed$="[[getLightAriaPressed_(darkModeEnabled_)]]"> + <div class="text">[[i18n('lightColorMode')]]</div> + </cr-button> + <cr-button id="darkMode" data-color-mode="DARK" + on-click="onClickColorModeButton_" + aria-pressed$="[[getDarkAriaPressed_(darkModeEnabled_)]]"> + <div class="text">[[i18n('darkColorMode')]]</div> + </cr-button> +</div>
diff --git a/ash/webui/personalization_app/resources/trusted/personalization_theme_element.ts b/ash/webui/personalization_app/resources/trusted/personalization_theme_element.ts new file mode 100644 index 0000000..7e561a0 --- /dev/null +++ b/ash/webui/personalization_app/resources/trusted/personalization_theme_element.ts
@@ -0,0 +1,94 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * @fileoverview This component displays color mode settings. + */ + +import 'chrome://resources/cr_elements/cr_button/cr_button.m.js'; +import 'chrome://resources/polymer/v3_0/iron-icon/iron-icon.js'; +import 'chrome://resources/polymer/v3_0/iron-iconset-svg/iron-iconset-svg.js'; +import '../common/styles.js'; + +import {html} from 'chrome://resources/polymer/v3_0/polymer/polymer_bundled.min.js'; + +import {ThemeObserverInterface, ThemeObserverReceiver, ThemeProviderInterface} from './personalization_app.mojom-webui.js'; +import {WithPersonalizationStore} from './personalization_store.js'; +import {setDarkModeEnabledAction} from './theme/theme_actions.js'; +import {setColorModePref} from './theme/theme_controller.js'; +import {getThemeProvider} from './theme/theme_interface_provider.js'; + +/** + * Set up the observer to listen for color mode changes. + */ +function initThemeObserver( + themeProvider: ThemeProviderInterface, + target: ThemeObserverInterface): ThemeObserverReceiver { + const receiver = new ThemeObserverReceiver(target); + themeProvider.setThemeObserver(receiver.$.bindNewPipeAndPassRemote()); + return receiver; +} + +export class PersonalizationThemeElement extends WithPersonalizationStore + implements ThemeObserverInterface { + static get is() { + return 'personalization-theme'; + } + + static get template() { + return html`{__html_template__}`; + } + + static get properties() { + return { + darkModeEnabled_: Boolean, + }; + } + + private darkModeEnabled_: boolean; + private themeProvider_: ThemeProviderInterface; + private themeObserver_: ThemeObserverReceiver|null; + + constructor() { + super(); + this.themeProvider_ = getThemeProvider(); + this.themeObserver_ = null; + } + + connectedCallback() { + super.connectedCallback(); + this.themeObserver_ = initThemeObserver(this.themeProvider_, this); + this.watch<PersonalizationThemeElement['darkModeEnabled_']>( + 'darkModeEnabled_', state => state.theme.darkModeEnabled); + this.updateFromStore(); + } + + disconnectedCallback() { + if (this.themeObserver_) { + this.themeObserver_.$.close(); + } + } + + onColorModeChanged(darkModeEnabled: boolean) { + this.dispatch(setDarkModeEnabledAction(darkModeEnabled)); + } + + private getLightAriaPressed_(darkModeEnabled: boolean) { + return (!darkModeEnabled).toString(); + } + + private getDarkAriaPressed_(darkModeEnabled: boolean) { + return darkModeEnabled.toString(); + } + + private onClickColorModeButton_(event: Event) { + const eventTarget = event.currentTarget as HTMLElement; + const colorMode = eventTarget.dataset['colorMode']; + setColorModePref( + colorMode === 'DARK', this.themeProvider_, this.getStore()); + } +} + +customElements.define( + PersonalizationThemeElement.is, PersonalizationThemeElement);
diff --git a/ash/webui/personalization_app/resources/trusted/theme/theme_actions.ts b/ash/webui/personalization_app/resources/trusted/theme/theme_actions.ts new file mode 100644 index 0000000..66439e1 --- /dev/null +++ b/ash/webui/personalization_app/resources/trusted/theme/theme_actions.ts
@@ -0,0 +1,25 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import {Action} from 'chrome://resources/js/cr/ui/store.js'; + +/** + * @fileoverview Defines the actions to change theme state. + */ + +export enum ThemeActionName { + SET_DARK_MODE_ENABLED = 'set_dark_mode_enabled', +} + +export type ThemeActions = SetDarkModeEnabledAction; + +export type SetDarkModeEnabledAction = Action&{ + name: ThemeActionName.SET_DARK_MODE_ENABLED; + enabled: boolean; +}; + +export function setDarkModeEnabledAction(enabled: boolean): + SetDarkModeEnabledAction { + return {name: ThemeActionName.SET_DARK_MODE_ENABLED, enabled}; +}
diff --git a/ash/webui/personalization_app/resources/trusted/theme/theme_controller.ts b/ash/webui/personalization_app/resources/trusted/theme/theme_controller.ts new file mode 100644 index 0000000..33639cd --- /dev/null +++ b/ash/webui/personalization_app/resources/trusted/theme/theme_controller.ts
@@ -0,0 +1,22 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import {ThemeProviderInterface} from '../personalization_app.mojom-webui.js'; +import {PersonalizationStore} from '../personalization_store.js'; +import {setDarkModeEnabledAction} from './theme_actions.js'; + +/** + * @fileoverview contains all of the functions to interact with C++ side through + * mojom calls. Handles setting |PersonalizationStore| state in response to + * mojom data. + */ + +// Disables or enables dark color mode. +export async function setColorModePref( + darkModeEnabled: boolean, provider: ThemeProviderInterface, + store: PersonalizationStore): Promise<void> { + await provider.setColorModePref(darkModeEnabled); + // Dispatch action to highlight color mode. + store.dispatch(setDarkModeEnabledAction(darkModeEnabled)); +}
diff --git a/ash/webui/personalization_app/resources/trusted/theme/theme_interface_provider.ts b/ash/webui/personalization_app/resources/trusted/theme/theme_interface_provider.ts new file mode 100644 index 0000000..d6b6f42 --- /dev/null +++ b/ash/webui/personalization_app/resources/trusted/theme/theme_interface_provider.ts
@@ -0,0 +1,29 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * @fileoverview a singleton getter for the theme mojom interface used in + * the Personalization SWA. Also contains utility functions around fetching + * mojom data and mocking out the implementation for testing. + */ + +import 'chrome://resources/mojo/mojo/public/js/bindings.js'; +import 'chrome://resources/mojo/url/mojom/url.mojom-webui.js'; + +import {ThemeProvider, ThemeProviderInterface} from '../personalization_app.mojom-webui.js'; + +let themeProvider: ThemeProviderInterface|null = null; + +export function setThemeProviderForTesting( + testProvider: ThemeProviderInterface): void { + themeProvider = testProvider; +} + +/** Returns a singleton for the ThemeProvider mojom interface. */ +export function getThemeProvider(): ThemeProviderInterface { + if (!themeProvider) { + themeProvider = ThemeProvider.getRemote(); + } + return themeProvider; +}
diff --git a/ash/webui/personalization_app/resources/trusted/theme/theme_reducers.ts b/ash/webui/personalization_app/resources/trusted/theme/theme_reducers.ts new file mode 100644 index 0000000..08b46f8 --- /dev/null +++ b/ash/webui/personalization_app/resources/trusted/theme/theme_reducers.ts
@@ -0,0 +1,26 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import {Actions} from '../personalization_actions.js'; +import {ReducerFunction} from '../personalization_reducers.js'; +import {PersonalizationState} from '../personalization_state.js'; + +import {ThemeActionName} from './theme_actions.js'; +import {ThemeState} from './theme_state.js'; + +export function darkModeEnabledReducer( + state: ThemeState['darkModeEnabled'], action: Actions, + _: PersonalizationState): ThemeState['darkModeEnabled'] { + switch (action.name) { + case ThemeActionName.SET_DARK_MODE_ENABLED: + return action.enabled; + default: + return state; + } +} + +export const themeReducers: + {[K in keyof ThemeState]: ReducerFunction<ThemeState[K]>} = { + darkModeEnabled: darkModeEnabledReducer, + };
diff --git a/ash/webui/personalization_app/resources/trusted/theme/theme_state.ts b/ash/webui/personalization_app/resources/trusted/theme/theme_state.ts new file mode 100644 index 0000000..f24d7db --- /dev/null +++ b/ash/webui/personalization_app/resources/trusted/theme/theme_state.ts
@@ -0,0 +1,16 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * Stores theme related states. + */ +export interface ThemeState { + darkModeEnabled: boolean; +} + +export function emptyState(): ThemeState { + return { + darkModeEnabled: false, + }; +}
diff --git a/ash/webui/personalization_app/resources/trusted/wallpaper/styles.html b/ash/webui/personalization_app/resources/trusted/wallpaper/styles.html index 040474e..782b9ddf 100644 --- a/ash/webui/personalization_app/resources/trusted/wallpaper/styles.html +++ b/ash/webui/personalization_app/resources/trusted/wallpaper/styles.html
@@ -25,17 +25,5 @@ iframe:focus-within { outline: none; } - cr-button { - border-color: var(--cros-button-stroke-color-secondary); - border-radius: 16px; - } - cr-button[aria-pressed=true], - cr-button[aria-selected=true] { - background-color: var(--cros-highlight-color); - border: 0; - } - cr-button + cr-button { - margin-inline-start: 8px; - } </style> </template>
diff --git a/ash/webui/personalization_app/test/fake_personalization_app_theme_provider.cc b/ash/webui/personalization_app/test/fake_personalization_app_theme_provider.cc new file mode 100644 index 0000000..4f1677f3 --- /dev/null +++ b/ash/webui/personalization_app/test/fake_personalization_app_theme_provider.cc
@@ -0,0 +1,27 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "ash/webui/personalization_app/test/fake_personalization_app_theme_provider.h" + +FakePersonalizationAppThemeProvider::FakePersonalizationAppThemeProvider( + content::WebUI* web_ui) {} + +FakePersonalizationAppThemeProvider::~FakePersonalizationAppThemeProvider() = + default; + +void FakePersonalizationAppThemeProvider::BindInterface( + mojo::PendingReceiver<ash::personalization_app::mojom::ThemeProvider> + receiver) { + theme_receiver_.reset(); + theme_receiver_.Bind(std::move(receiver)); +} + +void FakePersonalizationAppThemeProvider::SetThemeObserver( + mojo::PendingRemote<ash::personalization_app::mojom::ThemeObserver> + observer) {} + +void FakePersonalizationAppThemeProvider::SetColorModePref( + bool dark_mode_enabled) { + return; +}
diff --git a/ash/webui/personalization_app/test/fake_personalization_app_theme_provider.h b/ash/webui/personalization_app/test/fake_personalization_app_theme_provider.h new file mode 100644 index 0000000..a918967 --- /dev/null +++ b/ash/webui/personalization_app/test/fake_personalization_app_theme_provider.h
@@ -0,0 +1,49 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef ASH_WEBUI_PERSONALIZATION_APP_TEST_FAKE_PERSONALIZATION_APP_THEME_PROVIDER_H_ +#define ASH_WEBUI_PERSONALIZATION_APP_TEST_FAKE_PERSONALIZATION_APP_THEME_PROVIDER_H_ + +#include "ash/webui/personalization_app/personalization_app_theme_provider.h" + +#include <stdint.h> + +#include "ash/public/cpp/wallpaper/wallpaper_types.h" +#include "ash/webui/personalization_app/mojom/personalization_app.mojom.h" +#include "base/unguessable_token.h" +#include "mojo/public/cpp/bindings/pending_receiver.h" +#include "mojo/public/cpp/bindings/receiver.h" + +namespace content { +class WebUI; +} // namespace content + +class FakePersonalizationAppThemeProvider + : public ash::PersonalizationAppThemeProvider { + public: + explicit FakePersonalizationAppThemeProvider(content::WebUI* web_ui); + + FakePersonalizationAppThemeProvider( + const FakePersonalizationAppThemeProvider&) = delete; + FakePersonalizationAppThemeProvider& operator=( + const FakePersonalizationAppThemeProvider&) = delete; + + ~FakePersonalizationAppThemeProvider() override; + + void BindInterface( + mojo::PendingReceiver<ash::personalization_app::mojom::ThemeProvider> + receiver) override; + + void SetThemeObserver( + mojo::PendingRemote<ash::personalization_app::mojom::ThemeObserver> + observer) override; + + void SetColorModePref(bool dark_mode_enabled) override; + + private: + mojo::Receiver<ash::personalization_app::mojom::ThemeProvider> + theme_receiver_{this}; +}; + +#endif // ASH_WEBUI_PERSONALIZATION_APP_TEST_FAKE_PERSONALIZATION_APP_THEME_PROVIDER_H_
diff --git a/ash/webui/personalization_app/test/personalization_app_browsertest.js b/ash/webui/personalization_app/test/personalization_app_browsertest.js index ce609d7..9035c54 100644 --- a/ash/webui/personalization_app/test/personalization_app_browsertest.js +++ b/ash/webui/personalization_app/test/personalization_app_browsertest.js
@@ -9,6 +9,7 @@ GEN('#include "ash/webui/personalization_app/test/personalization_app_browsertest_fixture.h"'); GEN('#include "ash/constants/ash_features.h"'); +GEN('#include "chromeos/constants/chromeos_features.h"'); GEN('#include "content/public/test/browser_test.h"'); const ROOT_PAGE = 'chrome://personalization/'; @@ -24,7 +25,9 @@ get featureList() { return { enabled: [ - 'ash::features::kWallpaperWebUI', 'ash::features::kPersonalizationHub' + 'ash::features::kWallpaperWebUI', + 'ash::features::kPersonalizationHub', + 'chromeos::features::kDarkLightMode', ] }; } @@ -62,6 +65,20 @@ testDone(); }); +TEST_F('PersonalizationAppBrowserTest', 'ShowsThemeButtons', () => { + const theme = document.querySelector('personalization-router') + .shadowRoot.querySelector('personalization-main') + .shadowRoot.querySelector('personalization-theme'); + + const lightButton = theme.shadowRoot.getElementById('lightMode'); + assertTrue(!!lightButton); + assertEquals(lightButton.getAttribute('aria-pressed'), 'true'); + const darkButton = theme.shadowRoot.getElementById('darkMode'); + assertTrue(!!darkButton); + assertEquals(darkButton.getAttribute('aria-pressed'), 'false'); + testDone(); +}); + class WallpaperSubpageBrowserTest extends PersonalizationAppBrowserTest { /** @override */ get browsePreload() {
diff --git a/ash/webui/personalization_app/test/personalization_app_browsertest_fixture.cc b/ash/webui/personalization_app/test/personalization_app_browsertest_fixture.cc index 0f4dbfc..81ff842 100644 --- a/ash/webui/personalization_app/test/personalization_app_browsertest_fixture.cc +++ b/ash/webui/personalization_app/test/personalization_app_browsertest_fixture.cc
@@ -8,16 +8,19 @@ #include "ash/webui/personalization_app/personalization_app_ui.h" #include "ash/webui/personalization_app/personalization_app_url_constants.h" +#include "ash/webui/personalization_app/test/fake_personalization_app_theme_provider.h" #include "ash/webui/personalization_app/test/fake_personalization_app_wallpaper_provider.h" #include "chrome/test/base/mojo_web_ui_browser_test.h" std::unique_ptr<content::WebUIController> TestPersonalizationAppWebUIProvider::NewWebUI(content::WebUI* web_ui, const GURL& url) { + auto theme_provider = + std::make_unique<FakePersonalizationAppThemeProvider>(web_ui); auto wallpaper_provider = std::make_unique<FakePersonalizationAppWallpaperProvider>(web_ui); return std::make_unique<ash::PersonalizationAppUI>( - web_ui, std::move(wallpaper_provider)); + web_ui, std::move(theme_provider), std::move(wallpaper_provider)); } void PersonalizationAppBrowserTestFixture::SetUpOnMainThread() {
diff --git a/ash/wm/desks/templates/desks_templates_item_view.cc b/ash/wm/desks/templates/desks_templates_item_view.cc index 90eb0d73..ca3c21b 100644 --- a/ash/wm/desks/templates/desks_templates_item_view.cc +++ b/ash/wm/desks/templates/desks_templates_item_view.cc
@@ -331,8 +331,18 @@ updated_template->set_template_name(name_view_->GetText()); OnTemplateNameChanged(updated_template->template_name()); - DesksTemplatesPresenter::Get()->SaveOrUpdateDeskTemplate( - /*is_update=*/false, std::move(updated_template)); + // Calling `SaveOrUpdateDeskTemplate` will trigger rebuilding the desks + // templates grid views hierarchy which includes `this`. Use a post task as + // some other `ViewObserver`'s may still be using `this`. + // TODO(crbug.com/1266552): Remove the post task once saving and updating does + // not cause a `this` to be deleted anymore. + base::ThreadTaskRunnerHandle::Get()->PostTask( + FROM_HERE, base::BindOnce( + [](std::unique_ptr<DeskTemplate> desk_template) { + DesksTemplatesPresenter::Get()->SaveOrUpdateDeskTemplate( + /*is_update=*/false, std::move(desk_template)); + }, + std::move(updated_template))); } views::Button::KeyClickAction DesksTemplatesItemView::GetKeyClickActionForEvent(
diff --git a/ash/wm/desks/templates/desks_templates_presenter.cc b/ash/wm/desks/templates/desks_templates_presenter.cc index 1aacb2d..6337eb6 100644 --- a/ash/wm/desks/templates/desks_templates_presenter.cc +++ b/ash/wm/desks/templates/desks_templates_presenter.cc
@@ -234,6 +234,11 @@ if (status != desks_storage::DeskModel::GetEntryByUuidStatus::kOk) return; + // `CreateAndActivateNewDeskForTemplate` may destroy `this`. Copy the member + // variables to a local to prevent UAF. See https://crbug.com/1284138. + base::OnceClosure on_update_ui_closure_for_testing = + std::move(on_update_ui_closure_for_testing_); + // Launch the windows as specified in the template to a new desk. // Calling `CreateAndActivateNewDeskForTemplate` results in exiting overview // mode, which means the presenter doesn't exist anymore on callback (since it @@ -244,8 +249,8 @@ template_name, base::BindOnce(&OnNewDeskCreatedForTemplate, std::move(entry))); - if (on_update_ui_closure_for_testing_) - std::move(on_update_ui_closure_for_testing_).Run(); + if (on_update_ui_closure_for_testing) + std::move(on_update_ui_closure_for_testing).Run(); RecordLaunchTemplateHistogram(); }
diff --git a/ash/wm/overview/overview_grid.h b/ash/wm/overview/overview_grid.h index 1dcdb7ee..f85d611f 100644 --- a/ash/wm/overview/overview_grid.h +++ b/ash/wm/overview/overview_grid.h
@@ -337,10 +337,6 @@ // `was_zero_state` is true then we will expand the desks bar. void ShowDesksTemplatesGrid(bool was_zero_state); - // Hides the overview mode items that must be closed for the grid of desks - // templates to show. - void HideForDesksTemplatesGrid(); - // Hides the grid of desks templates and reshow the overview items. Updates // the templates button if we are not exiting overview. void HideDesksTemplatesGrid(bool exit_overview);
diff --git a/ash/wm/overview/overview_item.cc b/ash/wm/overview/overview_item.cc index 229fd78..e5cd36bc 100644 --- a/ash/wm/overview/overview_item.cc +++ b/ash/wm/overview/overview_item.cc
@@ -218,8 +218,8 @@ DCHECK(item_widget_); PerformFadeOutLayer(item_widget_->GetLayer()); - for (aura::Window* transient_child : - GetTransientTreeIterator(transform_window_.window())) { + for (aura::Window* transient_child : GetTransientTreeIterator(GetWindow())) { + transient_child->SetProperty(kForceVisibleInMiniViewKey, true); PerformFadeOutLayer(transient_child->layer()); }
diff --git a/base/BUILD.gn b/base/BUILD.gn index 64ea2b1..250dde1 100644 --- a/base/BUILD.gn +++ b/base/BUILD.gn
@@ -1394,7 +1394,6 @@ "//build:branding_buildflags", "//build:chromecast_buildflags", "//build:chromeos_buildflags", - "//build:os_buildflags", "//build/config/compiler:compiler_buildflags", "//third_party/modp_b64", ] @@ -2231,6 +2230,7 @@ "files/file_util_mac.mm", "mac/backup_util.h", "mac/backup_util.mm", + "mac/bridging.h", "mac/bundle_locations.h", "mac/bundle_locations.mm", "mac/call_with_eh_frame.cc", @@ -3346,7 +3346,6 @@ "//base/third_party/dynamic_annotations", "//build:chromecast_buildflags", "//build:chromeos_buildflags", - "//build:os_buildflags", "//testing/gmock", "//testing/gtest", "//third_party/icu",
diff --git a/base/files/file_util_unittest.cc b/base/files/file_util_unittest.cc index 68117262..0860d64 100644 --- a/base/files/file_util_unittest.cc +++ b/base/files/file_util_unittest.cc
@@ -46,7 +46,6 @@ #include "base/time/time.h" #include "build/build_config.h" #include "build/chromeos_buildflags.h" -#include "build/os_buildflags.h" #include "testing/gtest/include/gtest/gtest.h" #include "testing/multiprocess_func_list.h" #include "testing/platform_test.h"
diff --git a/base/logging.cc b/base/logging.cc index b910fd6e..6d9b6cf 100644 --- a/base/logging.cc +++ b/base/logging.cc
@@ -28,7 +28,6 @@ #include "base/task/common/task_annotator.h" #include "base/trace_event/base_tracing.h" #include "build/build_config.h" -#include "build/os_buildflags.h" #if defined(OS_WIN) #include <io.h>
diff --git a/base/logging.h b/base/logging.h index b6ab61b..5ec3939 100644 --- a/base/logging.h +++ b/base/logging.h
@@ -20,7 +20,6 @@ #include "base/strings/string_piece_forward.h" #include "build/build_config.h" #include "build/chromeos_buildflags.h" -#include "build/os_buildflags.h" #if BUILDFLAG(IS_CHROMEOS) #include <cstdio>
diff --git a/base/mac/bridging.h b/base/mac/bridging.h new file mode 100644 index 0000000..49467dd --- /dev/null +++ b/base/mac/bridging.h
@@ -0,0 +1,193 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef BASE_MAC_BRIDGING_H_ +#define BASE_MAC_BRIDGING_H_ + +#include <CoreText/CoreText.h> +#import <Foundation/Foundation.h> + +#include "base/base_export.h" +#include "base/check.h" +#include "build/build_config.h" + +#if defined(OS_IOS) +#import <UIKit/UIKit.h> +#endif + +#if defined(OS_MAC) +#import <AppKit/AppKit.h> +#endif + +#if !defined(__has_feature) || !__has_feature(objc_arc) +#error "base/mac/bridging.h requires ARC." +#endif + +// These functions convert pointers of bridged CFTypes to NSTypes and +// vice-versa. They come in two flavors: those that transfer ownership +// (`OwnershipCast`) and those that just convert the pointer (`PtrCast`). +// +// Examples: +// +// Ownership transference (as in `CFBridgingRetain`/`Release`): +// CFStringRef cf_string = CFStringCreateWithCString(...); +// NSString* ns_string = CFToNSOwnershipCast(cf_string); +// // At this point, `cf_string` does not need releasing. +// +// NSString* ns_string = [[NSString alloc] initWithString:...]; +// CFStringRef cf_string = NSToCFOwnershipCast(ns_string); +// // At this point, `cf_string` must be released. +// +// Pointer conversion (as in `__bridge`): +// // `cf_data` is some `CFDataRef` from somewhere. +// NSImage* ns_image = [[NSImage alloc] initWithData:CFToNSPtrCast(cf_data)]; +// +// // `ns_data` is some `NSData *` from somewhere. +// SecKeyRef sec_key = SecKeyCreateFromData(..., NSToCFPtrCast(ns_data), ...); +// +// The reason to use these functions (rather than using `__bridge` and +// `CFBridgingRetain`/`Release`) is because they are type-safe. The OS-provided +// bridging calls do not type check, while these calls do the appropriate type +// checking via the magic of macros. +// +// Implementation note: Why not templates? Type checking in Core Foundation +// involves functions named in a specific pattern, and only macro token pasting +// works for this purpose. + +#define CF_TO_NS_CAST_IMPL(TypeCF, TypeNS) \ + namespace base::mac { \ + inline BASE_EXPORT TypeNS* _Nullable CFToNSOwnershipCast( \ + TypeCF##Ref CF_CONSUMED _Nullable cf_val) { \ + DCHECK(!cf_val || TypeCF##GetTypeID() == CFGetTypeID(cf_val)); \ + return (__bridge_transfer TypeNS*)cf_val; \ + } \ + inline BASE_EXPORT CF_RETURNS_RETAINED \ + TypeCF##Ref _Nullable NSToCFOwnershipCast(TypeNS* _Nullable ns_val) { \ + TypeCF##Ref cf_val = (__bridge_retained TypeCF##Ref)ns_val; \ + DCHECK(!cf_val || TypeCF##GetTypeID() == CFGetTypeID(cf_val)); \ + return cf_val; \ + } \ + inline BASE_EXPORT TypeNS* _Nullable CFToNSPtrCast( \ + TypeCF##Ref _Nullable cf_val) { \ + DCHECK(!cf_val || TypeCF##GetTypeID() == CFGetTypeID(cf_val)); \ + return (__bridge TypeNS*)cf_val; \ + } \ + inline BASE_EXPORT TypeCF##Ref _Nullable NSToCFPtrCast( \ + TypeNS* _Nullable ns_val) { \ + TypeCF##Ref cf_val = (__bridge TypeCF##Ref)ns_val; \ + DCHECK(!cf_val || TypeCF##GetTypeID() == CFGetTypeID(cf_val)); \ + return cf_val; \ + } \ + } + +#define CF_TO_NS_MUTABLE_CAST_IMPL(name) \ + CF_TO_NS_CAST_IMPL(CF##name, NS##name) \ + namespace base::mac { \ + inline BASE_EXPORT NSMutable##name* _Nullable CFToNSOwnershipCast( \ + CFMutable##name##Ref CF_CONSUMED _Nullable cf_val) { \ + DCHECK(!cf_val || CF##name##GetTypeID() == CFGetTypeID(cf_val)); \ + return (__bridge_transfer NSMutable##name*)cf_val; \ + } \ + inline BASE_EXPORT CF_RETURNS_RETAINED \ + CFMutable##name##Ref _Nullable NSToCFOwnershipCast( \ + NSMutable##name* _Nullable ns_val) { \ + CFMutable##name##Ref cf_val = \ + (__bridge_retained CFMutable##name##Ref)ns_val; \ + DCHECK(!cf_val || CF##name##GetTypeID() == CFGetTypeID(cf_val)); \ + return cf_val; \ + } \ + inline BASE_EXPORT NSMutable##name* _Nullable CFToNSPtrCast( \ + CFMutable##name##Ref _Nullable cf_val) { \ + DCHECK(!cf_val || CF##name##GetTypeID() == CFGetTypeID(cf_val)); \ + return (__bridge NSMutable##name*)cf_val; \ + } \ + inline BASE_EXPORT CFMutable##name##Ref _Nullable NSToCFPtrCast( \ + NSMutable##name* _Nullable ns_val) { \ + CFMutable##name##Ref cf_val = (__bridge CFMutable##name##Ref)ns_val; \ + DCHECK(!cf_val || CF##name##GetTypeID() == CFGetTypeID(cf_val)); \ + return cf_val; \ + } \ + } + +// List of toll-free bridged types taken from: +// https://web.archive.org/web/20111124025525/http://www.cocoadev.com/index.pl?TollFreeBridged +// https://developer.apple.com/library/archive/documentation/CoreFoundation/Conceptual/CFDesignConcepts/Articles/tollFreeBridgedTypes.html#//apple_ref/doc/uid/TP40010677-SW4 + +// Foundation +CF_TO_NS_MUTABLE_CAST_IMPL(Array) +CF_TO_NS_MUTABLE_CAST_IMPL(AttributedString) +CF_TO_NS_CAST_IMPL(CFCalendar, NSCalendar) +CF_TO_NS_MUTABLE_CAST_IMPL(CharacterSet) +CF_TO_NS_MUTABLE_CAST_IMPL(Data) +CF_TO_NS_CAST_IMPL(CFDate, NSDate) +CF_TO_NS_MUTABLE_CAST_IMPL(Dictionary) +CF_TO_NS_CAST_IMPL(CFError, NSError) +CF_TO_NS_CAST_IMPL(CFLocale, NSLocale) +CF_TO_NS_CAST_IMPL(CFNumber, NSNumber) +CF_TO_NS_CAST_IMPL(CFRunLoopTimer, NSTimer) +CF_TO_NS_CAST_IMPL(CFTimeZone, NSTimeZone) +CF_TO_NS_MUTABLE_CAST_IMPL(Set) +CF_TO_NS_CAST_IMPL(CFReadStream, NSInputStream) +CF_TO_NS_CAST_IMPL(CFWriteStream, NSOutputStream) +CF_TO_NS_MUTABLE_CAST_IMPL(String) +CF_TO_NS_CAST_IMPL(CFURL, NSURL) + +// AppKit / UIKit +#if defined(OS_IOS) +CF_TO_NS_CAST_IMPL(CTFont, UIFont) +#else +// The NSFont/CTFont toll-free bridging is broken before 10.15. +// http://www.openradar.me/15341349 rdar://15341349 +// +// TODO(https://crbug.com/1076527): This is fixed in 10.15. When 10.15 is the +// minimum OS for Chromium, remove this specialization and replace it with just: +// +// CF_TO_NS_CAST_IMPL(CTFont, NSFont) + +extern "C" { +Boolean _CFIsObjC(CFTypeID typeID, CFTypeRef obj); +} // extern "C" + +namespace base::mac { + +inline BASE_EXPORT NSFont* _Nullable CFToNSOwnershipCast( + CTFontRef CF_CONSUMED _Nullable cf_val) { + NSFont* ns_val = (__bridge_transfer NSFont*)cf_val; + DCHECK(!cf_val || CTFontGetTypeID() == CFGetTypeID(cf_val) || + (_CFIsObjC(CTFontGetTypeID(), cf_val) && + [ns_val isKindOfClass:[NSFont class]])); + return ns_val; +} + +inline BASE_EXPORT CF_RETURNS_RETAINED _Nullable CTFontRef NSToCFOwnershipCast( + NSFont* _Nullable ns_val) { + CTFontRef cf_val = (__bridge_retained CTFontRef)ns_val; + DCHECK(!cf_val || CTFontGetTypeID() == CFGetTypeID(cf_val) || + [ns_val isKindOfClass:[NSFont class]]); + return cf_val; +} + +inline BASE_EXPORT _Nullable NSFont* CFToNSPtrCast(CTFontRef _Nullable cf_val) { + NSFont* ns_val = (__bridge NSFont*)cf_val; + DCHECK(!cf_val || CTFontGetTypeID() == CFGetTypeID(cf_val) || + (_CFIsObjC(CTFontGetTypeID(), cf_val) && + [ns_val isKindOfClass:[NSFont class]])); + return ns_val; +} + +inline BASE_EXPORT _Nullable CTFontRef NSToCFPtrCast(NSFont* _Nullable ns_val) { + CTFontRef cf_val = (__bridge CTFontRef)ns_val; + DCHECK(!cf_val || CTFontGetTypeID() == CFGetTypeID(cf_val) || + [ns_val isKindOfClass:[NSFont class]]); + return cf_val; +} + +} // namespace base::mac + +#endif + +#undef CF_TO_NS_CAST_IMPL +#undef CF_TO_NS_MUTABLE_CAST_IMPL + +#endif // BASE_MAC_BRIDGING_H_
diff --git a/base/mac/foundation_util.h b/base/mac/foundation_util.h index 28f1d7a..f20ec23 100644 --- a/base/mac/foundation_util.h +++ b/base/mac/foundation_util.h
@@ -153,6 +153,10 @@ } // namespace base::mac +// These casting functions cannot be implemented in a way that will work with +// ARC. Use the casting functions in base/mac/bridging.h instead. +#if !defined(__has_feature) || !__has_feature(objc_arc) + #if !defined(__OBJC__) #define OBJC_CPP_CLASS_DECL(x) class x; #else // __OBJC__ @@ -224,6 +228,8 @@ #undef CF_TO_NS_MUTABLE_CAST_DECL #undef OBJC_CPP_CLASS_DECL +#endif // !defined(__has_feature) || !__has_feature(objc_arc) + namespace base::mac { // CFCast<>() and CFCastStrict<>() cast a basic CFTypeRef to a more
diff --git a/base/process/process_util_unittest.cc b/base/process/process_util_unittest.cc index c4abf41..dfc69405 100644 --- a/base/process/process_util_unittest.cc +++ b/base/process/process_util_unittest.cc
@@ -36,7 +36,6 @@ #include "base/threading/simple_thread.h" #include "base/threading/thread.h" #include "build/build_config.h" -#include "build/os_buildflags.h" #include "testing/gtest/include/gtest/gtest.h" #include "testing/multiprocess_func_list.h"
diff --git a/base/task/thread_pool/task_tracker.cc b/base/task/thread_pool/task_tracker.cc index 6ec32c84..56aa70c 100644 --- a/base/task/thread_pool/task_tracker.cc +++ b/base/task/thread_pool/task_tracker.cc
@@ -446,7 +446,8 @@ } RegisteredTaskSource TaskTracker::RunAndPopNextTask( - RegisteredTaskSource task_source) { + RegisteredTaskSource task_source, + base::Location* posted_from) { DCHECK(task_source); const bool should_run_tasks = BeforeRunTask(task_source->shutdown_behavior()); @@ -462,6 +463,8 @@ } if (task) { + if (posted_from) + *posted_from = task->posted_from; // Run the |task| (whether it's a worker task or the Clear() closure). RunTask(std::move(task.value()), task_source.get(), traits); }
diff --git a/base/task/thread_pool/task_tracker.h b/base/task/thread_pool/task_tracker.h index b2c68855..1db4ef6 100644 --- a/base/task/thread_pool/task_tracker.h +++ b/base/task/thread_pool/task_tracker.h
@@ -118,7 +118,11 @@ // (which indicates that it should be reenqueued). WillPostTask() must have // allowed the task in front of |task_source| to be posted before this is // called. - RegisteredTaskSource RunAndPopNextTask(RegisteredTaskSource task_source); + // |posted_from| is optionally used to capture base::Location of the task ran + // for investigation of memory corruption. + // TODO(crbug.com/1218384): Remove |posted_from| once resolved. + RegisteredTaskSource RunAndPopNextTask(RegisteredTaskSource task_source, + base::Location* posted_from = nullptr); // Returns true once shutdown has started (StartShutdown() was called). // Note: sequential consistency with the thread calling StartShutdown() isn't
diff --git a/base/task/thread_pool/worker_thread.cc b/base/task/thread_pool/worker_thread.cc index 4ebe5426..3bd4944e 100644 --- a/base/task/thread_pool/worker_thread.cc +++ b/base/task/thread_pool/worker_thread.cc
@@ -375,12 +375,15 @@ // Alias pointer for investigation of memory corruption. crbug.com/1218384 TaskSource* task_source_before_run = task_source.get(); base::debug::Alias(&task_source_before_run); + base::Location posted_from; + task_source = + task_tracker_->RunAndPopNextTask(std::move(task_source), &posted_from); - task_source = task_tracker_->RunAndPopNextTask(std::move(task_source)); - - // Alias pointer for investigation of memory corruption. crbug.com/1218384 + // Alias pointer and posted_from for investigation of memory corruption. + // crbug.com/1218384 TaskSource* task_source_before_move = task_source.get(); base::debug::Alias(&task_source_before_move); + DEBUG_ALIAS_FOR_CSTR(posted_from_str, posted_from.ToString().c_str(), 128); delegate_->DidProcessTask(std::move(task_source));
diff --git a/base/test/BUILD.gn b/base/test/BUILD.gn index f626b80f..372dc70 100644 --- a/base/test/BUILD.gn +++ b/base/test/BUILD.gn
@@ -163,7 +163,6 @@ deps = [ "//base/third_party/dynamic_annotations", "//build:chromeos_buildflags", - "//build:os_buildflags", "//third_party/icu:icuuc", "//third_party/libxml:libxml_utils", "//third_party/libxml:xml_reader",
diff --git a/base/test/perf_test_suite.cc b/base/test/perf_test_suite.cc index 15c3ae2..63c6f41 100644 --- a/base/test/perf_test_suite.cc +++ b/base/test/perf_test_suite.cc
@@ -11,7 +11,7 @@ #include "base/process/launch.h" #include "base/strings/string_util.h" #include "base/test/perf_log.h" -#include "build/os_buildflags.h" +#include "build/build_config.h" #include "testing/gtest/include/gtest/gtest.h" #if BUILDFLAG(IS_FUCHSIA)
diff --git a/base/test/test_suite.cc b/base/test/test_suite.cc index 62b14d0..94b9f289 100644 --- a/base/test/test_suite.cc +++ b/base/test/test_suite.cc
@@ -46,7 +46,6 @@ #include "base/time/time.h" #include "base/tracing_buildflags.h" #include "build/build_config.h" -#include "build/os_buildflags.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" #include "testing/multiprocess_func_list.h"
diff --git a/base/tracing/protos/chrome_track_event.proto b/base/tracing/protos/chrome_track_event.proto index d856269..1102a59 100644 --- a/base/tracing/protos/chrome_track_event.proto +++ b/base/tracing/protos/chrome_track_event.proto
@@ -615,6 +615,39 @@ optional FrameType frame_type = 4; } +message EventLatency { + enum EventType { + MOUSE_PRESSED = 1; + MOUSE_RELEASED = 2; + MOUSE_WHEEL = 3; + KEY_PRESSED = 4; + KEY_RELEASED = 5; + TOUCH_PRESSED = 6; + TOUCH_RELEASED = 7; + TOUCH_MOVED = 8; + GESTURE_SCROLL_BEGIN = 9; + GESTURE_SCROLL_UPDATE = 10; + GESTURE_SCROLL_END = 11; + GESTURE_DOUBLE_TAP = 12; + GESTURE_LONG_PRESS = 13; + GESTURE_LONG_TAP = 14; + GESTURE_SHOW_PRESS = 15; + GESTURE_TAP = 16; + GESTURE_TAP_CANCEL = 17; + GESTURE_TAP_DOWN = 18; + GESTURE_TAP_UNCONFIRMED = 19; + GESTURE_TWO_FINGER_TAP = 20; + FIRST_GESTURE_SCROLL_UPDATE = 21; + MOUSE_DRAGGED = 22; + GESTURE_PINCH_BEGIN = 23; + GESTURE_PINCH_END = 24; + GESTURE_PINCH_UPDATE = 25; + INERTIAL_GESTURE_SCROLL_UPDATE = 26; + } + + optional EventType event_type = 1; +} + message ChromeTrackEvent { // Extension range for Chrome: 1000-1999 // Next ID: 1032 @@ -687,5 +720,7 @@ optional RendererMainThreadTaskExecution renderer_main_thread_task_execution = 1031; + + optional EventLatency event_latency = 1032; } }
diff --git a/build/BUILD.gn b/build/BUILD.gn index d18e914..93759ef 100644 --- a/build/BUILD.gn +++ b/build/BUILD.gn
@@ -45,19 +45,3 @@ "IS_CHROMEOS_WITH_HW_DETAILS=$is_chromeos_with_hw_details", ] } - -buildflag_header("os_buildflags") { - header = "os_buildflags.h" - flags = [ - "IS_ANDROID=$is_android", - "IS_CHROMEOS=$is_chromeos", - "IS_FUCHSIA=$is_fuchsia", - "IS_IOS=$is_ios", - "IS_LINUX=$is_linux", - "IS_MAC=$is_mac", - "IS_NACL=$is_nacl", - "IS_WIN=$is_win", - "IS_APPLE=$is_apple", - "IS_POSIX=$is_posix", - ] -}
diff --git a/build/android/gyp/compile_java.py b/build/android/gyp/compile_java.py index cbce3a4..842ca18 100755 --- a/build/android/gyp/compile_java.py +++ b/build/android/gyp/compile_java.py
@@ -482,7 +482,7 @@ if enable_partial_javac: all_changed_paths_are_java = all( - [p.endswith(".java") for p in changes.IterChangedPaths()]) + p.endswith(".java") for p in changes.IterChangedPaths()) if (all_changed_paths_are_java and not changes.HasStringChanges() and os.path.exists(jar_path) and (jar_info_path is None or os.path.exists(jar_info_path))):
diff --git a/build/android/gyp/java_cpp_enum.py b/build/android/gyp/java_cpp_enum.py index 62f1162a..ca89abce 100755 --- a/build/android/gyp/java_cpp_enum.py +++ b/build/android/gyp/java_cpp_enum.py
@@ -96,7 +96,7 @@ 'k' + self.original_enum_name] for prefix in prefixes: - if all([w.startswith(prefix) for w in self.entries.keys()]): + if all(w.startswith(prefix) for w in self.entries.keys()): prefix_to_strip = prefix break else:
diff --git a/build/android/gyp/java_google_api_keys.py b/build/android/gyp/java_google_api_keys.py index 9a75f038..d2bd34f 100755 --- a/build/android/gyp/java_google_api_keys.py +++ b/build/android/gyp/java_google_api_keys.py
@@ -95,8 +95,6 @@ values = {} values['GOOGLE_API_KEY'] = google_api_keys.GetAPIKey() - values['GOOGLE_API_KEY_PHYSICAL_WEB_TEST'] = ( - google_api_keys.GetAPIKeyAndroidNonStable()) values['GOOGLE_API_KEY_ANDROID_NON_STABLE'] = ( google_api_keys.GetAPIKeyAndroidNonStable()) values['GOOGLE_CLIENT_ID_MAIN'] = google_api_keys.GetClientID('MAIN')
diff --git a/build/build_config.h b/build/build_config.h index 48edcf4bd..1b25276 100644 --- a/build/build_config.h +++ b/build/build_config.h
@@ -2,19 +2,29 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -// This file adds defines about the platform we're currently building on. +// This file doesn't belong to any GN target by design for faster build and +// less developer overhead. + +// This file adds build flags about the OS we're currently building on. They are +// defined directly in this file instead of via a `buildflag_header` target in a +// GN file for faster build. They are defined using the corresponding OS defines +// (e.g. OS_WIN) which are also defined in this file (except for OS_CHROMEOS, +// which is set by the build system). These defines are deprecated and should +// NOT be used directly. For example: +// Please Use: #if BUILDFLAG(IS_WIN) +// Deprecated: #if defined(OS_WIN) // // Operating System: -// OS_AIX / OS_ANDROID / OS_ASMJS / OS_FREEBSD / OS_FUCHSIA / OS_IOS / -// OS_LINUX / OS_MAC / OS_NACL / OS_NETBSD / OS_OPENBSD / -// OS_QNX / OS_SOLARIS / OS_WIN +// IS_AIX / IS_ANDROID / IS_ASMJS / IS_FREEBSD / IS_FUCHSIA / IS_IOS / +// IS_LINUX / IS_MAC / IS_NACL / IS_NETBSD / IS_OPENBSD / +// IS_QNX / IS_SOLARIS / IS_WIN // Operating System family: -// OS_APPLE: IOS or MAC -// OS_BSD: FREEBSD or NETBSD or OPENBSD -// OS_POSIX: AIX or ANDROID or ASMJS or CHROMEOS or FREEBSD or IOS or LINUX +// IS_APPLE: IOS or MAC +// IS_BSD: FREEBSD or NETBSD or OPENBSD +// IS_POSIX: AIX or ANDROID or ASMJS or CHROMEOS or FREEBSD or IOS or LINUX // or MAC or NACL or NETBSD or OPENBSD or QNX or SOLARIS -// -// /!\ Note: OS_CHROMEOS is set by the build system, not this file + +// This file also adds defines specific to the platform, architecture etc. // // Compiler: // COMPILER_MSVC / COMPILER_GCC @@ -36,6 +46,8 @@ #ifndef BUILD_BUILD_CONFIG_H_ #define BUILD_BUILD_CONFIG_H_ +#include "build/buildflag.h" + // A set of macros to use for platform detection. #if defined(__native_client__) // __native_client__ must be first, so that other OS_ defines are not set. @@ -111,6 +123,115 @@ #define OS_POSIX 1 #endif +// OS build flags +#if defined(OS_AIX) +#define BUILDFLAG_INTERNAL_IS_AIX() (1) +#else +#define BUILDFLAG_INTERNAL_IS_AIX() (0) +#endif + +#if defined(OS_ANDROID) +#define BUILDFLAG_INTERNAL_IS_ANDROID() (1) +#else +#define BUILDFLAG_INTERNAL_IS_ANDROID() (0) +#endif + +#if defined(OS_APPLE) +#define BUILDFLAG_INTERNAL_IS_APPLE() (1) +#else +#define BUILDFLAG_INTERNAL_IS_APPLE() (0) +#endif + +#if defined(OS_ASMJS) +#define BUILDFLAG_INTERNAL_IS_ASMJS() (1) +#else +#define BUILDFLAG_INTERNAL_IS_ASMJS() (0) +#endif + +#if defined(OS_BSD) +#define BUILDFLAG_INTERNAL_IS_BSD() (1) +#else +#define BUILDFLAG_INTERNAL_IS_BSD() (0) +#endif + +#if defined(OS_CHROMEOS) +#define BUILDFLAG_INTERNAL_IS_CHROMEOS() (1) +#else +#define BUILDFLAG_INTERNAL_IS_CHROMEOS() (0) +#endif + +#if defined(OS_FREEBSD) +#define BUILDFLAG_INTERNAL_IS_FREEBSD() (1) +#else +#define BUILDFLAG_INTERNAL_IS_FREEBSD() (0) +#endif + +#if defined(OS_FUCHSIA) +#define BUILDFLAG_INTERNAL_IS_FUCHSIA() (1) +#else +#define BUILDFLAG_INTERNAL_IS_FUCHSIA() (0) +#endif + +#if defined(OS_IOS) +#define BUILDFLAG_INTERNAL_IS_IOS() (1) +#else +#define BUILDFLAG_INTERNAL_IS_IOS() (0) +#endif + +#if defined(OS_LINUX) +#define BUILDFLAG_INTERNAL_IS_LINUX() (1) +#else +#define BUILDFLAG_INTERNAL_IS_LINUX() (0) +#endif + +#if defined(OS_MAC) +#define BUILDFLAG_INTERNAL_IS_MAC() (1) +#else +#define BUILDFLAG_INTERNAL_IS_MAC() (0) +#endif + +#if defined(OS_NACL) +#define BUILDFLAG_INTERNAL_IS_NACL() (1) +#else +#define BUILDFLAG_INTERNAL_IS_NACL() (0) +#endif + +#if defined(OS_NETBSD) +#define BUILDFLAG_INTERNAL_IS_NETBSD() (1) +#else +#define BUILDFLAG_INTERNAL_IS_NETBSD() (0) +#endif + +#if defined(OS_OPENBSD) +#define BUILDFLAG_INTERNAL_IS_OPENBSD() (1) +#else +#define BUILDFLAG_INTERNAL_IS_OPENBSD() (0) +#endif + +#if defined(OS_POSIX) +#define BUILDFLAG_INTERNAL_IS_POSIX() (1) +#else +#define BUILDFLAG_INTERNAL_IS_POSIX() (0) +#endif + +#if defined(OS_QNX) +#define BUILDFLAG_INTERNAL_IS_QNX() (1) +#else +#define BUILDFLAG_INTERNAL_IS_QNX() (0) +#endif + +#if defined(OS_SOLARIS) +#define BUILDFLAG_INTERNAL_IS_SOLARIS() (1) +#else +#define BUILDFLAG_INTERNAL_IS_SOLARIS() (0) +#endif + +#if defined(OS_WIN) +#define BUILDFLAG_INTERNAL_IS_WIN() (1) +#else +#define BUILDFLAG_INTERNAL_IS_WIN() (0) +#endif + // Compiler detection. Note: clang masquerades as GCC on POSIX and as MSVC on // Windows. #if defined(__GNUC__)
diff --git a/build/fuchsia/linux.sdk.sha1 b/build/fuchsia/linux.sdk.sha1 index e360093..59be4ea 100644 --- a/build/fuchsia/linux.sdk.sha1 +++ b/build/fuchsia/linux.sdk.sha1
@@ -1 +1 @@ -7.20220107.1.1 +7.20220107.2.5
diff --git a/build/fuchsia/linux_internal.sdk.sha1 b/build/fuchsia/linux_internal.sdk.sha1 index e360093..8a9206c 100644 --- a/build/fuchsia/linux_internal.sdk.sha1 +++ b/build/fuchsia/linux_internal.sdk.sha1
@@ -1 +1 @@ -7.20220107.1.1 +7.20220107.2.4
diff --git a/build/fuchsia/mac.sdk.sha1 b/build/fuchsia/mac.sdk.sha1 index e360093..59be4ea 100644 --- a/build/fuchsia/mac.sdk.sha1 +++ b/build/fuchsia/mac.sdk.sha1
@@ -1 +1 @@ -7.20220107.1.1 +7.20220107.2.5
diff --git a/cc/layers/render_surface_impl.cc b/cc/layers/render_surface_impl.cc index 4a99f99..b3b5794 100644 --- a/cc/layers/render_surface_impl.cc +++ b/cc/layers/render_surface_impl.cc
@@ -170,7 +170,8 @@ bool RenderSurfaceImpl::CopyOfOutputRequired() const { return HasCopyRequest() || ShouldCacheRenderSurface() || - SubtreeCaptureId().is_valid(); + SubtreeCaptureId().is_valid() || + OwningEffectNode()->shared_element_resource_id.IsValid(); } int RenderSurfaceImpl::TransformTreeIndex() const {
diff --git a/cc/metrics/compositor_frame_reporter.cc b/cc/metrics/compositor_frame_reporter.cc index 19d10a8d..678364c 100644 --- a/cc/metrics/compositor_frame_reporter.cc +++ b/cc/metrics/compositor_frame_reporter.cc
@@ -348,6 +348,41 @@ return args.frame_time + (args.interval * 1.5); } +perfetto::protos::pbzero::EventLatency::EventType ToProtoEnum( + EventMetrics::EventType event_type) { +#define CASE(event_type, proto_event_type) \ + case EventMetrics::EventType::event_type: \ + return perfetto::protos::pbzero::EventLatency::proto_event_type + switch (event_type) { + CASE(kMousePressed, MOUSE_PRESSED); + CASE(kMouseReleased, MOUSE_RELEASED); + CASE(kMouseWheel, MOUSE_WHEEL); + CASE(kKeyPressed, KEY_PRESSED); + CASE(kKeyReleased, KEY_RELEASED); + CASE(kTouchPressed, TOUCH_PRESSED); + CASE(kTouchReleased, TOUCH_RELEASED); + CASE(kTouchMoved, TOUCH_MOVED); + CASE(kGestureScrollBegin, GESTURE_SCROLL_BEGIN); + CASE(kGestureScrollUpdate, GESTURE_SCROLL_UPDATE); + CASE(kGestureScrollEnd, GESTURE_SCROLL_END); + CASE(kGestureDoubleTap, GESTURE_DOUBLE_TAP); + CASE(kGestureLongPress, GESTURE_LONG_PRESS); + CASE(kGestureLongTap, GESTURE_LONG_TAP); + CASE(kGestureShowPress, GESTURE_SHOW_PRESS); + CASE(kGestureTap, GESTURE_TAP); + CASE(kGestureTapCancel, GESTURE_TAP_CANCEL); + CASE(kGestureTapDown, GESTURE_TAP_DOWN); + CASE(kGestureTapUnconfirmed, GESTURE_TAP_UNCONFIRMED); + CASE(kGestureTwoFingerTap, GESTURE_TWO_FINGER_TAP); + CASE(kFirstGestureScrollUpdate, FIRST_GESTURE_SCROLL_UPDATE); + CASE(kMouseDragged, MOUSE_DRAGGED); + CASE(kGesturePinchBegin, GESTURE_PINCH_BEGIN); + CASE(kGesturePinchEnd, GESTURE_PINCH_END); + CASE(kGesturePinchUpdate, GESTURE_PINCH_UPDATE); + CASE(kInertialGestureScrollUpdate, INERTIAL_GESTURE_SCROLL_UPDATE); + } +} + } // namespace // CompositorFrameReporter::ProcessedBlinkBreakdown::Iterator ================== @@ -1206,10 +1241,16 @@ event_metrics->GetDispatchStageTimestamp( EventMetrics::DispatchStage::kGenerated); - const auto trace_id = TRACE_ID_LOCAL(event_metrics.get()); - TRACE_EVENT_NESTABLE_ASYNC_BEGIN_WITH_TIMESTAMP1( - kTraceEventCategory, "EventLatency", trace_id, generated_timestamp, - "event", event_metrics->GetTypeName()); + const auto trace_track = + perfetto::Track(base::trace_event::GetNextGlobalTraceId()); + TRACE_EVENT_BEGIN( + kTraceEventCategory, "EventLatency", trace_track, generated_timestamp, + [&](perfetto::EventContext context) { + auto* event = + context.event<perfetto::protos::pbzero::ChromeTrackEvent>(); + auto* event_latency = event->set_event_latency(); + event_latency->set_event_type(ToProtoEnum(event_metrics->type())); + }); // Event dispatch stages. EventMetrics::DispatchStage dispatch_stage = @@ -1235,10 +1276,10 @@ const char* breakdown_name = GetEventLatencyDispatchBreakdownName(dispatch_stage, end_stage); - TRACE_EVENT_NESTABLE_ASYNC_BEGIN_WITH_TIMESTAMP0( - kTraceEventCategory, breakdown_name, trace_id, dispatch_timestamp); - TRACE_EVENT_NESTABLE_ASYNC_END_WITH_TIMESTAMP0( - kTraceEventCategory, breakdown_name, trace_id, end_timestamp); + TRACE_EVENT_BEGIN(kTraceEventCategory, + perfetto::StaticString{breakdown_name}, trace_track, + dispatch_timestamp); + TRACE_EVENT_END(kTraceEventCategory, trace_track, end_timestamp); dispatch_stage = end_stage; dispatch_timestamp = end_timestamp; @@ -1267,11 +1308,10 @@ const char* d2c_breakdown_name = GetEventLatencyDispatchToCompositorBreakdownName(dispatch_stage, stage_it->stage_type); - TRACE_EVENT_NESTABLE_ASYNC_BEGIN_WITH_TIMESTAMP0( - kTraceEventCategory, d2c_breakdown_name, trace_id, dispatch_timestamp); - TRACE_EVENT_NESTABLE_ASYNC_END_WITH_TIMESTAMP0(kTraceEventCategory, - d2c_breakdown_name, trace_id, - stage_it->start_time); + TRACE_EVENT_BEGIN(kTraceEventCategory, + perfetto::StaticString{d2c_breakdown_name}, trace_track, + dispatch_timestamp); + TRACE_EVENT_END(kTraceEventCategory, trace_track, stage_it->start_time); // Compositor stages. for (; stage_it != stage_history_.end(); ++stage_it) { @@ -1285,8 +1325,8 @@ continue; const char* stage_name = GetStageName(stage_type_index); - TRACE_EVENT_NESTABLE_ASYNC_BEGIN_WITH_TIMESTAMP0( - kTraceEventCategory, stage_name, trace_id, stage_it->start_time); + TRACE_EVENT_BEGIN(kTraceEventCategory, perfetto::StaticString{stage_name}, + trace_track, stage_it->start_time); if (stage_it->stage_type == StageType::kSubmitCompositorFrameToPresentationCompositorFrame) { @@ -1298,18 +1338,16 @@ if (start_time >= end_time) continue; const char* breakdown_name = GetVizBreakdownName(it.GetBreakdown()); - TRACE_EVENT_NESTABLE_ASYNC_BEGIN_WITH_TIMESTAMP0( - kTraceEventCategory, breakdown_name, trace_id, start_time); - TRACE_EVENT_NESTABLE_ASYNC_END_WITH_TIMESTAMP0( - kTraceEventCategory, breakdown_name, trace_id, end_time); + TRACE_EVENT_BEGIN(kTraceEventCategory, + perfetto::StaticString{breakdown_name}, trace_track, + start_time); + TRACE_EVENT_END(kTraceEventCategory, trace_track, end_time); } } - TRACE_EVENT_NESTABLE_ASYNC_END_WITH_TIMESTAMP0( - kTraceEventCategory, stage_name, trace_id, stage_it->end_time); + TRACE_EVENT_END(kTraceEventCategory, trace_track, stage_it->end_time); } - TRACE_EVENT_NESTABLE_ASYNC_END_WITH_TIMESTAMP0( - kTraceEventCategory, "EventLatency", trace_id, frame_termination_time_); + TRACE_EVENT_END(kTraceEventCategory, trace_track, frame_termination_time_); } }
diff --git a/cc/trees/layer_tree_host_impl.cc b/cc/trees/layer_tree_host_impl.cc index de7dba8..814e6fab 100644 --- a/cc/trees/layer_tree_host_impl.cc +++ b/cc/trees/layer_tree_host_impl.cc
@@ -1262,14 +1262,12 @@ const size_t surface_index = render_surface_list_size - 1 - i; RenderSurfaceImpl* render_surface = (*frame->render_surface_list)[surface_index]; - const auto& shared_element_id = - render_surface->GetDocumentTransitionSharedElementId(); const bool is_root_surface = render_surface->EffectTreeIndex() == EffectTree::kContentsRootNodeId; const bool should_draw_into_render_pass = is_root_surface || render_surface->contributes_to_drawn_surface() || - render_surface->CopyOfOutputRequired() || shared_element_id.valid(); + render_surface->CopyOfOutputRequired(); if (should_draw_into_render_pass) frame->render_passes.push_back(render_surface->CreateRenderPass()); }
diff --git a/chrome/VERSION b/chrome/VERSION index 29b6fb8..7a6bf14 100644 --- a/chrome/VERSION +++ b/chrome/VERSION
@@ -1,4 +1,4 @@ MAJOR=99 MINOR=0 -BUILD=4813 +BUILD=4814 PATCH=0
diff --git a/chrome/android/BUILD.gn b/chrome/android/BUILD.gn index ebec747..3a126ad7 100644 --- a/chrome/android/BUILD.gn +++ b/chrome/android/BUILD.gn
@@ -1389,6 +1389,7 @@ "//chrome/browser/signin/services/android:javatests", "//chrome/browser/sync/android:java", "//chrome/browser/sync/test/android:test_support_java", + "//chrome/browser/tab:critical_persisted_tab_data_flatbuffer_java", "//chrome/browser/tab:critical_persisted_tab_data_proto_java", "//chrome/browser/tab:java", "//chrome/browser/tab_group:java",
diff --git a/chrome/android/chrome_java_resources.gni b/chrome/android/chrome_java_resources.gni index 1824e362..2eeaac5 100644 --- a/chrome/android/chrome_java_resources.gni +++ b/chrome/android/chrome_java_resources.gni
@@ -547,8 +547,6 @@ "java/res/drawable/signin_header_animation.xml", "java/res/drawable/store_locally_tooltip_background.xml", "java/res/drawable/tab_indicator.xml", - "java/res/drawable/thumbnail_gradient_top_left.xml", - "java/res/drawable/thumbnail_gradient_top_right.xml", "java/res/drawable/tile_view_hairline_border_background.xml", "java/res/drawable/toolbar_shadow.xml", "java/res/drawable/visa_card.xml",
diff --git a/chrome/android/chrome_java_sources.gni b/chrome/android/chrome_java_sources.gni index ce0ae39..115cabc 100644 --- a/chrome/android/chrome_java_sources.gni +++ b/chrome/android/chrome_java_sources.gni
@@ -1072,7 +1072,6 @@ "java/src/org/chromium/chrome/browser/suggestions/SuggestionsOfflineModelObserver.java", "java/src/org/chromium/chrome/browser/suggestions/SuggestionsUiDelegate.java", "java/src/org/chromium/chrome/browser/suggestions/SuggestionsUiDelegateImpl.java", - "java/src/org/chromium/chrome/browser/suggestions/ThumbnailGradient.java", "java/src/org/chromium/chrome/browser/suggestions/mostvisited/MostVisitedSites.java", "java/src/org/chromium/chrome/browser/suggestions/mostvisited/MostVisitedSitesBridge.java", "java/src/org/chromium/chrome/browser/suggestions/mostvisited/MostVisitedSitesMetadataUtils.java",
diff --git a/chrome/android/features/autofill_assistant/javatests/src/org/chromium/chrome/browser/autofill_assistant/TestingAutofillAssistantModuleEntryProvider.java b/chrome/android/features/autofill_assistant/javatests/src/org/chromium/chrome/browser/autofill_assistant/TestingAutofillAssistantModuleEntryProvider.java index 10fb91ce..65f2c81 100644 --- a/chrome/android/features/autofill_assistant/javatests/src/org/chromium/chrome/browser/autofill_assistant/TestingAutofillAssistantModuleEntryProvider.java +++ b/chrome/android/features/autofill_assistant/javatests/src/org/chromium/chrome/browser/autofill_assistant/TestingAutofillAssistantModuleEntryProvider.java
@@ -11,7 +11,6 @@ import org.chromium.chrome.browser.ActivityTabProvider; import org.chromium.chrome.browser.autofill_assistant.onboarding.OnboardingCoordinatorFactory; import org.chromium.chrome.browser.browser_controls.BrowserControlsStateProvider; -import org.chromium.chrome.browser.tab.Tab; import org.chromium.components.browser_ui.bottomsheet.BottomSheetController; import org.chromium.content_public.browser.WebContents; @@ -108,13 +107,13 @@ } @Override - public void getModuleEntry( - Tab tab, Callback<AutofillAssistantModuleEntry> callback, boolean showUi) { + public void getModuleEntry(Callback<AutofillAssistantModuleEntry> callback, + AssistantModuleInstallUi.Provider moduleInstallUiProvider, boolean showUi) { if (mCannotInstall) { callback.onResult(null); return; } mNotInstalled = false; - super.getModuleEntry(tab, callback, showUi); + super.getModuleEntry(callback, moduleInstallUiProvider, showUi); } }
diff --git a/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AssistantModuleInstallUi.java b/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AssistantModuleInstallUi.java new file mode 100644 index 0000000..626174e --- /dev/null +++ b/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AssistantModuleInstallUi.java
@@ -0,0 +1,32 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.chrome.browser.autofill_assistant; + +import org.chromium.base.Consumer; + +/** + * UI informing the user about the status of installing a dynamic feature module. + */ +public interface AssistantModuleInstallUi { + /** + * Used to create {@link AssistantModuleInstallUi}. + */ + public interface Provider { + /** + * Creates a {@link AssistantModuleInstallUi}. + */ + AssistantModuleInstallUi create(Consumer<Boolean> onFailure); + } + + /** + * Show UI indicating the start of a module install. + */ + public void showInstallStartUi(); + + /** + * Show UI indicating the failure of a module install. + */ + public void showInstallFailureUi(); +}
diff --git a/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AssistantModuleInstallUiProviderChrome.java b/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AssistantModuleInstallUiProviderChrome.java new file mode 100644 index 0000000..eb9e589 --- /dev/null +++ b/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AssistantModuleInstallUiProviderChrome.java
@@ -0,0 +1,45 @@ + +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.chrome.browser.autofill_assistant; + +import org.chromium.base.Consumer; +import org.chromium.chrome.R; +import org.chromium.chrome.browser.modules.ModuleInstallUi; +import org.chromium.chrome.browser.tab.Tab; + +/** + * Implementation of {@link AssistantModuleInstallUi.Provider} for Chrome. + */ +public class AssistantModuleInstallUiProviderChrome implements AssistantModuleInstallUi.Provider { + private final Tab mTab; + + public AssistantModuleInstallUiProviderChrome(Tab tab) { + mTab = tab; + } + + @Override + public AssistantModuleInstallUi create(Consumer<Boolean> onFailure) { + ModuleInstallUi ui = new ModuleInstallUi(mTab, R.string.autofill_assistant_module_title, + new ModuleInstallUi.FailureUiListener() { + @Override + public void onFailureUiResponse(boolean retry) { + onFailure.accept(retry); + } + }); + + return new AssistantModuleInstallUi() { + @Override + public void showInstallStartUi() { + ui.showInstallStartUi(); + } + + @Override + public void showInstallFailureUi() { + ui.showInstallFailureUi(); + } + }; + } +}
diff --git a/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AutofillAssistantDirectActionHandler.java b/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AutofillAssistantDirectActionHandler.java index d662744f..ef837b4 100644 --- a/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AutofillAssistantDirectActionHandler.java +++ b/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AutofillAssistantDirectActionHandler.java
@@ -254,10 +254,10 @@ callback.onResult(null); return; } - mModuleEntryProvider.getModuleEntry(tab, (entry) -> { + mModuleEntryProvider.getModuleEntry((entry) -> { mDelegate = createDelegate(entry); callback.onResult(mDelegate); - }, /* showUi = */ true); + }, new AssistantModuleInstallUiProviderChrome(tab), /* showUi = */ true); } /** Creates a delegate from the given {@link AutofillAssistantModuleEntry}, if possible. */
diff --git a/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AutofillAssistantModuleEntryProvider.java b/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AutofillAssistantModuleEntryProvider.java index 0030be4..7cf52d01 100644 --- a/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AutofillAssistantModuleEntryProvider.java +++ b/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AutofillAssistantModuleEntryProvider.java
@@ -10,10 +10,7 @@ import org.chromium.base.Callback; import org.chromium.base.Log; import org.chromium.base.SysUtils; -import org.chromium.chrome.R; import org.chromium.chrome.browser.autofill_assistant.metrics.FeatureModuleInstallation; -import org.chromium.chrome.browser.modules.ModuleInstallUi; -import org.chromium.chrome.browser.tab.Tab; /** * Manages the loading of autofill assistant DFM, and provides implementation of @@ -46,13 +43,14 @@ /** Gets the AA module entry, installing it if necessary. */ /* package */ - void getModuleEntry(Tab tab, Callback<AutofillAssistantModuleEntry> callback, boolean showUi) { + void getModuleEntry(Callback<AutofillAssistantModuleEntry> callback, + AssistantModuleInstallUi.Provider moduleInstallUiProvider, boolean showUi) { AutofillAssistantModuleEntry entry = getModuleEntryIfInstalled(); if (entry != null) { callback.onResult(entry); return; } - loadDynamicModule(tab, callback, showUi); + loadDynamicModule(callback, moduleInstallUiProvider, showUi); } /** @@ -85,19 +83,16 @@ AutofillAssistantModule.installDeferred(); } - private static void loadDynamicModule( - Tab tab, Callback<AutofillAssistantModuleEntry> callback, boolean showUi) { - ModuleInstallUi ui = new ModuleInstallUi(tab, R.string.autofill_assistant_module_title, - new ModuleInstallUi.FailureUiListener() { - @Override - public void onFailureUiResponse(boolean retry) { - if (retry) { - loadDynamicModule(tab, callback, showUi); - } else { - callback.onResult(null); - } - } - }); + private static void loadDynamicModule(Callback<AutofillAssistantModuleEntry> callback, + AssistantModuleInstallUi.Provider moduleInstallUiProvider, boolean showUi) { + AssistantModuleInstallUi ui = moduleInstallUiProvider.create((Boolean retry) -> { + if (retry) { + loadDynamicModule(callback, moduleInstallUiProvider, showUi); + } else { + callback.onResult(null); + } + }); + if (showUi) { // Shows toast informing user about install start. ui.showInstallStartUi();
diff --git a/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AutofillAssistantTabHelper.java b/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AutofillAssistantTabHelper.java index 7e454a1..fa52fe0e0 100644 --- a/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AutofillAssistantTabHelper.java +++ b/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AutofillAssistantTabHelper.java
@@ -17,9 +17,10 @@ * the tab as observer and connect to its native counterpart in order to fulfill startup * requests from either side. */ - public static void createForTab(Tab tab, AssistantIsGsaFunction isGsaFunction, - AssistantIsMsbbEnabledFunction isMsbbEnabledFunction) { - Starter starter = new Starter(tab, isGsaFunction, isMsbbEnabledFunction); + public static void createForTab(Tab tab) { + Starter starter = new Starter(tab, AssistantDependencyUtilsChrome::isGsa, + AssistantDependencyUtilsChrome::isMakeSearchesAndBrowsingBetterSettingEnabled, + new AssistantModuleInstallUiProviderChrome(tab)); tab.addObserver(starter); tab.getUserDataHost().setUserData(USER_DATA_KEY, starter); }
diff --git a/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/Starter.java b/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/Starter.java index 234d962..374c24e 100644 --- a/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/Starter.java +++ b/chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/Starter.java
@@ -32,6 +32,7 @@ private final AssistantIsGsaFunction mIsGsaFunction; private final AssistantIsMsbbEnabledFunction mIsMsbbEnabledFunction; + private final AssistantModuleInstallUi.Provider mModuleInstallUiProvider; /** * The WebContents associated with the Tab which this starter is monitoring, unless detached. @@ -65,10 +66,12 @@ * This will wait for dependencies to become available and then create the native-side starter. */ public Starter(Tab tab, AssistantIsGsaFunction isGsaFunction, - AssistantIsMsbbEnabledFunction isMsbbEnabledFunction) { + AssistantIsMsbbEnabledFunction isMsbbEnabledFunction, + AssistantModuleInstallUi.Provider moduleInstallUiProvider) { mTab = tab; mIsGsaFunction = isGsaFunction; mIsMsbbEnabledFunction = isMsbbEnabledFunction; + mModuleInstallUiProvider = moduleInstallUiProvider; detectWebContentsChange(tab); } @@ -202,14 +205,14 @@ return; } - AutofillAssistantModuleEntryProvider.INSTANCE.getModuleEntry(mTab, + AutofillAssistantModuleEntryProvider.INSTANCE.getModuleEntry( (moduleEntry) -> safeNativeOnFeatureModuleInstalled(moduleEntry != null ? FeatureModuleInstallation .DFM_FOREGROUND_INSTALLATION_SUCCEEDED : FeatureModuleInstallation .DFM_FOREGROUND_INSTALLATION_FAILED), - showUi); + mModuleInstallUiProvider, showUi); } @CalledByNative
diff --git a/chrome/android/features/autofill_assistant/public/java_sources.gni b/chrome/android/features/autofill_assistant/public/java_sources.gni index 5e2e8cf..5c409dc 100644 --- a/chrome/android/features/autofill_assistant/public/java_sources.gni +++ b/chrome/android/features/autofill_assistant/public/java_sources.gni
@@ -12,6 +12,8 @@ "//chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AssistantInfoPageUtil.java", "//chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AssistantIsGsaFunction.java", "//chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AssistantIsMsbbEnabledFunction.java", + "//chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AssistantModuleInstallUi.java", + "//chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AssistantModuleInstallUiProviderChrome.java", "//chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AssistantOnboardingHelper.java", "//chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AssistantProfileImageUtil.java", "//chrome/android/features/autofill_assistant/public/java/src/org/chromium/chrome/browser/autofill_assistant/AssistantSnackbar.java",
diff --git a/chrome/android/features/start_surface/internal/java/src/org/chromium/chrome/features/start_surface/StartSurfaceLayout.java b/chrome/android/features/start_surface/internal/java/src/org/chromium/chrome/features/start_surface/StartSurfaceLayout.java index e1b4bb7..cd046de9 100644 --- a/chrome/android/features/start_surface/internal/java/src/org/chromium/chrome/features/start_surface/StartSurfaceLayout.java +++ b/chrome/android/features/start_surface/internal/java/src/org/chromium/chrome/features/start_surface/StartSurfaceLayout.java
@@ -305,6 +305,7 @@ @Override public void startHiding(int nextId, boolean hintAtTabSelection) { int startSurfaceState = mController.getStartSurfaceState(); + StartSurfaceUserData.getInstance().setUnusedTabRestoredAtStartup(false); if (startSurfaceState == StartSurfaceState.SHOWN_HOMEPAGE) { startHidingStartSurface(nextId, hintAtTabSelection); } else {
diff --git a/chrome/android/features/start_surface/internal/javatests/src/org/chromium/chrome/features/start_surface/InstantStartTabSwitcherTest.java b/chrome/android/features/start_surface/internal/javatests/src/org/chromium/chrome/features/start_surface/InstantStartTabSwitcherTest.java index 7d61a10..04b3601 100644 --- a/chrome/android/features/start_surface/internal/javatests/src/org/chromium/chrome/features/start_surface/InstantStartTabSwitcherTest.java +++ b/chrome/android/features/start_surface/internal/javatests/src/org/chromium/chrome/features/start_surface/InstantStartTabSwitcherTest.java
@@ -42,6 +42,12 @@ import org.junit.runner.RunWith; import org.chromium.base.MathUtils; +import org.chromium.base.metrics.RecordHistogram; +import org.chromium.base.test.params.ParameterAnnotations; +import org.chromium.base.test.params.ParameterAnnotations.UseMethodParameter; +import org.chromium.base.test.params.ParameterProvider; +import org.chromium.base.test.params.ParameterSet; +import org.chromium.base.test.params.ParameterizedRunner; import org.chromium.base.test.util.CommandLineFlags; import org.chromium.base.test.util.Criteria; import org.chromium.base.test.util.CriteriaHelper; @@ -61,17 +67,20 @@ import org.chromium.chrome.browser.tasks.pseudotab.TabAttributeCache; import org.chromium.chrome.browser.tasks.tab_management.TabUiFeatureUtilities; import org.chromium.chrome.browser.tasks.tab_management.TabUiTestHelper; -import org.chromium.chrome.test.ChromeJUnit4ClassRunner; +import org.chromium.chrome.test.ChromeJUnit4RunnerDelegate; import org.chromium.chrome.test.ChromeTabbedActivityTestRule; import org.chromium.chrome.test.util.ActivityTestUtils; import org.chromium.chrome.test.util.ChromeRenderTestRule; import org.chromium.chrome.test.util.browser.Features.DisableFeatures; import org.chromium.chrome.test.util.browser.Features.EnableFeatures; +import org.chromium.components.embedder_support.util.UrlUtilitiesJni; import org.chromium.content_public.browser.test.util.TestThreadUtils; import org.chromium.ui.test.util.UiRestriction; import org.chromium.ui.test.util.ViewUtils; import java.io.IOException; +import java.util.Arrays; +import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; @@ -79,7 +88,8 @@ * Integration tests of tab switcher with Instant Start which requires 2-stage initialization for * Clank startup. */ -@RunWith(ChromeJUnit4ClassRunner.class) +@RunWith(ParameterizedRunner.class) +@ParameterAnnotations.UseRunnerDelegate(ChromeJUnit4RunnerDelegate.class) // clang-format off @CommandLineFlags. Add({ChromeSwitches.DISABLE_FIRST_RUN_EXPERIENCE, "force-fieldtrials=Study/Group"}) @@ -91,6 +101,7 @@ public class InstantStartTabSwitcherTest { // clang-format on private static final String SHADOW_VIEW_TAG = "TabListViewShadow"; + private static final long MAX_TIMEOUT_MS = 30000L; @Rule public ChromeTabbedActivityTestRule mActivityTestRule = new ChromeTabbedActivityTestRule(); @@ -99,6 +110,23 @@ public ChromeRenderTestRule mRenderTestRule = ChromeRenderTestRule.Builder.withPublicCorpus().setRevision(1).build(); + /** + * {@link ParameterProvider} used for parameterized test that provides whether it's single tab + * switcher or carousel tab switcher and whether last visited tab is a search result page. + */ + public static class LVTIsSRPTestParams implements ParameterProvider { + private static final List<ParameterSet> sLVTIsSRPTestParams = + Arrays.asList(new ParameterSet().value(false, false).name("CarouselTab_NotSRP"), + new ParameterSet().value(true, false).name("SingleTab_NotSRP"), + new ParameterSet().value(false, true).name("CarouselTab_SRP"), + new ParameterSet().value(true, true).name("SingleTab_SRP")); + + @Override + public List<ParameterSet> getParameters() { + return sLVTIsSRPTestParams; + } + } + @After public void tearDown() { if (mActivityTestRule.getActivity() != null) { @@ -434,7 +462,7 @@ Assert.assertFalse(HomepageManager.isHomepageEnabled()); // Launches Chrome and verifies that the Tab switcher is showing. - mActivityTestRule.startMainActivityFromLauncher(); + StartSurfaceTestUtils.startMainActivityFromLauncher(mActivityTestRule); ChromeTabbedActivity cta = mActivityTestRule.getActivity(); StartSurfaceTestUtils.waitForTabModel(cta); TabUiTestHelper.verifyTabModelTabCount(cta, 1, 0); @@ -448,6 +476,62 @@ }); } + @Test + @MediumTest + // clang-format off + @CommandLineFlags.Add({ChromeSwitches.DISABLE_NATIVE_INITIALIZATION, + INSTANT_START_TEST_BASE_PARAMS + "/show_last_active_tab_only/true"}) + @UseMethodParameter(LVTIsSRPTestParams.class) + public void testRecordLastVisitedTabIsSRPHistogram(boolean isSingleTabSwitcher, boolean isSRP) + throws IOException { + // clang-format on + testRecordLastVisitedTabIsSRP(isSingleTabSwitcher, isSRP); + } + + @Test + @MediumTest + @DisableFeatures(ChromeFeatureList.INSTANT_START) + @UseMethodParameter(LVTIsSRPTestParams.class) + // clang-format off + @CommandLineFlags.Add({ChromeSwitches.DISABLE_NATIVE_INITIALIZATION, + INSTANT_START_TEST_BASE_PARAMS}) + // clang-format on + public void testRecordLastVisitedTabIsSRPHistogram_NoInstant( + boolean isSingleTabSwitcher, boolean isSRP) throws IOException { + testRecordLastVisitedTabIsSRP(isSingleTabSwitcher, isSRP); + } + + private void testRecordLastVisitedTabIsSRP(boolean isSingleTabSwitcher, boolean isSRP) + throws IOException { + StartSurfaceConfiguration.START_SURFACE_LAST_ACTIVE_TAB_ONLY.setForTesting( + isSingleTabSwitcher); + StartSurfaceTestUtils.createTabStateFile(new int[] {0, 1}, + new String[] {"https://www.google.com/search?q=test", "https://www.google.com"}, + isSRP ? 0 : 1); + StartSurfaceTestUtils.createThumbnailBitmapAndWriteToFile(0); + StartSurfaceTestUtils.createThumbnailBitmapAndWriteToFile(1); + TabAttributeCache.setTitleForTesting(0, "Google SRP"); + TabAttributeCache.setTitleForTesting(1, "Google Homepage"); + StartSurfaceTestUtils.startMainActivityFromLauncher(mActivityTestRule); + StartSurfaceTestUtils.startAndWaitNativeInitialization(mActivityTestRule); + ChromeTabbedActivity cta = mActivityTestRule.getActivity(); + StartSurfaceTestUtils.waitForOverviewVisible(cta); + StartSurfaceTestUtils.waitForDeferredStartup(mActivityTestRule); + + Assert.assertEquals(isSRP, + UrlUtilitiesJni.get().isGoogleSearchUrl( + StartSurfaceUserData.getInstance().getLastVisitedTabAtStartupUrl())); + Assert.assertEquals(1, + RecordHistogram.getHistogramTotalCountForTesting( + ReturnToChromeExperimentsUtil + .LAST_VISITED_TAB_IS_SRP_WHEN_OVERVIEW_IS_SHOWN_AT_LAUNCH_UMA)); + Assert.assertEquals(1, + RecordHistogram.getHistogramValueCountForTesting( + ReturnToChromeExperimentsUtil + .LAST_VISITED_TAB_IS_SRP_WHEN_OVERVIEW_IS_SHOWN_AT_LAUNCH_UMA, + isSRP ? 1 : 0)); + } + private boolean allCardsHaveThumbnail(RecyclerView recyclerView) { RecyclerView.Adapter adapter = recyclerView.getAdapter(); assert adapter != null;
diff --git a/chrome/android/features/start_surface/internal/javatests/src/org/chromium/chrome/features/start_surface/StartSurfaceTest.java b/chrome/android/features/start_surface/internal/javatests/src/org/chromium/chrome/features/start_surface/StartSurfaceTest.java index 800c4d5..da34e879 100644 --- a/chrome/android/features/start_surface/internal/javatests/src/org/chromium/chrome/features/start_surface/StartSurfaceTest.java +++ b/chrome/android/features/start_surface/internal/javatests/src/org/chromium/chrome/features/start_surface/StartSurfaceTest.java
@@ -101,6 +101,7 @@ import org.chromium.chrome.test.util.browser.suggestions.SuggestionsDependenciesRule; import org.chromium.components.browser_ui.bottomsheet.BottomSheetTestSupport; import org.chromium.components.embedder_support.util.UrlUtilities; +import org.chromium.content_public.browser.test.util.RenderProcessHostUtils; import org.chromium.content_public.browser.test.util.TestThreadUtils; import org.chromium.ui.modelutil.MVCListAdapter.ModelList; import org.chromium.ui.test.util.UiRestriction; @@ -571,8 +572,28 @@ // Disable feed placeholder animation because it causes waitForDeferredStartup() to time // out. FeedPlaceholderLayout.DISABLE_ANIMATION_SWITCH}) - public void startSurfaceRecordHistogramsTest() { + public void startSurfaceRecordHistogramsTest_SingleTab() { // clang-format on + startSurfaceRecordHistogramsTest(true); + } + + @Test + @MediumTest + @Restriction({UiRestriction.RESTRICTION_TYPE_PHONE}) + // clang-format off + @EnableFeatures({ChromeFeatureList.TAB_SWITCHER_ON_RETURN + "<Study", + ChromeFeatureList.TAB_GRID_LAYOUT_ANDROID, + ChromeFeatureList.START_SURFACE_ANDROID + "<Study"}) + @CommandLineFlags.Add({START_SURFACE_TEST_BASE_PARAMS, + // Disable feed placeholder animation because it causes waitForDeferredStartup() to time + // out. + FeedPlaceholderLayout.DISABLE_ANIMATION_SWITCH}) + public void startSurfaceRecordHistogramsTest_CarouselTab() { + // clang-format on + startSurfaceRecordHistogramsTest(false); + } + + private void startSurfaceRecordHistogramsTest(boolean isSingleTabSwitcher) { if (!mImmediateReturn) { assertNotEquals(0, ReturnToChromeExperimentsUtil.TAB_SWITCHER_ON_RETURN_MS.getValue()); StartSurfaceTestUtils.pressHomePageButton(mActivityTestRule.getActivity()); @@ -581,20 +602,12 @@ } Assert.assertEquals("single", StartSurfaceConfiguration.START_SURFACE_VARIATION.getValue()); - Assert.assertTrue(StartSurfaceConfiguration.START_SURFACE_LAST_ACTIVE_TAB_ONLY.getValue()); + Assert.assertEquals(isSingleTabSwitcher, + StartSurfaceConfiguration.START_SURFACE_LAST_ACTIVE_TAB_ONLY.getValue()); StartSurfaceTestUtils.waitForOverviewVisible( mLayoutChangedCallbackHelper, mCurrentlyActiveLayout); mActivityTestRule.waitForActivityNativeInitializationComplete(); - - // Waits for the current Tab to complete loading. The deferred startup will be triggered - // after the loading. - Tab tab = mActivityTestRule.getActivity().getActivityTab(); - if (tab != null && tab.isLoading()) { - CriteriaHelper.pollUiThread(() - -> !tab.isLoading(), - MAX_TIMEOUT_MS, CriteriaHelper.DEFAULT_POLLING_INTERVAL); - } - assertTrue("Deferred startup never completed", mActivityTestRule.waitForDeferredStartup()); + StartSurfaceTestUtils.waitForDeferredStartup(mActivityTestRule); boolean isInstantStart = TabUiFeatureUtilities.supportInstantStart(false, mActivityTestRule.getActivity()); @@ -606,11 +619,18 @@ int expectedRecordCount = mImmediateReturn ? 1 : 0; // Histograms should be only recorded when StartSurface is shown immediately after // launch. + if (isSingleTabSwitcher) { + Assert.assertEquals(expectedRecordCount, + RecordHistogram.getHistogramTotalCountForTesting( + StartSurfaceConfiguration.getHistogramName( + SingleTabSwitcherMediator.SINGLE_TAB_TITLE_AVAILABLE_TIME_UMA, + isInstantStart))); + } + Assert.assertEquals(expectedRecordCount, RecordHistogram.getHistogramTotalCountForTesting( - StartSurfaceConfiguration.getHistogramName( - SingleTabSwitcherMediator.SINGLE_TAB_TITLE_AVAILABLE_TIME_UMA, - isInstantStart))); + ReturnToChromeExperimentsUtil + .LAST_VISITED_TAB_IS_SRP_WHEN_OVERVIEW_IS_SHOWN_AT_LAUNCH_UMA)); Assert.assertEquals(expectedRecordCount, RecordHistogram.getHistogramTotalCountForTesting( @@ -1076,6 +1096,42 @@ assertEquals(hasUpdateMenuItem ? 12 : 11, menuItemsModelList.size()); } + @Test + @MediumTest + @Feature({"StartSurface"}) + // clang-format off + @CommandLineFlags.Add({START_SURFACE_TEST_BASE_PARAMS}) + public void test_DoNotLoadLastSelectedTabOnStartup() { + // clang-format on + doTestNotLoadLastSelectedTabOnStartupImpl(); + } + + @Test + @MediumTest + @Feature({"StartSurface"}) + // clang-format off + @CommandLineFlags.Add({START_SURFACE_TEST_BASE_PARAMS + "/show_last_active_tab_only/true"}) + public void test_DoNotLoadLastSelectedTabOnStartupV2() { + // clang-format on + doTestNotLoadLastSelectedTabOnStartupImpl(); + } + + private void doTestNotLoadLastSelectedTabOnStartupImpl() { + assumeTrue(mImmediateReturn); + + ChromeTabbedActivity cta = mActivityTestRule.getActivity(); + StartSurfaceTestUtils.waitForOverviewVisible( + mLayoutChangedCallbackHelper, mCurrentlyActiveLayout); + StartSurfaceTestUtils.waitForTabModel(cta); + TabUiTestHelper.verifyTabModelTabCount(cta, 1, 0); + Assert.assertEquals(0, RenderProcessHostUtils.getCurrentRenderProcessCount()); + + StartSurfaceTestUtils.launchFirstMVTile(cta, /* currentTabCount = */ 1); + TabUiTestHelper.verifyTabModelTabCount(cta, 2, 0); + StartSurfaceTestUtils.waitForCurrentTabLoaded(mActivityTestRule); + Assert.assertEquals(1, RenderProcessHostUtils.getCurrentRenderProcessCount()); + } + /** * Check that the next decision time is within |numOfDays| from now. * @param numOfDays Number of days to check.
diff --git a/chrome/android/features/start_surface/internal/javatests/src/org/chromium/chrome/features/start_surface/StartSurfaceTestUtils.java b/chrome/android/features/start_surface/internal/javatests/src/org/chromium/chrome/features/start_surface/StartSurfaceTestUtils.java index 7adcf7f..55cf296 100644 --- a/chrome/android/features/start_surface/internal/javatests/src/org/chromium/chrome/features/start_surface/StartSurfaceTestUtils.java +++ b/chrome/android/features/start_surface/internal/javatests/src/org/chromium/chrome/features/start_surface/StartSurfaceTestUtils.java
@@ -57,6 +57,7 @@ import org.chromium.chrome.browser.suggestions.tile.TileSectionType; import org.chromium.chrome.browser.suggestions.tile.TileSource; import org.chromium.chrome.browser.suggestions.tile.TileTitleSource; +import org.chromium.chrome.browser.tab.Tab; import org.chromium.chrome.browser.tab.TabState; import org.chromium.chrome.browser.tabmodel.TabPersistentStore; import org.chromium.chrome.browser.tabmodel.TabbedModeTabPersistencePolicy; @@ -482,6 +483,30 @@ } /** + * Wait for the deferred startup to be triggered. + * @param activityTestRule The ChromeTabbedActivityTestRule under test. + */ + public static void waitForDeferredStartup(ChromeTabbedActivityTestRule activityTestRule) { + // Waits for the current Tab to complete loading. The deferred startup will be triggered + // after the loading. + waitForCurrentTabLoaded(activityTestRule); + assertTrue("Deferred startup never completed", activityTestRule.waitForDeferredStartup()); + } + + /** + * Waits for the current Tab to complete loading. + * @param activityTestRule The ChromeTabbedActivityTestRule under test. + */ + public static void waitForCurrentTabLoaded(ChromeTabbedActivityTestRule activityTestRule) { + Tab tab = activityTestRule.getActivity().getActivityTab(); + if (tab != null && tab.isLoading()) { + CriteriaHelper.pollUiThread(() + -> !tab.isLoading(), + MAX_TIMEOUT_MS, CriteriaHelper.DEFAULT_POLLING_INTERVAL); + } + } + + /** * Create a file so that a TabState can be restored later. * @param tabId the Tab ID * @param encrypted for Incognito mode
diff --git a/chrome/android/features/start_surface/public/java/src/org/chromium/chrome/features/start_surface/StartSurfaceUserData.java b/chrome/android/features/start_surface/public/java/src/org/chromium/chrome/features/start_surface/StartSurfaceUserData.java index 4c7b533..150189a2 100644 --- a/chrome/android/features/start_surface/public/java/src/org/chromium/chrome/features/start_surface/StartSurfaceUserData.java +++ b/chrome/android/features/start_surface/public/java/src/org/chromium/chrome/features/start_surface/StartSurfaceUserData.java
@@ -18,9 +18,14 @@ private boolean mFocusOnOmnibox; private boolean mCreatedAsNtp; private boolean mOpenedFromStart; - + private String mLastVisitedTabAtStartupUrl; // Saves the Feeds instance state. private String mFeedsInstanceState; + /** + * Tracks whether the last visited Tab is restored at startup but not showing due to the + * overview page is showing at the startup. + */ + private boolean mUnusedTabRestoredAtStartup; /** * Static class that implements the initialization-on-demand holder idiom. @@ -154,4 +159,35 @@ protected String restoreFeedInstanceState() { return mFeedsInstanceState; } + + /** + * Sets the url of last visited tab at start up. + * @param lastVisitedTabAtStartupUrl The url of last visited tab at start up. + */ + public void setLastVisitedTabAtStartupUrl(String lastVisitedTabAtStartupUrl) { + mLastVisitedTabAtStartupUrl = lastVisitedTabAtStartupUrl; + } + + /** + * Returns the saved url of last visited tab at start up. + */ + public String getLastVisitedTabAtStartupUrl() { + return mLastVisitedTabAtStartupUrl; + } + + /** + * Sets whether an unused Tab is restored at startup due to an overview page is showing at the + * startup. + */ + public void setUnusedTabRestoredAtStartup(boolean overviewShownAtStartup) { + mUnusedTabRestoredAtStartup = overviewShownAtStartup; + } + + /** + * Gets whether an unused Tab is restored at startup due to an overview page is showing at the + * startup. + */ + public boolean getUnusedTabRestoredAtStartup() { + return mUnusedTabRestoredAtStartup; + } }
diff --git a/chrome/android/features/tab_ui/java/src/org/chromium/chrome/browser/tasks/SingleTabSwitcherMediator.java b/chrome/android/features/tab_ui/java/src/org/chromium/chrome/browser/tasks/SingleTabSwitcherMediator.java index f565fa2c..1f643b9e 100644 --- a/chrome/android/features/tab_ui/java/src/org/chromium/chrome/browser/tasks/SingleTabSwitcherMediator.java +++ b/chrome/android/features/tab_ui/java/src/org/chromium/chrome/browser/tasks/SingleTabSwitcherMediator.java
@@ -252,6 +252,7 @@ StartSurfaceConfiguration.recordHistogram(SINGLE_TAB_TITLE_AVAILABLE_TIME_UMA, mTabTitleAvailableTime - activityCreationTimeMs, TabUiFeatureUtilities.supportInstantStart(false, mContext)); + ReturnToChromeExperimentsUtil.recordLastVisitedTabIsSRPWhenOverviewIsShownAtLaunch(); } @Override
diff --git a/chrome/android/features/tab_ui/java/src/org/chromium/chrome/browser/tasks/tab_management/TabSwitcherMediator.java b/chrome/android/features/tab_ui/java/src/org/chromium/chrome/browser/tasks/tab_management/TabSwitcherMediator.java index f972598..63052848 100644 --- a/chrome/android/features/tab_ui/java/src/org/chromium/chrome/browser/tasks/tab_management/TabSwitcherMediator.java +++ b/chrome/android/features/tab_ui/java/src/org/chromium/chrome/browser/tasks/tab_management/TabSwitcherMediator.java
@@ -774,7 +774,11 @@ } @Override - public void onOverviewShownAtLaunch(long activityCreationTimeMs) {} + public void onOverviewShownAtLaunch(long activityCreationTimeMs) { + if (mMode == TabListMode.CAROUSEL) { + ReturnToChromeExperimentsUtil.recordLastVisitedTabIsSRPWhenOverviewIsShownAtLaunch(); + } + } /** * Do clean-up work after the overview hiding animation is finished.
diff --git a/chrome/android/java/res/drawable/thumbnail_gradient_top_left.xml b/chrome/android/java/res/drawable/thumbnail_gradient_top_left.xml deleted file mode 100644 index bac6d72..0000000 --- a/chrome/android/java/res/drawable/thumbnail_gradient_top_left.xml +++ /dev/null
@@ -1,10 +0,0 @@ -<?xml version="1.0" encoding="utf-8"?> -<shape xmlns:android="http://schemas.android.com/apk/res/android"> - - <!-- The spec for the gradient is #25272B applied at 6% transparency. This gives #0F25272B. --> - - <gradient - android:angle="315" - android:startColor="#0F25272B" - android:endColor="#00000000" /> -</shape> \ No newline at end of file
diff --git a/chrome/android/java/res/drawable/thumbnail_gradient_top_right.xml b/chrome/android/java/res/drawable/thumbnail_gradient_top_right.xml deleted file mode 100644 index 536c7f07..0000000 --- a/chrome/android/java/res/drawable/thumbnail_gradient_top_right.xml +++ /dev/null
@@ -1,10 +0,0 @@ -<?xml version="1.0" encoding="utf-8"?> -<shape xmlns:android="http://schemas.android.com/apk/res/android"> - - <!-- The spec for the gradient is #25272B applied at 6% transparency. This gives #0F25272B. --> - - <gradient - android:angle="225" - android:startColor="#0F25272B" - android:endColor="#00000000" /> -</shape> \ No newline at end of file
diff --git a/chrome/android/java/src/org/chromium/chrome/browser/ChromeTabbedActivity.java b/chrome/android/java/src/org/chromium/chrome/browser/ChromeTabbedActivity.java index 5f3d0be..99a4267 100644 --- a/chrome/android/java/src/org/chromium/chrome/browser/ChromeTabbedActivity.java +++ b/chrome/android/java/src/org/chromium/chrome/browser/ChromeTabbedActivity.java
@@ -1708,10 +1708,14 @@ ChromePreferenceKeys.TABBED_ACTIVITY_LAST_BACKGROUNDED_TIME_MS_PREF); assert getActivityTabStartupMetricsTracker() != null; + boolean shouldShowOverviewPageOnStart = shouldShowOverviewPageOnStart(); + if (shouldShowOverviewPageOnStart) { + StartSurfaceUserData.getInstance().setUnusedTabRestoredAtStartup(true); + } if (StartupPaintPreviewHelper.isEnabled()) { StartupPaintPreviewHelper paintPreviewHelper = new StartupPaintPreviewHelper( getWindowAndroid(), getOnCreateTimestampMs(), getBrowserControlsManager(), - getTabModelSelector(), shouldShowOverviewPageOnStart(), () -> { + getTabModelSelector(), shouldShowOverviewPageOnStart, () -> { return getToolbarManager() == null ? null : getToolbarManager().getProgressBarCoordinator();
diff --git a/chrome/android/java/src/org/chromium/chrome/browser/app/flags/ChromeCachedFlags.java b/chrome/android/java/src/org/chromium/chrome/browser/app/flags/ChromeCachedFlags.java index 51b412c..b223f31 100644 --- a/chrome/android/java/src/org/chromium/chrome/browser/app/flags/ChromeCachedFlags.java +++ b/chrome/android/java/src/org/chromium/chrome/browser/app/flags/ChromeCachedFlags.java
@@ -67,6 +67,7 @@ // Workaround for crbug.com/1223545: Do not use Arrays.asList(). List<String> featuresToCache = new ArrayList<String>() { { + add(ChromeFeatureList.ANONYMOUS_UPDATE_CHECKS); add(ChromeFeatureList.APP_MENU_MOBILE_SITE_OPTION); add(ChromeFeatureList.APP_TO_WEB_ATTRIBUTION); add(ChromeFeatureList.BOOKMARK_BOTTOM_SHEET);
diff --git a/chrome/android/java/src/org/chromium/chrome/browser/app/tab_activity_glue/ReparentingTask.java b/chrome/android/java/src/org/chromium/chrome/browser/app/tab_activity_glue/ReparentingTask.java index fce01ed8..b6a061ca 100644 --- a/chrome/android/java/src/org/chromium/chrome/browser/app/tab_activity_glue/ReparentingTask.java +++ b/chrome/android/java/src/org/chromium/chrome/browser/app/tab_activity_glue/ReparentingTask.java
@@ -26,7 +26,7 @@ import org.chromium.chrome.browser.flags.ChromeFeatureList; import org.chromium.chrome.browser.tab.Tab; import org.chromium.chrome.browser.tab.TabDelegateFactory; -import org.chromium.chrome.browser.tab.TabStateAttributes; +import org.chromium.chrome.browser.tab.TabImpl; import org.chromium.chrome.browser.tabmodel.TabModelSelector; import org.chromium.chrome.browser.tabmodel.TabReparentingParams; import org.chromium.content_public.browser.WebContents; @@ -168,9 +168,7 @@ public void finish(@NonNull Delegate delegate, @Nullable Runnable finalizeCallback) { delegate.getCompositorViewHolder().prepareForTabReparenting(); attach(delegate.getWindowAndroid(), delegate.getTabDelegateFactory()); - if (!mTab.isDestroyed()) { - TabStateAttributes.from(mTab).setIsTabStateDirty(true); - } + ((TabImpl) mTab).setIsTabStateDirty(true); if (finalizeCallback != null) finalizeCallback.run(); }
diff --git a/chrome/android/java/src/org/chromium/chrome/browser/compositor/CompositorViewHolder.java b/chrome/android/java/src/org/chromium/chrome/browser/compositor/CompositorViewHolder.java index 58f4e976..8a736f3 100644 --- a/chrome/android/java/src/org/chromium/chrome/browser/compositor/CompositorViewHolder.java +++ b/chrome/android/java/src/org/chromium/chrome/browser/compositor/CompositorViewHolder.java
@@ -65,6 +65,7 @@ import org.chromium.chrome.browser.toolbar.ControlContainer; import org.chromium.chrome.browser.ui.TabObscuringHandler; import org.chromium.chrome.browser.util.ChromeAccessibilityUtil; +import org.chromium.chrome.features.start_surface.StartSurfaceUserData; import org.chromium.components.browser_ui.widget.InsetObserverView; import org.chromium.components.browser_ui.widget.TouchEventObserver; import org.chromium.components.content_capture.OnscreenContentProvider; @@ -1421,7 +1422,13 @@ } private void setTab(Tab tab) { - if (tab != null) tab.loadIfNeeded(); + // The StartSurfaceUserData.getInstance().getUnusedTabRestoredAtStartup() is only true when + // the Start surface is showing in the startup and there isn't any Tab opened. Thus, no + // Tab needs to be loaded. Once a new Tab is opening and Start surface is hiding, this flag + // will be reset. + if (tab != null && !StartSurfaceUserData.getInstance().getUnusedTabRestoredAtStartup()) { + tab.loadIfNeeded(); + } View newView = tab != null ? tab.getView() : null; if (mView == newView) return;
diff --git a/chrome/android/java/src/org/chromium/chrome/browser/download/OMADownloadHandler.java b/chrome/android/java/src/org/chromium/chrome/browser/download/OMADownloadHandler.java index 6e5f347..35e24ca 100644 --- a/chrome/android/java/src/org/chromium/chrome/browser/download/OMADownloadHandler.java +++ b/chrome/android/java/src/org/chromium/chrome/browser/download/OMADownloadHandler.java
@@ -120,6 +120,25 @@ "953 Non-Acceptable Content \n\r"; private static final String DOWNLOAD_STATUS_LOADER_ERROR = "954 Loader Error \n\r"; + private static final NetworkTrafficAnnotationTag TRAFFIC_ANNOTATION = + NetworkTrafficAnnotationTag.createComplete("oma_download_handler_android", + "semantics {" + + " sender: 'OMA Download Handler (Android)'" + + " description: 'Uploads file download status to the server URL '" + + " 'specified in the download descriptor XML, as ' " + + " 'required by the OMA DRM specification.'" + + " trigger: 'After an OMA DRM file download completes.'" + + " data: 'Info related to the download.'" + + " destination: OTHER" + + "}" + + "policy {" + + " cookies_allowed: NO" + + " setting: 'This feature cannot be disabled by settings as it is '" + + " 'part of the OMA DRM specification.'" + + " policy_exception_justification:" + + " 'Not implemented.'" + + "}"); + private final Context mContext; private final SharedPreferencesManager mSharedPrefs; private final LongSparseArray<DownloadItem> mSystemDownloadIdMap = @@ -945,7 +964,7 @@ try { URL url = new URL(mOMAInfo.getValue(OMA_INSTALL_NOTIFY_URI)); urlConnection = (HttpURLConnection) ChromiumNetworkAdapter.openConnection( - url, NetworkTrafficAnnotationTag.MISSING_TRAFFIC_ANNOTATION); + url, TRAFFIC_ANNOTATION); urlConnection.setDoOutput(true); urlConnection.setUseCaches(false); urlConnection.setRequestMethod("POST");
diff --git a/chrome/android/java/src/org/chromium/chrome/browser/omaha/RequestGenerator.java b/chrome/android/java/src/org/chromium/chrome/browser/omaha/RequestGenerator.java index 3b9d973..e8e6974 100644 --- a/chrome/android/java/src/org/chromium/chrome/browser/omaha/RequestGenerator.java +++ b/chrome/android/java/src/org/chromium/chrome/browser/omaha/RequestGenerator.java
@@ -14,6 +14,8 @@ import org.xmlpull.v1.XmlSerializer; import org.chromium.base.BuildInfo; +import org.chromium.chrome.browser.flags.CachedFeatureFlags; +import org.chromium.chrome.browser.flags.ChromeFeatureList; import org.chromium.chrome.browser.uid.SettingsSecureBasedIdentificationGenerator; import org.chromium.chrome.browser.uid.UniqueIdentificationGeneratorFactory; import org.chromium.ui.base.DeviceFormFactor; @@ -80,8 +82,12 @@ serializer.attribute(null, "requestid", "{" + data.getRequestID() + "}"); serializer.attribute(null, "sessionid", "{" + sessionID + "}"); serializer.attribute(null, "installsource", data.getInstallSource()); - serializer.attribute(null, "userid", "{" + getDeviceID() + "}"); - serializer.attribute(null, "dedup", "uid"); + if (CachedFeatureFlags.isEnabled(ChromeFeatureList.ANONYMOUS_UPDATE_CHECKS)) { + serializer.attribute(null, "dedup", "cr"); + } else { + serializer.attribute(null, "userid", "{" + getDeviceID() + "}"); + serializer.attribute(null, "dedup", "uid"); + } // Set up <os platform="android"... /> serializer.startTag(null, "os");
diff --git a/chrome/android/java/src/org/chromium/chrome/browser/suggestions/ThumbnailGradient.java b/chrome/android/java/src/org/chromium/chrome/browser/suggestions/ThumbnailGradient.java deleted file mode 100644 index 5b6383f..0000000 --- a/chrome/android/java/src/org/chromium/chrome/browser/suggestions/ThumbnailGradient.java +++ /dev/null
@@ -1,148 +0,0 @@ -// Copyright 2017 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.chrome.browser.suggestions; - -import android.content.res.Resources; -import android.graphics.Bitmap; -import android.graphics.Color; -import android.graphics.drawable.BitmapDrawable; -import android.graphics.drawable.Drawable; -import android.graphics.drawable.LayerDrawable; -import android.os.SystemClock; - -import androidx.annotation.IntDef; - -import org.chromium.base.ApiCompatibilityUtils; -import org.chromium.base.metrics.RecordHistogram; -import org.chromium.chrome.R; -import org.chromium.ui.base.LocalizationUtils; - -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; - -/** - * When suggestions cards are displayed on a white background, thumbnails with white backgrounds - * have a gradient overlaid to provide contrast at the edge of the cards. - */ -public class ThumbnailGradient { - /** If all the RGB values of a pixel are greater than this value, it is counted as 'light'. */ - private static final int LIGHT_PIXEL_THRESHOLD = 0xcc; - - /** The percent of the border pictures that need to be 'light' for a Bitmap to be 'light'. */ - private static final float PIXEL_BORDER_RATIO = 0.4f; - - /** Where the image is located in the card. */ - @IntDef({ThumbnailLocation.START, ThumbnailLocation.END}) - @Retention(RetentionPolicy.SOURCE) - public @interface ThumbnailLocation { - int START = 0; - int END = 1; - } - - /** The corner of the thumbnail where the gradient is darkest. */ - @IntDef({GradientDirection.TOP_LEFT, GradientDirection.TOP_RIGHT}) - @Retention(RetentionPolicy.SOURCE) - private @interface GradientDirection { - int TOP_LEFT = 0; - int TOP_RIGHT = 1; - } - - /** - * Calls {@link #createDrawableWithGradientIfNeeded(Bitmap, int, Resources)} with the default - * {@link ThumbnailLocation#END}. - */ - public static Drawable createDrawableWithGradientIfNeeded(Bitmap bitmap, Resources resources) { - return createDrawableWithGradientIfNeeded(bitmap, ThumbnailLocation.END, resources); - } - - /** - * If the {@link Bitmap} should have a gradient applied this method returns a Drawable - * containing the Bitmap and a gradient. Otherwise it returns a BitmapDrawable containing just - * the Bitmap. - * @param bitmap The {@link Bitmap} used to create the drawable. - * @param thumbnailLocation Where the image is located in the card. - * @param resources The {@link Resources} for the current activity. - */ - public static Drawable createDrawableWithGradientIfNeeded( - Bitmap bitmap, @ThumbnailLocation int thumbnailLocation, Resources resources) { - int direction = getGradientDirection(thumbnailLocation); - - // We want to keep an eye on how long this takes. - long time = SystemClock.elapsedRealtime(); - boolean lightImage = hasLightCorner(bitmap, direction); - RecordHistogram.recordTimesHistogram( - "Thumbnails.Gradient.ImageDetectionTime", SystemClock.elapsedRealtime() - time); - - RecordHistogram.recordBooleanHistogram( - "Thumbnails.Gradient.ImageRequiresGradient", lightImage); - - if (lightImage) { - Drawable gradient = ApiCompatibilityUtils.getDrawable(resources, - direction == GradientDirection.TOP_LEFT - ? R.drawable.thumbnail_gradient_top_left - : R.drawable.thumbnail_gradient_top_right); - - return new LayerDrawable( - new Drawable[] {new BitmapDrawable(resources, bitmap), gradient}); - } - - return new BitmapDrawable(resources, bitmap); - } - - /** - * Determines whether a Bitmap has a light corner. - */ - private static boolean hasLightCorner(Bitmap bitmap, @GradientDirection int direction) { - int lightPixels = 0; - - final int width = bitmap.getWidth(); - final int height = bitmap.getHeight(); - // We test all the pixels along the top and one side. The |-1| is so we don't count the - // corner twice. - final int threshold = (int) ((width + height - 1) * PIXEL_BORDER_RATIO); - - for (int x = 0; x < width; x++) { - if (isPixelLight(bitmap.getPixel(x, 0))) lightPixels++; - } - - // If we've already exceeded the threshold of light pixels, don't bother counting the rest. - if (lightPixels > threshold) { - return true; - } - - final int x = direction == GradientDirection.TOP_LEFT ? 0 : width - 1; - // Avoid counting the corner pixels twice. - for (int y = 1; y < height - 1; y++) { - if (isPixelLight(bitmap.getPixel(x, y))) lightPixels++; - } - - return lightPixels > threshold; - } - - /** - * Whether a pixel counts as light. - */ - private static boolean isPixelLight(int color) { - return Color.red(color) > LIGHT_PIXEL_THRESHOLD && Color.blue(color) > LIGHT_PIXEL_THRESHOLD - && Color.green(color) > LIGHT_PIXEL_THRESHOLD; - } - - /** - * The gradient should come from the upper corner of the thumbnail that is touching the side of - * the card. - */ - @GradientDirection - private static int getGradientDirection(@ThumbnailLocation int thumbnailLocation) { - // The drawable resource does not get flipped automatically if we are in RTL, so we must - // flip it ourselves. - boolean rtl = LocalizationUtils.isLayoutRtl(); - - // If the thumbnail is on the left side of the card, the gradient should be applied - // to the top left corner. If it is on the right side of the card, the gradient should be - // applied to the top right corner. - return thumbnailLocation == ThumbnailLocation.END == rtl ? GradientDirection.TOP_LEFT - : GradientDirection.TOP_RIGHT; - } -}
diff --git a/chrome/android/java/src/org/chromium/chrome/browser/tab/TabHelpers.java b/chrome/android/java/src/org/chromium/chrome/browser/tab/TabHelpers.java index 9262692c..1a771c3 100644 --- a/chrome/android/java/src/org/chromium/chrome/browser/tab/TabHelpers.java +++ b/chrome/android/java/src/org/chromium/chrome/browser/tab/TabHelpers.java
@@ -5,7 +5,6 @@ package org.chromium.chrome.browser.tab; import org.chromium.chrome.browser.SwipeRefreshHandler; -import org.chromium.chrome.browser.autofill_assistant.AssistantDependencyUtilsChrome; import org.chromium.chrome.browser.autofill_assistant.AutofillAssistantTabHelper; import org.chromium.chrome.browser.complex_tasks.TaskTabHelper; import org.chromium.chrome.browser.contextualsearch.ContextualSearchTabHelper; @@ -38,8 +37,7 @@ TabBrowserControlsConstraintsHelper.createForTab(tab); ContinuousSearchTabHelper.createForTab(tab); if (ReaderModeManager.isEnabled()) ReaderModeManager.createForTab(tab); - AutofillAssistantTabHelper.createForTab(tab, AssistantDependencyUtilsChrome::isGsa, - AssistantDependencyUtilsChrome::isMakeSearchesAndBrowsingBetterSettingEnabled); + AutofillAssistantTabHelper.createForTab(tab); // The following will start prefetching data for the price drops feature, so // we should only do it if the user is eligible for the feature (e.g. has sync enabled).
diff --git a/chrome/android/java/src/org/chromium/chrome/browser/tab/TabImpl.java b/chrome/android/java/src/org/chromium/chrome/browser/tab/TabImpl.java index e6ef840..074a257 100644 --- a/chrome/android/java/src/org/chromium/chrome/browser/tab/TabImpl.java +++ b/chrome/android/java/src/org/chromium/chrome/browser/tab/TabImpl.java
@@ -135,6 +135,9 @@ private boolean mIsClosing; private boolean mIsShowingErrorPage; + /** Whether or not the TabState has changed. */ + private boolean mIsTabStateDirty = true; + /** * Saves how this tab was launched (from a link, external app, etc) so that * we can determine the different circumstances in which it should be @@ -782,6 +785,18 @@ return !(activity instanceof ChromeActivity); } + /** + * @return Whether the TabState representing this Tab has been updated. + */ + public boolean isTabStateDirty() { + return mIsTabStateDirty; + } + + @Override + public void setIsTabStateDirty(boolean isDirty) { + mIsTabStateDirty = isDirty; + } + @Override public void setIsTabSaveEnabled(boolean isTabSaveEnabled) { mIsTabSaveEnabledSupplier.set(isTabSaveEnabled); @@ -1043,6 +1058,7 @@ * @param url URL that was loaded. */ void didFinishPageLoad(GURL url) { + mIsTabStateDirty = true; updateTitle(); for (TabObserver observer : mObservers) observer.onPageLoadFinished(this, url); @@ -1103,6 +1119,7 @@ * Called when navigation entries were removed. */ void notifyNavigationEntriesDeleted() { + mIsTabStateDirty = true; for (TabObserver observer : mObservers) observer.onNavigationEntriesDeleted(this); } @@ -1198,6 +1215,7 @@ void updateTitle(String title) { if (TextUtils.equals(CriticalPersistedTabData.from(this).getTitle(), title)) return; + mIsTabStateDirty = true; CriticalPersistedTabData.from(this).setTitle(title); notifyPageTitleChanged(); }
diff --git a/chrome/android/java/src/org/chromium/chrome/browser/tab/TabWebContentsObserver.java b/chrome/android/java/src/org/chromium/chrome/browser/tab/TabWebContentsObserver.java index 562656b..acbc4424 100644 --- a/chrome/android/java/src/org/chromium/chrome/browser/tab/TabWebContentsObserver.java +++ b/chrome/android/java/src/org/chromium/chrome/browser/tab/TabWebContentsObserver.java
@@ -334,9 +334,7 @@ if (!navigation.hasCommitted()) return; if (navigation.isInPrimaryMainFrame()) { - if (!mTab.isDestroyed()) { - TabStateAttributes.from(mTab).setIsTabStateDirty(true); - } + mTab.setIsTabStateDirty(true); mTab.updateTitle(); mTab.handleDidFinishNavigation(navigation.getUrl(), navigation.pageTransition()); mTab.setIsShowingErrorPage(navigation.isErrorPage()); @@ -380,9 +378,7 @@ @Override public void navigationEntriesChanged() { - if (!mTab.isDestroyed()) { - TabStateAttributes.from(mTab).setIsTabStateDirty(true); - } + mTab.setIsTabStateDirty(true); } @Override
diff --git a/chrome/android/java/src/org/chromium/chrome/browser/tabmodel/TabPersistentStore.java b/chrome/android/java/src/org/chromium/chrome/browser/tabmodel/TabPersistentStore.java index 393785a..7b2ff52 100644 --- a/chrome/android/java/src/org/chromium/chrome/browser/tabmodel/TabPersistentStore.java +++ b/chrome/android/java/src/org/chromium/chrome/browser/tabmodel/TabPersistentStore.java
@@ -41,15 +41,16 @@ import org.chromium.chrome.browser.tab.Tab; import org.chromium.chrome.browser.tab.TabCreationState; import org.chromium.chrome.browser.tab.TabIdManager; +import org.chromium.chrome.browser.tab.TabImpl; import org.chromium.chrome.browser.tab.TabLaunchType; import org.chromium.chrome.browser.tab.TabState; -import org.chromium.chrome.browser.tab.TabStateAttributes; import org.chromium.chrome.browser.tab.TabStateExtractor; import org.chromium.chrome.browser.tab.state.CriticalPersistedTabData; import org.chromium.chrome.browser.tab.state.FilePersistedTabDataStorage; import org.chromium.chrome.browser.tab.state.PersistedTabData; import org.chromium.chrome.browser.tabpersistence.TabStateDirectory; import org.chromium.chrome.browser.tabpersistence.TabStateFileManager; +import org.chromium.chrome.features.start_surface.StartSurfaceUserData; import org.chromium.components.embedder_support.util.UrlConstants; import org.chromium.components.embedder_support.util.UrlUtilities; import org.chromium.content_public.browser.LoadUrlParams; @@ -130,9 +131,6 @@ new TabModelSelectorTabObserver(mTabModelSelector) { @Override public void onNavigationEntriesDeleted(Tab tab) { - if (!tab.isDestroyed()) { - TabStateAttributes.from(tab).setIsTabStateDirty(true); - } addTabToSaveQueue(tab); } @@ -145,20 +143,6 @@ public void onRootIdChanged(Tab tab, int newRootId) { addTabToSaveQueue(tab); } - - @Override - public void onPageLoadFinished(Tab tab, GURL url) { - if (!tab.isDestroyed()) { - TabStateAttributes.from(tab).setIsTabStateDirty(true); - } - } - - @Override - public void onTitleUpdated(Tab tab) { - if (!tab.isDestroyed()) { - TabStateAttributes.from(tab).setIsTabStateDirty(true); - } - } }; mTabModelObserver = new TabModelObserver() { @@ -734,8 +718,9 @@ boolean wasIncognitoTabModelSelected = mTabModelSelector.isIncognitoSelected(); int selectedModelTabCount = mTabModelSelector.getCurrentModel().getCount(); - // TODO(hanxi): Sets the correct value for skipLoadingTab. - TabModelUtils.setIndex(model, TabModelUtils.getTabIndexById(model, tabId), false); + TabModelUtils.setIndex(model, TabModelUtils.getTabIndexById(model, tabId), + StartSurfaceUserData.getInstance().getUnusedTabRestoredAtStartup()); + StartSurfaceUserData.getInstance().setLastVisitedTabAtStartupUrl(tabToRestore.url); boolean isIncognitoTabModelSelected = mTabModelSelector.isIncognitoSelected(); // Setting the index will cause the tab's model to be selected. Set it back to the model @@ -825,8 +810,8 @@ } private void addTabToSaveQueueIfApplicable(Tab tab) { - if (tab == null || tab.isDestroyed()) return; - if (mTabsToSave.contains(tab) || !TabStateAttributes.from(tab).isTabStateDirty() + if (tab == null) return; + if (mTabsToSave.contains(tab) || !((TabImpl) tab).isTabStateDirty() || isTabUrlContentScheme(tab)) { return; } @@ -1267,9 +1252,7 @@ protected void onPostExecute(Void v) { if (mDestroyed || isCancelled()) return; if (mStateSaved) { - if (!mTab.isDestroyed()) { - TabStateAttributes.from(mTab).setIsTabStateDirty(false); - } + ((TabImpl) mTab).setIsTabStateDirty(false); mTab.setIsTabSaveEnabled(isCriticalPersistedTabDataEnabled()); } mSaveTabTask = null;
diff --git a/chrome/android/java/src/org/chromium/chrome/browser/tasks/ReturnToChromeExperimentsUtil.java b/chrome/android/java/src/org/chromium/chrome/browser/tasks/ReturnToChromeExperimentsUtil.java index d877564c5..359a519 100644 --- a/chrome/android/java/src/org/chromium/chrome/browser/tasks/ReturnToChromeExperimentsUtil.java +++ b/chrome/android/java/src/org/chromium/chrome/browser/tasks/ReturnToChromeExperimentsUtil.java
@@ -47,6 +47,7 @@ import org.chromium.chrome.features.start_surface.StartSurfaceUserData; import org.chromium.components.embedder_support.util.UrlConstants; import org.chromium.components.embedder_support.util.UrlUtilities; +import org.chromium.components.embedder_support.util.UrlUtilitiesJni; import org.chromium.components.optimization_guide.proto.ModelsProto.OptimizationTarget; import org.chromium.components.segmentation_platform.SegmentationPlatformService; import org.chromium.components.signin.identitymanager.ConsentLevel; @@ -67,6 +68,9 @@ @VisibleForTesting public static final long INVALID_DECISION_TIMESTAMP = -1L; public static final long MILLISECONDS_PER_DAY = TimeUtils.SECONDS_PER_DAY * 1000; + @VisibleForTesting + public static final String LAST_VISITED_TAB_IS_SRP_WHEN_OVERVIEW_IS_SHOWN_AT_LAUNCH_UMA = + "Startup.Android.LastVisitedTabIsSRPWhenOverviewShownAtLaunch"; private static final String START_SEGMENTATION_PLATFORM_KEY = "chrome_start_android"; @@ -880,6 +884,17 @@ onUIClicked(ChromePreferenceKeys.TAP_MV_TILES_COUNT); } + /** + * Record whether the last visited tab shown in the single tab switcher or carousel tab switcher + * is a search result page or not. This should be called when Start surface is shown at startup. + */ + public static void recordLastVisitedTabIsSRPWhenOverviewIsShownAtLaunch() { + RecordHistogram.recordBooleanHistogram( + LAST_VISITED_TAB_IS_SRP_WHEN_OVERVIEW_IS_SHOWN_AT_LAUNCH_UMA, + UrlUtilitiesJni.get().isGoogleSearchUrl( + StartSurfaceUserData.getInstance().getLastVisitedTabAtStartupUrl())); + } + @VisibleForTesting public static String getBehaviourTypeKeyForTesting(String key) { return getBehaviourType(key);
diff --git a/chrome/android/javatests/src/org/chromium/chrome/browser/omaha/OmahaBaseTest.java b/chrome/android/javatests/src/org/chromium/chrome/browser/omaha/OmahaBaseTest.java index b9ecdbef..daa94d3f 100644 --- a/chrome/android/javatests/src/org/chromium/chrome/browser/omaha/OmahaBaseTest.java +++ b/chrome/android/javatests/src/org/chromium/chrome/browser/omaha/OmahaBaseTest.java
@@ -18,8 +18,10 @@ import org.junit.runner.RunWith; import org.chromium.base.ApiCompatibilityUtils; +import org.chromium.base.FeatureList; import org.chromium.base.test.util.AdvancedMockContext; import org.chromium.base.test.util.Feature; +import org.chromium.chrome.browser.flags.ChromeFeatureList; import org.chromium.chrome.test.ChromeJUnit4ClassRunner; import org.chromium.chrome.test.omaha.MockRequestGenerator; import org.chromium.chrome.test.omaha.MockRequestGenerator.DeviceType; @@ -188,10 +190,14 @@ Context targetContext = InstrumentationRegistry.getTargetContext(); OmahaBase.setIsDisabledForTesting(false); mContext = new AdvancedMockContext(targetContext); + FeatureList.TestValues overrides = new FeatureList.TestValues(); + overrides.addFeatureFlagOverride(ChromeFeatureList.ANONYMOUS_UPDATE_CHECKS, true); + FeatureList.setTestValues(overrides); } @After public void tearDown() { + FeatureList.setTestValues(null); OmahaBase.setIsDisabledForTesting(true); }
diff --git a/chrome/android/javatests/src/org/chromium/chrome/browser/omaha/RequestGeneratorTest.java b/chrome/android/javatests/src/org/chromium/chrome/browser/omaha/RequestGeneratorTest.java index 8abf7fc..8fac2c1 100644 --- a/chrome/android/javatests/src/org/chromium/chrome/browser/omaha/RequestGeneratorTest.java +++ b/chrome/android/javatests/src/org/chromium/chrome/browser/omaha/RequestGeneratorTest.java
@@ -14,13 +14,17 @@ import androidx.test.filters.SmallTest; +import org.junit.After; import org.junit.Assert; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; +import org.chromium.base.FeatureList; import org.chromium.base.test.util.AdvancedMockContext; import org.chromium.base.test.util.Feature; +import org.chromium.chrome.browser.flags.ChromeFeatureList; import org.chromium.chrome.browser.signin.services.IdentityServicesProvider; import org.chromium.chrome.browser.uid.SettingsSecureBasedIdentificationGenerator; import org.chromium.chrome.browser.uid.UniqueIdentificationGenerator; @@ -42,6 +46,18 @@ @Rule public final AccountManagerTestRule mAccountManagerTestRule = new AccountManagerTestRule(); + @Before + public void setUp() { + FeatureList.TestValues overrides = new FeatureList.TestValues(); + overrides.addFeatureFlagOverride(ChromeFeatureList.ANONYMOUS_UPDATE_CHECKS, true); + FeatureList.setTestValues(overrides); + } + + @After + public void tearDown() { + FeatureList.setTestValues(null); + } + @Test @SmallTest @Feature({"Omaha"}) @@ -121,6 +137,31 @@ createAndCheckXML(DeviceType.TABLET, false); } + @Test + @SmallTest + @Feature({"Omaha"}) + public void testXMLCreationWithUID() { + FeatureList.TestValues overrides = new FeatureList.TestValues(); + overrides.addFeatureFlagOverride(ChromeFeatureList.ANONYMOUS_UPDATE_CHECKS, false); + FeatureList.setTestValues(overrides); + IdentityServicesProvider.setInstanceForTests(mock(IdentityServicesProvider.class)); + when(IdentityServicesProvider.get().getIdentityManager(any())) + .thenReturn(mock(IdentityManager.class)); + when(IdentityServicesProvider.get().getIdentityManager(any()).hasPrimaryAccount(anyInt())) + .thenReturn(true); + MockRequestGenerator generator = new MockRequestGenerator( + new AdvancedMockContext(InstrumentationRegistry.getTargetContext()), + DeviceType.TABLET); + String xml = null; + try { + xml = generator.generateXML( + "", "", 0, 0, new RequestData(false, 0, "", INSTALL_SOURCE)); + } catch (RequestFailureException e) { + Assert.fail("XML generation failed."); + } + checkForAttributeAndValue(xml, "request", "userid", "{" + generator.getDeviceID() + "}"); + } + /** * Checks that the XML is being created properly. */ @@ -179,8 +220,6 @@ checkForTag(xml, "updatecheck")); } - checkForAttributeAndValue(xml, "request", "userid", "{" + generator.getDeviceID() + "}"); - return generator; }
diff --git a/chrome/android/javatests/src/org/chromium/chrome/browser/tab/state/CriticalPersistedTabDataTest.java b/chrome/android/javatests/src/org/chromium/chrome/browser/tab/state/CriticalPersistedTabDataTest.java index b1c2d8f..a3bd67b 100644 --- a/chrome/android/javatests/src/org/chromium/chrome/browser/tab/state/CriticalPersistedTabDataTest.java +++ b/chrome/android/javatests/src/org/chromium/chrome/browser/tab/state/CriticalPersistedTabDataTest.java
@@ -33,7 +33,8 @@ import org.chromium.chrome.browser.tab.TabStateExtractor; import org.chromium.chrome.browser.tab.TabUserAgent; import org.chromium.chrome.browser.tab.WebContentsState; -import org.chromium.chrome.browser.tab.proto.CriticalPersistedTabData.CriticalPersistedTabDataProto; +import org.chromium.chrome.browser.tab.flatbuffer.LaunchTypeAtCreation; +import org.chromium.chrome.browser.tab.flatbuffer.UserAgentType; import org.chromium.chrome.test.ChromeTabbedActivityTestRule; import org.chromium.chrome.test.batch.BlankCTATabInitialStateRule; import org.chromium.components.embedder_support.util.UrlConstants; @@ -311,16 +312,31 @@ @SmallTest @Test public void testWebContentsStateBug_crbug_1220839() throws InterruptedException { + PersistedTabDataConfiguration.setUseTestConfig(false); String url = mTestServer.getURL("/chrome/test/data/browsing_data/e.html"); Tab tab = sActivityTestRule.loadUrlInNewTab(url); + final Semaphore semaphore = new Semaphore(0); + // Saving serialized CriticalPersistedTabData ensures we get a direct ByteBuffer + // which is assumed in the rest of Clank. See crbug.com/1220839 for more details. ThreadUtils.runOnUiThreadBlocking(() -> { CriticalPersistedTabData criticalPersistedTabData = new CriticalPersistedTabData(tab, "", "", PARENT_ID, ROOT_ID, TIMESTAMP, TabStateExtractor.getWebContentsState(tab), CONTENT_STATE_VERSION, OPENER_APP_ID, THEME_COLOR, LAUNCH_TYPE_AT_CREATION, USER_AGENT_A); - ByteBuffer serialized = criticalPersistedTabData.getSerializeSupplier().get(); PersistedTabDataConfiguration config = PersistedTabDataConfiguration.get( - ShoppingPersistedTabData.class, tab.isIncognito()); + CriticalPersistedTabData.class, tab.isIncognito()); + FilePersistedTabDataStorage persistedTabDataStorage = new FilePersistedTabDataStorage(); + persistedTabDataStorage.save(tab.getId(), config.getId(), () -> { + return criticalPersistedTabData.getSerializeSupplier().get(); + }, semaphore::release); + }); + semaphore.acquire(); + ThreadUtils.runOnUiThreadBlocking(() -> { + PersistedTabDataConfiguration config = PersistedTabDataConfiguration.get( + CriticalPersistedTabData.class, tab.isIncognito()); + + ByteBuffer serialized = + CriticalPersistedTabData.restore(tab.getId(), tab.isIncognito()); CriticalPersistedTabData deserialized = new CriticalPersistedTabData( tab, serialized, config.getStorage(), config.getId()); Assert.assertEquals( @@ -581,9 +597,8 @@ @SmallTest @Test public void testConvertProtoLaunchTypeToTabLaunchType() { - for (CriticalPersistedTabDataProto.LaunchTypeAtCreation type : - CriticalPersistedTabDataProto.LaunchTypeAtCreation.values()) { - if (type == CriticalPersistedTabDataProto.LaunchTypeAtCreation.UNKNOWN) continue; + for (int type = 0; type < LaunchTypeAtCreation.names.length; type++) { + if (type == LaunchTypeAtCreation.UNKNOWN) continue; CriticalPersistedTabData.getLaunchType(type); } } @@ -593,37 +608,24 @@ public void testConvertTabUserAgentToProtoUserAgentType() { for (@TabUserAgent int tabUserAgent = 0; tabUserAgent <= TabUserAgent.SIZE; tabUserAgent++) { - CriticalPersistedTabDataProto.UserAgentType protoUserAgentType = - CriticalPersistedTabData.getUserAgentType(tabUserAgent); - if (tabUserAgent == TabUserAgent.DEFAULT) { - Assert.assertEquals("TabUserAgent value is mapped incorrectly.", protoUserAgentType, - CriticalPersistedTabDataProto.UserAgentType.DEFAULT); - } else { - Assert.assertNotEquals("TabUserAgent value is invalid.", protoUserAgentType, - CriticalPersistedTabDataProto.UserAgentType.DEFAULT); - } + int flatBufferUserAgentType = CriticalPersistedTabData.getUserAgentType(tabUserAgent); + Assert.assertNotEquals("TabUserAgent value is invalid.", flatBufferUserAgentType, + UserAgentType.USER_AGENT_UNKNOWN); if (tabUserAgent != TabUserAgent.SIZE) continue; Assert.assertEquals("TabUserAgent and ProtoUserAgentType should have the same size.", - protoUserAgentType, - CriticalPersistedTabDataProto.UserAgentType.USER_AGENT_SIZE); + flatBufferUserAgentType, UserAgentType.USER_AGENT_SIZE); } } @SmallTest @Test public void testConvertProtoUserAgentTypeToTabUserAgent() { - for (CriticalPersistedTabDataProto.UserAgentType type : - CriticalPersistedTabDataProto.UserAgentType.values()) { + for (int type = 0; type < UserAgentType.names.length; type++) { + if (type == UserAgentType.USER_AGENT_UNKNOWN) continue; @TabUserAgent - int tabUserAgent = CriticalPersistedTabData.getUserAgentType(type); - if (type == CriticalPersistedTabDataProto.UserAgentType.DEFAULT) { - Assert.assertEquals("ProtoUserAgentType value is mapped incorrectly.", tabUserAgent, - TabUserAgent.DEFAULT); - } else { - Assert.assertNotEquals( - "ProtoUserAgentType value is invalid.", tabUserAgent, TabUserAgent.DEFAULT); - } - if (type != CriticalPersistedTabDataProto.UserAgentType.USER_AGENT_SIZE) continue; + int tabUserAgent = CriticalPersistedTabData.getTabUserAgentType(type); + Assert.assertNotNull("ProtoUserAgentType value is invalid.", tabUserAgent); + if (type != UserAgentType.USER_AGENT_SIZE) continue; Assert.assertEquals("TabUserAgent and ProtoUserAgentType should have the same size.", tabUserAgent, TabUserAgent.SIZE); }
diff --git a/chrome/android/javatests/src/org/chromium/chrome/browser/tabmodel/TabPersistentStoreUnitTest.java b/chrome/android/javatests/src/org/chromium/chrome/browser/tabmodel/TabPersistentStoreUnitTest.java index 254f778..565cac0 100644 --- a/chrome/android/javatests/src/org/chromium/chrome/browser/tabmodel/TabPersistentStoreUnitTest.java +++ b/chrome/android/javatests/src/org/chromium/chrome/browser/tabmodel/TabPersistentStoreUnitTest.java
@@ -41,7 +41,6 @@ import org.chromium.chrome.browser.tab.TabImpl; import org.chromium.chrome.browser.tab.TabLaunchType; import org.chromium.chrome.browser.tab.TabState; -import org.chromium.chrome.browser.tab.TabStateAttributes; import org.chromium.chrome.browser.tabmodel.TabPersistentStore.TabModelSelectorMetadata; import org.chromium.chrome.browser.tabmodel.TabPersistentStore.TabRestoreDetails; import org.chromium.components.embedder_support.util.UrlConstants; @@ -137,7 +136,7 @@ TabImpl emptyNtpTab = mock(TabImpl.class); when(emptyNtpTab.getUrl()).thenReturn(new GURL(UrlConstants.NTP_URL)); - TabStateAttributes.from(emptyNtpTab).setIsTabStateDirty(true); + when(emptyNtpTab.isTabStateDirty()).thenReturn(true); when(emptyNtpTab.canGoBack()).thenReturn(false); when(emptyNtpTab.canGoForward()).thenReturn(false); @@ -146,7 +145,7 @@ TabImpl ntpWithBackNavTab = mock(TabImpl.class); when(ntpWithBackNavTab.getUrl()).thenReturn(new GURL(UrlConstants.NTP_URL)); - TabStateAttributes.from(ntpWithBackNavTab).setIsTabStateDirty(true); + when(ntpWithBackNavTab.isTabStateDirty()).thenReturn(true); when(ntpWithBackNavTab.canGoBack()).thenReturn(true); when(ntpWithBackNavTab.canGoForward()).thenReturn(false); @@ -155,7 +154,7 @@ TabImpl ntpWithForwardNavTab = mock(TabImpl.class); when(ntpWithForwardNavTab.getUrl()).thenReturn(new GURL(UrlConstants.NTP_URL)); - TabStateAttributes.from(ntpWithForwardNavTab).setIsTabStateDirty(true); + when(ntpWithForwardNavTab.isTabStateDirty()).thenReturn(true); when(ntpWithForwardNavTab.canGoBack()).thenReturn(false); when(ntpWithForwardNavTab.canGoForward()).thenReturn(true); @@ -164,7 +163,7 @@ TabImpl ntpWithAllTheNavsTab = mock(TabImpl.class); when(ntpWithAllTheNavsTab.getUrl()).thenReturn(new GURL(UrlConstants.NTP_URL)); - TabStateAttributes.from(ntpWithAllTheNavsTab).setIsTabStateDirty(true); + when(ntpWithAllTheNavsTab.isTabStateDirty()).thenReturn(true); when(ntpWithAllTheNavsTab.canGoBack()).thenReturn(true); when(ntpWithAllTheNavsTab.canGoForward()).thenReturn(true);
diff --git a/chrome/android/junit/src/org/chromium/chrome/browser/tab/TabUnitTest.java b/chrome/android/junit/src/org/chromium/chrome/browser/tab/TabUnitTest.java index ccf2a54..7e70563 100644 --- a/chrome/android/junit/src/org/chromium/chrome/browser/tab/TabUnitTest.java +++ b/chrome/android/junit/src/org/chromium/chrome/browser/tab/TabUnitTest.java
@@ -119,21 +119,21 @@ verify(mCriticalPersistedTabDataObserver).onRootIdChanged(mTab, TAB2_ID); assertThat(CriticalPersistedTabData.from(mTab).getRootId(), equalTo(TAB2_ID)); - assertThat(TabStateAttributes.from(mTab).isTabStateDirty(), equalTo(true)); + assertThat(mTab.isTabStateDirty(), equalTo(true)); } @Test @SmallTest public void testSetRootIdWithoutChange() { assertThat(CriticalPersistedTabData.from(mTab).getRootId(), equalTo(TAB1_ID)); - TabStateAttributes.from(mTab).setIsTabStateDirty(false); + mTab.setIsTabStateDirty(false); CriticalPersistedTabData.from(mTab).setRootId(TAB1_ID); verify(mCriticalPersistedTabDataObserver, never()) .onRootIdChanged(any(Tab.class), anyInt()); assertThat(CriticalPersistedTabData.from(mTab).getRootId(), equalTo(TAB1_ID)); - assertThat(TabStateAttributes.from(mTab).isTabStateDirty(), equalTo(false)); + assertThat(mTab.isTabStateDirty(), equalTo(false)); } @Test
diff --git a/chrome/android/static_initializers.gni b/chrome/android/static_initializers.gni index cf8ea29..977465e 100644 --- a/chrome/android/static_initializers.gni +++ b/chrome/android/static_initializers.gni
@@ -20,12 +20,9 @@ # Comments show static_initializers according to # tools/linux/dump-static-initializers.py. - # 000101 (initializer offset 0x142e924 size 0xc) - # std::__1::ios_base::Init::Init() - # # iostream.cpp (initializer offset 0x142e930 size 0x2) # [empty ctor, but it still has cost on gcc <4.6] - expected_static_initializer_count = 2 + expected_static_initializer_count = 1 # TODO(https://crbug.com/1177849): Remove from tflite: # tflite_engine.cc (initializer offset 0x34cb88c size 0x2c)
diff --git a/chrome/browser/BUILD.gn b/chrome/browser/BUILD.gn index c6cb2db..fa185249 100644 --- a/chrome/browser/BUILD.gn +++ b/chrome/browser/BUILD.gn
@@ -1645,10 +1645,6 @@ "signin/identity_manager_factory.h", "signin/investigator_dependency_provider.cc", "signin/investigator_dependency_provider.h", - "signin/primary_account_policy_manager.cc", - "signin/primary_account_policy_manager.h", - "signin/primary_account_policy_manager_factory.cc", - "signin/primary_account_policy_manager_factory.h", "signin/reauth_result.h", "signin/reauth_tab_helper.cc", "signin/reauth_tab_helper.h", @@ -1939,7 +1935,6 @@ "//base/allocator:buildflags", "//build:branding_buildflags", "//build:chromeos_buildflags", - "//build:os_buildflags", "//build/config/compiler:compiler_buildflags", "//cc", "//chrome:extra_resources",
diff --git a/chrome/browser/android/customtabs/custom_tabs_browsertest.cc b/chrome/browser/android/customtabs/custom_tabs_browsertest.cc index 2d8129e..5b1ec79 100644 --- a/chrome/browser/android/customtabs/custom_tabs_browsertest.cc +++ b/chrome/browser/android/customtabs/custom_tabs_browsertest.cc
@@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "base/ignore_result.h" +#include <tuple> + #include "base/run_loop.h" #include "base/strings/string_number_conversions.h" #include "base/strings/string_util.h" @@ -94,7 +95,7 @@ auto* web_contents = content::WebContents::FromRenderFrameHost(host); ClientDataHeaderWebContentsObserver::FromWebContents(web_contents) ->SetHeader(kHeaderValue2); - ignore_result(ExecJs(host, "document.images[0].src = 'google3.jpg'")); + std::ignore = ExecJs(host, "document.images[0].src = 'google3.jpg'"); } private:
diff --git a/chrome/browser/apps/app_service/media_access_browsertest.cc b/chrome/browser/apps/app_service/media_access_browsertest.cc index 7e1afa9..f6bc1ab 100644 --- a/chrome/browser/apps/app_service/media_access_browsertest.cc +++ b/chrome/browser/apps/app_service/media_access_browsertest.cc
@@ -288,7 +288,7 @@ // Request accessing the camera for |web_contents|. MediaRequestChangeForWebContent( - web_contents, web_contents->GetLastCommittedURL(), + web_contents, web_contents->GetVisibleURL(), blink::mojom::MediaStreamType::DEVICE_VIDEO_CAPTURE, content::MEDIA_REQUEST_STATE_DONE); @@ -297,7 +297,7 @@ // Request accessing the microphone for |web_contents|. MediaRequestChangeForWebContent( - web_contents, web_contents->GetLastCommittedURL(), + web_contents, web_contents->GetVisibleURL(), blink::mojom::MediaStreamType::DEVICE_AUDIO_CAPTURE, content::MEDIA_REQUEST_STATE_DONE); @@ -306,7 +306,7 @@ // Stop accessing the microphone for |web_contents|. MediaRequestChangeForWebContent( - web_contents, web_contents->GetLastCommittedURL(), + web_contents, web_contents->GetVisibleURL(), blink::mojom::MediaStreamType::DEVICE_AUDIO_CAPTURE, content::MEDIA_REQUEST_STATE_CLOSING); @@ -315,7 +315,7 @@ // Stop accessing the camera for |web_contents|. MediaRequestChangeForWebContent( - web_contents, web_contents->GetLastCommittedURL(), + web_contents, web_contents->GetVisibleURL(), blink::mojom::MediaStreamType::DEVICE_VIDEO_CAPTURE, content::MEDIA_REQUEST_STATE_CLOSING);
diff --git a/chrome/browser/apps/app_service/publishers/arc_apps_unittest.cc b/chrome/browser/apps/app_service/publishers/arc_apps_unittest.cc index 96d244a..5ff34f35 100644 --- a/chrome/browser/apps/app_service/publishers/arc_apps_unittest.cc +++ b/chrome/browser/apps/app_service/publishers/arc_apps_unittest.cc
@@ -111,7 +111,8 @@ // Verifies that a call to set the supported links preference from ARC persists // the setting in app service. -TEST_F(ArcAppsPublisherTest, SetSupportedLinksFromArc) { +// Flaky: https://crbug.com/1285361. +TEST_F(ArcAppsPublisherTest, DISABLED_SetSupportedLinksFromArc) { constexpr char kTestAuthority[] = "www.example.com"; const auto& fake_apps = arc_test()->fake_apps(); std::string package_name = fake_apps[0].package_name;
diff --git a/chrome/browser/ash/accessibility/dictation_browsertest.cc b/chrome/browser/ash/accessibility/dictation_browsertest.cc index 32cfe60..fa966a9 100644 --- a/chrome/browser/ash/accessibility/dictation_browsertest.cc +++ b/chrome/browser/ash/accessibility/dictation_browsertest.cc
@@ -286,6 +286,12 @@ DictationTest(const DictationTest&) = delete; DictationTest& operator=(const DictationTest&) = delete; + void SetUpCommandLine(base::CommandLine* command_line) override { + DictationBaseTest::SetUpCommandLine(command_line); + scoped_feature_list_.InitAndDisableFeature( + ::features::kExperimentalAccessibilityDictationExtension); + } + void SetUpOnMainThread() override { DictationBaseTest::SetUpOnMainThread(); @@ -340,6 +346,7 @@ std::unique_ptr<ui::MockIMEInputContextHandler> input_context_handler_; std::unique_ptr<ui::test::EventGenerator> generator_; ui::CompositionText empty_composition_text_; + base::test::ScopedFeatureList scoped_feature_list_; }; INSTANTIATE_TEST_SUITE_P(
diff --git a/chrome/browser/ash/app_mode/kiosk_profile_loader.cc b/chrome/browser/ash/app_mode/kiosk_profile_loader.cc index 49d06b07..0eb84ed5 100644 --- a/chrome/browser/ash/app_mode/kiosk_profile_loader.cc +++ b/chrome/browser/ash/app_mode/kiosk_profile_loader.cc
@@ -5,11 +5,11 @@ #include "chrome/browser/ash/app_mode/kiosk_profile_loader.h" #include <memory> +#include <tuple> #include "ash/components/login/auth/auth_status_consumer.h" #include "ash/components/login/auth/user_context.h" #include "base/bind.h" -#include "base/ignore_result.h" #include "base/location.h" #include "base/logging.h" #include "base/memory/weak_ptr.h" @@ -179,7 +179,7 @@ void KioskProfileLoader::OnAuthSuccess(const UserContext& user_context) { // LoginPerformer will delete itself. login_performer_->set_delegate(NULL); - ignore_result(login_performer_.release()); + std::ignore = login_performer_.release(); failed_mount_attempts_ = 0;
diff --git a/chrome/browser/ash/file_system_provider/operations/get_metadata.cc b/chrome/browser/ash/file_system_provider/operations/get_metadata.cc index d691b88..9e4621c 100644 --- a/chrome/browser/ash/file_system_provider/operations/get_metadata.cc +++ b/chrome/browser/ash/file_system_provider/operations/get_metadata.cc
@@ -9,9 +9,9 @@ #include <algorithm> #include <memory> #include <string> +#include <tuple> #include <utility> -#include "base/ignore_result.h" #include "chrome/common/extensions/api/file_system_provider.h" #include "chrome/common/extensions/api/file_system_provider_internal.h" @@ -55,8 +55,8 @@ // Allow to pass invalid modification time, since there is no way to verify // it easily on any earlier stage. base::Time output_modification_time; - ignore_result(base::Time::FromString(input_modification_time.c_str(), - &output_modification_time)); + std::ignore = base::Time::FromString(input_modification_time.c_str(), + &output_modification_time); output->modification_time = std::make_unique<base::Time>(output_modification_time); }
diff --git a/chrome/browser/ash/login/existing_user_controller.cc b/chrome/browser/ash/login/existing_user_controller.cc index 1053ec8..8079f64 100644 --- a/chrome/browser/ash/login/existing_user_controller.cc +++ b/chrome/browser/ash/login/existing_user_controller.cc
@@ -5,6 +5,7 @@ #include "chrome/browser/ash/login/existing_user_controller.h" #include <memory> +#include <tuple> #include <utility> #include <vector> @@ -23,7 +24,6 @@ #include "base/callback_helpers.h" #include "base/command_line.h" #include "base/compiler_specific.h" -#include "base/ignore_result.h" #include "base/logging.h" #include "base/metrics/histogram_functions.h" #include "base/scoped_observation.h" @@ -911,7 +911,7 @@ // LoginPerformer instance will delete itself in case of successful auth. login_performer_->set_delegate(nullptr); - ignore_result(login_performer_.release()); + std::ignore = login_performer_.release(); if (user_context.GetAuthFlow() == UserContext::AUTH_FLOW_OFFLINE) { base::UmaHistogramCounts100("Login.OfflineSuccess.Attempts",
diff --git a/chrome/browser/ash/login/users/chrome_user_manager_impl.cc b/chrome/browser/ash/login/users/chrome_user_manager_impl.cc index 21c2a46..3eb1ef1 100644 --- a/chrome/browser/ash/login/users/chrome_user_manager_impl.cc +++ b/chrome/browser/ash/login/users/chrome_user_manager_impl.cc
@@ -8,6 +8,7 @@ #include <memory> #include <set> +#include <tuple> #include <utility> #include <vector> @@ -25,7 +26,6 @@ #include "base/containers/span.h" #include "base/feature_list.h" #include "base/format_macros.h" -#include "base/ignore_result.h" #include "base/location.h" #include "base/logging.h" #include "base/memory/ptr_util.h" @@ -162,7 +162,7 @@ // Runs on SequencedWorkerPool thread. Passes resolved locale to UI thread. void ResolveLocale(const std::string& raw_locale, std::string* resolved_locale) { - ignore_result(l10n_util::CheckAndResolveLocale(raw_locale, resolved_locale)); + std::ignore = l10n_util::CheckAndResolveLocale(raw_locale, resolved_locale); } bool GetUserLockAttributes(const user_manager::User* user,
diff --git a/chrome/browser/ash/settings/device_settings_provider.cc b/chrome/browser/ash/settings/device_settings_provider.cc index a613af3f..1d9742c8e 100644 --- a/chrome/browser/ash/settings/device_settings_provider.cc +++ b/chrome/browser/ash/settings/device_settings_provider.cc
@@ -154,6 +154,7 @@ kReportDeviceSystemInfo, kReportDevicePrintJobs, kReportDeviceLoginLogout, + kReportCRDSessions, kReportOsUpdateStatus, kReportRunningKioskApp, kReportUploadFrequency,
diff --git a/chrome/browser/ash/smb_client/smb_service_unittest.cc b/chrome/browser/ash/smb_client/smb_service_unittest.cc index 6626467..9b04c7c1 100644 --- a/chrome/browser/ash/smb_client/smb_service_unittest.cc +++ b/chrome/browser/ash/smb_client/smb_service_unittest.cc
@@ -7,6 +7,7 @@ #include <stddef.h> #include <memory> +#include <tuple> #include <utility> #include "ash/components/disks/disk_mount_manager.h" @@ -15,7 +16,6 @@ #include "base/bind.h" #include "base/callback_helpers.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/json/json_reader.h" #include "base/memory/ptr_util.h" #include "base/run_loop.h" @@ -710,10 +710,10 @@ std::string(kSharePath) + base::NumberToString(i); const std::string mount_path = std::string(kMountPath) + base::NumberToString(i); - ignore_result(MountBasicShare(share_path, mount_path, + std::ignore = MountBasicShare(share_path, mount_path, base::BindOnce([](SmbMountResult result) { EXPECT_EQ(SmbMountResult::kSuccess, result); - }))); + })); } // Check: After mounting the maximum number of shares, requesting to mount an @@ -722,24 +722,24 @@ std::string(kSharePath) + base::NumberToString(kMaxSmbFsShares); const std::string mount_path = std::string(kMountPath) + base::NumberToString(kMaxSmbFsShares); - ignore_result(MountBasicShare( + std::ignore = MountBasicShare( share_path, mount_path, base::BindOnce([](SmbMountResult result) { EXPECT_EQ(SmbMountResult::kTooManyOpened, result); - }))); + })); } TEST_F(SmbServiceWithSmbfsTest, GetSmbFsShareForPath) { CreateService(profile_); WaitForSetupComplete(); - ignore_result(MountBasicShare(kSharePath, kMountPath, + std::ignore = MountBasicShare(kSharePath, kMountPath, base::BindOnce([](SmbMountResult result) { EXPECT_EQ(SmbMountResult::kSuccess, result); - }))); - ignore_result(MountBasicShare(kSharePath2, kMountPath2, + })); + std::ignore = MountBasicShare(kSharePath2, kMountPath2, base::BindOnce([](SmbMountResult result) { EXPECT_EQ(SmbMountResult::kSuccess, result); - }))); + })); SmbFsShare* share = smb_service_->GetSmbFsShareForPath(base::FilePath(kMountPath)); @@ -764,23 +764,23 @@ CreateService(profile_); WaitForSetupComplete(); - ignore_result(MountBasicShare(kSharePath, kMountPath, + std::ignore = MountBasicShare(kSharePath, kMountPath, base::BindOnce([](SmbMountResult result) { EXPECT_EQ(SmbMountResult::kSuccess, result); - }))); + })); // A second mount with the same share path should fail. - ignore_result(MountBasicShare( + std::ignore = MountBasicShare( kSharePath, kMountPath2, base::BindOnce([](SmbMountResult result) { EXPECT_EQ(SmbMountResult::kMountExists, result); - }))); + })); // Unmounting and mounting again should succeed. smb_service_->UnmountSmbFs(base::FilePath(kMountPath)); - ignore_result(MountBasicShare(kSharePath, kMountPath2, + std::ignore = MountBasicShare(kSharePath, kMountPath2, base::BindOnce([](SmbMountResult result) { EXPECT_EQ(SmbMountResult::kSuccess, result); - }))); + })); } } // namespace smb_client
diff --git a/chrome/browser/ash/web_applications/personalization_app/chrome_personalization_app_theme_provider.cc b/chrome/browser/ash/web_applications/personalization_app/chrome_personalization_app_theme_provider.cc new file mode 100644 index 0000000..40a333bb --- /dev/null +++ b/chrome/browser/ash/web_applications/personalization_app/chrome_personalization_app_theme_provider.cc
@@ -0,0 +1,47 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "chrome/browser/ash/web_applications/personalization_app/chrome_personalization_app_theme_provider.h" + +#include "ash/style/ash_color_provider.h" +#include "chrome/browser/profiles/profile.h" + +ChromePersonalizationAppThemeProvider::ChromePersonalizationAppThemeProvider( + content::WebUI* web_ui) + : web_ui_(web_ui), profile_(Profile::FromWebUI(web_ui_)) {} + +ChromePersonalizationAppThemeProvider:: + ~ChromePersonalizationAppThemeProvider() = default; + +void ChromePersonalizationAppThemeProvider::BindInterface( + mojo::PendingReceiver<ash::personalization_app::mojom::ThemeProvider> + receiver) { + theme_receiver_.reset(); + theme_receiver_.Bind(std::move(receiver)); +} + +void ChromePersonalizationAppThemeProvider::SetThemeObserver( + mojo::PendingRemote<ash::personalization_app::mojom::ThemeObserver> + observer) { + // May already be bound if user refreshes page. + theme_observer_remote_.reset(); + theme_observer_remote_.Bind(std::move(observer)); + if (!color_mode_observer_.IsObserving()) + color_mode_observer_.Observe(ash::AshColorProvider::Get()); + // Call it once to get the current color mode. + OnColorModeChanged(ash::ColorProvider::Get()->IsDarkModeEnabled()); +} + +void ChromePersonalizationAppThemeProvider::OnColorModeChanged( + bool dark_mode_enabled) { + DCHECK(theme_observer_remote_.is_bound()); + theme_observer_remote_->OnColorModeChanged(dark_mode_enabled); +} + +void ChromePersonalizationAppThemeProvider::SetColorModePref( + bool dark_mode_enabled) { + auto* color_provider = ash::AshColorProvider::Get(); + if (color_provider->IsDarkModeEnabled() != dark_mode_enabled) + color_provider->ToggleColorMode(); +}
diff --git a/chrome/browser/ash/web_applications/personalization_app/chrome_personalization_app_theme_provider.h b/chrome/browser/ash/web_applications/personalization_app/chrome_personalization_app_theme_provider.h new file mode 100644 index 0000000..e82fbb5 --- /dev/null +++ b/chrome/browser/ash/web_applications/personalization_app/chrome_personalization_app_theme_provider.h
@@ -0,0 +1,65 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CHROME_BROWSER_ASH_WEB_APPLICATIONS_PERSONALIZATION_APP_CHROME_PERSONALIZATION_APP_THEME_PROVIDER_H_ +#define CHROME_BROWSER_ASH_WEB_APPLICATIONS_PERSONALIZATION_APP_CHROME_PERSONALIZATION_APP_THEME_PROVIDER_H_ + +#include "ash/public/cpp/style/color_mode_observer.h" +#include "ash/public/cpp/style/color_provider.h" +#include "ash/webui/personalization_app/personalization_app_theme_provider.h" +#include "base/scoped_observation.h" +#include "mojo/public/cpp/bindings/receiver.h" +#include "mojo/public/cpp/bindings/remote.h" + +class Profile; + +namespace content { +class WebUI; +} // namespace content + +class ChromePersonalizationAppThemeProvider + : public ash::PersonalizationAppThemeProvider, + ash::ColorModeObserver { + public: + explicit ChromePersonalizationAppThemeProvider(content::WebUI* web_ui); + + ChromePersonalizationAppThemeProvider( + const ChromePersonalizationAppThemeProvider&) = delete; + ChromePersonalizationAppThemeProvider& operator=( + const ChromePersonalizationAppThemeProvider&) = delete; + + ~ChromePersonalizationAppThemeProvider() override; + + // PersonalizationAppThemeProvider: + void BindInterface( + mojo::PendingReceiver<ash::personalization_app::mojom::ThemeProvider> + receiver) override; + + // ash::personalization_app::mojom::ThemeProvider: + void SetThemeObserver( + mojo::PendingRemote<ash::personalization_app::mojom::ThemeObserver> + observer) override; + + void SetColorModePref(bool dark_mode_enabled) override; + + // ash::ColorModeObserver: + void OnColorModeChanged(bool dark_mode_enabled) override; + + private: + content::WebUI* const web_ui_ = nullptr; + + // Pointer to profile of user that opened personalization SWA. Not owned. + Profile* const profile_ = nullptr; + + base::ScopedObservation<ash::ColorProvider, ash::ColorModeObserver> + color_mode_observer_{this}; + + mojo::Receiver<ash::personalization_app::mojom::ThemeProvider> + theme_receiver_{this}; + + mojo::Remote<ash::personalization_app::mojom::ThemeObserver> + theme_observer_remote_; +}; + +#endif // CHROME_BROWSER_ASH_WEB_APPLICATIONS_PERSONALIZATION_APP_CHROME_PERSONALIZATION_APP_THEME_PROVIDER_H_
diff --git a/chrome/browser/autofill/autofill_browsertest.cc b/chrome/browser/autofill/autofill_browsertest.cc index bb45dea..d868e8f3 100644 --- a/chrome/browser/autofill/autofill_browsertest.cc +++ b/chrome/browser/autofill/autofill_browsertest.cc
@@ -6,10 +6,10 @@ #include <memory> #include <string> +#include <tuple> #include "base/command_line.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/memory/raw_ptr.h" #include "base/memory/ref_counted.h" #include "base/rand_util.h" @@ -872,9 +872,9 @@ int host_id = prerender_helper().AddPrerender(prerender_url); auto* rfh = prerender_helper().GetPrerenderedMainFrameHost(host_id); - ignore_result( + std::ignore = content::ExecJs(rfh, "document.querySelector('#NAME_FIRST').focus();", - content::EXECUTE_SCRIPT_NO_USER_GESTURE)); + content::EXECUTE_SCRIPT_NO_USER_GESTURE); // Since the initial prerender page load has finished at this point and we // have issued our programmatic focus, we need to check that the expectations
diff --git a/chrome/browser/browsing_data/access_context_audit_database.cc b/chrome/browser/browsing_data/access_context_audit_database.cc index 8b67a6b..70fbe5a 100644 --- a/chrome/browser/browsing_data/access_context_audit_database.cc +++ b/chrome/browser/browsing_data/access_context_audit_database.cc
@@ -4,8 +4,9 @@ #include "chrome/browser/browsing_data/access_context_audit_database.h" +#include <tuple> + #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/logging.h" #include "base/metrics/histogram_functions.h" #include "base/rand_util.h" @@ -43,7 +44,7 @@ // or hardware issues, not coding errors at the client level, so displaying // the error would probably lead to confusion. The ignored call signals the // test-expectation framework that the error was handled. - ignore_result(sql::Database::IsExpectedSqliteError(extended_error)); + std::ignore = sql::Database::IsExpectedSqliteError(extended_error); return; }
diff --git a/chrome/browser/captive_portal/captive_portal_browsertest.cc b/chrome/browser/captive_portal/captive_portal_browsertest.cc index 1d99805d..d798dce 100644 --- a/chrome/browser/captive_portal/captive_portal_browsertest.cc +++ b/chrome/browser/captive_portal/captive_portal_browsertest.cc
@@ -9,6 +9,7 @@ #include <memory> #include <set> #include <string> +#include <tuple> #include <utility> #include <vector> @@ -18,7 +19,6 @@ #include "base/compiler_specific.h" #include "base/files/file_path.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/memory/raw_ptr.h" #include "base/path_service.h" #include "base/run_loop.h" @@ -916,7 +916,7 @@ EXPECT_EQ(expected_num_jobs, static_cast<int>(ongoing_mock_requests_.size())); for (auto& job : ongoing_mock_requests_) - ignore_result(job.client.Unbind().PassPipe().release()); + std::ignore = job.client.Unbind().PassPipe().release(); ongoing_mock_requests_.clear(); }
diff --git a/chrome/browser/chrome_browser_interface_binders.cc b/chrome/browser/chrome_browser_interface_binders.cc index 80e960e..e792281 100644 --- a/chrome/browser/chrome_browser_interface_binders.cc +++ b/chrome/browser/chrome_browser_interface_binders.cc
@@ -975,6 +975,12 @@ ash::PersonalizationAppUI>(map); } + if (ash::features::IsPersonalizationHubEnabled()) { + RegisterWebUIControllerInterfaceBinder< + ash::personalization_app::mojom::ThemeProvider, + ash::PersonalizationAppUI>(map); + } + RegisterWebUIControllerInterfaceBinder< launcher_internals::mojom::PageHandlerFactory, chromeos::LauncherInternalsUI>(map);
diff --git a/chrome/browser/chrome_browser_main.cc b/chrome/browser/chrome_browser_main.cc index e194037..712138e 100644 --- a/chrome/browser/chrome_browser_main.cc +++ b/chrome/browser/chrome_browser_main.cc
@@ -10,6 +10,7 @@ #include <memory> #include <set> #include <string> +#include <tuple> #include <utility> #include <vector> @@ -24,7 +25,6 @@ #include "base/feature_list.h" #include "base/files/file_path.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/logging.h" #include "base/memory/ptr_util.h" #include "base/memory/scoped_refptr.h" @@ -1867,7 +1867,7 @@ // The below call to browser_shutdown::ShutdownPostThreadsStop() deletes // |browser_process_|. We release it so that we don't keep holding onto an // invalid reference. - ignore_result(browser_process_.release()); + std::ignore = browser_process_.release(); #if BUILDFLAG(ENABLE_DOWNGRADE_PROCESSING) if (result_code_ == chrome::RESULT_CODE_DOWNGRADE_AND_RELAUNCH) {
diff --git a/chrome/browser/chrome_security_exploit_browsertest.cc b/chrome/browser/chrome_security_exploit_browsertest.cc index e0bc0c3..8902f577 100644 --- a/chrome/browser/chrome_security_exploit_browsertest.cc +++ b/chrome/browser/chrome_security_exploit_browsertest.cc
@@ -2,9 +2,10 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <tuple> + #include "base/bind.h" #include "base/command_line.h" -#include "base/ignore_result.h" #include "base/memory/ptr_util.h" #include "base/memory/raw_ptr.h" #include "base/run_loop.h" @@ -500,8 +501,8 @@ // The renderer should always get killed, but sometimes ExecuteScript returns // true anyway, so just ignore the result. - ignore_result( - content::ExecuteScript(rfh, "URL.createObjectURL(new Blob(['foo']))")); + std::ignore = + content::ExecuteScript(rfh, "URL.createObjectURL(new Blob(['foo']))"); // If the process is killed, this test passes. EXPECT_EQ( @@ -544,8 +545,8 @@ // The renderer should always get killed, but sometimes ExecuteScript returns // true anyway, so just ignore the result. - ignore_result( - content::ExecuteScript(rfh, "URL.createObjectURL(new Blob(['foo']))")); + std::ignore = + content::ExecuteScript(rfh, "URL.createObjectURL(new Blob(['foo']))"); // If the process is killed, this test passes. EXPECT_EQ(
diff --git a/chrome/browser/chromeos/BUILD.gn b/chrome/browser/chromeos/BUILD.gn index d4e3b674..ffb4574 100644 --- a/chrome/browser/chromeos/BUILD.gn +++ b/chrome/browser/chromeos/BUILD.gn
@@ -3195,6 +3195,8 @@ "../ash/web_applications/os_settings_web_app_info.h", "../ash/web_applications/os_url_handler_system_web_app_info.cc", "../ash/web_applications/os_url_handler_system_web_app_info.h", + "../ash/web_applications/personalization_app/chrome_personalization_app_theme_provider.cc", + "../ash/web_applications/personalization_app/chrome_personalization_app_theme_provider.h", "../ash/web_applications/personalization_app/chrome_personalization_app_wallpaper_provider.cc", "../ash/web_applications/personalization_app/chrome_personalization_app_wallpaper_provider.h", "../ash/web_applications/personalization_app/personalization_app_info.cc",
diff --git a/chrome/browser/commerce/price_tracking/android/java/src/org/chromium/chrome/browser/price_tracking/PriceDropNotificationManager.java b/chrome/browser/commerce/price_tracking/android/java/src/org/chromium/chrome/browser/price_tracking/PriceDropNotificationManager.java index abab6ea..893cfa9a 100644 --- a/chrome/browser/commerce/price_tracking/android/java/src/org/chromium/chrome/browser/price_tracking/PriceDropNotificationManager.java +++ b/chrome/browser/commerce/price_tracking/android/java/src/org/chromium/chrome/browser/price_tracking/PriceDropNotificationManager.java
@@ -21,6 +21,7 @@ import androidx.annotation.Nullable; import androidx.annotation.VisibleForTesting; +import org.chromium.base.Callback; import org.chromium.base.ContextUtils; import org.chromium.base.IntentUtils; import org.chromium.base.Log; @@ -62,6 +63,8 @@ "org.chromium.chrome.browser.price_tracking.DESTINATION_URL"; static final String EXTRA_ACTION_ID = "org.chromium.chrome.browser.price_tracking.ACTION_ID"; static final String EXTRA_OFFER_ID = "org.chromium.chrome.browser.price_tracking.OFFER_ID"; + static final String EXTRA_PRODUCT_CLUSTER_ID = + "org.chromium.chrome.browser.price_tracking.PRODUCT_CLUSTER_ID"; private static NotificationManagerProxy sNotificationManagerForTesting; @@ -76,6 +79,7 @@ String destinationUrl = IntentUtils.safeGetStringExtra(intent, EXTRA_DESTINATION_URL); String actionId = IntentUtils.safeGetStringExtra(intent, EXTRA_ACTION_ID); String offerId = IntentUtils.safeGetStringExtra(intent, EXTRA_OFFER_ID); + String clusterId = IntentUtils.safeGetStringExtra(intent, EXTRA_PRODUCT_CLUSTER_ID); if (TextUtils.isEmpty(offerId)) { Log.e(TAG, "No offer id is provided when handling turn off alert action."); @@ -90,7 +94,7 @@ assert ACTION_ID_TURN_OFF_ALERT.equals(actionId) : "Currently only turn off alert action uses this activity."; priceDropNotificationManager.onNotificationActionClicked( - actionId, destinationUrl, offerId, /*recordMetrics=*/false); + actionId, destinationUrl, offerId, clusterId, /*recordMetrics=*/false); // Finish immediately. Could be better to have a callback from shopping backend. finish(); }); @@ -194,28 +198,52 @@ */ public void onNotificationActionClicked( String actionId, String url, @Nullable String offerId, boolean recordMetrics) { + onNotificationActionClicked(actionId, url, offerId, null, recordMetrics); + } + + /** + * Handles the notification action click events. + * + * @param actionId the id used to identify certain action. + * @param url of the tab which triggered the notification. + * @param offerId the id of the offer associated with this notification. + * @param clusterId The id of the cluster associated with the product notification. + * @param recordMetrics Whether to record metrics using {@link NotificationUmaTracker}. Only + * Chime notification code path should set this to true. + */ + public void onNotificationActionClicked(String actionId, String url, @Nullable String offerId, + @Nullable String clusterId, boolean recordMetrics) { if (actionId.equals(ACTION_ID_VISIT_SITE) && recordMetrics) { NotificationUmaTracker.getInstance().onNotificationActionClick( NotificationUmaTracker.ActionType.PRICE_DROP_VISIT_SITE, NotificationUmaTracker.SystemNotificationType.PRICE_DROP_ALERTS, NotificationIntentInterceptor.INVALID_CREATE_TIME); } else if (actionId.equals(ACTION_ID_TURN_OFF_ALERT)) { - if (offerId == null) return; + if (offerId == null && clusterId == null) return; SubscriptionsManagerImpl subscriptionsManager = (new CommerceSubscriptionsServiceFactory()) .getForLastUsedProfile() .getSubscriptionsManager(); - subscriptionsManager.unsubscribe( - new CommerceSubscription(CommerceSubscriptionType.PRICE_TRACK, offerId, - SubscriptionManagementType.CHROME_MANAGED, TrackingIdType.OFFER_ID), - (status) -> { - assert status - == SubscriptionsManager.StatusCode.OK - : "Failed to remove subscriptions."; - Log.e(TAG, - String.format(Locale.US, - "Failed to remove subscriptions. Status: %d", status)); - }); + Callback<Integer> callback = (status) -> { + assert status + == SubscriptionsManager.StatusCode.OK : "Failed to remove subscriptions."; + Log.e(TAG, + String.format( + Locale.US, "Failed to remove subscriptions. Status: %d", status)); + }; + if (offerId != null) { + subscriptionsManager.unsubscribe( + new CommerceSubscription(CommerceSubscriptionType.PRICE_TRACK, offerId, + SubscriptionManagementType.CHROME_MANAGED, TrackingIdType.OFFER_ID), + callback); + } + if (clusterId != null) { + subscriptionsManager.unsubscribe( + new CommerceSubscription(CommerceSubscriptionType.PRICE_TRACK, clusterId, + SubscriptionManagementType.USER_MANAGED, + TrackingIdType.PRODUCT_CLUSTER_ID), + callback); + } if (recordMetrics) { NotificationUmaTracker.getInstance().onNotificationActionClick( NotificationUmaTracker.ActionType.PRICE_DROP_TURN_OFF_ALERT, @@ -252,12 +280,26 @@ * @param offerId The offer id of the product. */ public Intent getNotificationActionClickIntent(String actionId, String url, String offerId) { + return getNotificationActionClickIntent(actionId, url, offerId, null); + } + + /** + * Gets the notification action click intents. + * + * @param actionId the id used to identify certain action. + * @param url of the tab which triggered the notification. + * @param offerId The offer id of the product. + * @param clusterId The cluster id of the product. + */ + public Intent getNotificationActionClickIntent( + String actionId, String url, String offerId, String clusterId) { if (ACTION_ID_VISIT_SITE.equals(actionId)) return getNotificationClickIntent(url); if (ACTION_ID_TURN_OFF_ALERT.equals(actionId)) { Intent intent = new Intent(mContext, TrampolineActivity.class); intent.putExtra(EXTRA_DESTINATION_URL, url); intent.putExtra(EXTRA_ACTION_ID, actionId); intent.putExtra(EXTRA_OFFER_ID, offerId); + if (clusterId != null) intent.putExtra(EXTRA_PRODUCT_CLUSTER_ID, clusterId); return intent; } return null;
diff --git a/chrome/browser/commerce/price_tracking/android/java/src/org/chromium/chrome/browser/price_tracking/PriceDropNotifier.java b/chrome/browser/commerce/price_tracking/android/java/src/org/chromium/chrome/browser/price_tracking/PriceDropNotifier.java index 2ff303b..ac5b8fd 100644 --- a/chrome/browser/commerce/price_tracking/android/java/src/org/chromium/chrome/browser/price_tracking/PriceDropNotifier.java +++ b/chrome/browser/commerce/price_tracking/android/java/src/org/chromium/chrome/browser/price_tracking/PriceDropNotifier.java
@@ -42,12 +42,14 @@ static class NotificationData { public NotificationData(CharSequence title, CharSequence text, String iconUrl, - String destinationUrl, String offerId, List<ActionData> actions) { + String destinationUrl, String offerId, String productClusterId, + List<ActionData> actions) { this.title = title; this.text = text; this.iconUrl = iconUrl; this.destinationUrl = destinationUrl; this.offerId = offerId; + this.productClusterId = productClusterId; this.actions = actions; } @@ -74,6 +76,10 @@ */ public final String offerId; /** + * Associated cluster ID. + */ + public final String productClusterId; + /** * A list of button actions. */ public final List<ActionData> actions; @@ -170,8 +176,9 @@ PriceTrackingNotificationConfig.getNotificationTimeoutMs()); if (notificationData.actions != null) { for (ActionData action : notificationData.actions) { - PendingIntentProvider actionClickIntentProvider = createClickIntent( - action.actionId, notificationData.destinationUrl, notificationData.offerId); + PendingIntentProvider actionClickIntentProvider = + createClickIntent(action.actionId, notificationData.destinationUrl, + notificationData.offerId, notificationData.productClusterId); notificationBuilder.addAction(0, action.text, actionClickIntentProvider, actionIdToUmaActionType(action.actionId)); } @@ -197,9 +204,10 @@ mContext, 0, intent, PendingIntent.FLAG_UPDATE_CURRENT); } - private PendingIntentProvider createClickIntent(String actionId, String url, String offerId) { + private PendingIntentProvider createClickIntent( + String actionId, String url, String offerId, String clusterId) { Intent intent = mPriceDropNotificationManager.getNotificationActionClickIntent( - actionId, url, offerId); + actionId, url, offerId, clusterId); return PendingIntentProvider.getActivity( mContext, 0, intent, PendingIntent.FLAG_UPDATE_CURRENT); }
diff --git a/chrome/browser/commerce/price_tracking/android/java/src/org/chromium/chrome/browser/price_tracking/PriceTrackingNotificationBridge.java b/chrome/browser/commerce/price_tracking/android/java/src/org/chromium/chrome/browser/price_tracking/PriceTrackingNotificationBridge.java index 3ddba7d..1565b88 100644 --- a/chrome/browser/commerce/price_tracking/android/java/src/org/chromium/chrome/browser/price_tracking/PriceTrackingNotificationBridge.java +++ b/chrome/browser/commerce/price_tracking/android/java/src/org/chromium/chrome/browser/price_tracking/PriceTrackingNotificationBridge.java
@@ -111,11 +111,12 @@ // Use UnsignedLongs to convert OfferId to avoid overflow. String offerId = UnsignedLongs.toString(priceDropPayload.getOfferId()); + String clusterId = UnsignedLongs.toString(priceDropPayload.getProductClusterId()); ChromeMessage chromeMessage = chromeNotification.getChromeMessage(); PriceDropNotifier.NotificationData notificationData = new PriceDropNotifier.NotificationData(title, text, chromeMessage.hasIconImageUrl() ? chromeMessage.getIconImageUrl() : null, - priceDropPayload.getDestinationUrl(), offerId, + priceDropPayload.getDestinationUrl(), offerId, clusterId, parseActions(chromeNotification)); mNotifier.showNotification(notificationData); }
diff --git a/chrome/browser/commerce/price_tracking/android/javatests/src/org/chromium/chrome/browser/price_tracking/PriceDropNotificationManagerTest.java b/chrome/browser/commerce/price_tracking/android/javatests/src/org/chromium/chrome/browser/price_tracking/PriceDropNotificationManagerTest.java index e3db2da..2c20ad02 100644 --- a/chrome/browser/commerce/price_tracking/android/javatests/src/org/chromium/chrome/browser/price_tracking/PriceDropNotificationManagerTest.java +++ b/chrome/browser/commerce/price_tracking/android/javatests/src/org/chromium/chrome/browser/price_tracking/PriceDropNotificationManagerTest.java
@@ -77,6 +77,7 @@ private static final String ACTION_ID_TURN_OFF_ALERT = "turn_off_alert"; private static final String TEST_URL = "www.test.com"; private static final String OFFER_ID = "offer_id"; + private static final String PRODUCT_CLUSTER_ID = "cluster_id"; private MockNotificationManagerProxy mMockNotificationManager; private PriceDropNotificationManager mPriceDropNotificationManager; @@ -220,12 +221,15 @@ @MediumTest public void testGetNotificationActionClickIntent() { verifyClickIntent(mPriceDropNotificationManager.getNotificationActionClickIntent( - ACTION_ID_VISIT_SITE, TEST_URL, OFFER_ID)); + ACTION_ID_VISIT_SITE, TEST_URL, OFFER_ID, PRODUCT_CLUSTER_ID)); Intent turnOffAlertIntent = mPriceDropNotificationManager.getNotificationActionClickIntent( - ACTION_ID_TURN_OFF_ALERT, TEST_URL, OFFER_ID); + ACTION_ID_TURN_OFF_ALERT, TEST_URL, OFFER_ID, PRODUCT_CLUSTER_ID); assertNotNull(turnOffAlertIntent); assertEquals(PriceDropNotificationManager.TrampolineActivity.class.getName(), turnOffAlertIntent.getComponent().getClassName()); + assertEquals(PRODUCT_CLUSTER_ID, + IntentUtils.safeGetStringExtra( + turnOffAlertIntent, PriceDropNotificationManager.EXTRA_PRODUCT_CLUSTER_ID)); assertEquals(OFFER_ID, IntentUtils.safeGetStringExtra( turnOffAlertIntent, PriceDropNotificationManager.EXTRA_OFFER_ID)); @@ -252,12 +256,12 @@ SubscriptionManagementType.CHROME_MANAGED, TrackingIdType.OFFER_ID); mPriceDropNotificationManager.onNotificationActionClicked( - ACTION_ID_TURN_OFF_ALERT, TEST_URL, null, false); + ACTION_ID_TURN_OFF_ALERT, TEST_URL, null, null, false); verify(mMockSubscriptionsManager, times(0)) .unsubscribe(eq(commerceSubscription), any(Callback.class)); mPriceDropNotificationManager.onNotificationActionClicked( - ACTION_ID_TURN_OFF_ALERT, TEST_URL, offerId, false); + ACTION_ID_TURN_OFF_ALERT, TEST_URL, offerId, null, false); verify(mMockSubscriptionsManager, times(1)) .unsubscribe(eq(commerceSubscription), any(Callback.class)); }
diff --git a/chrome/browser/commerce/price_tracking/android/javatests/src/org/chromium/chrome/browser/price_tracking/PriceDropNotifierUnitTest.java b/chrome/browser/commerce/price_tracking/android/javatests/src/org/chromium/chrome/browser/price_tracking/PriceDropNotifierUnitTest.java index e3dcb517..fc7d5d10 100644 --- a/chrome/browser/commerce/price_tracking/android/javatests/src/org/chromium/chrome/browser/price_tracking/PriceDropNotifierUnitTest.java +++ b/chrome/browser/commerce/price_tracking/android/javatests/src/org/chromium/chrome/browser/price_tracking/PriceDropNotifierUnitTest.java
@@ -68,6 +68,7 @@ private static final String ICON_URL = "http://www.example.com/icon"; private static final String DESTINATION_URL = "http://www.example.com/destination"; private static final String OFFER_ID = "offer_id"; + private static final String PRODUCT_CLUSTER_ID = "cluster_id"; private static final String ACTION_TEXT_0 = "action_text_0"; private static final String ACTION_TEXT_1 = "action_text_1"; @@ -166,8 +167,8 @@ } private void showNotification(List<ActionData> actionDataList) { - PriceDropNotifier.NotificationData data = new NotificationData( - TITLE, TEXT, ICON_URL, DESTINATION_URL, OFFER_ID, actionDataList); + PriceDropNotifier.NotificationData data = new NotificationData(TITLE, TEXT, ICON_URL, + DESTINATION_URL, OFFER_ID, PRODUCT_CLUSTER_ID, actionDataList); mPriceDropNotifier.showNotification(data); } @@ -205,7 +206,7 @@ @Test public void testShowNotificationNoIconURL() { PriceDropNotifier.NotificationData data = new NotificationData( - TITLE, TEXT, /*iconUrl=*/null, DESTINATION_URL, OFFER_ID, null); + TITLE, TEXT, /*iconUrl=*/null, DESTINATION_URL, OFFER_ID, PRODUCT_CLUSTER_ID, null); mPriceDropNotifier.showNotification(data); verify(mNotificationBuilder, times(0)).setLargeIcon(any()); verify(mNotificationBuilder, times(0)).setBigPictureStyle(any(), any());
diff --git a/chrome/browser/commerce/price_tracking/proto/notifications.proto b/chrome/browser/commerce/price_tracking/proto/notifications.proto index a8cf97d..475c6d0c 100644 --- a/chrome/browser/commerce/price_tracking/proto/notifications.proto +++ b/chrome/browser/commerce/price_tracking/proto/notifications.proto
@@ -58,4 +58,5 @@ optional string destination_url = 3; optional commerce.ProductPrice current_price = 4; optional commerce.ProductPrice previous_price = 5; + optional uint64 product_cluster_id = 6; }
diff --git a/chrome/browser/component_updater/first_party_sets_component_installer.cc b/chrome/browser/component_updater/first_party_sets_component_installer.cc index 0f469640..7451e1f 100644 --- a/chrome/browser/component_updater/first_party_sets_component_installer.cc +++ b/chrome/browser/component_updater/first_party_sets_component_installer.cc
@@ -4,9 +4,13 @@ #include "chrome/browser/component_updater/first_party_sets_component_installer.h" +#include <utility> + #include "base/bind.h" +#include "base/callback.h" #include "base/cxx17_backports.h" #include "base/feature_list.h" +#include "base/files/file.h" #include "base/files/file_util.h" #include "base/logging.h" #include "base/memory/ref_counted.h" @@ -17,18 +21,21 @@ #include "base/version.h" #include "chrome/browser/browser_process.h" #include "chrome/browser/first_party_sets/first_party_sets_pref_names.h" +#include "components/component_updater/component_installer.h" #include "components/component_updater/component_updater_paths.h" #include "components/prefs/pref_service.h" #include "content/public/browser/network_service_instance.h" #include "net/base/features.h" #include "net/cookies/cookie_util.h" #include "services/network/public/mojom/network_service.mojom.h" -#include "third_party/abseil-cpp/absl/types/optional.h" using component_updater::ComponentUpdateService; namespace { +using SetsReadyOnceCallback = component_updater:: + FirstPartySetsComponentInstallerPolicy::SetsReadyOnceCallback; + constexpr base::FilePath::CharType kFirstPartySetsSetsFileName[] = FILE_PATH_LITERAL("sets.json"); @@ -61,6 +68,10 @@ if (!base::FeatureList::IsEnabled(net::features::kFirstPartySets)) { return false; } + // Some java tests start the browser in minimal mode, in which + // `g_browser_process` is unset. + if (!g_browser_process) + return false; auto* local_state = g_browser_process->local_state(); if (!local_state || !local_state->HasPrefPath(first_party_sets::kFirstPartySetsEnabled)) { @@ -69,14 +80,21 @@ return local_state->GetBoolean(first_party_sets::kFirstPartySetsEnabled); } -// Invokes `on_sets_ready` with the contents of the component, if: -// * the component has been installed; and +// Invokes `on_sets_ready`, if: // * First-Party Sets is enabled; and -// * the component was read successfully. -void SetFirstPartySetsConfig( - base::OnceCallback<void(base::File)> on_sets_ready) { +// * `on_sets_ready` is not null. +// +// If the component has been installed and can be read, we pass the component +// file; otherwise, we pass an invalid file. +void SetFirstPartySetsConfig(SetsReadyOnceCallback on_sets_ready) { + if (!IsFirstPartySetsEnabled() || on_sets_ready.is_null()) { + return; + } + const base::FilePath instance_path = GetConfigPathInstance(); - if (instance_path.empty() || !IsFirstPartySetsEnabled()) { + if (instance_path.empty()) { + // Registration is complete, but no component version exists on disk. + std::move(on_sets_ready).Run(base::File()); return; } @@ -95,12 +113,16 @@ // static void FirstPartySetsComponentInstallerPolicy::ReconfigureAfterNetworkRestart( - base::OnceCallback<void(base::File)> on_sets_ready) { + SetsReadyOnceCallback on_sets_ready) { SetFirstPartySetsConfig(std::move(on_sets_ready)); } +void FirstPartySetsComponentInstallerPolicy::OnRegistrationComplete() { + SetFirstPartySetsConfig(std::move(on_sets_ready_)); +} + FirstPartySetsComponentInstallerPolicy::FirstPartySetsComponentInstallerPolicy( - base::RepeatingCallback<void(base::File)> on_sets_ready) + SetsReadyOnceCallback on_sets_ready) : on_sets_ready_(std::move(on_sets_ready)) {} FirstPartySetsComponentInstallerPolicy:: @@ -150,7 +172,7 @@ GetConfigPathInstance() = GetInstalledPath(install_dir); - SetFirstPartySetsConfig(on_sets_ready_); + SetFirstPartySetsConfig(std::move(on_sets_ready_)); } // Called during startup and installation before ComponentReady(). @@ -201,14 +223,23 @@ void RegisterFirstPartySetsComponent(ComponentUpdateService* cus) { VLOG(1) << "Registering First-Party Sets component."; - base::MakeRefCounted<ComponentInstaller>( - std::make_unique<FirstPartySetsComponentInstallerPolicy>( - /*on_sets_ready=*/base::BindRepeating([](base::File sets_file) { - VLOG(1) << "Received First-Party Sets"; - content::GetNetworkService()->SetFirstPartySets( - std::move(sets_file)); - }))) - ->Register(cus, base::OnceClosure()); + auto policy = std::make_unique<FirstPartySetsComponentInstallerPolicy>( + /*on_sets_ready=*/base::BindOnce([](base::File sets_file) { + VLOG(1) << "Received First-Party Sets"; + content::GetNetworkService()->SetFirstPartySets(std::move(sets_file)); + })); + + FirstPartySetsComponentInstallerPolicy* raw_policy = policy.get(); + // Dereferencing `raw_policy` this way is safe because the closure is invoked + // by the ComponentInstaller instance, which owns `policy` (so they have the + // same lifetime). Therefore if/when the closure is invoked, `policy` is still + // alive. + base::MakeRefCounted<ComponentInstaller>(std::move(policy)) + ->Register(cus, base::BindOnce( + [](FirstPartySetsComponentInstallerPolicy* policy) { + policy->OnRegistrationComplete(); + }, + raw_policy)); } // static
diff --git a/chrome/browser/component_updater/first_party_sets_component_installer.h b/chrome/browser/component_updater/first_party_sets_component_installer.h index 6fcf843..6ccdfe5de 100644 --- a/chrome/browser/component_updater/first_party_sets_component_installer.h +++ b/chrome/browser/component_updater/first_party_sets_component_installer.h
@@ -28,10 +28,12 @@ class FirstPartySetsComponentInstallerPolicy : public ComponentInstallerPolicy { public: + using SetsReadyOnceCallback = base::OnceCallback<void(base::File)>; + // |on_sets_ready| will be called on the UI thread when the sets are ready. It // is exposed here for testing. explicit FirstPartySetsComponentInstallerPolicy( - base::RepeatingCallback<void(base::File)> on_sets_ready); + SetsReadyOnceCallback on_sets_ready); ~FirstPartySetsComponentInstallerPolicy() override; FirstPartySetsComponentInstallerPolicy( @@ -42,7 +44,9 @@ // Calls the callback with the current First-Party Sets data, if the data // exists and can be read. static void ReconfigureAfterNetworkRestart( - base::OnceCallback<void(base::File)>); + SetsReadyOnceCallback on_sets_ready); + + void OnRegistrationComplete(); // Resets static state. Should only be used to clear state during testing. static void ResetForTesting(); @@ -59,10 +63,14 @@ FRIEND_TEST_ALL_PREFIXES(FirstPartySetsComponentInstallerTest, NonexistentFile_OnComponentReady); FRIEND_TEST_ALL_PREFIXES(FirstPartySetsComponentInstallerTest, + NonexistentFile_OnRegistrationComplete); + FRIEND_TEST_ALL_PREFIXES(FirstPartySetsComponentInstallerTest, LoadsSets_OnComponentReady); FRIEND_TEST_ALL_PREFIXES(FirstPartySetsComponentInstallerTest, LoadsSets_OnNetworkRestart); FRIEND_TEST_ALL_PREFIXES(FirstPartySetsComponentInstallerTest, + IgnoreNewSets_NoInitialComponent); + FRIEND_TEST_ALL_PREFIXES(FirstPartySetsComponentInstallerTest, IgnoreNewSets_OnComponentReady); FRIEND_TEST_ALL_PREFIXES(FirstPartySetsComponentInstallerTest, IgnoreNewSets_OnNetworkRestart); @@ -98,7 +106,9 @@ static base::FilePath GetInstalledPath(const base::FilePath& base); - base::RepeatingCallback<void(base::File)> on_sets_ready_; + // We use a OnceCallback to ensure we only pass along the sets file once + // during Chrome's lifetime (modulo reconfiguring the network service). + SetsReadyOnceCallback on_sets_ready_; }; // Call once during startup to make the component update service aware of
diff --git a/chrome/browser/component_updater/first_party_sets_component_installer_unittest.cc b/chrome/browser/component_updater/first_party_sets_component_installer_unittest.cc index a3ba4ae..df54ca7 100644 --- a/chrome/browser/component_updater/first_party_sets_component_installer_unittest.cc +++ b/chrome/browser/component_updater/first_party_sets_component_installer_unittest.cc
@@ -5,6 +5,8 @@ #include "chrome/browser/component_updater/first_party_sets_component_installer.h" #include "base/callback_helpers.h" +#include "base/check.h" +#include "base/files/file.h" #include "base/files/file_util.h" #include "base/files/scoped_temp_dir.h" #include "base/run_loop.h" @@ -86,6 +88,23 @@ base::RunLoop().RunUntilIdle(); } +TEST_F(FirstPartySetsComponentInstallerTest, + NonexistentFile_OnRegistrationComplete) { + ASSERT_TRUE( + base::DeleteFile(FirstPartySetsComponentInstallerPolicy::GetInstalledPath( + component_install_dir_.GetPath()))); + + base::RunLoop run_loop; + FirstPartySetsComponentInstallerPolicy( + base::BindLambdaForTesting([&](base::File file) { + EXPECT_FALSE(file.IsValid()); + run_loop.Quit(); + })) + .OnRegistrationComplete(); + + run_loop.Run(); +} + TEST_F(FirstPartySetsComponentInstallerTest, LoadsSets_OnComponentReady) { SEQUENCE_CHECKER(sequence_checker); const std::string expectation = "some first party sets"; @@ -93,6 +112,7 @@ auto policy = std::make_unique<FirstPartySetsComponentInstallerPolicy>( base::BindLambdaForTesting([&](base::File file) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker); + EXPECT_TRUE(file.IsValid()); EXPECT_EQ(ReadToString(std::move(file)), expectation); run_loop.Quit(); })); @@ -108,6 +128,41 @@ run_loop.Run(); } +// Test that when the first version of the component is installed, +// ComponentReady is a no-op, because OnRegistrationComplete already executed +// the OnceCallback. +TEST_F(FirstPartySetsComponentInstallerTest, IgnoreNewSets_NoInitialComponent) { + SEQUENCE_CHECKER(sequence_checker); + + int callback_calls = 0; + FirstPartySetsComponentInstallerPolicy policy( + // This should run only once for the OnRegistrationComplete call. + base::BindLambdaForTesting([&](base::File file) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker); + EXPECT_FALSE(file.IsValid()); + callback_calls++; + })); + + policy.OnRegistrationComplete(); + env_.RunUntilIdle(); + EXPECT_EQ(callback_calls, 1); + + // Install the component, which should be ignored. + base::ScopedTempDir install_dir; + CHECK(install_dir.CreateUniqueTempDirUnderPath( + component_install_dir_.GetPath())); + ASSERT_TRUE( + base::WriteFile(FirstPartySetsComponentInstallerPolicy::GetInstalledPath( + install_dir.GetPath()), + "first party sets content")); + policy.ComponentReady(base::Version(), install_dir.GetPath(), + base::Value(base::Value::Type::DICTIONARY)); + + env_.RunUntilIdle(); + + EXPECT_EQ(callback_calls, 1); +} + // Test if a component has been installed, ComponentReady will be no-op when // newer versions are installed. TEST_F(FirstPartySetsComponentInstallerTest, IgnoreNewSets_OnComponentReady) { @@ -125,6 +180,7 @@ // It should run only once for the first ComponentReady call. base::BindLambdaForTesting([&](base::File file) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker); + EXPECT_TRUE(file.IsValid()); EXPECT_EQ(ReadToString(std::move(file)), sets_v1); callback_calls++; })); @@ -160,6 +216,7 @@ FirstPartySetsComponentInstallerPolicy policy( base::BindLambdaForTesting([&](base::File file) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker); + EXPECT_TRUE(file.IsValid()); EXPECT_EQ(ReadToString(std::move(file)), expectation); run_loop.Quit(); })); @@ -181,6 +238,7 @@ FirstPartySetsComponentInstallerPolicy::ReconfigureAfterNetworkRestart( base::BindLambdaForTesting([&](base::File file) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker); + EXPECT_TRUE(file.IsValid()); EXPECT_EQ(ReadToString(std::move(file)), expectation); run_loop.Quit(); })); @@ -205,6 +263,7 @@ FirstPartySetsComponentInstallerPolicy policy( base::BindLambdaForTesting([&](base::File file) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker); + EXPECT_TRUE(file.IsValid()); EXPECT_EQ(ReadToString(std::move(file)), sets_v1); })); @@ -231,6 +290,7 @@ FirstPartySetsComponentInstallerPolicy::ReconfigureAfterNetworkRestart( base::BindLambdaForTesting([&](base::File file) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker); + EXPECT_TRUE(file.IsValid()); EXPECT_EQ(ReadToString(std::move(file)), sets_v1); callback_calls++; }));
diff --git a/chrome/browser/enterprise/reporting/report_scheduler_unittest.cc b/chrome/browser/enterprise/reporting/report_scheduler_unittest.cc index e872580..96286739 100644 --- a/chrome/browser/enterprise/reporting/report_scheduler_unittest.cc +++ b/chrome/browser/enterprise/reporting/report_scheduler_unittest.cc
@@ -104,7 +104,8 @@ MockReportUploader& operator=(const MockReportUploader&) = delete; ~MockReportUploader() override = default; - MOCK_METHOD2(SetRequestAndUpload, void(ReportRequestQueue, ReportCallback)); + MOCK_METHOD3(SetRequestAndUpload, + void(ReportType, ReportRequestQueue, ReportCallback)); }; class MockRealTimeReportGenerator : public RealTimeReportGenerator { @@ -313,8 +314,8 @@ EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kSuccess)); + EXPECT_CALL(*uploader_, SetRequestAndUpload(ReportType::kFull, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kSuccess)); CreateScheduler(); EXPECT_TRUE(scheduler_->IsNextReportScheduledForTesting()); @@ -334,8 +335,8 @@ EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kTransientError)); + EXPECT_CALL(*uploader_, SetRequestAndUpload(ReportType::kFull, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kTransientError)); CreateScheduler(); EXPECT_TRUE(scheduler_->IsNextReportScheduledForTesting()); @@ -355,8 +356,8 @@ EXPECT_CALL_SetupRegistrationWithSetDMToken(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kPersistentError)); + EXPECT_CALL(*uploader_, SetRequestAndUpload(ReportType::kFull, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kPersistentError)); CreateScheduler(); EXPECT_TRUE(scheduler_->IsNextReportScheduledForTesting()); @@ -381,7 +382,7 @@ EXPECT_CALL_SetupRegistrationWithSetDMToken(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(0))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)).Times(0); + EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _, _)).Times(0); CreateScheduler(); EXPECT_TRUE(scheduler_->IsNextReportScheduledForTesting()); @@ -409,8 +410,8 @@ EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kSuccess)); + EXPECT_CALL(*uploader_, SetRequestAndUpload(ReportType::kFull, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kSuccess)); CreateScheduler(); EXPECT_TRUE(scheduler_->IsNextReportScheduledForTesting()); @@ -429,8 +430,8 @@ EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kSuccess)); + EXPECT_CALL(*uploader_, SetRequestAndUpload(ReportType::kFull, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kSuccess)); CreateScheduler(); EXPECT_TRUE(scheduler_->IsNextReportScheduledForTesting()); @@ -466,8 +467,8 @@ EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kSuccess)); + EXPECT_CALL(*uploader_, SetRequestAndUpload(ReportType::kFull, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kSuccess)); CreateScheduler(); EXPECT_TRUE(scheduler_->IsNextReportScheduledForTesting()); @@ -502,8 +503,9 @@ EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kBrowserVersion, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kSuccess)); + EXPECT_CALL(*uploader_, + SetRequestAndUpload(ReportType::kBrowserVersion, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kSuccess)); CreateScheduler(); g_browser_process->GetBuildState()->SetUpdate( @@ -525,8 +527,9 @@ EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kBrowserVersion, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kPersistentError)); + EXPECT_CALL(*uploader_, + SetRequestAndUpload(ReportType::kBrowserVersion, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kPersistentError)); CreateScheduler(); g_browser_process->GetBuildState()->SetUpdate( @@ -559,8 +562,10 @@ // Hang on to the uploader's ReportCallback. ReportUploader::ReportCallback saved_callback; - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce([&saved_callback](ReportRequestQueue requests, + EXPECT_CALL(*uploader_, + SetRequestAndUpload(ReportType::kBrowserVersion, _, _)) + .WillOnce([&saved_callback](ReportType report_type, + ReportRequestQueue requests, ReportUploader::ReportCallback callback) { saved_callback = std::move(callback); }); @@ -582,8 +587,8 @@ EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); auto new_uploader = std::make_unique<MockReportUploader>(); - EXPECT_CALL(*new_uploader, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kSuccess)); + EXPECT_CALL(*new_uploader, SetRequestAndUpload(ReportType::kFull, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kSuccess)); std::move(saved_callback).Run(ReportUploader::kSuccess); ExpectLastUploadTimestampUpdated(false); ::testing::Mock::VerifyAndClearExpectations(generator_); @@ -612,8 +617,9 @@ EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kBrowserVersion, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kSuccess)); + EXPECT_CALL(*uploader_, + SetRequestAndUpload(ReportType::kBrowserVersion, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kSuccess)); CreateScheduler(); @@ -642,8 +648,8 @@ EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kSuccess)); + EXPECT_CALL(*uploader_, SetRequestAndUpload(ReportType::kFull, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kSuccess)); CreateScheduler(); @@ -664,7 +670,7 @@ TEST_F(ReportSchedulerTest, ExtensionRequestWithRealTimePipeline) { EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(_, _)).Times(0); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)).Times(0); + EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _, _)).Times(0); Profile* profile = profile_manager_.CreateTestingProfile("profile");
diff --git a/chrome/browser/extensions/api/developer_private/developer_private_api.cc b/chrome/browser/extensions/api/developer_private/developer_private_api.cc index 7b56123..f9e3dbb 100644 --- a/chrome/browser/extensions/api/developer_private/developer_private_api.cc +++ b/chrome/browser/extensions/api/developer_private/developer_private_api.cc
@@ -7,13 +7,13 @@ #include <stddef.h> #include <memory> #include <string> +#include <tuple> #include <utility> #include <vector> #include "base/bind.h" #include "base/files/file_util.h" #include "base/guid.h" -#include "base/ignore_result.h" #include "base/lazy_instance.h" #include "base/memory/scoped_refptr.h" #include "base/metrics/histogram_macros.h" @@ -159,7 +159,7 @@ std::string data; // This call can fail, but it doesn't matter for our purposes. If it fails, // we simply return an empty string for the manifest, and ignore it. - ignore_result(base::ReadFileToString(path, &data)); + std::ignore = base::ReadFileToString(path, &data); return data; }
diff --git a/chrome/browser/extensions/api/extension_action/page_action_browsertest.cc b/chrome/browser/extensions/api/extension_action/page_action_browsertest.cc index 7b8ef728..99fcd5b 100644 --- a/chrome/browser/extensions/api/extension_action/page_action_browsertest.cc +++ b/chrome/browser/extensions/api/extension_action/page_action_browsertest.cc
@@ -144,7 +144,11 @@ } // Regression test for crbug.com/44415. -IN_PROC_BROWSER_TEST_P(PageActionBrowserTest, PageActionRefreshCrash) { +// TODO(crbug.com/1284555): Re-enable test to run with service workers after +// fixing flakiness. +using PageActionWithoutServiceWorkerTest = ExtensionBrowserTest; +IN_PROC_BROWSER_TEST_F(PageActionWithoutServiceWorkerTest, + PageActionRefreshCrash) { ExtensionRegistry* registry = extensions::ExtensionRegistry::Get(browser()->profile());
diff --git a/chrome/browser/extensions/api/messaging/native_messaging_test_util.cc b/chrome/browser/extensions/api/messaging/native_messaging_test_util.cc index 8b4a85d..c673c30 100644 --- a/chrome/browser/extensions/api/messaging/native_messaging_test_util.cc +++ b/chrome/browser/extensions/api/messaging/native_messaging_test_util.cc
@@ -5,11 +5,11 @@ #include "chrome/browser/extensions/api/messaging/native_messaging_test_util.h" #include <memory> +#include <tuple> #include <utility> #include "base/files/file_path.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/json/json_file_value_serializer.h" #include "base/path_service.h" #include "base/strings/stringprintf.h" @@ -126,7 +126,7 @@ ScopedTestNativeMessagingHost::~ScopedTestNativeMessagingHost() { base::ScopedAllowBlockingForTesting allow_blocking; - ignore_result(temp_dir_.Delete()); + std::ignore = temp_dir_.Delete(); } } // namespace extensions
diff --git a/chrome/browser/extensions/crx_installer.cc b/chrome/browser/extensions/crx_installer.cc index 5643078..c1889aa 100644 --- a/chrome/browser/extensions/crx_installer.cc +++ b/chrome/browser/extensions/crx_installer.cc
@@ -32,6 +32,8 @@ #include "chrome/browser/extensions/load_error_reporter.h" #include "chrome/browser/extensions/permissions_updater.h" #include "chrome/browser/extensions/webstore_installer.h" +#include "chrome/browser/profiles/keep_alive/profile_keep_alive_types.h" +#include "chrome/browser/profiles/keep_alive/scoped_profile_keep_alive.h" #include "chrome/browser/profiles/profile.h" #include "chrome/common/chrome_paths.h" #include "chrome/common/chrome_switches.h" @@ -1057,6 +1059,9 @@ } void CrxInstaller::NotifyCrxInstallBegin() { + profile_keep_alive_ = std::make_unique<ScopedProfileKeepAlive>( + profile_, ProfileKeepAliveOrigin::kCrxInstaller); + InstallTrackerFactory::GetForBrowserContext(profile()) ->OnBeginCrxInstall(expected_id_); } @@ -1121,6 +1126,8 @@ FROM_HERE, base::BindOnce(std::move(installer_callback_), error))) { NOTREACHED(); } + + profile_keep_alive_.reset(); } void CrxInstaller::CleanupTempFiles() {
diff --git a/chrome/browser/extensions/crx_installer.h b/chrome/browser/extensions/crx_installer.h index ea0f700..089ab148 100644 --- a/chrome/browser/extensions/crx_installer.h +++ b/chrome/browser/extensions/crx_installer.h
@@ -30,6 +30,7 @@ #include "third_party/abseil-cpp/absl/types/optional.h" class ExtensionServiceTest; +class ScopedProfileKeepAlive; class SkBitmap; namespace base { @@ -368,6 +369,9 @@ // The Profile the extension is being installed in. raw_ptr<Profile> profile_; + // Prevent Profile destruction until the CrxInstaller is done. + std::unique_ptr<ScopedProfileKeepAlive> profile_keep_alive_; + // The extension being installed. scoped_refptr<const Extension> extension_;
diff --git a/chrome/browser/extensions/process_management_browsertest.cc b/chrome/browser/extensions/process_management_browsertest.cc index 1d3f959..56342e8 100644 --- a/chrome/browser/extensions/process_management_browsertest.cc +++ b/chrome/browser/extensions/process_management_browsertest.cc
@@ -4,7 +4,8 @@ #include <stddef.h> -#include "base/ignore_result.h" +#include <tuple> + #include "base/strings/stringprintf.h" #include "base/strings/utf_string_conversions.h" #include "base/test/scoped_feature_list.h" @@ -709,8 +710,8 @@ // WaitForLoadStop() will return false on a 404, but that can happen if we // navigate to a blocked or nonexistent extension page. - ignore_result(content::WaitForLoadStop( - browser()->tab_strip_model()->GetActiveWebContents())); + std::ignore = content::WaitForLoadStop( + browser()->tab_strip_model()->GetActiveWebContents()); EXPECT_TRUE(content::ExecuteScriptAndExtractBool(web_contents, kGetAccess, &can_access));
diff --git a/chrome/browser/file_system_access/chrome_file_system_access_permission_context_browsertest.cc b/chrome/browser/file_system_access/chrome_file_system_access_permission_context_browsertest.cc index 44361a3..bca3a39 100644 --- a/chrome/browser/file_system_access/chrome_file_system_access_permission_context_browsertest.cc +++ b/chrome/browser/file_system_access/chrome_file_system_access_permission_context_browsertest.cc
@@ -4,9 +4,10 @@ #include "chrome/browser/file_system_access/chrome_file_system_access_permission_context.h" +#include <tuple> + #include "base/files/file_util.h" #include "base/files/scoped_temp_dir.h" -#include "base/ignore_result.h" #include "chrome/browser/file_system_access/file_system_access_permission_request_manager.h" #include "chrome/browser/profiles/profile.h" #include "chrome/browser/ui/browser.h" @@ -194,7 +195,7 @@ // In order to get the file handle without the file picker dialog in the // prerendered page, BroadcastChannel gets the file handle from the current // active page. - ignore_result( + std::ignore = content::ExecJs(prerendered_frame_host, R"( var createWritableAndClose = (async () => { let b = new BroadcastChannel('channel'); @@ -208,17 +209,17 @@ await w.close(); return "";})(); )", - content::EvalJsOptions::EXECUTE_SCRIPT_NO_USER_GESTURE)); + content::EvalJsOptions::EXECUTE_SCRIPT_NO_USER_GESTURE); // The active page picks files and sends it to the prerendered page to test // 'close()' in prerendering. - ignore_result(content::ExecJs( + std::ignore = content::ExecJs( GetWebContents(), "(async () => {" " let [e] = await self.showOpenFilePicker();" " self.entry = e;" " new BroadcastChannel('channel').postMessage({entry: e});" - " return e.name; })()")); + " return e.name; })()"); // PerformAfterWriteChecks() is not called in prerendering. EXPECT_FALSE(permission_context.performed_after_write_checks());
diff --git a/chrome/browser/first_run/first_run.cc b/chrome/browser/first_run/first_run.cc index 18e21abc..2dadc0c 100644 --- a/chrome/browser/first_run/first_run.cc +++ b/chrome/browser/first_run/first_run.cc
@@ -6,13 +6,13 @@ #include <algorithm> #include <memory> +#include <tuple> #include <utility> #include "base/bind.h" #include "base/command_line.h" #include "base/files/file_path.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/metrics/histogram_macros.h" #include "base/metrics/user_metrics.h" #include "base/no_destructor.h" @@ -344,7 +344,7 @@ // Causes the first run sentinel creation time to be read and cached, while // I/O is still allowed. - ignore_result(GetFirstRunSentinelCreationTime()); + std::ignore = GetFirstRunSentinelCreationTime(); } base::Time GetFirstRunSentinelCreationTime() {
diff --git a/chrome/browser/flags/android/chrome_feature_list.cc b/chrome/browser/flags/android/chrome_feature_list.cc index ddfeffdc..e4b3f1c 100644 --- a/chrome/browser/flags/android/chrome_feature_list.cc +++ b/chrome/browser/flags/android/chrome_feature_list.cc
@@ -102,6 +102,7 @@ &download::features::kSmartSuggestionForLargeDownloads, &download::features::kUseDownloadOfflineContentProvider, &embedder_support::kShowTrustedPublisherURL, + &features::kAnonymousUpdateChecks, &features::kContinuousSearch, &features::kEarlyLibraryLoad, &features::kGenericSensorExtraClasses,
diff --git a/chrome/browser/flags/android/java/src/org/chromium/chrome/browser/flags/CachedFeatureFlags.java b/chrome/browser/flags/android/java/src/org/chromium/chrome/browser/flags/CachedFeatureFlags.java index af111e7..c07010d 100644 --- a/chrome/browser/flags/android/java/src/org/chromium/chrome/browser/flags/CachedFeatureFlags.java +++ b/chrome/browser/flags/android/java/src/org/chromium/chrome/browser/flags/CachedFeatureFlags.java
@@ -48,6 +48,7 @@ */ private static Map<String, Boolean> sDefaults = ImmutableMap.<String, Boolean>builder() + .put(ChromeFeatureList.ANONYMOUS_UPDATE_CHECKS, true) .put(ChromeFeatureList.BOOKMARK_BOTTOM_SHEET, false) .put(ChromeFeatureList.CONDITIONAL_TAB_STRIP_ANDROID, false) .put(ChromeFeatureList.LENS_CAMERA_ASSISTED_SEARCH, false)
diff --git a/chrome/browser/flags/android/java/src/org/chromium/chrome/browser/flags/ChromeFeatureList.java b/chrome/browser/flags/android/java/src/org/chromium/chrome/browser/flags/ChromeFeatureList.java index 2d6b15f..08839df 100644 --- a/chrome/browser/flags/android/java/src/org/chromium/chrome/browser/flags/ChromeFeatureList.java +++ b/chrome/browser/flags/android/java/src/org/chromium/chrome/browser/flags/ChromeFeatureList.java
@@ -216,6 +216,7 @@ "AndroidLayoutChangeTabReparenting"; public static final String ANDROID_SEARCH_ENGINE_CHOICE_NOTIFICATION = "AndroidSearchEngineChoiceNotification"; + public static final String ANONYMOUS_UPDATE_CHECKS = "AnonymousUpdateChecks"; public static final String APP_LANGUAGE_PROMPT = "AppLanguagePrompt"; public static final String ASSISTANT_CONSENT_SIMPLIFIED_TEXT = "AssistantConsentSimplifiedText"; public static final String ASSISTANT_CONSENT_V2 = "AssistantConsentV2";
diff --git a/chrome/browser/google/google_update_policy_fetcher_win.cc b/chrome/browser/google/google_update_policy_fetcher_win.cc index fb1ce9f..458369f 100644 --- a/chrome/browser/google/google_update_policy_fetcher_win.cc +++ b/chrome/browser/google/google_update_policy_fetcher_win.cc
@@ -6,10 +6,11 @@ #include <ATLComTime.h> #include <wrl/client.h> + +#include <tuple> #include <utility> #include "base/bind.h" -#include "base/ignore_result.h" #include "base/numerics/safe_conversions.h" #include "base/strings/string_split.h" #include "base/strings/string_util_win.h" @@ -71,11 +72,11 @@ ::COleDateTime date_time(date); base::Time time; if (date_time.m_status == ::COleDateTime::valid) { - ignore_result(base::Time::FromLocalExploded( + std::ignore = base::Time::FromLocalExploded( {date_time.GetYear(), date_time.GetMonth(), date_time.GetDayOfWeek(), date_time.GetDay(), date_time.GetHour(), date_time.GetMinute(), date_time.GetSecond(), 0}, - &time)); + &time); } return time; }
diff --git a/chrome/browser/media/cdm_document_service_impl_test.cc b/chrome/browser/media/cdm_document_service_impl_test.cc index df8d388..ae85d16 100644 --- a/chrome/browser/media/cdm_document_service_impl_test.cc +++ b/chrome/browser/media/cdm_document_service_impl_test.cc
@@ -5,10 +5,10 @@ #include "chrome/browser/media/cdm_document_service_impl.h" #include <memory> +#include <tuple> #include "base/files/file.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/json/values_util.h" #include "base/logging.h" #include "base/run_loop.h" @@ -172,7 +172,7 @@ // Call GetMediaFoundationCdmData to create the origin id first, otherwise // `SetCdmClientToken()` will assume the preference data associated with the // origin was recently cleared and will not save the client token. - ignore_result(GetMediaFoundationCdmData()); + std::ignore = GetMediaFoundationCdmData(); std::vector<uint8_t> expected_client_token = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; SetCdmClientToken(expected_client_token); @@ -192,12 +192,12 @@ // Call GetMediaFoundationCdmData to create the origin id first, otherwise // `SetCdmClientToken()` will assume the preference data associated with the // origin was recently cleared and will not save the client token. - ignore_result(GetMediaFoundationCdmData()); + std::ignore = GetMediaFoundationCdmData(); std::vector<uint8_t> expected_client_token = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; SetCdmClientToken(expected_client_token); NavigateToUrlAndCreateCdmDocumentService(GURL(kTestOrigin2)); - ignore_result(GetMediaFoundationCdmData()); + std::ignore = GetMediaFoundationCdmData(); SetCdmClientToken({1, 2, 3, 4, 5}); NavigateToUrlAndCreateCdmDocumentService(GURL(kTestOrigin)); @@ -210,7 +210,7 @@ // remove that entry and return without saving the client token. TEST_F(CdmDocumentServiceImplTest, SetClientTokenAfterCorruption) { NavigateToUrlAndCreateCdmDocumentService(GURL(kTestOrigin)); - ignore_result(GetMediaFoundationCdmData()); + std::ignore = GetMediaFoundationCdmData(); CorruptCdmPreference(); std::vector<uint8_t> expected_client_token = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
diff --git a/chrome/browser/media/encrypted_media_supported_types_browsertest.cc b/chrome/browser/media/encrypted_media_supported_types_browsertest.cc index 61bc7e7..5b3019b 100644 --- a/chrome/browser/media/encrypted_media_supported_types_browsertest.cc +++ b/chrome/browser/media/encrypted_media_supported_types_browsertest.cc
@@ -5,11 +5,11 @@ #include <stddef.h> #include <string> +#include <tuple> #include <vector> #include "base/base_switches.h" #include "base/files/file_path.h" -#include "base/ignore_result.h" #include "base/path_service.h" #include "base/strings/stringprintf.h" #include "base/strings/utf_string_conversions.h" @@ -76,7 +76,7 @@ // Any support is acceptable. This can be used around new CDM check-in time // where test expectations can change based on the new CDM's capability. // For any usage of EXPECT_ANY, add a TODO explaining the plan to fix it. -#define EXPECT_ANY(test) ignore_result(test) +#define EXPECT_ANY(test) std::ignore = test #if BUILDFLAG(ENABLE_AV1_DECODER) #define EXPECT_AV1 EXPECT_SUCCESS
diff --git a/chrome/browser/media/history/media_history_store.cc b/chrome/browser/media/history/media_history_store.cc index 3be6bfd..34579ac 100644 --- a/chrome/browser/media/history/media_history_store.cc +++ b/chrome/browser/media/history/media_history_store.cc
@@ -4,10 +4,11 @@ #include "chrome/browser/media/history/media_history_store.h" +#include <tuple> + #include "base/callback.h" #include "base/files/file_path.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/metrics/histogram_functions.h" #include "base/strings/stringprintf.h" #include "base/task/task_runner_util.h" @@ -52,7 +53,7 @@ // or hardware issues, not coding errors at the client level, so displaying // the error would probably lead to confusion. The ignored call signals the // test-expectation framework that the error was handled. - ignore_result(sql::Database::IsExpectedSqliteError(extended_error)); + std::ignore = sql::Database::IsExpectedSqliteError(extended_error); return; }
diff --git a/chrome/browser/media/webrtc/conditional_focus_browsertest.cc b/chrome/browser/media/webrtc/conditional_focus_browsertest.cc index 07ee4d1..e4cbfbd 100644 --- a/chrome/browser/media/webrtc/conditional_focus_browsertest.cc +++ b/chrome/browser/media/webrtc/conditional_focus_browsertest.cc
@@ -258,8 +258,16 @@ browser()->tab_strip_model()->GetWebContentsAt(0)); } +// TODO(crbug.com/1285418): Flaky on Win and Linux bots. +#if defined(OS_WIN) || defined(OS_LINUX) +#define MAYBE_ExceptionRaisedIfFocusCalledMultipleTimes \ + DISABLED_ExceptionRaisedIfFocusCalledMultipleTimes +#else +#define MAYBE_ExceptionRaisedIfFocusCalledMultipleTimes \ + ExceptionRaisedIfFocusCalledMultipleTimes +#endif IN_PROC_BROWSER_TEST_F(ConditionalFocusBrowserTest, - ExceptionRaisedIfFocusCalledMultipleTimes) { + MAYBE_ExceptionRaisedIfFocusCalledMultipleTimes) { // Setup. SetUpTestTabs(); Capture(0, FocusEnumValue::kFocusCapturedSurface);
diff --git a/chrome/browser/media/webrtc/test_stats_dictionary_unittest.cc b/chrome/browser/media/webrtc/test_stats_dictionary_unittest.cc index c893d5f..a5e0c48 100644 --- a/chrome/browser/media/webrtc/test_stats_dictionary_unittest.cc +++ b/chrome/browser/media/webrtc/test_stats_dictionary_unittest.cc
@@ -6,10 +6,10 @@ #include <memory> #include <set> +#include <tuple> #include <vector> #include "base/check.h" -#include "base/ignore_result.h" #include "base/json/json_reader.h" #include "base/memory/ref_counted.h" #include "base/values.h" @@ -52,7 +52,7 @@ CHECK(value); base::DictionaryValue* dictionary; CHECK(value->GetAsDictionary(&dictionary)); - ignore_result(value.release()); + std::ignore = value.release(); report_ = new TestStatsReportDictionary( std::unique_ptr<base::DictionaryValue>(dictionary)); }
diff --git a/chrome/browser/media/webrtc/webrtc_browsertest_base.cc b/chrome/browser/media/webrtc/webrtc_browsertest_base.cc index aef3d05b..9cf2dd6 100644 --- a/chrome/browser/media/webrtc/webrtc_browsertest_base.cc +++ b/chrome/browser/media/webrtc/webrtc_browsertest_base.cc
@@ -7,8 +7,8 @@ #include <stddef.h> #include <limits> +#include <tuple> -#include "base/ignore_result.h" #include "base/json/json_reader.h" #include "base/lazy_instance.h" #include "base/logging.h" @@ -566,7 +566,7 @@ base::DictionaryValue* dictionary; CHECK(parsed_json); CHECK(parsed_json->GetAsDictionary(&dictionary)); - ignore_result(parsed_json.release()); + std::ignore = parsed_json.release(); return scoped_refptr<content::TestStatsReportDictionary>( new content::TestStatsReportDictionary( std::unique_ptr<base::DictionaryValue>(dictionary)));
diff --git a/chrome/browser/media/webrtc/webrtc_internals_perf_browsertest.cc b/chrome/browser/media/webrtc/webrtc_internals_perf_browsertest.cc index b8e5553b..9169c9a 100644 --- a/chrome/browser/media/webrtc/webrtc_internals_perf_browsertest.cc +++ b/chrome/browser/media/webrtc/webrtc_internals_perf_browsertest.cc
@@ -3,10 +3,10 @@ // found in the LICENSE file. #include <memory> +#include <tuple> #include "base/command_line.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/json/json_reader.h" #include "base/strings/string_split.h" #include "base/test/test_timeouts.h" @@ -76,7 +76,7 @@ base::JSONReader::ReadDeprecated(all_stats_json); base::DictionaryValue* result; if (parsed_json.get() && parsed_json->GetAsDictionary(&result)) { - ignore_result(parsed_json.release()); + std::ignore = parsed_json.release(); return result; }
diff --git a/chrome/browser/media/webrtc/webrtc_video_display_perf_browsertest.cc b/chrome/browser/media/webrtc/webrtc_video_display_perf_browsertest.cc index 33dd615..e027535 100644 --- a/chrome/browser/media/webrtc/webrtc_video_display_perf_browsertest.cc +++ b/chrome/browser/media/webrtc/webrtc_video_display_perf_browsertest.cc
@@ -3,8 +3,8 @@ // found in the LICENSE file. #include <algorithm> +#include <tuple> -#include "base/ignore_result.h" #include "base/json/json_reader.h" #include "base/strings/string_tokenizer.h" #include "base/strings/stringprintf.h" @@ -132,7 +132,7 @@ base::DictionaryValue* dictionary = nullptr; if (!parsed_json.get() || !parsed_json->GetAsDictionary(&dictionary)) return goog_decode_ms; - ignore_result(parsed_json.release()); + std::ignore = parsed_json.release(); // |dictionary| should have exactly two entries, one per ssrc. if (!dictionary || dictionary->DictSize() != 2u)
diff --git a/chrome/browser/media_galleries/media_galleries_test_util.cc b/chrome/browser/media_galleries/media_galleries_test_util.cc index 2691835e..481bbbc 100644 --- a/chrome/browser/media_galleries/media_galleries_test_util.cc +++ b/chrome/browser/media_galleries/media_galleries_test_util.cc
@@ -7,12 +7,12 @@ #include <stddef.h> #include <memory> +#include <tuple> #include <utility> #include "base/base_paths.h" #include "base/files/file_path.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/path_service.h" #include "base/strings/string_number_conversions.h" #include "base/threading/thread_restrictions.h" @@ -97,7 +97,7 @@ EnsureMediaDirectoriesExists::~EnsureMediaDirectoriesExists() { base::ScopedAllowBlockingForTesting allow_blocking; - ignore_result(fake_dir_.Delete()); + std::ignore = fake_dir_.Delete(); } void EnsureMediaDirectoriesExists::ChangeMediaPathOverrides() {
diff --git a/chrome/browser/metrics/chrome_browser_main_extra_parts_metrics.cc b/chrome/browser/metrics/chrome_browser_main_extra_parts_metrics.cc index 7393962c..91990009b 100644 --- a/chrome/browser/metrics/chrome_browser_main_extra_parts_metrics.cc +++ b/chrome/browser/metrics/chrome_browser_main_extra_parts_metrics.cc
@@ -24,7 +24,6 @@ #include "build/build_config.h" #include "build/chromeos_buildflags.h" #include "build/config/compiler/compiler_buildflags.h" -#include "build/os_buildflags.h" #include "chrome/browser/about_flags.h" #include "chrome/browser/browser_process.h" #include "chrome/browser/buildflags.h"
diff --git a/chrome/browser/metrics/perf/profile_provider_unittest_main.cc b/chrome/browser/metrics/perf/profile_provider_unittest_main.cc index 068e93c2..7ddcc6fe 100644 --- a/chrome/browser/metrics/perf/profile_provider_unittest_main.cc +++ b/chrome/browser/metrics/perf/profile_provider_unittest_main.cc
@@ -4,10 +4,11 @@ #include "chrome/browser/metrics/perf/profile_provider_chromeos.h" +#include <tuple> + #include "base/bind.h" #include "base/callback_helpers.h" #include "base/command_line.h" -#include "base/ignore_result.h" #include "base/metrics/field_trial.h" #include "base/metrics/statistics_recorder.h" #include "base/run_loop.h" @@ -214,7 +215,7 @@ ASSERT_TRUE(profile.has_perf_data()); // Collection succeeded: don't output the error log. - ignore_result(scoped_log_error.Release()); + std::ignore = scoped_log_error.Release(); } protected:
diff --git a/chrome/browser/metrics/perf/windowed_incognito_observer.cc b/chrome/browser/metrics/perf/windowed_incognito_observer.cc index dd4e95e..6c5b8f7 100644 --- a/chrome/browser/metrics/perf/windowed_incognito_observer.cc +++ b/chrome/browser/metrics/perf/windowed_incognito_observer.cc
@@ -4,7 +4,8 @@ #include "chrome/browser/metrics/perf/windowed_incognito_observer.h" -#include "base/ignore_result.h" +#include <tuple> + #include "base/no_destructor.h" #include "chrome/browser/profiles/profile.h" #include "chrome/browser/ui/browser.h" @@ -30,7 +31,7 @@ // static void WindowedIncognitoMonitor::Init() { - ignore_result(WindowedIncognitoMonitor::Get()); + std::ignore = WindowedIncognitoMonitor::Get(); } // static
diff --git a/chrome/browser/navigation_predictor/navigation_predictor_browsertest.cc b/chrome/browser/navigation_predictor/navigation_predictor_browsertest.cc index d8e1e46..e3a51ef4 100644 --- a/chrome/browser/navigation_predictor/navigation_predictor_browsertest.cc +++ b/chrome/browser/navigation_predictor/navigation_predictor_browsertest.cc
@@ -3,6 +3,7 @@ // found in the LICENSE file. #include <memory> +#include <tuple> #include "base/run_loop.h" #include "base/sequence_checker.h" @@ -750,8 +751,8 @@ // Create a fenced frame. const GURL& fenced_frame_url = test_server()->GetURL("/fenced_frames/simple_page_with_anchors.html"); - ignore_result(fenced_frame_test_helper().CreateFencedFrame( - web_contents()->GetMainFrame(), fenced_frame_url)); + std::ignore = fenced_frame_test_helper().CreateFencedFrame( + web_contents()->GetMainFrame(), fenced_frame_url); // Make sure the fenced frame doesn't log any anchors. anchor_entries = test_ukm_recorder->GetEntriesByName(AnchorEntry::kEntryName);
diff --git a/chrome/browser/net/errorpage_browsertest.cc b/chrome/browser/net/errorpage_browsertest.cc index 1fc3187..8460a2ad 100644 --- a/chrome/browser/net/errorpage_browsertest.cc +++ b/chrome/browser/net/errorpage_browsertest.cc
@@ -1082,7 +1082,8 @@ : web_app::SystemWebAppBrowserTestBase(true) {} }; -IN_PROC_BROWSER_TEST_F(ErrorPageOfflineAppLaunchTest, DiagnosticsConnectivity) { +IN_PROC_BROWSER_TEST_F(ErrorPageOfflineAppLaunchTest, + DISABLED_DiagnosticsConnectivity) { WaitForTestSystemAppInstall(); ASSERT_TRUE(ui_test_utils::NavigateToURL( browser(),
diff --git a/chrome/browser/net/profile_network_context_service.cc b/chrome/browser/net/profile_network_context_service.cc index e0a3e473..c052559 100644 --- a/chrome/browser/net/profile_network_context_service.cc +++ b/chrome/browser/net/profile_network_context_service.cc
@@ -707,10 +707,25 @@ local_state->GetFilePath(prefs::kDiskCacheDir); if (!disk_cache_dir.empty()) base_cache_path = disk_cache_dir.Append(base_cache_path.BaseName()); - network_context_params->http_cache_path = + base::FilePath http_cache_path = base_cache_path.Append(chrome::kCacheDirname); - network_context_params->http_cache_max_size = - local_state->GetInteger(prefs::kDiskCacheSize); + if (base::FeatureList::IsEnabled(features::kDisableHttpDiskCache)) { + // Clear any existing on-disk cache first since if the user tries to + // remove the cache it would only affect the in-memory cache while in the + // experiment. + base::ThreadPool::PostTask( + FROM_HERE, + {base::TaskPriority::BEST_EFFORT, base::MayBlock(), + base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN}, + base::BindOnce(base::GetDeletePathRecursivelyCallback(), + http_cache_path)); + network_context_params->http_cache_max_size = + features::kDisableHttpDiskCacheMemoryCacheSizeParam.Get(); + } else { + network_context_params->http_cache_path = http_cache_path; + network_context_params->http_cache_max_size = + local_state->GetInteger(prefs::kDiskCacheSize); + } network_context_params->file_paths = ::network::mojom::NetworkContextFilePaths::New(); @@ -747,6 +762,8 @@ network_context_params->file_paths->transport_security_persister_file_name = base::FilePath(chrome::kTransportSecurityPersisterFilename); + network_context_params->file_paths->sct_auditing_pending_reports_file_name = + base::FilePath(chrome::kSCTAuditingPendingReportsFileName); } const base::Value* hsts_policy_bypass_list = g_browser_process->local_state()->GetList(prefs::kHSTSPolicyBypassList);
diff --git a/chrome/browser/page_load_metrics/page_load_metrics_browsertest.cc b/chrome/browser/page_load_metrics/page_load_metrics_browsertest.cc index 2931e64c..ffb259f 100644 --- a/chrome/browser/page_load_metrics/page_load_metrics_browsertest.cc +++ b/chrome/browser/page_load_metrics/page_load_metrics_browsertest.cc
@@ -26,7 +26,6 @@ #include "base/time/time.h" #include "build/build_config.h" #include "build/chromeos_buildflags.h" -#include "build/os_buildflags.h" #include "chrome/browser/page_load_metrics/observers/aborts_page_load_metrics_observer.h" #include "chrome/browser/page_load_metrics/observers/core/ukm_page_load_metrics_observer.h" #include "chrome/browser/page_load_metrics/observers/document_write_page_load_metrics_observer.h"
diff --git a/chrome/browser/pdf/pdf_extension_test.cc b/chrome/browser/pdf/pdf_extension_test.cc index f300450..ce1e519 100644 --- a/chrome/browser/pdf/pdf_extension_test.cc +++ b/chrome/browser/pdf/pdf_extension_test.cc
@@ -18,7 +18,6 @@ #include "base/files/file_enumerator.h" #include "base/files/file_util.h" #include "base/hash/hash.h" -#include "base/ignore_result.h" #include "base/logging.h" #include "base/memory/raw_ptr.h" #include "base/memory/ref_counted.h" @@ -1183,7 +1182,7 @@ // TODO(crbug.com/1228987): Load success or failure is non-deterministic // currently, due to races between viewport messages and loading. For this // test, we only care that loading terminated, not about success or failure. - ignore_result(pdf_extension_test_util::EnsurePDFHasLoaded(contents)); + std::ignore = pdf_extension_test_util::EnsurePDFHasLoaded(contents); } IN_PROC_BROWSER_TEST_P(PDFExtensionJSTest, ViewerToolbar) {
diff --git a/chrome/browser/performance_manager/policies/bfcache_policy_browsertest.cc b/chrome/browser/performance_manager/policies/bfcache_policy_browsertest.cc index 271e025..06b0b9f 100644 --- a/chrome/browser/performance_manager/policies/bfcache_policy_browsertest.cc +++ b/chrome/browser/performance_manager/policies/bfcache_policy_browsertest.cc
@@ -12,7 +12,7 @@ #include "base/run_loop.h" #include "base/test/bind.h" #include "base/test/scoped_feature_list.h" -#include "build/os_buildflags.h" +#include "build/build_config.h" #include "chrome/browser/ui/browser.h" #include "chrome/test/base/in_process_browser_test.h" #include "chrome/test/base/ui_test_utils.h"
diff --git a/chrome/browser/plugins/plugin_response_interceptor_url_loader_throttle.cc b/chrome/browser/plugins/plugin_response_interceptor_url_loader_throttle.cc index ae5c53a9..8179ba9 100644 --- a/chrome/browser/plugins/plugin_response_interceptor_url_loader_throttle.cc +++ b/chrome/browser/plugins/plugin_response_interceptor_url_loader_throttle.cc
@@ -4,12 +4,12 @@ #include "chrome/browser/plugins/plugin_response_interceptor_url_loader_throttle.h" +#include <tuple> #include <utility> #include "base/bind.h" #include "base/feature_list.h" #include "base/guid.h" -#include "base/ignore_result.h" #include "chrome/browser/extensions/api/streams_private/streams_private_api.h" #include "chrome/browser/plugins/plugin_utils.h" #include "content/public/browser/browser_task_traits.h" @@ -131,7 +131,7 @@ std::string payload = view_id; mojo::PendingRemote<network::mojom::URLLoader> dummy_new_loader; - ignore_result(dummy_new_loader.InitWithNewPipeAndPassReceiver()); + std::ignore = dummy_new_loader.InitWithNewPipeAndPassReceiver(); mojo::Remote<network::mojom::URLLoaderClient> new_client; mojo::PendingReceiver<network::mojom::URLLoaderClient> new_client_receiver = new_client.BindNewPipeAndPassReceiver();
diff --git a/chrome/browser/portal/portal_recently_audible_browsertest.cc b/chrome/browser/portal/portal_recently_audible_browsertest.cc index c799323..a84c9593 100644 --- a/chrome/browser/portal/portal_recently_audible_browsertest.cc +++ b/chrome/browser/portal/portal_recently_audible_browsertest.cc
@@ -2,8 +2,9 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <tuple> + #include "base/containers/contains.h" -#include "base/ignore_result.h" #include "base/run_loop.h" #include "base/test/test_timeouts.h" #include "base/threading/thread_task_runner_handle.h" @@ -238,7 +239,7 @@ // Ideally this would never briefly flicker to false, but it can because the // hystersis here applies at the WebContents level, not the tab level, and // portals swaps WebContents. So if it does change to false, ignore that... - ignore_result(ActiveTabChangesTo(TabAlertState::AUDIO_PLAYING, false)); + std::ignore = ActiveTabChangesTo(TabAlertState::AUDIO_PLAYING, false); // ...for it will shortly become true again. EXPECT_TRUE(ActiveTabChangesTo(TabAlertState::AUDIO_PLAYING, true));
diff --git a/chrome/browser/predictors/predictor_database.cc b/chrome/browser/predictors/predictor_database.cc index 9465c76b..e593300 100644 --- a/chrome/browser/predictors/predictor_database.cc +++ b/chrome/browser/predictors/predictor_database.cc
@@ -11,7 +11,6 @@ #include "base/check.h" #include "base/files/file_path.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/task/sequenced_task_runner.h" #include "chrome/browser/predictors/autocomplete_action_predictor_table.h" #include "chrome/browser/predictors/loading_predictor_config.h" @@ -118,10 +117,6 @@ autocomplete_table_->Initialize(db_.get()); resource_prefetch_tables_->Initialize(db_.get()); - // The logged_in_predictor table is obsolete as of Chrome 44. - // TODO(davidben): Remove this after April 16, 2016. - ignore_result(db_->Execute("DROP TABLE IF EXISTS logged_in_predictor")); - LogDatabaseStats(); }
diff --git a/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_request.cc b/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_request.cc index 1d3d1ea..e3eaad9 100644 --- a/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_request.cc +++ b/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_request.cc
@@ -5,6 +5,7 @@ #include "chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_request.h" #include "chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_url_loader.h" +#include "streaming_search_prefetch_request.h" StreamingSearchPrefetchRequest::StreamingSearchPrefetchRequest( const GURL& prefetch_url, @@ -19,7 +20,9 @@ std::unique_ptr<network::ResourceRequest> resource_request, const net::NetworkTrafficAnnotationTag& network_traffic_annotation) { streaming_url_loader_ = std::make_unique<StreamingSearchPrefetchURLLoader>( - this, profile, std::move(resource_request), network_traffic_annotation); + this, profile, std::move(resource_request), network_traffic_annotation, + base::BindOnce(&StreamingSearchPrefetchRequest::StopPrefetch, + weak_factory_.GetWeakPtr())); } std::unique_ptr<SearchPrefetchURLLoader>
diff --git a/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_request.h b/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_request.h index ff051c6..ad75d40 100644 --- a/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_request.h +++ b/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_request.h
@@ -47,6 +47,8 @@ private: // The ongoing prefetch request. Null before and after the fetch. std::unique_ptr<StreamingSearchPrefetchURLLoader> streaming_url_loader_; + + base::WeakPtrFactory<StreamingSearchPrefetchRequest> weak_factory_{this}; }; #endif // CHROME_BROWSER_PREFETCH_SEARCH_PREFETCH_STREAMING_SEARCH_PREFETCH_REQUEST_H_
diff --git a/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_url_loader.cc b/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_url_loader.cc index cf2a420..0c3f0c9 100644 --- a/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_url_loader.cc +++ b/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_url_loader.cc
@@ -26,15 +26,18 @@ #include "services/network/public/cpp/shared_url_loader_factory.h" #include "services/network/public/mojom/early_hints.mojom.h" #include "services/network/public/mojom/url_response_head.mojom.h" +#include "streaming_search_prefetch_url_loader.h" #include "url/gurl.h" StreamingSearchPrefetchURLLoader::StreamingSearchPrefetchURLLoader( StreamingSearchPrefetchRequest* streaming_prefetch_request, Profile* profile, std::unique_ptr<network::ResourceRequest> resource_request, - const net::NetworkTrafficAnnotationTag& network_traffic_annotation) + const net::NetworkTrafficAnnotationTag& network_traffic_annotation, + base::OnceClosure stop_prefetch_closure) : resource_request_(std::move(resource_request)), - streaming_prefetch_request_(streaming_prefetch_request) { + streaming_prefetch_request_(streaming_prefetch_request), + stop_prefetch_closure_(std::move(stop_prefetch_closure)) { DCHECK(streaming_prefetch_request_); auto url_loader_factory = profile->GetDefaultStoragePartition() ->GetURLLoaderFactoryForBrowserProcess(); @@ -126,7 +129,7 @@ if (streaming_prefetch_request_) { streaming_prefetch_request_->ErrorEncountered(); } else { - delete this; + PostTaskToStopPrefetchAndDeleteSelf(); } } @@ -210,7 +213,7 @@ mojo::CreateDataPipe(&options, producer_handle_, consumer_handle); if (rv != MOJO_RESULT_OK) { - delete this; + PostTaskToStopPrefetchAndDeleteSelf(); return; } @@ -232,7 +235,7 @@ MojoResult result, const mojo::HandleSignalsState& state) { if (result != MOJO_RESULT_OK) { - delete this; + PostTaskToStopPrefetchAndDeleteSelf(); return; } PushData(); @@ -258,7 +261,7 @@ } if (result != MOJO_RESULT_OK) { - delete this; + PostTaskToStopPrefetchAndDeleteSelf(); return; } @@ -347,16 +350,24 @@ DCHECK(streaming_prefetch_request_); streaming_prefetch_request_->ErrorEncountered(); } else { - delete this; + PostTaskToStopPrefetchAndDeleteSelf(); } } void StreamingSearchPrefetchURLLoader::OnURLLoaderClientMojoDisconnect() { DCHECK(forwarding_client_); DCHECK(!streaming_prefetch_request_); - delete this; + PostTaskToStopPrefetchAndDeleteSelf(); } void StreamingSearchPrefetchURLLoader::ClearOwnerPointer() { streaming_prefetch_request_ = nullptr; } + +void StreamingSearchPrefetchURLLoader::PostTaskToStopPrefetchAndDeleteSelf() { + // To avoid UAF bugs, post a separate task to delete this object. + if (stop_prefetch_closure_) { + base::SequencedTaskRunnerHandle::Get()->PostTask( + FROM_HERE, std::move(stop_prefetch_closure_)); + } +}
diff --git a/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_url_loader.h b/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_url_loader.h index c49a0e0..bdf2924 100644 --- a/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_url_loader.h +++ b/chrome/browser/prefetch/search_prefetch/streaming_search_prefetch_url_loader.h
@@ -40,7 +40,8 @@ StreamingSearchPrefetchRequest* streaming_prefetch_request, Profile* profile, std::unique_ptr<network::ResourceRequest> resource_request, - const net::NetworkTrafficAnnotationTag& network_traffic_annotation); + const net::NetworkTrafficAnnotationTag& network_traffic_annotation, + base::OnceClosure stop_prefetch_closure); ~StreamingSearchPrefetchURLLoader() override; @@ -106,6 +107,9 @@ // Clears |producer_handle_| and |handle_watcher_|. void Finish(); + // Post a task to delete this object by running stop_prefetch_closure_. + void PostTaskToStopPrefetchAndDeleteSelf(); + // Sets up mojo forwarding to the navigation path. Resumes // |network_url_loader_| calls. Serves the start of the response to the // navigation path. After this method is called, |this| manages its own @@ -165,6 +169,9 @@ mojo::ScopedDataPipeProducerHandle producer_handle_; std::unique_ptr<mojo::SimpleWatcher> handle_watcher_; + // Closure to cancel this prefetch. Running this callback will destroy |this|. + base::OnceClosure stop_prefetch_closure_; + base::WeakPtrFactory<StreamingSearchPrefetchURLLoader> weak_factory_{this}; };
diff --git a/chrome/browser/printing/cloud_print/cloud_print_proxy_service_unittest.cc b/chrome/browser/printing/cloud_print/cloud_print_proxy_service_unittest.cc index fb6b838..1e4d705 100644 --- a/chrome/browser/printing/cloud_print/cloud_print_proxy_service_unittest.cc +++ b/chrome/browser/printing/cloud_print/cloud_print_proxy_service_unittest.cc
@@ -8,11 +8,11 @@ #include <memory> #include <string> +#include <tuple> #include <utility> #include "base/bind.h" #include "base/command_line.h" -#include "base/ignore_result.h" #include "base/location.h" #include "base/run_loop.h" #include "base/task/single_thread_task_runner.h" @@ -161,7 +161,7 @@ &TestCloudPrintProxyService::HandleCloudPrintProxyRequest, base::Unretained(this))); mojo::PendingRemote<service_manager::mojom::InterfaceProvider> handle; - ignore_result(handle.InitWithNewPipeAndPassReceiver()); + std::ignore = handle.InitWithNewPipeAndPassReceiver(); process_control_.SetMojoHandle(std::move(handle)); }
diff --git a/chrome/browser/profiles/chrome_browser_main_extra_parts_profiles.cc b/chrome/browser/profiles/chrome_browser_main_extra_parts_profiles.cc index 074ea8a..98cd5b7 100644 --- a/chrome/browser/profiles/chrome_browser_main_extra_parts_profiles.cc +++ b/chrome/browser/profiles/chrome_browser_main_extra_parts_profiles.cc
@@ -9,7 +9,6 @@ #include "build/build_config.h" #include "build/chromeos_buildflags.h" -#include "build/os_buildflags.h" #include "chrome/browser/accuracy_tips/accuracy_service_factory.h" #include "chrome/browser/autocomplete/autocomplete_classifier_factory.h" #include "chrome/browser/autocomplete/in_memory_url_index_factory.h"
diff --git a/chrome/browser/profiles/keep_alive/profile_keep_alive_types.cc b/chrome/browser/profiles/keep_alive/profile_keep_alive_types.cc index 7c364d8..a5d217af 100644 --- a/chrome/browser/profiles/keep_alive/profile_keep_alive_types.cc +++ b/chrome/browser/profiles/keep_alive/profile_keep_alive_types.cc
@@ -53,6 +53,8 @@ return out << "kWebAppUpdate"; case ProfileKeepAliveOrigin::kGettingWebAppInfo: return out << "kGettingWebAppInfo"; + case ProfileKeepAliveOrigin::kCrxInstaller: + return out << "kCrxInstaller"; } NOTREACHED(); return out << static_cast<int>(origin);
diff --git a/chrome/browser/profiles/keep_alive/profile_keep_alive_types.h b/chrome/browser/profiles/keep_alive/profile_keep_alive_types.h index 5a3148e..02e54bc 100644 --- a/chrome/browser/profiles/keep_alive/profile_keep_alive_types.h +++ b/chrome/browser/profiles/keep_alive/profile_keep_alive_types.h
@@ -104,7 +104,10 @@ // --list-apps switch. kGettingWebAppInfo = 23, - kMaxValue = kGettingWebAppInfo, + // An extension .crx is being installed. + kCrxInstaller = 24, + + kMaxValue = kCrxInstaller, }; std::ostream& operator<<(std::ostream& out,
diff --git a/chrome/browser/profiles/keep_alive/scoped_profile_keep_alive.cc b/chrome/browser/profiles/keep_alive/scoped_profile_keep_alive.cc index a709aac..5f26b2a 100644 --- a/chrome/browser/profiles/keep_alive/scoped_profile_keep_alive.cc +++ b/chrome/browser/profiles/keep_alive/scoped_profile_keep_alive.cc
@@ -45,6 +45,10 @@ const Profile* profile, ProfileKeepAliveOrigin origin) { DCHECK_CURRENTLY_ON(content::BrowserThread::UI); + // |g_browser_process| could be nullptr if this is called during shutdown, + // e.g. in tests. + if (!g_browser_process) + return; // |profile_manager| could be nullptr if this is called during shutdown, e.g. // for system/guest profiles or in tests. auto* profile_manager = g_browser_process->profile_manager();
diff --git a/chrome/browser/profiles/profile_manager.cc b/chrome/browser/profiles/profile_manager.cc index 798a594..e7d59e6 100644 --- a/chrome/browser/profiles/profile_manager.cc +++ b/chrome/browser/profiles/profile_manager.cc
@@ -68,7 +68,6 @@ #include "chrome/browser/profiles/profiles_state.h" #include "chrome/browser/signin/account_reconcilor_factory.h" #include "chrome/browser/signin/identity_manager_factory.h" -#include "chrome/browser/signin/primary_account_policy_manager_factory.h" #include "chrome/browser/signin/signin_util.h" #include "chrome/browser/sync/sync_service_factory.h" #include "chrome/browser/ui/startup/startup_browser_creator.h" @@ -1548,7 +1547,7 @@ // had enough time to initialize and should have updated the user signout // flag attached to the profile. signin_util::EnsureUserSignoutAllowedIsInitializedForProfile(profile); - PrimaryAccountPolicyManagerFactory::GetForProfile(profile)->Initialize(); + signin_util::EnsurePrimaryAccountAllowedForProfile(profile); #if !defined(OS_ANDROID) // The caret browsing command-line switch toggles caret browsing on
diff --git a/chrome/browser/resources/access_code_cast/access_code_cast.ts b/chrome/browser/resources/access_code_cast/access_code_cast.ts index a71f5d7..b40c6f8 100644 --- a/chrome/browser/resources/access_code_cast/access_code_cast.ts +++ b/chrome/browser/resources/access_code_cast/access_code_cast.ts
@@ -63,6 +63,10 @@ this.router = BrowserProxy.getInstance().callbackRouter; this.accessCode = ''; + + window.onblur = () => { + this.close(); + }; } ready() {
diff --git a/chrome/browser/resources/chromeos/edu_coexistence/edu_coexistence_template.html b/chrome/browser/resources/chromeos/edu_coexistence/edu_coexistence_template.html index c41e908..d6f32db 100644 --- a/chrome/browser/resources/chromeos/edu_coexistence/edu_coexistence_template.html +++ b/chrome/browser/resources/chromeos/edu_coexistence/edu_coexistence_template.html
@@ -2,11 +2,11 @@ :host { overflow-y: hidden; --background-gradient-0: linear-gradient(0deg, - rgba(var(--google-grey-refresh-100-rgb), 1) 0, - rgba(var(--google-grey-refresh-100-rgb), 0) 8px); + rgba(var(--google-grey-100-rgb), 1) 0, + rgba(var(--google-grey-100-rgb), 0) 8px); --background-gradient-180: linear-gradient(180deg, - rgba(var(--google-grey-refresh-100-rgb), 1) 0, - rgba(var(--google-grey-refresh-100-rgb), 0) 8px); + rgba(var(--google-grey-100-rgb), 1) 0, + rgba(var(--google-grey-100-rgb), 0) 8px); } </style> <div class="template-container">
diff --git a/chrome/browser/resources/chromeos/emoji_picker/emoji_button.html b/chrome/browser/resources/chromeos/emoji_picker/emoji_button.html index 46f3abb..8b30593d 100644 --- a/chrome/browser/resources/chromeos/emoji_picker/emoji_button.html +++ b/chrome/browser/resources/chromeos/emoji_picker/emoji_button.html
@@ -54,7 +54,7 @@ * dark mode since dark mode is the same. */ background: linear-gradient( - 315deg, var(--google-grey-refresh-500) 4px, + 315deg, var(--google-grey-500) 4px, var(--emoji-background) 4px, var(--emoji-background)); content: ''; display: block;
diff --git a/chrome/browser/resources/chromeos/login/components/common_styles/common_styles.html b/chrome/browser/resources/chromeos/login/components/common_styles/common_styles.html index 133fa9d..41b569b 100644 --- a/chrome/browser/resources/chromeos/login/components/common_styles/common_styles.html +++ b/chrome/browser/resources/chromeos/login/components/common_styles/common_styles.html
@@ -45,22 +45,22 @@ OobeScrollableBehavior. */ .scrollable.can-scroll:not(.is-scrolled):not(.scrolled-to-bottom) { background: linear-gradient(0deg, - rgba(var(--google-grey-refresh-100-rgb), 1) 0, - rgba(var(--google-grey-refresh-100-rgb), 0) 8px); + rgba(var(--google-grey-100-rgb), 1) 0, + rgba(var(--google-grey-100-rgb), 0) 8px); } .scrollable.can-scroll.is-scrolled:not(.scrolled-to-bottom) { background: linear-gradient(0deg, - rgba(var(--google-grey-refresh-100-rgb), 1) 0, - rgba(var(--google-grey-refresh-100-rgb), 0) 8px), + rgba(var(--google-grey-100-rgb), 1) 0, + rgba(var(--google-grey-100-rgb), 0) 8px), linear-gradient(180deg, - rgba(var(--google-grey-refresh-100-rgb), 1) 0, - rgba(var(--google-grey-refresh-100-rgb), 0) 8px); + rgba(var(--google-grey-100-rgb), 1) 0, + rgba(var(--google-grey-100-rgb), 0) 8px); } .scrollable.is-scrolled.scrolled-to-bottom { background: linear-gradient(180deg, - rgba(var(--google-grey-refresh-100-rgb), 1) 0, - rgba(var(--google-grey-refresh-100-rgb), 0) 8px); + rgba(var(--google-grey-100-rgb), 1) 0, + rgba(var(--google-grey-100-rgb), 0) 8px); } /* Links styles used within OOBE */
diff --git a/chrome/browser/resources/chromeos/login/screens/common/hw_data_collection.html b/chrome/browser/resources/chromeos/login/screens/common/hw_data_collection.html index de9106f7..f776e512 100644 --- a/chrome/browser/resources/chromeos/login/screens/common/hw_data_collection.html +++ b/chrome/browser/resources/chromeos/login/screens/common/hw_data_collection.html
@@ -21,7 +21,7 @@ <template> <style include="oobe-dialog-host-styles"> #dataUsageLabelContainer { - color: var(--google-grey-refresh-700); /* #5F6368 */ + color: var(--google-grey-700); /* #5F6368 */ line-height: 18px; }
diff --git a/chrome/browser/resources/chromeos/login/screens/oobe/oobe_eula.html b/chrome/browser/resources/chromeos/login/screens/oobe/oobe_eula.html index 9f10f5a..a0bd9e84 100644 --- a/chrome/browser/resources/chromeos/login/screens/oobe/oobe_eula.html +++ b/chrome/browser/resources/chromeos/login/screens/oobe/oobe_eula.html
@@ -80,7 +80,7 @@ } #usageStatsLabelContainer { - color: var(--google-grey-refresh-700); /* #5F6368 */ + color: var(--google-grey-700); /* #5F6368 */ line-height: 18px; }
diff --git a/chrome/browser/resources/chromeos/login/screens/oobe/welcome.html b/chrome/browser/resources/chromeos/login/screens/oobe/welcome.html index d573284c..77d714e 100644 --- a/chrome/browser/resources/chromeos/login/screens/oobe/welcome.html +++ b/chrome/browser/resources/chromeos/login/screens/oobe/welcome.html
@@ -79,7 +79,7 @@ } #oobeAdvancedOptionsScreen .advanced-option-subtitle { - color: var(--google-grey-refresh-700); + color: var(--google-grey-700); } </style> <oobe-welcome-dialog id="welcomeScreen" role="dialog" for-step="greeting"
diff --git a/chrome/browser/resources/feedback_webui/css/feedback.css b/chrome/browser/resources/feedback_webui/css/feedback.css index 007821b..c3636b2 100644 --- a/chrome/browser/resources/feedback_webui/css/feedback.css +++ b/chrome/browser/resources/feedback_webui/css/feedback.css
@@ -8,7 +8,6 @@ } body { - background-color: #fbfbfb; display: flex; flex-direction: column; height: 100%; @@ -24,9 +23,9 @@ .title-bar { -webkit-align-items: center; -webkit-app-region: drag; - background-color: #fff; - box-shadow: 0 1px #d0d0d0; - color: rgb(80, 80, 82); + background-color: var(--feedback-bg-color); + box-shadow: 0 1px var(--feedback-box-shadow-color); + color: var(--feedback-primary-color); display: flex; flex-grow: 0; font-size: 15px; @@ -50,7 +49,7 @@ } .content { - color: #444; + color: var(--feedback-primary-color); font-size: 12px; } @@ -59,7 +58,7 @@ } .content #description-text { - border-color: #c8c8c8; + border-color: var(--feedback-separator-color); box-sizing: border-box; height: 120px; line-height: 18px; @@ -86,18 +85,18 @@ } .content .text-field-container > select { - border: 1px solid #c8c8c8; - color: #585858; + border: 1px solid var(--feedback-separator-color); + color: var(--feedback-secondary-color); flex: 1 1 auto; height: 100%; padding-inline-start: 5px; } .content .text-field-container > input[type=text] { - flex: 1 1 auto; border: 1px solid; - border-color: #c8c8c8; - color: #585858; + border-color: var(--feedback-separator-color); + color: var(--feedback-secondary-color); + flex: 1 1 auto; height: 100%; padding-inline-start: 5px; } @@ -107,7 +106,7 @@ } .content .description-empty-notification { - color: rgb(204, 0, 0); + color: var(--feedback-alert-color); font-weight: bold; } @@ -228,7 +227,7 @@ } .content .attach-file-notification { - color: rgb(204, 0, 0); + color: var(--feedback-alert-color); font-weight: bold; } @@ -239,7 +238,7 @@ button.blue-button { color: #fff; - text-shadow: 1px sharp drop shadow rgb(45, 106, 218); + text-shadow: 1px sharp drop shadow var(--feedback-prominent-color); } /* Used for elements that are hidden but not ignored by screen reader. */
diff --git a/chrome/browser/resources/feedback_webui/css/feedback_shared_vars.css b/chrome/browser/resources/feedback_webui/css/feedback_shared_vars.css index 195bb77..a8a9583 100644 --- a/chrome/browser/resources/feedback_webui/css/feedback_shared_vars.css +++ b/chrome/browser/resources/feedback_webui/css/feedback_shared_vars.css
@@ -4,25 +4,42 @@ html:not(body) { /* Google colors used in UI. */ - --google-blue-300: rgb(138, 180, 248); + --google-blue-50: rgb(232, 240, 254); + --google-blue-300-rgb: 138, 180, 248; + --google-blue-300: rgb(var(--google-blue-300-rgb)); --google-blue-600: rgb(26, 115, 232); --google-grey-200: rgb(232, 234, 237); - --google-grey-900: rgb(32, 33, 36); + --google-grey-500: rgb(154, 160, 166); + --google-grey-700: rgb(95, 99, 104); + --google-grey-900-rgb: 32, 33, 36; + --google-grey-900: rgb(var(--google-grey-900-rgb)); --google-red-300: rgb(242, 139, 130); --google-red-600: rgb(217, 48, 37); /* Feedback specific variables. */ - --feedback-bg-color: #fff; + --feedback-alert-color: var(--google-red-600); + --feedback-bg-color: rgb(255, 255, 255); + --feedback-box-shadow-color: #d0d0d0; + --feedback-highlight-color: var(--google-blue-50); --feedback-link-color: var(--google-blue-600); --feedback-primary-color: var(--google-grey-900); + --feedback-prominent-color: var(--google-blue-600); + --feedback-secondary-color: var(--google-grey-700); + --feedback-separator-color: rgba(0, 0, 0, 0.14); } @media (prefers-color-scheme: dark) { html:not(body) { + --feedback-alert-color: var(--google-red-300); --feedback-bg-color: var(--google-grey-900); + --feedback-box-shadow-color: rgba(var(--google-grey-900-rgb), 0.04); + --feedback-highlight-color: rgba(var(--google-blue-300-rgb), 0.3); --feedback-link-color: var(--google-blue-300); --feedback-primary-color: var(--google-grey-200); + --feedback-prominent-color: var(--google-blue-300); + --feedback-secondary-color: var(--google-grey-500); + --feedback-separator-color: rgba(255, 255, 255, 0.14); } }
diff --git a/chrome/browser/resources/feedback_webui/css/sys_info.css b/chrome/browser/resources/feedback_webui/css/sys_info.css index f2d8842..98fe1d8 100644 --- a/chrome/browser/resources/feedback_webui/css/sys_info.css +++ b/chrome/browser/resources/feedback_webui/css/sys_info.css
@@ -13,11 +13,15 @@ } #tableTitle { - color: #4a4a4a; + color: var(--feedback-primary-color); } #status { - color: rgb(66, 133, 244); + color: var(--feedback-prominent-color); display: inline-block; margin: .5em .5em; } + +.list:not(.filtered) tr:nth-child(odd) td { + background-color: var(--feedback-highlight-color); +}
diff --git a/chrome/browser/resources/history/history_item.html b/chrome/browser/resources/history/history_item.html index 78f2c68..51952ca 100644 --- a/chrome/browser/resources/history/history_item.html +++ b/chrome/browser/resources/history/history_item.html
@@ -224,6 +224,7 @@ <template is="dom-if" if="[[item.starred]]"> <cr-icon-button id="bookmark-star" iron-icon="cr:star" on-click="onRemoveBookmarkTap_" + title="$i18n{removeBookmark}" aria-hidden="true"> </cr-icon-button> </template>
diff --git a/chrome/browser/resources/management/management_browser_proxy.ts b/chrome/browser/resources/management/management_browser_proxy.ts index a05336d1..78bdba3 100644 --- a/chrome/browser/resources/management/management_browser_proxy.ts +++ b/chrome/browser/resources/management/management_browser_proxy.ts
@@ -68,6 +68,7 @@ EXTENSION = 'extension', ANDROID_APPLICATION = 'android application', LOGIN_LOGOUT = 'login-logout', + CRD_SESSIONS = 'crd sessions', }
diff --git a/chrome/browser/resources/management/management_ui.ts b/chrome/browser/resources/management/management_ui.ts index 164100d..17883b6 100644 --- a/chrome/browser/resources/management/management_ui.ts +++ b/chrome/browser/resources/management/management_ui.ts
@@ -280,6 +280,8 @@ return 'management:play-store'; case DeviceReportingType.LOGIN_LOGOUT: return 'management:timelapse'; + case DeviceReportingType.CRD_SESSIONS: + return 'management:timelapse'; default: return 'cr:computer'; }
diff --git a/chrome/browser/resources/pdf/BUILD.gn b/chrome/browser/resources/pdf/BUILD.gn index f20f682..3dc781e 100644 --- a/chrome/browser/resources/pdf/BUILD.gn +++ b/chrome/browser/resources/pdf/BUILD.gn
@@ -53,7 +53,8 @@ input_files = print_preview_non_webcomponents_files + shared_non_webcomponents_files + print_preview_webcomponents_files + shared_webcomponents_files - input_files_base_dir = rebase_path("$target_gen_dir/$preprocess_folder", "//") + input_files_base_dir = + rebase_path("$target_gen_dir/$preprocess_folder", root_build_dir) deps = [ ":preprocess", ":preprocess_generated",
diff --git a/chrome/browser/resources/pdf/elements/shared-css.html b/chrome/browser/resources/pdf/elements/shared-css.html index 66dfb56..eea9186 100644 --- a/chrome/browser/resources/pdf/elements/shared-css.html +++ b/chrome/browser/resources/pdf/elements/shared-css.html
@@ -23,7 +23,7 @@ --cr-menu-shadow: rgba(0, 0, 0, .3) 0 1px 2px 0, rgba(0, 0, 0, .15) 0 3px 6px 2px; --cr-primary-text-color: var(--google-grey-200); - --cr-menu-background-focus-color: var(--google-grey-refresh-700); + --cr-menu-background-focus-color: var(--google-grey-700); --cr-menu-background-sheen: rgba(255, 255, 255, .06); --cr-separator-line: var(--cr-separator-height) solid rgba(255, 255, 255, .1);
diff --git a/chrome/browser/resources/pdf/elements/viewer-pdf-sidenav.html b/chrome/browser/resources/pdf/elements/viewer-pdf-sidenav.html index cf7dea86..cc5c8ff3 100644 --- a/chrome/browser/resources/pdf/elements/viewer-pdf-sidenav.html +++ b/chrome/browser/resources/pdf/elements/viewer-pdf-sidenav.html
@@ -1,6 +1,6 @@ <style include="pdf-shared cr-hidden-style cr-shared-style"> :host { - --sidenav-selected-tab-color: var(--google-blue-refresh-300); + --sidenav-selected-tab-color: var(--google-blue-300); background-color: var(--viewer-pdf-toolbar-background-color); display: flex; height: 100%;
diff --git a/chrome/browser/resources/pdf/elements/viewer-properties-dialog.html b/chrome/browser/resources/pdf/elements/viewer-properties-dialog.html index 19c99f2..18d3ad9 100644 --- a/chrome/browser/resources/pdf/elements/viewer-properties-dialog.html +++ b/chrome/browser/resources/pdf/elements/viewer-properties-dialog.html
@@ -12,7 +12,7 @@ } .break > td { - --break-color: var(--google-grey-refresh-300); + --break-color: var(--google-grey-300); border-bottom: 1px solid var(--break-color); padding-bottom: var(--break-padding); }
diff --git a/chrome/browser/resources/pdf/elements/viewer-thumbnail.html b/chrome/browser/resources/pdf/elements/viewer-thumbnail.html index d9fd672..b0cdb08c 100644 --- a/chrome/browser/resources/pdf/elements/viewer-thumbnail.html +++ b/chrome/browser/resources/pdf/elements/viewer-thumbnail.html
@@ -1,6 +1,6 @@ <style> :host { - --focus-border-color: var(--google-blue-refresh-300); + --focus-border-color: var(--google-blue-300); display: block; }
diff --git a/chrome/browser/resources/pdf/elements/viewer-toolbar.html b/chrome/browser/resources/pdf/elements/viewer-toolbar.html index e6d49e2..08eadb8 100644 --- a/chrome/browser/resources/pdf/elements/viewer-toolbar.html +++ b/chrome/browser/resources/pdf/elements/viewer-toolbar.html
@@ -132,7 +132,7 @@ } paper-progress { - --paper-progress-active-color: var(--google-blue-refresh-300); + --paper-progress-active-color: var(--google-blue-300); --paper-progress-container-color: transparent; --paper-progress-height: 3px; bottom: 0;
diff --git a/chrome/browser/resources/pdf/pdf_viewer_pp.js b/chrome/browser/resources/pdf/pdf_viewer_pp.js index 28ce7cf1..eb4edb5a 100644 --- a/chrome/browser/resources/pdf/pdf_viewer_pp.js +++ b/chrome/browser/resources/pdf/pdf_viewer_pp.js
@@ -362,7 +362,7 @@ } /** - * The background color used for print preview (--google-grey-refresh-300). Keep + * The background color used for print preview (--google-grey-300). Keep * in sync with `ChromePdfStreamDelegate::MapToOriginalUrl()`. * @type {number} */ @@ -370,7 +370,7 @@ /** * The background color used for print preview when dark mode is enabled - * (--google-grey-refresh-700). + * (--google-grey-700). * @type {number} */ const PRINT_PREVIEW_DARK_BACKGROUND_COLOR = 0xff5f6368;
diff --git a/chrome/browser/safe_browsing/chrome_cleaner/reporter_runner_win.cc b/chrome/browser/safe_browsing/chrome_cleaner/reporter_runner_win.cc index bf117cf..11d1d3c4 100644 --- a/chrome/browser/safe_browsing/chrome_cleaner/reporter_runner_win.cc +++ b/chrome/browser/safe_browsing/chrome_cleaner/reporter_runner_win.cc
@@ -8,6 +8,7 @@ #include <algorithm> #include <memory> +#include <tuple> #include <utility> #include <vector> @@ -17,7 +18,6 @@ #include "base/command_line.h" #include "base/files/file_path.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/metrics/field_trial.h" #include "base/metrics/histogram_macros.h" #include "base/metrics/sparse_histogram.h" @@ -470,7 +470,7 @@ // The reporter sequence has been scheduled to run, so don't notify that // it has not been scheduled. - ignore_result(scoped_runner.Release()); + std::ignore = scoped_runner.Release(); } private: @@ -620,8 +620,8 @@ if (!invocations_.container().empty()) { // If there are other invocations to start, then we shouldn't finalize // this object. ScopedClosureRunner::Release requires its return value to - // be used, so simply ignore_result it, since it will not be needed. - ignore_result(scoped_runner.Release()); + // be used, so simply std::ignore it, since it will not be needed. + std::ignore = scoped_runner.Release(); PostNextInvocation(); }
diff --git a/chrome/browser/safe_browsing/incident_reporting/last_download_finder.cc b/chrome/browser/safe_browsing/incident_reporting/last_download_finder.cc index 1de851e..c0670cb 100644 --- a/chrome/browser/safe_browsing/incident_reporting/last_download_finder.cc +++ b/chrome/browser/safe_browsing/incident_reporting/last_download_finder.cc
@@ -10,10 +10,10 @@ #include <algorithm> #include <functional> #include <memory> +#include <tuple> #include <utility> #include "base/bind.h" -#include "base/ignore_result.h" #include "base/memory/ptr_util.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" @@ -187,7 +187,7 @@ download_request->set_url(download.url_chain.back().spec()); // digests is a required field, so force it to exist. // TODO(grt): Include digests in reports; http://crbug.com/389123. - ignore_result(download_request->mutable_digests()); + std::ignore = download_request->mutable_digests(); download_request->set_length(download.received_bytes); for (size_t i = 0; i < download.url_chain.size(); ++i) { const GURL& url = download.url_chain[i];
diff --git a/chrome/browser/share/android/javatests/src/org/chromium/chrome/browser/share/share_sheet/ChromeProvidedSharingOptionsProviderTest.java b/chrome/browser/share/android/javatests/src/org/chromium/chrome/browser/share/share_sheet/ChromeProvidedSharingOptionsProviderTest.java index 11f292c..669ed7b 100644 --- a/chrome/browser/share/android/javatests/src/org/chromium/chrome/browser/share/share_sheet/ChromeProvidedSharingOptionsProviderTest.java +++ b/chrome/browser/share/android/javatests/src/org/chromium/chrome/browser/share/share_sheet/ChromeProvidedSharingOptionsProviderTest.java
@@ -162,6 +162,7 @@ @Test @MediumTest + @DisabledTest(message = "http://crbug/1285362") @Features.EnableFeatures({ChromeFeatureList.LIGHTWEIGHT_REACTIONS}) public void getPropertyModels_lightweightReactionsEnabled() { setUpChromeProvidedSharingOptionsProviderTest(
diff --git a/chrome/browser/shared_highlighting/shared_highlighting_browsertest.cc b/chrome/browser/shared_highlighting/shared_highlighting_browsertest.cc index daf7c78e..75bd2a8 100644 --- a/chrome/browser/shared_highlighting/shared_highlighting_browsertest.cc +++ b/chrome/browser/shared_highlighting/shared_highlighting_browsertest.cc
@@ -2,7 +2,6 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "base/ignore_result.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" #include "build/build_config.h"
diff --git a/chrome/browser/signin/primary_account_policy_manager.cc b/chrome/browser/signin/primary_account_policy_manager.cc deleted file mode 100644 index 35890f8..0000000 --- a/chrome/browser/signin/primary_account_policy_manager.cc +++ /dev/null
@@ -1,52 +0,0 @@ -// Copyright 2022 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "chrome/browser/signin/primary_account_policy_manager.h" - -#include "chrome/browser/browser_process.h" -#include "chrome/browser/profiles/profile.h" -#include "chrome/browser/signin/signin_util.h" -#include "components/signin/public/base/signin_metrics.h" -#include "components/signin/public/base/signin_pref_names.h" - -PrimaryAccountPolicyManager::PrimaryAccountPolicyManager(Profile* profile) - : profile_(profile) { - DCHECK(profile_); - DCHECK(!profile_->IsOffTheRecord()); -} - -PrimaryAccountPolicyManager::~PrimaryAccountPolicyManager() = default; - -void PrimaryAccountPolicyManager::Initialize() { - signin_util::EnsurePrimaryAccountAllowedForProfile( - profile_, signin_metrics::SIGNIN_NOT_ALLOWED_ON_PROFILE_INIT); - - signin_allowed_.Init( - prefs::kSigninAllowed, profile_->GetPrefs(), - base::BindRepeating( - &PrimaryAccountPolicyManager::OnSigninAllowedPrefChanged, - weak_pointer_factory_.GetWeakPtr())); - - local_state_pref_registrar_.Init(g_browser_process->local_state()); - local_state_pref_registrar_.Add( - prefs::kGoogleServicesUsernamePattern, - base::BindRepeating( - &PrimaryAccountPolicyManager::OnGoogleServicesUsernamePatternChanged, - weak_pointer_factory_.GetWeakPtr())); -} - -void PrimaryAccountPolicyManager::Shutdown() { - local_state_pref_registrar_.RemoveAll(); - signin_allowed_.Destroy(); -} - -void PrimaryAccountPolicyManager::OnGoogleServicesUsernamePatternChanged() { - signin_util::EnsurePrimaryAccountAllowedForProfile( - profile_, signin_metrics::GOOGLE_SERVICE_NAME_PATTERN_CHANGED); -} - -void PrimaryAccountPolicyManager::OnSigninAllowedPrefChanged() { - signin_util::EnsurePrimaryAccountAllowedForProfile( - profile_, signin_metrics::SIGNOUT_PREF_CHANGED); -}
diff --git a/chrome/browser/signin/primary_account_policy_manager.h b/chrome/browser/signin/primary_account_policy_manager.h deleted file mode 100644 index 1901750..0000000 --- a/chrome/browser/signin/primary_account_policy_manager.h +++ /dev/null
@@ -1,44 +0,0 @@ -// Copyright 2022 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef CHROME_BROWSER_SIGNIN_PRIMARY_ACCOUNT_POLICY_MANAGER_H_ -#define CHROME_BROWSER_SIGNIN_PRIMARY_ACCOUNT_POLICY_MANAGER_H_ - -#include "base/memory/raw_ptr.h" -#include "base/memory/weak_ptr.h" -#include "components/keyed_service/core/keyed_service.h" -#include "components/prefs/pref_change_registrar.h" -#include "components/prefs/pref_member.h" - -class Profile; - -class PrimaryAccountPolicyManager : public KeyedService { - public: - explicit PrimaryAccountPolicyManager(Profile* profile); - ~PrimaryAccountPolicyManager() override; - - PrimaryAccountPolicyManager(const PrimaryAccountPolicyManager&) = delete; - PrimaryAccountPolicyManager& operator=(const PrimaryAccountPolicyManager&) = - delete; - - void Initialize(); - void Shutdown() override; - - private: - void OnSigninAllowedPrefChanged(); - void OnGoogleServicesUsernamePatternChanged(); - - raw_ptr<Profile> profile_; - - // Helper object to listen for changes to the signin allowed preference. - BooleanPrefMember signin_allowed_; - - // Helper object to listen for changes to signin preferences stored in non- - // profile-specific local prefs (like kGoogleServicesUsernamePattern). - PrefChangeRegistrar local_state_pref_registrar_; - - base::WeakPtrFactory<PrimaryAccountPolicyManager> weak_pointer_factory_{this}; -}; - -#endif // CHROME_BROWSER_SIGNIN_PRIMARY_ACCOUNT_POLICY_MANAGER_H_
diff --git a/chrome/browser/signin/primary_account_policy_manager_factory.cc b/chrome/browser/signin/primary_account_policy_manager_factory.cc deleted file mode 100644 index da80239..0000000 --- a/chrome/browser/signin/primary_account_policy_manager_factory.cc +++ /dev/null
@@ -1,40 +0,0 @@ -// Copyright 2022 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "chrome/browser/signin/primary_account_policy_manager_factory.h" - -#include "chrome/browser/profiles/profile.h" -#include "chrome/browser/signin/identity_manager_factory.h" -#include "components/keyed_service/content/browser_context_dependency_manager.h" - -// static -PrimaryAccountPolicyManagerFactory* -PrimaryAccountPolicyManagerFactory::GetInstance() { - static base::NoDestructor<PrimaryAccountPolicyManagerFactory> instance; - return instance.get(); -} - -// static -PrimaryAccountPolicyManager* PrimaryAccountPolicyManagerFactory::GetForProfile( - Profile* profile) { - DCHECK(profile); - return static_cast<PrimaryAccountPolicyManager*>( - GetInstance()->GetServiceForBrowserContext(profile, true)); -} - -PrimaryAccountPolicyManagerFactory::PrimaryAccountPolicyManagerFactory() - : BrowserContextKeyedServiceFactory( - "PrimaryAccountPolicyManager", - BrowserContextDependencyManager::GetInstance()) { - DependsOn(IdentityManagerFactory::GetInstance()); -} - -PrimaryAccountPolicyManagerFactory::~PrimaryAccountPolicyManagerFactory() = - default; - -KeyedService* PrimaryAccountPolicyManagerFactory::BuildServiceInstanceFor( - content::BrowserContext* context) const { - Profile* profile = Profile::FromBrowserContext(context); - return new PrimaryAccountPolicyManager(profile); -}
diff --git a/chrome/browser/signin/primary_account_policy_manager_factory.h b/chrome/browser/signin/primary_account_policy_manager_factory.h deleted file mode 100644 index a82b211..0000000 --- a/chrome/browser/signin/primary_account_policy_manager_factory.h +++ /dev/null
@@ -1,32 +0,0 @@ -// Copyright 2022 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef CHROME_BROWSER_SIGNIN_PRIMARY_ACCOUNT_POLICY_MANAGER_FACTORY_H_ -#define CHROME_BROWSER_SIGNIN_PRIMARY_ACCOUNT_POLICY_MANAGER_FACTORY_H_ - -#include "base/no_destructor.h" -#include "chrome/browser/signin/primary_account_policy_manager.h" -#include "components/keyed_service/content/browser_context_keyed_service_factory.h" - -class Profile; - -class PrimaryAccountPolicyManagerFactory - : public BrowserContextKeyedServiceFactory { - public: - // Returns an instance of the factory singleton. - static PrimaryAccountPolicyManagerFactory* GetInstance(); - static PrimaryAccountPolicyManager* GetForProfile(Profile* profile); - - private: - friend base::NoDestructor<PrimaryAccountPolicyManagerFactory>; - - PrimaryAccountPolicyManagerFactory(); - ~PrimaryAccountPolicyManagerFactory() override; - - // BrowserContextKeyedServiceFactory: - KeyedService* BuildServiceInstanceFor( - content::BrowserContext* context) const override; -}; - -#endif // CHROME_BROWSER_SIGNIN_PRIMARY_ACCOUNT_POLICY_MANAGER_FACTORY_H_
diff --git a/chrome/browser/signin/primary_account_policy_manager_unittest.cc b/chrome/browser/signin/primary_account_policy_manager_unittest.cc deleted file mode 100644 index 3844e09e..0000000 --- a/chrome/browser/signin/primary_account_policy_manager_unittest.cc +++ /dev/null
@@ -1,100 +0,0 @@ -// Copyright 2022 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "chrome/browser/signin/primary_account_policy_manager.h" - -#include "build/buildflag.h" -#include "build/chromeos_buildflags.h" -#include "chrome/browser/signin/identity_test_environment_profile_adaptor.h" -#include "chrome/browser/signin/primary_account_policy_manager_factory.h" -#include "chrome/test/base/scoped_testing_local_state.h" -#include "chrome/test/base/testing_browser_process.h" -#include "chrome/test/base/testing_profile.h" -#include "chrome/test/base/testing_profile_manager.h" -#include "components/signin/public/base/signin_pref_names.h" -#include "components/signin/public/identity_manager/identity_test_environment.h" -#include "content/public/test/browser_task_environment.h" -#include "testing/gtest/include/gtest/gtest.h" - -class PrimaryAccountPolicyManagerTest : public testing::Test { - public: - PrimaryAccountPolicyManagerTest() - : profile_manager_(TestingBrowserProcess::GetGlobal()) {} - - void SetUp() override { ASSERT_TRUE(profile_manager_.SetUp()); } - - ~PrimaryAccountPolicyManagerTest() override = default; - - void CreateTestingProfile() { - DCHECK(!profile_); - - profile_ = profile_manager_.CreateTestingProfile( - "primary_account_policy_manager_test_profile_path", - IdentityTestEnvironmentProfileAdaptor:: - GetIdentityTestEnvironmentFactories()); - identity_test_env_adaptor_ = - std::make_unique<IdentityTestEnvironmentProfileAdaptor>(profile_); - } - - void DestroyProfile() { - identity_test_env_adaptor_.reset(); - profile_ = nullptr; - profile_manager_.DeleteTestingProfile( - "primary_account_policy_manager_test_profile_path"); - } - - PrefService* GetLocalState() { return profile_manager_.local_state()->Get(); } - - TestingProfile* GetProfile() { - DCHECK(profile_); - return profile_; - } - - signin::IdentityTestEnvironment* GetIdentityTestEnv() { - DCHECK(identity_test_env_adaptor_); - return identity_test_env_adaptor_->identity_test_env(); - } - - private: - content::BrowserTaskEnvironment task_environment_; - TestingProfileManager profile_manager_; - TestingProfile* profile_ = nullptr; - std::unique_ptr<IdentityTestEnvironmentProfileAdaptor> - identity_test_env_adaptor_; -}; - -#if !BUILDFLAG(IS_CHROMEOS_ASH) -// All primary accounts are allowed on ChromeOS, so this the -// PrimaryAccountPolicyManagerTest does not clear the primary account on -// ChromeOS. -// -// TODO(msarda): Exclude |PrimaryAccountPolicyManager| from the ChromeOS -// build. -TEST_F(PrimaryAccountPolicyManagerTest, - ClearPrimarySyncAccountWhenSigninNotAllowed) { - CreateTestingProfile(); - GetIdentityTestEnv()->MakePrimaryAccountAvailable( - "test@foo.com", signin::ConsentLevel::kSync); - GetProfile()->GetPrefs()->SetBoolean(prefs::kSigninAllowed, false); - - EXPECT_FALSE(GetIdentityTestEnv()->identity_manager()->HasPrimaryAccount( - signin::ConsentLevel::kSync)); - EXPECT_FALSE(GetIdentityTestEnv()->identity_manager()->HasPrimaryAccount( - signin::ConsentLevel::kSignin)); -} - -TEST_F(PrimaryAccountPolicyManagerTest, - ClearPrimarySyncAccountWhenPatternNotAllowed) { - CreateTestingProfile(); - GetIdentityTestEnv()->MakePrimaryAccountAvailable( - "test@foo.com", signin::ConsentLevel::kSync); - GetLocalState()->SetString(prefs::kGoogleServicesUsernamePattern, - ".*@bar.com"); - - EXPECT_FALSE(GetIdentityTestEnv()->identity_manager()->HasPrimaryAccount( - signin::ConsentLevel::kSync)); - EXPECT_FALSE(GetIdentityTestEnv()->identity_manager()->HasPrimaryAccount( - signin::ConsentLevel::kSignin)); -} -#endif // !BUILDFLAG(IS_CHROMEOS_ASH)
diff --git a/chrome/browser/signin/signin_util.cc b/chrome/browser/signin/signin_util.cc index fc2ef2b..d3ab11f8 100644 --- a/chrome/browser/signin/signin_util.cc +++ b/chrome/browser/signin/signin_util.cc
@@ -290,9 +290,7 @@ UserSignoutSetting::GetForProfile(profile)->set_state(new_state); } -void EnsurePrimaryAccountAllowedForProfile( - Profile* profile, - signin_metrics::ProfileSignout clear_primary_account_source) { +void EnsurePrimaryAccountAllowedForProfile(Profile* profile) { // All primary accounts are allowed on ChromeOS, so this method is a no-op on // ChromeOS. #if !BUILDFLAG(IS_CHROMEOS_ASH) @@ -320,7 +318,7 @@ auto* primary_account_mutator = identity_manager->GetPrimaryAccountMutator(); primary_account_mutator->ClearPrimaryAccount( - clear_primary_account_source, + signin_metrics::SIGNIN_NOT_ALLOWED_ON_PROFILE_INIT, signin_metrics::SignoutDelete::kIgnoreMetric); break; }
diff --git a/chrome/browser/signin/signin_util.h b/chrome/browser/signin/signin_util.h index 0a6693f..10a436d 100644 --- a/chrome/browser/signin/signin_util.h +++ b/chrome/browser/signin/signin_util.h
@@ -8,7 +8,6 @@ #include <string> #include "build/build_config.h" -#include "components/signin/public/base/signin_metrics.h" class Profile; @@ -55,11 +54,7 @@ // is no longer allowed, then this clears the primary account. // * If |IsUserSignoutAllowedForProfile| is not allowed and the primary account // is not longer allowed, then this removes the profile. -// -// TODO(msarda): Move to |primary_account_policy_manager.h| -void EnsurePrimaryAccountAllowedForProfile( - Profile* profile, - signin_metrics::ProfileSignout clear_primary_account_source); +void EnsurePrimaryAccountAllowedForProfile(Profile* profile); #if !defined(OS_ANDROID) // Returns true if profile separation is enforced by policy.
diff --git a/chrome/browser/ssl/sct_reporting_service_browsertest.cc b/chrome/browser/ssl/sct_reporting_service_browsertest.cc index c0de69c8..5c6abad 100644 --- a/chrome/browser/ssl/sct_reporting_service_browsertest.cc +++ b/chrome/browser/ssl/sct_reporting_service_browsertest.cc
@@ -3,9 +3,9 @@ // found in the LICENSE file. #include <memory> +#include <tuple> #include "base/callback.h" -#include "base/ignore_result.h" #include "base/synchronization/lock.h" #include "base/test/scoped_feature_list.h" #include "build/build_config.h" @@ -102,7 +102,7 @@ true); // The report server must be initialized here so the reporting URL can be // set before the network service is initialized. - ignore_result(report_server()->InitializeAndListen()); + std::ignore = report_server()->InitializeAndListen(); SCTReportingService::GetReportURLInstance() = report_server()->GetURL("/"); } ~SCTReportingServiceBrowserTest() override {
diff --git a/chrome/browser/sync/test/integration/two_client_web_apps_integration_test_mac_win_linux.cc b/chrome/browser/sync/test/integration/two_client_web_apps_integration_test_mac_win_linux.cc index 5fd5c3f..40fddc3 100644 --- a/chrome/browser/sync/test/integration/two_client_web_apps_integration_test_mac_win_linux.cc +++ b/chrome/browser/sync/test/integration/two_client_web_apps_integration_test_mac_win_linux.cc
@@ -405,7 +405,7 @@ IN_PROC_BROWSER_TEST_F( TwoClientWebAppsIntegrationTestMacWinLinux, - WebAppIntegration_TurnSyncOff_InstCrtShctWindowedSiteA_TurnSyncOn_SwitchProfileClientClient2_InListNotLclyInstSiteA) { + DISABLED_WebAppIntegration_TurnSyncOff_InstCrtShctWindowedSiteA_TurnSyncOn_SwitchProfileClientClient2_InListNotLclyInstSiteA) { // Test contents are generated by script. Please do not modify! // See `chrome/test/webapps/README.md` for more info. // Sheriffs: Disabling this test is supported. @@ -418,7 +418,7 @@ IN_PROC_BROWSER_TEST_F( TwoClientWebAppsIntegrationTestMacWinLinux, - WebAppIntegration_TurnSyncOff_InstOmniboxSiteA_TurnSyncOn_SwitchProfileClientClient2_InListNotLclyInstSiteA) { + DISABLED_WebAppIntegration_TurnSyncOff_InstOmniboxSiteA_TurnSyncOn_SwitchProfileClientClient2_InListNotLclyInstSiteA) { // Test contents are generated by script. Please do not modify! // See `chrome/test/webapps/README.md` for more info. // Sheriffs: Disabling this test is supported.
diff --git a/chrome/browser/tab/BUILD.gn b/chrome/browser/tab/BUILD.gn index 619931c..2769c88 100644 --- a/chrome/browser/tab/BUILD.gn +++ b/chrome/browser/tab/BUILD.gn
@@ -24,7 +24,6 @@ "java/src/org/chromium/chrome/browser/tab/TabObserver.java", "java/src/org/chromium/chrome/browser/tab/TabResolver.java", "java/src/org/chromium/chrome/browser/tab/TabState.java", - "java/src/org/chromium/chrome/browser/tab/TabStateAttributes.java", "java/src/org/chromium/chrome/browser/tab/TabViewManager.java", "java/src/org/chromium/chrome/browser/tab/TabViewProvider.java", "java/src/org/chromium/chrome/browser/tab/TabWebContentsDelegateAndroid.java", @@ -61,6 +60,7 @@ ] deps = [ + ":critical_persisted_tab_data_flatbuffer_java", ":critical_persisted_tab_data_proto_java", ":java_resources", "//base:base_java", @@ -90,6 +90,7 @@ "//components/strings:components_strings_grd", "//content/public/android:content_java", "//net/android:net_java", + "//third_party/android_deps:com_google_flatbuffers_flatbuffers_java_java", "//third_party/android_deps:guava_android_java", "//third_party/android_deps:protobuf_lite_runtime_java", "//third_party/androidx:androidx_annotation_annotation_java", @@ -132,6 +133,11 @@ ] } +flatbuffer_java_library("critical_persisted_tab_data_flatbuffer_java") { + root_dir = "java/src/org/chromium/chrome/browser/tab/state/flatbuffer" + sources = [ "$root_dir/critical_persisted_tab_data.fbs" ] +} + android_library("junit") { bypass_platform_checks = true testonly = true
diff --git a/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/Tab.java b/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/Tab.java index d9cdc11..188f660 100644 --- a/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/Tab.java +++ b/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/Tab.java
@@ -273,6 +273,14 @@ void goForward(); /** + * Set whether the TabState representing this Tab has been updated. + * This method will ultimately be deprecated when the migration + * to CriticalPersistedTabData is complete. + * @param isDirty Whether the Tab's state has changed. + */ + void setIsTabStateDirty(boolean isTabStateDirty); + + /** * Set whether {@link Tab} metadata (specifically all {@link PersistedTabData}) * will be saved. Not all Tabs need to be persisted across restarts. * The default value when a Tab is initialized is false.
diff --git a/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/TabStateAttributes.java b/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/TabStateAttributes.java deleted file mode 100644 index 098c9e7..0000000 --- a/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/TabStateAttributes.java +++ /dev/null
@@ -1,45 +0,0 @@ -// Copyright 2022 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.chrome.browser.tab; - -import org.chromium.base.UserData; -import org.chromium.base.UserDataHost; - -/** - * Attributes related to {@link TabState} - */ -public class TabStateAttributes implements UserData { - private static final Class<TabStateAttributes> USER_DATA_KEY = TabStateAttributes.class; - /** Whether or not the TabState has changed. */ - private boolean mIsTabStateDirty = true; - - /** - * @return {@link TabStateAttributes} for a {@link Tab} - */ - public static TabStateAttributes from(Tab tab) { - UserDataHost host = tab.getUserDataHost(); - TabStateAttributes attrs = host.getUserData(USER_DATA_KEY); - return attrs != null ? attrs : host.setUserData(USER_DATA_KEY, new TabStateAttributes()); - } - - private TabStateAttributes() {} - - /** - * @return true if the {@link TabState} has been changed - */ - public boolean isTabStateDirty() { - return mIsTabStateDirty; - } - - /** - * Set whether the TabState representing this Tab has been updated. - * This method will ultimately be deprecated when the migration - * to CriticalPersistedTabData is complete. - * @param isTabStateDirty whether the Tab's state has changed. - */ - public void setIsTabStateDirty(boolean isTabStateDirty) { - mIsTabStateDirty = isTabStateDirty; - } -}
diff --git a/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/CriticalPersistedTabData.java b/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/CriticalPersistedTabData.java index ef8a8acb..2ee58b3 100644 --- a/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/CriticalPersistedTabData.java +++ b/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/CriticalPersistedTabData.java
@@ -10,20 +10,20 @@ import androidx.annotation.Nullable; import androidx.annotation.VisibleForTesting; -import com.google.protobuf.ByteString; -import com.google.protobuf.InvalidProtocolBufferException; +import com.google.flatbuffers.FlatBufferBuilder; import org.chromium.base.Callback; -import org.chromium.base.Log; import org.chromium.base.ObserverList; import org.chromium.base.TraceEvent; import org.chromium.base.supplier.Supplier; import org.chromium.chrome.browser.tab.Tab; import org.chromium.chrome.browser.tab.TabLaunchType; -import org.chromium.chrome.browser.tab.TabStateAttributes; import org.chromium.chrome.browser.tab.TabUserAgent; import org.chromium.chrome.browser.tab.WebContentsState; import org.chromium.chrome.browser.tab.WebContentsStateBridge; +import org.chromium.chrome.browser.tab.flatbuffer.CriticalPersistedTabDataFlatBuffer; +import org.chromium.chrome.browser.tab.flatbuffer.LaunchTypeAtCreation; +import org.chromium.chrome.browser.tab.flatbuffer.UserAgentType; import org.chromium.chrome.browser.tab.proto.CriticalPersistedTabData.CriticalPersistedTabDataProto; import org.chromium.components.embedder_support.util.UrlConstants; import org.chromium.components.embedder_support.util.UrlUtilities; @@ -32,7 +32,6 @@ import org.chromium.url.GURL; import java.nio.ByteBuffer; -import java.util.Locale; /** * Data which is core to the app and must be retrieved as quickly as possible on startup. @@ -43,6 +42,7 @@ CriticalPersistedTabData.class; private static final int UNSPECIFIED_THEME_COLOR = Color.TRANSPARENT; + private static final String NULL_OPENER_APP_ID = " "; public static final long INVALID_TIMESTAMP = -1; /** @@ -215,41 +215,28 @@ @Override boolean deserialize(@Nullable ByteBuffer bytes) { try (TraceEvent e = TraceEvent.scoped("CriticalPersistedTabData.Deserialize")) { - CriticalPersistedTabDataProto criticalPersistedTabDataProto = - CriticalPersistedTabDataProto.parseFrom(bytes); - mParentId = criticalPersistedTabDataProto.getParentId(); - mRootId = criticalPersistedTabDataProto.getRootId(); - mTimestampMillis = criticalPersistedTabDataProto.getTimestampMillis(); - ByteString webContentsStateByteString = - criticalPersistedTabDataProto.getWebContentsStateBytes(); - mWebContentsState = new WebContentsState( - ByteBuffer.allocateDirect(webContentsStateByteString.size())); - webContentsStateByteString.copyTo(mWebContentsState.buffer()); + CriticalPersistedTabDataFlatBuffer deserialized = + CriticalPersistedTabDataFlatBuffer.getRootAsCriticalPersistedTabDataFlatBuffer( + bytes); + mParentId = deserialized.parentId(); + mRootId = deserialized.rootId(); + mTimestampMillis = deserialized.timestampMillis(); + mWebContentsState = + new WebContentsState(deserialized.webContentsStateBytesAsByteBuffer().slice()); mWebContentsState.setVersion(WebContentsState.CONTENTS_STATE_CURRENT_VERSION); mUrl = mWebContentsState.getVirtualUrlFromState() == null ? GURL.emptyGURL() : new GURL(mWebContentsState.getVirtualUrlFromState()); mTitle = mWebContentsState.getDisplayTitleFromState(); - mContentStateVersion = criticalPersistedTabDataProto.getContentStateVersion(); - mOpenerAppId = TextUtils.isEmpty(criticalPersistedTabDataProto.getOpenerAppId()) + mContentStateVersion = deserialized.contentStateVersion(); + mOpenerAppId = NULL_OPENER_APP_ID.equals(deserialized.openerAppId()) ? null - : criticalPersistedTabDataProto.getOpenerAppId(); - mThemeColor = criticalPersistedTabDataProto.getThemeColor(); - mTabLaunchTypeAtCreation = - getLaunchType(criticalPersistedTabDataProto.getLaunchTypeAtCreation()); - if (criticalPersistedTabDataProto.hasUserAgent()) { - mUserAgent = getUserAgentType(criticalPersistedTabDataProto.getUserAgent()); - } else { - mUserAgent = TabUserAgent.UNSET; - } + : deserialized.openerAppId(); + mThemeColor = deserialized.themeColor(); + mTabLaunchTypeAtCreation = getLaunchType(deserialized.launchTypeAtCreation()); + mUserAgent = getTabUserAgentType(deserialized.userAgent()); return true; - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, - String.format(Locale.ENGLISH, - "There was a problem deserializing Tab %d. Details: %s", mTab.getId(), - e.getMessage())); } - return false; } @Override @@ -258,142 +245,135 @@ } @VisibleForTesting - static @Nullable @TabLaunchType Integer getLaunchType( - CriticalPersistedTabDataProto.LaunchTypeAtCreation protoLaunchType) { - switch (protoLaunchType) { - case FROM_LINK: + static @Nullable @TabLaunchType Integer getLaunchType(int flatBufferLaunchType) { + switch (flatBufferLaunchType) { + case LaunchTypeAtCreation.FROM_LINK: return TabLaunchType.FROM_LINK; - case FROM_EXTERNAL_APP: + case LaunchTypeAtCreation.FROM_EXTERNAL_APP: return TabLaunchType.FROM_EXTERNAL_APP; - case FROM_CHROME_UI: + case LaunchTypeAtCreation.FROM_CHROME_UI: return TabLaunchType.FROM_CHROME_UI; - case FROM_RESTORE: + case LaunchTypeAtCreation.FROM_RESTORE: return TabLaunchType.FROM_RESTORE; - case FROM_LONGPRESS_FOREGROUND: + case LaunchTypeAtCreation.FROM_LONGPRESS_FOREGROUND: return TabLaunchType.FROM_LONGPRESS_FOREGROUND; - case FROM_LONGPRESS_BACKGROUND: + case LaunchTypeAtCreation.FROM_LONGPRESS_BACKGROUND: return TabLaunchType.FROM_LONGPRESS_BACKGROUND; - case FROM_REPARENTING: + case LaunchTypeAtCreation.FROM_REPARENTING: return TabLaunchType.FROM_REPARENTING; - case FROM_LAUNCHER_SHORTCUT: + case LaunchTypeAtCreation.FROM_LAUNCHER_SHORTCUT: return TabLaunchType.FROM_LAUNCHER_SHORTCUT; - case FROM_SPECULATIVE_BACKGROUND_CREATION: + case LaunchTypeAtCreation.FROM_SPECULATIVE_BACKGROUND_CREATION: return TabLaunchType.FROM_SPECULATIVE_BACKGROUND_CREATION; - case FROM_BROWSER_ACTIONS: + case LaunchTypeAtCreation.FROM_BROWSER_ACTIONS: return TabLaunchType.FROM_BROWSER_ACTIONS; - case FROM_LAUNCH_NEW_INCOGNITO_TAB: + case LaunchTypeAtCreation.FROM_LAUNCH_NEW_INCOGNITO_TAB: return TabLaunchType.FROM_LAUNCH_NEW_INCOGNITO_TAB; - case FROM_STARTUP: + case LaunchTypeAtCreation.FROM_STARTUP: return TabLaunchType.FROM_STARTUP; - case FROM_START_SURFACE: + case LaunchTypeAtCreation.FROM_START_SURFACE: return TabLaunchType.FROM_START_SURFACE; - case FROM_TAB_GROUP_UI: + case LaunchTypeAtCreation.FROM_TAB_GROUP_UI: return TabLaunchType.FROM_TAB_GROUP_UI; - case FROM_LONGPRESS_BACKGROUND_IN_GROUP: + case LaunchTypeAtCreation.FROM_LONGPRESS_BACKGROUND_IN_GROUP: return TabLaunchType.FROM_LONGPRESS_BACKGROUND_IN_GROUP; - case FROM_APP_WIDGET: + case LaunchTypeAtCreation.FROM_APP_WIDGET: return TabLaunchType.FROM_APP_WIDGET; - case SIZE: + case LaunchTypeAtCreation.SIZE: return TabLaunchType.SIZE; default: assert false : "Unexpected deserialization of LaunchAtCreationType: " - + protoLaunchType; + + flatBufferLaunchType; // shouldn't happen return null; } } @VisibleForTesting - static CriticalPersistedTabDataProto.LaunchTypeAtCreation getLaunchType( - @Nullable @TabLaunchType Integer protoLaunchType) { - if (protoLaunchType == null) { - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.UNKNOWN; + static int getLaunchType(@Nullable @TabLaunchType Integer tabLaunchType) { + if (tabLaunchType == null) { + return LaunchTypeAtCreation.UNKNOWN; } - switch (protoLaunchType) { + switch (tabLaunchType) { case TabLaunchType.FROM_LINK: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.FROM_LINK; + return LaunchTypeAtCreation.FROM_LINK; case TabLaunchType.FROM_EXTERNAL_APP: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.FROM_EXTERNAL_APP; + return LaunchTypeAtCreation.FROM_EXTERNAL_APP; case TabLaunchType.FROM_CHROME_UI: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.FROM_CHROME_UI; + return LaunchTypeAtCreation.FROM_CHROME_UI; case TabLaunchType.FROM_RESTORE: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.FROM_RESTORE; + return LaunchTypeAtCreation.FROM_RESTORE; case TabLaunchType.FROM_LONGPRESS_FOREGROUND: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.FROM_LONGPRESS_FOREGROUND; + return LaunchTypeAtCreation.FROM_LONGPRESS_FOREGROUND; case TabLaunchType.FROM_LONGPRESS_BACKGROUND: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.FROM_LONGPRESS_BACKGROUND; + return LaunchTypeAtCreation.FROM_LONGPRESS_BACKGROUND; case TabLaunchType.FROM_REPARENTING: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.FROM_REPARENTING; + return LaunchTypeAtCreation.FROM_REPARENTING; case TabLaunchType.FROM_LAUNCHER_SHORTCUT: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.FROM_LAUNCHER_SHORTCUT; + return LaunchTypeAtCreation.FROM_LAUNCHER_SHORTCUT; case TabLaunchType.FROM_SPECULATIVE_BACKGROUND_CREATION: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation - .FROM_SPECULATIVE_BACKGROUND_CREATION; + return LaunchTypeAtCreation.FROM_SPECULATIVE_BACKGROUND_CREATION; case TabLaunchType.FROM_BROWSER_ACTIONS: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.FROM_BROWSER_ACTIONS; + return LaunchTypeAtCreation.FROM_BROWSER_ACTIONS; case TabLaunchType.FROM_LAUNCH_NEW_INCOGNITO_TAB: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation - .FROM_LAUNCH_NEW_INCOGNITO_TAB; + return LaunchTypeAtCreation.FROM_LAUNCH_NEW_INCOGNITO_TAB; case TabLaunchType.FROM_STARTUP: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.FROM_STARTUP; + return LaunchTypeAtCreation.FROM_STARTUP; case TabLaunchType.FROM_START_SURFACE: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.FROM_START_SURFACE; + return LaunchTypeAtCreation.FROM_START_SURFACE; case TabLaunchType.FROM_TAB_GROUP_UI: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.FROM_TAB_GROUP_UI; + return LaunchTypeAtCreation.FROM_TAB_GROUP_UI; case TabLaunchType.FROM_LONGPRESS_BACKGROUND_IN_GROUP: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation - .FROM_LONGPRESS_BACKGROUND_IN_GROUP; + return LaunchTypeAtCreation.FROM_LONGPRESS_BACKGROUND_IN_GROUP; case TabLaunchType.FROM_APP_WIDGET: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.FROM_APP_WIDGET; + return LaunchTypeAtCreation.FROM_APP_WIDGET; case TabLaunchType.SIZE: - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.SIZE; + return LaunchTypeAtCreation.SIZE; default: - assert false : "Unexpected serialization of LaunchAtCreationType: " - + protoLaunchType; + assert false : "Unexpected serialization of LaunchAtCreationType: " + tabLaunchType; // shouldn't happen - return CriticalPersistedTabDataProto.LaunchTypeAtCreation.UNKNOWN; + return LaunchTypeAtCreation.UNKNOWN; } } @VisibleForTesting - static @TabUserAgent int getUserAgentType( - CriticalPersistedTabDataProto.UserAgentType protoUserAgent) { - switch (protoUserAgent) { - case DEFAULT: + static @TabUserAgent int getTabUserAgentType(int flatbufferUserAgentType) { + switch (flatbufferUserAgentType) { + case UserAgentType.DEFAULT: return TabUserAgent.DEFAULT; - case MOBILE: + case UserAgentType.MOBILE: return TabUserAgent.MOBILE; - case DESKTOP: + case UserAgentType.DESKTOP: return TabUserAgent.DESKTOP; - case UNSET: + case UserAgentType.UNSET: return TabUserAgent.UNSET; - case USER_AGENT_SIZE: + case UserAgentType.USER_AGENT_SIZE: return TabUserAgent.SIZE; default: - assert false : "Unexpected deserialization of UserAgentType: " + protoUserAgent; + assert false : "Unexpected deserialization of UserAgentType: " + + flatbufferUserAgentType; // shouldn't happen return TabUserAgent.DEFAULT; } } @VisibleForTesting - static CriticalPersistedTabDataProto.UserAgentType getUserAgentType( - @TabUserAgent int protoUserAgent) { - switch (protoUserAgent) { + static int getUserAgentType(@TabUserAgent int userAgent) { + switch (userAgent) { case TabUserAgent.DEFAULT: - return CriticalPersistedTabDataProto.UserAgentType.DEFAULT; + return UserAgentType.DEFAULT; case TabUserAgent.MOBILE: - return CriticalPersistedTabDataProto.UserAgentType.MOBILE; + return UserAgentType.MOBILE; case TabUserAgent.DESKTOP: - return CriticalPersistedTabDataProto.UserAgentType.DESKTOP; + return UserAgentType.DESKTOP; case TabUserAgent.UNSET: - return CriticalPersistedTabDataProto.UserAgentType.UNSET; + return UserAgentType.UNSET; case TabUserAgent.SIZE: - return CriticalPersistedTabDataProto.UserAgentType.USER_AGENT_SIZE; + return UserAgentType.USER_AGENT_SIZE; default: - assert false : "Unexpected serialization of UserAgentType: " + protoUserAgent; + assert false : "Unexpected serialization of UserAgentType: " + userAgent; // shouldn't happen - return CriticalPersistedTabDataProto.UserAgentType.DEFAULT; + return UserAgentType.USER_AGENT_UNKNOWN; } } @@ -429,32 +409,54 @@ CriticalPersistedTabDataProto.Builder builder; final WebContentsState webContentsState; final ByteBuffer byteBuffer; + final String openerAppId; + final int parentId; + final int rootId; + final long timestampMillis; + final int webContentsStateVersion; + final int themeColor; + final int launchType; + final int userAgentType; + FlatBufferBuilder fbb = new FlatBufferBuilder(); try (TraceEvent e = TraceEvent.scoped("CriticalPersistedTabData.PreSerialize")) { webContentsState = mWebContentsState == null ? getWebContentsStateFromTab(mTab) : mWebContentsState; byteBuffer = webContentsState == null ? null : webContentsState.buffer(); - builder = CriticalPersistedTabDataProto.newBuilder() - .setParentId(mParentId) - .setRootId(mRootId) - .setTimestampMillis(mTimestampMillis) - .setContentStateVersion(mContentStateVersion) - .setOpenerAppId(mOpenerAppId == null ? "" : mOpenerAppId) - .setThemeColor(mThemeColor) - .setLaunchTypeAtCreation(getLaunchType(mTabLaunchTypeAtCreation)) - .setUserAgent(getUserAgentType(mUserAgent)); + if (byteBuffer != null) { + byteBuffer.rewind(); + } + openerAppId = mOpenerAppId; + parentId = mParentId; + rootId = mRootId; + timestampMillis = mTimestampMillis; + webContentsStateVersion = mContentStateVersion; + themeColor = mThemeColor; + launchType = getLaunchType(mTabLaunchTypeAtCreation); + userAgentType = getUserAgentType(mUserAgent); } return () -> { try (TraceEvent e = TraceEvent.scoped("CriticalPersistedTabData.Serialize")) { - // TODO(crbug.com/1203298) migrate to ByteString.copyFrom(ByteBuffer ...) - // in a thread safe way to avoid intermediate ByteBuffer -> byte[]. Be careful as - // this has caused crashes in the past crbug.com/1195550. - return builder - .setWebContentsStateBytes(byteBuffer == null - ? ByteString.EMPTY - : ByteString.copyFrom(getContentStateByteArray(byteBuffer))) - .build() - .toByteString() - .asReadOnlyByteBuffer(); + int wcs = CriticalPersistedTabDataFlatBuffer.createWebContentsStateBytesVector(fbb, + byteBuffer == null ? ByteBuffer.allocate(0).put(new byte[] {}) + : byteBuffer); + int oaid = + fbb.createString(mOpenerAppId == null ? NULL_OPENER_APP_ID : mOpenerAppId); + CriticalPersistedTabDataFlatBuffer.startCriticalPersistedTabDataFlatBuffer(fbb); + CriticalPersistedTabDataFlatBuffer.addParentId(fbb, parentId); + CriticalPersistedTabDataFlatBuffer.addRootId(fbb, rootId); + CriticalPersistedTabDataFlatBuffer.addTimestampMillis(fbb, timestampMillis); + CriticalPersistedTabDataFlatBuffer.addWebContentsStateBytes(fbb, wcs); + CriticalPersistedTabDataFlatBuffer.addContentStateVersion( + fbb, webContentsStateVersion); + CriticalPersistedTabDataFlatBuffer.addOpenerAppId(fbb, oaid); + CriticalPersistedTabDataFlatBuffer.addThemeColor(fbb, themeColor); + CriticalPersistedTabDataFlatBuffer.addLaunchTypeAtCreation(fbb, launchType); + CriticalPersistedTabDataFlatBuffer.addUserAgent(fbb, userAgentType); + int r = CriticalPersistedTabDataFlatBuffer.endCriticalPersistedTabDataFlatBuffer( + fbb); + fbb.finish(r); + + return fbb.dataBuffer(); } }; } @@ -565,13 +567,13 @@ * Set root id */ public void setRootId(int rootId) { - if (mRootId == rootId || mTab.isDestroyed()) return; + if (mRootId == rootId) return; // TODO(crbug.com/1059640) add in setters for all mutable fields mRootId = rootId; for (CriticalPersistedTabDataObserver observer : mObservers) { observer.onRootIdChanged(mTab, rootId); } - TabStateAttributes.from(mTab).setIsTabStateDirty(true); + mTab.setIsTabStateDirty(true); save(); }
diff --git a/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/FilePersistedTabDataStorage.java b/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/FilePersistedTabDataStorage.java index 3b042a3..cb7338e9 100644 --- a/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/FilePersistedTabDataStorage.java +++ b/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/FilePersistedTabDataStorage.java
@@ -27,6 +27,7 @@ import org.chromium.content_public.browser.UiThreadTaskTraits; import java.io.File; +import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.IOException; @@ -34,6 +35,7 @@ import java.lang.annotation.RetentionPolicy; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import java.nio.channels.FileChannel.MapMode; import java.util.LinkedList; import java.util.List; import java.util.Locale; @@ -373,11 +375,14 @@ @Override public ByteBuffer executeSyncTask() { boolean success = false; - byte[] res = null; + ByteBuffer res = null; + FileInputStream fileInputStream = null; try { long startTime = SystemClock.elapsedRealtime(); AtomicFile atomicFile = new AtomicFile(mFile); - res = atomicFile.readFully(); + fileInputStream = atomicFile.openRead(); + FileChannel channel = fileInputStream.getChannel(); + res = channel.map(MapMode.READ_ONLY, channel.position(), channel.size()); success = true; RecordHistogram.recordTimesHistogram( String.format(Locale.US, "Tabs.PersistedTabData.Storage.LoadTime.%s", @@ -398,7 +403,7 @@ } RecordHistogram.recordBooleanHistogram( "Tabs.PersistedTabData.Storage.Restore." + getUmaTag(), success); - return res == null ? null : ByteBuffer.wrap(res); + return res; } @Override
diff --git a/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/PersistedTabDataConfiguration.java b/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/PersistedTabDataConfiguration.java index cce2702..db519ce 100644 --- a/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/PersistedTabDataConfiguration.java +++ b/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/PersistedTabDataConfiguration.java
@@ -16,8 +16,8 @@ public enum PersistedTabDataConfiguration { // TODO(crbug.com/1059650) investigate should this go in the app code? // Also investigate if the storage instance should be shared. - CRITICAL_PERSISTED_TAB_DATA("CPTD"), - ENCRYPTED_CRITICAL_PERSISTED_TAB_DATA("ECPTD"), + CRITICAL_PERSISTED_TAB_DATA("CPTDFB"), + ENCRYPTED_CRITICAL_PERSISTED_TAB_DATA("ECPTDFB"), MOCK_PERSISTED_TAB_DATA("MPTD"), ENCRYPTED_MOCK_PERSISTED_TAB_DATA("EMPTD"), SHOPPING_PERSISTED_TAB_DATA("SPTD"),
diff --git a/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/flatbuffer/critical_persisted_tab_data.fbs b/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/flatbuffer/critical_persisted_tab_data.fbs new file mode 100644 index 0000000..340cfe6f --- /dev/null +++ b/chrome/browser/tab/java/src/org/chromium/chrome/browser/tab/state/flatbuffer/critical_persisted_tab_data.fbs
@@ -0,0 +1,64 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +namespace org.chromium.chrome.browser.tab.flatbuffer; + +enum LaunchTypeAtCreation:int { + FROM_LINK = 0, + FROM_EXTERNAL_APP = 1, + FROM_CHROME_UI = 2, + FROM_RESTORE = 3, + FROM_LONGPRESS_FOREGROUND = 4, + FROM_LONGPRESS_BACKGROUND = 5, + FROM_REPARENTING = 6, + FROM_LAUNCHER_SHORTCUT = 7, + FROM_SPECULATIVE_BACKGROUND_CREATION = 8, + FROM_BROWSER_ACTIONS = 9, + FROM_LAUNCH_NEW_INCOGNITO_TAB = 10, + FROM_STARTUP = 11, + FROM_START_SURFACE = 12, + FROM_TAB_GROUP_UI = 13, + FROM_LONGPRESS_BACKGROUND_IN_GROUP = 14, + FROM_APP_WIDGET = 15, + SIZE = 16, + UNKNOWN = 17, +} + +enum UserAgentType:int { + DEFAULT = 0, + MOBILE = 1, + DESKTOP = 2, + UNSET = 3, + USER_AGENT_SIZE = 4, + USER_AGENT_UNKNOWN = 5, +} + +table CriticalPersistedTabDataFlatBuffer { + // Parent Tab identifier. + parent_id:int; + + // Root Tab identifier. + root_id:int; + + // Timestamp when Tab was last accessed. + timestamp_millis:long; + + // WebContentsState. + web_contents_state_bytes:[byte]; + + // Content State version. + content_state_version:int; + + // Identifier for app which opened the Tab. + opener_app_id:string; + + // Theme color. + theme_color:int; + + // Launch type at creation. + launch_type_at_creation:LaunchTypeAtCreation; + + // User Agent. + user_agent:UserAgentType=DEFAULT; +}
diff --git a/chrome/browser/thumbnail/generator/android/thumbnail_media_parser_impl.cc b/chrome/browser/thumbnail/generator/android/thumbnail_media_parser_impl.cc index 3952e07..aa0c2b4 100644 --- a/chrome/browser/thumbnail/generator/android/thumbnail_media_parser_impl.cc +++ b/chrome/browser/thumbnail/generator/android/thumbnail_media_parser_impl.cc
@@ -4,10 +4,11 @@ #include "chrome/browser/thumbnail/generator/android/thumbnail_media_parser_impl.h" +#include <tuple> + #include "base/bind.h" #include "base/files/file.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/numerics/safe_conversions.h" #include "base/task/post_task.h" #include "base/task/task_runner_util.h" @@ -270,7 +271,7 @@ // ThumbnailMediaParser does not use them, but the Mojo argument is // currently marked as required so pass a remote but drop the other end. mojo::PendingRemote<media::mojom::FrameInterfaceFactory> interfaces; - ignore_result(interfaces.InitWithNewPipeAndPassReceiver()); + std::ignore = interfaces.InitWithNewPipeAndPassReceiver(); content::GetMediaService().CreateInterfaceFactory( media_interface_factory_.BindNewPipeAndPassReceiver(), std::move(interfaces));
diff --git a/chrome/browser/ui/ash/clipboard_history_browsertest.cc b/chrome/browser/ui/ash/clipboard_history_browsertest.cc index 43140256..537d12ef 100644 --- a/chrome/browser/ui/ash/clipboard_history_browsertest.cc +++ b/chrome/browser/ui/ash/clipboard_history_browsertest.cc
@@ -4,6 +4,7 @@ #include <list> #include <memory> +#include <tuple> #include "ash/clipboard/clipboard_history.h" #include "ash/clipboard/clipboard_history_controller_impl.h" @@ -15,7 +16,6 @@ #include "ash/public/cpp/clipboard_image_model_factory.h" #include "ash/shell.h" #include "base/bind.h" -#include "base/ignore_result.h" #include "base/path_service.h" #include "base/test/bind.h" #include "base/test/metrics/histogram_tester.h" @@ -805,8 +805,8 @@ // Wait for the paste event to propagate to the web contents. // The web contents will notify us a paste occurred by updating page title. - ignore_result( - content::TitleWatcher(web_contents, u"Paste 1").WaitAndGetTitle()); + std::ignore = + content::TitleWatcher(web_contents, u"Paste 1").WaitAndGetTitle(); // Confirm the expected paste data. base::ListValue last_paste = GetLastPaste(); @@ -825,8 +825,8 @@ // Wait for the paste event to propagate to the web contents. // The web contents will notify us a paste occurred by updating page title. - ignore_result( - content::TitleWatcher(web_contents, u"Paste 2").WaitAndGetTitle()); + std::ignore = + content::TitleWatcher(web_contents, u"Paste 2").WaitAndGetTitle(); // Confirm the expected paste data. last_paste = GetLastPaste();
diff --git a/chrome/browser/ui/commander/commander_controller_unittest.cc b/chrome/browser/ui/commander/commander_controller_unittest.cc index 2ea2c6d..7ef7bef5 100644 --- a/chrome/browser/ui/commander/commander_controller_unittest.cc +++ b/chrome/browser/ui/commander/commander_controller_unittest.cc
@@ -5,10 +5,10 @@ #include "chrome/browser/ui/commander/commander_controller.h" #include <string> +#include <tuple> #include "base/callback.h" #include "base/callback_helpers.h" -#include "base/ignore_result.h" #include "base/memory/raw_ptr.h" #include "base/run_loop.h" #include "base/test/bind.h" @@ -139,7 +139,7 @@ TEST_F(CommanderControllerTest, ResultSetIdsDifferAcrossCalls) { std::vector<std::unique_ptr<CommandSource>> sources; - ignore_result(AddSource(&sources, CreateNoOpCommandSource())); + std::ignore = AddSource(&sources, CreateNoOpCommandSource()); base::RunLoop run_loop; auto controller = CommanderController::CreateWithSourcesForTesting(std::move(sources));
diff --git a/chrome/browser/ui/login/login_handler_browsertest.cc b/chrome/browser/ui/login/login_handler_browsertest.cc index f3045a6..2f3d1758 100644 --- a/chrome/browser/ui/login/login_handler_browsertest.cc +++ b/chrome/browser/ui/login/login_handler_browsertest.cc
@@ -5,10 +5,10 @@ #include <algorithm> #include <list> #include <map> +#include <tuple> #include "base/bind.h" #include "base/feature_list.h" -#include "base/ignore_result.h" #include "base/metrics/field_trial.h" #include "base/strings/stringprintf.h" #include "base/strings/utf_string_conversions.h" @@ -2428,10 +2428,10 @@ const GURL kAuthIFrameUrl = embedded_test_server()->GetURL(kAuthBasicPage); content::RenderFrameHost* prerender_rfh = prerender_helper().GetPrerenderedMainFrameHost(host_id); - ignore_result(ExecJs(prerender_rfh, - "var i = document.createElement('iframe'); i.src = '" + - kAuthIFrameUrl.spec() + - "'; document.body.appendChild(i);")); + std::ignore = + ExecJs(prerender_rfh, + "var i = document.createElement('iframe'); i.src = '" + + kAuthIFrameUrl.spec() + "'; document.body.appendChild(i);"); // The prerender should be destroyed. host_observer.WaitForDestroyed(); @@ -2473,8 +2473,8 @@ imgElement.src = '/auth-basic/favicon.gif'; document.body.appendChild(imgElement); )"; - ignore_result(ExecJs(prerender_helper().GetPrerenderedMainFrameHost(host_id), - fetch_subresource_script)); + std::ignore = ExecJs(prerender_helper().GetPrerenderedMainFrameHost(host_id), + fetch_subresource_script); // The prerender should be destroyed. host_observer.WaitForDestroyed();
diff --git a/chrome/browser/ui/views/bookmarks/bookmark_bar_view_test.cc b/chrome/browser/ui/views/bookmarks/bookmark_bar_view_test.cc index a670b6c..3f4cdd2 100644 --- a/chrome/browser/ui/views/bookmarks/bookmark_bar_view_test.cc +++ b/chrome/browser/ui/views/bookmarks/bookmark_bar_view_test.cc
@@ -731,13 +731,7 @@ } }; -// TODO(crbug.com/1281365): Flakes on Linux. -#if defined(OS_LINUX) -#define MAYBE_Submenus DISABLED_Submenus -#else -#define MAYBE_Submenus Submenus -#endif -VIEW_TEST(BookmarkBarViewTest3, MAYBE_Submenus) +VIEW_TEST(BookmarkBarViewTest3, Submenus) // Observer that posts a task upon the context menu creation. // This is necessary for Linux as the context menu has to check the clipboard, @@ -828,13 +822,7 @@ BookmarkContextMenuNotificationObserver observer_; }; -// TODO(crbug.com/1281302): Flaky on Linux -#if defined(OS_LINUX) -#define MAYBE_ContextMenus DISABLED_ContextMenus -#else -#define MAYBE_ContextMenus ContextMenus -#endif -VIEW_TEST(BookmarkBarViewTest4, MAYBE_ContextMenus) +VIEW_TEST(BookmarkBarViewTest4, ContextMenus) // Tests drag and drop within the same menu. class BookmarkBarViewTest5 : public BookmarkBarViewDragTestBase { @@ -1268,14 +1256,7 @@ BookmarkContextMenuNotificationObserver observer_; }; -// TODO(https://crbug.com/1281220): Flaky on Linux -#if defined(OS_LINUX) -#define MAYBE_CloseMenuAfterClosingContextMenu \ - DISABLED_CloseMenuAfterClosingContextMenu -#else -#define MAYBE_CloseMenuAfterClosingContextMenu CloseMenuAfterClosingContextMenu -#endif -VIEW_TEST(BookmarkBarViewTest11, MAYBE_CloseMenuAfterClosingContextMenu) +VIEW_TEST(BookmarkBarViewTest11, CloseMenuAfterClosingContextMenu) // Tests showing a modal dialog from a context menu. class BookmarkBarViewTest12 : public BookmarkBarViewEventTestBase { @@ -1484,13 +1465,7 @@ BookmarkContextMenuNotificationObserver observer_; }; -// TODO(https://crbug.com/1281270): Flaky on Linux. -#if defined(OS_LINUX) -#define MAYBE_ContextMenus2 DISABLED_ContextMenus2 -#else -#define MAYBE_ContextMenus2 ContextMenus2 -#endif -VIEW_TEST(BookmarkBarViewTest14, MAYBE_ContextMenus2) +VIEW_TEST(BookmarkBarViewTest14, ContextMenus2) // Makes sure deleting from the context menu keeps the bookmark menu showing. class BookmarkBarViewTest15 : public BookmarkBarViewEventTestBase {
diff --git a/chrome/browser/ui/views/commander_frontend_views.cc b/chrome/browser/ui/views/commander_frontend_views.cc index f4de5c9..fd8bcc6 100644 --- a/chrome/browser/ui/views/commander_frontend_views.cc +++ b/chrome/browser/ui/views/commander_frontend_views.cc
@@ -4,9 +4,10 @@ #include "chrome/browser/ui/views/commander_frontend_views.h" +#include <tuple> + #include "base/bind.h" #include "base/callback_helpers.h" -#include "base/ignore_result.h" #include "base/memory/raw_ptr.h" #include "build/build_config.h" #include "build/chromeos_buildflags.h" @@ -213,7 +214,7 @@ focus_loss_watcher_.reset(); widget_delegate_->SetOwnedByWidget(true); - ignore_result(widget_delegate_.release()); + std::ignore = widget_delegate_.release(); widget_->Close(); widget_ = nullptr; }
diff --git a/chrome/browser/ui/views/commander_frontend_views_browsertest.cc b/chrome/browser/ui/views/commander_frontend_views_browsertest.cc index 596a2bc..25ca754 100644 --- a/chrome/browser/ui/views/commander_frontend_views_browsertest.cc +++ b/chrome/browser/ui/views/commander_frontend_views_browsertest.cc
@@ -4,7 +4,8 @@ #include "chrome/browser/ui/views/commander_frontend_views.h" -#include "base/ignore_result.h" +#include <tuple> + #include "base/memory/raw_ptr.h" #include "chrome/browser/ui/commander/commander_backend.h" #include "chrome/browser/ui/commander/commander_view_model.h" @@ -234,7 +235,7 @@ auto frontend = std::make_unique<CommanderFrontendViews>(backend_.get()); frontend->Show(browser()); - ignore_result(WaitForCommanderWidgetAttachedTo(browser())); + std::ignore = WaitForCommanderWidgetAttachedTo(browser()); frontend->OnOptionSelected(8, 13); ASSERT_EQ(backend_->command_selected_invocations().size(), 1u); @@ -245,7 +246,7 @@ IN_PROC_BROWSER_TEST_F(CommanderFrontendViewsTest, PassesOnTextChanged) { auto frontend = std::make_unique<CommanderFrontendViews>(backend_.get()); frontend->Show(browser()); - ignore_result(WaitForCommanderWidgetAttachedTo(browser())); + std::ignore = WaitForCommanderWidgetAttachedTo(browser()); const std::u16string input = u"orange"; frontend->OnTextChanged(input); @@ -258,7 +259,7 @@ PassesOnCompositeCommandCancelled) { auto frontend = std::make_unique<CommanderFrontendViews>(backend_.get()); frontend->Show(browser()); - ignore_result(WaitForCommanderWidgetAttachedTo(browser())); + std::ignore = WaitForCommanderWidgetAttachedTo(browser()); EXPECT_EQ(backend_->composite_command_cancelled_invocation_count(), 0); frontend->OnCompositeCommandCancelled();
diff --git a/chrome/browser/ui/views/frame/glass_browser_frame_view_browsertest_win.cc b/chrome/browser/ui/views/frame/glass_browser_frame_view_browsertest_win.cc index dbd2ebcb..060c697 100644 --- a/chrome/browser/ui/views/frame/glass_browser_frame_view_browsertest_win.cc +++ b/chrome/browser/ui/views/frame/glass_browser_frame_view_browsertest_win.cc
@@ -4,9 +4,10 @@ #include "chrome/browser/ui/views/frame/glass_browser_frame_view.h" +#include <tuple> + #include "base/bind.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/memory/raw_ptr.h" #include "base/strings/stringprintf.h" #include "base/test/bind.h" @@ -241,7 +242,7 @@ web_app_frame_toolbar_helper_.SetupGeometryChangeCallback(web_contents); browser_view_->ToggleWindowControlsOverlayEnabled(); content::TitleWatcher title_watcher(web_contents, u"ongeometrychange"); - ignore_result(title_watcher.WaitAndGetTitle()); + std::ignore = title_watcher.WaitAndGetTitle(); } raw_ptr<BrowserView> browser_view_ = nullptr;
diff --git a/chrome/browser/ui/views/frame/opaque_browser_frame_view_browsertest.cc b/chrome/browser/ui/views/frame/opaque_browser_frame_view_browsertest.cc index 75fce1e..5cfaff6 100644 --- a/chrome/browser/ui/views/frame/opaque_browser_frame_view_browsertest.cc +++ b/chrome/browser/ui/views/frame/opaque_browser_frame_view_browsertest.cc
@@ -4,8 +4,9 @@ #include "chrome/browser/ui/views/frame/opaque_browser_frame_view.h" +#include <tuple> + #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/memory/raw_ptr.h" #include "base/test/bind.h" #include "build/build_config.h" @@ -352,7 +353,7 @@ web_app_frame_toolbar_helper_.SetupGeometryChangeCallback(web_contents); browser_view_->ToggleWindowControlsOverlayEnabled(); content::TitleWatcher title_watcher(web_contents, u"ongeometrychange"); - ignore_result(title_watcher.WaitAndGetTitle()); + std::ignore = title_watcher.WaitAndGetTitle(); } raw_ptr<BrowserView> browser_view_ = nullptr;
diff --git a/chrome/browser/ui/views/frame/top_controls_slide_controller_chromeos_browsertest.cc b/chrome/browser/ui/views/frame/top_controls_slide_controller_chromeos_browsertest.cc index df9517c..6b03ab8e 100644 --- a/chrome/browser/ui/views/frame/top_controls_slide_controller_chromeos_browsertest.cc +++ b/chrome/browser/ui/views/frame/top_controls_slide_controller_chromeos_browsertest.cc
@@ -6,6 +6,7 @@ #include <memory> #include <numeric> +#include <tuple> #include <vector> #include "ash/constants/ash_switches.h" @@ -17,7 +18,6 @@ #include "base/bind.h" #include "base/callback_helpers.h" #include "base/command_line.h" -#include "base/ignore_result.h" #include "base/path_service.h" #include "base/strings/safe_sprintf.h" #include "build/chromeos_buildflags.h" @@ -867,7 +867,7 @@ EXPECT_TRUE(bool_result); // Evaluate an empty sentence to make sure that the event processing is done // in the content. - ignore_result(content::EvalJs(contents, ";")); + std::ignore = content::EvalJs(contents, ";"); SCOPED_TRACE("Scroll to hide should now work."); ScrollAndExpectTopChromeToBe(ScrollDirection::kDown, @@ -1085,7 +1085,7 @@ send_key_event(ui::VKEY_RETURN); // Evaluate an empty sentence to make sure that the event processing is done // in the content. - ignore_result(content::EvalJs(contents, ";")); + std::ignore = content::EvalJs(contents, ";"); // Verify that the selected option has changed and the fourth option is // selected.
diff --git a/chrome/browser/ui/views/omnibox/omnibox_view_views_browsertest.cc b/chrome/browser/ui/views/omnibox/omnibox_view_views_browsertest.cc index 4a74a40b..4bd17d9 100644 --- a/chrome/browser/ui/views/omnibox/omnibox_view_views_browsertest.cc +++ b/chrome/browser/ui/views/omnibox/omnibox_view_views_browsertest.cc
@@ -9,11 +9,14 @@ #include "base/command_line.h" #include "base/feature_list.h" #include "base/memory/raw_ptr.h" +#include "base/test/bind.h" #include "build/build_config.h" #include "build/chromeos_buildflags.h" #include "chrome/app/chrome_command_ids.h" #include "chrome/browser/external_protocol/external_protocol_handler.h" +#include "chrome/browser/interstitials/security_interstitial_page_test_utils.h" #include "chrome/browser/search_engines/template_url_service_factory.h" +#include "chrome/browser/ssl/typed_navigation_upgrade_throttle.h" #include "chrome/browser/ui/browser.h" #include "chrome/browser/ui/browser_commands.h" #include "chrome/browser/ui/browser_window.h" @@ -34,6 +37,8 @@ #include "content/public/browser/web_contents.h" #include "content/public/common/url_constants.h" #include "content/public/test/browser_test.h" +#include "content/public/test/test_navigation_observer.h" +#include "content/public/test/url_loader_interceptor.h" #include "third_party/blink/public/common/chrome_debug_urls.h" #include "ui/accessibility/accessibility_switches.h" #include "ui/accessibility/ax_action_data.h" @@ -129,6 +134,19 @@ } } + OmniboxView* omnibox() { + return browser()->window()->GetLocationBar()->GetOmniboxView(); + } + + void PressEnterAndWaitForNavigations(size_t num_expected_navigations) { + content::TestNavigationObserver navigation_observer( + browser()->tab_strip_model()->GetActiveWebContents(), + num_expected_navigations); + EXPECT_TRUE(ui_test_utils::SendKeyPressSync(browser(), ui::VKEY_RETURN, + false, false, false, false)); + navigation_observer.Wait(); + } + private: // InProcessBrowserTest: void SetUpOnMainThread() override { @@ -989,3 +1007,69 @@ #endif // !defined(OS_MAC) || defined(USE_AURA) } + +// SendKeyPressSync times out on Mac, probably due to https://crbug.com/824418. +#if defined(OS_MAC) +#define MAYBE_DefaultTypedNavigationsToHttps_ZeroSuggest_NoUpgrade \ + DISABLED_DefaultTypedNavigationsToHttps_ZeroSuggest_NoUpgrade +#else +#define MAYBE_DefaultTypedNavigationsToHttps_ZeroSuggest_NoUpgrade \ + DefaultTypedNavigationsToHttps_ZeroSuggest_NoUpgrade +#endif + +// Test that triggers a zero suggest autocomplete request by clicking on the +// omnibox. These should never attempt an https upgrade or involve the typed +// navigation upgrade throttle. +// This is a regression test for https://crbug.com/1251065 +IN_PROC_BROWSER_TEST_F( + OmniboxViewViewsTest, + MAYBE_DefaultTypedNavigationsToHttps_ZeroSuggest_NoUpgrade) { + // Since the embedded test server only works for URLs with non-default ports, + // use a URLLoaderInterceptor to mimic port-free operation. This allows the + // rest of the test to operate as if all URLs are using the default ports. + content::URLLoaderInterceptor interceptor(base::BindLambdaForTesting( + [&](content::URLLoaderInterceptor::RequestParams* params) { + if (params->url_request.url.host() == "site-with-good-https.com") { + std::string headers = + "HTTP/1.1 200 OK\nContent-Type: text/html; charset=utf-8\n"; + std::string body = "<html><title>Success</title>Hello world</html>"; + content::URLLoaderInterceptor::WriteResponse(headers, body, + params->client.get()); + return true; + } + // Not handled by us. + return false; + })); + + base::HistogramTester histograms; + const GURL url("https://site-with-good-https.com"); + + // Type "https://site-with-good-https.com". This should load fine without + // hitting TypedNavigationUpgradeThrottle. + omnibox()->SetUserText(base::UTF8ToUTF16(url.spec()), true); + PressEnterAndWaitForNavigations(1); + content::WebContents* contents = + browser()->tab_strip_model()->GetActiveWebContents(); + EXPECT_EQ(url, contents->GetLastCommittedURL()); + EXPECT_FALSE(chrome_browser_interstitials::IsShowingInterstitial(contents)); + + histograms.ExpectTotalCount(TypedNavigationUpgradeThrottle::kHistogramName, + 0); + ui_test_utils::HistoryEnumerator enumerator(browser()->profile()); + EXPECT_TRUE(base::Contains(enumerator.urls(), url)); + + // Now click the omnibox. This should trigger a zero suggest request with the + // text "site-with-good-https.com" despite the omnibox URL being + // https://site-with-good-https.com. Autocomplete input class shouldn't try + // to upgrade this request. + const gfx::Rect omnibox_bounds = + BrowserView::GetBrowserViewForBrowser(browser()) + ->GetViewByID(VIEW_ID_OMNIBOX) + ->GetBoundsInScreen(); + const gfx::Point click_location = omnibox_bounds.CenterPoint(); + ASSERT_NO_FATAL_FAILURE( + Click(ui_controls::LEFT, click_location, click_location)); + PressEnterAndWaitForNavigations(1); + histograms.ExpectTotalCount(TypedNavigationUpgradeThrottle::kHistogramName, + 0); +}
diff --git a/chrome/browser/ui/views/overlay/overlay_window_views.cc b/chrome/browser/ui/views/overlay/overlay_window_views.cc index d575bf7..1087db0 100644 --- a/chrome/browser/ui/views/overlay/overlay_window_views.cc +++ b/chrome/browser/ui/views/overlay/overlay_window_views.cc
@@ -851,12 +851,7 @@ void OverlayWindowViews::Close() { views::Widget::Close(); - - if (has_registered_frame_sink_hierarchy_) { - DCHECK(GetCurrentFrameSinkId()); - GetCompositor()->RemoveChildFrameSink(*GetCurrentFrameSinkId()); - has_registered_frame_sink_hierarchy_ = false; - } + MaybeUnregisterFrameSinkHierarchy(); } void OverlayWindowViews::ShowInactive() { @@ -889,12 +884,7 @@ void OverlayWindowViews::Hide() { views::Widget::Hide(); - - if (has_registered_frame_sink_hierarchy_) { - DCHECK(GetCurrentFrameSinkId()); - GetCompositor()->RemoveChildFrameSink(*GetCurrentFrameSinkId()); - has_registered_frame_sink_hierarchy_ = false; - } + MaybeUnregisterFrameSinkHierarchy(); } bool OverlayWindowViews::IsVisible() { @@ -996,10 +986,7 @@ // The PiP window may have a previous surface set. If the window stays open // since then, we need to unregister the previous frame sink; otherwise the // surface frame sink should already be removed when the window closed. - if (has_registered_frame_sink_hierarchy_) { - DCHECK(GetCurrentFrameSinkId()); - GetCompositor()->RemoveChildFrameSink(*GetCurrentFrameSinkId()); - } + MaybeUnregisterFrameSinkHierarchy(); // Add the new frame sink to the PiP window and set the surface. GetCompositor()->AddChildFrameSink(surface_id.frame_sink_id()); @@ -1026,11 +1013,7 @@ void OverlayWindowViews::OnNativeWidgetDestroying() { views::Widget::OnNativeWidgetDestroying(); - if (has_registered_frame_sink_hierarchy_) { - DCHECK(GetCurrentFrameSinkId()); - GetCompositor()->RemoveChildFrameSink(*GetCurrentFrameSinkId()); - has_registered_frame_sink_hierarchy_ = false; - } + MaybeUnregisterFrameSinkHierarchy(); } void OverlayWindowViews::OnNativeWidgetDestroyed() { @@ -1094,11 +1077,7 @@ } void OverlayWindowViews::OnNativeWidgetRemovingFromCompositor() { - if (has_registered_frame_sink_hierarchy_) { - DCHECK(GetCurrentFrameSinkId()); - GetCompositor()->RemoveChildFrameSink(*GetCurrentFrameSinkId()); - has_registered_frame_sink_hierarchy_ = false; - } + MaybeUnregisterFrameSinkHierarchy(); } void OverlayWindowViews::OnKeyEvent(ui::KeyEvent* event) { @@ -1429,3 +1408,11 @@ return nullptr; } + +void OverlayWindowViews::MaybeUnregisterFrameSinkHierarchy() { + if (has_registered_frame_sink_hierarchy_) { + DCHECK(GetCurrentFrameSinkId()); + GetCompositor()->RemoveChildFrameSink(*GetCurrentFrameSinkId()); + has_registered_frame_sink_hierarchy_ = false; + } +}
diff --git a/chrome/browser/ui/views/overlay/overlay_window_views.h b/chrome/browser/ui/views/overlay/overlay_window_views.h index 229b618..c1d1661 100644 --- a/chrome/browser/ui/views/overlay/overlay_window_views.h +++ b/chrome/browser/ui/views/overlay/overlay_window_views.h
@@ -202,6 +202,11 @@ // returns nullptr. const viz::FrameSinkId* GetCurrentFrameSinkId() const; + // Unregisters the current frame sink id for the surface displayed in the + // |video_view_| from its parent frame sink if the frame sink hierarchy has + // been registered before. + void MaybeUnregisterFrameSinkHierarchy(); + // Not owned; |controller_| owns |this|. raw_ptr<content::PictureInPictureWindowController> controller_;
diff --git a/chrome/browser/ui/views/toolbar/app_menu_browsertest.cc b/chrome/browser/ui/views/toolbar/app_menu_browsertest.cc index 58ee6c4..f266aa7e 100644 --- a/chrome/browser/ui/views/toolbar/app_menu_browsertest.cc +++ b/chrome/browser/ui/views/toolbar/app_menu_browsertest.cc
@@ -39,13 +39,7 @@ // properly (this was triggering a crash in AppMenu where it was trying to make // use of RecentTabsMenuModelDelegate before created). See // https://crbug.com/1249741 for more. -#if BUILDFLAG(IS_CHROMEOS_LACROS) -// TODO(crbug.com/1284776): Re-enable once flakiness is fixed. -#define MAYBE_ShowWithRecentlyClosedWindow DISABLED_ShowWithRecentlyClosedWindow -#else -#define MAYBE_ShowWithRecentlyClosedWindow ShowWithRecentlyClosedWindow -#endif -IN_PROC_BROWSER_TEST_F(AppMenuBrowserTest, MAYBE_ShowWithRecentlyClosedWindow) { +IN_PROC_BROWSER_TEST_F(AppMenuBrowserTest, ShowWithRecentlyClosedWindow) { // Create an additional browser, close it, and ensure it is added to the // TabRestoreService. sessions::TabRestoreService* tab_restore_service =
diff --git a/chrome/browser/ui/views/web_apps/frame_toolbar/web_app_frame_toolbar_browsertest.cc b/chrome/browser/ui/views/web_apps/frame_toolbar/web_app_frame_toolbar_browsertest.cc index e9bfaefb..ef08ab9 100644 --- a/chrome/browser/ui/views/web_apps/frame_toolbar/web_app_frame_toolbar_browsertest.cc +++ b/chrome/browser/ui/views/web_apps/frame_toolbar/web_app_frame_toolbar_browsertest.cc
@@ -3,8 +3,8 @@ // found in the LICENSE file. #include <cmath> +#include <tuple> -#include "base/ignore_result.h" #include "base/path_service.h" #include "base/run_loop.h" #include "base/test/bind.h" @@ -465,7 +465,7 @@ helper()->SetupGeometryChangeCallback(web_contents); content::TitleWatcher title_watcher(web_contents, u"ongeometrychange"); helper()->browser_view()->ToggleWindowControlsOverlayEnabled(); - ignore_result(title_watcher.WaitAndGetTitle()); + std::ignore = title_watcher.WaitAndGetTitle(); } bool GetWindowControlOverlayVisibility() { @@ -490,7 +490,7 @@ ->app_browser() ->tab_strip_model() ->GetActiveWebContents())); - ignore_result(title_watcher.WaitAndGetTitle()); + std::ignore = title_watcher.WaitAndGetTitle(); } gfx::Rect GetWindowControlOverlayBoundingClientRect() { @@ -532,7 +532,7 @@ helper()->SetupGeometryChangeCallback(web_contents); content::TitleWatcher title_watcher(web_contents, u"ongeometrychange"); helper()->browser_view()->GetWidget()->SetBounds(new_bounds); - ignore_result(title_watcher.WaitAndGetTitle()); + std::ignore = title_watcher.WaitAndGetTitle(); } gfx::Rect GetWindowControlOverlayBoundingClientRectFromEvent() {
diff --git a/chrome/browser/ui/web_applications/web_app_browsertest.cc b/chrome/browser/ui/web_applications/web_app_browsertest.cc index c9e1c42..2d36656 100644 --- a/chrome/browser/ui/web_applications/web_app_browsertest.cc +++ b/chrome/browser/ui/web_applications/web_app_browsertest.cc
@@ -18,7 +18,6 @@ #include "base/time/time.h" #include "build/build_config.h" #include "build/chromeos_buildflags.h" -#include "build/os_buildflags.h" #include "chrome/browser/apps/app_service/app_launch_params.h" #include "chrome/browser/apps/app_service/app_service_proxy.h" #include "chrome/browser/apps/app_service/app_service_proxy_factory.h"
diff --git a/chrome/browser/ui/web_applications/web_app_ui_manager_impl.cc b/chrome/browser/ui/web_applications/web_app_ui_manager_impl.cc index cf7f41a2..28b65a1 100644 --- a/chrome/browser/ui/web_applications/web_app_ui_manager_impl.cc +++ b/chrome/browser/ui/web_applications/web_app_ui_manager_impl.cc
@@ -25,6 +25,7 @@ #include "chrome/browser/ui/web_applications/web_app_metrics.h" #include "chrome/browser/ui/webui/web_app_internals/web_app_internals_source.h" #include "chrome/browser/web_applications/extensions/web_app_extension_shortcut.h" +#include "chrome/browser/web_applications/os_integration_manager.h" #include "chrome/browser/web_applications/system_web_apps/system_web_app_manager.h" #include "chrome/browser/web_applications/web_app_callback_app_identity.h" #include "chrome/browser/web_applications/web_app_provider.h" @@ -50,6 +51,12 @@ #endif #if defined(OS_WIN) +#include "base/process/process.h" +#include "chrome/browser/browser_process.h" +#include "chrome/browser/ui/browser_list.h" +#include "chrome/browser/web_applications/web_app_install_finalizer.h" +#include "components/keep_alive_registry/keep_alive_types.h" +#include "components/keep_alive_registry/scoped_keep_alive.h" #include "ui/gfx/native_widget_types.h" #endif // defined(OS_WIN) @@ -68,23 +75,36 @@ } #if defined(OS_WIN) - -// UninstallWebAppWithDialogFromStartupSwitch handles WebApp uninstallation from -// the Windows Settings. +// ScopedKeepAlive not only keeps the process from terminating early +// during uninstall, it also ensures the process will terminate when it +// is destroyed if there is no active browser window. void UninstallWebAppWithDialogFromStartupSwitch(const AppId& app_id, Profile* profile, WebAppProvider* provider) { - if (!provider->registrar().IsLocallyInstalled(app_id)) { - // App does not exist and controller is destroyed. - return; + // ScopedKeepAlive does not only keeps the process from early termination, + // but ensure the process termination when there is no active browser window. + std::unique_ptr<ScopedKeepAlive> scoped_keep_alive = + std::make_unique<ScopedKeepAlive>(KeepAliveOrigin::WEB_APP_UNINSTALL, + KeepAliveRestartOption::DISABLED); + if (provider->install_finalizer().CanUserUninstallWebApp(app_id)) { + WebAppUiManagerImpl::Get(provider)->dialog_manager().UninstallWebApp( + app_id, webapps::WebappUninstallSource::kOsSettings, + gfx::kNullNativeWindow, + base::BindOnce([](std::unique_ptr<ScopedKeepAlive> scoped_keep_alive, + bool success) {}, + std::move(scoped_keep_alive))); + } else { + // There is a chance that a previous invalid uninstall operation (due + // to a crash or otherwise) could end up orphaning an OsSettings entry. + // In this case we clean up the OsSettings entry. + web_app::OsHooksOptions options; + options[OsHookType::kUninstallationViaOsSettings] = true; + provider->os_integration_manager().UninstallOsHooks( + app_id, options, + base::BindOnce([](std::unique_ptr<ScopedKeepAlive> scoped_keep_alive, + OsHooksErrors os_hooks_errors) {}, + std::move(scoped_keep_alive))); } - - // Note: WebAppInstallFinalizer::UninstallWebApp creates a ScopedKeepAlive - // object which ensures the browser stays alive during the WebApp - // uninstall. - WebAppUiManagerImpl::Get(provider)->dialog_manager().UninstallWebApp( - app_id, webapps::WebappUninstallSource::kOsSettings, - gfx::kNullNativeWindow, base::DoNothing()); } #endif // defined(OS_WIN)
diff --git a/chrome/browser/ui/webui/access_code_cast/access_code_cast_ui.cc b/chrome/browser/ui/webui/access_code_cast/access_code_cast_ui.cc index dd54345..3047403 100644 --- a/chrome/browser/ui/webui/access_code_cast/access_code_cast_ui.cc +++ b/chrome/browser/ui/webui/access_code_cast/access_code_cast_ui.cc
@@ -71,7 +71,9 @@ webui_ = webui; } -void AccessCodeCastDialog::OnDialogClosed(const std::string& json_retval) {} +void AccessCodeCastDialog::OnDialogClosed(const std::string& json_retval) { + delete this; +} void AccessCodeCastDialog::OnCloseContents(content::WebContents* source, bool* out_close_dialog) {
diff --git a/chrome/browser/ui/webui/chrome_web_ui_controller_factory.cc b/chrome/browser/ui/webui/chrome_web_ui_controller_factory.cc index c72d362..28306c79e 100644 --- a/chrome/browser/ui/webui/chrome_web_ui_controller_factory.cc +++ b/chrome/browser/ui/webui/chrome_web_ui_controller_factory.cc
@@ -207,6 +207,7 @@ #include "chrome/browser/ash/web_applications/chrome_file_manager_ui_delegate.h" #include "chrome/browser/ash/web_applications/help_app/help_app_ui_delegate.h" #include "chrome/browser/ash/web_applications/media_app/chrome_media_app_ui_delegate.h" +#include "chrome/browser/ash/web_applications/personalization_app/chrome_personalization_app_theme_provider.h" #include "chrome/browser/ash/web_applications/personalization_app/chrome_personalization_app_wallpaper_provider.h" #include "chrome/browser/feedback/feedback_dialog_utils.h" #include "chrome/browser/nearby_sharing/nearby_sharing_service_factory.h" @@ -577,6 +578,18 @@ chrome::WebUIFeedbackSource::kConnectivityDiagnostics), /*show_feedback_button=*/!chrome::IsRunningInAppMode()); } + +template <> +WebUIController* NewWebUI<ash::PersonalizationAppUI>(WebUI* web_ui, + const GURL& url) { + auto theme_provider = + std::make_unique<ChromePersonalizationAppThemeProvider>(web_ui); + auto wallpaper_provider = + std::make_unique<ChromePersonalizationAppWallpaperProvider>(web_ui); + return new ash::PersonalizationAppUI(web_ui, std::move(theme_provider), + std::move(wallpaper_provider)); +} + #endif // BUILDFLAG(IS_CHROMEOS_ASH) #if BUILDFLAG(ENABLE_DICE_SUPPORT) @@ -960,8 +973,7 @@ } if (url.host_piece() == ash::kChromeUIPersonalizationAppHost && chromeos::features::IsWallpaperWebUIEnabled()) { - return &NewComponentUI<ash::PersonalizationAppUI, - ChromePersonalizationAppWallpaperProvider>; + return &NewWebUI<ash::PersonalizationAppUI>; } if (url.host_piece() == ash::kChromeUISystemExtensionsInternalsHost && base::FeatureList::IsEnabled(ash::features::kSystemExtensions)) {
diff --git a/chrome/browser/ui/webui/components/components_ui.cc b/chrome/browser/ui/webui/components/components_ui.cc index c6e87025..440d7f6 100644 --- a/chrome/browser/ui/webui/components/components_ui.cc +++ b/chrome/browser/ui/webui/components/components_ui.cc
@@ -75,7 +75,8 @@ user_manager::UserManager::Get()->IsLoggedInAsPublicAccount() #elif BUILDFLAG(IS_CHROMEOS_LACROS) chromeos::LacrosService::Get()->init_params()->session_type == - crosapi::mojom::SessionType::kPublicSession + crosapi::mojom::SessionType::kPublicSession || + profile->IsGuestSession() #else profile->IsOffTheRecord() #endif
diff --git a/chrome/browser/ui/webui/management/management_ui.cc b/chrome/browser/ui/webui/management/management_ui.cc index 4f000cf..d7128ac 100644 --- a/chrome/browser/ui/webui/management/management_ui.cc +++ b/chrome/browser/ui/webui/management/management_ui.cc
@@ -72,6 +72,7 @@ {kManagementPrinting, IDS_MANAGEMENT_REPORT_PRINTING}, {kManagementReportPrintJobs, IDS_MANAGEMENT_REPORT_PRINT_JOBS}, {kManagementReportLoginLogout, IDS_MANAGEMENT_REPORT_LOGIN_LOGOUT}, + {kManagementReportCRDSessions, IDS_MANAGEMENT_REPORT_CRD_SESSIONS}, {kManagementCrostini, IDS_MANAGEMENT_CROSTINI}, {kManagementCrostiniContainerConfiguration, IDS_MANAGEMENT_CROSTINI_CONTAINER_CONFIGURATION},
diff --git a/chrome/browser/ui/webui/management/management_ui_handler.cc b/chrome/browser/ui/webui/management/management_ui_handler.cc index 77cca35e..25f8db3 100644 --- a/chrome/browser/ui/webui/management/management_ui_handler.cc +++ b/chrome/browser/ui/webui/management/management_ui_handler.cc
@@ -171,6 +171,7 @@ "managementReportAndroidApplications"; const char kManagementReportPrintJobs[] = "managementReportPrintJobs"; const char kManagementReportLoginLogout[] = "managementReportLoginLogout"; +const char kManagementReportCRDSessions[] = "managementReportCRDSessions"; const char kManagementReportDlpEvents[] = "managementReportDlpEvents"; const char kManagementPrinting[] = "managementPrinting"; const char kManagementCrostini[] = "managementCrostini"; @@ -221,7 +222,8 @@ kExtensions, kAndroidApplication, kDlpEvents, - kLoginLogout + kLoginLogout, + kCRDSessions, }; // Corresponds to DeviceReportingType in management_browser_proxy.js @@ -257,6 +259,8 @@ return "dlp events"; case DeviceReportingType::kLoginLogout: return "login-logout"; + case DeviceReportingType::kCRDSessions: + return "crd sessions"; default: NOTREACHED() << "Unknown device reporting type"; return "device"; @@ -630,6 +634,14 @@ AddDeviceReportingElement(report_sources, kManagementReportLoginLogout, DeviceReportingType::kLoginLogout); } + + bool report_crd_sessions = false; + chromeos::CrosSettings::Get()->GetBoolean(ash::kReportCRDSessions, + &report_crd_sessions); + if (report_crd_sessions) { + AddDeviceReportingElement(report_sources, kManagementReportCRDSessions, + DeviceReportingType::kCRDSessions); + } } bool ManagementUIHandler::IsUpdateRequiredEol() const {
diff --git a/chrome/browser/ui/webui/management/management_ui_handler.h b/chrome/browser/ui/webui/management/management_ui_handler.h index 474311e..2d9841b 100644 --- a/chrome/browser/ui/webui/management/management_ui_handler.h +++ b/chrome/browser/ui/webui/management/management_ui_handler.h
@@ -36,6 +36,7 @@ extern const char kManagementReportPrintJobs[]; extern const char kManagementReportDlpEvents[]; extern const char kManagementReportLoginLogout[]; +extern const char kManagementReportCRDSessions[]; extern const char kManagementPrinting[]; extern const char kManagementCrostini[]; extern const char kManagementCrostiniContainerConfiguration[];
diff --git a/chrome/browser/ui/webui/settings/settings_localized_strings_provider.cc b/chrome/browser/ui/webui/settings/settings_localized_strings_provider.cc index 8fc7ea14..d63ac6d 100644 --- a/chrome/browser/ui/webui/settings/settings_localized_strings_provider.cc +++ b/chrome/browser/ui/webui/settings/settings_localized_strings_provider.cc
@@ -208,7 +208,8 @@ user_manager::UserManager::Get()->IsLoggedInAsPublicAccount()); #elif BUILDFLAG(IS_CHROMEOS_LACROS) chromeos::LacrosService::Get()->init_params()->session_type == - crosapi::mojom::SessionType::kPublicSession); + crosapi::mojom::SessionType::kPublicSession || + profile->IsGuestSession()); #else profile->IsGuestSession()); #endif
diff --git a/chrome/browser/ui/webui/settings/site_settings_handler.cc b/chrome/browser/ui/webui/settings/site_settings_handler.cc index 52c1642..9fcd702 100644 --- a/chrome/browser/ui/webui/settings/site_settings_handler.cc +++ b/chrome/browser/ui/webui/settings/site_settings_handler.cc
@@ -622,9 +622,13 @@ void SiteSettingsHandler::OnJavascriptAllowed() { ObserveSourcesForProfile(profile_); - if (profile_->HasPrimaryOTRProfile()) - ObserveSourcesForProfile( - profile_->GetPrimaryOTRProfile(/*create_if_needed=*/true)); + if (profile_->HasPrimaryOTRProfile()) { + auto* primary_otr_profile = + profile_->GetPrimaryOTRProfile(/*create_if_needed=*/true); + // Avoid duplicate observation. + if (primary_otr_profile != profile_) + ObserveSourcesForProfile(primary_otr_profile); + } // Here we only subscribe to the HostZoomMap for the default storage partition // since we don't allow the user to manage the zoom levels for apps.
diff --git a/chrome/browser/webshare/win/fake_storage_file_statics.cc b/chrome/browser/webshare/win/fake_storage_file_statics.cc index 3088add..4c200158 100644 --- a/chrome/browser/webshare/win/fake_storage_file_statics.cc +++ b/chrome/browser/webshare/win/fake_storage_file_statics.cc
@@ -9,10 +9,10 @@ #include <wrl/module.h> #include <memory> +#include <tuple> #include "base/bind.h" #include "base/callback_helpers.h" -#include "base/ignore_result.h" #include "base/test/bind.h" #include "base/test/fake_iasync_operation_win.h" #include "base/threading/thread_task_runner_handle.h" @@ -63,7 +63,7 @@ // and release ownership of the original 'back' to the caller. base::win::ScopedHString holder(display_name_with_extension); display_name_with_extension_ = holder.GetAsUTF8(); - ignore_result(holder.release()); + std::ignore = holder.release(); } FakeStorageFile(const FakeStorageFile&) = delete; FakeStorageFile& operator=(const FakeStorageFile&) = delete;
diff --git a/chrome/browser/webshare/win/fake_uri_runtime_class_factory.cc b/chrome/browser/webshare/win/fake_uri_runtime_class_factory.cc index d6fa279..acfe510a 100644 --- a/chrome/browser/webshare/win/fake_uri_runtime_class_factory.cc +++ b/chrome/browser/webshare/win/fake_uri_runtime_class_factory.cc
@@ -5,8 +5,8 @@ #include "chrome/browser/webshare/win/fake_uri_runtime_class_factory.h" #include <string> +#include <tuple> -#include "base/ignore_result.h" #include "base/notreached.h" #include "base/win/scoped_hstring.h" #include "testing/gtest/include/gtest/gtest.h" @@ -151,7 +151,7 @@ // and release ownership of the original 'back' to the caller. base::win::ScopedHString holder(uri); auto uri_string = holder.GetAsUTF8(); - ignore_result(holder.release()); + std::ignore = holder.release(); if (uri_string.empty()) { ADD_FAILURE() << "CreateUri called with empty uri.";
diff --git a/chrome/browser/xsurface/android/java/src/org/chromium/chrome/browser/xsurface/SurfaceActionsHandler.java b/chrome/browser/xsurface/android/java/src/org/chromium/chrome/browser/xsurface/SurfaceActionsHandler.java index c776472..0894a01 100644 --- a/chrome/browser/xsurface/android/java/src/org/chromium/chrome/browser/xsurface/SurfaceActionsHandler.java +++ b/chrome/browser/xsurface/android/java/src/org/chromium/chrome/browser/xsurface/SurfaceActionsHandler.java
@@ -6,6 +6,7 @@ import android.view.View; +import java.util.List; /** * Interface to provide chromium calling points for an external surface. */ @@ -48,6 +49,15 @@ /** * Dismiss the open bottom sheet (or do nothing if there isn't one). + * */ default void dismissBottomSheet() {} + + /** + * Notifies the host app that url with broadTopicMids and entityMids was clicked. + * @param url The URL that the user clicked on + * @param entityMids Sorted list (most relevant to least) of entity MIDs that correspond to the + * clicked URL + */ + default void updateUserProfileOnLinkClick(String url, List<Long> entityMids) {} }
diff --git a/chrome/build/linux.pgo.txt b/chrome/build/linux.pgo.txt index 863fa0a..29dd729 100644 --- a/chrome/build/linux.pgo.txt +++ b/chrome/build/linux.pgo.txt
@@ -1 +1 @@ -chrome-linux-main-1641556507-e2d334876b7c435f81feb7b610d45fdc9019eb04.profdata +chrome-linux-main-1641578242-e0d6e8e2f0ae433564dcd4bdef588458f8d9b944.profdata
diff --git a/chrome/build/mac.pgo.txt b/chrome/build/mac.pgo.txt index 0fb791f..656a72cf 100644 --- a/chrome/build/mac.pgo.txt +++ b/chrome/build/mac.pgo.txt
@@ -1 +1 @@ -chrome-mac-main-1641556507-72cc532d3d10b7e16fb62ec615fb675b5366780f.profdata +chrome-mac-main-1641578242-7bbbc65858af9963fc8e8bd496e65e9b5bb8af08.profdata
diff --git a/chrome/build/win32.pgo.txt b/chrome/build/win32.pgo.txt index c22bcf97..d94a2d4 100644 --- a/chrome/build/win32.pgo.txt +++ b/chrome/build/win32.pgo.txt
@@ -1 +1 @@ -chrome-win32-main-1641556507-c8f9e38ff290f4fa5b5425e7ee57d20422cc8617.profdata +chrome-win32-main-1641589129-0dde4ecb02410249eb818b00718aa90d6a6d6006.profdata
diff --git a/chrome/build/win64.pgo.txt b/chrome/build/win64.pgo.txt index 67f523ea..52651dd 100644 --- a/chrome/build/win64.pgo.txt +++ b/chrome/build/win64.pgo.txt
@@ -1 +1 @@ -chrome-win64-main-1641556507-6812bf0408b84a4c9d107273eda1d933a9a4d878.profdata +chrome-win64-main-1641589129-dcfca0fc35387db301739a73849e085e42f51e67.profdata
diff --git a/chrome/common/BUILD.gn b/chrome/common/BUILD.gn index 7515a12..7f441d95d 100644 --- a/chrome/common/BUILD.gn +++ b/chrome/common/BUILD.gn
@@ -241,7 +241,6 @@ deps = [ "//build:chromeos_buildflags", - "//build:os_buildflags", "//components/google/core/common", "//components/live_caption:constants", "//components/metrics:call_stack_profile_builder",
diff --git a/chrome/common/channel_info_mac.mm b/chrome/common/channel_info_mac.mm index 6f13900..59da30c 100644 --- a/chrome/common/channel_info_mac.mm +++ b/chrome/common/channel_info_mac.mm
@@ -6,8 +6,9 @@ #import <Foundation/Foundation.h> +#include <tuple> + #include "base/check.h" -#include "base/ignore_result.h" #include "base/mac/bundle_locations.h" #include "base/no_destructor.h" #include "base/strings/sys_string_conversions.h" @@ -127,8 +128,8 @@ } // namespace void CacheChannelInfo() { - ignore_result(GetChannelState()); - ignore_result(SideBySideCapable()); + std::ignore = GetChannelState(); + std::ignore = SideBySideCapable(); } std::string GetChannelName(WithExtendedStable with_extended_stable) {
diff --git a/chrome/common/chrome_constants.cc b/chrome/common/chrome_constants.cc index cc31b30..5e5e7a2 100644 --- a/chrome/common/chrome_constants.cc +++ b/chrome/common/chrome_constants.cc
@@ -139,6 +139,8 @@ FPL("previews_opt_out.db"); const base::FilePath::CharType kQueryTileStorageDirname[] = FPL("Query Tiles"); const base::FilePath::CharType kReadmeFilename[] = FPL("README"); +const base::FilePath::CharType kSCTAuditingPendingReportsFileName[] = + FPL("SCT Auditing Pending Reports"); const base::FilePath::CharType kSecurePreferencesFilename[] = FPL("Secure Preferences"); const base::FilePath::CharType kServiceStateFileName[] = FPL("Service State");
diff --git a/chrome/common/chrome_constants.h b/chrome/common/chrome_constants.h index ad02022..dee5b6d 100644 --- a/chrome/common/chrome_constants.h +++ b/chrome/common/chrome_constants.h
@@ -67,6 +67,7 @@ extern const base::FilePath::CharType kQueryTileStorageDirname[]; extern const base::FilePath::CharType kReadmeFilename[]; extern const base::FilePath::CharType kReportingAndNelStoreFilename[]; +extern const base::FilePath::CharType kSCTAuditingPendingReportsFileName[]; extern const base::FilePath::CharType kSecurePreferencesFilename[]; extern const base::FilePath::CharType kSegmentationPlatformStorageDirName[]; extern const base::FilePath::CharType kServiceStateFileName[];
diff --git a/chrome/common/chrome_features.cc b/chrome/common/chrome_features.cc index 75dd8cd..8c8188e 100644 --- a/chrome/common/chrome_features.cc +++ b/chrome/common/chrome_features.cc
@@ -52,6 +52,11 @@ const base::Feature kAlwaysReinstallSystemWebApps{ "ReinstallSystemWebApps", base::FEATURE_DISABLED_BY_DEFAULT}; +#if defined(OS_ANDROID) +const base::Feature kAnonymousUpdateChecks{"AnonymousUpdateChecks", + base::FEATURE_ENABLED_BY_DEFAULT}; +#endif + #if BUILDFLAG(IS_CHROMEOS_ASH) // Controls whether web apps can be installed via APKs on Chrome OS. const base::Feature kApkWebAppInstalls{"ApkWebAppInstalls", @@ -323,6 +328,15 @@ const base::Feature kDesktopPWAsWebBundles{"DesktopPWAsWebBundles", base::FEATURE_DISABLED_BY_DEFAULT}; +// Tries disabling the HTTP disk cache. +const base::Feature kDisableHttpDiskCache{"DisableHttpDiskCache", + base::FEATURE_DISABLED_BY_DEFAULT}; + +// The size of the memory cache when the HTTP disk cache is disabled. 0 uses the +// default size. +const base::FeatureParam<int> kDisableHttpDiskCacheMemoryCacheSizeParam{ + &kDisableHttpDiskCache, "MemoryCacheSize", 0}; + // Enable DNS over HTTPS (DoH). const base::Feature kDnsOverHttps { "DnsOverHttps",
diff --git a/chrome/common/chrome_features.h b/chrome/common/chrome_features.h index 73dc22583..fdab413 100644 --- a/chrome/common/chrome_features.h +++ b/chrome/common/chrome_features.h
@@ -53,6 +53,11 @@ COMPONENT_EXPORT(CHROME_FEATURES) extern const base::Feature kAlwaysReinstallSystemWebApps; +#if defined(OS_ANDROID) +COMPONENT_EXPORT(CHROME_FEATURES) +extern const base::Feature kAnonymousUpdateChecks; +#endif + #if BUILDFLAG(IS_CHROMEOS_ASH) COMPONENT_EXPORT(CHROME_FEATURES) extern const base::Feature kApkWebAppInstalls; @@ -219,6 +224,11 @@ extern const base::Feature kDesktopPWAsWebBundles; COMPONENT_EXPORT(CHROME_FEATURES) +extern const base::Feature kDisableHttpDiskCache; +COMPONENT_EXPORT(CHROME_FEATURES) +extern const base::FeatureParam<int> kDisableHttpDiskCacheMemoryCacheSizeParam; + +COMPONENT_EXPORT(CHROME_FEATURES) extern const base::Feature kDnsOverHttps; COMPONENT_EXPORT(CHROME_FEATURES) extern const base::FeatureParam<bool> kDnsOverHttpsFallbackParam;
diff --git a/chrome/common/logging_chrome.cc b/chrome/common/logging_chrome.cc index a3594ca7..78840ec 100644 --- a/chrome/common/logging_chrome.cc +++ b/chrome/common/logging_chrome.cc
@@ -3,7 +3,6 @@ // found in the LICENSE file. #include "build/build_config.h" -#include "build/os_buildflags.h" // Need to include this before most other files because it defines // IPC_MESSAGE_LOG_ENABLED. We need to use it to define
diff --git a/chrome/common/net/x509_certificate_model_nss.cc b/chrome/common/net/x509_certificate_model_nss.cc index 8c43dff..46d5897 100644 --- a/chrome/common/net/x509_certificate_model_nss.cc +++ b/chrome/common/net/x509_certificate_model_nss.cc
@@ -20,8 +20,8 @@ #include <algorithm> #include <memory> +#include <tuple> -#include "base/ignore_result.h" #include "base/logging.h" #include "base/numerics/safe_conversions.h" #include "base/strings/string_number_conversions.h" @@ -258,7 +258,7 @@ NSSCMSContentInfo *cinfo = NSS_CMSMessage_GetContentInfo(message.get()); if (NSS_CMSContentInfo_SetContent_SignedData( message.get(), cinfo, signed_data.get()) == SECSuccess) { - ignore_result(signed_data.release()); + std::ignore = signed_data.release(); } else { DLOG(ERROR) << "NSS_CMSMessage_GetContentInfo failed"; return std::string();
diff --git a/chrome/installer/util/beacons.cc b/chrome/installer/util/beacons.cc index d72ee95..0745dca 100644 --- a/chrome/installer/util/beacons.cc +++ b/chrome/installer/util/beacons.cc
@@ -6,7 +6,8 @@ #include <stdint.h> -#include "base/ignore_result.h" +#include <tuple> + #include "base/notreached.h" #include "base/win/registry.h" #include "base/win/win_util.h" @@ -16,7 +17,7 @@ void UpdateDefaultBrowserBeaconForPath(const base::FilePath& chrome_exe) { // Getting Chrome's default state causes the beacon to be updated via a call // to UpdateDefaultBrowserBeaconWithState below. - ignore_result(ShellUtil::GetChromeDefaultStateFromPath(chrome_exe)); + std::ignore = ShellUtil::GetChromeDefaultStateFromPath(chrome_exe); } void UpdateDefaultBrowserBeaconWithState(
diff --git a/chrome/installer/util/beacons_unittest.cc b/chrome/installer/util/beacons_unittest.cc index e9cadd5..5e5417f 100644 --- a/chrome/installer/util/beacons_unittest.cc +++ b/chrome/installer/util/beacons_unittest.cc
@@ -7,7 +7,6 @@ #include <memory> #include <tuple> -#include "base/ignore_result.h" #include "base/test/test_reg_util_win.h" #include "base/test/test_timeouts.h" #include "base/threading/platform_thread.h" @@ -141,7 +140,7 @@ // Software\Chromium, so it always exists. // Silence unused variable warnings. - ignore_result(wrong_root); + std::ignore = wrong_root; #endif // The right key should exist.
diff --git a/chrome/installer/util/logging_installer.cc b/chrome/installer/util/logging_installer.cc index 30ac80c..42c37dc 100644 --- a/chrome/installer/util/logging_installer.cc +++ b/chrome/installer/util/logging_installer.cc
@@ -7,11 +7,12 @@ #include <stdint.h> #include <windows.h> +#include <tuple> + #include "base/command_line.h" #include "base/files/file.h" #include "base/files/file_path.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/logging.h" #include "base/logging_win.h" #include "base/path_service.h" @@ -126,7 +127,7 @@ // Fallback to current directory if getting the temp directory fails. base::FilePath tmp_path; - ignore_result(base::PathService::Get(base::DIR_TEMP, &tmp_path)); + std::ignore = base::PathService::Get(base::DIR_TEMP, &tmp_path); return tmp_path.Append(kLogFilename); }
diff --git a/chrome/service/cloud_print/print_system_win.cc b/chrome/service/cloud_print/print_system_win.cc index d6ad50cb..77f3c99b 100644 --- a/chrome/service/cloud_print/print_system_win.cc +++ b/chrome/service/cloud_print/print_system_win.cc
@@ -9,11 +9,11 @@ #include <wrl/client.h> #include <memory> +#include <tuple> #include "base/bind.h" #include "base/command_line.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/json/json_writer.h" #include "base/memory/free_deleter.h" #include "base/memory/raw_ptr.h" @@ -448,7 +448,7 @@ /*autorotate=*/false, use_color, printing::PdfRenderSettings::Mode::NORMAL))) { // The object will self-destruct when the child process dies. - ignore_result(utility_host.release()); + std::ignore = utility_host.release(); } else { client_task_runner->PostTask( FROM_HERE, base::BindOnce(&Core::PrintJobDone, this, false)); @@ -587,7 +587,7 @@ this, client_task_runner.get()); if (utility_host->StartGetPrinterCapsAndDefaults(printer_name_)) { // The object will self-destruct when the child process dies. - ignore_result(utility_host.release()); + std::ignore = utility_host.release(); } else { client_task_runner->PostTask( FROM_HERE, base::BindOnce(&PrinterCapsHandler::OnChildDied, this)); @@ -601,7 +601,7 @@ this, client_task_runner.get()); if (utility_host->StartGetPrinterSemanticCapsAndDefaults(printer_name_)) { // The object will self-destruct when the child process dies. - ignore_result(utility_host.release()); + std::ignore = utility_host.release(); } else { client_task_runner->PostTask( FROM_HERE, base::BindOnce(&PrinterCapsHandler::OnChildDied, this));
diff --git a/chrome/services/speech/soda/soda_client.cc b/chrome/services/speech/soda/soda_client.cc index b48f456b..cd620fa 100644 --- a/chrome/services/speech/soda/soda_client.cc +++ b/chrome/services/speech/soda/soda_client.cc
@@ -4,7 +4,8 @@ #include "chrome/services/speech/soda/soda_client.h" -#include "base/ignore_result.h" +#include <tuple> + #include "base/logging.h" #include "base/metrics/histogram_functions.h" #include "build/build_config.h" @@ -75,7 +76,7 @@ // there have been no crashes on 10.15+, likely due to a change in the // __cxa_atexit implementation. if (base::mac::IsAtMostOS10_14()) - ignore_result(lib_.release()); + std::ignore = lib_.release(); #endif // defined(OS_MAC) }
diff --git a/chrome/test/BUILD.gn b/chrome/test/BUILD.gn index 4ed72e1..02d3c6c 100644 --- a/chrome/test/BUILD.gn +++ b/chrome/test/BUILD.gn
@@ -1071,7 +1071,6 @@ "//base/test:test_support", "//build:branding_buildflags", "//build:chromeos_buildflags", - "//build:os_buildflags", "//chrome:packed_resources", "//chrome:resources", "//chrome:strings", @@ -4825,7 +4824,6 @@ "../browser/signin/chrome_signin_proxying_url_loader_factory_unittest.cc", "../browser/signin/chrome_signin_status_metrics_provider_delegate_unittest.cc", "../browser/signin/chrome_signin_url_loader_throttle_unittest.cc", - "../browser/signin/primary_account_policy_manager_unittest.cc", "../browser/signin/reauth_tab_helper_unittest.cc", "../browser/signin/reauth_util_unittest.cc", "../browser/signin/signin_profile_attributes_updater_unittest.cc",
diff --git a/chrome/test/base/js2gtest.js b/chrome/test/base/js2gtest.js index 47d3488..67aeb4d5 100644 --- a/chrome/test/base/js2gtest.js +++ b/chrome/test/base/js2gtest.js
@@ -168,7 +168,8 @@ addSetPreloadInfo = true; } output(` -#include "base/ignore_result.h" +#include <tuple> + #include "url/gurl.h" #include "testing/gtest/include/gtest/gtest.h"`); // Add includes specified by test fixture. @@ -473,7 +474,7 @@ } if (testServer) { output(` - ignore_result(embedded_test_server()->Start());`); + std::ignore = embedded_test_server()->Start();`); } output(` }`);
diff --git a/chrome/test/data/webui/BUILD.gn b/chrome/test/data/webui/BUILD.gn index 32f8d0d..32f3f7a 100644 --- a/chrome/test/data/webui/BUILD.gn +++ b/chrome/test/data/webui/BUILD.gn
@@ -134,7 +134,7 @@ deps += [ ":modulize" ] } - if (is_chromeos_ash || is_win) { + if (is_chromeos_ash || is_win || is_mac) { sources += [ "inline_login/inline_login_browsertest.js" ] deps += [ "//build:chromeos_buildflags" ] } @@ -670,6 +670,15 @@ # TypeScript related targets +checked_in_dts_files = [ "test_browser_proxy.d.ts" ] + +# Copies checked-in .d.ts files to the preprocess folder so that they are +# discovered by TSC the same way generated .d.ts files are. +copy("copy_checked_in_dts_files") { + sources = checked_in_dts_files + outputs = [ "$target_gen_dir/tsc/{{source_target_relative}}" ] +} + ts_definitions("generate_definitions") { root_dir = "./" out_dir = "$target_gen_dir/tsc" @@ -679,10 +688,12 @@ "mock_controller.js", "mock_timer.js", "mojo_webui_test_support.js", - "test_browser_proxy.js", "test_plural_string_proxy.js", "test_store.js", "test_util.js", ] - extra_deps = [ "//ui/webui/resources:generate_definitions" ] + extra_deps = [ + ":copy_checked_in_dts_files", + "//ui/webui/resources:generate_definitions", + ] }
diff --git a/chrome/test/data/webui/chromeos/personalization_app/BUILD.gn b/chrome/test/data/webui/chromeos/personalization_app/BUILD.gn index 19429b1..1a90624 100644 --- a/chrome/test/data/webui/chromeos/personalization_app/BUILD.gn +++ b/chrome/test/data/webui/chromeos/personalization_app/BUILD.gn
@@ -39,8 +39,10 @@ "personalization_breadcrumb_element_test.ts", "personalization_main_element_test.ts", "personalization_router_element_test.ts", + "personalization_theme_element_test.ts", "personalization_toast_element_test.ts", "test_personalization_store.ts", + "test_theme_interface_provider.ts", "test_wallpaper_interface_provider.ts", "user_subpage_element_test.ts", "wallpaper_collections_element_test.ts",
diff --git a/chrome/test/data/webui/chromeos/personalization_app/personalization_app_component_test.ts b/chrome/test/data/webui/chromeos/personalization_app/personalization_app_component_test.ts index bd5f9c8f..0e47c38 100644 --- a/chrome/test/data/webui/chromeos/personalization_app/personalization_app_component_test.ts +++ b/chrome/test/data/webui/chromeos/personalization_app/personalization_app_component_test.ts
@@ -12,6 +12,7 @@ import {PersonalizationBreadcrumbTest} from './personalization_breadcrumb_element_test.js'; import {PersonalizationMainTest} from './personalization_main_element_test.js'; import {PersonalizationRouterTest} from './personalization_router_element_test.js'; +import {PersonalizationThemeTest} from './personalization_theme_element_test.js'; import {PersonalizationToastTest} from './personalization_toast_element_test.js'; import {UserSubpageTest} from './user_subpage_element_test.js'; import {WallpaperCollectionsTest} from './wallpaper_collections_element_test.js'; @@ -33,6 +34,7 @@ PersonalizationBreadcrumbTest, PersonalizationMainTest, PersonalizationRouterTest, + PersonalizationThemeTest, PersonalizationToastTest, UserSubpageTest, WallpaperCollectionsTest,
diff --git a/chrome/test/data/webui/chromeos/personalization_app/personalization_app_test_utils.ts b/chrome/test/data/webui/chromeos/personalization_app/personalization_app_test_utils.ts index 7ce5122..8130f538 100644 --- a/chrome/test/data/webui/chromeos/personalization_app/personalization_app_test_utils.ts +++ b/chrome/test/data/webui/chromeos/personalization_app/personalization_app_test_utils.ts
@@ -8,6 +8,7 @@ */ import {emptyState, PersonalizationState} from 'chrome://personalization/trusted/personalization_state.js'; +import {setThemeProviderForTesting} from 'chrome://personalization/trusted/theme/theme_interface_provider.js'; import {setWallpaperProviderForTesting} from 'chrome://personalization/trusted/wallpaper/wallpaper_interface_provider.js'; import {flush, PolymerElement} from 'chrome://resources/polymer/v3_0/polymer/polymer_bundled.min.js'; @@ -15,6 +16,7 @@ import {flushTasks} from 'chrome://webui-test/test_util.js'; import {TestPersonalizationStore} from './test_personalization_store.js'; +import {TestThemeProvider} from './test_theme_interface_provider.js'; import {TestWallpaperProvider} from './test_wallpaper_interface_provider.js'; /** @@ -56,10 +58,12 @@ export function baseSetup(initialState: PersonalizationState = emptyState()) { const wallpaperProvider = new TestWallpaperProvider(); setWallpaperProviderForTesting(wallpaperProvider); + const themeProvider = new TestThemeProvider(); + setThemeProviderForTesting(themeProvider); const personalizationStore = new TestPersonalizationStore(initialState); personalizationStore.replaceSingleton(); document.body.innerHTML = ''; - return {wallpaperProvider, personalizationStore}; + return {themeProvider, wallpaperProvider, personalizationStore}; } function getDebugString(w: any) {
diff --git a/chrome/test/data/webui/chromeos/personalization_app/personalization_theme_element_test.ts b/chrome/test/data/webui/chromeos/personalization_app/personalization_theme_element_test.ts new file mode 100644 index 0000000..1d920ec --- /dev/null +++ b/chrome/test/data/webui/chromeos/personalization_app/personalization_theme_element_test.ts
@@ -0,0 +1,81 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** @fileoverview Test suite for theme-element component. */ + +import {emptyState} from 'chrome://personalization/trusted/personalization_state.js'; +import {PersonalizationThemeElement} from 'chrome://personalization/trusted/personalization_theme_element.js'; +import {ThemeActionName} from 'chrome://personalization/trusted/theme/theme_actions.js'; +import {assertDeepEquals, assertEquals, assertFalse, assertTrue} from 'chrome://webui-test/chai_assert.js'; +import {flushTasks, waitAfterNextRender} from 'chrome://webui-test/test_util.js'; + +import {baseSetup, initElement} from './personalization_app_test_utils.js'; +import {TestPersonalizationStore} from './test_personalization_store.js'; +import {TestThemeProvider} from './test_theme_interface_provider.js'; + +export function PersonalizationThemeTest() { + let personalizationThemeElement: PersonalizationThemeElement|null; + let themeProvider: TestThemeProvider; + let personalizationStore: TestPersonalizationStore; + + setup(() => { + const mocks = baseSetup(); + themeProvider = mocks.themeProvider; + personalizationStore = mocks.personalizationStore; + }); + + teardown(async () => { + if (personalizationThemeElement) { + personalizationThemeElement.remove(); + } + personalizationThemeElement = null; + await flushTasks(); + }); + + test('displays content', async () => { + personalizationStore.data.theme = {darkModeEnabled: false}; + personalizationThemeElement = initElement(PersonalizationThemeElement); + await waitAfterNextRender(personalizationThemeElement); + + assertEquals( + personalizationThemeElement.i18n('themeLabel'), + personalizationThemeElement.shadowRoot!.querySelector('h2')!.innerText); + }); + + test('sets color mode in store on first load', async () => { + personalizationStore.expectAction(ThemeActionName.SET_DARK_MODE_ENABLED); + personalizationThemeElement = initElement(PersonalizationThemeElement); + const action = await personalizationStore.waitForAction( + ThemeActionName.SET_DARK_MODE_ENABLED); + assertTrue(action.enabled); + }); + + test('sets theme data in store on changed', async () => { + // Make sure state starts as expected. + assertDeepEquals(emptyState(), personalizationStore.data); + + personalizationThemeElement = initElement(PersonalizationThemeElement); + + await themeProvider.whenCalled('setThemeObserver'); + + personalizationStore.expectAction(ThemeActionName.SET_DARK_MODE_ENABLED); + themeProvider.themeObserverRemote!.onColorModeChanged( + /*darkModeEnabled=*/ false); + + const {enabled} = await personalizationStore.waitForAction( + ThemeActionName.SET_DARK_MODE_ENABLED); + assertFalse(enabled); + }); + + test('shows pressed button on load', async () => { + personalizationThemeElement = initElement(PersonalizationThemeElement); + personalizationStore.data.theme.darkModeEnabled = true; + personalizationStore.notifyObservers(); + await waitAfterNextRender(personalizationThemeElement); + const radioButton = + personalizationThemeElement.shadowRoot!.getElementById('darkMode'); + assertTrue(!!radioButton); + assertEquals(radioButton!.getAttribute('aria-pressed'), 'true'); + }); +}
diff --git a/chrome/test/data/webui/chromeos/personalization_app/test_theme_interface_provider.ts b/chrome/test/data/webui/chromeos/personalization_app/test_theme_interface_provider.ts new file mode 100644 index 0000000..286d232 --- /dev/null +++ b/chrome/test/data/webui/chromeos/personalization_app/test_theme_interface_provider.ts
@@ -0,0 +1,34 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import {ThemeObserverInterface, ThemeObserverRemote, ThemeProviderInterface} from 'chrome://personalization/trusted/personalization_app.mojom-webui.js'; +import {TestBrowserProxy} from 'chrome://webui-test/test_browser_proxy.js'; + +/** + * @implements {ThemeProviderInterface} + * @extends {TestBrowserProxy} + */ +export class TestThemeProvider extends TestBrowserProxy implements + ThemeProviderInterface { + constructor() { + super([ + 'setThemeObserver', + 'setColorModePref', + ]); + } + + themeObserverRemote: ThemeObserverInterface|null = null; + + setThemeObserver(remote: ThemeObserverRemote) { + this.methodCalled('setThemeObserver'); + this.themeObserverRemote = remote; + window.setTimeout(() => { + this.themeObserverRemote!.onColorModeChanged(/*darkModeEnabled=*/ true); + }, 0); + } + + setColorModePref(darkModeEnabled: boolean) { + this.methodCalled('setColorModePref', darkModeEnabled); + } +}
diff --git a/chrome/test/data/webui/cr_components/most_visited_focus_test.ts b/chrome/test/data/webui/cr_components/most_visited_focus_test.ts index bb3d1a8..1f1b7a6 100644 --- a/chrome/test/data/webui/cr_components/most_visited_focus_test.ts +++ b/chrome/test/data/webui/cr_components/most_visited_focus_test.ts
@@ -47,8 +47,7 @@ setup(() => { document.body.innerHTML = ''; - const handler = TestBrowserProxy.fromClass(MostVisitedPageHandlerRemote) as - unknown as MostVisitedPageHandlerRemote; + const handler = TestBrowserProxy.fromClass(MostVisitedPageHandlerRemote); const callbackRouter = new MostVisitedPageCallbackRouter(); MostVisitedBrowserProxy.setInstance( new MostVisitedBrowserProxy(handler, callbackRouter));
diff --git a/chrome/test/data/webui/cr_components/most_visited_test.ts b/chrome/test/data/webui/cr_components/most_visited_test.ts index 4d7c5e0a..1c8c1284 100644 --- a/chrome/test/data/webui/cr_components/most_visited_test.ts +++ b/chrome/test/data/webui/cr_components/most_visited_test.ts
@@ -81,9 +81,7 @@ } function createBrowserProxy() { - handler = TestBrowserProxy.fromClass(MostVisitedPageHandlerRemote) as - unknown as MostVisitedPageHandlerRemote & - TestBrowserProxy; + handler = TestBrowserProxy.fromClass(MostVisitedPageHandlerRemote); const callbackRouter = new MostVisitedPageCallbackRouter(); MostVisitedBrowserProxy.setInstance( new MostVisitedBrowserProxy(handler, callbackRouter)); @@ -115,9 +113,7 @@ } function createWindowProxy() { - windowProxy = TestBrowserProxy.fromClass(MostVisitedWindowProxy) as unknown as - MostVisitedWindowProxy & - TestBrowserProxy; + windowProxy = TestBrowserProxy.fromClass(MostVisitedWindowProxy); windowProxy.setResultMapperFor('matchMedia', (query: string) => { const mediaListenerList = new FakeMediaQueryList(query); if (query === '(min-width: 672px)') {
diff --git a/chrome/test/data/webui/test_browser_proxy.d.ts b/chrome/test/data/webui/test_browser_proxy.d.ts new file mode 100644 index 0000000..78edd4c --- /dev/null +++ b/chrome/test/data/webui/test_browser_proxy.d.ts
@@ -0,0 +1,16 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +export class TestBrowserProxy { + static fromClass<T>(clazz: {new(): T}): T&TestBrowserProxy; + constructor(methodNames?: Array<string>); + methodCalled(methodName: string, ...args: any[]): any; + whenCalled(methodName: string): Promise<any>; + resetResolver(methodName: string): void; + reset(): void; + getCallCount(methodName: string): number; + getArgs(methodName: string): Array<any>; + setResultMapperFor(methodName: string, resultMapper: Function): void; + setResultFor(methodName: string, value: any): void; +}
diff --git a/chrome/test/media_router/media_router_cast_ui_for_test.cc b/chrome/test/media_router/media_router_cast_ui_for_test.cc index 8619c0e..7efa64aa 100644 --- a/chrome/test/media_router/media_router_cast_ui_for_test.cc +++ b/chrome/test/media_router/media_router_cast_ui_for_test.cc
@@ -6,7 +6,6 @@ #include "base/memory/raw_ptr.h" #include "chrome/browser/media/router/media_router_feature.h" -#include "chrome/browser/ui/media_router/media_router_file_dialog.h" #include "chrome/browser/ui/media_router/media_router_ui.h" #include "chrome/browser/ui/views/media_router/cast_dialog_sink_button.h" #include "chrome/browser/ui/views/media_router/media_router_dialog_controller_views.h" @@ -27,41 +26,6 @@ ui::EF_LEFT_MOUSE_BUTTON, 0); } -// File dialog with a preset file URL. -class TestMediaRouterFileDialog : public MediaRouterFileDialog { - public: - TestMediaRouterFileDialog(MediaRouterFileDialogDelegate* delegate, GURL url) - : MediaRouterFileDialog(nullptr), delegate_(delegate), file_url_(url) {} - ~TestMediaRouterFileDialog() override = default; - - GURL GetLastSelectedFileUrl() override { return file_url_; } - - void OpenFileDialog(Browser* browser) override { - delegate_->FileDialogFileSelected(ui::SelectedFileInfo()); - } - - private: - raw_ptr<MediaRouterFileDialogDelegate> delegate_; - GURL file_url_; -}; - -// File dialog which fails on open. -class TestFailMediaRouterFileDialog : public MediaRouterFileDialog { - public: - TestFailMediaRouterFileDialog(MediaRouterFileDialogDelegate* delegate, - const IssueInfo& issue) - : MediaRouterFileDialog(nullptr), delegate_(delegate), issue_(issue) {} - ~TestFailMediaRouterFileDialog() override = default; - - void OpenFileDialog(Browser* browser) override { - delegate_->FileDialogSelectionFailed(issue_); - } - - private: - raw_ptr<MediaRouterFileDialogDelegate> delegate_; - const IssueInfo issue_; -}; - } // namespace // static @@ -157,19 +121,6 @@ ObserveDialog(WatchType::kDialogHidden); } -void MediaRouterCastUiForTest::SetLocalFile(const GURL& file_url) { - dialog_controller_->ui()->set_media_router_file_dialog_for_test( - std::make_unique<TestMediaRouterFileDialog>(dialog_controller_->ui(), - file_url)); -} - -void MediaRouterCastUiForTest::SetLocalFileSelectionIssue( - const IssueInfo& issue) { - dialog_controller_->ui()->set_media_router_file_dialog_for_test( - std::make_unique<TestFailMediaRouterFileDialog>(dialog_controller_->ui(), - issue)); -} - void MediaRouterCastUiForTest::OnDialogCreated() { MediaRouterUiForTestBase::OnDialogCreated(); CastDialogView::GetInstance()->KeepShownForTesting();
diff --git a/chrome/test/media_router/media_router_cast_ui_for_test.h b/chrome/test/media_router/media_router_cast_ui_for_test.h index 1654098..1b022b7 100644 --- a/chrome/test/media_router/media_router_cast_ui_for_test.h +++ b/chrome/test/media_router/media_router_cast_ui_for_test.h
@@ -42,8 +42,6 @@ void WaitForAnyRoute() override; void WaitForDialogShown() override; void WaitForDialogHidden() override; - void SetLocalFile(const GURL& file_url) override; - void SetLocalFileSelectionIssue(const IssueInfo& issue) override; void OnDialogCreated() override; private:
diff --git a/chrome/test/media_router/media_router_e2e_ui_browsertest.cc b/chrome/test/media_router/media_router_e2e_ui_browsertest.cc index a40ddcf..2856388 100644 --- a/chrome/test/media_router/media_router_e2e_ui_browsertest.cc +++ b/chrome/test/media_router/media_router_e2e_ui_browsertest.cc
@@ -16,53 +16,6 @@ namespace media_router { -// TODO(crbug.com/903016) Disabled due to flakiness. -IN_PROC_BROWSER_TEST_P(MediaRouterE2EBrowserTest, - DISABLED_OpenLocalMediaFileFullscreen) { - GURL file_url = ui_test_utils::GetTestUrl( - base::FilePath(base::FilePath::kCurrentDirectory), - base::FilePath(FILE_PATH_LITERAL("media/bigbuck.webm"))); - - // Start at a new tab, the file should open in the same tab. - ASSERT_TRUE(ui_test_utils::NavigateToURL(browser(), - GURL(chrome::kChromeUINewTabURL))); - // Make sure there is 1 tab. - ASSERT_EQ(1, browser()->tab_strip_model()->count()); - - content::WebContents* web_contents = - browser()->tab_strip_model()->GetActiveWebContents(); - - test_ui_->ShowDialog(); - test_ui_->WaitForSinkAvailable(receiver_); - - // Mock out file dialog operations, as those can't be simulated. - test_ui_->SetLocalFile(file_url); - // Click on the desired mode. - test_ui_->ChooseSourceType(CastDialogView::kLocalFile); - test_ui_->WaitForSinkAvailable(receiver_); - test_ui_->StartCasting(receiver_); - - // Play the file for 10 seconds. - Wait(base::Seconds(10)); - - // Expect that the current tab has the file open in it. - ASSERT_EQ(file_url, web_contents->GetLastCommittedURL()); - - // Expect that fullscreen is active. - bool is_fullscreen = false; - std::string is_fullscreen_script = - "domAutomationController.send" - "(document.webkitCurrentFullScreenElement != null);"; - CHECK(content::ExecuteScriptAndExtractBool(web_contents, is_fullscreen_script, - &is_fullscreen)); - - ASSERT_TRUE(is_fullscreen); - test_ui_->WaitForSink(receiver_); - test_ui_->StopCasting(receiver_); - // Wait 15s for Chromecast to back to home screen and ready to use status. - Wait(base::Seconds(15)); -} - IN_PROC_BROWSER_TEST_P(MediaRouterE2EBrowserTest, MANUAL_MirrorHTML5Video) { MEDIA_ROUTER_INTEGRATION_BROWER_TEST_CAST_ONLY(); content::WebContents* web_contents =
diff --git a/chrome/test/media_router/media_router_gmc_ui_for_test.cc b/chrome/test/media_router/media_router_gmc_ui_for_test.cc index 5f08185..3f56865a 100644 --- a/chrome/test/media_router/media_router_gmc_ui_for_test.cc +++ b/chrome/test/media_router/media_router_gmc_ui_for_test.cc
@@ -82,15 +82,6 @@ NOTIMPLEMENTED(); } -void MediaRouterGmcUiForTest::SetLocalFile(const GURL& file_url) { - NOTIMPLEMENTED(); -} - -void MediaRouterGmcUiForTest::SetLocalFileSelectionIssue( - const IssueInfo& issue) { - NOTIMPLEMENTED(); -} - MediaRouterGmcUiForTest::MediaRouterGmcUiForTest( content::WebContents* web_contents) : MediaRouterUiForTestBase(web_contents),
diff --git a/chrome/test/media_router/media_router_gmc_ui_for_test.h b/chrome/test/media_router/media_router_gmc_ui_for_test.h index f1560f1..5a1dcea 100644 --- a/chrome/test/media_router/media_router_gmc_ui_for_test.h +++ b/chrome/test/media_router/media_router_gmc_ui_for_test.h
@@ -45,8 +45,6 @@ void WaitForAnyRoute() override; void WaitForDialogShown() override; void WaitForDialogHidden() override; - void SetLocalFile(const GURL& file_url) override; - void SetLocalFileSelectionIssue(const IssueInfo& issue) override; private: friend class content::WebContentsUserData<MediaRouterGmcUiForTest>;
diff --git a/chrome/test/media_router/media_router_integration_browsertest.cc b/chrome/test/media_router/media_router_integration_browsertest.cc index e8cec315..bbf007e 100644 --- a/chrome/test/media_router/media_router_integration_browsertest.cc +++ b/chrome/test/media_router/media_router_integration_browsertest.cc
@@ -20,7 +20,6 @@ #include "chrome/browser/media/router/mojo/media_router_desktop.h" #include "chrome/browser/ui/browser_finder.h" #include "chrome/browser/ui/media_router/media_cast_mode.h" -#include "chrome/browser/ui/media_router/media_router_file_dialog.h" #include "chrome/browser/ui/tabs/tab_strip_model.h" #include "chrome/browser/ui/views/frame/browser_view.h" #include "chrome/common/url_constants.h" @@ -245,31 +244,6 @@ return web_contents; } -void MediaRouterIntegrationBrowserTest::OpenDialogAndCastFile() { - GURL file_url = net::FilePathToFileURL( - media::GetTestDataFilePath("butterfly-853x480.webm")); - test_ui_->ShowDialog(); - // Mock out file dialog operations, as those can't be simulated. - test_ui_->SetLocalFile(file_url); - test_ui_->WaitForSink(receiver_); - test_ui_->ChooseSourceType(CastDialogView::kLocalFile); - ASSERT_EQ(CastDialogView::kLocalFile, test_ui_->GetChosenSourceType()); - test_ui_->WaitForSinkAvailable(receiver_); - test_ui_->StartCasting(receiver_); - ASSERT_EQ(file_url, GetActiveWebContents()->GetVisibleURL()); -} - -void MediaRouterIntegrationBrowserTest::OpenDialogAndCastFileFails() { - GURL file_url = - net::FilePathToFileURL(media::GetTestDataFilePath("easy.webm")); - test_ui_->ShowDialog(); - // Mock out file dialog operations, as those can't be simulated. - test_ui_->SetLocalFileSelectionIssue(IssueInfo()); - test_ui_->WaitForSink(receiver_); - test_ui_->ChooseSourceType(CastDialogView::kLocalFile); - test_ui_->WaitForAnyIssue(); -} - void MediaRouterIntegrationBrowserTest::OpenTestPage( base::FilePath::StringPieceType file_name) { base::FilePath full_path = GetResourceFile(file_name); @@ -491,134 +465,6 @@ RunBasicTest(); } -// Tests that creating a route with a local file opens the file in a new tab. -// -// This test was disabled because the test needs to wait until navigation is -// complete before looking for the route, but it's not clear how to do that -// without deadlocking the test. -// This test passed locally when running with native test provider, so it -// is updated to MANUAL and is allowed to run on private waterfall. -IN_PROC_BROWSER_TEST_P(MediaRouterIntegrationBrowserTest, - MANUAL_OpenLocalMediaFileInCurrentTab) { - MEDIA_ROUTER_INTEGRATION_BROWER_TEST_CAST_ONLY(); - // Start at a new tab, the file should open in the same tab. - ASSERT_TRUE(ui_test_utils::NavigateToURL(browser(), - GURL(chrome::kChromeUINewTabURL))); - // Make sure there is 1 tab. - ASSERT_EQ(1, browser()->tab_strip_model()->count()); - - OpenDialogAndCastFile(); - - // Expect that no new tab has been opened. - ASSERT_EQ(1, browser()->tab_strip_model()->count()); - - // The dialog will close from navigating to the local file within the tab, so - // open it again after it closes. - test_ui_->WaitForDialogHidden(); - test_ui_->ShowDialog(); - - // Wait for a route to be created. - test_ui_->WaitForAnyRoute(); -} - -// TODO(http://crbug.com/1095068): There maybe a crash on Linux and ChromeOS. -#if defined(OS_LINUX) || defined(OS_CHROMEOS) -#define MAYBE_OpenLocalMediaFileInNewTab MANUAL_OpenLocalMediaFileInNewTab -#else -#define MAYBE_OpenLocalMediaFileInNewTab OpenLocalMediaFileInNewTab -#endif - -// Tests that creating a route with a local file opens the file in a new tab. -IN_PROC_BROWSER_TEST_P(MediaRouterIntegrationBrowserTest, - MAYBE_OpenLocalMediaFileInNewTab) { - MEDIA_ROUTER_INTEGRATION_BROWER_TEST_CAST_ONLY(); - // Start at a tab with content in it, the file will open in a new tab. - ASSERT_TRUE( - ui_test_utils::NavigateToURL(browser(), GURL("https://google.com"))); - // Make sure there is 1 tab. - ASSERT_EQ(1, browser()->tab_strip_model()->count()); - - OpenDialogAndCastFile(); - - // Expect that a new tab has been opened. - ASSERT_EQ(2, browser()->tab_strip_model()->count()); - - test_ui_->ShowDialog(); - - // Wait for a route to be created. - test_ui_->WaitForAnyRoute(); -} - -// Tests that failing to create a route with a local file shows an issue. -// TODO(https://crbug.com/907539): Make the Views dialog show the issue. -IN_PROC_BROWSER_TEST_P(MediaRouterIntegrationBrowserTest, - DISABLED_OpenLocalMediaFileFailsAndShowsIssue) { - OpenDialogAndCastFileFails(); - // Expect that the issue is showing. - ASSERT_TRUE(IsUIShowingIssue()); -} - -// Tests that creating a route with a local file opens in fullscreen. -// TODO(https://crbug.com/903016) Disabled for being flaky in entering -// fullscreen. -IN_PROC_BROWSER_TEST_P(MediaRouterIntegrationBrowserTest, - DISABLED_OpenLocalMediaFileFullscreen) { - // Start at a new tab, the file should open in the same tab. - ASSERT_TRUE(ui_test_utils::NavigateToURL(browser(), - GURL(chrome::kChromeUINewTabURL))); - // Make sure there is 1 tab. - ASSERT_EQ(1, browser()->tab_strip_model()->count()); - - OpenDialogAndCastFile(); - - // Increment web contents capturer count so it thinks capture has started. - // This will allow the file tab to go fullscreen. - content::WebContents* web_contents = GetActiveWebContents(); - auto capture_handle = - web_contents->IncrementCapturerCount(gfx::Size(), /*stay_hidden=*/false, - /*stay_awake=*/true); - - // Wait for capture poll timer to pick up change. - Wait(base::Seconds(3)); - - // Expect that fullscreen was entered. - ASSERT_TRUE( - web_contents->GetDelegate()->IsFullscreenForTabOrPending(web_contents)); -} - -// Flaky on MSan bots: http://crbug.com/879885 -#if defined(MEMORY_SANITIZER) -#define MAYBE_OpenLocalMediaFileCastFailNoFullscreen \ - MANUAL_OpenLocalMediaFileCastFailNoFullscreen -#else -#define MAYBE_OpenLocalMediaFileCastFailNoFullscreen \ - OpenLocalMediaFileCastFailNoFullscreen -#endif -// Tests that failed route creation of local file does not enter fullscreen. -IN_PROC_BROWSER_TEST_P(MediaRouterIntegrationBrowserTest, - MAYBE_OpenLocalMediaFileCastFailNoFullscreen) { - MEDIA_ROUTER_INTEGRATION_BROWER_TEST_CAST_ONLY(); - test_provider_->set_route_error_message("Unknown error"); - // Start at a new tab, the file should open in the same tab. - ASSERT_TRUE(ui_test_utils::NavigateToURL(browser(), - GURL(chrome::kChromeUINewTabURL))); - // Make sure there is 1 tab. - ASSERT_EQ(1, browser()->tab_strip_model()->count()); - - OpenDialogAndCastFile(); - - // Wait for file to start playing (but not being captured). - Wait(base::Seconds(3)); - - // Expect no capture is ongoing. - content::WebContents* web_contents = GetActiveWebContents(); - ASSERT_FALSE(web_contents->IsBeingCaptured()); - - // Expect that fullscreen is not entered. - ASSERT_FALSE( - web_contents->GetDelegate()->IsFullscreenForTabOrPending(web_contents)); -} - // TODO(crbug.com/1238728): Test is flaky on Windows and Linux. #if defined(OS_LINUX) || defined(OS_WIN) #define MAYBE_SendAndOnMessage MANUAL_SendAndOnMessage
diff --git a/chrome/test/media_router/media_router_integration_browsertest.h b/chrome/test/media_router/media_router_integration_browsertest.h index 7d2314cf..a413c2ad 100644 --- a/chrome/test/media_router/media_router_integration_browsertest.h +++ b/chrome/test/media_router/media_router_integration_browsertest.h
@@ -116,17 +116,6 @@ // |should_succeed| is true. virtual content::WebContents* StartSessionWithTestPageAndChooseSink(); - // Opens the MR dialog and clicks through the motions of casting a - // file. Sets up the route provider to succeed or otherwise based on - // |route_success|. Note: The system dialog portion has to be mocked - // out as it cannot be simulated. - void OpenDialogAndCastFile(); - - // Opens the MR dialog and clicks through the motions of choosing to - // cast file, file returns an issue. Note: The system dialog portion - // has to be mocked out as it cannot be simulated. - void OpenDialogAndCastFileFails(); - void OpenTestPage(base::FilePath::StringPieceType file); void OpenTestPageInNewTab(base::FilePath::StringPieceType file); virtual GURL GetTestPageUrl(const base::FilePath& full_path);
diff --git a/chrome/test/media_router/media_router_ui_for_test_base.h b/chrome/test/media_router/media_router_ui_for_test_base.h index 76d7521..6526888 100644 --- a/chrome/test/media_router/media_router_ui_for_test_base.h +++ b/chrome/test/media_router/media_router_ui_for_test_base.h
@@ -56,11 +56,6 @@ std::string GetStatusTextForSink(const std::string& sink_name) const; std::string GetIssueTextForSink(const std::string& sink_name) const; - // Sets up a mock file picker that returns |file_url| as the selected file. - virtual void SetLocalFile(const GURL& file_url) = 0; - // Sets up a mock file picker that fails with |issue|. - virtual void SetLocalFileSelectionIssue(const IssueInfo& issue) = 0; - // Called by MediaRouterDialogControllerViews. virtual void OnDialogCreated();
diff --git a/chrome/test/v8/wasm_trap_handler_browsertest.cc b/chrome/test/v8/wasm_trap_handler_browsertest.cc index f16210c..2092971 100644 --- a/chrome/test/v8/wasm_trap_handler_browsertest.cc +++ b/chrome/test/v8/wasm_trap_handler_browsertest.cc
@@ -5,8 +5,9 @@ // These tests focus on Wasm out of bounds behavior to make sure trap-based // bounds checks work when integrated with all of Chrome. +#include <tuple> + #include "base/base_switches.h" -#include "base/ignore_result.h" #include "build/build_config.h" #include "chrome/browser/ui/browser.h" #include "chrome/browser/ui/tabs/tab_strip_model.h" @@ -132,7 +133,7 @@ // Sanitizers may prevent signal handler installation and thereby trap handler // setup. As there is no easy way to test if signal handler installation is // possible, we disable this test for sanitizers. - ignore_result(is_trap_handler_enabled); + std::ignore = is_trap_handler_enabled; return; #endif
diff --git a/chrome/utility/safe_browsing/mac/hfs_fuzzer.cc b/chrome/utility/safe_browsing/mac/hfs_fuzzer.cc index 6ab1b80..c3764cb 100644 --- a/chrome/utility/safe_browsing/mac/hfs_fuzzer.cc +++ b/chrome/utility/safe_browsing/mac/hfs_fuzzer.cc
@@ -6,9 +6,9 @@ #include <stdint.h> #include <memory> +#include <tuple> #include <vector> -#include "base/ignore_result.h" #include "chrome/utility/safe_browsing/mac/hfs.h" #include "chrome/utility/safe_browsing/mac/read_stream.h" #include "testing/libfuzzer/libfuzzer_exports.h" @@ -24,9 +24,9 @@ while (hfs_iterator.Next()) { // Test accessing properties. - ignore_result(hfs_iterator.IsSymbolicLink()); - ignore_result(hfs_iterator.IsDecmpfsCompressed()); - ignore_result(hfs_iterator.GetPath()); + std::ignore = hfs_iterator.IsSymbolicLink(); + std::ignore = hfs_iterator.IsDecmpfsCompressed(); + std::ignore = hfs_iterator.GetPath(); if (hfs_iterator.IsDirectory() || hfs_iterator.IsHardLink()) continue;
diff --git a/chromecast/browser/BUILD.gn b/chromecast/browser/BUILD.gn index 16b530e..4b54bf4a 100644 --- a/chromecast/browser/BUILD.gn +++ b/chromecast/browser/BUILD.gn
@@ -121,6 +121,8 @@ "cast_content_gesture_handler.h", "cast_download_manager_delegate.cc", "cast_download_manager_delegate.h", + "cast_feature_update_observer.cc", + "cast_feature_update_observer.h", "cast_http_user_agent_settings.cc", "cast_http_user_agent_settings.h", "cast_media_blocker.cc",
diff --git a/chromecast/browser/cast_browser_main_parts.cc b/chromecast/browser/cast_browser_main_parts.cc index eabc26f..507872f0 100644 --- a/chromecast/browser/cast_browser_main_parts.cc +++ b/chromecast/browser/cast_browser_main_parts.cc
@@ -41,6 +41,7 @@ #include "chromecast/browser/cast_content_browser_client.h" #include "chromecast/browser/cast_extension_url_loader_factory.h" #include "chromecast/browser/cast_feature_list_creator.h" +#include "chromecast/browser/cast_feature_update_observer.h" #include "chromecast/browser/cast_system_memory_pressure_evaluator.h" #include "chromecast/browser/cast_system_memory_pressure_evaluator_adjuster.h" #include "chromecast/browser/cast_web_service.h" @@ -769,6 +770,12 @@ cast_browser_process_->cast_service()->Start(); + if (base::CommandLine::ForCurrentProcess()->HasSwitch( + switches::kUseCastBrowserPrefConfig)) { + feature_update_observer_ = std::make_unique<CastFeatureUpdateObserver>( + connector(), cast_browser_process_->pref_service()); + } + return content::RESULT_CODE_NORMAL_EXIT; }
diff --git a/chromecast/browser/cast_browser_main_parts.h b/chromecast/browser/cast_browser_main_parts.h index 9f02df6..d3b2f95 100644 --- a/chromecast/browser/cast_browser_main_parts.h +++ b/chromecast/browser/cast_browser_main_parts.h
@@ -35,6 +35,7 @@ #endif // defined(USE_AURA) namespace chromecast { +class CastFeatureUpdateObserver; class CastSystemMemoryPressureEvaluatorAdjuster; class CastWebService; class DisplaySettingsManager; @@ -177,6 +178,8 @@ std::unique_ptr<WaylandServerController> wayland_server_controller_; #endif + std::unique_ptr<CastFeatureUpdateObserver> feature_update_observer_; + #if defined(USE_AURA) && !defined(OS_FUCHSIA) // Only used when running with --enable-ui-devtools. std::unique_ptr<CastUIDevTools> ui_devtools_;
diff --git a/chromecast/browser/cast_feature_update_observer.cc b/chromecast/browser/cast_feature_update_observer.cc new file mode 100644 index 0000000..008cff2 --- /dev/null +++ b/chromecast/browser/cast_feature_update_observer.cc
@@ -0,0 +1,52 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "chromecast/browser/cast_feature_update_observer.h" + +#include "base/bind.h" +#include "base/check.h" +#include "chromecast/base/pref_names.h" +#include "chromecast/common/mojom/constants.mojom.h" +#include "chromecast/external_mojo/external_service_support/external_connector.h" +#include "components/prefs/pref_service.h" + +namespace chromecast { + +CastFeatureUpdateObserver::CastFeatureUpdateObserver( + external_service_support::ExternalConnector* connector, + PrefService* pref_service) + : connector_(connector), pref_service_(pref_service) { + DCHECK(connector_); + DCHECK(pref_service_); + + BindFeatureUpdateService(); +} + +CastFeatureUpdateObserver::~CastFeatureUpdateObserver() = default; + +void CastFeatureUpdateObserver::BindFeatureUpdateService() { + feature_update_service_.reset(); + receiver_.reset(); + connector_->BindInterface( + mojom::kChromecastServiceName, + feature_update_service_.BindNewPipeAndPassReceiver()); + feature_update_service_->RegisterFeatureUpdateObserver( + receiver_.BindNewPipeAndPassRemote()); + + // Right now we are in the process of making the `cast_service` manage the + // lifecycle of `cast_browser`. Until that is done, `cast_service` has a + // shorter lifecycle than `cast_browser`, so we need to handle disconnects + // here. + // TODO(crbug/1285360): remove once process lifecycles are inverted. + receiver_.set_disconnect_handler( + base::BindOnce(&CastFeatureUpdateObserver::BindFeatureUpdateService, + base::Unretained(this))); +} + +void CastFeatureUpdateObserver::OnFeaturesUpdated(base::Value features) { + pref_service_->Set(prefs::kLatestDCSFeatures, std::move(features)); + pref_service_->CommitPendingWrite(); +} + +} // namespace chromecast
diff --git a/chromecast/browser/cast_feature_update_observer.h b/chromecast/browser/cast_feature_update_observer.h new file mode 100644 index 0000000..21d606f1 --- /dev/null +++ b/chromecast/browser/cast_feature_update_observer.h
@@ -0,0 +1,47 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CHROMECAST_BROWSER_CAST_FEATURE_UPDATE_OBSERVER_H_ +#define CHROMECAST_BROWSER_CAST_FEATURE_UPDATE_OBSERVER_H_ + +#include "base/values.h" +#include "chromecast/common/mojom/feature_update.mojom.h" +#include "mojo/public/cpp/bindings/receiver.h" +#include "mojo/public/cpp/bindings/remote.h" + +class PrefService; + +namespace chromecast { + +namespace external_service_support { +class ExternalConnector; +} // namespace external_service_support + +class CastFeatureUpdateObserver + : public chromecast::mojom::FeatureUpdateObserver { + public: + CastFeatureUpdateObserver( + external_service_support::ExternalConnector* connector, + PrefService* pref_service); + CastFeatureUpdateObserver(const CastFeatureUpdateObserver&) = delete; + CastFeatureUpdateObserver& operator=(const CastFeatureUpdateObserver&) = + delete; + ~CastFeatureUpdateObserver() override; + + private: + // chromecast::mojom::FeatureUpdateObserver implementation: + void OnFeaturesUpdated(base::Value features) override; + + void BindFeatureUpdateService(); + + external_service_support::ExternalConnector* const connector_; + PrefService* const pref_service_; + + mojo::Receiver<chromecast::mojom::FeatureUpdateObserver> receiver_{this}; + mojo::Remote<chromecast::mojom::FeatureUpdateService> feature_update_service_; +}; + +} // namespace chromecast + +#endif // CHROMECAST_BROWSER_CAST_FEATURE_UPDATE_OBSERVER_H_
diff --git a/chromecast/common/mojom/BUILD.gn b/chromecast/common/mojom/BUILD.gn index c4773a3..255fae0 100644 --- a/chromecast/common/mojom/BUILD.gn +++ b/chromecast/common/mojom/BUILD.gn
@@ -15,6 +15,7 @@ "cast_demo.mojom", "constants.mojom", "feature_manager.mojom", + "feature_update.mojom", "gesture.mojom", "identification_settings.mojom", "js_channel.mojom",
diff --git a/chromecast/common/mojom/feature_update.mojom b/chromecast/common/mojom/feature_update.mojom new file mode 100644 index 0000000..1bb3ebd --- /dev/null +++ b/chromecast/common/mojom/feature_update.mojom
@@ -0,0 +1,23 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +module chromecast.mojom; + +import "mojo/public/mojom/base/values.mojom"; + +// This interface is implemented by the clients of `FeatureUpdateService`. +interface FeatureUpdateObserver { + // Triggered when there is an update to base::Feature configs because a + // DCS download is completed. `features` is a dictionary which is ready to be + // persisted to disk (via PrefService). + OnFeaturesUpdated(mojo_base.mojom.DictionaryValue features); +}; + +// This interface is implemented in the Cast Service process and allows the +// observers to receive base::Feature state updates after base::Feature +// overrides are downloaded successfully from the cloud service (DCS). +interface FeatureUpdateService { + // Adds an observer to receive feature config updates. + RegisterFeatureUpdateObserver(pending_remote<FeatureUpdateObserver> observer); +};
diff --git a/chromeos/CHROMEOS_LKGM b/chromeos/CHROMEOS_LKGM index fc73a4d..20d0bdbc 100644 --- a/chromeos/CHROMEOS_LKGM +++ b/chromeos/CHROMEOS_LKGM
@@ -1 +1 @@ -14436.0.0 \ No newline at end of file +14443.0.0 \ No newline at end of file
diff --git a/chromeos/chromeos_strings.grd b/chromeos/chromeos_strings.grd index 6029b9a..df7cab35 100644 --- a/chromeos/chromeos_strings.grd +++ b/chromeos/chromeos_strings.grd
@@ -1959,6 +1959,15 @@ <message name="IDS_PERSONALIZATION_APP_SET_AS_WALLPAPER" desc="Label for the button to confirm preview wallpaper"> Set as wallpaper </message> + <message name="IDS_PERSONALIZATION_APP_THEME_LABEL" desc="Label for theme element in personalization hub"> + Theme + </message> + <message name="IDS_PERSONALIZATION_APP_THEME_DARK_COLOR_MODE" desc="Label for dark color mode setting"> + Dark + </message> + <message name="IDS_PERSONALIZATION_APP_THEME_LIGHT_COLOR_MODE" desc="Label for light color mode setting"> + Light + </message> <!-- Traffic Counters UI --> <message name="IDS_TRAFFIC_COUNTERS_UNKNOWN" desc="Traffic counters related to an unknown source">
diff --git a/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_THEME_DARK_COLOR_MODE.png.sha1 b/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_THEME_DARK_COLOR_MODE.png.sha1 new file mode 100644 index 0000000..3123943 --- /dev/null +++ b/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_THEME_DARK_COLOR_MODE.png.sha1
@@ -0,0 +1 @@ +df267d0c605415a09d35dd596bb1b1574148186c \ No newline at end of file
diff --git a/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_THEME_LABEL.png.sha1 b/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_THEME_LABEL.png.sha1 new file mode 100644 index 0000000..3123943 --- /dev/null +++ b/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_THEME_LABEL.png.sha1
@@ -0,0 +1 @@ +df267d0c605415a09d35dd596bb1b1574148186c \ No newline at end of file
diff --git a/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_THEME_LIGHT_COLOR_MODE.png.sha1 b/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_THEME_LIGHT_COLOR_MODE.png.sha1 new file mode 100644 index 0000000..3123943 --- /dev/null +++ b/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_THEME_LIGHT_COLOR_MODE.png.sha1
@@ -0,0 +1 @@ +df267d0c605415a09d35dd596bb1b1574148186c \ No newline at end of file
diff --git a/chromeos/dbus/fwupd/fwupd_client.cc b/chromeos/dbus/fwupd/fwupd_client.cc index 9bcd331..7682587 100644 --- a/chromeos/dbus/fwupd/fwupd_client.cc +++ b/chromeos/dbus/fwupd/fwupd_client.cc
@@ -101,8 +101,10 @@ } writer.CloseContainer(&array_writer); + // TODO(michaelcheco): Investigate whether or not the estimated install time + // multiplied by some factor can be used in place of |TIMEOUT_INFINITE|. proxy_->CallMethodWithErrorResponse( - &method_call, dbus::ObjectProxy::TIMEOUT_USE_DEFAULT, + &method_call, dbus::ObjectProxy::TIMEOUT_INFINITE, base::BindOnce(&FwupdClientImpl::InstallUpdateCallback, weak_ptr_factory_.GetWeakPtr())); }
diff --git a/chromeos/network/policy_applicator.cc b/chromeos/network/policy_applicator.cc index 7d25520..bde95d0 100644 --- a/chromeos/network/policy_applicator.cc +++ b/chromeos/network/policy_applicator.cc
@@ -309,13 +309,11 @@ profile_, new_guid, &global_network_config_, new_policy_as_dict, user_settings); - if (features::IsESimPolicyEnabled()) { - // Copy over the value of ICCID and EID property from old entry to new shill - // properties since Shill requires ICCID and EID to create or update the - // existing service. - CopyRequiredCellularProperies(entry_properties_as_dict, - &new_shill_properties); - } + // Copy over the value of ICCID and EID property from old entry to new shill + // properties since Shill requires ICCID and EID to create or update the + // existing service. + CopyRequiredCellularProperies(entry_properties_as_dict, + &new_shill_properties); // A new policy has to be applied to this profile entry. In order to keep // implicit state of Shill like "connected successfully before", keep the
diff --git a/components/app_restore/restore_data.cc b/components/app_restore/restore_data.cc index d5477ba5..3b7ca26 100644 --- a/components/app_restore/restore_data.cc +++ b/components/app_restore/restore_data.cc
@@ -17,38 +17,31 @@ RestoreData::RestoreData() = default; RestoreData::RestoreData(std::unique_ptr<base::Value> restore_data_value) { - base::DictionaryValue* restore_data_dict = nullptr; - if (!restore_data_value || !restore_data_value->is_dict() || - !restore_data_value->GetAsDictionary(&restore_data_dict) || - !restore_data_dict) { + if (!restore_data_value || !restore_data_value->is_dict()) { DVLOG(0) << "Fail to parse full restore data. " << "Cannot find the full restore data dict."; return; } - for (base::DictionaryValue::Iterator iter(*restore_data_dict); - !iter.IsAtEnd(); iter.Advance()) { - const std::string& app_id = iter.key(); - base::Value* value = restore_data_dict->FindDictKey(app_id); - base::DictionaryValue* data_dict = nullptr; - if (!value || !value->is_dict() || !value->GetAsDictionary(&data_dict) || - !data_dict) { + for (auto iter : restore_data_value->DictItems()) { + const std::string& app_id = iter.first; + base::Value* value = restore_data_value->FindDictKey(app_id); + if (!value || !value->is_dict()) { DVLOG(0) << "Fail to parse full restore data. " << "Cannot find the app restore data dict."; continue; } - for (base::DictionaryValue::Iterator data_iter(*data_dict); - !data_iter.IsAtEnd(); data_iter.Advance()) { + for (auto data_iter : value->DictItems()) { int window_id = 0; - if (!base::StringToInt(data_iter.key(), &window_id)) { + if (!base::StringToInt(data_iter.first, &window_id)) { DVLOG(0) << "Fail to parse full restore data. " << "Cannot find the valid id."; continue; } app_id_to_launch_list_[app_id][window_id] = std::make_unique<AppRestoreData>( - std::move(*data_dict->FindDictKey(data_iter.key()))); + std::move(*value->FindDictKey(data_iter.first))); } } }
diff --git a/components/autofill/core/browser/BUILD.gn b/components/autofill/core/browser/BUILD.gn index c812fc5..a5d0663 100644 --- a/components/autofill/core/browser/BUILD.gn +++ b/components/autofill/core/browser/BUILD.gn
@@ -257,11 +257,14 @@ "payments/payments_requests/select_challenge_option_request.h", "payments/payments_requests/unmask_card_request.cc", "payments/payments_requests/unmask_card_request.h", + "payments/payments_requests/update_virtual_card_enrollment_request.cc", + "payments/payments_requests/update_virtual_card_enrollment_request.h", "payments/payments_service_url.cc", "payments/payments_service_url.h", "payments/payments_util.cc", "payments/payments_util.h", "payments/risk_data_loader.h", + "payments/virtual_card_enrollment_flow.h", "payments/virtual_card_enrollment_manager.cc", "payments/virtual_card_enrollment_manager.h", "payments/wait_for_signal_or_timeout.cc",
diff --git a/components/autofill/core/browser/form_parsing/address_field.cc b/components/autofill/core/browser/form_parsing/address_field.cc index a6a0d82..c41bd6e 100644 --- a/components/autofill/core/browser/form_parsing/address_field.cc +++ b/components/autofill/core/browser/form_parsing/address_field.cc
@@ -28,6 +28,35 @@ return true; } +// Removes the |attribute| from all |patterns|. +// TODO(crbug/1142936): This is necessary for +// AddressField::ParseNameAndLabelSeparately(). +int WithoutAttribute(int match_type, MatchAttributes attribute) { + return match_type & ~attribute; +} + +// Removes the |attribute| from all |patterns|. +// TODO(crbug/1142936): This is necessary for +// AddressField::ParseNameAndLabelSeparately(). +std::vector<MatchingPattern> WithoutAttribute( + std::vector<MatchingPattern> patterns, + MatchAttributes attribute) { + for (MatchingPattern& p : patterns) + p.match_field_attributes &= ~attribute; + return patterns; +} + +// Adds the |field_type| to all |patterns|. +// TODO(crbug/1142936): This is necessary for AddressField::ParseAddressLines() +// and AddressField::Parse(). +std::vector<MatchingPattern> WithFieldType( + std::vector<MatchingPattern> patterns, + MatchFieldTypes field_type) { + for (MatchingPattern& p : patterns) + p.match_field_input_types |= field_type; + return patterns; +} + } // namespace // Some sites use type="tel" for zip fields (to get a numerical input). @@ -90,8 +119,8 @@ // Ignore email addresses. } else if (ParseFieldSpecifics( scanner, kEmailRe, MATCH_DEFAULT | MATCH_TEXT_AREA, - email_patterns, nullptr, {log_manager, "kEmailRe"}, - {.augment_types = MATCH_TEXT_AREA})) { + WithFieldType(email_patterns, MATCH_TEXT_AREA), nullptr, + {log_manager, "kEmailRe"})) { continue; } else if (address_field->ParseAddress(scanner, page_language) || address_field->ParseDependentLocalityCityStateCountryZipCode( @@ -298,6 +327,10 @@ PatternProvider::GetInstance().GetMatchPatterns("ADDRESS_LINE_1", page_language); + // TODO(crbug.com/1121990): Remove duplicate calls when launching + // AutofillParsingPatternProvider. The old code calls ParseFieldSpecifics() + // for two different patterns, |pattern| and |label_pattern|. The new code + // handles both patterns at once in the |address_line1_patterns|. if (!ParseFieldSpecifics(scanner, pattern, MATCH_DEFAULT | MATCH_SEARCH, address_line1_patterns, &address1_, {log_manager_, "kAddressLine1Re"}) && @@ -305,17 +338,16 @@ MATCH_LABEL | MATCH_SEARCH | MATCH_TEXT, address_line1_patterns, &address1_, {log_manager_, "kAddressLine1LabelRe"}) && - !ParseFieldSpecifics(scanner, pattern, - MATCH_DEFAULT | MATCH_SEARCH | MATCH_TEXT_AREA, - address_line1_patterns, &street_address_, - {log_manager_, "kAddressLine1Re"}, - {.augment_types = MATCH_TEXT_AREA}) && - !ParseFieldSpecifics(scanner, label_pattern, - MATCH_LABEL | MATCH_SEARCH | MATCH_TEXT_AREA, - address_line1_patterns, &street_address_, - {log_manager_, "kAddressLine1LabelRe"}, - {.augment_types = MATCH_TEXT_AREA})) + !ParseFieldSpecifics( + scanner, pattern, MATCH_DEFAULT | MATCH_SEARCH | MATCH_TEXT_AREA, + WithFieldType(address_line1_patterns, MATCH_TEXT_AREA), + &street_address_, {log_manager_, "kAddressLine1Re"}) && + !ParseFieldSpecifics( + scanner, label_pattern, MATCH_LABEL | MATCH_SEARCH | MATCH_TEXT_AREA, + WithFieldType(address_line1_patterns, MATCH_TEXT_AREA), + &street_address_, {log_manager_, "kAddressLine1LabelRe"})) { return false; + } if (street_address_) return true; @@ -348,8 +380,9 @@ {log_manager_, "kAddressLinesExtraRe"}) && !ParseFieldSpecifics(scanner, label_pattern, MATCH_LABEL | MATCH_TEXT, address_line2_patterns, &address3_, - {log_manager_, "kAddressLine2LabelRe"})) + {log_manager_, "kAddressLine2LabelRe"})) { return true; + } // Try for surplus lines, which we will promptly discard. Some pages have 4 // address lines (e.g. uk/ShoesDirect2.html)! @@ -469,12 +502,12 @@ AutofillField* cur_match = nullptr; size_t saved_cursor = scanner->SaveCursor(); bool parsed_name = ParseFieldSpecifics( - scanner, pattern, match_type & ~MATCH_LABEL, patterns, &cur_match, - logging, {.restrict_attributes = MATCH_NAME}); + scanner, pattern, WithoutAttribute(match_type, MATCH_LABEL), + WithoutAttribute(patterns, MATCH_LABEL), &cur_match, logging); scanner->RewindTo(saved_cursor); bool parsed_label = ParseFieldSpecifics( - scanner, pattern, match_type & ~MATCH_NAME, patterns, &cur_match, logging, - {.restrict_attributes = MATCH_LABEL}); + scanner, pattern, WithoutAttribute(match_type, MATCH_NAME), + WithoutAttribute(patterns, MATCH_NAME), &cur_match, logging); if (parsed_name && parsed_label) { if (match) *match = cur_match;
diff --git a/components/autofill/core/browser/form_parsing/form_field.cc b/components/autofill/core/browser/form_parsing/form_field.cc index eb0cbe3..3350480 100644 --- a/components/autofill/core/browser/form_parsing/form_field.cc +++ b/components/autofill/core/browser/form_parsing/form_field.cc
@@ -277,20 +277,8 @@ int match_type, const std::vector<MatchingPattern>& patterns, AutofillField** match, - const RegExLogging& logging, - MatchFieldBitmasks match_field_bitmasks) { + const RegExLogging& logging) { if (base::FeatureList::IsEnabled(features::kAutofillParsingPatternProvider)) { - // TODO(crbug/1142936): This hack is to allow - // AddressField::ParseNameAndLabelSeparately(). - if (match_field_bitmasks.restrict_attributes != ~0 || - match_field_bitmasks.augment_types != 0) { - std::vector<MatchingPattern> modified_patterns = patterns; - for (MatchingPattern& mp : modified_patterns) { - mp.match_field_attributes &= match_field_bitmasks.restrict_attributes; - mp.match_field_input_types |= match_field_bitmasks.augment_types; - } - return ParseFieldSpecifics(scanner, modified_patterns, match, logging); - } return ParseFieldSpecifics(scanner, patterns, match, logging); } else { return ParseFieldSpecificsWithLegacyPattern(scanner, pattern, match_type,
diff --git a/components/autofill/core/browser/form_parsing/form_field.h b/components/autofill/core/browser/form_parsing/form_field.h index 4c4c86ff..33f3906b 100644 --- a/components/autofill/core/browser/form_parsing/form_field.h +++ b/components/autofill/core/browser/form_parsing/form_field.h
@@ -127,20 +127,12 @@ AutofillField** match, const RegExLogging& logging = {}); - struct MatchFieldBitmasks { - int restrict_attributes = ~0; - int augment_types = 0; - }; - static bool ParseFieldSpecifics(AutofillScanner* scanner, base::StringPiece16 pattern, int match_type, const std::vector<MatchingPattern>& patterns, AutofillField** match, - const RegExLogging& logging, - MatchFieldBitmasks match_field_bitmasks = { - .restrict_attributes = ~0, - .augment_types = 0}); + const RegExLogging& logging); // Attempts to parse a field with an empty label. Returns true // on success and fills |match| with a pointer to the field.
diff --git a/components/autofill/core/browser/payments/payments_client.cc b/components/autofill/core/browser/payments/payments_client.cc index d3258e9..a3e947f 100644 --- a/components/autofill/core/browser/payments/payments_client.cc +++ b/components/autofill/core/browser/payments/payments_client.cc
@@ -30,6 +30,7 @@ #include "components/autofill/core/browser/payments/payments_requests/payments_request.h" #include "components/autofill/core/browser/payments/payments_requests/select_challenge_option_request.h" #include "components/autofill/core/browser/payments/payments_requests/unmask_card_request.h" +#include "components/autofill/core/browser/payments/payments_requests/update_virtual_card_enrollment_request.h" #include "components/autofill/core/browser/payments/payments_service_url.h" #include "components/autofill/core/common/autofill_features.h" #include "components/autofill/core/common/autofill_payments_features.h" @@ -968,6 +969,14 @@ PaymentsClient::GetDetailsForEnrollmentResponseDetails:: ~GetDetailsForEnrollmentResponseDetails() = default; +PaymentsClient::UpdateVirtualCardEnrollmentRequestDetails:: + UpdateVirtualCardEnrollmentRequestDetails() = default; +PaymentsClient::UpdateVirtualCardEnrollmentRequestDetails:: + UpdateVirtualCardEnrollmentRequestDetails( + const UpdateVirtualCardEnrollmentRequestDetails&) = default; +PaymentsClient::UpdateVirtualCardEnrollmentRequestDetails:: + ~UpdateVirtualCardEnrollmentRequestDetails() = default; + PaymentsClient::PaymentsClient( scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, signin::IdentityManager* identity_manager, @@ -1068,6 +1077,14 @@ /*authenticate=*/true); } +void PaymentsClient::UpdateVirtualCardEnrollment( + const UpdateVirtualCardEnrollmentRequestDetails& request_details, + base::OnceCallback<void(AutofillClient::PaymentsRpcResult)> callback) { + IssueRequest(std::make_unique<UpdateVirtualCardEnrollmentRequest>( + request_details, std::move(callback)), + /*authenticate=*/true); +} + void PaymentsClient::CancelRequest() { request_.reset(); resource_request_.reset();
diff --git a/components/autofill/core/browser/payments/payments_client.h b/components/autofill/core/browser/payments/payments_client.h index a7252e5..85a4eba47 100644 --- a/components/autofill/core/browser/payments/payments_client.h +++ b/components/autofill/core/browser/payments/payments_client.h
@@ -20,6 +20,7 @@ #include "components/autofill/core/browser/data_model/credit_card.h" #include "components/autofill/core/browser/payments/card_unmask_challenge_option.h" #include "components/autofill/core/browser/payments/card_unmask_delegate.h" +#include "components/autofill/core/browser/payments/virtual_card_enrollment_flow.h" #include "components/signin/public/identity_manager/access_token_fetcher.h" #include "components/signin/public/identity_manager/access_token_info.h" #include "google_apis/gaia/google_service_auth_error.h" @@ -109,7 +110,7 @@ // An opaque token used to chain consecutive payments requests together. std::string context_token; // The url origin of the website where the unmasking happened. Should be - // populated when the unmasking is for a virtual card. + // populated when the unmasking is for a virtual-card. absl::optional<GURL> last_committed_url_origin; }; @@ -133,11 +134,11 @@ std::string real_pan; std::string dcvv; // The expiration month of the card. It falls in between 1 - 12. Should be - // populated when the card is a virtual card which does not necessarily have + // populated when the card is a virtual-card which does not necessarily have // the same expiration date as its related actual card. std::string expiration_month; // The four-digit expiration year of the card. Should be populated when the - // card is a virtual card which does not necessarily have the same + // card is a virtual-card which does not necessarily have the same // expiration date as its related actual card. std::string expiration_year; // Challenge required for enrolling user into FIDO authentication for future @@ -285,7 +286,42 @@ std::string server_id; // TODO(crbug.com/1281695): Add |virtual_card_enrollment_state| and // |card_art_url| data members when integrating all of the logic for the - // virtual card enrollment flow. + // virtual-card enrollment flow. + }; + + // A collection of information needed for the + // UpdateVirtualCardEnrollmentRequest. + struct UpdateVirtualCardEnrollmentRequestDetails { + UpdateVirtualCardEnrollmentRequestDetails(); + UpdateVirtualCardEnrollmentRequestDetails( + const UpdateVirtualCardEnrollmentRequestDetails&); + UpdateVirtualCardEnrollmentRequestDetails operator=( + const UpdateVirtualCardEnrollmentRequestDetails&) = delete; + ~UpdateVirtualCardEnrollmentRequestDetails(); + // Denotes the source that the corresponding + // UpdateVirtualCardEnrollmentRequest for this + // UpdateVirtualCardEnrollmentRequestDetails originated from, i.e., a + // |virtual_card_enrollment_source| of kUpstream means the request happens + // after a user saved a card in the upstream flow. + VirtualCardEnrollmentSource virtual_card_enrollment_source = + VirtualCardEnrollmentSource::kNone; + // Denotes the type of this specific UpdateVirtualCardEnrollmentRequest, + // i.e., a type of VirtualCardEnrollmentRequestType::kEnroll would mean this + // is an enroll request. + VirtualCardEnrollmentRequestType virtual_card_enrollment_request_type = + VirtualCardEnrollmentRequestType::kNone; + // The billing customer number for the account this request is sent to. If + // |billing_customer_number| is non-zero, it means the user has a Google + // Payments account. + int64_t billing_customer_number = 0; + // Populated if it is an unenroll request. |instrument_id| lets the server + // know which card to unenroll from VCN. + absl::optional<int64_t> instrument_id; + // Populated if it is an enroll request. Based on the |vcn_context_token| + // the server is able to retrieve the instrument id, and using + // |vcn_context_token| for enroll allows the server to link a + // GetDetailsForEnroll call with the corresponding Enroll call. + absl::optional<std::string> vcn_context_token; }; // TODO(crbug.com/1281695): Add GetDetailsForEnrollRequest. @@ -303,10 +339,10 @@ // and link this specific GetDetailsForEnroll call with its corresponding // enroll call. std::string vcn_context_token; - // Google's legal message lines in the virtual card enroll flow for this + // Google's legal message lines in the virtual-card enroll flow for this // specific card based on |vcn_context_token|. LegalMessageLines google_legal_message; - // The issuer's legal message lines in the virtual card enroll flow for this + // The issuer's legal message lines in the virtual-card enroll flow for this // specific card based on |vcn_context_token|. LegalMessageLines issuer_legal_message; }; @@ -402,6 +438,14 @@ base::OnceCallback<void(AutofillClient::PaymentsRpcResult, const std::string&)> callback); + // The user has chosen to change the virtual-card enrollment of a credit card. + // Send the necessary information for the server to identify the credit card + // for which virtual-card enrollment will be updated, as well as metadata so + // that the server understands the context for the request. + virtual void UpdateVirtualCardEnrollment( + const UpdateVirtualCardEnrollmentRequestDetails& request_details, + base::OnceCallback<void(AutofillClient::PaymentsRpcResult)> callback); + // Cancels and clears the current |request_|. void CancelRequest();
diff --git a/components/autofill/core/browser/payments/payments_client_unittest.cc b/components/autofill/core/browser/payments/payments_client_unittest.cc index 6f71767..b48076b 100644 --- a/components/autofill/core/browser/payments/payments_client_unittest.cc +++ b/components/autofill/core/browser/payments/payments_client_unittest.cc
@@ -23,6 +23,7 @@ #include "components/autofill/core/browser/payments/credit_card_save_manager.h" #include "components/autofill/core/browser/payments/local_card_migration_manager.h" #include "components/autofill/core/browser/payments/payments_client.h" +#include "components/autofill/core/browser/payments/virtual_card_enrollment_flow.h" #include "components/autofill/core/browser/test_personal_data_manager.h" #include "components/autofill/core/common/autofill_clock.h" #include "components/autofill/core/common/autofill_features.h" @@ -224,6 +225,11 @@ context_token_ = updated_context_token; } + void OnDidGetUpdateVirtualCardEnrollmentResponse( + AutofillClient::PaymentsRpcResult result) { + result_ = result; + } + protected: base::test::ScopedFeatureList scoped_feature_list_; @@ -1597,5 +1603,129 @@ EXPECT_EQ(AutofillClient::PaymentsRpcResult::kPermanentFailure, result_); } +typedef std::tuple<VirtualCardEnrollmentSource, + VirtualCardEnrollmentRequestType, + AutofillClient::PaymentsRpcResult> + UpdateVirtualCardEnrollmentTestData; + +class UpdateVirtualCardEnrollmentTest + : public PaymentsClientTest, + public ::testing::WithParamInterface< + UpdateVirtualCardEnrollmentTestData> { + public: + UpdateVirtualCardEnrollmentTest() = default; + ~UpdateVirtualCardEnrollmentTest() override = default; + + void TriggerFlow() { + VirtualCardEnrollmentSource virtual_card_enrollment_source = + std::get<0>(GetParam()); + VirtualCardEnrollmentRequestType virtual_card_enrollment_request_type = + std::get<1>(GetParam()); + StartUpdateVirtualCardEnrollment(virtual_card_enrollment_source, + virtual_card_enrollment_request_type); + IssueOAuthToken(); + + // |response_type_for_test| is the AutofillClient::PaymentsRpcResult + // response type we want to test for the combination of + // |virtual_card_enrollment_source| and + // |virtual_card_enrollment_request_type| we are currently on. + AutofillClient::PaymentsRpcResult response_type_for_test = + std::get<2>(GetParam()); + switch (response_type_for_test) { + case AutofillClient::PaymentsRpcResult::kSuccess: + if (virtual_card_enrollment_request_type == + VirtualCardEnrollmentRequestType::kEnroll) { + ReturnResponse(net::HTTP_OK, + "{ \"enroll_result\": \"ENROLL_SUCCESS\" }"); + } else if (virtual_card_enrollment_request_type == + VirtualCardEnrollmentRequestType::kUnenroll) { + ReturnResponse(net::HTTP_OK, "{}"); + } + break; + case AutofillClient::PaymentsRpcResult::kVcnRetrievalTryAgainFailure: + ReturnResponse( + net::HTTP_OK, + "{ \"error\": { \"code\": \"ANYTHING_ELSE\", " + "\"api_error_reason\": \"virtual_card_temporary_error\"} }"); + break; + case AutofillClient::PaymentsRpcResult::kTryAgainFailure: + ReturnResponse(net::HTTP_OK, + "{ \"error\": { \"code\": \"INTERNAL\", " + "\"api_error_reason\": \"ANYTHING_ELSE\"} }"); + break; + case AutofillClient::PaymentsRpcResult::kVcnRetrievalPermanentFailure: + ReturnResponse( + net::HTTP_OK, + "{ \"error\": { \"code\": \"ANYTHING_ELSE\", " + "\"api_error_reason\": \"virtual_card_permanent_error\"} }"); + break; + case AutofillClient::PaymentsRpcResult::kPermanentFailure: + ReturnResponse(net::HTTP_OK, + "{ \"error\": { \"code\": \"ANYTHING_ELSE\" } }"); + break; + case AutofillClient::PaymentsRpcResult::kNetworkError: + ReturnResponse(net::HTTP_REQUEST_TIMEOUT, ""); + break; + case AutofillClient::PaymentsRpcResult::kNone: + NOTREACHED(); + break; + } + EXPECT_EQ(response_type_for_test, result_); + } + + private: + void StartUpdateVirtualCardEnrollment( + VirtualCardEnrollmentSource virtual_card_enrollment_source, + VirtualCardEnrollmentRequestType virtual_card_enrollment_request_type) { + PaymentsClient::UpdateVirtualCardEnrollmentRequestDetails request_details; + request_details.virtual_card_enrollment_request_type = + virtual_card_enrollment_request_type; + request_details.virtual_card_enrollment_source = + virtual_card_enrollment_source; + request_details.billing_customer_number = 555666777888; + if (virtual_card_enrollment_request_type == + VirtualCardEnrollmentRequestType::kEnroll) { + request_details.vcn_context_token = "fake context token"; + } else if (virtual_card_enrollment_request_type == + VirtualCardEnrollmentRequestType::kUnenroll) { + request_details.instrument_id = 12345678; + } + client_->UpdateVirtualCardEnrollment( + request_details, + base::BindOnce( + &PaymentsClientTest::OnDidGetUpdateVirtualCardEnrollmentResponse, + weak_ptr_factory_.GetWeakPtr())); + } +}; + +// Initializes the parameterized test suite with all possible values of +// VirtualCardEnrollmentSource, VirtualCardEnrollmentRequestType, and +// AutofillClient::PaymentsRpcResult. +INSTANTIATE_TEST_SUITE_P( + , + UpdateVirtualCardEnrollmentTest, + testing::Combine( + testing::Values(VirtualCardEnrollmentSource::kUpstream, + VirtualCardEnrollmentSource::kDownstream, + VirtualCardEnrollmentSource::kSettingsPage), + testing::Values(VirtualCardEnrollmentRequestType::kEnroll, + VirtualCardEnrollmentRequestType::kUnenroll), + testing::Values( + AutofillClient::PaymentsRpcResult::kSuccess, + AutofillClient::PaymentsRpcResult::kVcnRetrievalTryAgainFailure, + AutofillClient::PaymentsRpcResult::kTryAgainFailure, + AutofillClient::PaymentsRpcResult::kVcnRetrievalPermanentFailure, + AutofillClient::PaymentsRpcResult::kPermanentFailure, + AutofillClient::PaymentsRpcResult::kNetworkError))); + +// Parameterized test that tests all combinations of +// VirtualCardEnrollmentSource and VirtualCardEnrollmentRequestType against all +// possible server responses in the UpdateVirtualCardEnrollmentFlow. This test +// will be run once for each combination. +TEST_P(UpdateVirtualCardEnrollmentTest, + UpdateVirtualCardEnrollmentTest_TestAllFlows) { + TriggerFlow(); +} + } // namespace payments } // namespace autofill
diff --git a/components/autofill/core/browser/payments/payments_requests/update_virtual_card_enrollment_request.cc b/components/autofill/core/browser/payments/payments_requests/update_virtual_card_enrollment_request.cc new file mode 100644 index 0000000..fc0b87d --- /dev/null +++ b/components/autofill/core/browser/payments/payments_requests/update_virtual_card_enrollment_request.cc
@@ -0,0 +1,181 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/autofill/core/browser/payments/payments_requests/update_virtual_card_enrollment_request.h" + +#include <string> + +#include "base/json/json_writer.h" +#include "base/values.h" +#include "components/autofill/core/browser/payments/virtual_card_enrollment_flow.h" + +namespace autofill { +namespace payments { + +namespace { +const char kUpdateVirtualCardEnrollmentRequestPath[] = + "payments/apis/virtualcardservice/enroll"; +} // namespace + +UpdateVirtualCardEnrollmentRequest::UpdateVirtualCardEnrollmentRequest( + const PaymentsClient::UpdateVirtualCardEnrollmentRequestDetails& + request_details, + base::OnceCallback<void(AutofillClient::PaymentsRpcResult)> callback) + : request_details_(request_details), callback_(std::move(callback)) {} + +UpdateVirtualCardEnrollmentRequest::~UpdateVirtualCardEnrollmentRequest() = + default; + +std::string UpdateVirtualCardEnrollmentRequest::GetRequestUrlPath() { + return kUpdateVirtualCardEnrollmentRequestPath; +} + +std::string UpdateVirtualCardEnrollmentRequest::GetRequestContentType() { + return "application/json"; +} + +std::string UpdateVirtualCardEnrollmentRequest::GetRequestContent() { + base::Value request_dict(base::Value::Type::DICTIONARY); + + switch (request_details_.virtual_card_enrollment_request_type) { + case VirtualCardEnrollmentRequestType::kEnroll: + BuildEnrollRequestDictionary(&request_dict); + break; + case VirtualCardEnrollmentRequestType::kUnenroll: + BuildUnenrollRequestDictionary(&request_dict); + break; + case VirtualCardEnrollmentRequestType::kNone: + NOTREACHED(); + break; + } + + std::string request_content; + base::JSONWriter::Write(request_dict, &request_content); + VLOG(3) << "UpdateVirtualCardEnrollmentRequest Body: " << request_content; + return request_content; +} + +void UpdateVirtualCardEnrollmentRequest::ParseResponse( + const base::Value& response) { + // Only enroll requests have a response to parse, unenroll request responses + // are empty except for possible errors which are parsed in PaymentsClient. + if (request_details_.virtual_card_enrollment_request_type == + VirtualCardEnrollmentRequestType::kEnroll) { + auto* enroll_result = + response.FindKeyOfType("enroll_result", base::Value::Type::STRING); + if (enroll_result) { + enroll_result_ = enroll_result->GetString(); + } + } +} + +bool UpdateVirtualCardEnrollmentRequest::IsResponseComplete() { + switch (request_details_.virtual_card_enrollment_request_type) { + case VirtualCardEnrollmentRequestType::kEnroll: + // If it is an enroll request, we know the response is complete if the + // response has an enroll result that is ENROLL_SUCCESS, as that is the + // only field in an enroll response other than the possible error. + return enroll_result_.has_value() && enroll_result_ == "ENROLL_SUCCESS"; + case VirtualCardEnrollmentRequestType::kUnenroll: + // Unenroll responses are empty except for having an error. In + // PaymentsClient, if the response has an error it will be handled before + // we check IsResponseComplete(), so if we ever reach this branch we know + // the response completed successfully as there is no error. Thus, we + // always return true. + return true; + case VirtualCardEnrollmentRequestType::kNone: + NOTREACHED(); + return false; + } +} + +void UpdateVirtualCardEnrollmentRequest::RespondToDelegate( + AutofillClient::PaymentsRpcResult result) { + std::move(callback_).Run(result); +} + +void UpdateVirtualCardEnrollmentRequest::BuildEnrollRequestDictionary( + base::Value* request_dict) { + DCHECK(request_details_.virtual_card_enrollment_request_type == + VirtualCardEnrollmentRequestType::kEnroll); + + // If it is an enroll request, we should always have a context token from the + // previous GetDetailsForEnroll request and we should not have an instrument + // id set in |request_details_|. + DCHECK(request_details_.vcn_context_token.has_value() && + !request_details_.instrument_id.has_value()); + + // Builds the context and channel_type for this enroll request. + base::Value context(base::Value::Type::DICTIONARY); + switch (request_details_.virtual_card_enrollment_source) { + case VirtualCardEnrollmentSource::kUpstream: + context.SetKey("billable_service", + base::Value(kUploadCardBillableServiceNumber)); + request_dict->SetKey("channel_type", base::Value("CHROME_UPSTREAM")); + break; + case VirtualCardEnrollmentSource::kDownstream: + // Downstream enroll is treated the same as settings page enroll because + // chrome client should already have a card synced from the server. + // Fall-through. + case VirtualCardEnrollmentSource::kSettingsPage: + context.SetKey("billable_service", + base::Value(kUnmaskCardBillableServiceNumber)); + request_dict->SetKey("channel_type", base::Value("CHROME_DOWNSTREAM")); + break; + case VirtualCardEnrollmentSource::kNone: + NOTREACHED(); + break; + } + if (request_details_.billing_customer_number != 0) { + context.SetKey("customer_context", + BuildCustomerContextDictionary( + request_details_.billing_customer_number)); + } + request_dict->SetKey("context", std::move(context)); + + // Sets the virtual_card_enrollment_flow field in this enroll request which + // lets the server know whether the enrollment is happening with ToS or not. + // Chrome client requests will always be ENROLL_WITH_TOS. This field is + // necessary because virtual card enroll through other platforms enrolls + // without ToS, for example Web Push Provisioning. + request_dict->SetKey("virtual_card_enrollment_flow", + base::Value("ENROLL_WITH_TOS")); + + // Sets the context_token field in this enroll request which is used by the + // server to link this enroll request to the previous + // GetDetailsForEnrollRequest, as well as to retrieve the specific credit card + // to enroll. + request_dict->SetKey("context_token", + base::Value(request_details_.vcn_context_token.value())); +} + +void UpdateVirtualCardEnrollmentRequest::BuildUnenrollRequestDictionary( + base::Value* request_dict) { + DCHECK(request_details_.virtual_card_enrollment_request_type == + VirtualCardEnrollmentRequestType::kUnenroll); + + // If it is an unenroll request, we should always have an instrument id and we + // should not have a context token set in |request_details_|. + DCHECK(request_details_.instrument_id.has_value() && + !request_details_.vcn_context_token.has_value()); + + // Builds the context for this unenroll request if a billing customer number + // is present. + if (request_details_.billing_customer_number != 0) { + base::Value context(base::Value::Type::DICTIONARY); + context.SetKey("customer_context", + BuildCustomerContextDictionary( + request_details_.billing_customer_number)); + request_dict->SetKey("context", std::move(context)); + } + + // Sets the instrument_id field in this unenroll request which is used by + // the server to get the appropriate credit card to unenroll. + request_dict->SetKey("instrument_id", + base::Value(base::NumberToString( + request_details_.instrument_id.value()))); +} + +} // namespace payments +} // namespace autofill
diff --git a/components/autofill/core/browser/payments/payments_requests/update_virtual_card_enrollment_request.h b/components/autofill/core/browser/payments/payments_requests/update_virtual_card_enrollment_request.h new file mode 100644 index 0000000..4438448 --- /dev/null +++ b/components/autofill/core/browser/payments/payments_requests/update_virtual_card_enrollment_request.h
@@ -0,0 +1,64 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_AUTOFILL_CORE_BROWSER_PAYMENTS_PAYMENTS_REQUESTS_UPDATE_VIRTUAL_CARD_ENROLLMENT_REQUEST_H_ +#define COMPONENTS_AUTOFILL_CORE_BROWSER_PAYMENTS_PAYMENTS_REQUESTS_UPDATE_VIRTUAL_CARD_ENROLLMENT_REQUEST_H_ + +#include <string> + +#include "base/callback.h" +#include "components/autofill/core/browser/autofill_client.h" +#include "components/autofill/core/browser/payments/payments_client.h" +#include "components/autofill/core/browser/payments/payments_requests/payments_request.h" +#include "third_party/abseil-cpp/absl/types/optional.h" + +namespace base { +class Value; +} // namespace base + +namespace autofill { +namespace payments { + +// Payments request to enroll or unenroll a credit card into a virtual card. +// Virtual Card Numbers allow a user to check out online with a credit +// card using a credit card number that is different from the credit card's +// original number. +class UpdateVirtualCardEnrollmentRequest : public PaymentsRequest { + public: + UpdateVirtualCardEnrollmentRequest( + const PaymentsClient::UpdateVirtualCardEnrollmentRequestDetails& + request_details, + base::OnceCallback<void(AutofillClient::PaymentsRpcResult)> callback); + UpdateVirtualCardEnrollmentRequest( + const UpdateVirtualCardEnrollmentRequest&) = delete; + UpdateVirtualCardEnrollmentRequest& operator=( + const UpdateVirtualCardEnrollmentRequest&) = delete; + ~UpdateVirtualCardEnrollmentRequest() override; + + // PaymentsRequest: + std::string GetRequestUrlPath() override; + std::string GetRequestContentType() override; + std::string GetRequestContent() override; + void ParseResponse(const base::Value& response) override; + bool IsResponseComplete() override; + void RespondToDelegate(AutofillClient::PaymentsRpcResult result) override; + + private: + // Modifies the base::Value that |request_dict| points to by setting all of + // the fields needed for an Enroll request. + void BuildEnrollRequestDictionary(base::Value* request_dict); + + // Modifies the base::Value that |request_dict| points to by setting all of + // the fields needed for an Unenroll request. + void BuildUnenrollRequestDictionary(base::Value* request_dict); + + PaymentsClient::UpdateVirtualCardEnrollmentRequestDetails request_details_; + base::OnceCallback<void(AutofillClient::PaymentsRpcResult)> callback_; + absl::optional<std::string> enroll_result_; +}; + +} // namespace payments +} // namespace autofill + +#endif // COMPONENTS_AUTOFILL_CORE_BROWSER_PAYMENTS_PAYMENTS_REQUESTS_UPDATE_VIRTUAL_CARD_ENROLLMENT_REQUEST_H_
diff --git a/components/autofill/core/browser/payments/virtual_card_enrollment_flow.h b/components/autofill/core/browser/payments/virtual_card_enrollment_flow.h new file mode 100644 index 0000000..9c66297c --- /dev/null +++ b/components/autofill/core/browser/payments/virtual_card_enrollment_flow.h
@@ -0,0 +1,43 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_AUTOFILL_CORE_BROWSER_PAYMENTS_VIRTUAL_CARD_ENROLLMENT_FLOW_H_ +#define COMPONENTS_AUTOFILL_CORE_BROWSER_PAYMENTS_VIRTUAL_CARD_ENROLLMENT_FLOW_H_ + +namespace autofill { + +// This enum is used to denote the specific source that the virtual card +// enrollment process originated from. +enum class VirtualCardEnrollmentSource { + // Default value, should never be used. + kNone = 0, + // Offering VCN Enrollment after Upstream, i.e., saving a card to Google + // Payments. + kUpstream = 1, + // Offering VCN Enrollment after Downstream, i.e., unmasking a card from + // Google Payments. + kDownstream = 2, + // Offering VCN Enrollment from the payment methods settings page. + kSettingsPage = 3, + // Max value, needs to be updated every time a new enum is added. + kMaxValue = kSettingsPage, +}; + +// Denotes the request type for an UpdateVirtualCardEnrollmentRequest. +enum class VirtualCardEnrollmentRequestType { + // Default value, should never be used. + kNone = 0, + // The corresponding UpdateVirtualCardEnrollmentRequest is an enroll + // request. + kEnroll = 1, + // The corresponding UpdateVirtualCardEnrollmentRequest is an unenroll + // request. + kUnenroll = 2, + // Max value, needs to be updated every time a new enum is added. + kMaxValue = kUnenroll, +}; + +} // namespace autofill + +#endif // COMPONENTS_AUTOFILL_CORE_BROWSER_PAYMENTS_VIRTUAL_CARD_ENROLLMENT_FLOW_H_
diff --git a/components/autofill/core/browser/payments/virtual_card_enrollment_manager.cc b/components/autofill/core/browser/payments/virtual_card_enrollment_manager.cc index 56d1dc5..c381427 100644 --- a/components/autofill/core/browser/payments/virtual_card_enrollment_manager.cc +++ b/components/autofill/core/browser/payments/virtual_card_enrollment_manager.cc
@@ -22,7 +22,7 @@ void VirtualCardEnrollmentManager::OfferVirtualCardEnroll( raw_ptr<CreditCard> credit_card, - VirtualCardEnrollmentFlow virtual_card_enrollment_flow) {} + VirtualCardEnrollmentSource virtual_card_enrollment_source) {} void VirtualCardEnrollmentManager::Unenroll(int64_t instrument_id) {}
diff --git a/components/autofill/core/browser/payments/virtual_card_enrollment_manager.h b/components/autofill/core/browser/payments/virtual_card_enrollment_manager.h index 9139271..3749bba8 100644 --- a/components/autofill/core/browser/payments/virtual_card_enrollment_manager.h +++ b/components/autofill/core/browser/payments/virtual_card_enrollment_manager.h
@@ -10,28 +10,12 @@ #include "base/memory/raw_ptr.h" #include "components/autofill/core/browser/autofill_client.h" #include "components/autofill/core/browser/payments/payments_client.h" +#include "components/autofill/core/browser/payments/virtual_card_enrollment_flow.h" namespace autofill { class CreditCard; -// This enum is used to denote the specific flow that the virtual card -// enrollment process is a part of. -enum class VirtualCardEnrollmentFlow { - // Default value, should never be used. - kNone = 0, - // Offering VCN Enrollment after Upstream flow, i.e., saving a card to - // Google Payments. - kUpstream = 1, - // Offering VCN Enrollment after Downstream flow, i.e., unmasking a card - // from Google Payments. - kDownstream = 2, - // Offering VCN Enrollment from the payment methods settings page. - kSettingsPage = 3, - // Max value, needs to be updated every time a new enum is added. - kMaxValue = kSettingsPage, -}; - // This struct is passed into the controller when we show the // VirtualCardEnrollmentBubble, and it lets the controller customize the // bubble based on the fields in this struct. For example, we will show @@ -52,9 +36,9 @@ // The legal message lines for the footer of the // VirtualCardEnrollmentBubble. LegalMessageLines legal_message_lines; - // The flow for which the VirtualCardEnrollmentBubble will be shown. - VirtualCardEnrollmentFlow virtual_card_enrollment_flow = - VirtualCardEnrollmentFlow::kNone; + // The source for which the VirtualCardEnrollmentBubble will be shown. + VirtualCardEnrollmentSource virtual_card_enrollment_source = + VirtualCardEnrollmentSource::kNone; }; // This struct is used to track the state of the virtual card enrollment @@ -70,7 +54,7 @@ // Only populated once the risk engine responded. absl::optional<std::string> risk_data; // |virtual_card_enrollment_fields|'s |credit_card| and - // |virtual_card_enrollment_flow| are populated in the beginning of the + // |virtual_card_enrollment_source| are populated in the beginning of the // virtual card enrollment flow, but the rest of the fields are only populated // before showing the VirtualCardEnrollmentBubble. VirtualCardEnrollmentFields virtual_card_enrollment_fields; @@ -97,12 +81,12 @@ // Starting point for the VCN enroll flow. The fields in |credit_card| will // be used throughout the flow, such as for request fields as well as credit // card specific fields for the bubble to display. - // |virtual_card_enrollment_flow| will be used by + // |virtual_card_enrollment_source| will be used by // ShowVirtualCardEnrollBubble() to differentiate different bubbles based on - // the flow we are in. + // the source we originated from. void OfferVirtualCardEnroll( raw_ptr<CreditCard> credit_card, - VirtualCardEnrollmentFlow virtual_card_enrollment_flow); + VirtualCardEnrollmentSource virtual_card_enrollment_source); // Unenrolls the card mapped to the given |instrument_id|. void Unenroll(int64_t instrument_id); @@ -111,7 +95,7 @@ // Called once the risk data is loaded. The |risk_data| will be used with // |state|'s |virtual_card_enrollment_fields|'s |credit_card|'s // |instrument_id_| field to make a GetDetailsForEnroll request, and - // |state|'s |virtual_card_enrollment_flow| will be passed down to when we + // |state|'s |virtual_card_enrollment_source| will be passed down to when we // show the bubble so that we show the correct bubble version. void OnRiskDataLoadedForVirtualCard( std::unique_ptr<VirtualCardEnrollmentProcessState> state, @@ -123,7 +107,7 @@ // |virtual_card_enrollment_fields|'s |credit_card|'s |instrument_id| are the // fields the server requires for the GetDetailsForEnrollRequest, and will be // used by |client_|'s |payments_client_|. |state|'s - // |virtual_card_enrollment_fields_|'s |virtual_card_enrollment_flow| is + // |virtual_card_enrollment_fields_|'s |virtual_card_enrollment_source| is // passed here so that it can be forwarded to ShowVirtualCardEnrollBubble. void GetDetailsForEnroll( std::unique_ptr<VirtualCardEnrollmentProcessState> state);
diff --git a/components/browser_ui/widget/android/java/src/org/chromium/components/browser_ui/widget/ContextMenuDialog.java b/components/browser_ui/widget/android/java/src/org/chromium/components/browser_ui/widget/ContextMenuDialog.java index 05862c1..bc707af 100644 --- a/components/browser_ui/widget/android/java/src/org/chromium/components/browser_ui/widget/ContextMenuDialog.java +++ b/components/browser_ui/widget/android/java/src/org/chromium/components/browser_ui/widget/ContextMenuDialog.java
@@ -142,11 +142,24 @@ @Override public void onLayoutChange(View v, int left, int top, int right, int bottom, int oldLeft, int oldTop, int oldRight, int oldBottom) { + // // If the layout size does not change (e.g. call due to #forceLayout), do nothing + // // because we don't want to dismiss the context menu. + if (left == oldLeft && right == oldRight && top == oldTop && bottom == oldBottom) { + return; + } + if (mIsPopup) { // If the menu is a popup, wait for the layout to be measured, then proceed with // showing the popup window. if (v.getMeasuredHeight() == 0) return; + // If dialog is showing and the layout changes, we might lost the anchor point. + // We'll dismiss the context menu and remove the listener. + if (mPopupWindow != null && mPopupWindow.isShowing()) { + dismiss(); + return; + } + final int posX = (int) mTouchPointXPx; final int posY = (int) (mTouchPointYPx + mTopContentOffsetPx); final Rect rect = new Rect(posX, posY, posX, posY); @@ -167,9 +180,9 @@ if (v.getMeasuredHeight() == 0) return; startEnterAnimation(); + v.removeOnLayoutChangeListener(this); + mOnLayoutChangeListener = null; } - v.removeOnLayoutChangeListener(this); - mOnLayoutChangeListener = null; } }; (mIsPopup ? mLayout : mContentView).addOnLayoutChangeListener(mOnLayoutChangeListener);
diff --git a/components/browser_ui/widget/android/java/src/org/chromium/components/browser_ui/widget/ContextMenuDialogUnitTest.java b/components/browser_ui/widget/android/java/src/org/chromium/components/browser_ui/widget/ContextMenuDialogUnitTest.java index c31766e..a57f27999 100644 --- a/components/browser_ui/widget/android/java/src/org/chromium/components/browser_ui/widget/ContextMenuDialogUnitTest.java +++ b/components/browser_ui/widget/android/java/src/org/chromium/components/browser_ui/widget/ContextMenuDialogUnitTest.java
@@ -71,6 +71,7 @@ Mockito.doNothing() .when(mSpyPopupWindow) .showAtLocation(any(View.class), anyInt(), anyInt(), anyInt()); + Mockito.doNothing().when(mSpyPopupWindow).dismiss(); } @Test @@ -111,7 +112,8 @@ public void testShowPopupWindow() { mDialog = createContextMenuDialog(/*isPopup=*/true, /*shouldRemoveScrim=*/false); mDialog.show(); - // Request layout so #onLayoutChange is triggered. + // Change layout params and request layout so #onLayoutChange is triggered. + mRootView.setRight(mRootView.getRight() + 1); mRootView.requestLayout(); final ArgumentCaptor<Integer> gravityCaptor = ArgumentCaptor.forClass(Integer.class); @@ -128,6 +130,25 @@ Mockito.verify(mSpyPopupWindow).dismiss(); } + @Test + public void testShowPopupWindow_2ndLayout() { + mDialog = createContextMenuDialog(/*isPopup=*/true, /*shouldRemoveScrim=*/false); + mDialog.show(); + // Change layout params and request layout so #onLayoutChange is triggered. + mRootView.setRight(mRootView.getRight() + 1); + mRootView.requestLayout(); + Mockito.verify(mSpyPopupWindow) + .showAtLocation(eq(mRootView.getRootView()), anyInt(), anyInt(), anyInt()); + + // Mock up popup window is showing. + Mockito.doReturn(true).when(mSpyPopupWindow).isShowing(); + + // Change layout params and request layout so #onLayoutChange is triggered. + mRootView.setRight(mRootView.getRight() - 1); + mRootView.requestLayout(); + Mockito.verify(mSpyPopupWindow).dismiss(); + } + /** * Inspired by https://crbug.com/1281011. If popup context menu is dismissed before * #onLayoutRequest for the root view, popup menu should not get invoked.
diff --git a/components/browsing_data/content/cookie_helper_unittest.cc b/components/browsing_data/content/cookie_helper_unittest.cc index 58d8a73..586ad6f 100644 --- a/components/browsing_data/content/cookie_helper_unittest.cc +++ b/components/browsing_data/content/cookie_helper_unittest.cc
@@ -10,6 +10,7 @@ #include "base/run_loop.h" #include "base/test/bind.h" #include "base/time/time.h" +#include "build/build_config.h" #include "content/public/browser/cookie_access_details.h" #include "content/public/browser/storage_partition.h" #include "content/public/test/browser_task_environment.h" @@ -591,7 +592,13 @@ base::Unretained(this))); } -TEST_F(CookieHelperTest, CannedGetCookieCount) { +#if defined(OS_LINUX) && defined(THREAD_SANITIZER) +#define MAYBE_CannedGetCookieCount DISABLED_CannedGetCookieCount +#else +#define MAYBE_CannedGetCookieCount CannedGetCookieCount +#endif +// Flaky on Linux TSan: https://crbug.com/1285414. +TEST_F(CookieHelperTest, MAYBE_CannedGetCookieCount) { // The URL in the omnibox is a frame URL. This is not necessarily the request // URL, since websites usually include other resources. GURL frame1_url("http://www.google.com");
diff --git a/components/browsing_data/content/service_worker_helper_unittest.cc b/components/browsing_data/content/service_worker_helper_unittest.cc index c325b5b..534f101d 100644 --- a/components/browsing_data/content/service_worker_helper_unittest.cc +++ b/components/browsing_data/content/service_worker_helper_unittest.cc
@@ -9,6 +9,7 @@ #include "base/memory/scoped_refptr.h" #include "base/strings/utf_string_conversions.h" #include "base/threading/thread_task_runner_handle.h" +#include "build/build_config.h" #include "content/public/browser/browser_context.h" #include "content/public/browser/storage_partition.h" #include "content/public/test/browser_task_environment.h" @@ -30,7 +31,13 @@ content::TestBrowserContext browser_context_; }; -TEST_F(CannedServiceWorkerHelperTest, Empty) { +#if defined(OS_LINUX) && defined(THREAD_SANITIZER) +#define MAYBE_Empty DISABLED_Empty +#else +#define MAYBE_Empty Empty +#endif +// Flaky on Linux TSan: https://crbug.com/1285414 +TEST_F(CannedServiceWorkerHelperTest, MAYBE_Empty) { const GURL origin("https://host1:1/"); std::vector<GURL> scopes; scopes.push_back(GURL("https://host1:1/*"));
diff --git a/components/certificate_transparency/data/log_list.json b/components/certificate_transparency/data/log_list.json index 8b87a8a..b697f47 100644 --- a/components/certificate_transparency/data/log_list.json +++ b/components/certificate_transparency/data/log_list.json
@@ -1,6 +1,6 @@ { - "version": "4.81", - "log_list_timestamp": "2022-01-06T01:34:23Z", + "version": "4.82", + "log_list_timestamp": "2022-01-07T01:35:26Z", "operators": [ { "name": "Google",
diff --git a/components/enterprise/browser/reporting/report_scheduler.cc b/components/enterprise/browser/reporting/report_scheduler.cc index 250f049..77cf5e9d 100644 --- a/components/enterprise/browser/reporting/report_scheduler.cc +++ b/components/enterprise/browser/reporting/report_scheduler.cc
@@ -65,6 +65,21 @@ LOG(ERROR) << "Extension request failed to be added to the pipeline."; } +ReportType TriggerToReportType(ReportScheduler::ReportTrigger trigger) { + switch (trigger) { + case ReportScheduler::kTriggerNone: + case ReportScheduler::kTriggerExtensionRequestRealTime: + NOTREACHED(); + [[fallthrough]]; + case ReportScheduler::kTriggerTimer: + return ReportType::kFull; + case ReportScheduler::kTriggerUpdate: + return ReportType::kBrowserVersion; + case ReportScheduler::kTriggerNewVersion: + return ReportType::kBrowserVersion; + } +} + } // namespace ReportScheduler::Delegate::Delegate() = default; @@ -229,28 +244,11 @@ } active_trigger_ = trigger; - ReportType report_type = ReportType::kFull; - switch (trigger) { - case kTriggerNone: - case kTriggerExtensionRequestRealTime: - NOTREACHED(); - [[fallthrough]]; - case kTriggerTimer: - VLOG(1) << "Generating enterprise report."; - break; - case kTriggerUpdate: - VLOG(1) << "Generating basic enterprise report upon update."; - report_type = ReportType::kBrowserVersion; - break; - case kTriggerNewVersion: - VLOG(1) << "Generating basic enterprise report upon new version."; - report_type = ReportType::kBrowserVersion; - break; - } report_generator_->Generate( - report_type, base::BindOnce(&ReportScheduler::OnReportGenerated, - base::Unretained(this))); + TriggerToReportType(trigger), + base::BindOnce(&ReportScheduler::OnReportGenerated, + base::Unretained(this))); } void ReportScheduler::GenerateAndUploadRealtimeReport( @@ -280,8 +278,9 @@ } RecordUploadTrigger(active_trigger_); report_uploader_->SetRequestAndUpload( - std::move(requests), base::BindOnce(&ReportScheduler::OnReportUploaded, - base::Unretained(this))); + TriggerToReportType(active_trigger_), std::move(requests), + base::BindOnce(&ReportScheduler::OnReportUploaded, + base::Unretained(this))); } void ReportScheduler::OnReportUploaded(ReportUploader::ReportStatus status) {
diff --git a/components/enterprise/browser/reporting/report_uploader.cc b/components/enterprise/browser/reporting/report_uploader.cc index 73c6e4ef..b9488df0 100644 --- a/components/enterprise/browser/reporting/report_uploader.cc +++ b/components/enterprise/browser/reporting/report_uploader.cc
@@ -9,6 +9,7 @@ #include "base/metrics/histogram_functions.h" #include "base/time/time.h" #include "build/chromeos_buildflags.h" +#include "components/enterprise/browser/reporting/report_type.h" #include "components/policy/core/common/cloud/cloud_policy_client.h" #include "device_management_backend.pb.h" @@ -41,8 +42,10 @@ maximum_number_of_retries_(maximum_number_of_retries) {} ReportUploader::~ReportUploader() = default; -void ReportUploader::SetRequestAndUpload(ReportRequestQueue requests, +void ReportUploader::SetRequestAndUpload(ReportType report_type, + ReportRequestQueue requests, ReportCallback callback) { + report_type_ = report_type; requests_ = std::move(requests); callback_ = std::move(callback); Upload(); @@ -52,17 +55,28 @@ auto callback = base::BindRepeating(&ReportUploader::OnRequestFinished, weak_ptr_factory_.GetWeakPtr()); + switch (report_type_) { + case ReportType::kFull: + case ReportType::kBrowserVersion: { + auto request = std::make_unique<ReportRequest::DeviceReportRequestProto>( + requests_.front()->GetDeviceReportRequest()); #if BUILDFLAG(IS_CHROMEOS_ASH) - auto request = - std::make_unique<enterprise_management::ChromeOsUserReportRequest>( - requests_.front()->GetDeviceReportRequest()); - client_->UploadChromeOsUserReport(std::move(request), std::move(callback)); + client_->UploadChromeOsUserReport(std::move(request), + std::move(callback)); #else - auto request = - std::make_unique<enterprise_management::ChromeDesktopReportRequest>( - requests_.front()->GetDeviceReportRequest()); - client_->UploadChromeDesktopReport(std::move(request), std::move(callback)); + client_->UploadChromeDesktopReport(std::move(request), + std::move(callback)); #endif + break; + } + case ReportType::kProfileReport: { + client_->UploadChromeProfileReport( + std::make_unique<em::ChromeProfileReportRequest>( + requests_.front()->GetChromeProfileReportRequest()), + std::move(callback)); + break; + } + } } // BUILDFLAG(IS_CHROMEOS_ASH) void ReportUploader::OnRequestFinished(bool status) {
diff --git a/components/enterprise/browser/reporting/report_uploader.h b/components/enterprise/browser/reporting/report_uploader.h index 7fb17e5..98d85289 100644 --- a/components/enterprise/browser/reporting/report_uploader.h +++ b/components/enterprise/browser/reporting/report_uploader.h
@@ -53,7 +53,8 @@ // Sets a list of requests and upload it. Request will be uploaded one after // another. - virtual void SetRequestAndUpload(ReportRequestQueue requests, + virtual void SetRequestAndUpload(ReportType report_type, + ReportRequestQueue requests, ReportCallback callback); private: @@ -77,6 +78,7 @@ raw_ptr<policy::CloudPolicyClient> client_; ReportCallback callback_; ReportRequestQueue requests_; + ReportType report_type_; net::BackoffEntry backoff_entry_; base::OneShotTimer backoff_request_timer_;
diff --git a/components/enterprise/browser/reporting/report_uploader_unittest.cc b/components/enterprise/browser/reporting/report_uploader_unittest.cc index d71d481..f526eff 100644 --- a/components/enterprise/browser/reporting/report_uploader_unittest.cc +++ b/components/enterprise/browser/reporting/report_uploader_unittest.cc
@@ -13,6 +13,7 @@ #include "components/enterprise/browser/reporting/report_request.h" #include "components/enterprise/browser/reporting/report_type.h" #include "components/policy/core/common/cloud/mock_cloud_policy_client.h" +#include "device_management_backend.pb.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" @@ -59,19 +60,31 @@ << "Please update kBrowserVersionNames above."; ReportRequestQueue requests; for (int i = 0; i < number_of_request; i++) { - auto request = std::make_unique<ReportRequest>(ReportType::kFull); - request->GetDeviceReportRequest() - .mutable_browser_report() - ->set_browser_version(kBrowserVersionNames[i]); + auto request = std::make_unique<ReportRequest>(GetReportType()); + em::BrowserReport* browser_report; + switch (GetReportType()) { + case ReportType::kFull: + case ReportType::kBrowserVersion: + browser_report = + request->GetDeviceReportRequest().mutable_browser_report(); + break; + case ReportType::kProfileReport: + browser_report = + request->GetChromeProfileReportRequest().mutable_browser_report(); + break; + } + browser_report->set_browser_version(kBrowserVersionNames[i]); requests.push(std::move(request)); } has_responded_ = false; uploader_->SetRequestAndUpload( - std::move(requests), + GetReportType(), std::move(requests), base::BindOnce(&ReportUploaderTest::OnReportUploaded, base::Unretained(this), expected_status)); } + virtual ReportType GetReportType() { return ReportType::kFull; } + void OnReportUploaded(ReportUploader::ReportStatus expected_status, ReportUploader::ReportStatus actuall_status) { EXPECT_EQ(expected_status, actuall_status); @@ -104,7 +117,7 @@ base::test::TaskEnvironment task_environment_; std::unique_ptr<ReportUploader> uploader_; - policy::MockCloudPolicyClient client_; + ::testing::StrictMock<policy::MockCloudPolicyClient> client_; bool has_responded_ = false; base::HistogramTester histogram_tester_; }; @@ -113,17 +126,12 @@ : public ReportUploaderTest, public ::testing::WithParamInterface<policy::DeviceManagementStatus> {}; -TEST_F(ReportUploaderTest, Success) { - EXPECT_CALL(client_, UploadReportProxy(_, _)) - .WillOnce(WithArgs<1>(policy::ScheduleStatusCallback(true))); - UploadReportAndSetExpectation(/*number_of_request=*/1, - ReportUploader::kSuccess); - RunNextTask(); - EXPECT_TRUE(has_responded_); - histogram_tester_.ExpectUniqueSample( - kResponseMetricsName, ReportResponseMetricsStatus::kSuccess, 1); - ::testing::Mock::VerifyAndClearExpectations(&client_); -} +class ReportUploaderTestWithReportType + : public ReportUploaderTest, + public ::testing::WithParamInterface<ReportType> { + public: + ReportType GetReportType() override { return GetParam(); } +}; TEST_F(ReportUploaderTest, PersistentError) { CreateUploader(/* retry_count = */ 1); @@ -304,4 +312,32 @@ policy::DM_STATUS_TEMPORARY_UNAVAILABLE, policy::DM_STATUS_SERVICE_TOO_MANY_REQUESTS)); +TEST_P(ReportUploaderTestWithReportType, Success) { + switch (GetReportType()) { + case ReportType::kFull: + case ReportType::kBrowserVersion: + EXPECT_CALL(client_, UploadReportProxy(_, _)) + .WillOnce(WithArgs<1>(policy::ScheduleStatusCallback(true))); + break; + case ReportType::kProfileReport: + EXPECT_CALL(client_, UploadChromeProfileReportProxy(_, _)) + .WillOnce(WithArgs<1>(policy::ScheduleStatusCallback(true))); + break; + } + + UploadReportAndSetExpectation(/*number_of_request=*/1, + ReportUploader::kSuccess); + RunNextTask(); + EXPECT_TRUE(has_responded_); + histogram_tester_.ExpectUniqueSample( + kResponseMetricsName, ReportResponseMetricsStatus::kSuccess, 1); + ::testing::Mock::VerifyAndClearExpectations(&client_); +} + +INSTANTIATE_TEST_SUITE_P(All, + ReportUploaderTestWithReportType, + ::testing::Values(ReportType::kFull, + ReportType::kBrowserVersion, + ReportType::kProfileReport)); + } // namespace enterprise_reporting
diff --git a/components/exo/display_unittest.cc b/components/exo/display_unittest.cc index 28e1a948..3469fdc 100644 --- a/components/exo/display_unittest.cc +++ b/components/exo/display_unittest.cc
@@ -84,15 +84,15 @@ }; TEST_F(DisplayTest, CreateSurface) { - std::unique_ptr<Display> display(new Display); + Display display; // Creating a surface should succeed. - std::unique_ptr<Surface> surface = display->CreateSurface(); + std::unique_ptr<Surface> surface = display.CreateSurface(); EXPECT_TRUE(surface); } TEST_F(DisplayTest, CreateSharedMemory) { - std::unique_ptr<Display> display(new Display); + Display display; int shm_size = 8192; base::UnsafeSharedMemoryRegion shared_memory = @@ -101,12 +101,12 @@ // Creating a shared memory instance from a valid region should succeed. std::unique_ptr<SharedMemory> shm1 = - display->CreateSharedMemory(std::move(shared_memory)); + display.CreateSharedMemory(std::move(shared_memory)); EXPECT_TRUE(shm1); // Creating a shared memory instance from a invalid region should fail. std::unique_ptr<SharedMemory> shm2 = - display->CreateSharedMemory(base::UnsafeSharedMemoryRegion()); + display.CreateSharedMemory(base::UnsafeSharedMemoryRegion()); EXPECT_FALSE(shm2); } @@ -115,7 +115,7 @@ TEST_F(DisplayTest, DISABLED_CreateLinuxDMABufBuffer) { const gfx::Size buffer_size(256, 256); - std::unique_ptr<Display> display(new Display); + Display display; // Creating a prime buffer from a native pixmap handle should succeed. scoped_refptr<gfx::NativePixmap> pixmap = ui::OzonePlatform::GetInstance() @@ -124,9 +124,9 @@ buffer_size, gfx::BufferFormat::RGBA_8888, gfx::BufferUsage::GPU_READ); gfx::NativePixmapHandle native_pixmap_handle = pixmap->ExportHandle(); - std::unique_ptr<Buffer> buffer1 = display->CreateLinuxDMABufBuffer( - buffer_size, gfx::BufferFormat::RGBA_8888, - std::move(native_pixmap_handle), false); + std::unique_ptr<Buffer> buffer1 = + display.CreateLinuxDMABufBuffer(buffer_size, gfx::BufferFormat::RGBA_8888, + std::move(native_pixmap_handle), false); EXPECT_TRUE(buffer1); // Create a handle without a file descriptor. @@ -134,9 +134,9 @@ native_pixmap_handle.planes[0].fd.reset(); // Creating a prime buffer using an invalid fd should fail. - std::unique_ptr<Buffer> buffer2 = display->CreateLinuxDMABufBuffer( - buffer_size, gfx::BufferFormat::RGBA_8888, - std::move(native_pixmap_handle), false); + std::unique_ptr<Buffer> buffer2 = + display.CreateLinuxDMABufBuffer(buffer_size, gfx::BufferFormat::RGBA_8888, + std::move(native_pixmap_handle), false); EXPECT_FALSE(buffer2); } @@ -146,37 +146,37 @@ #endif TEST_F(DisplayTest, CreateShellSurface) { - std::unique_ptr<Display> display(new Display); + Display display; // Create two surfaces. - std::unique_ptr<Surface> surface1 = display->CreateSurface(); + std::unique_ptr<Surface> surface1 = display.CreateSurface(); ASSERT_TRUE(surface1); - std::unique_ptr<Surface> surface2 = display->CreateSurface(); + std::unique_ptr<Surface> surface2 = display.CreateSurface(); ASSERT_TRUE(surface2); // Create a shell surface for surface1. std::unique_ptr<ShellSurface> shell_surface1 = - display->CreateShellSurface(surface1.get()); + display.CreateShellSurface(surface1.get()); EXPECT_TRUE(shell_surface1); // Create a shell surface for surface2. std::unique_ptr<ShellSurface> shell_surface2 = - display->CreateShellSurface(surface2.get()); + display.CreateShellSurface(surface2.get()); EXPECT_TRUE(shell_surface2); } TEST_F(DisplayTest, CreateClientControlledShellSurface) { - std::unique_ptr<Display> display(new Display); + Display display; // Create two surfaces. - std::unique_ptr<Surface> surface1 = display->CreateSurface(); + std::unique_ptr<Surface> surface1 = display.CreateSurface(); ASSERT_TRUE(surface1); - std::unique_ptr<Surface> surface2 = display->CreateSurface(); + std::unique_ptr<Surface> surface2 = display.CreateSurface(); ASSERT_TRUE(surface2); // Create a remote shell surface for surface1. std::unique_ptr<ClientControlledShellSurface> shell_surface1 = - display->CreateOrGetClientControlledShellSurface( + display.CreateOrGetClientControlledShellSurface( surface1.get(), ash::kShellWindowId_SystemModalContainer, /*default_scale_factor=*/2.0, /*default_scale_cancellation=*/true); @@ -185,7 +185,7 @@ // Create a remote shell surface for surface2. std::unique_ptr<ShellSurfaceBase> shell_surface2 = - display->CreateOrGetClientControlledShellSurface( + display.CreateOrGetClientControlledShellSurface( surface2.get(), ash::desks_util::GetActiveDeskContainerId(), /*default_scale_factor=*/1.0, /*default_scale_cancellation=*/true); @@ -193,25 +193,26 @@ } TEST_F(DisplayTest, GetClientControlledShellSurface) { - std::unique_ptr<Display> display(new Display); + Display display; // Create a external surface, bind with a window id. + std::unique_ptr<Surface> surface = display.CreateSurface(); ClientControlledShellSurface* external_shell_surface = new ClientControlledShellSurface( - new Surface, + surface.get(), /*can_minimize=*/true, ash::desks_util::GetActiveDeskContainerId(), /*default_scale_cancellation=*/true); property_resolver()->PutClientControlledShellSurface( /*window_session_id=*/10001, base::WrapUnique(external_shell_surface)); // Create surface with specific window id. - std::unique_ptr<Surface> surface_with_id = display->CreateSurface(); + std::unique_ptr<Surface> surface_with_id = display.CreateSurface(); ASSERT_TRUE(surface_with_id); surface_with_id->SetWindowSessionId(10001); // Get a remote shell surface by external source. std::unique_ptr<ClientControlledShellSurface> shell_surface = - display->CreateOrGetClientControlledShellSurface( + display.CreateOrGetClientControlledShellSurface( surface_with_id.get(), ash::desks_util::GetActiveDeskContainerId(), /*default_scale_factor=*/2.0, /*default_scale_cancellation=*/true); @@ -220,68 +221,68 @@ } TEST_F(DisplayTest, CreateSubSurface) { - std::unique_ptr<Display> display(new Display); + Display display; // Create child, parent and toplevel surfaces. - std::unique_ptr<Surface> child = display->CreateSurface(); + std::unique_ptr<Surface> child = display.CreateSurface(); ASSERT_TRUE(child); - std::unique_ptr<Surface> parent = display->CreateSurface(); + std::unique_ptr<Surface> parent = display.CreateSurface(); ASSERT_TRUE(parent); - std::unique_ptr<Surface> toplevel = display->CreateSurface(); + std::unique_ptr<Surface> toplevel = display.CreateSurface(); ASSERT_TRUE(toplevel); // Attempting to create a sub surface for child with child as its parent // should fail. - EXPECT_FALSE(display->CreateSubSurface(child.get(), child.get())); + EXPECT_FALSE(display.CreateSubSurface(child.get(), child.get())); // Create a sub surface for child. std::unique_ptr<SubSurface> child_sub_surface = - display->CreateSubSurface(child.get(), toplevel.get()); + display.CreateSubSurface(child.get(), toplevel.get()); EXPECT_TRUE(child_sub_surface); // Attempting to create another sub surface when already assigned the role of // sub surface should fail. - EXPECT_FALSE(display->CreateSubSurface(child.get(), parent.get())); + EXPECT_FALSE(display.CreateSubSurface(child.get(), parent.get())); // Deleting the sub surface should allow a new sub surface to be created. child_sub_surface.reset(); - child_sub_surface = display->CreateSubSurface(child.get(), parent.get()); + child_sub_surface = display.CreateSubSurface(child.get(), parent.get()); EXPECT_TRUE(child_sub_surface); - std::unique_ptr<Surface> sibling = display->CreateSurface(); + std::unique_ptr<Surface> sibling = display.CreateSurface(); ASSERT_TRUE(sibling); // Create a sub surface for sibiling. std::unique_ptr<SubSurface> sibling_sub_surface = - display->CreateSubSurface(sibling.get(), parent.get()); + display.CreateSubSurface(sibling.get(), parent.get()); EXPECT_TRUE(sibling_sub_surface); // Create a shell surface for toplevel surface. std::unique_ptr<ShellSurface> shell_surface = - display->CreateShellSurface(toplevel.get()); + display.CreateShellSurface(toplevel.get()); EXPECT_TRUE(shell_surface); // Attempting to create a sub surface when already assigned the role of // shell surface should fail. - EXPECT_FALSE(display->CreateSubSurface(toplevel.get(), parent.get())); + EXPECT_FALSE(display.CreateSubSurface(toplevel.get(), parent.get())); - std::unique_ptr<Surface> grandchild = display->CreateSurface(); + std::unique_ptr<Surface> grandchild = display.CreateSurface(); ASSERT_TRUE(grandchild); // Create a sub surface for grandchild. std::unique_ptr<SubSurface> grandchild_sub_surface = - display->CreateSubSurface(grandchild.get(), child.get()); + display.CreateSubSurface(grandchild.get(), child.get()); EXPECT_TRUE(grandchild_sub_surface); // Attempting to create a sub surface for parent with child as its parent // should fail. - EXPECT_FALSE(display->CreateSubSurface(parent.get(), child.get())); + EXPECT_FALSE(display.CreateSubSurface(parent.get(), child.get())); // Attempting to create a sub surface for parent with grandchild as its parent // should fail. - EXPECT_FALSE(display->CreateSubSurface(parent.get(), grandchild.get())); + EXPECT_FALSE(display.CreateSubSurface(parent.get(), grandchild.get())); // Create a sub surface for parent. - EXPECT_TRUE(display->CreateSubSurface(parent.get(), toplevel.get())); + EXPECT_TRUE(display.CreateSubSurface(parent.get(), toplevel.get())); } class TestDataDeviceDelegate : public DataDeviceDelegate {
diff --git a/components/external_intents/android/java/src/org/chromium/components/external_intents/RedirectHandler.java b/components/external_intents/android/java/src/org/chromium/components/external_intents/RedirectHandler.java index 8e6f619..845de114 100644 --- a/components/external_intents/android/java/src/org/chromium/components/external_intents/RedirectHandler.java +++ b/components/external_intents/android/java/src/org/chromium/components/external_intents/RedirectHandler.java
@@ -29,28 +29,51 @@ public static final int INVALID_ENTRY_INDEX = -1; public static final long INVALID_TIME = -1; - private static final int NAVIGATION_TYPE_NONE = 0; private static final int NAVIGATION_TYPE_FROM_INTENT = 1; private static final int NAVIGATION_TYPE_FROM_USER_TYPING = 2; private static final int NAVIGATION_TYPE_FROM_LINK_WITHOUT_USER_GESTURE = 3; private static final int NAVIGATION_TYPE_FROM_RELOAD = 4; private static final int NAVIGATION_TYPE_OTHER = 5; - private Intent mInitialIntent; - // A resolver list which includes all resolvers of |mInitialIntent|. - private final HashSet<ComponentName> mCachedResolvers = new HashSet<ComponentName>(); - private boolean mIsInitialIntentHeadingToChrome; - private boolean mIsCustomTabIntent; + private static class IntentState { + final Intent mInitialIntent; + final boolean mIsCustomTabIntent; + final boolean mIsInitialIntentHeadingToChrome; + final boolean mExternalIntentStartedTask; + + // A resolver list which includes all resolvers of |mInitialIntent|. + HashSet<ComponentName> mCachedResolvers = new HashSet<ComponentName>(); + + IntentState(Intent initialIntent, boolean isInitialIntentHeadingToChrome, + boolean isCustomTabIntent, boolean externalIntentStartedTask) { + mInitialIntent = initialIntent; + mIsInitialIntentHeadingToChrome = isInitialIntentHeadingToChrome; + mIsCustomTabIntent = isCustomTabIntent; + mExternalIntentStartedTask = externalIntentStartedTask; + } + } + + private static class NavigationState { + final int mInitialNavigationType; + final int mLastCommittedEntryIndexBeforeStartingNavigation; + final boolean mHasUserStartedNonInitialNavigation; + boolean mIsOnEffectiveRedirectChain; + boolean mShouldNotOverrideUrlLoadingOnCurrentRedirectChain; + boolean mShouldNotBlockOverrideUrlLoadingOnCurrentRedirectionChain; + + NavigationState(int initialNavigationType, + int lastCommittedEntryIndexBeforeStartingNavigation, + boolean hasUserStartedNonInitialNavigation) { + mInitialNavigationType = initialNavigationType; + mLastCommittedEntryIndexBeforeStartingNavigation = + lastCommittedEntryIndexBeforeStartingNavigation; + mHasUserStartedNonInitialNavigation = hasUserStartedNonInitialNavigation; + } + } private long mLastNewUrlLoadingTime = INVALID_TIME; - private boolean mIsOnEffectiveRedirectChain; - private int mInitialNavigationType; - private int mLastCommittedEntryIndexBeforeStartingNavigation; - private boolean mHasUserStartedNonInitialNavigation; - - private boolean mShouldNotOverrideUrlLoadingOnCurrentRedirectChain; - private boolean mShouldNotBlockOverrideUrlLoadingOnCurrentRedirectionChain; - private boolean mExternalIntentStartedTask; + private IntentState mIntentState; + private NavigationState mNavigationState; public static RedirectHandler create() { return new RedirectHandler(); @@ -59,36 +82,31 @@ protected RedirectHandler() {} /** - * Updates |mIntentHistory| and |mLastIntentUpdatedTime|. If |intent| comes from chrome and - * currently |mIsOnEffectiveIntentRedirectChain| is true, that means |intent| was sent from - * this tab because only the front tab or a new tab can receive an intent from chrome. In that - * case, |intent| is added to |mIntentHistory|. - * Otherwise, |mIntentHistory| and |mPreviousResolvers| are cleared, and then |intent| is put - * into |mIntentHistory|. + * Resets |mIntentState| for the newly received Intent. */ public void updateIntent(Intent intent, boolean isCustomTabIntent, boolean sendToExternalApps, boolean isCCTExternalLinkHandlingEnabled, boolean externalIntentStartedTask) { - clear(); - if (intent == null || !Intent.ACTION_VIEW.equals(intent.getAction())) { + mIntentState = null; return; } - mIsCustomTabIntent = isCustomTabIntent; - mExternalIntentStartedTask = externalIntentStartedTask; + boolean isInitialIntentHeadingToChrome = false; boolean checkIsToChrome = true; // All custom tabs VIEW intents are by design explicit intents, so the presence of package // name doesn't imply they have to be handled by Chrome explicitly. Check if external apps // should be checked for handling the initial redirect chain. - if (mIsCustomTabIntent) { + if (isCustomTabIntent) { checkIsToChrome = !(sendToExternalApps && isCCTExternalLinkHandlingEnabled); } - if (checkIsToChrome) mIsInitialIntentHeadingToChrome = isIntentToChrome(intent); + if (checkIsToChrome) isInitialIntentHeadingToChrome = isIntentToChrome(intent); // A sanitized copy of the initial intent for detecting if resolvers have changed. - mInitialIntent = new Intent(intent); - ExternalNavigationHandler.sanitizeQueryIntentActivitiesIntent(mInitialIntent); + Intent initialIntent = new Intent(intent); + ExternalNavigationHandler.sanitizeQueryIntentActivitiesIntent(initialIntent); + mIntentState = new IntentState(initialIntent, isInitialIntentHeadingToChrome, + isCustomTabIntent, externalIntentStartedTask); } private static boolean isIntentToChrome(Intent intent) { @@ -98,25 +116,12 @@ IntentUtils.safeGetStringExtra(intent, Browser.EXTRA_APPLICATION_ID)); } - private void clearIntentHistory() { - mIsInitialIntentHeadingToChrome = false; - mIsCustomTabIntent = false; - mInitialIntent = null; - mExternalIntentStartedTask = false; - mCachedResolvers.clear(); - } - /** * Resets all variables except timestamps. */ public void clear() { - clearIntentHistory(); - mInitialNavigationType = NAVIGATION_TYPE_NONE; - mIsOnEffectiveRedirectChain = false; - mLastCommittedEntryIndexBeforeStartingNavigation = 0; - mHasUserStartedNonInitialNavigation = false; - mShouldNotOverrideUrlLoadingOnCurrentRedirectChain = false; - mShouldNotBlockOverrideUrlLoadingOnCurrentRedirectionChain = false; + mIntentState = null; + mNavigationState = null; } /** @@ -124,7 +129,7 @@ * occurs. */ public void setShouldNotOverrideUrlLoadingOnCurrentRedirectChain() { - mShouldNotOverrideUrlLoadingOnCurrentRedirectChain = true; + mNavigationState.mShouldNotOverrideUrlLoadingOnCurrentRedirectChain = true; } /** @@ -132,7 +137,7 @@ * a new user-initiated navigation occurs. */ public void setShouldNotBlockUrlLoadingOverrideOnCurrentRedirectionChain() { - mShouldNotBlockOverrideUrlLoadingOnCurrentRedirectionChain = true; + mNavigationState.mShouldNotBlockOverrideUrlLoadingOnCurrentRedirectionChain = true; } /** @@ -142,7 +147,7 @@ * swiped away or timed out). */ public boolean wasTaskStartedByExternalIntent() { - return mExternalIntentStartedTask; + return mIntentState != null && mIntentState.mExternalIntentStartedTask; } /** @@ -185,42 +190,39 @@ isNewLoadingStartedByUser = true; } } - - if (isNewLoadingStartedByUser) { - // Updates mInitialNavigationType for a new loading started by a user's gesture. - if (isFromIntent && mInitialIntent != null) { - mInitialNavigationType = NAVIGATION_TYPE_FROM_INTENT; - } else { - clearIntentHistory(); - if (pageTransitionCore == PageTransition.TYPED) { - mInitialNavigationType = NAVIGATION_TYPE_FROM_USER_TYPING; - } else if (pageTransitionCore == PageTransition.RELOAD - || (pageTransType & PageTransition.FORWARD_BACK) != 0) { - mInitialNavigationType = NAVIGATION_TYPE_FROM_RELOAD; - } else if (pageTransitionCore == PageTransition.LINK && !hasUserGesture) { - mInitialNavigationType = NAVIGATION_TYPE_FROM_LINK_WITHOUT_USER_GESTURE; - } else { - mInitialNavigationType = NAVIGATION_TYPE_OTHER; - } - } - mIsOnEffectiveRedirectChain = false; - mLastCommittedEntryIndexBeforeStartingNavigation = lastCommittedEntryIndex; - if (!isInitialNavigation) { - mHasUserStartedNonInitialNavigation = true; - } - mShouldNotOverrideUrlLoadingOnCurrentRedirectChain = false; - mShouldNotBlockOverrideUrlLoadingOnCurrentRedirectionChain = false; - } else if (mInitialNavigationType != NAVIGATION_TYPE_NONE) { + if (!isNewLoadingStartedByUser) { // Redirect chain starts from the second url loading. - mIsOnEffectiveRedirectChain = true; + mNavigationState.mIsOnEffectiveRedirectChain = true; + return; } + + // Create the NavigationState for a new Navigation chain. + int mInitialNavigationType; + if (isFromIntent && mIntentState != null) { + mInitialNavigationType = NAVIGATION_TYPE_FROM_INTENT; + } else { + mIntentState = null; + if (pageTransitionCore == PageTransition.TYPED) { + mInitialNavigationType = NAVIGATION_TYPE_FROM_USER_TYPING; + } else if (pageTransitionCore == PageTransition.RELOAD + || (pageTransType & PageTransition.FORWARD_BACK) != 0) { + mInitialNavigationType = NAVIGATION_TYPE_FROM_RELOAD; + } else if (pageTransitionCore == PageTransition.LINK && !hasUserGesture) { + mInitialNavigationType = NAVIGATION_TYPE_FROM_LINK_WITHOUT_USER_GESTURE; + } else { + mInitialNavigationType = NAVIGATION_TYPE_OTHER; + } + } + mNavigationState = new NavigationState( + mInitialNavigationType, lastCommittedEntryIndex, !isInitialNavigation); } /** * @return whether on effective intent redirect chain or not. */ public boolean isOnEffectiveIntentRedirectChain() { - return mInitialNavigationType == NAVIGATION_TYPE_FROM_INTENT && mIsOnEffectiveRedirectChain; + return mNavigationState.mInitialNavigationType == NAVIGATION_TYPE_FROM_INTENT + && mNavigationState.mIsOnEffectiveRedirectChain; } /** @@ -240,7 +242,8 @@ public boolean shouldStayInApp(boolean hasExternalProtocol, boolean isForTrustedCallingApp) { // http://crbug/424029 : Need to stay in Chrome for an intent heading explicitly to Chrome. // http://crbug/881740 : Relax stay in Chrome restriction for Custom Tabs. - return (mIsInitialIntentHeadingToChrome && !hasExternalProtocol) + return (mIntentState != null && mIntentState.mIsInitialIntentHeadingToChrome + && !hasExternalProtocol) || shouldNavigationTypeStayInApp(isForTrustedCallingApp); } @@ -253,35 +256,36 @@ private boolean shouldNavigationTypeStayInApp(boolean isForTrustedCallingApp) { // http://crbug.com/162106: Never leave Chrome from a refresh. - if (mInitialNavigationType == NAVIGATION_TYPE_FROM_RELOAD) return true; + if (mNavigationState.mInitialNavigationType == NAVIGATION_TYPE_FROM_RELOAD) return true; // If the app we would navigate to is trusted and what launched Chrome, allow the // navigation. if (isForTrustedCallingApp) return false; // Otherwise allow navigation out of the app only with a user gesture. - return mInitialNavigationType == NAVIGATION_TYPE_FROM_LINK_WITHOUT_USER_GESTURE; + return mNavigationState.mInitialNavigationType + == NAVIGATION_TYPE_FROM_LINK_WITHOUT_USER_GESTURE; } /** * @return Whether this navigation is initiated by a Custom Tabs {@link Intent}. */ public boolean isFromCustomTabIntent() { - return mIsCustomTabIntent; + return mIntentState != null && mIntentState.mIsCustomTabIntent; } /** * @return whether navigation is from a user's typing or not. */ public boolean isNavigationFromUserTyping() { - return mInitialNavigationType == NAVIGATION_TYPE_FROM_USER_TYPING; + return mNavigationState.mInitialNavigationType == NAVIGATION_TYPE_FROM_USER_TYPING; } /** * @return whether we should stay in Chrome or not. */ public boolean shouldNotOverrideUrlLoading() { - return mShouldNotOverrideUrlLoadingOnCurrentRedirectChain; + return mNavigationState.mShouldNotOverrideUrlLoadingOnCurrentRedirectChain; } /** @@ -289,8 +293,8 @@ * chain. */ public boolean getAndClearShouldNotBlockOverrideUrlLoadingOnCurrentRedirectionChain() { - boolean value = mShouldNotBlockOverrideUrlLoadingOnCurrentRedirectionChain; - mShouldNotBlockOverrideUrlLoadingOnCurrentRedirectionChain = false; + boolean value = mNavigationState.mShouldNotBlockOverrideUrlLoadingOnCurrentRedirectionChain; + mNavigationState.mShouldNotBlockOverrideUrlLoadingOnCurrentRedirectionChain = false; return value; } @@ -298,21 +302,21 @@ * @return whether on navigation or not. */ public boolean isOnNavigation() { - return mInitialNavigationType != NAVIGATION_TYPE_NONE; + return mNavigationState != null; } /** * @return the last committed entry index which was saved before starting this navigation. */ public int getLastCommittedEntryIndexBeforeStartingNavigation() { - return mLastCommittedEntryIndexBeforeStartingNavigation; + return mNavigationState.mLastCommittedEntryIndexBeforeStartingNavigation; } /** * @return whether the user has started a non-initial navigation. */ public boolean hasUserStartedNonInitialNavigation() { - return mHasUserStartedNonInitialNavigation; + return mNavigationState != null && mNavigationState.mHasUserStartedNonInitialNavigation; } /** @@ -320,19 +324,17 @@ */ public boolean hasNewResolver(List<ResolveInfo> resolvingInfos, Function<Intent, List<ResolveInfo>> queryIntentActivitiesFunction) { - if (mInitialIntent == null) { - return !resolvingInfos.isEmpty(); - } + if (mIntentState == null) return !resolvingInfos.isEmpty(); - if (mCachedResolvers.isEmpty()) { - for (ResolveInfo r : queryIntentActivitiesFunction.apply(mInitialIntent)) { - mCachedResolvers.add( + if (mIntentState.mCachedResolvers.isEmpty()) { + for (ResolveInfo r : queryIntentActivitiesFunction.apply(mIntentState.mInitialIntent)) { + mIntentState.mCachedResolvers.add( new ComponentName(r.activityInfo.packageName, r.activityInfo.name)); } } - if (resolvingInfos.size() > mCachedResolvers.size()) return true; + if (resolvingInfos.size() > mIntentState.mCachedResolvers.size()) return true; for (ResolveInfo r : resolvingInfos) { - if (!mCachedResolvers.contains( + if (!mIntentState.mCachedResolvers.contains( new ComponentName(r.activityInfo.packageName, r.activityInfo.name))) { return true; } @@ -344,6 +346,6 @@ * @return The initial intent of a redirect chain, if available. */ public Intent getInitialIntent() { - return mInitialIntent; + return mIntentState != null ? mIntentState.mInitialIntent : null; } }
diff --git a/components/external_intents/android/javatests/src/org/chromium/components/external_intents/RedirectHandlerTest.java b/components/external_intents/android/javatests/src/org/chromium/components/external_intents/RedirectHandlerTest.java index 4926f24..8ef164c 100644 --- a/components/external_intents/android/javatests/src/org/chromium/components/external_intents/RedirectHandlerTest.java +++ b/components/external_intents/android/javatests/src/org/chromium/components/external_intents/RedirectHandlerTest.java
@@ -4,6 +4,7 @@ package org.chromium.components.external_intents; +import android.content.Context; import android.content.Intent; import android.content.pm.ActivityInfo; import android.content.pm.PackageManager; @@ -14,17 +15,18 @@ import androidx.test.filters.SmallTest; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.chromium.base.CommandLine; import org.chromium.base.ContextUtils; import org.chromium.base.Function; import org.chromium.base.PackageManagerUtils; import org.chromium.base.test.BaseJUnit4ClassRunner; import org.chromium.base.test.util.AdvancedMockContext; +import org.chromium.base.test.util.Batch; import org.chromium.base.test.util.Feature; import org.chromium.ui.base.PageTransition; @@ -36,6 +38,7 @@ * Unittests for tab redirect handler. */ @RunWith(BaseJUnit4ClassRunner.class) +@Batch(Batch.UNIT_TESTS) public class RedirectHandlerTest { private static final int TRANS_TYPE_OF_LINK_FROM_INTENT = PageTransition.LINK | PageTransition.FROM_API; @@ -47,6 +50,8 @@ private Function<Intent, List<ResolveInfo>> mQueryIntentFunction = (Intent intent) -> queryIntentActivities(intent); + private Context mContextToRestore; + static { try { sYtIntent = Intent.parseUri("http://youtube.com/", Intent.URI_INTENT_SCHEME); @@ -59,10 +64,15 @@ @Before public void setUp() { - CommandLine.init(new String[0]); + mContextToRestore = ContextUtils.getApplicationContext(); ContextUtils.initApplicationContextForTests(new TestContext()); } + @After + public void tearDown() { + ContextUtils.initApplicationContextForTests(mContextToRestore); + } + private List<ResolveInfo> queryIntentActivities(Intent intent) { return PackageManagerUtils.queryIntentActivities(intent, 0); } @@ -253,8 +263,6 @@ fooIntent.putExtra(Browser.EXTRA_APPLICATION_ID, TEST_PACKAGE_NAME); handler.updateIntent(fooIntent, false, false, false, false); Assert.assertFalse(handler.isOnNavigation()); - Assert.assertTrue(handler.shouldStayInApp(false)); - Assert.assertFalse(handler.shouldStayInApp(true)); handler.updateNewUrlLoading(TRANS_TYPE_OF_LINK_FROM_INTENT, false, false, 0, 0, false); Assert.assertTrue(handler.shouldStayInApp(false)); @@ -283,7 +291,6 @@ RedirectHandler handler = RedirectHandler.create(); handler.updateIntent(sYtIntent, false, false, false, false); Assert.assertFalse(handler.isOnNavigation()); - Assert.assertFalse(handler.isNavigationFromUserTyping()); handler.updateNewUrlLoading(PageTransition.TYPED, false, false, 0, 0, false); Assert.assertTrue(handler.isNavigationFromUserTyping()); @@ -311,8 +318,6 @@ fooIntent.setPackage(TEST_PACKAGE_NAME); handler.updateIntent(fooIntent, false, false, false, false); Assert.assertFalse(handler.isOnNavigation()); - Assert.assertTrue(handler.shouldStayInApp(false)); - Assert.assertFalse(handler.shouldStayInApp(true)); handler.updateNewUrlLoading(TRANS_TYPE_OF_LINK_FROM_INTENT, false, false, 0, 0, false); Assert.assertTrue(handler.shouldStayInApp(false)); @@ -343,9 +348,9 @@ ///////////////////////////////////////////////////// RedirectHandler handler = RedirectHandler.create(); handler.updateIntent(sYtIntent, false, false, false, false); - Assert.assertFalse(handler.shouldNotOverrideUrlLoading()); handler.updateNewUrlLoading(PageTransition.LINK, false, true, 0, 0, false); + Assert.assertFalse(handler.shouldNotOverrideUrlLoading()); handler.setShouldNotOverrideUrlLoadingOnCurrentRedirectChain(); handler.updateNewUrlLoading(PageTransition.LINK, true, false, 0, 0, false); @@ -357,9 +362,9 @@ ///////////////////////////////////////////////////// handler = RedirectHandler.create(); handler.updateIntent(sYtIntent, false, false, false, false); - Assert.assertFalse(handler.shouldNotOverrideUrlLoading()); handler.updateNewUrlLoading(PageTransition.LINK, false, true, 0, 0, false); + Assert.assertFalse(handler.shouldNotOverrideUrlLoading()); handler.setShouldNotOverrideUrlLoadingOnCurrentRedirectChain(); // Effective redirection occurred. @@ -384,8 +389,6 @@ RedirectHandler handler = RedirectHandler.create(); handler.updateIntent(sYtIntent, false, false, false, false); Assert.assertFalse(handler.isOnNavigation()); - Assert.assertFalse(handler.shouldStayInApp(false)); - Assert.assertFalse(handler.shouldStayInApp(true)); long lastUserInteractionTime = SystemClock.elapsedRealtime(); handler.updateNewUrlLoading( @@ -417,8 +420,6 @@ RedirectHandler handler = RedirectHandler.create(); handler.updateIntent(sYtIntent, false, false, false, false); Assert.assertFalse(handler.isOnNavigation()); - Assert.assertFalse(handler.shouldStayInApp(false)); - Assert.assertFalse(handler.shouldStayInApp(true)); long lastUserInteractionTime = SystemClock.elapsedRealtime(); handler.updateNewUrlLoading( @@ -450,15 +451,14 @@ RedirectHandler handler = RedirectHandler.create(); handler.updateIntent(sYtIntent, false, false, false, false); Assert.assertFalse(handler.isOnNavigation()); - Assert.assertFalse(handler.shouldStayInApp(false)); - Assert.assertFalse(handler.shouldStayInApp(true)); - Assert.assertFalse(handler.hasUserStartedNonInitialNavigation()); long lastUserInteractionTime = SystemClock.elapsedRealtime(); handler.updateNewUrlLoading(PageTransition.FORM_SUBMIT | PageTransition.FORWARD_BACK, false, true, lastUserInteractionTime, 0, false); Assert.assertTrue(handler.shouldStayInApp(false)); Assert.assertTrue(handler.shouldStayInApp(true)); + Assert.assertTrue(handler.hasUserStartedNonInitialNavigation()); + handler.updateNewUrlLoading(PageTransition.LINK, false, false, lastUserInteractionTime, 1, false /* isInitialNavigation */); Assert.assertTrue(handler.shouldStayInApp(false));
diff --git a/components/history_clusters/core/BUILD.gn b/components/history_clusters/core/BUILD.gn index 7437a17..01e2e92 100644 --- a/components/history_clusters/core/BUILD.gn +++ b/components/history_clusters/core/BUILD.gn
@@ -41,6 +41,8 @@ "content_annotations_cluster_processor.h", "content_visibility_cluster_finalizer.cc", "content_visibility_cluster_finalizer.h", + "keyword_cluster_finalizer.cc", + "keyword_cluster_finalizer.h", "noisy_cluster_finalizer.cc", "noisy_cluster_finalizer.h", "on_device_clustering_backend.cc", @@ -63,7 +65,6 @@ ":history_clusters_buildflags", "//base", "//components/history/core/browser", - "//components/history_clusters/core/proto", "//components/keyed_service/core", "//components/optimization_guide/core:entities", "//components/pref_registry", @@ -91,6 +92,7 @@ "clusterer_unittest.cc", "content_annotations_cluster_processor_unittest.cc", "content_visibility_cluster_finalizer_unittest.cc", + "keyword_cluster_finalizer_unittest.cc", "noisy_cluster_finalizer_unittest.cc", "on_device_clustering_backend_unittest.cc", "on_device_clustering_util_unittest.cc", @@ -106,7 +108,6 @@ "//base/test:test_support", "//components/history/core/browser", "//components/history/core/test", - "//components/history_clusters/core/proto", "//components/optimization_guide/core:entities", "//components/search_engines", "//services/network:test_support",
diff --git a/components/history_clusters/core/clusterer.cc b/components/history_clusters/core/clusterer.cc index 6f7c087..292055f6 100644 --- a/components/history_clusters/core/clusterer.cc +++ b/components/history_clusters/core/clusterer.cc
@@ -4,32 +4,11 @@ #include "components/history_clusters/core/clusterer.h" -#include "base/strings/utf_string_conversions.h" #include "components/history/core/browser/history_types.h" #include "components/history_clusters/core/on_device_clustering_features.h" namespace history_clusters { -namespace { - -void AddKeywordsForVisitToCluster(history::Cluster& cluster, - const history::ClusterVisit& visit) { - base::flat_set<std::u16string> keywords_set(cluster.keywords.begin(), - cluster.keywords.end()); - for (const auto& entity : - visit.annotated_visit.content_annotations.model_annotations.entities) { - keywords_set.insert(base::UTF8ToUTF16(entity.id)); - } - for (const auto& category : - visit.annotated_visit.content_annotations.model_annotations.categories) { - keywords_set.insert(base::UTF8ToUTF16(category.id)); - } - cluster.keywords = - std::vector<std::u16string>(keywords_set.begin(), keywords_set.end()); -} - -} // namespace - Clusterer::Clusterer() = default; Clusterer::~Clusterer() = default; @@ -98,15 +77,12 @@ default_scored_visit.score = 1.0; if (cluster_idx) { clusters[*cluster_idx].visits.push_back(default_scored_visit); - AddKeywordsForVisitToCluster(clusters[*cluster_idx], - default_scored_visit); } else { // Add to new cluster. cluster_idx = clusters.size(); history::Cluster new_cluster; new_cluster.visits = {default_scored_visit}; - AddKeywordsForVisitToCluster(new_cluster, default_scored_visit); clusters.push_back(std::move(new_cluster)); } visit_id_to_cluster_map[visit.annotated_visit.visit_row.visit_id] =
diff --git a/components/history_clusters/core/content_annotations_cluster_processor.cc b/components/history_clusters/core/content_annotations_cluster_processor.cc index 9835485e..08f9075 100644 --- a/components/history_clusters/core/content_annotations_cluster_processor.cc +++ b/components/history_clusters/core/content_annotations_cluster_processor.cc
@@ -132,7 +132,7 @@ CreateBoWsForClusters(clusters, &cluster_idx_to_entity_bows, &cluster_idx_to_category_bows); - // Now cluster on the keywords in each BoW between clusters. + // Now cluster on the entries in each BoW between clusters. std::vector<history::Cluster> aggregated_clusters; base::flat_set<int> merged_cluster_indices; for (size_t i = 0; i < clusters.size(); i++) { @@ -142,8 +142,6 @@ // Greedily combine clusters by checking if this cluster is similar to any // other unmerged clusters. history::Cluster aggregated_cluster = clusters[i]; - base::flat_set<std::u16string> aggregated_cluster_keywords( - clusters[i].keywords.begin(), clusters[i].keywords.end()); for (size_t j = i + 1; j < clusters.size(); j++) { if (merged_cluster_indices.find(j) != merged_cluster_indices.end()) { continue; @@ -153,18 +151,13 @@ float category_similarity = CalculateSimilarityScore( cluster_idx_to_category_bows[i], cluster_idx_to_category_bows[j]); if (ShouldMergeClusters(entity_similarity, category_similarity)) { - // Add the visits and keywords to the aggregated cluster. + // Add the visits to the aggregated cluster. merged_cluster_indices.insert(j); aggregated_cluster.visits.insert(aggregated_cluster.visits.end(), clusters[j].visits.begin(), clusters[j].visits.end()); - aggregated_cluster_keywords.insert(clusters[j].keywords.begin(), - clusters[j].keywords.end()); } } - aggregated_cluster.keywords = - std::vector<std::u16string>({aggregated_cluster_keywords.begin(), - aggregated_cluster_keywords.end()}); aggregated_clusters.push_back(std::move(aggregated_cluster)); } return aggregated_clusters;
diff --git a/components/history_clusters/core/content_annotations_cluster_processor_unittest.cc b/components/history_clusters/core/content_annotations_cluster_processor_unittest.cc index 7d5762c..38b7aa4 100644 --- a/components/history_clusters/core/content_annotations_cluster_processor_unittest.cc +++ b/components/history_clusters/core/content_annotations_cluster_processor_unittest.cc
@@ -62,7 +62,6 @@ cluster1.visits = {testing::CreateClusterVisit(visit), testing::CreateClusterVisit(visit2), testing::CreateClusterVisit(visit4)}; - cluster1.keywords = {std::u16string(u"github"), std::u16string(u"google")}; clusters.push_back(cluster1); // After the context clustering, visit5 will not be in the same cluster as @@ -74,8 +73,6 @@ visit5.content_annotations.model_annotations.categories = {{"category", 1}}; history::Cluster cluster2; cluster2.visits = {testing::CreateClusterVisit(visit5)}; - cluster2.keywords = {std::u16string(u"github"), - std::u16string(u"otherkeyword")}; clusters.push_back(cluster2); std::vector<history::Cluster> result_clusters = ProcessClusters(clusters); @@ -85,10 +82,6 @@ testing::VisitResult(1, 1.0), testing::VisitResult(2, 1.0), testing::VisitResult(4, 1.0), testing::VisitResult(10, 1.0)))); ASSERT_EQ(result_clusters.size(), 1u); - EXPECT_THAT( - result_clusters.at(0).keywords, - UnorderedElementsAre(std::u16string(u"github"), std::u16string(u"google"), - std::u16string(u"otherkeyword"))); } TEST_F(ContentAnnotationsClusterProcessorTest, BelowThreshold) { @@ -105,7 +98,6 @@ history::Cluster cluster1; cluster1.visits = {testing::CreateClusterVisit(visit), testing::CreateClusterVisit(visit2)}; - cluster1.keywords = {std::u16string(u"github"), std::u16string(u"google")}; clusters.push_back(cluster1); // After the context clustering, visit4 will not be in the same cluster as @@ -117,7 +109,6 @@ visit4.content_annotations.model_annotations.entities = {{"github", 1}}; history::Cluster cluster2; cluster2.visits = {testing::CreateClusterVisit(visit4)}; - cluster2.keywords = {std::u16string(u"github")}; clusters.push_back(cluster2); // This visit has the same entities but no categories and shouldn't be @@ -127,7 +118,6 @@ visit5.content_annotations.model_annotations.entities = {{"github", 1}}; history::Cluster cluster3; cluster3.visits = {testing::CreateClusterVisit(visit5)}; - cluster3.keywords = {std::u16string(u"irrelevant")}; clusters.push_back(cluster3); // This visit has the same categories but no entities and shouldn't be @@ -137,7 +127,6 @@ visit6.content_annotations.model_annotations.categories = {{"category", 1}}; history::Cluster cluster4; cluster4.visits = {testing::CreateClusterVisit(visit6)}; - cluster4.keywords = {std::u16string(u"category")}; clusters.push_back(cluster4); // This visit has no content annotations and shouldn't be grouped with the @@ -157,14 +146,6 @@ ElementsAre(testing::VisitResult(11, 1.0)), ElementsAre(testing::VisitResult(12, 1.0)))); EXPECT_THAT(result_clusters.size(), 4u); - EXPECT_THAT(result_clusters.at(0).keywords, - UnorderedElementsAre(std::u16string(u"github"), - std::u16string(u"google"))); - EXPECT_THAT(result_clusters.at(1).keywords, - UnorderedElementsAre(std::u16string(u"irrelevant"))); - EXPECT_THAT(result_clusters.at(2).keywords, - UnorderedElementsAre(std::u16string(u"category"))); - EXPECT_TRUE(result_clusters.at(3).keywords.empty()); } class ContentAnnotationsIntersectionMetricTest @@ -199,7 +180,6 @@ cluster1.visits = {testing::CreateClusterVisit(visit), testing::CreateClusterVisit(visit2), testing::CreateClusterVisit(visit4)}; - cluster1.keywords = {std::u16string(u"github"), std::u16string(u"google")}; clusters.push_back(cluster1); // After the context clustering, visit5 will not be in the same cluster as @@ -216,8 +196,6 @@ history::Cluster cluster2; cluster2.visits = {testing::CreateClusterVisit(visit5), testing::CreateClusterVisit(visit6)}; - cluster2.keywords = {std::u16string(u"github"), std::u16string(u"google"), - std::u16string(u"otherkeyword")}; clusters.push_back(cluster2); std::vector<history::Cluster> result_clusters = ProcessClusters(clusters); @@ -227,10 +205,6 @@ testing::VisitResult(4, 1.0), testing::VisitResult(10, 1.0), testing::VisitResult(11, 1.0)))); ASSERT_EQ(result_clusters.size(), 1u); - EXPECT_THAT( - result_clusters.at(0).keywords, - UnorderedElementsAre(std::u16string(u"github"), std::u16string(u"google"), - std::u16string(u"otherkeyword"))); } TEST_F(ContentAnnotationsIntersectionMetricTest, BelowThreshold) { @@ -251,7 +225,6 @@ cluster1.visits = {testing::CreateClusterVisit(visit), testing::CreateClusterVisit(visit2), testing::CreateClusterVisit(visit4)}; - cluster1.keywords = {std::u16string(u"github"), std::u16string(u"google")}; clusters.push_back(cluster1); // After the context clustering, visit5 will not be in the same cluster as @@ -263,15 +236,10 @@ visit5.content_annotations.model_annotations.categories = {{"category", 1}}; history::Cluster cluster2; cluster2.visits = {testing::CreateClusterVisit(visit5)}; - cluster2.keywords = {std::u16string(u"github"), std::u16string(u"google"), - std::u16string(u"otherkeyword")}; clusters.push_back(cluster2); std::vector<history::Cluster> result_clusters = ProcessClusters(clusters); ASSERT_EQ(result_clusters.size(), 2u); - EXPECT_THAT(result_clusters.at(0).keywords, - UnorderedElementsAre(std::u16string(u"github"), - std::u16string(u"google"))); } } // namespace
diff --git a/components/history_clusters/core/keyword_cluster_finalizer.cc b/components/history_clusters/core/keyword_cluster_finalizer.cc new file mode 100644 index 0000000..b4e43d2 --- /dev/null +++ b/components/history_clusters/core/keyword_cluster_finalizer.cc
@@ -0,0 +1,42 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/history_clusters/core/keyword_cluster_finalizer.h" + +#include "base/containers/flat_set.h" +#include "base/strings/utf_string_conversions.h" +#include "components/history_clusters/core/on_device_clustering_features.h" +#include "components/history_clusters/core/on_device_clustering_util.h" + +namespace history_clusters { + +KeywordClusterFinalizer::KeywordClusterFinalizer() = default; +KeywordClusterFinalizer::~KeywordClusterFinalizer() = default; + +void KeywordClusterFinalizer::FinalizeCluster(history::Cluster& cluster) { + base::flat_set<std::u16string> keywords_set; + for (const auto& visit : cluster.visits) { + if (features::ShouldExcludeKeywordsFromNoisyVisits() && + IsNoisyVisit(visit)) { + // Do not put keywords if user visits the page a lot and it's not a + // search-like visit. + continue; + } + + for (const auto& entity : + visit.annotated_visit.content_annotations.model_annotations.entities) { + keywords_set.insert(base::UTF8ToUTF16(entity.id)); + } + if (features::ShouldIncludeCategoriesInKeywords()) { + for (const auto& category : visit.annotated_visit.content_annotations + .model_annotations.categories) { + keywords_set.insert(base::UTF8ToUTF16(category.id)); + } + } + } + cluster.keywords = + std::vector<std::u16string>(keywords_set.begin(), keywords_set.end()); +} + +} // namespace history_clusters
diff --git a/components/history_clusters/core/keyword_cluster_finalizer.h b/components/history_clusters/core/keyword_cluster_finalizer.h new file mode 100644 index 0000000..9781abf --- /dev/null +++ b/components/history_clusters/core/keyword_cluster_finalizer.h
@@ -0,0 +1,24 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_HISTORY_CLUSTERS_CORE_KEYWORD_CLUSTER_FINALIZER_H_ +#define COMPONENTS_HISTORY_CLUSTERS_CORE_KEYWORD_CLUSTER_FINALIZER_H_ + +#include "components/history_clusters/core/cluster_finalizer.h" + +namespace history_clusters { + +// A cluster finalizer that determines the set of keywords for a given cluster. +class KeywordClusterFinalizer : public ClusterFinalizer { + public: + KeywordClusterFinalizer(); + ~KeywordClusterFinalizer() override; + + // ClusterFinalizer: + void FinalizeCluster(history::Cluster& cluster) override; +}; + +} // namespace history_clusters + +#endif // COMPONENTS_HISTORY_CLUSTERS_CORE_KEYWORD_CLUSTER_FINALIZER_H_
diff --git a/components/history_clusters/core/keyword_cluster_finalizer_unittest.cc b/components/history_clusters/core/keyword_cluster_finalizer_unittest.cc new file mode 100644 index 0000000..0143ce0 --- /dev/null +++ b/components/history_clusters/core/keyword_cluster_finalizer_unittest.cc
@@ -0,0 +1,134 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/history_clusters/core/keyword_cluster_finalizer.h" + +#include "base/test/scoped_feature_list.h" +#include "base/test/task_environment.h" +#include "components/history_clusters/core/clustering_test_utils.h" +#include "components/history_clusters/core/on_device_clustering_features.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace history_clusters { +namespace { + +using ::testing::UnorderedElementsAre; + +class KeywordClusterFinalizerTest : public ::testing::Test { + public: + void SetUp() override { + cluster_finalizer_ = std::make_unique<KeywordClusterFinalizer>(); + scoped_feature_list_.InitAndEnableFeatureWithParameters( + features::kOnDeviceClustering, + {{"exclude_keywords_from_noisy_visits", "true"}, + {"include_categories_in_keywords", "false"}}); + } + + void TearDown() override { cluster_finalizer_.reset(); } + + void FinalizeCluster(history::Cluster& cluster) { + cluster_finalizer_->FinalizeCluster(cluster); + } + + private: + base::test::ScopedFeatureList scoped_feature_list_; + std::unique_ptr<KeywordClusterFinalizer> cluster_finalizer_; + base::test::TaskEnvironment task_environment_; +}; + +TEST_F(KeywordClusterFinalizerTest, IncludesKeywordsBasedOnFeatureParameters) { + history::ClusterVisit visit = testing::CreateClusterVisit( + testing::CreateDefaultAnnotatedVisit(1, GURL("https://foo.com/"))); + visit.engagement_score = 1.0; + visit.annotated_visit.content_annotations.model_annotations.entities = { + {"github", 1}}; + visit.annotated_visit.content_annotations.model_annotations.categories = { + {"category", 1}}; + + history::ClusterVisit visit2 = + testing::CreateClusterVisit(testing::CreateDefaultAnnotatedVisit( + 2, GURL("https://engagementtoohigh.com/"))); + visit2.engagement_score = 25.0; + visit2.annotated_visit.content_annotations.model_annotations.entities = { + {"github", 1}, {"onlyinnoisyvisit", 1}}; + visit2.annotated_visit.content_annotations.model_annotations.categories = { + {"category", 1}}; + + history::ClusterVisit visit3 = testing::CreateClusterVisit( + testing::CreateDefaultAnnotatedVisit(2, GURL("https://baz.com/"))); + visit3.duplicate_visit_ids.push_back(1); + visit3.engagement_score = 1.0; + visit3.annotated_visit.content_annotations.model_annotations.entities = { + {"github", 1}, {"otherentity", 1}}; + visit3.annotated_visit.content_annotations.model_annotations.categories = { + {"category", 1}}; + + history::Cluster cluster; + cluster.visits = {visit, visit2, visit3}; + FinalizeCluster(cluster); + EXPECT_THAT(cluster.keywords, + UnorderedElementsAre(u"github", u"otherentity")); +} + +class KeywordClusterFinalizerIncludeAllTest : public ::testing::Test { + public: + void SetUp() override { + cluster_finalizer_ = std::make_unique<KeywordClusterFinalizer>(); + scoped_feature_list_.InitAndEnableFeatureWithParameters( + features::kOnDeviceClustering, + {{"exclude_keywords_from_noisy_visits", "false"}, + {"include_categories_in_keywords", "true"}}); + } + + void TearDown() override { cluster_finalizer_.reset(); } + + void FinalizeCluster(history::Cluster& cluster) { + cluster_finalizer_->FinalizeCluster(cluster); + } + + private: + base::test::ScopedFeatureList scoped_feature_list_; + std::unique_ptr<KeywordClusterFinalizer> cluster_finalizer_; + base::test::TaskEnvironment task_environment_; +}; + +TEST_F(KeywordClusterFinalizerIncludeAllTest, + IncludesKeywordsBasedOnFeatureParameters) { + history::ClusterVisit visit = testing::CreateClusterVisit( + testing::CreateDefaultAnnotatedVisit(1, GURL("https://foo.com/"))); + visit.engagement_score = 1.0; + visit.annotated_visit.content_annotations.model_annotations.entities = { + {"github", 1}}; + visit.annotated_visit.content_annotations.model_annotations.categories = { + {"category", 1}}; + + history::ClusterVisit visit2 = + testing::CreateClusterVisit(testing::CreateDefaultAnnotatedVisit( + 2, GURL("https://engagementtoohigh.com/"))); + visit2.engagement_score = 25.0; + visit2.annotated_visit.content_annotations.model_annotations.entities = { + {"github", 1}, {"onlyinnoisyvisit", 1}}; + visit2.annotated_visit.content_annotations.model_annotations.categories = { + {"category", 1}}; + + history::ClusterVisit visit3 = testing::CreateClusterVisit( + testing::CreateDefaultAnnotatedVisit(2, GURL("https://baz.com/"))); + visit3.duplicate_visit_ids.push_back(1); + visit3.engagement_score = 1.0; + visit3.annotated_visit.content_annotations.model_annotations.entities = { + {"github", 1}, {"otherentity", 1}}; + visit3.annotated_visit.content_annotations.model_annotations.categories = { + {"category", 1}}; + + history::Cluster cluster; + cluster.visits = {visit, visit2, visit3}; + FinalizeCluster(cluster); + EXPECT_THAT(cluster.keywords, + UnorderedElementsAre(u"github", u"category", u"onlyinnoisyvisit", + u"otherentity")); +} + +} // namespace +} // namespace history_clusters
diff --git a/components/history_clusters/core/noisy_cluster_finalizer.cc b/components/history_clusters/core/noisy_cluster_finalizer.cc index ff5b8ef..3bc2dae 100644 --- a/components/history_clusters/core/noisy_cluster_finalizer.cc +++ b/components/history_clusters/core/noisy_cluster_finalizer.cc
@@ -15,8 +15,7 @@ void NoisyClusterFinalizer::FinalizeCluster(history::Cluster& cluster) { size_t interesting_visit_cnt = 0; for (const auto& visit : cluster.visits) { - if (visit.engagement_score < - features::NoisyClusterVisitEngagementThreshold()) { + if (!IsNoisyVisit(visit)) { interesting_visit_cnt += 1; } if (interesting_visit_cnt >=
diff --git a/components/history_clusters/core/on_device_clustering_backend.cc b/components/history_clusters/core/on_device_clustering_backend.cc index eefe333..b6db99b1 100644 --- a/components/history_clusters/core/on_device_clustering_backend.cc +++ b/components/history_clusters/core/on_device_clustering_backend.cc
@@ -18,6 +18,7 @@ #include "components/history/core/browser/history_types.h" #include "components/history_clusters/core/content_annotations_cluster_processor.h" #include "components/history_clusters/core/content_visibility_cluster_finalizer.h" +#include "components/history_clusters/core/keyword_cluster_finalizer.h" #include "components/history_clusters/core/noisy_cluster_finalizer.h" #include "components/history_clusters/core/on_device_clustering_features.h" #include "components/history_clusters/core/on_device_clustering_util.h" @@ -272,6 +273,7 @@ if (engagement_score_provider_ && features::ShouldFilterNoisyClusters()) { cluster_finalizers.push_back(std::make_unique<NoisyClusterFinalizer>()); } + cluster_finalizers.push_back(std::make_unique<KeywordClusterFinalizer>()); // Group visits into clusters. std::vector<history::Cluster> clusters =
diff --git a/components/history_clusters/core/on_device_clustering_backend_unittest.cc b/components/history_clusters/core/on_device_clustering_backend_unittest.cc index c1a5499..2f38753 100644 --- a/components/history_clusters/core/on_device_clustering_backend_unittest.cc +++ b/components/history_clusters/core/on_device_clustering_backend_unittest.cc
@@ -72,7 +72,9 @@ features::kOnDeviceClustering, {{"content_clustering_enabled", "false"}, {"dedupe_similar_visits", "false"}, - {"min_page_topics_model_version_for_visibility", "125"}}); + {"min_page_topics_model_version_for_visibility", "125"}, + {"include_categories_in_keywords", "true"}, + {"exclude_keywords_from_noisy_visits", "false"}}); } void SetUp() override { @@ -356,8 +358,11 @@ public: OnDeviceClusteringWithContentBackendTest() { scoped_feature_list_.InitAndEnableFeatureWithParameters( - features::kOnDeviceClustering, {{"content_clustering_enabled", "true"}, - {"dedupe_similar_visits", "false"}}); + features::kOnDeviceClustering, + {{"content_clustering_enabled", "true"}, + {"dedupe_similar_visits", "false"}, + {"include_categories_in_keywords", "true"}, + {"exclude_keywords_from_noisy_visits", "false"}}); } private:
diff --git a/components/history_clusters/core/on_device_clustering_features.cc b/components/history_clusters/core/on_device_clustering_features.cc index 24f818f..e49f8428 100644 --- a/components/history_clusters/core/on_device_clustering_features.cc +++ b/components/history_clusters/core/on_device_clustering_features.cc
@@ -125,5 +125,15 @@ kOnDeviceClustering, "content_clustering_intersection_threshold", 2); } +bool ShouldIncludeCategoriesInKeywords() { + return GetFieldTrialParamByFeatureAsBool( + kOnDeviceClustering, "include_categories_in_keywords", true); +} + +bool ShouldExcludeKeywordsFromNoisyVisits() { + return GetFieldTrialParamByFeatureAsBool( + kOnDeviceClustering, "exclude_keywords_from_noisy_visits", false); +} + } // namespace features } // namespace history_clusters
diff --git a/components/history_clusters/core/on_device_clustering_features.h b/components/history_clusters/core/on_device_clustering_features.h index 0b1d46de..4923484 100644 --- a/components/history_clusters/core/on_device_clustering_features.h +++ b/components/history_clusters/core/on_device_clustering_features.h
@@ -92,6 +92,13 @@ // when clustering based on intersection score. int ClusterIntersectionThreshold(); +// Whether to include category names in the keywords for a cluster. +bool ShouldIncludeCategoriesInKeywords(); + +// Whether to exclude keywords from visits that may be considered "noisy" to the +// user (i.e. highly engaged, non-SRP). +bool ShouldExcludeKeywordsFromNoisyVisits(); + } // namespace features } // namespace history_clusters
diff --git a/components/history_clusters/core/on_device_clustering_util.cc b/components/history_clusters/core/on_device_clustering_util.cc index 3c7b7ca3..6621c0b 100644 --- a/components/history_clusters/core/on_device_clustering_util.cc +++ b/components/history_clusters/core/on_device_clustering_util.cc
@@ -5,6 +5,7 @@ #include "components/history_clusters/core/on_device_clustering_util.h" #include "base/containers/contains.h" +#include "components/history_clusters/core/on_device_clustering_features.h" namespace history_clusters { @@ -109,4 +110,10 @@ }); } +bool IsNoisyVisit(const history::ClusterVisit& visit) { + return visit.engagement_score > + features::NoisyClusterVisitEngagementThreshold() && + !visit.is_search_visit; +} + } // namespace history_clusters
diff --git a/components/history_clusters/core/on_device_clustering_util.h b/components/history_clusters/core/on_device_clustering_util.h index 38106938..a0c6355 100644 --- a/components/history_clusters/core/on_device_clustering_util.h +++ b/components/history_clusters/core/on_device_clustering_util.h
@@ -23,6 +23,10 @@ // by-score sorting of visits within clusters. Exposed for testing. void SortClusters(std::vector<history::Cluster>* clusters); +// Whether the visit is considered a noisy visit (i.e. high engagement, +// non-SRP). +bool IsNoisyVisit(const history::ClusterVisit& visit); + } // namespace history_clusters #endif // COMPONENTS_HISTORY_CLUSTERS_CORE_ON_DEVICE_CLUSTERING_UTIL_H_
diff --git a/components/history_clusters/core/on_device_clustering_util_unittest.cc b/components/history_clusters/core/on_device_clustering_util_unittest.cc index 97998a416..00d7bc26 100644 --- a/components/history_clusters/core/on_device_clustering_util_unittest.cc +++ b/components/history_clusters/core/on_device_clustering_util_unittest.cc
@@ -140,5 +140,33 @@ EXPECT_FLOAT_EQ(visits[1].score, 0.5); } +TEST_F(OnDeviceClusteringUtilTest, IsNoisyVisitSearchHighEngagementVisit) { + history::ClusterVisit visit; + visit.is_search_visit = true; + visit.engagement_score = 90.0; + EXPECT_FALSE(IsNoisyVisit(visit)); +} + +TEST_F(OnDeviceClusteringUtilTest, IsNoisyVisitNotSearchHighEngagementVisit) { + history::ClusterVisit visit; + visit.is_search_visit = false; + visit.engagement_score = 90.0; + EXPECT_TRUE(IsNoisyVisit(visit)); +} + +TEST_F(OnDeviceClusteringUtilTest, IsNoisyVisitNotSearchLowEngagementVisit) { + history::ClusterVisit visit; + visit.is_search_visit = false; + visit.engagement_score = 1.0; + EXPECT_FALSE(IsNoisyVisit(visit)); +} + +TEST_F(OnDeviceClusteringUtilTest, IsNoisyVisitSearchLowEngagementVisit) { + history::ClusterVisit visit; + visit.is_search_visit = true; + visit.engagement_score = 1.0; + EXPECT_FALSE(IsNoisyVisit(visit)); +} + } // namespace } // namespace history_clusters
diff --git a/components/history_clusters/core/proto/BUILD.gn b/components/history_clusters/core/proto/BUILD.gn deleted file mode 100644 index 20e7e6e..0000000 --- a/components/history_clusters/core/proto/BUILD.gn +++ /dev/null
@@ -1,9 +0,0 @@ -# Copyright 2021 The Chromium Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -import("//third_party/protobuf/proto_library.gni") - -proto_library("proto") { - sources = [ "clusters.proto" ] -}
diff --git a/components/history_clusters/core/proto/clusters.proto b/components/history_clusters/core/proto/clusters.proto deleted file mode 100644 index 23faed2..0000000 --- a/components/history_clusters/core/proto/clusters.proto +++ /dev/null
@@ -1,67 +0,0 @@ -// Copyright 2021 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. -syntax = "proto3"; - -option optimize_for = LITE_RUNTIME; - -package history_clusters.proto; - -// Below this line sourced from Google3 for debug (non-production) usage only. - -message GetClustersRequest { - // Represents a set of visits. - repeated AnnotatedVisit visits = 1; - // The experiment name that controls the clustering behavior served for - // this request. - string experiment_name = 2; -} - -message GetClustersResponse { - // Represents a set of clusters calculated from the request. - repeated Cluster clusters = 1; -} - -message AnnotatedVisit { - // The ID associated with this visit. - int64 visit_id = 1; - // The URL for the visit. - string url = 2; - // The origin for the visit. - string origin = 3; - // The amount of time the page load was in the foreground in seconds. - // TODO(tommycli): Kind of a misnomer. We are just using visit_duration, - // which may or may not be in the foreground. - int64 foreground_time_secs = 4; - // Relative time of navigation for this page load. - int64 navigation_time_ms = 5; - // Site engagement score rounded to nearest ten. - int64 site_engagement_score = 6; - // The page end reason. - int64 page_end_reason = 7; - // Page transition. - int64 page_transition = 8; - // Whether the page load originated from Google Search. - bool is_from_google_search = 9; - // The visit ID that referred this visit, if not a new navigation. - int64 referring_visit_id = 10; -} - -message ClusterVisit { - // The ID of the visit where visit_id corresponds to the visit in the history - // table. - int64 visit_id = 1; - // The score associated with this visit. - // - // Used for calculating ordering of visits within a cluster. - float score = 2; -} - -message Cluster { - reserved 2; - - // The keywords that the cluster contains/is related to. - repeated string keywords = 1; - // The visits that are attached to this cluster. - repeated ClusterVisit cluster_visits = 3; -}
diff --git a/components/js_injection/browser/js_to_browser_messaging.cc b/components/js_injection/browser/js_to_browser_messaging.cc index e2cc9cf..7867deb 100644 --- a/components/js_injection/browser/js_to_browser_messaging.cc +++ b/components/js_injection/browser/js_to_browser_messaging.cc
@@ -56,6 +56,7 @@ return render_frame_host_->GetLifecycleState() == content::RenderFrameHost::LifecycleState::kInBackForwardCache; } + content::Page& GetPage() override { return render_frame_host_->GetPage(); } private: raw_ptr<content::RenderFrameHost> render_frame_host_;
diff --git a/components/js_injection/browser/web_message_reply_proxy.h b/components/js_injection/browser/web_message_reply_proxy.h index 2eb52975..3b887ee 100644 --- a/components/js_injection/browser/web_message_reply_proxy.h +++ b/components/js_injection/browser/web_message_reply_proxy.h
@@ -5,6 +5,10 @@ #ifndef COMPONENTS_JS_INJECTION_BROWSER_WEB_MESSAGE_REPLY_PROXY_H_ #define COMPONENTS_JS_INJECTION_BROWSER_WEB_MESSAGE_REPLY_PROXY_H_ +namespace content { +class Page; +} + namespace js_injection { struct WebMessage; @@ -18,6 +22,9 @@ // forward cache. virtual bool IsInBackForwardCache() = 0; + // Returns the page the messages are sent to. + virtual content::Page& GetPage() = 0; + protected: virtual ~WebMessageReplyProxy() = default; };
diff --git a/components/keep_alive_registry/keep_alive_types.cc b/components/keep_alive_registry/keep_alive_types.cc index ea9bc27..950be06c 100644 --- a/components/keep_alive_registry/keep_alive_types.cc +++ b/components/keep_alive_registry/keep_alive_types.cc
@@ -68,6 +68,8 @@ return out << "CREDENTIAL_PROVIDER_SIGNIN_DIALOG"; case KeepAliveOrigin::WEB_APP_INTENT_PICKER: return out << "WEB_APP_INTENT_PICKER"; + case KeepAliveOrigin::WEB_APP_UNINSTALL: + return out << "WEB_APP_UNINSTALL"; case KeepAliveOrigin::APP_MANIFEST_UPDATE: return out << "APP_MANIFEST_UPDATE"; case KeepAliveOrigin::APP_START_URL_MIGRATION:
diff --git a/components/keep_alive_registry/keep_alive_types.h b/components/keep_alive_registry/keep_alive_types.h index 9957017..15b4f9c 100644 --- a/components/keep_alive_registry/keep_alive_types.h +++ b/components/keep_alive_registry/keep_alive_types.h
@@ -64,6 +64,9 @@ CREDENTIAL_PROVIDER_SIGNIN_DIALOG, WEB_APP_INTENT_PICKER, + // c/b/ui/web_applications + WEB_APP_UNINSTALL, + // c/b/web_applications APP_MANIFEST_UPDATE, APP_START_URL_MIGRATION,
diff --git a/components/management_strings.grdp b/components/management_strings.grdp index b0c2cdb..ee4c67d 100644 --- a/components/management_strings.grdp +++ b/components/management_strings.grdp
@@ -137,6 +137,9 @@ <message name="IDS_MANAGEMENT_REPORT_LOGIN_LOGOUT" desc="Message stating that administrators can see device Login and Logout events."> Device login/logout history, including timestamps and failed attempts </message> + <message name="IDS_MANAGEMENT_REPORT_CRD_SESSIONS" desc="Message stating that administrators can see Chrome Remote Desktop events."> + Chrome Remote Desktop history, including timestamps, hosts and client session ids + </message> <message name="IDS_MANAGEMENT_CROSTINI" desc="Message stating that administrators can see Crostini usage"> Linux apps installed and when they were last used </message>
diff --git a/components/management_strings_grdp/IDS_MANAGEMENT_REPORT_CRD_SESSIONS.png.sha1 b/components/management_strings_grdp/IDS_MANAGEMENT_REPORT_CRD_SESSIONS.png.sha1 new file mode 100644 index 0000000..52a082e --- /dev/null +++ b/components/management_strings_grdp/IDS_MANAGEMENT_REPORT_CRD_SESSIONS.png.sha1
@@ -0,0 +1 @@ +7dc395c7a54d4206cf4004cd34553177330f565d \ No newline at end of file
diff --git a/components/messages/android/messages_feature.cc b/components/messages/android/messages_feature.cc index 59e9117b..a8313df 100644 --- a/components/messages/android/messages_feature.cc +++ b/components/messages/android/messages_feature.cc
@@ -18,7 +18,7 @@ "MessagesForAndroidInfrastructure", base::FEATURE_ENABLED_BY_DEFAULT}; const base::Feature kMessagesForAndroidNearOomReduction{ - "MessagesForAndroidNearOomReduction", base::FEATURE_DISABLED_BY_DEFAULT}; + "MessagesForAndroidNearOomReduction", base::FEATURE_ENABLED_BY_DEFAULT}; const base::Feature kMessagesForAndroidNotificationBlocked{ "MessagesForAndroidNotificationBlocked", base::FEATURE_DISABLED_BY_DEFAULT};
diff --git a/components/omnibox/browser/omnibox_edit_model.cc b/components/omnibox/browser/omnibox_edit_model.cc index 83ecadd..c94ba46e 100644 --- a/components/omnibox/browser/omnibox_edit_model.cc +++ b/components/omnibox/browser/omnibox_edit_model.cc
@@ -1202,9 +1202,10 @@ // Send the textfield contents exactly as-is, as otherwise the verbatim // match can be wrong. The full page URL is anyways in set_current_url(). + // Don't attempt to use https as the default scheme for these requests. input_ = AutocompleteInput(view_->GetText(), GetPageClassification(), client_->GetSchemeClassifier(), - client_->ShouldDefaultTypedNavigationsToHttps(), + /*should_use_https_as_default_scheme=*/false, client_->GetHttpsPortForTesting()); input_.set_current_url(client_->GetURL()); input_.set_current_title(client_->GetTitle());
diff --git a/components/optimization_guide/core/bert_model_executor.cc b/components/optimization_guide/core/bert_model_executor.cc index 9e8e7b0..fcd4df0 100644 --- a/components/optimization_guide/core/bert_model_executor.cc +++ b/components/optimization_guide/core/bert_model_executor.cc
@@ -7,7 +7,7 @@ #include "base/trace_event/trace_event.h" #include "components/optimization_guide/core/model_util.h" #include "components/optimization_guide/core/tflite_op_resolver.h" -#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h" +#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h" namespace optimization_guide { @@ -28,18 +28,21 @@ GetStringNameForOptimizationTarget(optimization_target_), "input_length", input.size()); *out_status = ExecutionStatus::kSuccess; - return static_cast<tflite::task::text::nlclassifier::BertNLClassifier*>( - execution_task) + return static_cast<tflite::task::text::BertNLClassifier*>(execution_task) ->Classify(input); } std::unique_ptr<BertModelExecutor::ModelExecutionTask> BertModelExecutor::BuildModelExecutionTask(base::MemoryMappedFile* model_file, ExecutionStatus* out_status) { + tflite::task::text::BertNLClassifierOptions options; + *options.mutable_base_options() + ->mutable_model_file() + ->mutable_file_content() = std::string( + reinterpret_cast<const char*>(model_file->data()), model_file->length()); auto maybe_nl_classifier = - tflite::task::text::nlclassifier::BertNLClassifier::CreateFromBuffer( - reinterpret_cast<const char*>(model_file->data()), - model_file->length(), std::make_unique<TFLiteOpResolver>()); + tflite::task::text::BertNLClassifier::CreateFromOptions( + std::move(options), std::make_unique<TFLiteOpResolver>()); if (maybe_nl_classifier.ok()) return std::move(maybe_nl_classifier.value()); *out_status = ExecutionStatus::kErrorModelFileNotValid;
diff --git a/components/optimization_guide/core/model_validator.cc b/components/optimization_guide/core/model_validator.cc index b936f747..65a9bd5f 100644 --- a/components/optimization_guide/core/model_validator.cc +++ b/components/optimization_guide/core/model_validator.cc
@@ -64,7 +64,12 @@ float ModelValidatorExecutor::Postprocess( const std::vector<const TfLiteTensor*>& output_tensors) { std::vector<float> data; - tflite::task::core::PopulateVector<float>(output_tensors[0], &data); + absl::Status status = + tflite::task::core::PopulateVector<float>(output_tensors[0], &data); + if (!status.ok()) { + NOTREACHED(); + return -1; + } return data[0]; }
diff --git a/components/optimization_guide/core/test_tflite_model_executor.cc b/components/optimization_guide/core/test_tflite_model_executor.cc index 9b7ea803..09697f7d 100644 --- a/components/optimization_guide/core/test_tflite_model_executor.cc +++ b/components/optimization_guide/core/test_tflite_model_executor.cc
@@ -11,14 +11,15 @@ absl::Status TestTFLiteModelExecutor::Preprocess( const std::vector<TfLiteTensor*>& input_tensors, const std::vector<float>& input) { - tflite::task::core::PopulateTensor<float>(input, input_tensors[0]); - return absl::OkStatus(); + return tflite::task::core::PopulateTensor<float>(input, input_tensors[0]); } std::vector<float> TestTFLiteModelExecutor::Postprocess( const std::vector<const TfLiteTensor*>& output_tensors) { std::vector<float> data; - tflite::task::core::PopulateVector<float>(output_tensors[0], &data); + absl::Status status = + tflite::task::core::PopulateVector<float>(output_tensors[0], &data); + DCHECK(status.ok()); return data; }
diff --git a/components/payments/content/secure_payment_confirmation_app.cc b/components/payments/content/secure_payment_confirmation_app.cc index 239801e..088234f0 100644 --- a/components/payments/content/secure_payment_confirmation_app.cc +++ b/components/payments/content/secure_payment_confirmation_app.cc
@@ -218,8 +218,7 @@ response->secure_payment_confirmation = mojom::SecurePaymentConfirmationResponse::New( response_->info.Clone(), response_->signature, - response_->has_transport, response_->transport, - response_->user_handle); + response_->authenticator_attachment, response_->user_handle); return response; }
diff --git a/components/permissions/prediction_service/prediction_model_executor.cc b/components/permissions/prediction_service/prediction_model_executor.cc index 8cea2a3..1fd7298 100644 --- a/components/permissions/prediction_service/prediction_model_executor.cc +++ b/components/permissions/prediction_service/prediction_model_executor.cc
@@ -28,43 +28,90 @@ NOTREACHED(); } - tflite::task::core::PopulateTensor<float>( + absl::Status status = tflite::task::core::PopulateTensor<float>( input.client_features().client_stats().avg_deny_rate(), input_tensors[0]); - tflite::task::core::PopulateTensor<float>( + if (!status.ok()) { + return status; + } + + status = tflite::task::core::PopulateTensor<float>( input.client_features().client_stats().avg_dismiss_rate(), input_tensors[1]); - tflite::task::core::PopulateTensor<float>( + if (!status.ok()) { + return status; + } + + status = tflite::task::core::PopulateTensor<float>( input.client_features().client_stats().avg_grant_rate(), input_tensors[2]); - tflite::task::core::PopulateTensor<float>( + if (!status.ok()) { + return status; + } + + status = tflite::task::core::PopulateTensor<float>( input.client_features().client_stats().avg_ignore_rate(), input_tensors[3]); - tflite::task::core::PopulateTensor<float>( + if (!status.ok()) { + return status; + } + + status = tflite::task::core::PopulateTensor<float>( input.permission_features()[0].permission_stats().avg_deny_rate(), input_tensors[4]); - tflite::task::core::PopulateTensor<float>( + if (!status.ok()) { + return status; + } + + status = tflite::task::core::PopulateTensor<float>( input.permission_features()[0].permission_stats().avg_dismiss_rate(), input_tensors[5]); - tflite::task::core::PopulateTensor<float>( + if (!status.ok()) { + return status; + } + + status = tflite::task::core::PopulateTensor<float>( input.permission_features()[0].permission_stats().avg_grant_rate(), input_tensors[6]); - tflite::task::core::PopulateTensor<float>( + if (!status.ok()) { + return status; + } + + status = tflite::task::core::PopulateTensor<float>( input.permission_features()[0].permission_stats().avg_ignore_rate(), input_tensors[7]); - tflite::task::core::PopulateTensor<int64_t>( + if (!status.ok()) { + return status; + } + + status = tflite::task::core::PopulateTensor<int64_t>( static_cast<int64_t>( input.permission_features()[0].permission_stats().prompts_count()), input_tensors[8]); - tflite::task::core::PopulateTensor<int64_t>( + if (!status.ok()) { + return status; + } + + status = tflite::task::core::PopulateTensor<int64_t>( static_cast<int64_t>( input.client_features().client_stats().prompts_count()), input_tensors[9]); - tflite::task::core::PopulateTensor<int64_t>( + if (!status.ok()) { + return status; + } + + status = tflite::task::core::PopulateTensor<int64_t>( static_cast<int64_t>(input.client_features().gesture_enum()), input_tensors[10]); - tflite::task::core::PopulateTensor<int64_t>( + if (!status.ok()) { + return status; + } + + status = tflite::task::core::PopulateTensor<int64_t>( static_cast<int64_t>(input.client_features().platform_enum()), input_tensors[11]); + if (!status.ok()) { + return status; + } return absl::OkStatus(); } @@ -73,7 +120,10 @@ DCHECK(request_type_ == RequestType::kNotifications || request_type_ == RequestType::kGeolocation); std::vector<float> data; - tflite::task::core::PopulateVector<float>(output_tensors[0], &data); + absl::Status status = + tflite::task::core::PopulateVector<float>(output_tensors[0], &data); + DCHECK(status.ok()); + GeneratePredictionsResponse response; float threshold = request_type_ == RequestType::kNotifications ? kNotificationPredictionsThreshold
diff --git a/components/policy/core/common/cloud/cloud_policy_util.cc b/components/policy/core/common/cloud/cloud_policy_util.cc index 8dbf3f8..5008676 100644 --- a/components/policy/core/common/cloud/cloud_policy_util.cc +++ b/components/policy/core/common/cloud/cloud_policy_util.cc
@@ -18,7 +18,8 @@ #include <wincred.h> #endif -#if defined(OS_LINUX) || BUILDFLAG(IS_CHROMEOS_LACROS) || defined(OS_APPLE) +#if defined(OS_LINUX) || BUILDFLAG(IS_CHROMEOS_LACROS) || defined(OS_APPLE) || \ + defined(OS_FUCHSIA) #include <pwd.h> #include <sys/types.h> #include <unistd.h> @@ -64,11 +65,6 @@ #include "base/mac/scoped_cftyperef.h" #include "base/strings/string_util.h" #include "base/strings/sys_string_conversions.h" -#include "base/system/sys_info.h" -#endif - -#if defined(OS_LINUX) || BUILDFLAG(IS_CHROMEOS_LACROS) -#include "base/system/sys_info.h" #endif #if defined(OS_IOS) @@ -80,7 +76,7 @@ namespace em = enterprise_management; std::string GetMachineName() { -#if defined(OS_LINUX) || BUILDFLAG(IS_CHROMEOS_LACROS) +#if defined(OS_LINUX) || BUILDFLAG(IS_CHROMEOS_LACROS) || defined(OS_FUCHSIA) char hostname[HOST_NAME_MAX]; if (gethostname(hostname, HOST_NAME_MAX) == 0) // Success. return hostname; @@ -126,13 +122,13 @@ return result; } return std::string(); -#elif defined(OS_ANDROID) || defined(OS_FUCHSIA) - // TODO(crbug.com/1257674): This should be fully implemented when there is - // support in fuchsia. +#elif defined(OS_ANDROID) return std::string(); -#else +#elif defined(OS_CHROMEOS) NOTREACHED(); return std::string(); +#else +#error Unsupported platform #endif }
diff --git a/components/policy/core/common/cloud/mock_cloud_policy_client.h b/components/policy/core/common/cloud/mock_cloud_policy_client.h index 7442f17..32231e2 100644 --- a/components/policy/core/common/cloud/mock_cloud_policy_client.h +++ b/components/policy/core/common/cloud/mock_cloud_policy_client.h
@@ -13,6 +13,7 @@ #include "components/policy/core/common/cloud/cloud_policy_client.h" #include "components/policy/core/common/cloud/device_management_service.h" #include "components/reporting/proto/synced/record.pb.h" +#include "device_management_backend.pb.h" #include "testing/gmock/include/gmock/gmock.h" namespace network { @@ -101,6 +102,17 @@ void(enterprise_management::ChromeOsUserReportRequest*, StatusCallback&)); + void UploadChromeProfileReport( + std::unique_ptr<enterprise_management::ChromeProfileReportRequest> + request, + StatusCallback callback) override { + UploadChromeProfileReportProxy(request.get(), callback); + } + // Use Proxy function because unique_ptr can't be used in mock function. + MOCK_METHOD2(UploadChromeProfileReportProxy, + void(enterprise_management::ChromeProfileReportRequest*, + StatusCallback&)); + void UploadSecurityEventReport(content::BrowserContext* context, bool include_device_info, base::Value value,
diff --git a/components/policy/resources/policy_templates.json b/components/policy/resources/policy_templates.json index 32c3bde..00b8ed2 100644 --- a/components/policy/resources/policy_templates.json +++ b/components/policy/resources/policy_templates.json
@@ -9453,7 +9453,7 @@ 'chrome_os:86-', 'android:86-', 'webview_android:86-', - 'ios:91-', + 'ios:98-', ], 'features': { 'dynamic_refresh': True, @@ -9485,7 +9485,7 @@ 'chrome_os:86-', 'android:86-', 'webview_android:86-', - 'ios:91-', + 'ios:98-', ], 'features': { 'dynamic_refresh': True,
diff --git a/components/power_metrics/BUILD.gn b/components/power_metrics/BUILD.gn index 96a8786..ecc0c9b 100644 --- a/components/power_metrics/BUILD.gn +++ b/components/power_metrics/BUILD.gn
@@ -9,6 +9,9 @@ "energy_impact_mac.mm", "iopm_power_source_sampling_event_source.cc", "iopm_power_source_sampling_event_source.h", + "m1_sensors_internal_types_mac.h", + "m1_sensors_mac.h", + "m1_sensors_mac.mm", "mach_time_mac.h", "mach_time_mac.mm", "resource_coalition_internal_types_mac.h",
diff --git a/components/power_metrics/m1_sensors_internal_types_mac.h b/components/power_metrics/m1_sensors_internal_types_mac.h new file mode 100644 index 0000000..039d84c --- /dev/null +++ b/components/power_metrics/m1_sensors_internal_types_mac.h
@@ -0,0 +1,26 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_POWER_METRICS_M1_SENSORS_INTERNAL_TYPES_MAC_H_ +#define COMPONENTS_POWER_METRICS_M1_SENSORS_INTERNAL_TYPES_MAC_H_ + +#include <stdint.h> + +// From: +// https://opensource.apple.com/source/IOHIDFamily/IOHIDFamily-421.6/IOHIDFamily/IOHIDEventTypes.h.auto.html + +#define IOHIDEventFieldBase(type) (type << 16) + +constexpr int64_t kIOHIDEventTypeTemperature = 15; + +// From: +// https://opensource.apple.com/source/IOHIDFamily/IOHIDFamily-421.6/IOHIDFamily/AppleHIDUsageTables.h + +// Usage pages +constexpr int kHIDPage_AppleVendor = 0xff00; + +// Usage keys for `kHIDPage_AppleVendor` +constexpr int kHIDUsage_AppleVendor_TemperatureSensor = 0x0005; + +#endif // COMPONENTS_POWER_METRICS_M1_SENSORS_INTERNAL_TYPES_MAC_H_
diff --git a/components/power_metrics/m1_sensors_mac.h b/components/power_metrics/m1_sensors_mac.h new file mode 100644 index 0000000..45703ad --- /dev/null +++ b/components/power_metrics/m1_sensors_mac.h
@@ -0,0 +1,48 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// The Apple M1 chip has sensors to monitor its power consumption and +// temperature. This file defines a class to retrieve data from these sensors. + +#ifndef COMPONENTS_POWER_METRICS_M1_SENSORS_MAC_H_ +#define COMPONENTS_POWER_METRICS_M1_SENSORS_MAC_H_ + +#include <memory> + +#include <IOKit/hidsystem/IOHIDEventSystemClient.h> + +#include "base/mac/scoped_cftyperef.h" +#include "third_party/abseil-cpp/absl/types/optional.h" + +namespace power_metrics { + +class M1SensorsReader { + public: + struct TemperaturesCelsius { + TemperaturesCelsius(); + TemperaturesCelsius(const TemperaturesCelsius&) noexcept; + ~TemperaturesCelsius(); + + absl::optional<double> p_cores; + absl::optional<double> e_cores; + }; + + virtual ~M1SensorsReader(); + + // Creates an M1SensorsReader. Returns nullptr on failure. + static std::unique_ptr<M1SensorsReader> Create(); + + // Reads temperature sensors. Virtual for testing. + virtual TemperaturesCelsius ReadTemperatures(); + + protected: + M1SensorsReader(base::ScopedCFTypeRef<IOHIDEventSystemClientRef> system); + + private: + base::ScopedCFTypeRef<IOHIDEventSystemClientRef> system_; +}; + +} // namespace power_metrics + +#endif // COMPONENTS_POWER_METRICS_M1_SENSORS_MAC_H_
diff --git a/components/power_metrics/m1_sensors_mac.mm b/components/power_metrics/m1_sensors_mac.mm new file mode 100644 index 0000000..0477e77 --- /dev/null +++ b/components/power_metrics/m1_sensors_mac.mm
@@ -0,0 +1,123 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/power_metrics/m1_sensors_mac.h" + +#import <Foundation/Foundation.h> +#import <IOKit/hid/IOHIDDeviceKeys.h> +#import <IOKit/hidsystem/IOHIDServiceClient.h> + +#include <utility> + +#include "base/mac/foundation_util.h" +#include "base/memory/ptr_util.h" +#include "components/power_metrics/m1_sensors_internal_types_mac.h" + +extern "C" { + +extern IOHIDEventSystemClientRef IOHIDEventSystemClientCreate(CFAllocatorRef); +extern int IOHIDEventSystemClientSetMatching(IOHIDEventSystemClientRef client, + CFDictionaryRef match); +extern CFTypeRef IOHIDServiceClientCopyEvent(IOHIDServiceClientRef, + int64_t, + int32_t, + int64_t); +extern double IOHIDEventGetFloatValue(CFTypeRef, int32_t); +} + +namespace power_metrics { + +namespace { + +absl::optional<double> GetEventFloatValue(IOHIDServiceClientRef service, + int64_t event_type) { + base::ScopedCFTypeRef<CFTypeRef> event( + IOHIDServiceClientCopyEvent(service, event_type, 0, 0)); + if (!event) + return absl::nullopt; + return IOHIDEventGetFloatValue(event, IOHIDEventFieldBase(event_type)); +} + +} // namespace + +M1SensorsReader::TemperaturesCelsius::TemperaturesCelsius() = default; +M1SensorsReader::TemperaturesCelsius::TemperaturesCelsius( + const TemperaturesCelsius&) noexcept = default; +M1SensorsReader::TemperaturesCelsius::~TemperaturesCelsius() = default; + +M1SensorsReader::~M1SensorsReader() = default; + +// static +std::unique_ptr<M1SensorsReader> M1SensorsReader::Create() { + base::ScopedCFTypeRef<IOHIDEventSystemClientRef> system( + IOHIDEventSystemClientCreate(kCFAllocatorDefault)); + + if (system == nil) + return nullptr; + + NSDictionary* filter = @{ + @kIOHIDPrimaryUsagePageKey : [NSNumber numberWithInt:kHIDPage_AppleVendor], + @kIOHIDPrimaryUsageKey : + [NSNumber numberWithInt:kHIDUsage_AppleVendor_TemperatureSensor], + }; + IOHIDEventSystemClientSetMatching(system, base::mac::NSToCFCast(filter)); + + return base::WrapUnique(new M1SensorsReader(std::move(system))); +} + +M1SensorsReader::TemperaturesCelsius M1SensorsReader::ReadTemperatures() { + base::ScopedCFTypeRef<CFArrayRef> services( + IOHIDEventSystemClientCopyServices(system_.get())); + + // There are multiple temperature sensors on P-Cores and E-Cores. Count and + // sum values to compute average later. + int num_p_core_temp = 0; + int num_e_core_temp = 0; + double sum_p_core_temp = 0; + double sum_e_core_temp = 0; + + for (id service_obj in base::mac::CFToNSCast(services.get())) { + IOHIDServiceClientRef service = (IOHIDServiceClientRef)service_obj; + + base::ScopedCFTypeRef<CFStringRef> product_cf( + base::mac::CFCast<CFStringRef>( + IOHIDServiceClientCopyProperty(service, CFSTR(kIOHIDProductKey)))); + if (product_cf == nil) + continue; + + if ([base::mac::CFToNSCast(product_cf.get()) + hasPrefix:@"pACC MTR Temp Sensor"]) { + absl::optional<double> temp = + GetEventFloatValue(service, kIOHIDEventTypeTemperature); + if (temp.has_value()) { + num_p_core_temp += 1; + sum_p_core_temp += temp.value(); + } + } + + if ([base::mac::CFToNSCast(product_cf.get()) + hasPrefix:@"eACC MTR Temp Sensor"]) { + absl::optional<double> temp = + GetEventFloatValue(service, kIOHIDEventTypeTemperature); + if (temp.has_value()) { + num_e_core_temp += 1; + sum_e_core_temp += temp.value(); + } + } + } + + TemperaturesCelsius temperatures; + if (num_p_core_temp > 0) + temperatures.p_cores = sum_p_core_temp / num_p_core_temp; + if (num_e_core_temp > 0) + temperatures.e_cores = sum_e_core_temp / num_e_core_temp; + + return temperatures; +} + +M1SensorsReader::M1SensorsReader( + base::ScopedCFTypeRef<IOHIDEventSystemClientRef> system) + : system_(std::move(system)) {} + +} // namespace power_metrics
diff --git a/components/printing/test/print_render_frame_helper_browsertest.cc b/components/printing/test/print_render_frame_helper_browsertest.cc index cf44f23c..4d1d51e 100644 --- a/components/printing/test/print_render_frame_helper_browsertest.cc +++ b/components/printing/test/print_render_frame_helper_browsertest.cc
@@ -118,6 +118,29 @@ // A web page to simulate the print preview page. const char kPrintPreviewHTML[] = "<body><p id=\"pdf-viewer\">Hello World!</p></body>"; + +const char kHTMLWithManyLinesOfText[] = + "<html><head><style>" + "p { font-size: 24px; }" + "</style></head><body>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "<p>The quick brown fox jumped over the lazy dog.</p>" + "</body></html>"; #endif // BUILDFLAG(ENABLE_PRINT_PREVIEW) #endif // !BUILDFLAG(IS_CHROMEOS_ASH) @@ -1022,6 +1045,13 @@ preview_ui()->has_custom_page_size_style()); } + void SetLetterMediaSize() { + base::Value* media_size_value = print_settings().SetKey( + kSettingMediaSize, base::Value(base::Value::Type::DICTIONARY)); + media_size_value->SetIntKey(kSettingMediaSizeWidthMicrons, 215900); + media_size_value->SetIntKey(kSettingMediaSizeHeightMicrons, 279400); + } + base::Value& print_settings() { return print_settings_; } private: @@ -1526,6 +1556,91 @@ OnClosePrintPreviewDialog(); } +TEST_F(PrintRenderFrameHelperPreviewTest, PrintPreviewForManyLinesOfText) { + LoadHTML(kHTMLWithManyLinesOfText); + + SetLetterMediaSize(); + + OnPrintPreview(); + + EXPECT_EQ(0u, preview_ui()->print_preview_pages_remaining()); + VerifyDidPreviewPage(true, 0); + VerifyPreviewPageCount(1); + VerifyPrintPreviewCancelled(false); + VerifyPrintPreviewFailed(false); + VerifyPrintPreviewGenerated(true); + VerifyPagesPrinted(false); + + OnClosePrintPreviewDialog(); +} + +TEST_F(PrintRenderFrameHelperPreviewTest, + PrintPreviewForManyLinesOfTextWithScaling) { + LoadHTML(kHTMLWithManyLinesOfText); + + SetLetterMediaSize(); + print_settings().SetIntKey(kSettingScaleFactor, 200); + + OnPrintPreview(); + + constexpr int kExpectedPageCount = 3; + EXPECT_EQ(0u, preview_ui()->print_preview_pages_remaining()); + for (int i = 0; i < kExpectedPageCount; ++i) + VerifyDidPreviewPage(true, i); + VerifyPreviewPageCount(kExpectedPageCount); + VerifyPrintPreviewCancelled(false); + VerifyPrintPreviewFailed(false); + VerifyPrintPreviewGenerated(true); + VerifyPagesPrinted(false); + + OnClosePrintPreviewDialog(); +} + +TEST_F(PrintRenderFrameHelperPreviewTest, + PrintPreviewForManyLinesOfTextWithTextSelection) { + LoadHTML(kHTMLWithManyLinesOfText); + GetMainFrame()->ExecuteCommand("SelectAll"); + + SetLetterMediaSize(); + print_settings().SetBoolKey(kSettingShouldPrintSelectionOnly, true); + + OnPrintPreview(); + + EXPECT_EQ(0u, preview_ui()->print_preview_pages_remaining()); + VerifyDidPreviewPage(true, 0); + VerifyPreviewPageCount(1); + VerifyPrintPreviewCancelled(false); + VerifyPrintPreviewFailed(false); + VerifyPrintPreviewGenerated(true); + VerifyPagesPrinted(false); + + OnClosePrintPreviewDialog(); +} + +TEST_F(PrintRenderFrameHelperPreviewTest, + PrintPreviewForManyLinesOfTextWithTextSelectionAndScaling) { + LoadHTML(kHTMLWithManyLinesOfText); + GetMainFrame()->ExecuteCommand("SelectAll"); + + SetLetterMediaSize(); + print_settings().SetBoolKey(kSettingShouldPrintSelectionOnly, true); + print_settings().SetIntKey(kSettingScaleFactor, 200); + + OnPrintPreview(); + + EXPECT_EQ(0u, preview_ui()->print_preview_pages_remaining()); + VerifyDidPreviewPage(true, 0); + // TODO(crbug.com/1023416): The preview should contain 3 pages, like the + // PrintPreviewForManyLinesOfTextWithScaling test case. + VerifyPreviewPageCount(1); + VerifyPrintPreviewCancelled(false); + VerifyPrintPreviewFailed(false); + VerifyPrintPreviewGenerated(true); + VerifyPagesPrinted(false); + + OnClosePrintPreviewDialog(); +} + // Tests that cancelling print preview works. TEST_F(PrintRenderFrameHelperPreviewTest, PrintPreviewCancel) { LoadHTML(kLongPageHTML);
diff --git a/components/reporting/proto/synced/metric_data.proto b/components/reporting/proto/synced/metric_data.proto index 33dbe00..3108fdc3 100644 --- a/components/reporting/proto/synced/metric_data.proto +++ b/components/reporting/proto/synced/metric_data.proto
@@ -268,6 +268,20 @@ optional string input_device_name = 6; } +// Boot Performance telemetry data +message BootPerformanceTelemetry { + // Total time when to boot up. + optional int64 boot_up_seconds = 1; + // The Timestamp when power came on. + optional int64 boot_up_timestamp_seconds = 2; + // Total time since shutdown start to power off. + optional int64 shutdown_seconds = 3; + // Timestamp when shutdown. + optional int64 shutdown_timestamp_seconds = 4; + // Shutdown reason. + optional string shutdown_reason = 5; +} + // Data that can change over time, collected and reported every specific period // of time or when an event occur. message TelemetryData { @@ -279,6 +293,8 @@ optional AudioTelemetry audio_telemetry = 2; // Usb telemetry data optional UsbTelemetry usb_telemetry = 3; + // Boot Performance telemetry data. + optional BootPerformanceTelemetry boot_performance_telemetry = 4; } enum MetricEventType {
diff --git a/components/segmentation_platform/internal/execution/segmentation_model_executor.cc b/components/segmentation_platform/internal/execution/segmentation_model_executor.cc index e03fd07..f4fa1fe 100644 --- a/components/segmentation_platform/internal/execution/segmentation_model_executor.cc +++ b/components/segmentation_platform/internal/execution/segmentation_model_executor.cc
@@ -31,8 +31,7 @@ "length of input data does not match length of tensor"); } - tflite::task::core::PopulateTensor<float>(input, input_tensors[0]); - return absl::OkStatus(); + return tflite::task::core::PopulateTensor<float>(input, input_tensors[0]); } float SegmentationModelExecutor::Postprocess( @@ -43,7 +42,12 @@ DCHECK_EQ(1u, output_tensors[0]->bytes / sizeof(output_tensors[0]->type)); std::vector<float> data; - tflite::task::core::PopulateVector<float>(output_tensors[0], &data); + absl::Status status = + tflite::task::core::PopulateVector<float>(output_tensors[0], &data); + if (!status.ok()) { + NOTREACHED(); + return -1; + } DCHECK_EQ(1u, data.size()); return data[0]; }
diff --git a/components/signin/internal/identity_manager/BUILD.gn b/components/signin/internal/identity_manager/BUILD.gn index 15d8659..08370fd 100644 --- a/components/signin/internal/identity_manager/BUILD.gn +++ b/components/signin/internal/identity_manager/BUILD.gn
@@ -35,6 +35,7 @@ "primary_account_manager.h", "primary_account_mutator_impl.cc", "primary_account_mutator_impl.h", + "primary_account_policy_manager.h", "profile_oauth2_token_service.cc", "profile_oauth2_token_service.h", "profile_oauth2_token_service_builder.cc", @@ -99,6 +100,13 @@ if (is_chromeos_ash) { deps += [ "//ash/constants" ] public_deps += [ "//ash/components/account_manager" ] + } else { + sources += [ + "primary_account_policy_manager_impl.cc", + "primary_account_policy_manager_impl.h", + ] + + public_deps += [ "//components/prefs" ] } if (!is_android && !is_ios) { @@ -181,6 +189,8 @@ "//ash/constants", "//components/account_manager_core:test_support", ] + } else { + sources += [ "primary_account_policy_manager_impl_unittest.cc" ] } if (is_ios) {
diff --git a/components/signin/internal/identity_manager/primary_account_manager.cc b/components/signin/internal/identity_manager/primary_account_manager.cc index 89b922cc..741a2b97 100644 --- a/components/signin/internal/identity_manager/primary_account_manager.cc +++ b/components/signin/internal/identity_manager/primary_account_manager.cc
@@ -17,6 +17,7 @@ #include "components/prefs/pref_registry_simple.h" #include "components/prefs/pref_service.h" #include "components/signin/internal/identity_manager/account_tracker_service.h" +#include "components/signin/internal/identity_manager/primary_account_policy_manager.h" #include "components/signin/internal/identity_manager/profile_oauth2_token_service.h" #include "components/signin/public/base/account_consistency_method.h" #include "components/signin/public/base/signin_client.h" @@ -29,10 +30,13 @@ PrimaryAccountManager::PrimaryAccountManager( SigninClient* client, ProfileOAuth2TokenService* token_service, - AccountTrackerService* account_tracker_service) + AccountTrackerService* account_tracker_service, + std::unique_ptr<PrimaryAccountPolicyManager> policy_manager) : client_(client), token_service_(token_service), - account_tracker_service_(account_tracker_service) { + account_tracker_service_(account_tracker_service), + initialized_(false), + policy_manager_(std::move(policy_manager)) { DCHECK(client_); DCHECK(account_tracker_service_); } @@ -115,6 +119,9 @@ SetPrimaryAccountInternal(account_info, consented); } + if (policy_manager_) { + policy_manager_->InitializePolicy(local_state, this); + } // It is important to only load credentials after starting to observe the // token service. token_service_->AddObserver(this);
diff --git a/components/signin/internal/identity_manager/primary_account_manager.h b/components/signin/internal/identity_manager/primary_account_manager.h index dab27fa..04671a7 100644 --- a/components/signin/internal/identity_manager/primary_account_manager.h +++ b/components/signin/internal/identity_manager/primary_account_manager.h
@@ -18,6 +18,8 @@ #ifndef COMPONENTS_SIGNIN_INTERNAL_IDENTITY_MANAGER_PRIMARY_ACCOUNT_MANAGER_H_ #define COMPONENTS_SIGNIN_INTERNAL_IDENTITY_MANAGER_PRIMARY_ACCOUNT_MANAGER_H_ +#include <memory> + #include "base/memory/raw_ptr.h" #include "base/observer_list.h" #include "base/observer_list_types.h" @@ -31,6 +33,7 @@ class AccountTrackerService; class PrefRegistrySimple; class PrefService; +class PrimaryAccountPolicyManager; class ProfileOAuth2TokenService; namespace signin_metrics { @@ -56,9 +59,11 @@ kRemoveAllAccounts, }; - PrimaryAccountManager(SigninClient* client, - ProfileOAuth2TokenService* token_service, - AccountTrackerService* account_tracker_service); + PrimaryAccountManager( + SigninClient* client, + ProfileOAuth2TokenService* token_service, + AccountTrackerService* account_tracker_service, + std::unique_ptr<PrimaryAccountPolicyManager> policy_manager); PrimaryAccountManager(const PrimaryAccountManager&) = delete; PrimaryAccountManager& operator=(const PrimaryAccountManager&) = delete; @@ -179,6 +184,7 @@ // this field. CoreAccountInfo primary_account_info_; + std::unique_ptr<PrimaryAccountPolicyManager> policy_manager_; base::ObserverList<Observer> observers_; };
diff --git a/components/signin/internal/identity_manager/primary_account_manager_unittest.cc b/components/signin/internal/identity_manager/primary_account_manager_unittest.cc index 2100f76..d63af1fd92 100644 --- a/components/signin/internal/identity_manager/primary_account_manager_unittest.cc +++ b/components/signin/internal/identity_manager/primary_account_manager_unittest.cc
@@ -23,6 +23,7 @@ #include "components/signin/internal/identity_manager/account_fetcher_service.h" #include "components/signin/internal/identity_manager/account_tracker_service.h" #include "components/signin/internal/identity_manager/fake_profile_oauth2_token_service_delegate.h" +#include "components/signin/internal/identity_manager/primary_account_policy_manager.h" #include "components/signin/internal/identity_manager/profile_oauth2_token_service.h" #include "components/signin/public/base/signin_pref_names.h" #include "components/signin/public/base/signin_switches.h" @@ -30,6 +31,10 @@ #include "components/sync_preferences/testing_pref_service_syncable.h" #include "testing/gtest/include/gtest/gtest.h" +#if !BUILDFLAG(IS_CHROMEOS_ASH) +#include "components/signin/internal/identity_manager/primary_account_policy_manager_impl.h" +#endif + using signin::ConsentLevel; class PrimaryAccountManagerTest : public testing::Test, @@ -74,8 +79,20 @@ void CreatePrimaryAccountManager() { DCHECK(!manager_); + // Supply the primary account manager with a policy manager to reflect + // production usage: null on ChromeOS, a PrimaryAccountPolicyManagerImpl on + // other platforms. + std::unique_ptr<PrimaryAccountPolicyManager> policy_manager; +#if !BUILDFLAG(IS_CHROMEOS_ASH) + policy_manager = + std::make_unique<PrimaryAccountPolicyManagerImpl>(&test_signin_client_); + policy_manager_ = + static_cast<PrimaryAccountPolicyManagerImpl*>(policy_manager.get()); +#endif + manager_ = std::make_unique<PrimaryAccountManager>( - &test_signin_client_, &token_service_, &account_tracker_); + &test_signin_client_, &token_service_, &account_tracker_, + std::move(policy_manager)); manager_->Initialize(&local_state_); manager_->AddObserver(this); } @@ -122,6 +139,9 @@ ProfileOAuth2TokenService token_service_; AccountTrackerService account_tracker_; AccountFetcherService account_fetcher_; +#if !BUILDFLAG(IS_CHROMEOS_ASH) + raw_ptr<PrimaryAccountPolicyManagerImpl> policy_manager_; +#endif std::unique_ptr<PrimaryAccountManager> manager_; std::vector<std::string> oauth_tokens_fetched_; std::vector<std::string> cookies_; @@ -222,6 +242,33 @@ signin_metrics::SignoutDelete::kIgnoreMetric); EXPECT_FALSE(manager_->HasPrimaryAccount(ConsentLevel::kSignin)); } + +TEST_F(PrimaryAccountManagerTest, ProhibitedAtStartup) { + CoreAccountId account_id = AddToAccountTracker("gaia_id", "user@gmail.com"); + user_prefs_.SetString(prefs::kGoogleServicesAccountId, account_id.ToString()); + local_state_.SetString(prefs::kGoogleServicesUsernamePattern, + ".*@google.com"); + CreatePrimaryAccountManager(); + // Currently signed in user is prohibited by policy, so should be signed out. + EXPECT_EQ("", manager_->GetPrimaryAccountInfo(ConsentLevel::kSync).email); + EXPECT_EQ(CoreAccountId(), + manager_->GetPrimaryAccountId(ConsentLevel::kSync)); +} + +TEST_F(PrimaryAccountManagerTest, ProhibitedAfterStartup) { + CoreAccountId account_id = AddToAccountTracker("gaia_id", "user@gmail.com"); + user_prefs_.SetString(prefs::kGoogleServicesAccountId, account_id.ToString()); + CreatePrimaryAccountManager(); + EXPECT_EQ("user@gmail.com", + manager_->GetPrimaryAccountInfo(ConsentLevel::kSync).email); + EXPECT_EQ(account_id, manager_->GetPrimaryAccountId(ConsentLevel::kSync)); + // Update the profile - user should be signed out. + local_state_.SetString(prefs::kGoogleServicesUsernamePattern, + ".*@google.com"); + EXPECT_EQ("", manager_->GetPrimaryAccountInfo(ConsentLevel::kSync).email); + EXPECT_EQ(CoreAccountId(), + manager_->GetPrimaryAccountId(ConsentLevel::kSync)); +} #endif // Regression test for https://crbug.com/1155519. @@ -291,6 +338,19 @@ EXPECT_EQ(account_id, manager_->GetPrimaryAccountId(ConsentLevel::kSync)); } +#if !BUILDFLAG(IS_CHROMEOS_ASH) +TEST_F(PrimaryAccountManagerTest, SigninNotAllowed) { + std::string user("user@google.com"); + CoreAccountId account_id = AddToAccountTracker("gaia_id", user); + user_prefs_.SetString(prefs::kGoogleServicesAccountId, account_id.ToString()); + user_prefs_.SetBoolean(prefs::kSigninAllowed, false); + CreatePrimaryAccountManager(); + // Currently signing in is prohibited by policy, so should be signed out. + EXPECT_EQ("", manager_->GetPrimaryAccountInfo(ConsentLevel::kSync).email); + EXPECT_TRUE(manager_->GetPrimaryAccountId(ConsentLevel::kSync).empty()); +} +#endif + TEST_F(PrimaryAccountManagerTest, GaiaIdMigration) { #if BUILDFLAG(IS_CHROMEOS_ASH) base::test::ScopedFeatureList scoped_feature_list;
diff --git a/components/signin/internal/identity_manager/primary_account_policy_manager.h b/components/signin/internal/identity_manager/primary_account_policy_manager.h new file mode 100644 index 0000000..f5dedd0 --- /dev/null +++ b/components/signin/internal/identity_manager/primary_account_policy_manager.h
@@ -0,0 +1,29 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_SIGNIN_INTERNAL_IDENTITY_MANAGER_PRIMARY_ACCOUNT_POLICY_MANAGER_H_ +#define COMPONENTS_SIGNIN_INTERNAL_IDENTITY_MANAGER_PRIMARY_ACCOUNT_POLICY_MANAGER_H_ + +class PrefService; +class PrimaryAccountManager; + +class PrimaryAccountPolicyManager { + public: + PrimaryAccountPolicyManager() = default; + + PrimaryAccountPolicyManager(const PrimaryAccountPolicyManager&) = delete; + PrimaryAccountPolicyManager& operator=(const PrimaryAccountPolicyManager&) = + delete; + + virtual ~PrimaryAccountPolicyManager() = default; + + // On platforms where PrimaryAccountManager is responsible for dealing with + // invalid username policy updates, we need to check this during + // initialization and sign the user out. + virtual void InitializePolicy( + PrefService* local_state, + PrimaryAccountManager* primary_account_manager) = 0; +}; + +#endif // COMPONENTS_SIGNIN_INTERNAL_IDENTITY_MANAGER_PRIMARY_ACCOUNT_POLICY_MANAGER_H_
diff --git a/components/signin/internal/identity_manager/primary_account_policy_manager_impl.cc b/components/signin/internal/identity_manager/primary_account_policy_manager_impl.cc new file mode 100644 index 0000000..9116698 --- /dev/null +++ b/components/signin/internal/identity_manager/primary_account_policy_manager_impl.cc
@@ -0,0 +1,114 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/signin/internal/identity_manager/primary_account_policy_manager_impl.h" + +#include "base/bind.h" +#include "base/logging.h" +#include "build/build_config.h" +#include "components/signin/internal/identity_manager/primary_account_manager.h" +#include "components/signin/public/base/signin_client.h" +#include "components/signin/public/base/signin_metrics.h" +#include "components/signin/public/base/signin_pref_names.h" +#include "components/signin/public/identity_manager/account_info.h" +#include "components/signin/public/identity_manager/identity_utils.h" + +PrimaryAccountPolicyManagerImpl::PrimaryAccountPolicyManagerImpl( + SigninClient* client) + : client_(client) {} + +PrimaryAccountPolicyManagerImpl::~PrimaryAccountPolicyManagerImpl() { + local_state_pref_registrar_.RemoveAll(); +} + +void PrimaryAccountPolicyManagerImpl::InitializePolicy( + PrefService* local_state, + PrimaryAccountManager* primary_account_manager) { + // local_state can be null during unit tests. + if (local_state) { + local_state_pref_registrar_.Init(local_state); + local_state_pref_registrar_.Add( + prefs::kGoogleServicesUsernamePattern, + base::BindRepeating(&PrimaryAccountPolicyManagerImpl:: + OnGoogleServicesUsernamePatternChanged, + weak_pointer_factory_.GetWeakPtr(), + primary_account_manager)); + } + signin_allowed_.Init( + prefs::kSigninAllowed, client_->GetPrefs(), + base::BindRepeating( + &PrimaryAccountPolicyManagerImpl::OnSigninAllowedPrefChanged, + base::Unretained(this), primary_account_manager)); + + CoreAccountInfo account_info = primary_account_manager->GetPrimaryAccountInfo( + signin::ConsentLevel::kSync); + if (!account_info.account_id.empty() && + (!IsAllowedUsername(account_info.email) || !IsSigninAllowed())) { + // User is signed in, but the username is invalid or signin is no longer + // allowed, so the user must be sign out. + // + // This may happen in the following cases: + // a. The user has toggled off signin allowed in settings. + // b. The administrator changed the policy since the last signin. + // + // Note: The token service has not yet loaded its credentials, so accounts + // cannot be revoked here. + // + // On desktop, when PrimaryAccountManager is initializing, the profile was + // not yet marked with sign out allowed. Therefore sign out is not allowed + // and all calls to RevokeSyncConsent() and ClearPrimaryAccount() methods + // are no-op. + // + // TODO(msarda): RevokeSyncConsent() method do not guarantee that the sync + // consent can really be revoked (this depends on whether sign out is + // allowed). Add a check here on desktop to make it clear that + // RevokeSyncConsent() does not do anything. + primary_account_manager->RevokeSyncConsent( + signin_metrics::SIGNIN_PREF_CHANGED_DURING_SIGNIN, + signin_metrics::SignoutDelete::kIgnoreMetric); + } +} + +void PrimaryAccountPolicyManagerImpl::OnGoogleServicesUsernamePatternChanged( + PrimaryAccountManager* primary_account_manager) { + if (primary_account_manager->HasPrimaryAccount(signin::ConsentLevel::kSync) && + !IsAllowedUsername( + primary_account_manager + ->GetPrimaryAccountInfo(signin::ConsentLevel::kSync) + .email)) { + // Signed in user is invalid according to the current policy so sign + // the user out. + primary_account_manager->ClearPrimaryAccount( + signin_metrics::GOOGLE_SERVICE_NAME_PATTERN_CHANGED, + signin_metrics::SignoutDelete::kIgnoreMetric); + } +} + +bool PrimaryAccountPolicyManagerImpl::IsSigninAllowed() const { + return signin_allowed_.GetValue(); +} + +void PrimaryAccountPolicyManagerImpl::OnSigninAllowedPrefChanged( + PrimaryAccountManager* primary_account_manager) { + if (!IsSigninAllowed() && + primary_account_manager->HasPrimaryAccount(signin::ConsentLevel::kSync)) { + VLOG(0) << "IsSigninAllowed() set to false, signing out the user"; + primary_account_manager->ClearPrimaryAccount( + signin_metrics::SIGNOUT_PREF_CHANGED, + signin_metrics::SignoutDelete::kIgnoreMetric); + } +} + +bool PrimaryAccountPolicyManagerImpl::IsAllowedUsername( + const std::string& username) const { + const PrefService* local_state = local_state_pref_registrar_.prefs(); + + // TODO(crbug.com/908121): We need to deal for now with the fact that many + // unit tests have a null |local_state| passed to InitializePolicy(), in which + // case all usernames are considered 'allowed'. + if (!local_state) + return true; + + return signin::IsUsernameAllowedByPatternFromPrefs(local_state, username); +}
diff --git a/components/signin/internal/identity_manager/primary_account_policy_manager_impl.h b/components/signin/internal/identity_manager/primary_account_policy_manager_impl.h new file mode 100644 index 0000000..fa720cc --- /dev/null +++ b/components/signin/internal/identity_manager/primary_account_policy_manager_impl.h
@@ -0,0 +1,66 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_SIGNIN_INTERNAL_IDENTITY_MANAGER_PRIMARY_ACCOUNT_POLICY_MANAGER_IMPL_H_ +#define COMPONENTS_SIGNIN_INTERNAL_IDENTITY_MANAGER_PRIMARY_ACCOUNT_POLICY_MANAGER_IMPL_H_ + +#include <string> + +#include "base/gtest_prod_util.h" +#include "base/memory/raw_ptr.h" +#include "base/memory/weak_ptr.h" +#include "components/prefs/pref_change_registrar.h" +#include "components/prefs/pref_member.h" +#include "components/signin/internal/identity_manager/primary_account_policy_manager.h" + +class PrefService; +class PrimaryAccountManager; +class SigninClient; + +class PrimaryAccountPolicyManagerImpl : public PrimaryAccountPolicyManager { + public: + explicit PrimaryAccountPolicyManagerImpl(SigninClient* client); + + PrimaryAccountPolicyManagerImpl(const PrimaryAccountPolicyManagerImpl&) = + delete; + PrimaryAccountPolicyManagerImpl& operator=( + const PrimaryAccountPolicyManagerImpl&) = delete; + + ~PrimaryAccountPolicyManagerImpl() override; + + // PrimaryAccountPolicyManager: + void InitializePolicy( + PrefService* local_state, + PrimaryAccountManager* primary_account_manager) override; + + private: + FRIEND_TEST_ALL_PREFIXES(PrimaryAccountPolicyManagerImplTest, Prohibited); + FRIEND_TEST_ALL_PREFIXES(PrimaryAccountPolicyManagerImplTest, + TestAlternateWildcard); + + // Returns true if a signin to Chrome is allowed (by policy or pref). + bool IsSigninAllowed() const; + + void OnSigninAllowedPrefChanged( + PrimaryAccountManager* primary_account_manager); + void OnGoogleServicesUsernamePatternChanged( + PrimaryAccountManager* primary_account_manager); + + // Returns true if the passed username is allowed by policy. + bool IsAllowedUsername(const std::string& username) const; + + raw_ptr<SigninClient> client_; + + // Helper object to listen for changes to signin preferences stored in non- + // profile-specific local prefs (like kGoogleServicesUsernamePattern). + PrefChangeRegistrar local_state_pref_registrar_; + + // Helper object to listen for changes to the signin allowed preference. + BooleanPrefMember signin_allowed_; + + base::WeakPtrFactory<PrimaryAccountPolicyManagerImpl> weak_pointer_factory_{ + this}; +}; + +#endif // COMPONENTS_SIGNIN_INTERNAL_IDENTITY_MANAGER_PRIMARY_ACCOUNT_POLICY_MANAGER_IMPL_H_
diff --git a/components/signin/internal/identity_manager/primary_account_policy_manager_impl_unittest.cc b/components/signin/internal/identity_manager/primary_account_policy_manager_impl_unittest.cc new file mode 100644 index 0000000..83df3659 --- /dev/null +++ b/components/signin/internal/identity_manager/primary_account_policy_manager_impl_unittest.cc
@@ -0,0 +1,74 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/signin/internal/identity_manager/primary_account_policy_manager_impl.h" + +#include <memory> +#include <string> + +#include "base/test/task_environment.h" +#include "components/prefs/testing_pref_service.h" +#include "components/signin/internal/identity_manager/account_tracker_service.h" +#include "components/signin/internal/identity_manager/fake_profile_oauth2_token_service_delegate.h" +#include "components/signin/internal/identity_manager/primary_account_manager.h" +#include "components/signin/internal/identity_manager/primary_account_policy_manager_impl.h" +#include "components/signin/internal/identity_manager/profile_oauth2_token_service.h" +#include "components/signin/public/base/account_consistency_method.h" +#include "components/signin/public/base/signin_pref_names.h" +#include "components/signin/public/base/test_signin_client.h" +#include "components/sync_preferences/testing_pref_service_syncable.h" +#include "testing/gtest/include/gtest/gtest.h" + +class PrimaryAccountPolicyManagerImplTest : public testing::Test { + public: + PrimaryAccountPolicyManagerImplTest() + : test_signin_client_(&user_prefs_), + token_service_( + &user_prefs_, + std::make_unique<FakeProfileOAuth2TokenServiceDelegate>()), + primary_account_manager_(&test_signin_client_, + &token_service_, + &account_tracker_, + nullptr /*policy_manager*/), + policy_manager_(&test_signin_client_) { + PrimaryAccountManager::RegisterProfilePrefs(user_prefs_.registry()); + PrimaryAccountManager::RegisterPrefs(local_state_.registry()); + + policy_manager_.InitializePolicy(&local_state_, &primary_account_manager_); + } + + ~PrimaryAccountPolicyManagerImplTest() override { + test_signin_client_.Shutdown(); + } + + base::test::TaskEnvironment task_environment_; + sync_preferences::TestingPrefServiceSyncable user_prefs_; + TestingPrefServiceSimple local_state_; + TestSigninClient test_signin_client_; + ProfileOAuth2TokenService token_service_; + AccountTrackerService account_tracker_; + PrimaryAccountManager primary_account_manager_; + PrimaryAccountPolicyManagerImpl policy_manager_; +}; + +TEST_F(PrimaryAccountPolicyManagerImplTest, Prohibited) { + local_state_.SetString(prefs::kGoogleServicesUsernamePattern, + ".*@google.com"); + EXPECT_TRUE(policy_manager_.IsAllowedUsername("test@google.com")); + EXPECT_TRUE(policy_manager_.IsAllowedUsername("happy@google.com")); + EXPECT_FALSE(policy_manager_.IsAllowedUsername("test@invalid.com")); + EXPECT_FALSE(policy_manager_.IsAllowedUsername("test@notgoogle.com")); + EXPECT_FALSE(policy_manager_.IsAllowedUsername(std::string())); +} + +TEST_F(PrimaryAccountPolicyManagerImplTest, TestAlternateWildcard) { + // Test to make sure we accept "*@google.com" as a pattern (treat it as if + // the admin entered ".*@google.com"). + local_state_.SetString(prefs::kGoogleServicesUsernamePattern, "*@google.com"); + EXPECT_TRUE(policy_manager_.IsAllowedUsername("test@google.com")); + EXPECT_TRUE(policy_manager_.IsAllowedUsername("happy@google.com")); + EXPECT_FALSE(policy_manager_.IsAllowedUsername("test@invalid.com")); + EXPECT_FALSE(policy_manager_.IsAllowedUsername("test@notgoogle.com")); + EXPECT_FALSE(policy_manager_.IsAllowedUsername(std::string())); +}
diff --git a/components/signin/public/identity_manager/access_token_fetcher_unittest.cc b/components/signin/public/identity_manager/access_token_fetcher_unittest.cc index d622989..e9db325 100644 --- a/components/signin/public/identity_manager/access_token_fetcher_unittest.cc +++ b/components/signin/public/identity_manager/access_token_fetcher_unittest.cc
@@ -14,6 +14,7 @@ #include "components/prefs/testing_pref_service.h" #include "components/signin/internal/identity_manager/account_tracker_service.h" #include "components/signin/internal/identity_manager/fake_profile_oauth2_token_service.h" +#include "components/signin/internal/identity_manager/primary_account_policy_manager_impl.h" #include "components/signin/internal/identity_manager/profile_oauth2_token_service_delegate.h" #include "components/signin/public/base/signin_pref_names.h" #include "components/signin/public/base/test_signin_client.h" @@ -66,7 +67,8 @@ account_tracker_(std::make_unique<AccountTrackerService>()), primary_account_manager_(&signin_client_, &token_service_, - account_tracker_.get()) { + account_tracker_.get(), + nullptr /* policy_manager */) { AccountTrackerService::RegisterPrefs(pref_service_.registry()); PrimaryAccountManager::RegisterProfilePrefs(pref_service_.registry());
diff --git a/components/signin/public/identity_manager/identity_manager_builder.cc b/components/signin/public/identity_manager/identity_manager_builder.cc index 4dc5694..f948d70 100644 --- a/components/signin/public/identity_manager/identity_manager_builder.cc +++ b/components/signin/public/identity_manager/identity_manager_builder.cc
@@ -17,6 +17,7 @@ #include "components/signin/internal/identity_manager/gaia_cookie_manager_service.h" #include "components/signin/internal/identity_manager/primary_account_manager.h" #include "components/signin/internal/identity_manager/primary_account_mutator_impl.h" +#include "components/signin/internal/identity_manager/primary_account_policy_manager.h" #include "components/signin/internal/identity_manager/profile_oauth2_token_service.h" #include "components/signin/internal/identity_manager/profile_oauth2_token_service_builder.h" #include "components/signin/public/base/account_consistency_method.h" @@ -40,6 +41,10 @@ #include "components/signin/internal/identity_manager/accounts_mutator_impl.h" #endif +#if !BUILDFLAG(IS_CHROMEOS_ASH) +#include "components/signin/internal/identity_manager/primary_account_policy_manager_impl.h" +#endif + namespace signin { namespace { @@ -58,8 +63,13 @@ ProfileOAuth2TokenService* token_service, PrefService* local_state) { std::unique_ptr<PrimaryAccountManager> primary_account_manager; + std::unique_ptr<PrimaryAccountPolicyManager> policy_manager; +#if !BUILDFLAG(IS_CHROMEOS_ASH) && !defined(OS_IOS) + policy_manager = std::make_unique<PrimaryAccountPolicyManagerImpl>(client); +#endif primary_account_manager = std::make_unique<PrimaryAccountManager>( - client, token_service, account_tracker_service); + client, token_service, account_tracker_service, + std::move(policy_manager)); primary_account_manager->Initialize(local_state); return primary_account_manager; }
diff --git a/components/signin/public/identity_manager/identity_manager_unittest.cc b/components/signin/public/identity_manager/identity_manager_unittest.cc index e04923c..7449b5fe4 100644 --- a/components/signin/public/identity_manager/identity_manager_unittest.cc +++ b/components/signin/public/identity_manager/identity_manager_unittest.cc
@@ -31,6 +31,7 @@ #include "components/signin/internal/identity_manager/gaia_cookie_manager_service.h" #include "components/signin/internal/identity_manager/primary_account_manager.h" #include "components/signin/internal/identity_manager/primary_account_mutator_impl.h" +#include "components/signin/internal/identity_manager/primary_account_policy_manager_impl.h" #include "components/signin/internal/identity_manager/profile_oauth2_token_service_delegate.h" #include "components/signin/public/base/account_consistency_method.h" #include "components/signin/public/base/consent_level.h" @@ -395,8 +396,14 @@ DCHECK_EQ(account_consistency, AccountConsistencyMethod::kDisabled) << "AccountConsistency is not used by PrimaryAccountManager"; + std::unique_ptr<PrimaryAccountPolicyManager> policy_manager; +#if !BUILDFLAG(IS_CHROMEOS_ASH) + policy_manager = + std::make_unique<PrimaryAccountPolicyManagerImpl>(&signin_client_); +#endif auto primary_account_manager = std::make_unique<PrimaryAccountManager>( - &signin_client_, token_service.get(), account_tracker_service.get()); + &signin_client_, token_service.get(), account_tracker_service.get(), + std::move(policy_manager)); // Passing this switch ensures that the new PrimaryAccountManager starts // with a clean slate. Otherwise PrimaryAccountManager::Initialize will use
diff --git a/components/signin/public/identity_manager/identity_test_environment.cc b/components/signin/public/identity_manager/identity_test_environment.cc index ca66309..6275bcf 100644 --- a/components/signin/public/identity_manager/identity_test_environment.cc +++ b/components/signin/public/identity_manager/identity_test_environment.cc
@@ -26,6 +26,7 @@ #include "components/signin/internal/identity_manager/gaia_cookie_manager_service.h" #include "components/signin/internal/identity_manager/primary_account_manager.h" #include "components/signin/internal/identity_manager/primary_account_mutator_impl.h" +#include "components/signin/internal/identity_manager/primary_account_policy_manager_impl.h" #include "components/signin/public/base/consent_level.h" #include "components/signin/public/base/test_signin_client.h" #include "components/signin/public/identity_manager/accounts_mutator.h" @@ -336,9 +337,15 @@ signin_client, token_service.get(), account_tracker_service.get(), std::make_unique<image_fetcher::FakeImageDecoder>()); + std::unique_ptr<PrimaryAccountPolicyManager> policy_manager; +#if !BUILDFLAG(IS_CHROMEOS_ASH) + policy_manager = + std::make_unique<PrimaryAccountPolicyManagerImpl>(signin_client); +#endif std::unique_ptr<PrimaryAccountManager> primary_account_manager = std::make_unique<PrimaryAccountManager>( - signin_client, token_service.get(), account_tracker_service.get()); + signin_client, token_service.get(), account_tracker_service.get(), + std::move(policy_manager)); primary_account_manager->Initialize(pref_service); std::unique_ptr<GaiaCookieManagerService> gaia_cookie_manager_service =
diff --git a/components/translate/core/language_detection/language_detection_model.cc b/components/translate/core/language_detection/language_detection_model.cc index b0f322fe..4435e064 100644 --- a/components/translate/core/language_detection/language_detection_model.cc +++ b/components/translate/core/language_detection/language_detection_model.cc
@@ -4,7 +4,7 @@ #include "components/translate/core/language_detection/language_detection_model.h" -#include "base/files/memory_mapped_file.h" +#include "base/cxx17_backports.h" #include "base/metrics/histogram_macros.h" #include "base/metrics/histogram_macros_local.h" #include "base/strings/utf_string_conversions.h" @@ -73,19 +73,27 @@ if (!model_file.IsValid()) return; - if (!model_fb_.Initialize(std::move(model_file))) - return; - recorder.set_state( LanguageDetectionModelState::kModelFileValidAndMemoryMapped); - auto statusor_classifier = tflite::task::text::nlclassifier::NLClassifier:: - CreateFromBufferAndOptions( - reinterpret_cast<const char*>(model_fb_.data()), model_fb_.length(), - {.input_tensor_index = 0, - .output_score_tensor_index = 0, - .output_label_tensor_index = 2}, - CreateLangIdResolver()); + tflite::task::text::NLClassifierOptions options; + options.set_input_tensor_index(0); + options.set_output_score_tensor_index(0); + options.set_output_label_tensor_index(2); + + std::string file_content(model_file.GetLength(), '\0'); + int bytes_read = + model_file.Read(0, base::data(file_content), model_file.GetLength()); + if (bytes_read != model_file.GetLength()) { + return; + } + *options.mutable_base_options() + ->mutable_model_file() + ->mutable_file_content() = std::move(file_content); + + auto statusor_classifier = + tflite::task::text::nlclassifier::NLClassifier::CreateFromOptions( + options, CreateLangIdResolver()); if (!statusor_classifier.ok()) { LOCAL_HISTOGRAM_BOOLEAN("LanguageDetection.TFLiteModel.InvalidModelFile", true);
diff --git a/components/translate/core/language_detection/language_detection_model.h b/components/translate/core/language_detection/language_detection_model.h index aa7ca26..5d28006 100644 --- a/components/translate/core/language_detection/language_detection_model.h +++ b/components/translate/core/language_detection/language_detection_model.h
@@ -6,7 +6,8 @@ #define COMPONENTS_TRANSLATE_CORE_LANGUAGE_DETECTION_LANGUAGE_DETECTION_MODEL_H_ #include <string> -#include "base/files/memory_mapped_file.h" + +#include "base/files/file.h" namespace tflite { namespace task { @@ -70,11 +71,6 @@ std::pair<std::string, float> DetectTopLanguage( const std::string& sampled_str) const; - // A memory-mapped file that contains the TFLite model used for - // determining the language of a page. This must be valid in order - // to evaluate the model owned by |this|. - base::MemoryMappedFile model_fb_; - // The tflite classifier that can determine the language of text. std::unique_ptr<tflite::task::text::nlclassifier::NLClassifier> lang_detection_model_;
diff --git a/components/translate/core/language_detection/language_detection_model_unittest.cc b/components/translate/core/language_detection/language_detection_model_unittest.cc index 1f70c31..d9e664a 100644 --- a/components/translate/core/language_detection/language_detection_model_unittest.cc +++ b/components/translate/core/language_detection/language_detection_model_unittest.cc
@@ -55,8 +55,7 @@ LanguageDetectionModelState::kModelFileInvalid, 1); } -// TODO(crbug.com/1240561): Fix flaky test. -TEST(LanguageDetectionModelTest, DISABLED_UnsupportedModelFileProvided) { +TEST(LanguageDetectionModelTest, UnsupportedModelFileProvided) { base::HistogramTester histogram_tester; base::File file = CreateInvalidModelFile();
diff --git a/components/viz/service/display/overlay_processor_ozone.cc b/components/viz/service/display/overlay_processor_ozone.cc index 2bda8ac5..341a536d 100644 --- a/components/viz/service/display/overlay_processor_ozone.cc +++ b/components/viz/service/display/overlay_processor_ozone.cc
@@ -245,6 +245,12 @@ return ToEnclosedRect(overlay.display_rect); } +void OverlayProcessorOzone::RegisterOverlayRequirement(bool requires_overlay) { + // This can be null in unit tests. + if (overlay_candidates_) + overlay_candidates_->RegisterOverlayRequirement(requires_overlay); +} + bool OverlayProcessorOzone::SetNativePixmapForCandidate( ui::OverlaySurfaceCandidate* candidate, const gpu::Mailbox& mailbox,
diff --git a/components/viz/service/display/overlay_processor_ozone.h b/components/viz/service/display/overlay_processor_ozone.h index c3112b6..337775b 100644 --- a/components/viz/service/display/overlay_processor_ozone.h +++ b/components/viz/service/display/overlay_processor_ozone.h
@@ -36,6 +36,7 @@ OverlayCandidateList* surfaces) override; gfx::Rect GetOverlayDamageRectForOutputSurface( const OverlayCandidate& candidate) const override; + void RegisterOverlayRequirement(bool requires_overlay) override; private: // Populates |native_pixmap| and |native_pixmap_unique_id| in |candidate|
diff --git a/components/viz/service/display/overlay_processor_using_strategy.cc b/components/viz/service/display/overlay_processor_using_strategy.cc index a4b1cfc..58658cd 100644 --- a/components/viz/service/display/overlay_processor_using_strategy.cc +++ b/components/viz/service/display/overlay_processor_using_strategy.cc
@@ -517,12 +517,15 @@ render_pass, *candidates); } + bool has_required_overlay = false; for (auto&& candidate : proposed_candidates) { // Underlays change the material so we save it here to record proper UMA. DrawQuad::Material quad_material = candidate.strategy->GetUMAEnum() != OverlayStrategy::kUnknown ? candidate.quad_iter->material : DrawQuad::Material::kInvalid; + if (candidate.candidate.requires_overlay) + has_required_overlay = true; bool used_overlay = candidate.strategy->AttemptPrioritized( output_color_matrix, render_pass_backdrop_filters, resource_provider, @@ -581,9 +584,11 @@ UpdateDownscalingCapabilities(scale_factor, /*success=*/true); } } + RegisterOverlayRequirement(has_required_overlay); return true; } } + RegisterOverlayRequirement(has_required_overlay); if (proposed_candidates.size() == 0) { LogStrategyEnumUMA(num_proposed_pre_sort != 0
diff --git a/components/viz/service/display/overlay_processor_using_strategy.h b/components/viz/service/display/overlay_processor_using_strategy.h index f2165277..a4220a74 100644 --- a/components/viz/service/display/overlay_processor_using_strategy.h +++ b/components/viz/service/display/overlay_processor_using_strategy.h
@@ -183,6 +183,11 @@ const OverlayProcessorInterface::OutputSurfaceOverlayPlane* primary_plane, OverlayCandidateList* candidate_list); + // This should be called during overlay processing to register whether or not + // there is a candidate that requires an overlay so that the manager can allow + // the overlay on the display with the requirement only. + virtual void RegisterOverlayRequirement(bool requires_overlay) {} + protected: virtual gfx::Rect GetOverlayDamageRectForOutputSurface( const OverlayCandidate& overlay) const;
diff --git a/components/viz/service/frame_sinks/external_begin_frame_source_android.cc b/components/viz/service/frame_sinks/external_begin_frame_source_android.cc index 49f8b53..e189e0fd 100644 --- a/components/viz/service/frame_sinks/external_begin_frame_source_android.cc +++ b/components/viz/service/frame_sinks/external_begin_frame_source_android.cc
@@ -4,19 +4,215 @@ #include "components/viz/service/frame_sinks/external_begin_frame_source_android.h" +#include <dlfcn.h> +#include <sys/types.h> + +#include "base/android/build_info.h" #include "base/android/jni_android.h" +#include "base/logging.h" +#include "base/trace_event/trace_event.h" #include "components/viz/service/service_jni_headers/ExternalBeginFrameSourceAndroid_jni.h" +extern "C" { +typedef struct AChoreographer AChoreographer; +typedef void (*AChoreographer_frameCallback64)(int64_t, void*); +typedef void (*AChoreographer_refreshRateCallback)(int64_t, void*); + +using pAChoreographer_getInstance = AChoreographer* (*)(); +using pAChoreographer_postFrameCallback64 = + void (*)(AChoreographer*, AChoreographer_frameCallback64, void*); +using pAChoreographer_registerRefreshRateCallback = + void (*)(AChoreographer*, AChoreographer_refreshRateCallback, void*); +using pAChoreographer_unregisterRefreshRateCallback = + void (*)(AChoreographer*, AChoreographer_refreshRateCallback, void*); +} + +namespace { + +#define LOAD_FUNCTION(lib, func) \ + do { \ + func##Fn = reinterpret_cast<p##func>(dlsym(lib, #func)); \ + if (!func##Fn) { \ + supported = false; \ + LOG(ERROR) << "Unable to load function " << #func; \ + } \ + } while (0) + +struct AChoreographerMethods { + static const AChoreographerMethods& Get() { + static AChoreographerMethods instance; + return instance; + } + + bool supported = true; + pAChoreographer_getInstance AChoreographer_getInstanceFn; + pAChoreographer_postFrameCallback64 AChoreographer_postFrameCallback64Fn; + pAChoreographer_registerRefreshRateCallback + AChoreographer_registerRefreshRateCallbackFn; + pAChoreographer_unregisterRefreshRateCallback + AChoreographer_unregisterRefreshRateCallbackFn; + + private: + AChoreographerMethods() { + void* main_dl_handle = dlopen("libandroid.so", RTLD_NOW); + if (!main_dl_handle) { + LOG(ERROR) << "Couldnt load libandroid.so"; + supported = false; + return; + } + + LOAD_FUNCTION(main_dl_handle, AChoreographer_getInstance); + LOAD_FUNCTION(main_dl_handle, AChoreographer_postFrameCallback64); + LOAD_FUNCTION(main_dl_handle, AChoreographer_registerRefreshRateCallback); + LOAD_FUNCTION(main_dl_handle, AChoreographer_unregisterRefreshRateCallback); + } + ~AChoreographerMethods() = default; +}; + +} // namespace + namespace viz { +class ExternalBeginFrameSourceAndroid::AChoreographerImpl { + public: + static std::unique_ptr<AChoreographerImpl> Create( + ExternalBeginFrameSourceAndroid* client); + + AChoreographerImpl(ExternalBeginFrameSourceAndroid* client, + AChoreographer* choreographer); + ~AChoreographerImpl(); + + void SetEnabled(bool enabled); + + private: + static void FrameCallback64(int64_t frame_time_nanos, void* data); + static void RefershRateCallback(int64_t vsync_period_nanos, void* data); + + void OnVSync(int64_t frame_time_nanos, + base::WeakPtr<AChoreographerImpl>* self); + void SetVsyncPeriod(int64_t vsync_period_nanos); + void RequestVsyncIfNeeded(); + + ExternalBeginFrameSourceAndroid* const client_; + AChoreographer* const achoreographer_; + + base::TimeDelta vsync_period_; + bool vsync_notification_enabled_ = false; + // This is a heap-allocated WeakPtr to this object. The WeakPtr is either + // * passed to `postFrameCallback` if there is one (and exactly one) callback + // pending. This is in case this is deleted before a pending callback + // fires, in which case the callback is responsible for deleting the + // WeakPtr. + // * or owned by this member variable when there is no pending callback. + // Thus whether this is nullptr also indicates whether there is a pending + // frame callback. + std::unique_ptr<base::WeakPtr<AChoreographerImpl>> self_for_frame_callback_; + base::WeakPtrFactory<AChoreographerImpl> weak_ptr_factory_{this}; +}; + +// static +std::unique_ptr<ExternalBeginFrameSourceAndroid::AChoreographerImpl> +ExternalBeginFrameSourceAndroid::AChoreographerImpl::Create( + ExternalBeginFrameSourceAndroid* client) { + if (base::android::BuildInfo::GetInstance()->sdk_int() < + base::android::SDK_VERSION_R) { + return nullptr; + } + if (!AChoreographerMethods::Get().supported) + return nullptr; + + AChoreographer* choreographer = + AChoreographerMethods::Get().AChoreographer_getInstanceFn(); + if (!choreographer) + return nullptr; + + return std::make_unique<AChoreographerImpl>(client, choreographer); +} + +ExternalBeginFrameSourceAndroid::AChoreographerImpl::AChoreographerImpl( + ExternalBeginFrameSourceAndroid* client, + AChoreographer* choreographer) + : client_(client), + achoreographer_(choreographer), + vsync_period_(base::Microseconds(16666)) { + AChoreographerMethods::Get().AChoreographer_registerRefreshRateCallbackFn( + achoreographer_, &RefershRateCallback, this); + self_for_frame_callback_ = + std::make_unique<base::WeakPtr<AChoreographerImpl>>( + weak_ptr_factory_.GetWeakPtr()); +} + +ExternalBeginFrameSourceAndroid::AChoreographerImpl::~AChoreographerImpl() { + AChoreographerMethods::Get().AChoreographer_unregisterRefreshRateCallbackFn( + achoreographer_, &RefershRateCallback, this); +} + +void ExternalBeginFrameSourceAndroid::AChoreographerImpl::SetEnabled( + bool enabled) { + if (vsync_notification_enabled_ == enabled) + return; + vsync_notification_enabled_ = enabled; + RequestVsyncIfNeeded(); +} + +// static +void ExternalBeginFrameSourceAndroid::AChoreographerImpl::FrameCallback64( + int64_t frame_time_nanos, + void* data) { + TRACE_EVENT0("toplevel,viz", "VSync"); + auto* self = static_cast<base::WeakPtr<AChoreographerImpl>*>(data); + if (!(*self)) { + delete self; + return; + } + (*self)->OnVSync(frame_time_nanos, self); +} + +// static +void ExternalBeginFrameSourceAndroid::AChoreographerImpl::RefershRateCallback( + int64_t vsync_period_nanos, + void* data) { + static_cast<AChoreographerImpl*>(data)->SetVsyncPeriod(vsync_period_nanos); +} + +void ExternalBeginFrameSourceAndroid::AChoreographerImpl::OnVSync( + int64_t frame_time_nanos, + base::WeakPtr<AChoreographerImpl>* self) { + DCHECK(!self_for_frame_callback_); + DCHECK(self); + self_for_frame_callback_.reset(self); + if (vsync_notification_enabled_) { + client_->OnVSyncImpl(frame_time_nanos, vsync_period_); + RequestVsyncIfNeeded(); + } +} + +void ExternalBeginFrameSourceAndroid::AChoreographerImpl::SetVsyncPeriod( + int64_t vsync_period_nanos) { + vsync_period_ = base::Nanoseconds(vsync_period_nanos); +} + +void ExternalBeginFrameSourceAndroid::AChoreographerImpl:: + RequestVsyncIfNeeded() { + if (!vsync_notification_enabled_ || !self_for_frame_callback_) + return; + AChoreographerMethods::Get().AChoreographer_postFrameCallback64Fn( + achoreographer_, &FrameCallback64, self_for_frame_callback_.release()); +} + +// ============================================================================ + ExternalBeginFrameSourceAndroid::ExternalBeginFrameSourceAndroid( uint32_t restart_id, float refresh_rate) - : ExternalBeginFrameSource(this, restart_id), - j_object_(Java_ExternalBeginFrameSourceAndroid_Constructor( - base::android::AttachCurrentThread(), - reinterpret_cast<jlong>(this), - refresh_rate)) {} + : ExternalBeginFrameSource(this, restart_id) { + achoreographer_ = AChoreographerImpl::Create(this); + if (!achoreographer_) { + j_object_ = Java_ExternalBeginFrameSourceAndroid_Constructor( + base::android::AttachCurrentThread(), reinterpret_cast<jlong>(this), + refresh_rate); + } +} ExternalBeginFrameSourceAndroid::~ExternalBeginFrameSourceAndroid() { SetEnabled(false); @@ -27,6 +223,12 @@ const base::android::JavaParamRef<jobject>& obj, jlong time_micros, jlong period_micros) { + OnVSyncImpl(time_micros * 1000, base::Microseconds(period_micros)); +} + +void ExternalBeginFrameSourceAndroid::OnVSyncImpl( + int64_t time_nanos, + base::TimeDelta vsync_period) { // Warning: It is generally unsafe to manufacture TimeTicks values. The // following assumption is being made, AND COULD EASILY BREAK AT ANY TIME: // Upstream, Java code is providing "System.nanos() / 1000," and this is the @@ -34,8 +236,7 @@ DCHECK_EQ(base::TimeTicks::GetClock(), base::TimeTicks::Clock::LINUX_CLOCK_MONOTONIC); base::TimeTicks frame_time = - base::TimeTicks() + base::Microseconds(time_micros); - base::TimeDelta vsync_period(base::Microseconds(period_micros)); + base::TimeTicks() + base::Nanoseconds(time_nanos); // Calculate the next frame deadline: base::TimeTicks deadline = frame_time + vsync_period; @@ -45,8 +246,10 @@ } void ExternalBeginFrameSourceAndroid::UpdateRefreshRate(float refresh_rate) { - Java_ExternalBeginFrameSourceAndroid_updateRefreshRate( - base::android::AttachCurrentThread(), j_object_, refresh_rate); + if (j_object_) { + Java_ExternalBeginFrameSourceAndroid_updateRefreshRate( + base::android::AttachCurrentThread(), j_object_, refresh_rate); + } } void ExternalBeginFrameSourceAndroid::SetDynamicBeginFrameDeadlineOffsetSource( @@ -62,8 +265,13 @@ } void ExternalBeginFrameSourceAndroid::SetEnabled(bool enabled) { - Java_ExternalBeginFrameSourceAndroid_setEnabled( - base::android::AttachCurrentThread(), j_object_, enabled); + if (achoreographer_) { + achoreographer_->SetEnabled(enabled); + } else { + DCHECK(j_object_); + Java_ExternalBeginFrameSourceAndroid_setEnabled( + base::android::AttachCurrentThread(), j_object_, enabled); + } } } // namespace viz
diff --git a/components/viz/service/frame_sinks/external_begin_frame_source_android.h b/components/viz/service/frame_sinks/external_begin_frame_source_android.h index c8c7f7f..b7ed105 100644 --- a/components/viz/service/frame_sinks/external_begin_frame_source_android.h +++ b/components/viz/service/frame_sinks/external_begin_frame_source_android.h
@@ -6,8 +6,10 @@ #define COMPONENTS_VIZ_SERVICE_FRAME_SINKS_EXTERNAL_BEGIN_FRAME_SOURCE_ANDROID_H_ #include <jni.h> +#include <memory> #include "base/android/jni_weak_ref.h" +#include "base/time/time.h" #include "components/viz/common/frame_sinks/begin_frame_source.h" #include "components/viz/service/viz_service_export.h" @@ -40,11 +42,15 @@ dynamic_begin_frame_deadline_offset_source) override; private: + class AChoreographerImpl; + // ExternalBeginFrameSourceClient implementation. void OnNeedsBeginFrames(bool needs_begin_frames) override; void SetEnabled(bool enabled); + void OnVSyncImpl(int64_t time_nanos, base::TimeDelta vsync_period); + std::unique_ptr<AChoreographerImpl> achoreographer_; base::android::ScopedJavaGlobalRef<jobject> j_object_; BeginFrameArgsGenerator begin_frame_args_generator_; };
diff --git a/components/viz/service/surfaces/surface_saved_frame.cc b/components/viz/service/surfaces/surface_saved_frame.cc index bb309039..3ebb6e43 100644 --- a/components/viz/service/surfaces/surface_saved_frame.cc +++ b/components/viz/service/surfaces/surface_saved_frame.cc
@@ -82,6 +82,19 @@ void SurfaceSavedFrame::RequestCopyOfOutput(Surface* surface) { DCHECK(surface->HasActiveFrame()); + if (surface->GetActiveFrame().metadata.has_shared_element_resources) { + // TODO(khushalsagar) : This should be the only mode once renderer based SET + // lands. + copy_root_render_pass_ = false; + CopyUsingOriginalFrame(surface); + } else { + CopyUsingCleanFrame(surface); + } + + DCHECK_EQ(copy_request_count_, ExpectedResultCount()); +} + +void SurfaceSavedFrame::CopyUsingCleanFrame(Surface* surface) { const auto& root_draw_data = GetRootRenderPassDrawData(surface); // Bind kRoot and root geometry information to the callback. auto root_request = std::make_unique<CopyOutputRequest>( @@ -100,20 +113,6 @@ return; } - if (surface->GetActiveFrame().metadata.has_shared_element_resources) { - // TODO(khushalsagar) : This should be the only mode once renderer based SET - // lands. - CopyUsingOriginalFrame(surface, std::move(root_request)); - } else { - CopyUsingCleanFrame(surface, std::move(root_request)); - } - - DCHECK_EQ(copy_request_count_, ExpectedResultCount()); -} - -void SurfaceSavedFrame::CopyUsingCleanFrame( - Surface* surface, - std::unique_ptr<CopyOutputRequest> root_request) { // If the directive includes shared elements then we need to create a new // CompositorFrame with render passes that remove these elements. The strategy // is as follows : @@ -187,9 +186,7 @@ clean_surface_.emplace(surface, std::move(clean_frame)); } -void SurfaceSavedFrame::CopyUsingOriginalFrame( - Surface* surface, - std::unique_ptr<CopyOutputRequest> root_request) { +void SurfaceSavedFrame::CopyUsingOriginalFrame(Surface* surface) { const auto& active_frame = surface->GetActiveFrame(); for (const auto& render_pass : active_frame.render_pass_list) { if (auto request = CreateCopyRequestIfNeeded( @@ -199,12 +196,6 @@ copy_request_count_++; } } - - // TODO(khushalsagar) : The root element should be an intermediate render pass - // in the renderer's frame. We could optimize it if there are no shared - // elements. See crbug.com/1265700. - surface->RequestCopyOfOutputOnRootRenderPass(std::move(root_request)); - copy_request_count_++; } std::unique_ptr<CopyOutputRequest> SurfaceSavedFrame::CreateCopyRequestIfNeeded( @@ -288,7 +279,7 @@ size_t SurfaceSavedFrame::ExpectedResultCount() const { // Start with 1 for the root render pass. - size_t count = 1; + size_t count = copy_root_render_pass_ ? 1 : 0; for (auto& shared_element : directive_.shared_elements()) count += !shared_element.render_pass_id.is_null(); return count;
diff --git a/components/viz/service/surfaces/surface_saved_frame.h b/components/viz/service/surfaces/surface_saved_frame.h index cac37b2..90f7ddfc 100644 --- a/components/viz/service/surfaces/surface_saved_frame.h +++ b/components/viz/service/surfaces/surface_saved_frame.h
@@ -114,14 +114,12 @@ // Queues copy requests by creating a copy of the CompositorFrame as specified // in ScopedCleanSurface. - void CopyUsingCleanFrame(Surface* surface, - std::unique_ptr<CopyOutputRequest> root_request); + void CopyUsingCleanFrame(Surface* surface); // Queues copy requests from the original CompositorFrame. This mode is used // when the frame produced by the renderer already has independent render // passes for each shared element. - void CopyUsingOriginalFrame(Surface* surface, - std::unique_ptr<CopyOutputRequest> root_request); + void CopyUsingOriginalFrame(Surface* surface); std::unique_ptr<CopyOutputRequest> CreateCopyRequestIfNeeded( const CompositorRenderPass& render_pass, @@ -177,6 +175,9 @@ // whether the SurfaceSavedFrame is "valid". size_t valid_result_count_ = 0; + // Tracks whether the root render pass should be copied. + bool copy_root_render_pass_ = true; + absl::optional<ScopedCleanSurface> clean_surface_; base::WeakPtrFactory<SurfaceSavedFrame> weak_factory_{this};
diff --git a/components/viz/service/transitions/surface_animation_manager.cc b/components/viz/service/transitions/surface_animation_manager.cc index 30749e3..3f4727c 100644 --- a/components/viz/service/transitions/surface_animation_manager.cc +++ b/components/viz/service/transitions/surface_animation_manager.cc
@@ -1187,11 +1187,11 @@ shared_element_quad.resource_id); if (texture_it != saved_textures_->element_id_to_resource.end()) { - resource_list->push_back(saved_textures_->element_id_to_resource.at( - shared_element_quad.resource_id)); + const auto& transferable_resource = texture_it->second; + resource_list->push_back(transferable_resource); // GPU textures are flipped but software bitmaps are not. - bool y_flipped = !saved_textures_->root.resource.is_software; + bool y_flipped = !transferable_resource.is_software; ReplaceSharedElementWithTexture(©_pass, shared_element_quad, y_flipped, resource_list->back().id); return true;
diff --git a/content/app/content_main_runner_impl.cc b/content/app/content_main_runner_impl.cc index 3999a567..b21e6cb 100644 --- a/content/app/content_main_runner_impl.cc +++ b/content/app/content_main_runner_impl.cc
@@ -10,6 +10,7 @@ #include <memory> #include <string> +#include <tuple> #include <utility> #include <vector> @@ -25,7 +26,6 @@ #include "base/debug/stack_trace.h" #include "base/files/file_path.h" #include "base/i18n/icu_util.h" -#include "base/ignore_result.h" #include "base/lazy_instance.h" #include "base/location.h" #include "base/logging.h" @@ -1044,7 +1044,7 @@ base::FieldTrialList* leaked_field_trial_list = SetUpFieldTrialsAndFeatureList().release(); ANNOTATE_LEAKING_OBJECT_PTR(leaked_field_trial_list); - ignore_result(leaked_field_trial_list); + std::ignore = leaked_field_trial_list; delegate_->PostFieldTrialInitialization(); mojo::core::InitFeatures(); }
diff --git a/content/app_shim_remote_cocoa/ns_view_bridge_factory_impl.mm b/content/app_shim_remote_cocoa/ns_view_bridge_factory_impl.mm index 6675ae22..2cd40c2 100644 --- a/content/app_shim_remote_cocoa/ns_view_bridge_factory_impl.mm +++ b/content/app_shim_remote_cocoa/ns_view_bridge_factory_impl.mm
@@ -8,7 +8,6 @@ #include <vector> #include "base/bind.h" -#include "base/ignore_result.h" #include "content/app_shim_remote_cocoa/render_widget_host_ns_view_bridge.h" #include "content/app_shim_remote_cocoa/render_widget_host_ns_view_host_helper.h" #include "content/app_shim_remote_cocoa/web_contents_ns_view_bridge.h" @@ -154,10 +153,10 @@ // Create a RenderWidgetHostNSViewBridgeOwner. The resulting object will be // destroyed when its underlying pipe is closed. - ignore_result(new RenderWidgetHostNSViewBridgeOwner( + std::ignore = new RenderWidgetHostNSViewBridgeOwner( std::move(host), mojo::PendingAssociatedReceiver<mojom::RenderWidgetHostNSView>( - std::move(view_receiver_handle)))); + std::move(view_receiver_handle))); } void CreateWebContentsNSView(
diff --git a/content/app_shim_remote_cocoa/render_widget_host_view_cocoa.mm b/content/app_shim_remote_cocoa/render_widget_host_view_cocoa.mm index 4d83201..050e234 100644 --- a/content/app_shim_remote_cocoa/render_widget_host_view_cocoa.mm +++ b/content/app_shim_remote_cocoa/render_widget_host_view_cocoa.mm
@@ -5,13 +5,14 @@ #import "content/app_shim_remote_cocoa/render_widget_host_view_cocoa.h" #include <Carbon/Carbon.h> // for <HIToolbox/Events.h> + #include <limits> +#include <tuple> #include <utility> #include "base/containers/contains.h" #include "base/cxx17_backports.h" #include "base/debug/crash_logging.h" -#include "base/ignore_result.h" #import "base/mac/foundation_util.h" #include "base/strings/sys_string_conversions.h" #import "content/browser/accessibility/browser_accessibility_cocoa.h" @@ -619,7 +620,7 @@ - (void)setHostDisconnected { // Set the host to be an abandoned message pipe, and set the hostHelper // to forward messages to that host. - ignore_result(_dummyHost.BindNewPipeAndPassReceiver()); + std::ignore = _dummyHost.BindNewPipeAndPassReceiver(); _dummyHostHelper = std::make_unique<DummyHostHelper>(); _host = _dummyHost.get(); _hostHelper = _dummyHostHelper.get();
diff --git a/content/browser/accessibility/touch_accessibility_aura_browsertest.cc b/content/browser/accessibility/touch_accessibility_aura_browsertest.cc index 4622759d..baefdf0 100644 --- a/content/browser/accessibility/touch_accessibility_aura_browsertest.cc +++ b/content/browser/accessibility/touch_accessibility_aura_browsertest.cc
@@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "base/ignore_result.h" +#include <tuple> + #include "base/strings/string_number_conversions.h" #include "content/browser/accessibility/browser_accessibility.h" #include "content/browser/renderer_host/render_widget_host_view_child_frame.h" @@ -58,7 +59,7 @@ std::unique_ptr<ui::Event> mouse_move_event( new ui::MouseEvent(ui::ET_MOUSE_MOVED, location, location, ui::EventTimeForNow(), flags, 0)); - ignore_result(sink->OnEventFromSource(mouse_move_event.get())); + std::ignore = sink->OnEventFromSource(mouse_move_event.get()); } };
diff --git a/content/browser/android/selection/selection_popup_controller.cc b/content/browser/android/selection/selection_popup_controller.cc index 71d47433..8c72c07 100644 --- a/content/browser/android/selection/selection_popup_controller.cc +++ b/content/browser/android/selection/selection_popup_controller.cc
@@ -214,7 +214,8 @@ Java_SelectionPopupControllerImpl_onSelectAroundCaretFailure(env, obj); } else { Java_SelectionPopupControllerImpl_onSelectAroundCaretSuccess( - env, obj, result->extended_start_adjust, result->extended_end_adjust); + env, obj, result->extended_start_adjust, result->extended_end_adjust, + result->word_start_adjust, result->word_end_adjust); } }
diff --git a/content/browser/attribution_reporting/BUILD.gn b/content/browser/attribution_reporting/BUILD.gn index fc6f3d3..ea116da 100644 --- a/content/browser/attribution_reporting/BUILD.gn +++ b/content/browser/attribution_reporting/BUILD.gn
@@ -13,11 +13,19 @@ { types = [ { + mojom = "content.mojom.AttributionReportID" + cpp = "::content::EventAttributionReport::Id" + }, + { mojom = "content.mojom.SourceType" cpp = "::content::StorableSource::SourceType" }, ] - traits_headers = [ "attribution_internals_mojom_traits.h" ] + traits_headers = [ + "attribution_internals_mojom_traits.h", + "event_attribution_report.h", + ] + traits_sources = [ "attribution_internals_mojom_traits.cc" ] }, ] }
diff --git a/content/browser/attribution_reporting/attribution_internals.mojom b/content/browser/attribution_reporting/attribution_internals.mojom index 4b272f86..b4bfc58 100644 --- a/content/browser/attribution_reporting/attribution_internals.mojom +++ b/content/browser/attribution_reporting/attribution_internals.mojom
@@ -13,9 +13,16 @@ kEvent, }; +struct AttributionReportID { + int64 value; +}; + // Struct containing stored data that will be sent in a future attribution // report. struct WebUIAttributionReport { + // Allows the WebUI to issue commands for individual reports. + // Not intended to be displayed. + AttributionReportID? id; string attribution_destination; url.mojom.Url report_url; double trigger_time; @@ -100,10 +107,10 @@ // being sent. GetReports() => (array<WebUIAttributionReport> reports); - // Sends all stored reports, ignoring delay, returning when the - // operation has been completed and all reports have been cleared from + // Sends the given reports, ignoring delay, returning when the + // operation has been completed and the reports have been cleared from // storage. - SendPendingReports() => (); + SendReports(array<AttributionReportID> ids) => (); // Deletes all persisted data for the attribution reporting API, returning when the // operation has been completed.
diff --git a/content/browser/attribution_reporting/attribution_internals_browsertest.cc b/content/browser/attribution_reporting/attribution_internals_browsertest.cc index 346a162..0367b26 100644 --- a/content/browser/attribution_reporting/attribution_internals_browsertest.cc +++ b/content/browser/attribution_reporting/attribution_internals_browsertest.cc
@@ -39,6 +39,7 @@ using DeactivatedSource = ::content::AttributionStorage::DeactivatedSource; using ::testing::_; +using ::testing::ElementsAre; using ::testing::IsNull; using ::testing::Return; @@ -389,22 +390,22 @@ let table = document.querySelector("#report-table-wrapper tbody"); let obs = new MutationObserver(() => { if (table.children.length === 7 && - table.children[0].children[1].innerText === "https://conversion.test" && - table.children[0].children[2].innerText === + table.children[0].children[2].innerText === "https://conversion.test" && + table.children[0].children[3].innerText === "https://report.test/.well-known/attribution-reporting/report-attribution" && - table.children[0].children[5].innerText === "13" && - table.children[0].children[6].innerText === "yes" && - table.children[0].children[7].innerText === "Pending" && - table.children[1].children[5].innerText === "11" && - table.children[1].children[7].innerText === "Dropped due to low priority" && - table.children[2].children[5].innerText === "12" && - table.children[2].children[7].innerText === "Dropped for noise" && - table.children[3].children[5].innerText === "0" && - table.children[3].children[6].innerText === "no" && - table.children[3].children[7].innerText === "Sent: HTTP 200" && - table.children[4].children[7].innerText === "Prohibited by browser policy" && - table.children[5].children[7].innerText === "Network error" && - table.children[6].children[7].innerText === "Dropped due to rate-limiting") { + table.children[0].children[6].innerText === "13" && + table.children[0].children[7].innerText === "yes" && + table.children[0].children[8].innerText === "Pending" && + table.children[1].children[6].innerText === "11" && + table.children[1].children[8].innerText === "Dropped due to low priority" && + table.children[2].children[6].innerText === "12" && + table.children[2].children[8].innerText === "Dropped for noise" && + table.children[3].children[6].innerText === "0" && + table.children[3].children[7].innerText === "no" && + table.children[3].children[8].innerText === "Sent: HTTP 200" && + table.children[4].children[8].innerText === "Prohibited by browser policy" && + table.children[5].children[8].innerText === "Network error" && + table.children[6].children[8].innerText === "Dropped due to rate-limiting") { document.title = $1; } }); @@ -421,22 +422,22 @@ let table = document.querySelector("#report-table-wrapper tbody"); let obs = new MutationObserver(() => { if (table.children.length === 7 && - table.children[6].children[1].innerText === "https://conversion.test" && - table.children[6].children[2].innerText === + table.children[6].children[2].innerText === "https://conversion.test" && + table.children[6].children[3].innerText === "https://report.test/.well-known/attribution-reporting/report-attribution" && - table.children[6].children[5].innerText === "13" && - table.children[6].children[6].innerText === "yes" && - table.children[6].children[7].innerText === "Pending" && - table.children[5].children[5].innerText === "12" && - table.children[5].children[7].innerText === "Dropped for noise" && - table.children[4].children[5].innerText === "11" && - table.children[4].children[7].innerText === "Dropped due to low priority" && - table.children[3].children[5].innerText === "0" && - table.children[3].children[6].innerText === "no" && - table.children[3].children[7].innerText === "Sent: HTTP 200" && - table.children[2].children[7].innerText === "Prohibited by browser policy" && - table.children[1].children[7].innerText === "Network error" && - table.children[0].children[7].innerText === "Dropped due to rate-limiting") { + table.children[6].children[6].innerText === "13" && + table.children[6].children[7].innerText === "yes" && + table.children[6].children[8].innerText === "Pending" && + table.children[5].children[6].innerText === "12" && + table.children[5].children[8].innerText === "Dropped for noise" && + table.children[4].children[6].innerText === "11" && + table.children[4].children[8].innerText === "Dropped due to low priority" && + table.children[3].children[6].innerText === "0" && + table.children[3].children[7].innerText === "no" && + table.children[3].children[8].innerText === "Sent: HTTP 200" && + table.children[2].children[8].innerText === "Prohibited by browser policy" && + table.children[1].children[8].innerText === "Network error" && + table.children[0].children[8].innerText === "Dropped due to rate-limiting") { document.title = $1; } }); @@ -446,7 +447,7 @@ TitleWatcher title_watcher(shell()->web_contents(), kCompleteTitle2); // Sort by priority ascending. EXPECT_TRUE(ExecJsInWebUI( - "document.querySelectorAll('#report-table-wrapper th')[5].click();")); + "document.querySelectorAll('#report-table-wrapper th')[6].click();")); EXPECT_EQ(kCompleteTitle2, title_watcher.WaitAndGetTitle()); } @@ -455,22 +456,22 @@ let table = document.querySelector("#report-table-wrapper tbody"); let obs = new MutationObserver(() => { if (table.children.length === 7 && - table.children[0].children[1].innerText === "https://conversion.test" && - table.children[0].children[2].innerText === + table.children[0].children[2].innerText === "https://conversion.test" && + table.children[0].children[3].innerText === "https://report.test/.well-known/attribution-reporting/report-attribution" && - table.children[0].children[5].innerText === "13" && - table.children[0].children[6].innerText === "yes" && - table.children[0].children[7].innerText === "Pending" && - table.children[1].children[5].innerText === "12" && - table.children[1].children[7].innerText === "Dropped for noise" && - table.children[2].children[5].innerText === "11" && - table.children[2].children[7].innerText === "Dropped due to low priority" && - table.children[3].children[5].innerText === "0" && - table.children[3].children[6].innerText === "no" && - table.children[3].children[7].innerText === "Sent: HTTP 200" && - table.children[4].children[7].innerText === "Prohibited by browser policy" && - table.children[5].children[7].innerText === "Network error" && - table.children[6].children[7].innerText === "Dropped due to rate-limiting") { + table.children[0].children[6].innerText === "13" && + table.children[0].children[7].innerText === "yes" && + table.children[0].children[8].innerText === "Pending" && + table.children[1].children[6].innerText === "12" && + table.children[1].children[8].innerText === "Dropped for noise" && + table.children[2].children[6].innerText === "11" && + table.children[2].children[8].innerText === "Dropped due to low priority" && + table.children[3].children[6].innerText === "0" && + table.children[3].children[7].innerText === "no" && + table.children[3].children[8].innerText === "Sent: HTTP 200" && + table.children[4].children[8].innerText === "Prohibited by browser policy" && + table.children[5].children[8].innerText === "Network error" && + table.children[6].children[8].innerText === "Dropped due to rate-limiting") { document.title = $1; } }); @@ -480,7 +481,7 @@ TitleWatcher title_watcher(shell()->web_contents(), kCompleteTitle3); // Sort by priority descending. EXPECT_TRUE(ExecJsInWebUI( - "document.querySelectorAll('#report-table-wrapper th')[5].click();")); + "document.querySelectorAll('#report-table-wrapper th')[6].click();")); EXPECT_EQ(kCompleteTitle3, title_watcher.WaitAndGetTitle()); } @@ -515,8 +516,8 @@ let table = document.querySelector("#report-table-wrapper tbody"); let obs = new MutationObserver(() => { if (table.children.length === 2 && - table.children[0].children[5].innerText === "7" && - table.children[1].children[7].innerText === "Sent: HTTP 200") { + table.children[0].children[6].innerText === "7" && + table.children[1].children[8].innerText === "Sent: HTTP 200") { document.title = $1; } }); @@ -603,11 +604,16 @@ EXPECT_CALL(manager_, GetPendingReportsForWebUI) .WillOnce(InvokeCallback<std::vector<EventAttributionReport>>( - {ReportBuilder(SourceBuilder().Build()).SetPriority(7).Build()})) + {ReportBuilder(SourceBuilder().Build()) + .SetPriority(7) + .SetReportId(EventAttributionReport::Id(5)) + .Build()})) .WillOnce(InvokeCallback<std::vector<EventAttributionReport>>({})); - EXPECT_CALL(manager_, SendReportsForWebUI) - .WillOnce([](base::OnceClosure done) { std::move(done).Run(); }); + EXPECT_CALL(manager_, SendReportsForWebUI( + ElementsAre(EventAttributionReport::Id(5)), _)) + .WillOnce([](const std::vector<EventAttributionReport::Id>& ids, + base::OnceClosure done) { std::move(done).Run(); }); OverrideWebUIAttributionManager(); @@ -615,7 +621,7 @@ let table = document.querySelector("#report-table-wrapper tbody"); let obs = new MutationObserver(() => { if (table.children.length === 1 && - table.children[0].children[5].innerText === "7") { + table.children[0].children[6].innerText === "7") { document.title = $1; } }); @@ -632,6 +638,8 @@ TitleWatcher sent_title_watcher(shell()->web_contents(), kSentTitle); SetTitleOnReportsTableEmpty(kSentTitle); + EXPECT_TRUE(ExecJsInWebUI( + R"(document.querySelector('input[type="checkbox"]').click();)")); EXPECT_TRUE( ExecJsInWebUI("document.getElementById('send-reports').click();"));
diff --git a/content/browser/attribution_reporting/attribution_internals_handler_impl.cc b/content/browser/attribution_reporting/attribution_internals_handler_impl.cc index 1efa8e62..49d90169 100644 --- a/content/browser/attribution_reporting/attribution_internals_handler_impl.cc +++ b/content/browser/attribution_reporting/attribution_internals_handler_impl.cc
@@ -82,7 +82,8 @@ int http_response_code, mojom::WebUIAttributionReport::Status status) { return mojom::WebUIAttributionReport::New( - report.source().ConversionDestination().Serialize(), report.ReportURL(), + report.report_id(), report.source().ConversionDestination().Serialize(), + report.ReportURL(), /*trigger_time=*/report.conversion_time().ToJsTime(), /*report_time=*/report.report_time().ToJsTime(), report.priority(), report.ReportBody(/*pretty_print=*/true), @@ -154,11 +155,12 @@ } } -void AttributionInternalsHandlerImpl::SendPendingReports( - mojom::AttributionInternalsHandler::SendPendingReportsCallback callback) { +void AttributionInternalsHandlerImpl::SendReports( + const std::vector<EventAttributionReport::Id>& ids, + mojom::AttributionInternalsHandler::SendReportsCallback callback) { if (AttributionManager* manager = manager_provider_->GetManager(web_ui_->GetWebContents())) { - manager->SendReportsForWebUI(std::move(callback)); + manager->SendReportsForWebUI(ids, std::move(callback)); } else { std::move(callback).Run(); }
diff --git a/content/browser/attribution_reporting/attribution_internals_handler_impl.h b/content/browser/attribution_reporting/attribution_internals_handler_impl.h index dfb1348..221e621 100644 --- a/content/browser/attribution_reporting/attribution_internals_handler_impl.h +++ b/content/browser/attribution_reporting/attribution_internals_handler_impl.h
@@ -10,6 +10,7 @@ #include "content/browser/attribution_reporting/attribution_internals.mojom.h" #include "content/browser/attribution_reporting/attribution_manager.h" #include "content/browser/attribution_reporting/attribution_storage.h" +#include "content/browser/attribution_reporting/event_attribution_report.h" #include "mojo/public/cpp/bindings/pending_receiver.h" #include "mojo/public/cpp/bindings/receiver.h" #include "mojo/public/cpp/bindings/remote_set.h" @@ -49,9 +50,9 @@ override; void GetReports( mojom::AttributionInternalsHandler::GetReportsCallback callback) override; - void SendPendingReports( - mojom::AttributionInternalsHandler::SendPendingReportsCallback callback) - override; + void SendReports(const std::vector<EventAttributionReport::Id>& ids, + mojom::AttributionInternalsHandler::SendReportsCallback + callback) override; void ClearStorage(mojom::AttributionInternalsHandler::ClearStorageCallback callback) override; void AddObserver(
diff --git a/content/browser/attribution_reporting/attribution_internals_mojom_traits.cc b/content/browser/attribution_reporting/attribution_internals_mojom_traits.cc new file mode 100644 index 0000000..cbc8265 --- /dev/null +++ b/content/browser/attribution_reporting/attribution_internals_mojom_traits.cc
@@ -0,0 +1,18 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "content/browser/attribution_reporting/attribution_internals_mojom_traits.h" + +namespace mojo { + +// static +bool StructTraits<content::mojom::AttributionReportIDDataView, + content::EventAttributionReport::Id>:: + Read(content::mojom::AttributionReportIDDataView data, + content::EventAttributionReport::Id* out) { + *out = content::EventAttributionReport::Id(data.value()); + return true; +} + +} // namespace mojo
diff --git a/content/browser/attribution_reporting/attribution_internals_mojom_traits.h b/content/browser/attribution_reporting/attribution_internals_mojom_traits.h index 8750fc5..2d3187a8 100644 --- a/content/browser/attribution_reporting/attribution_internals_mojom_traits.h +++ b/content/browser/attribution_reporting/attribution_internals_mojom_traits.h
@@ -5,9 +5,13 @@ #ifndef CONTENT_BROWSER_ATTRIBUTION_REPORTING_ATTRIBUTION_INTERNALS_MOJOM_TRAITS_H_ #define CONTENT_BROWSER_ATTRIBUTION_REPORTING_ATTRIBUTION_INTERNALS_MOJOM_TRAITS_H_ +#include <stdint.h> + #include "content/browser/attribution_reporting/attribution_internals.mojom.h" +#include "content/browser/attribution_reporting/event_attribution_report.h" #include "content/browser/attribution_reporting/storable_source.h" #include "mojo/public/cpp/bindings/enum_traits.h" +#include "mojo/public/cpp/bindings/struct_traits.h" namespace mojo { @@ -40,6 +44,18 @@ } }; +template <> +class StructTraits<content::mojom::AttributionReportIDDataView, + content::EventAttributionReport::Id> { + public: + static int64_t value(const content::EventAttributionReport::Id& id) { + return *id; + } + + static bool Read(content::mojom::AttributionReportIDDataView data, + content::EventAttributionReport::Id* out); +}; + } // namespace mojo #endif // CONTENT_BROWSER_ATTRIBUTION_REPORTING_ATTRIBUTION_INTERNALS_MOJOM_TRAITS_H_
diff --git a/content/browser/attribution_reporting/attribution_manager.h b/content/browser/attribution_reporting/attribution_manager.h index 4e21516..5eb787768 100644 --- a/content/browser/attribution_reporting/attribution_manager.h +++ b/content/browser/attribution_reporting/attribution_manager.h
@@ -91,9 +91,11 @@ base::OnceCallback<void(std::vector<EventAttributionReport>)> callback) = 0; - // Sends all pending reports immediately, and runs |done| once they have all + // Sends the given reports immediately, and runs |done| once they have all // been sent. - virtual void SendReportsForWebUI(base::OnceClosure done) = 0; + virtual void SendReportsForWebUI( + const std::vector<EventAttributionReport::Id>& ids, + base::OnceClosure done) = 0; // Returns the AttributionPolicy that is used to control API policies such // as noise.
diff --git a/content/browser/attribution_reporting/attribution_manager_impl.cc b/content/browser/attribution_reporting/attribution_manager_impl.cc index dd83689..8500138a 100644 --- a/content/browser/attribution_reporting/attribution_manager_impl.cc +++ b/content/browser/attribution_reporting/attribution_manager_impl.cc
@@ -288,12 +288,13 @@ /*max_report_time=*/base::Time::Max(), /*limit=*/1000); } -void AttributionManagerImpl::SendReportsForWebUI(base::OnceClosure done) { - GetAndHandleReports( - base::BindOnce(&AttributionManagerImpl::OnGetReportsToSendFromWebUI, - weak_factory_.GetWeakPtr(), std::move(done)), - /*max_report_time=*/base::Time::Max(), - /*limit=*/-1); +void AttributionManagerImpl::SendReportsForWebUI( + const std::vector<EventAttributionReport::Id>& ids, + base::OnceClosure done) { + attribution_storage_.AsyncCall(&AttributionStorage::GetReports) + .WithArgs(ids) + .Then(base::BindOnce(&AttributionManagerImpl::OnGetReportsToSendFromWebUI, + weak_factory_.GetWeakPtr(), std::move(done))); } const AttributionPolicy& AttributionManagerImpl::GetAttributionPolicy() const {
diff --git a/content/browser/attribution_reporting/attribution_manager_impl.h b/content/browser/attribution_reporting/attribution_manager_impl.h index 604132a5..a98245d 100644 --- a/content/browser/attribution_reporting/attribution_manager_impl.h +++ b/content/browser/attribution_reporting/attribution_manager_impl.h
@@ -112,7 +112,8 @@ void GetPendingReportsForWebUI( base::OnceCallback<void(std::vector<EventAttributionReport>)> callback) override; - void SendReportsForWebUI(base::OnceClosure done) override; + void SendReportsForWebUI(const std::vector<EventAttributionReport::Id>& ids, + base::OnceClosure done) override; const AttributionPolicy& GetAttributionPolicy() const override; void ClearData(base::Time delete_begin, base::Time delete_end,
diff --git a/content/browser/attribution_reporting/attribution_manager_impl_unittest.cc b/content/browser/attribution_reporting/attribution_manager_impl_unittest.cc index 9128f50..17e6f71 100644 --- a/content/browser/attribution_reporting/attribution_manager_impl_unittest.cc +++ b/content/browser/attribution_reporting/attribution_manager_impl_unittest.cc
@@ -738,9 +738,12 @@ attribution_manager_->HandleSource( SourceBuilder().SetExpiry(kImpressionExpiry).Build()); attribution_manager_->HandleTrigger(DefaultTrigger()); + std::vector<EventAttributionReport> reports = StoredReports(); + EXPECT_THAT(reports, SizeIs(1)); EXPECT_THAT(network_sender_->calls(), IsEmpty()); - attribution_manager_->SendReportsForWebUI(base::DoNothing()); + attribution_manager_->SendReportsForWebUI({*reports.front().report_id()}, + base::DoNothing()); task_environment_.FastForwardBy(base::TimeDelta()); EXPECT_THAT(network_sender_->calls(), SizeIs(1)); } @@ -753,9 +756,12 @@ SourceBuilder().SetExpiry(kImpressionExpiry).Build()); attribution_manager_->HandleTrigger(DefaultTrigger()); attribution_manager_->HandleTrigger(DefaultTrigger()); + std::vector<EventAttributionReport> reports = StoredReports(); + EXPECT_THAT(reports, SizeIs(2)); EXPECT_THAT(network_sender_->calls(), IsEmpty()); attribution_manager_->SendReportsForWebUI( + {*reports.front().report_id(), *reports.back().report_id()}, base::BindLambdaForTesting([&]() { callback_calls++; })); task_environment_.FastForwardBy(base::TimeDelta()); EXPECT_THAT(network_sender_->calls(), SizeIs(2)); @@ -1163,7 +1169,8 @@ SourceBuilder().SetExpiry(kImpressionExpiry).Build()); attribution_manager_->HandleTrigger(DefaultTrigger()); - attribution_manager_->SendReportsForWebUI(base::DoNothing()); + attribution_manager_->SendReportsForWebUI({EventAttributionReport::Id(1)}, + base::DoNothing()); task_environment_.FastForwardBy(base::TimeDelta()); EXPECT_THAT(network_sender_->calls(), SizeIs(1));
diff --git a/content/browser/attribution_reporting/attribution_storage.h b/content/browser/attribution_reporting/attribution_storage.h index a64baea..2fdf6ee 100644 --- a/content/browser/attribution_reporting/attribution_storage.h +++ b/content/browser/attribution_reporting/attribution_storage.h
@@ -222,6 +222,12 @@ virtual absl::optional<base::Time> GetNextReportTime(base::Time time) WARN_UNUSED_RESULT = 0; + // Returns the reports with the given IDs. This call is logically const, and + // does not modify the underlying storage. + virtual std::vector<EventAttributionReport> GetReports( + const std::vector<EventAttributionReport::Id>& ids) + WARN_UNUSED_RESULT = 0; + // Returns all active sources in storage. Active sources are all // sources that can still convert. Sources that: are past expiry, // reached the attribution limit, or was marked inactive due to having
diff --git a/content/browser/attribution_reporting/attribution_storage_sql.cc b/content/browser/attribution_reporting/attribution_storage_sql.cc index 265f5d5..389845a 100644 --- a/content/browser/attribution_reporting/attribution_storage_sql.cc +++ b/content/browser/attribution_reporting/attribution_storage_sql.cc
@@ -5,7 +5,9 @@ #include "content/browser/attribution_reporting/attribution_storage_sql.h" #include <stdint.h> + #include <string> +#include <tuple> #include <utility> #include "base/bind.h" @@ -15,7 +17,6 @@ #include "base/containers/flat_set.h" #include "base/files/file_util.h" #include "base/guid.h" -#include "base/ignore_result.h" #include "base/logging.h" #include "base/metrics/histogram_functions.h" #include "base/metrics/histogram_macros.h" @@ -923,6 +924,21 @@ return absl::nullopt; } +std::vector<EventAttributionReport> AttributionStorageSql::GetReports( + const std::vector<EventAttributionReport::Id>& ids) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (!LazyInit(DbCreationPolicy::kIgnoreIfAbsent)) + return {}; + + std::vector<EventAttributionReport> reports; + for (EventAttributionReport::Id id : ids) { + absl::optional<EventAttributionReport> report = GetReport(id); + if (report.has_value()) + reports.push_back(std::move(*report)); + } + return reports; +} + absl::optional<EventAttributionReport> AttributionStorageSql::GetReport( EventAttributionReport::Id conversion_id) { static constexpr char kGetReportSql[] = @@ -1697,7 +1713,7 @@ // or hardware issues, not coding errors at the client level, so displaying // the error would probably lead to confusion. The ignored call signals the // test-expectation framework that the error was handled. - ignore_result(sql::Database::IsExpectedSqliteError(extended_error)); + std::ignore = sql::Database::IsExpectedSqliteError(extended_error); return; }
diff --git a/content/browser/attribution_reporting/attribution_storage_sql.h b/content/browser/attribution_reporting/attribution_storage_sql.h index ec1fa86e3..89b3356 100644 --- a/content/browser/attribution_reporting/attribution_storage_sql.h +++ b/content/browser/attribution_reporting/attribution_storage_sql.h
@@ -87,6 +87,8 @@ base::Time max_report_time, int limit = -1) override; absl::optional<base::Time> GetNextReportTime(base::Time time) override; + std::vector<EventAttributionReport> GetReports( + const std::vector<EventAttributionReport::Id>& ids) override; std::vector<StorableSource> GetActiveSources(int limit = -1) override; bool DeleteReport(EventAttributionReport::Id report_id) override; bool UpdateReportForSendFailure(EventAttributionReport::Id report_id,
diff --git a/content/browser/attribution_reporting/attribution_storage_sql_migrations_unittest.cc b/content/browser/attribution_reporting/attribution_storage_sql_migrations_unittest.cc index fc781ab..2b13614 100644 --- a/content/browser/attribution_reporting/attribution_storage_sql_migrations_unittest.cc +++ b/content/browser/attribution_reporting/attribution_storage_sql_migrations_unittest.cc
@@ -8,7 +8,6 @@ #include "base/files/file_util.h" #include "base/files/scoped_temp_dir.h" #include "base/guid.h" -#include "base/ignore_result.h" #include "base/path_service.h" #include "base/strings/string_util.h" #include "base/test/metrics/histogram_tester.h" @@ -46,9 +45,9 @@ std::make_unique<ConfigurableStorageDelegate>()); // We need to run an operation on storage to force the lazy initialization. - ignore_result( + std::ignore = static_cast<AttributionStorage*>(&storage)->GetAttributionsToReport( - base::Time::Min())); + base::Time::Min()); } base::FilePath DbPath() {
diff --git a/content/browser/attribution_reporting/attribution_test_utils.h b/content/browser/attribution_reporting/attribution_test_utils.h index 1352101..4ef2d5e8 100644 --- a/content/browser/attribution_reporting/attribution_test_utils.h +++ b/content/browser/attribution_reporting/attribution_test_utils.h
@@ -178,7 +178,11 @@ (base::OnceCallback<void(std::vector<EventAttributionReport>)> callback), (override)); - MOCK_METHOD(void, SendReportsForWebUI, (base::OnceClosure done), (override)); + MOCK_METHOD(void, + SendReportsForWebUI, + (const std::vector<EventAttributionReport::Id>& ids, + base::OnceClosure done), + (override)); MOCK_METHOD(void, ClearData,
diff --git a/content/browser/bluetooth/frame_connected_bluetooth_devices_unittest.cc b/content/browser/bluetooth/frame_connected_bluetooth_devices_unittest.cc index bbbdf5d..847c85d 100644 --- a/content/browser/bluetooth/frame_connected_bluetooth_devices_unittest.cc +++ b/content/browser/bluetooth/frame_connected_bluetooth_devices_unittest.cc
@@ -4,7 +4,8 @@ #include "content/browser/bluetooth/frame_connected_bluetooth_devices.h" -#include "base/ignore_result.h" +#include <tuple> + #include "base/memory/raw_ptr.h" #include "base/memory/ref_counted.h" #include "content/browser/bluetooth/web_bluetooth_service_impl.h" @@ -44,7 +45,7 @@ mojo::AssociatedRemote<blink::mojom::WebBluetoothServerClient> CreateServerClient() { mojo::AssociatedRemote<blink::mojom::WebBluetoothServerClient> client; - ignore_result(client.BindNewEndpointAndPassDedicatedReceiver()); + std::ignore = client.BindNewEndpointAndPassDedicatedReceiver(); return client; }
diff --git a/content/browser/browser_main_loop.cc b/content/browser/browser_main_loop.cc index 44dc30b..6846ca3 100644 --- a/content/browser/browser_main_loop.cc +++ b/content/browser/browser_main_loop.cc
@@ -9,6 +9,7 @@ #include <algorithm> #include <memory> #include <string> +#include <tuple> #include <utility> #include <vector> @@ -17,7 +18,6 @@ #include "base/command_line.h" #include "base/cxx17_backports.h" #include "base/feature_list.h" -#include "base/ignore_result.h" #include "base/location.h" #include "base/logging.h" #include "base/memory/memory_pressure_monitor.h" @@ -1178,15 +1178,15 @@ if (audio_manager_ && !audio_manager_->Shutdown()) { // Intentionally leak AudioManager if shutdown failed. // We might run into various CHECK(s) in AudioManager destructor. - ignore_result(audio_manager_.release()); + std::ignore = audio_manager_.release(); // |user_input_monitor_| may be in use by stray streams in case // AudioManager shutdown failed. - ignore_result(user_input_monitor_.release()); + std::ignore = user_input_monitor_.release(); } // Leaking AudioSystem: we cannot correctly destroy it since Audio service // connection in there is bound to IO thread. - ignore_result(audio_system_.release()); + std::ignore = audio_system_.release(); } if (parts_) {
diff --git a/content/browser/child_process_launcher_helper_android.cc b/content/browser/child_process_launcher_helper_android.cc index a90558b5..feb2d49 100644 --- a/content/browser/child_process_launcher_helper_android.cc +++ b/content/browser/child_process_launcher_helper_android.cc
@@ -3,13 +3,13 @@ // found in the LICENSE file. #include <memory> +#include <tuple> #include "base/android/apk_assets.h" #include "base/android/application_status_listener.h" #include "base/android/jni_array.h" #include "base/bind.h" #include "base/i18n/icu_util.h" -#include "base/ignore_result.h" #include "base/logging.h" #include "base/metrics/field_trial.h" #include "base/task/post_task.h" @@ -133,7 +133,7 @@ const auto& region = files_to_register->GetRegionAt(i); bool auto_close = files_to_register->OwnsFD(fd); if (auto_close) { - ignore_result(files_to_register->ReleaseFD(fd).release()); + std::ignore = files_to_register->ReleaseFD(fd).release(); } ScopedJavaLocalRef<jobject> j_file_info =
diff --git a/content/browser/content_security_policy_browsertest.cc b/content/browser/content_security_policy_browsertest.cc index 22d3033f..eb4e163 100644 --- a/content/browser/content_security_policy_browsertest.cc +++ b/content/browser/content_security_policy_browsertest.cc
@@ -2,8 +2,9 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <tuple> + #include "base/files/file_path.h" -#include "base/ignore_result.h" #include "base/path_service.h" #include "base/threading/thread_restrictions.h" #include "content/browser/renderer_host/render_frame_host_impl.h" @@ -200,7 +201,7 @@ // On windows, if `document_url` contains the host part "localhost", the // actual committed URL does not. So we omit EXPECT_TRUE and ignore the // result value here. - ignore_result(NavigateToURL(shell(), document_url)); + std::ignore = NavigateToURL(shell(), document_url); GURL element_url = net::FilePathToFileURL(TestFilePath( testCase.element_name == "iframe" ? "empty.html" : "blank.jpg"));
diff --git a/content/browser/file_system_access/file_system_access_directory_handle_impl_unittest.cc b/content/browser/file_system_access/file_system_access_directory_handle_impl_unittest.cc index ef02afca..0f664820 100644 --- a/content/browser/file_system_access/file_system_access_directory_handle_impl_unittest.cc +++ b/content/browser/file_system_access/file_system_access_directory_handle_impl_unittest.cc
@@ -6,11 +6,11 @@ #include <iterator> #include <string> +#include <tuple> #include "base/bind.h" #include "base/files/file_util.h" #include "base/files/scoped_temp_dir.h" -#include "base/ignore_result.h" #include "base/memory/raw_ptr.h" #include "base/run_loop.h" #include "base/test/bind.h" @@ -211,7 +211,7 @@ // ignore any failures writing these files to disk. EXPECT_TRUE(success) << "Failed to create file " << file_path; #else - ignore_result(success); + std::ignore = success; #endif }
diff --git a/content/browser/find_request_manager.cc b/content/browser/find_request_manager.cc index 4b3ad9b..57b78a3 100644 --- a/content/browser/find_request_manager.cc +++ b/content/browser/find_request_manager.cc
@@ -174,6 +174,16 @@ rfh->GetLastCommittedOrigin()); } +// kMinKeystrokesWithoutDelay should be high enough that script in the page +// can't provide every possible search result at the same time. +constexpr int kMinKeystrokesWithoutDelay = 4; + +// The delay for very short queries, before sending find requests. This should +// be higher than the duration in between two keystrokes. This is based on +// WebCore.FindInPage.DurationBetweenKeystrokes metrics, this is higher than +// 90% of them. +constexpr int kDelayMs = 400; + } // namespace // Observes searched WebContentses for RenderFrameHost state updates, including @@ -286,7 +296,8 @@ void FindRequestManager::Find(int request_id, const std::u16string& search_text, - blink::mojom::FindOptionsPtr options) { + blink::mojom::FindOptionsPtr options, + bool skip_delay) { // Every find request must have a unique ID, and these IDs must strictly // increase so that newer requests always have greater IDs than older // requests. @@ -311,6 +322,41 @@ last_searched_text_ = search_text; } + if (skip_delay) { + delayed_find_task_.Cancel(); + EmitFindRequest(request_id, search_text, std::move(options)); + return; + } + + if (!options->new_session) { + // If the user presses enter while we are waiting for a delayed find, then + // run the find now to improve responsiveness. + if (!delayed_find_task_.IsCancelled()) { + delayed_find_task_.callback().Run(); + } else { + EmitFindRequest(request_id, search_text, std::move(options)); + } + return; + } + + if (search_text.length() < kMinKeystrokesWithoutDelay) { + delayed_find_task_.Reset(base::BindOnce( + &FindRequestManager::EmitFindRequest, weak_factory_.GetWeakPtr(), + request_id, search_text, std::move(options))); + base::ThreadTaskRunnerHandle::Get()->PostDelayedTask( + FROM_HERE, delayed_find_task_.callback(), base::Milliseconds(kDelayMs)); + return; + } + + // If we aren't going to delay, then clear any previous attempts to delay. + delayed_find_task_.Cancel(); + + EmitFindRequest(request_id, search_text, std::move(options)); +} + +void FindRequestManager::EmitFindRequest(int request_id, + const std::u16string& search_text, + blink::mojom::FindOptionsPtr options) { // If this is a new find session, clear any queued requests from last session. if (options->new_session) find_request_queue_ = base::queue<FindRequest>();
diff --git a/content/browser/find_request_manager.h b/content/browser/find_request_manager.h index f9867fc..9e02495 100644 --- a/content/browser/find_request_manager.h +++ b/content/browser/find_request_manager.h
@@ -10,6 +10,7 @@ #include <unordered_set> #include <vector> +#include "base/cancelable_callback.h" #include "base/containers/queue.h" #include "base/memory/raw_ptr.h" #include "build/build_config.h" @@ -45,7 +46,8 @@ // |options|. |request_id| uniquely identifies the find request. void Find(int request_id, const std::u16string& search_text, - blink::mojom::FindOptionsPtr options); + blink::mojom::FindOptionsPtr options, + bool skip_delay = false); // Stops the active find session and clears the general highlighting of the // matches. |action| determines whether the last active match (if any) will be @@ -216,6 +218,10 @@ // callback if the each RenderFrameHost is alive and active. void ForEachAddedFindInPageRenderFrameHost(FrameIterationCallback callback); + void EmitFindRequest(int request_id, + const std::u16string& search_text, + blink::mojom::FindOptionsPtr options); + #if defined(OS_ANDROID) // Called when a nearest find result reply is no longer pending for a frame. void RemoveNearestFindResultPendingReply(RenderFrameHost* rfh); @@ -364,8 +370,12 @@ base::TimeTicks last_time_typed_; std::u16string last_searched_text_; + base::CancelableOnceClosure delayed_find_task_; + CreateFindInPageClientFunction create_find_in_page_client_for_testing_ = nullptr; + + base::WeakPtrFactory<FindRequestManager> weak_factory_{this}; }; } // namespace content
diff --git a/content/browser/indexed_db/indexed_db_backing_store.cc b/content/browser/indexed_db/indexed_db_backing_store.cc index bd4e676..a73549d 100644 --- a/content/browser/indexed_db/indexed_db_backing_store.cc +++ b/content/browser/indexed_db/indexed_db_backing_store.cc
@@ -5,13 +5,13 @@ #include "content/browser/indexed_db/indexed_db_backing_store.h" #include <algorithm> +#include <tuple> #include <utility> #include "base/bind.h" #include "base/dcheck_is_on.h" #include "base/files/file_path.h" #include "base/format_macros.h" -#include "base/ignore_result.h" #include "base/json/json_reader.h" #include "base/json/json_writer.h" #include "base/logging.h" @@ -713,11 +713,11 @@ if (!found) { // Initialize new backing store. db_schema_version = indexed_db::kLatestKnownSchemaVersion; - ignore_result( - PutInt(write_batch.get(), schema_version_key, db_schema_version)); + std::ignore = + PutInt(write_batch.get(), schema_version_key, db_schema_version); db_data_version = latest_known_data_version; - ignore_result( - PutInt(write_batch.get(), data_version_key, db_data_version.Encode())); + std::ignore = + PutInt(write_batch.get(), data_version_key, db_data_version.Encode()); // If a blob directory already exists for this database, blow it away. It's // leftover from a partially-purged previous generation of data. if (filesystem_proxy_ && @@ -775,8 +775,8 @@ // Up to date. Nothing to do. } else if (latest_known_data_version.IsAtLeast(db_data_version)) { db_data_version = latest_known_data_version; - ignore_result( - PutInt(write_batch.get(), data_version_key, db_data_version.Encode())); + std::ignore = + PutInt(write_batch.get(), data_version_key, db_data_version.Encode()); } else { // |db_data_version| is in the future according to at least one component. INTERNAL_CONSISTENCY_ERROR(SET_UP_METADATA); @@ -3071,7 +3071,7 @@ const std::string data_version_key = DataVersionKey::Encode(); Status s; - ignore_result(PutInt(write_batch, schema_version_key, db_schema_version)); + std::ignore = PutInt(write_batch, schema_version_key, db_schema_version); const std::string start_key = DatabaseNameKey::EncodeMinKeyForOrigin(origin_identifier_); const std::string stop_key = @@ -3094,8 +3094,8 @@ } std::string version_key = DatabaseMetaDataKey::Encode( database_id, DatabaseMetaDataKey::USER_VERSION); - ignore_result(PutVarInt(write_batch, version_key, - IndexedDBDatabaseMetadata::DEFAULT_VERSION)); + std::ignore = PutVarInt(write_batch, version_key, + IndexedDBDatabaseMetadata::DEFAULT_VERSION); } return s; @@ -3109,9 +3109,9 @@ const std::string data_version_key = DataVersionKey::Encode(); Status s; - ignore_result(PutInt(write_batch, schema_version_key, db_schema_version)); - ignore_result(PutInt(write_batch, data_version_key, - IndexedDBDataFormatVersion::GetCurrent().Encode())); + std::ignore = PutInt(write_batch, schema_version_key, db_schema_version); + std::ignore = PutInt(write_batch, data_version_key, + IndexedDBDataFormatVersion::GetCurrent().Encode()); return s; } @@ -3146,7 +3146,7 @@ if (storage_key_.origin().host() != "docs.google.com") return InternalInconsistencyStatus(); } else { - ignore_result(PutInt(write_batch, schema_version_key, db_schema_version)); + std::ignore = PutInt(write_batch, schema_version_key, db_schema_version); } return s; @@ -3165,7 +3165,7 @@ INTERNAL_CONSISTENCY_ERROR(SET_UP_METADATA); return InternalInconsistencyStatus(); } - ignore_result(PutInt(write_batch, schema_version_key, db_schema_version)); + std::ignore = PutInt(write_batch, schema_version_key, db_schema_version); // Delete all empty files that resulted from the migration to v4. If this // fails it's not a big deal. @@ -3196,7 +3196,7 @@ return InternalInconsistencyStatus(); } } - ignore_result(PutInt(write_batch, schema_version_key, db_schema_version)); + std::ignore = PutInt(write_batch, schema_version_key, db_schema_version); return s; }
diff --git a/content/browser/indexed_db/indexed_db_backing_store_unittest.cc b/content/browser/indexed_db/indexed_db_backing_store_unittest.cc index 0fe57acb..687607e0 100644 --- a/content/browser/indexed_db/indexed_db_backing_store_unittest.cc +++ b/content/browser/indexed_db/indexed_db_backing_store_unittest.cc
@@ -8,6 +8,7 @@ #include <stdint.h> #include <string> +#include <tuple> #include <utility> #include "base/barrier_closure.h" @@ -19,7 +20,6 @@ #include "base/files/file_util.h" #include "base/files/scoped_temp_dir.h" #include "base/guid.h" -#include "base/ignore_result.h" #include "base/memory/raw_ptr.h" #include "base/notreached.h" #include "base/strings/string_number_conversions.h" @@ -1786,7 +1786,7 @@ // Set the schema to 2, which was before blob support. std::unique_ptr<LevelDBWriteBatch> write_batch = LevelDBWriteBatch::Create(); const std::string schema_version_key = SchemaVersionKey::Encode(); - ignore_result(indexed_db::PutInt(write_batch.get(), schema_version_key, 2)); + std::ignore = indexed_db::PutInt(write_batch.get(), schema_version_key, 2); ASSERT_TRUE(backing_store()->db()->Write(write_batch.get()).ok()); task_environment_.RunUntilIdle(); @@ -1902,7 +1902,7 @@ // Set the schema to 2, which was before blob support. std::unique_ptr<LevelDBWriteBatch> write_batch = LevelDBWriteBatch::Create(); const std::string schema_version_key = SchemaVersionKey::Encode(); - ignore_result(indexed_db::PutInt(write_batch.get(), schema_version_key, 2)); + std::ignore = indexed_db::PutInt(write_batch.get(), schema_version_key, 2); ASSERT_TRUE(backing_store()->db()->Write(write_batch.get()).ok()); // Clean up on the IDB sequence.
diff --git a/content/browser/indexed_db/indexed_db_context_impl.cc b/content/browser/indexed_db/indexed_db_context_impl.cc index 6a7d303..24577f10 100644 --- a/content/browser/indexed_db/indexed_db_context_impl.cc +++ b/content/browser/indexed_db/indexed_db_context_impl.cc
@@ -870,8 +870,10 @@ const blink::StorageKey& storage_key, mojo::PendingReceiver<blink::mojom::IDBFactory> receiver, storage::QuotaErrorOr<storage::BucketInfo> result) { - DCHECK(result.ok()); - dispatcher_host_.AddReceiver(storage_key, std::move(receiver)); + absl::optional<storage::BucketLocator> bucket = + result.ok() ? absl::make_optional(result->ToBucketLocator()) + : absl::nullopt; + dispatcher_host_.AddReceiver(storage_key, bucket, std::move(receiver)); } void IndexedDBContextImpl::ShutdownOnIDBSequence() {
diff --git a/content/browser/indexed_db/indexed_db_context_unittest.cc b/content/browser/indexed_db/indexed_db_context_unittest.cc index e1a6c10..8401f2c 100644 --- a/content/browser/indexed_db/indexed_db_context_unittest.cc +++ b/content/browser/indexed_db/indexed_db_context_unittest.cc
@@ -2,9 +2,13 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <memory> + #include "base/barrier_closure.h" #include "base/files/scoped_temp_dir.h" #include "base/run_loop.h" +#include "base/test/bind.h" +#include "base/test/gmock_callback_support.h" #include "base/test/task_environment.h" #include "base/threading/thread.h" #include "base/time/default_clock.h" @@ -14,10 +18,13 @@ #include "content/browser/indexed_db/indexed_db_context_impl.h" #include "content/browser/indexed_db/indexed_db_factory_impl.h" #include "content/browser/indexed_db/mock_indexed_db_callbacks.h" +#include "content/browser/indexed_db/mock_mojo_indexed_db_callbacks.h" +#include "content/browser/indexed_db/mock_mojo_indexed_db_database_callbacks.h" #include "mojo/public/cpp/bindings/remote.h" #include "storage/browser/test/mock_quota_manager.h" #include "storage/browser/test/mock_quota_manager_proxy.h" #include "storage/browser/test/quota_manager_proxy_sync.h" +#include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" #include "third_party/blink/public/common/storage_key/storage_key.h" @@ -103,4 +110,84 @@ EXPECT_GT(result->id.value(), 0); } +TEST_F(IndexedDBContextTest, GetDefaultBucketError) { + // Disable database so it will return errors when getting the default bucket. + quota_manager_->SetDisableDatabase(true); + + mojo::Remote<blink::mojom::IDBFactory> example_remote; + indexed_db_context_->BindIndexedDB( + example_storage_key_, example_remote.BindNewPipeAndPassReceiver()); + + // IDBFactory::GetDatabaseInfo + base::RunLoop loop_1; + auto mock_callbacks = + std::make_unique<testing::StrictMock<MockMojoIndexedDBCallbacks>>(); + EXPECT_CALL(*mock_callbacks, + Error(blink::mojom::IDBException::kUnknownError, + std::u16string( + u"Internal error retrieving bucket data directory."))) + .Times(1) + .WillOnce(base::test::RunClosure(loop_1.QuitClosure())); + + example_remote->GetDatabaseInfo(mock_callbacks->CreateInterfacePtrAndBind()); + loop_1.Run(); + + testing::Mock::VerifyAndClear(&mock_callbacks); + + // IDBFactory::Open + base::RunLoop loop_2; + mock_callbacks = + std::make_unique<testing::StrictMock<MockMojoIndexedDBCallbacks>>(); + auto database_callbacks = + std::make_unique<MockMojoIndexedDBDatabaseCallbacks>(); + auto transaction_remote = + mojo::AssociatedRemote<blink::mojom::IDBTransaction>(); + EXPECT_CALL(*mock_callbacks, + Error(blink::mojom::IDBException::kUnknownError, + std::u16string( + u"Internal error retrieving bucket data directory."))) + .Times(1) + .WillOnce(base::test::RunClosure(loop_2.QuitClosure())); + + example_remote->Open(mock_callbacks->CreateInterfacePtrAndBind(), + database_callbacks->CreateInterfacePtrAndBind(), + u"database_name", /*version=*/1, + transaction_remote.BindNewEndpointAndPassReceiver(), + /*transaction_id=*/0); + loop_2.Run(); + + // IDBFactory::DeleteDatabase + base::RunLoop loop_3; + mock_callbacks = + std::make_unique<testing::StrictMock<MockMojoIndexedDBCallbacks>>(); + EXPECT_CALL(*mock_callbacks, + Error(blink::mojom::IDBException::kUnknownError, + std::u16string( + u"Internal error retrieving bucket data directory."))) + .Times(1) + .WillOnce(base::test::RunClosure(loop_3.QuitClosure())); + + example_remote->DeleteDatabase(mock_callbacks->CreateInterfacePtrAndBind(), + u"database_name", /*force_close=*/true); + loop_3.Run(); + + // IDBFactory::AbortTransactionsAndCompactDatabase + base::RunLoop loop_4; + example_remote->AbortTransactionsAndCompactDatabase( + base::BindLambdaForTesting([&](blink::mojom::IDBStatus status) { + EXPECT_EQ(status, blink::mojom::IDBStatus::NotFound); + loop_4.Quit(); + })); + loop_4.Run(); + + // IDBFactory::AbortTransactionsForDatabase + base::RunLoop loop_5; + example_remote->AbortTransactionsForDatabase( + base::BindLambdaForTesting([&](blink::mojom::IDBStatus status) { + EXPECT_EQ(status, blink::mojom::IDBStatus::NotFound); + loop_5.Quit(); + })); + loop_5.Run(); +} + } // namespace content
diff --git a/content/browser/indexed_db/indexed_db_dispatcher_host.cc b/content/browser/indexed_db/indexed_db_dispatcher_host.cc index f36e6666..51ec8705 100644 --- a/content/browser/indexed_db/indexed_db_dispatcher_host.cc +++ b/content/browser/indexed_db/indexed_db_dispatcher_host.cc
@@ -214,10 +214,13 @@ void IndexedDBDispatcherHost::AddReceiver( const blink::StorageKey& storage_key, + absl::optional<storage::BucketLocator> bucket, mojo::PendingReceiver<blink::mojom::IDBFactory> pending_receiver) { DCHECK(IDBTaskRunner()->RunsTasksInCurrentSequence()); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - receivers_.Add(this, std::move(pending_receiver), storage_key); + if (bucket.has_value()) + DCHECK_EQ(bucket->storage_key, storage_key); + receivers_.Add(this, std::move(pending_receiver), bucket); } void IndexedDBDispatcherHost::AddDatabaseBinding( @@ -274,10 +277,23 @@ pending_callbacks) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - const auto& storage_key = receivers_.current_context(); + // Return error if failed to retrieve bucket from the QuotaManager. + if (!receivers_.current_context().has_value()) { + auto callbacks = base::MakeRefCounted<IndexedDBCallbacks>( + this->AsWeakPtr(), blink::StorageKey(), std::move(pending_callbacks), + IDBTaskRunner()); + IndexedDBDatabaseError error = IndexedDBDatabaseError( + blink::mojom::IDBException::kUnknownError, + u"Internal error retrieving bucket data directory."); + callbacks->OnError(error); + return; + } + + const auto storage_key = receivers_.current_context()->storage_key; auto callbacks = base::MakeRefCounted<IndexedDBCallbacks>( this->AsWeakPtr(), storage_key, std::move(pending_callbacks), IDBTaskRunner()); + base::FilePath indexed_db_path = indexed_db_context_->data_path(); indexed_db_context_->GetIDBFactory()->GetDatabaseInfo( std::move(callbacks), storage_key, indexed_db_path); @@ -294,7 +310,19 @@ int64_t transaction_id) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - const auto& storage_key = receivers_.current_context(); + // Return error if failed to retrieve bucket from the QuotaManager. + if (!receivers_.current_context().has_value()) { + auto callbacks = base::MakeRefCounted<IndexedDBCallbacks>( + this->AsWeakPtr(), blink::StorageKey(), std::move(pending_callbacks), + IDBTaskRunner()); + IndexedDBDatabaseError error = IndexedDBDatabaseError( + blink::mojom::IDBException::kUnknownError, + u"Internal error retrieving bucket data directory."); + callbacks->OnError(error); + return; + } + + const auto storage_key = receivers_.current_context()->storage_key; auto callbacks = base::MakeRefCounted<IndexedDBCallbacks>( this->AsWeakPtr(), storage_key, std::move(pending_callbacks), IDBTaskRunner()); @@ -310,6 +338,7 @@ std::make_unique<IndexedDBPendingConnection>( std::move(callbacks), std::move(database_callbacks), transaction_id, version, std::move(create_transaction_callback)); + // TODO(dgrogan): Don't let a non-existing database be opened (and therefore // created) if this origin is already over quota. indexed_db_context_->GetIDBFactory()->Open(name, std::move(connection), @@ -322,10 +351,23 @@ bool force_close) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - const auto& storage_key = receivers_.current_context(); + // Return error if failed to retrieve bucket from the QuotaManager. + if (!receivers_.current_context().has_value()) { + auto callbacks = base::MakeRefCounted<IndexedDBCallbacks>( + this->AsWeakPtr(), blink::StorageKey(), std::move(pending_callbacks), + IDBTaskRunner()); + IndexedDBDatabaseError error = IndexedDBDatabaseError( + blink::mojom::IDBException::kUnknownError, + u"Internal error retrieving bucket data directory."); + callbacks->OnError(error); + return; + } + + const auto storage_key = receivers_.current_context()->storage_key; auto callbacks = base::MakeRefCounted<IndexedDBCallbacks>( this->AsWeakPtr(), storage_key, std::move(pending_callbacks), IDBTaskRunner()); + base::FilePath indexed_db_path = indexed_db_context_->data_path(); indexed_db_context_->GetIDBFactory()->DeleteDatabase( name, std::move(callbacks), storage_key, indexed_db_path, force_close); @@ -335,9 +377,16 @@ AbortTransactionsAndCompactDatabaseCallback mojo_callback) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - const auto& storage_key = receivers_.current_context(); + // Return error if failed to retrieve bucket from the QuotaManager. + if (!receivers_.current_context().has_value()) { + std::move(mojo_callback).Run(blink::mojom::IDBStatus::NotFound); + return; + } + + const auto storage_key = receivers_.current_context()->storage_key; base::OnceCallback<void(leveldb::Status)> callback_on_io = base::BindOnce( &CallCompactionStatusCallbackOnIDBThread, std::move(mojo_callback)); + indexed_db_context_->GetIDBFactory()->AbortTransactionsAndCompactDatabase( std::move(callback_on_io), storage_key); } @@ -346,9 +395,16 @@ AbortTransactionsForDatabaseCallback mojo_callback) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - const auto& storage_key = receivers_.current_context(); + // Return error if failed to retrieve bucket from the QuotaManager. + if (!receivers_.current_context().has_value()) { + std::move(mojo_callback).Run(blink::mojom::IDBStatus::NotFound); + return; + } + + const auto storage_key = receivers_.current_context()->storage_key; base::OnceCallback<void(leveldb::Status)> callback_on_io = base::BindOnce( &CallAbortStatusCallbackOnIDBThread, std::move(mojo_callback)); + indexed_db_context_->GetIDBFactory()->AbortTransactionsForDatabase( std::move(callback_on_io), storage_key); }
diff --git a/content/browser/indexed_db/indexed_db_dispatcher_host.h b/content/browser/indexed_db/indexed_db_dispatcher_host.h index 4768706..c96115e 100644 --- a/content/browser/indexed_db/indexed_db_dispatcher_host.h +++ b/content/browser/indexed_db/indexed_db_dispatcher_host.h
@@ -14,6 +14,7 @@ #include <vector> #include "base/sequence_checker.h" +#include "components/services/storage/public/cpp/buckets/bucket_locator.h" #include "components/services/storage/public/mojom/blob_storage_context.mojom-forward.h" #include "components/services/storage/public/mojom/file_system_access_context.mojom-forward.h" #include "content/browser/indexed_db/indexed_db_external_object.h" @@ -51,6 +52,7 @@ void AddReceiver( const blink::StorageKey& storage_key, + absl::optional<storage::BucketLocator> bucket, mojo::PendingReceiver<blink::mojom::IDBFactory> pending_receiver); void AddDatabaseBinding( @@ -136,7 +138,9 @@ // Shared task runner used to read blob files on. const scoped_refptr<base::TaskRunner> file_task_runner_; - mojo::ReceiverSet<blink::mojom::IDBFactory, blink::StorageKey> receivers_; + mojo::ReceiverSet<blink::mojom::IDBFactory, + absl::optional<storage::BucketLocator>> + receivers_; mojo::UniqueAssociatedReceiverSet<blink::mojom::IDBDatabase> database_receivers_; mojo::UniqueAssociatedReceiverSet<blink::mojom::IDBCursor> cursor_receivers_;
diff --git a/content/browser/indexed_db/indexed_db_dispatcher_host_unittest.cc b/content/browser/indexed_db/indexed_db_dispatcher_host_unittest.cc index c63e213..fc60b471 100644 --- a/content/browser/indexed_db/indexed_db_dispatcher_host_unittest.cc +++ b/content/browser/indexed_db/indexed_db_dispatcher_host_unittest.cc
@@ -4,11 +4,12 @@ #include "content/browser/indexed_db/indexed_db_dispatcher_host.h" +#include <tuple> + #include "base/barrier_closure.h" #include "base/bind.h" #include "base/callback.h" #include "base/files/scoped_temp_dir.h" -#include "base/ignore_result.h" #include "base/memory/ref_counted.h" #include "base/run_loop.h" #include "base/strings/utf_string_conversions.h" @@ -535,7 +536,7 @@ mojo::PendingRemote<blink::mojom::Blob> blob; // Ignore the result of InitWithNewPipeAndPassReceiver, to end up with // an invalid blob. - ignore_result(blob.InitWithNewPipeAndPassReceiver()); + std::ignore = blob.InitWithNewPipeAndPassReceiver(); external_objects.push_back( blink::mojom::IDBExternalObject::NewBlobOrFile( blink::mojom::IDBBlobInfo::New(std::move(blob), "fakeUUID",
diff --git a/content/browser/indexed_db/indexed_db_leveldb_coding_decodeidbkey_fuzzer.cc b/content/browser/indexed_db/indexed_db_leveldb_coding_decodeidbkey_fuzzer.cc index c27e56f..c44f6252 100644 --- a/content/browser/indexed_db/indexed_db_leveldb_coding_decodeidbkey_fuzzer.cc +++ b/content/browser/indexed_db/indexed_db_leveldb_coding_decodeidbkey_fuzzer.cc
@@ -5,14 +5,15 @@ #include <stddef.h> #include <stdint.h> -#include "base/ignore_result.h" +#include <tuple> + #include "content/browser/indexed_db/indexed_db_leveldb_coding.h" #include "third_party/blink/public/common/indexeddb/indexeddb_key.h" extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { base::StringPiece key_str_piece(reinterpret_cast<const char*>(data), size); auto indexed_db_key = std::make_unique<blink::IndexedDBKey>(); - ignore_result(content::DecodeIDBKey(&key_str_piece, &indexed_db_key)); + std::ignore = content::DecodeIDBKey(&key_str_piece, &indexed_db_key); // Ensure that encoding |indexed_db_key| produces the same result. std::string result;
diff --git a/content/browser/indexed_db/indexed_db_leveldb_coding_decodeidbkeypath_fuzzer.cc b/content/browser/indexed_db/indexed_db_leveldb_coding_decodeidbkeypath_fuzzer.cc index 739362c0..bbee894e 100644 --- a/content/browser/indexed_db/indexed_db_leveldb_coding_decodeidbkeypath_fuzzer.cc +++ b/content/browser/indexed_db/indexed_db_leveldb_coding_decodeidbkeypath_fuzzer.cc
@@ -5,7 +5,8 @@ #include <stddef.h> #include <stdint.h> -#include "base/ignore_result.h" +#include <tuple> + #include "content/browser/indexed_db/indexed_db_leveldb_coding.h" #include "third_party/blink/public/common/indexeddb/indexeddb_key_path.h" @@ -13,8 +14,8 @@ base::StringPiece key_path_str_piece(reinterpret_cast<const char*>(data), size); blink::IndexedDBKeyPath indexed_db_key_path; - ignore_result( - content::DecodeIDBKeyPath(&key_path_str_piece, &indexed_db_key_path)); + std::ignore = + content::DecodeIDBKeyPath(&key_path_str_piece, &indexed_db_key_path); // Ensure that encoding |indexed_db_key_path| produces the same result. std::string result;
diff --git a/content/browser/indexed_db/indexed_db_leveldb_coding_encodeidbkey_fuzzer.cc b/content/browser/indexed_db/indexed_db_leveldb_coding_encodeidbkey_fuzzer.cc index 43bdf877..670e8e1 100644 --- a/content/browser/indexed_db/indexed_db_leveldb_coding_encodeidbkey_fuzzer.cc +++ b/content/browser/indexed_db/indexed_db_leveldb_coding_encodeidbkey_fuzzer.cc
@@ -7,7 +7,8 @@ #include <fuzzer/FuzzedDataProvider.h> -#include "base/ignore_result.h" +#include <tuple> + #include "base/strings/string_piece.h" #include "base/strings/utf_string_conversions.h" #include "content/browser/indexed_db/indexed_db_leveldb_coding.h" @@ -100,7 +101,7 @@ // Ensure that |result| can be decoded back into the original key. auto decoded_key = std::make_unique<IndexedDBKey>(); auto result_str_piece = base::StringPiece(result); - ignore_result(content::DecodeIDBKey(&result_str_piece, &decoded_key)); + std::ignore = content::DecodeIDBKey(&result_str_piece, &decoded_key); assert(decoded_key->Equals(key)); return 0; }
diff --git a/content/browser/indexed_db/indexed_db_leveldb_coding_encodeidbkeypath_fuzzer.cc b/content/browser/indexed_db/indexed_db_leveldb_coding_encodeidbkeypath_fuzzer.cc index 641ac7f..b48b541 100644 --- a/content/browser/indexed_db/indexed_db_leveldb_coding_encodeidbkeypath_fuzzer.cc +++ b/content/browser/indexed_db/indexed_db_leveldb_coding_encodeidbkeypath_fuzzer.cc
@@ -7,7 +7,8 @@ #include <fuzzer/FuzzedDataProvider.h> -#include "base/ignore_result.h" +#include <tuple> + #include "base/strings/utf_string_conversions.h" #include "content/browser/indexed_db/indexed_db_leveldb_coding.h" #include "third_party/blink/public/common/indexeddb/indexeddb_key_path.h" @@ -56,8 +57,7 @@ // Ensure that |result| can be decoded back into the original key path. IndexedDBKeyPath decoded_key_path; auto result_str_piece = base::StringPiece(result); - ignore_result( - content::DecodeIDBKeyPath(&result_str_piece, &decoded_key_path)); + std::ignore = content::DecodeIDBKeyPath(&result_str_piece, &decoded_key_path); assert(decoded_key_path == key_path); return 0; }
diff --git a/content/browser/isolated_origin_browsertest.cc b/content/browser/isolated_origin_browsertest.cc index 2d859c2..81abc90 100644 --- a/content/browser/isolated_origin_browsertest.cc +++ b/content/browser/isolated_origin_browsertest.cc
@@ -3,11 +3,11 @@ // found in the LICENSE file. #include <sstream> +#include <tuple> #include <vector> #include "base/bind.h" #include "base/command_line.h" -#include "base/ignore_result.h" #include "base/memory/raw_ptr.h" #include "base/strings/string_util.h" #include "base/strings/stringprintf.h" @@ -3437,7 +3437,7 @@ // Bind the real DomStorage implementation. mojo::PendingRemote<blink::mojom::DomStorageClient> unused_client; - ignore_result(unused_client.InitWithNewPipeAndPassReceiver()); + std::ignore = unused_client.InitWithNewPipeAndPassReceiver(); mojo::ReceiverId receiver_id = storage_partition->BindDomStorage( rph->GetID(), std::move(receiver), std::move(unused_client)); @@ -3559,11 +3559,11 @@ content::RenderProcessHostBadIpcMessageWaiter kill_waiter( web_contents()->GetMainFrame()->GetProcess()); - // Use ignore_result here, since on Android the renderer process is + // Use std::ignore here, since on Android the renderer process is // terminated, but ExecuteScript still returns true. It properly returns // false on all other platforms. - ignore_result( - ExecJs(web_contents()->GetMainFrame(), "sessionStorage.length;")); + std::ignore = + ExecJs(web_contents()->GetMainFrame(), "sessionStorage.length;"); EXPECT_EQ(bad_message::RPH_MOJO_PROCESS_ERROR, kill_waiter.Wait()); } @@ -3634,7 +3634,7 @@ EXPECT_EQ(1, EvalJs(web_contents()->GetMainFrame(), "sessionStorage.length")); content::RenderProcessHostBadIpcMessageWaiter kill_waiter( ChildFrameAt(shell(), 0)->GetProcess()); - ignore_result(ExecJs(ChildFrameAt(shell(), 0), "sessionStorage.length")); + std::ignore = ExecJs(ChildFrameAt(shell(), 0), "sessionStorage.length"); EXPECT_EQ(bad_message::RPH_MOJO_PROCESS_ERROR, kill_waiter.Wait()); // The subframe has crashed, but the main frame should still be alive and // working. @@ -3706,7 +3706,7 @@ EXPECT_EQ(1, EvalJs(web_contents()->GetMainFrame(), "localStorage.length")); content::RenderProcessHostBadIpcMessageWaiter kill_waiter( ChildFrameAt(shell(), 0)->GetProcess()); - ignore_result(ExecJs(ChildFrameAt(shell(), 0), "localStorage.length")); + std::ignore = ExecJs(ChildFrameAt(shell(), 0), "localStorage.length"); EXPECT_EQ(bad_message::RPH_MOJO_PROCESS_ERROR, kill_waiter.Wait()); // The subframe has crashed, but the main frame should still be alive and // working. @@ -3733,11 +3733,11 @@ content::RenderProcessHostBadIpcMessageWaiter kill_waiter( shell()->web_contents()->GetMainFrame()->GetProcess()); - // Use ignore_result here, since on Android the renderer process is + // Use std::ignore here, since on Android the renderer process is // terminated, but ExecuteScript still returns true. It properly returns // false on all other platforms. - ignore_result( - ExecJs(shell()->web_contents()->GetMainFrame(), "localStorage.length;")); + std::ignore = + ExecJs(shell()->web_contents()->GetMainFrame(), "localStorage.length;"); EXPECT_EQ(bad_message::RPH_MOJO_PROCESS_ERROR, kill_waiter.Wait()); } @@ -3774,11 +3774,11 @@ content::RenderProcessHostBadIpcMessageWaiter kill_waiter( shell()->web_contents()->GetMainFrame()->GetProcess()); - // Use ignore_result here, since on Android the renderer process is + // Use std::ignore here, since on Android the renderer process is // terminated, but ExecuteScript still returns true. It properly returns // false on all other platforms. - ignore_result( - ExecJs(shell()->web_contents()->GetMainFrame(), "localStorage.length;")); + std::ignore = + ExecJs(shell()->web_contents()->GetMainFrame(), "localStorage.length;"); EXPECT_EQ(bad_message::RPH_MOJO_PROCESS_ERROR, kill_waiter.Wait()); } @@ -3801,11 +3801,11 @@ content::RenderProcessHostBadIpcMessageWaiter kill_waiter( shell()->web_contents()->GetMainFrame()->GetProcess()); - // Use ignore_result here, since on Android the renderer process is + // Use std::ignore here, since on Android the renderer process is // terminated, but ExecuteScript still returns true. It properly returns // false on all other platforms. - ignore_result( - ExecJs(shell()->web_contents()->GetMainFrame(), "localStorage.length;")); + std::ignore = + ExecJs(shell()->web_contents()->GetMainFrame(), "localStorage.length;"); EXPECT_EQ(bad_message::RPH_MOJO_PROCESS_ERROR, kill_waiter.Wait()); }
diff --git a/content/browser/media/forwarding_audio_stream_factory_unittest.cc b/content/browser/media/forwarding_audio_stream_factory_unittest.cc index 935e6db6..4a337de 100644 --- a/content/browser/media/forwarding_audio_stream_factory_unittest.cc +++ b/content/browser/media/forwarding_audio_stream_factory_unittest.cc
@@ -5,11 +5,11 @@ #include "content/browser/media/forwarding_audio_stream_factory.h" #include <memory> +#include <tuple> #include <utility> #include "base/bind.h" #include "base/callback_helpers.h" -#include "base/ignore_result.h" #include "base/memory/raw_ptr.h" #include "base/run_loop.h" #include "base/test/mock_callback.h" @@ -277,7 +277,7 @@ std::move(broker_factory_)); EXPECT_CALL(*broker, CreateStream(NotNull())); - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateInputStream(main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), kInputDeviceId, kParams, kSharedMemoryCount, kEnableAgc, @@ -301,7 +301,7 @@ std::make_unique<MockBrokerFactory>()); EXPECT_CALL(*broker, CreateStream(NotNull())); - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateLoopbackStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), source_factory.core(), kParams, kSharedMemoryCount, kMuteSource, @@ -318,7 +318,7 @@ std::move(broker_factory_)); EXPECT_CALL(*broker, CreateStream(NotNull())); - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateOutputStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), kOutputDeviceId, kParams, std::move(client)); @@ -339,7 +339,7 @@ EXPECT_CALL(*main_rfh_broker, CreateStream(NotNull())); mojo::PendingRemote<blink::mojom::RendererAudioInputStreamFactoryClient> client; - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateInputStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), kInputDeviceId, kParams, kSharedMemoryCount, kEnableAgc, @@ -350,7 +350,7 @@ EXPECT_CALL(*other_rfh_broker, CreateStream(NotNull())); mojo::PendingRemote<blink::mojom::RendererAudioInputStreamFactoryClient> client; - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateInputStream( other_rfh()->GetProcess()->GetID(), other_rfh()->GetRoutingID(), kInputDeviceId, kParams, kSharedMemoryCount, kEnableAgc, @@ -384,7 +384,7 @@ EXPECT_CALL(*main_rfh_broker, CreateStream(NotNull())); mojo::PendingRemote<blink::mojom::RendererAudioInputStreamFactoryClient> client; - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateLoopbackStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), source_factory.core(), kParams, kSharedMemoryCount, kMuteSource, @@ -395,7 +395,7 @@ EXPECT_CALL(*other_rfh_broker, CreateStream(NotNull())); mojo::PendingRemote<blink::mojom::RendererAudioInputStreamFactoryClient> client; - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateLoopbackStream( other_rfh()->GetProcess()->GetID(), other_rfh()->GetRoutingID(), source_factory.core(), kParams, kSharedMemoryCount, kMuteSource, @@ -423,7 +423,7 @@ { EXPECT_CALL(*main_rfh_broker, CreateStream(NotNull())); mojo::PendingRemote<media::mojom::AudioOutputStreamProviderClient> client; - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateOutputStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), kOutputDeviceId, kParams, std::move(client)); @@ -432,7 +432,7 @@ { EXPECT_CALL(*other_rfh_broker, CreateStream(NotNull())); mojo::PendingRemote<media::mojom::AudioOutputStreamProviderClient> client; - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateOutputStream( other_rfh()->GetProcess()->GetID(), other_rfh()->GetRoutingID(), kOutputDeviceId, kParams, std::move(client)); @@ -475,7 +475,7 @@ EXPECT_CALL(*main_rfh_input_broker, CreateStream(NotNull())); mojo::PendingRemote<blink::mojom::RendererAudioInputStreamFactoryClient> input_client; - ignore_result(input_client.InitWithNewPipeAndPassReceiver()); + std::ignore = input_client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateInputStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), kInputDeviceId, kParams, kSharedMemoryCount, kEnableAgc, @@ -486,7 +486,7 @@ EXPECT_CALL(*other_rfh_input_broker, CreateStream(NotNull())); mojo::PendingRemote<blink::mojom::RendererAudioInputStreamFactoryClient> input_client; - ignore_result(input_client.InitWithNewPipeAndPassReceiver()); + std::ignore = input_client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateInputStream( other_rfh()->GetProcess()->GetID(), other_rfh()->GetRoutingID(), kInputDeviceId, kParams, kSharedMemoryCount, kEnableAgc, @@ -498,7 +498,7 @@ EXPECT_CALL(*main_rfh_loopback_broker, CreateStream(NotNull())); mojo::PendingRemote<blink::mojom::RendererAudioInputStreamFactoryClient> input_client; - ignore_result(input_client.InitWithNewPipeAndPassReceiver()); + std::ignore = input_client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateLoopbackStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), source_factory.core(), kParams, kSharedMemoryCount, kMuteSource, @@ -509,7 +509,7 @@ EXPECT_CALL(*other_rfh_loopback_broker, CreateStream(NotNull())); mojo::PendingRemote<blink::mojom::RendererAudioInputStreamFactoryClient> input_client; - ignore_result(input_client.InitWithNewPipeAndPassReceiver()); + std::ignore = input_client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateLoopbackStream( other_rfh()->GetProcess()->GetID(), other_rfh()->GetRoutingID(), source_factory.core(), kParams, kSharedMemoryCount, kMuteSource, @@ -521,7 +521,7 @@ EXPECT_CALL(*main_rfh_output_broker, CreateStream(NotNull())); mojo::PendingRemote<media::mojom::AudioOutputStreamProviderClient> output_client; - ignore_result(output_client.InitWithNewPipeAndPassReceiver()); + std::ignore = output_client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateOutputStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), kOutputDeviceId, kParams, std::move(output_client)); @@ -531,7 +531,7 @@ EXPECT_CALL(*other_rfh_output_broker, CreateStream(NotNull())); mojo::PendingRemote<media::mojom::AudioOutputStreamProviderClient> output_client; - ignore_result(output_client.InitWithNewPipeAndPassReceiver()); + std::ignore = output_client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateOutputStream( other_rfh()->GetProcess()->GetID(), other_rfh()->GetRoutingID(), kOutputDeviceId, kParams, std::move(output_client)); @@ -568,14 +568,14 @@ std::move(broker_factory_)); EXPECT_CALL(*input_broker, CreateStream(NotNull())); - ignore_result(input_client.InitWithNewPipeAndPassReceiver()); + std::ignore = input_client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateInputStream(main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), kInputDeviceId, kParams, kSharedMemoryCount, kEnableAgc, std::move(input_client)); EXPECT_CALL(*output_broker, CreateStream(NotNull())); - ignore_result(output_client.InitWithNewPipeAndPassReceiver()); + std::ignore = output_client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateOutputStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), kOutputDeviceId, kParams, std::move(output_client)); @@ -608,7 +608,7 @@ EXPECT_CALL(*main_rfh_input_broker, CreateStream(NotNull())); mojo::PendingRemote<blink::mojom::RendererAudioInputStreamFactoryClient> input_client; - ignore_result(input_client.InitWithNewPipeAndPassReceiver()); + std::ignore = input_client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateInputStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), kInputDeviceId, kParams, kSharedMemoryCount, kEnableAgc, @@ -619,7 +619,7 @@ EXPECT_CALL(*other_rfh_input_broker, CreateStream(NotNull())); mojo::PendingRemote<blink::mojom::RendererAudioInputStreamFactoryClient> input_client; - ignore_result(input_client.InitWithNewPipeAndPassReceiver()); + std::ignore = input_client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateInputStream( other_rfh()->GetProcess()->GetID(), other_rfh()->GetRoutingID(), kInputDeviceId, kParams, kSharedMemoryCount, kEnableAgc, @@ -631,7 +631,7 @@ EXPECT_CALL(*main_rfh_output_broker, CreateStream(NotNull())); mojo::PendingRemote<media::mojom::AudioOutputStreamProviderClient> output_client; - ignore_result(output_client.InitWithNewPipeAndPassReceiver()); + std::ignore = output_client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateOutputStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), kOutputDeviceId, kParams, std::move(output_client)); @@ -641,7 +641,7 @@ EXPECT_CALL(*other_rfh_output_broker, CreateStream(NotNull())); mojo::PendingRemote<media::mojom::AudioOutputStreamProviderClient> output_client; - ignore_result(output_client.InitWithNewPipeAndPassReceiver()); + std::ignore = output_client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateOutputStream( other_rfh()->GetProcess()->GetID(), other_rfh()->GetRoutingID(), kOutputDeviceId, kParams, std::move(output_client)); @@ -697,7 +697,7 @@ std::move(broker_factory_)); EXPECT_CALL(*broker, CreateStream(NotNull())); - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateOutputStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), kOutputDeviceId, kParams, std::move(client)); @@ -738,7 +738,7 @@ EXPECT_FALSE(stream_factory_.IsMuterConnected()); EXPECT_CALL(*broker, CreateStream(NotNull())); - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateOutputStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), kOutputDeviceId, kParams, std::move(client)); @@ -768,7 +768,7 @@ { EXPECT_CALL(*broker, CreateStream(NotNull())); mojo::PendingRemote<media::mojom::AudioOutputStreamProviderClient> client; - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateOutputStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), kOutputDeviceId, kParams, std::move(client)); @@ -787,7 +787,7 @@ { EXPECT_CALL(*another_broker, CreateStream(NotNull())); mojo::PendingRemote<media::mojom::AudioOutputStreamProviderClient> client; - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory.core()->CreateOutputStream( main_rfh()->GetProcess()->GetID(), main_rfh()->GetRoutingID(), kOutputDeviceId, kParams, std::move(client));
diff --git a/content/browser/media/frameless_media_interface_proxy.cc b/content/browser/media/frameless_media_interface_proxy.cc index 482d1827..fcaa137e 100644 --- a/content/browser/media/frameless_media_interface_proxy.cc +++ b/content/browser/media/frameless_media_interface_proxy.cc
@@ -4,8 +4,9 @@ #include "content/browser/media/frameless_media_interface_proxy.h" +#include <tuple> + #include "base/bind.h" -#include "base/ignore_result.h" #include "base/logging.h" #include "content/public/browser/media_service.h" #include "media/base/cdm_context.h" @@ -114,7 +115,7 @@ DCHECK(!interface_factory_remote_); mojo::PendingRemote<media::mojom::FrameInterfaceFactory> interfaces; - ignore_result(interfaces.InitWithNewPipeAndPassReceiver()); + std::ignore = interfaces.InitWithNewPipeAndPassReceiver(); GetMediaService().CreateInterfaceFactory( interface_factory_remote_.BindNewPipeAndPassReceiver(),
diff --git a/content/browser/media/key_system_support_impl.cc b/content/browser/media/key_system_support_impl.cc index 5ff77df..f8684e5d 100644 --- a/content/browser/media/key_system_support_impl.cc +++ b/content/browser/media/key_system_support_impl.cc
@@ -4,12 +4,12 @@ #include "content/browser/media/key_system_support_impl.h" +#include <tuple> #include <vector> #include "base/command_line.h" #include "base/containers/flat_set.h" #include "base/feature_list.h" -#include "base/ignore_result.h" #include "base/logging.h" #include "base/metrics/histogram_functions.h" #include "base/strings/string_split.h" @@ -265,9 +265,8 @@ // parallel `IsKeySystemSupported()` calls from different renderer processes. // This is okay and won't cause collision or corruption of data. if (lazy_initialize) { - ignore_result(CdmRegistryImpl::GetInstance()->FinalizeCdmCapability( - key_system, CdmInfo::Robustness::kHardwareSecure, - hw_secure_capability)); + std::ignore = CdmRegistryImpl::GetInstance()->FinalizeCdmCapability( + key_system, CdmInfo::Robustness::kHardwareSecure, hw_secure_capability); } auto capability = media::mojom::KeySystemCapability::New();
diff --git a/content/browser/media/media_interface_proxy.cc b/content/browser/media/media_interface_proxy.cc index ecca0cb..84bc823 100644 --- a/content/browser/media/media_interface_proxy.cc +++ b/content/browser/media/media_interface_proxy.cc
@@ -7,9 +7,9 @@ #include <map> #include <memory> #include <string> +#include <tuple> #include "base/bind.h" -#include "base/ignore_result.h" #include "base/logging.h" #include "base/memory/raw_ptr.h" #include "base/no_destructor.h" @@ -118,7 +118,7 @@ } else { // The embedder doesn't provide a secondary Media Service instance. Bind // permanently to a disconnected pipe which discards all calls. - ignore_result(remote->BindNewPipeAndPassReceiver()); + std::ignore = remote->BindNewPipeAndPassReceiver(); } }
diff --git a/content/browser/portal/portal_browsertest.cc b/content/browser/portal/portal_browsertest.cc index 258c45e7..b2e4112 100644 --- a/content/browser/portal/portal_browsertest.cc +++ b/content/browser/portal/portal_browsertest.cc
@@ -9,7 +9,6 @@ #include "base/bind.h" #include "base/callback.h" #include "base/callback_helpers.h" -#include "base/ignore_result.h" #include "base/memory/ptr_util.h" #include "base/memory/raw_ptr.h" #include "base/run_loop.h" @@ -810,7 +809,7 @@ portal)); RenderProcessHostBadIpcMessageWaiter kill_waiter(main_frame->GetProcess()); GURL a_url(embedded_test_server()->GetURL("a.com", "/title1.html")); - ignore_result(ExecJs(main_frame, JsReplace("portal.src = $1;", a_url))); + std::ignore = ExecJs(main_frame, JsReplace("portal.src = $1;", a_url)); EXPECT_EQ(bad_message::RPH_MOJO_PROCESS_ERROR, kill_waiter.Wait()); }
diff --git a/content/browser/prerender/prerender_browsertest.cc b/content/browser/prerender/prerender_browsertest.cc index 701da2f..a75cf0a 100644 --- a/content/browser/prerender/prerender_browsertest.cc +++ b/content/browser/prerender/prerender_browsertest.cc
@@ -3,6 +3,7 @@ // found in the LICENSE file. #include <cstdint> +#include <tuple> #include "base/barrier_closure.h" #include "base/base_switches.h" @@ -11,7 +12,6 @@ #include "base/containers/flat_set.h" #include "base/files/file_util.h" #include "base/files/scoped_temp_dir.h" -#include "base/ignore_result.h" #include "base/memory/raw_ptr.h" #include "base/memory/weak_ptr.h" #include "base/metrics/metrics_hashes.h" @@ -579,10 +579,10 @@ // Fetch a subframe that requires authentication. const GURL kAuthIFrameUrl = GetUrl("/auth-basic"); RenderFrameHost* prerender_rfh = GetPrerenderedMainFrameHost(host_id); - ignore_result(ExecJs(prerender_rfh, - "var i = document.createElement('iframe'); i.src = '" + - kAuthIFrameUrl.spec() + - "'; document.body.appendChild(i);")); + std::ignore = + ExecJs(prerender_rfh, + "var i = document.createElement('iframe'); i.src = '" + + kAuthIFrameUrl.spec() + "'; document.body.appendChild(i);"); // The prerender should be destroyed. host_observer.WaitForDestroyed(); @@ -618,8 +618,8 @@ imgElement.src = '/auth-basic/favicon.gif'; document.body.appendChild(imgElement); )"; - ignore_result( - ExecJs(GetPrerenderedMainFrameHost(host_id), fetch_subresource_script)); + std::ignore = + ExecJs(GetPrerenderedMainFrameHost(host_id), fetch_subresource_script); // The prerender should be destroyed. host_observer.WaitForDestroyed(); @@ -2143,8 +2143,8 @@ imgElement.src = '/load_image/image.png'; document.body.appendChild(imgElement); )"; - ignore_result(ExecJs(prerender_helper()->GetPrerenderedMainFrameHost(host_id), - fetch_subresource_script)); + std::ignore = ExecJs(prerender_helper()->GetPrerenderedMainFrameHost(host_id), + fetch_subresource_script); // The prerender should be destroyed. host_observer.WaitForDestroyed(); @@ -2215,8 +2215,8 @@ // server should ask for a client certificate or respond with an expired // certificate, which leads to the cancellation of prerendering. std::string resource_url = GetUrl("/workers/empty.js?intercept").spec(); - ignore_result(ExecJs(prerender_helper()->GetPrerenderedMainFrameHost(host_id), - JsReplace("fetch($1);", resource_url))); + std::ignore = ExecJs(prerender_helper()->GetPrerenderedMainFrameHost(host_id), + JsReplace("fetch($1);", resource_url)); // Check the prerender was destroyed. host_observer.WaitForDestroyed(); @@ -2609,8 +2609,8 @@ // Executing `navigator.getGamepads()` to start binding the GamepadMonitor // interface. - ignore_result(EvalJs(prerender_render_frame_host, "navigator.getGamepads()", - EvalJsOptions::EXECUTE_SCRIPT_NO_USER_GESTURE)); + std::ignore = EvalJs(prerender_render_frame_host, "navigator.getGamepads()", + EvalJsOptions::EXECUTE_SCRIPT_NO_USER_GESTURE); // Verify Mojo capability control cancels prerendering. EXPECT_FALSE(HasHostForUrl(kPrerenderingUrl)); histogram_tester.ExpectUniqueSample( @@ -2883,8 +2883,8 @@ // Whether using the EXECUTE_SCRIPT_NO_USER_GESTURE flag or not does not // affect the test result. The purpose of using it is to simulate real // scenarios since prerendering pages cannot have user gestures. - ignore_result(ExecJs(prerender_rfh, "const context = new AudioContext();", - EvalJsOptions::EXECUTE_SCRIPT_NO_USER_GESTURE)); + std::ignore = ExecJs(prerender_rfh, "const context = new AudioContext();", + EvalJsOptions::EXECUTE_SCRIPT_NO_USER_GESTURE); host_observer.WaitForDestroyed(); EXPECT_FALSE(HasHostForUrl(kPrerenderingUrl)); histogram_tester.ExpectUniqueSample(
diff --git a/content/browser/presentation/presentation_service_impl_unittest.cc b/content/browser/presentation/presentation_service_impl_unittest.cc index 9ef865d..21e4306 100644 --- a/content/browser/presentation/presentation_service_impl_unittest.cc +++ b/content/browser/presentation/presentation_service_impl_unittest.cc
@@ -10,12 +10,12 @@ #include <iterator> #include <memory> #include <string> +#include <tuple> #include <utility> #include <vector> #include "base/bind.h" #include "base/callback_helpers.h" -#include "base/ignore_result.h" #include "base/run_loop.h" #include "base/strings/stringprintf.h" #include "content/browser/presentation/presentation_test_utils.h" @@ -251,8 +251,7 @@ mojo::PendingRemote<PresentationConnection> presentation_connection_remote; mojo::Remote<PresentationConnection> controller_remote; - ignore_result( - presentation_connection_remote.InitWithNewPipeAndPassReceiver()); + std::ignore = presentation_connection_remote.InitWithNewPipeAndPassReceiver(); std::move(callback).Run(PresentationConnectionResult::New( blink::mojom::PresentationInfo::New(presentation_url2_, kPresentationId), std::move(presentation_connection_remote),
diff --git a/content/browser/renderer_host/frame_tree_browsertest.cc b/content/browser/renderer_host/frame_tree_browsertest.cc index 42295e8..654cf4a 100644 --- a/content/browser/renderer_host/frame_tree_browsertest.cc +++ b/content/browser/renderer_host/frame_tree_browsertest.cc
@@ -1778,15 +1778,8 @@ EXPECT_TRUE(ExecJs(root, "localStorage[\"foo\"] = \"c\"")); EXPECT_EQ("c", EvalJs(root, "localStorage[\"foo\"]")); - // TODO(crbug.com/1199077) This should return "a" once StorageKey starts - // using the nonce for partitioning. Also remove the shadowDOM specific check - // once nonce support is complete (for MPArch, possibly due to a separate - // process and incomplete nonce support, it is returning "a" on some - // platforms). - if (GetParam() == - blink::features::FencedFramesImplementationType::kShadowDOM) { - EXPECT_EQ("c", EvalJs(fenced_frame, "localStorage[\"foo\"]")); - } + // This shouldn't impact the fenced frame's local storage: + EXPECT_EQ("a", EvalJs(fenced_frame, "localStorage[\"foo\"]")); } IN_PROC_BROWSER_TEST_P(FencedFrameTreeBrowserTest,
diff --git a/content/browser/renderer_host/input/autoscroll_browsertest.cc b/content/browser/renderer_host/input/autoscroll_browsertest.cc index 8c8f189f..43f7768 100644 --- a/content/browser/renderer_host/input/autoscroll_browsertest.cc +++ b/content/browser/renderer_host/input/autoscroll_browsertest.cc
@@ -2,8 +2,9 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <tuple> + #include "base/feature_list.h" -#include "base/ignore_result.h" #include "build/build_config.h" #include "content/browser/web_contents/web_contents_impl.h" #include "content/public/common/content_switches.h" @@ -119,7 +120,7 @@ std::u16string ready_title(u"ready"); TitleWatcher watcher(shell()->web_contents(), ready_title); - ignore_result(watcher.WaitAndGetTitle()); + std::ignore = watcher.WaitAndGetTitle(); MainThreadFrameObserver main_thread_sync(host); main_thread_sync.Wait();
diff --git a/content/browser/renderer_host/input/composited_scrolling_browsertest.cc b/content/browser/renderer_host/input/composited_scrolling_browsertest.cc index d179bb16..ecb5ab8 100644 --- a/content/browser/renderer_host/input/composited_scrolling_browsertest.cc +++ b/content/browser/renderer_host/input/composited_scrolling_browsertest.cc
@@ -2,11 +2,11 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <tuple> #include <utility> #include "base/bind.h" #include "base/command_line.h" -#include "base/ignore_result.h" #include "base/run_loop.h" #include "base/strings/utf_string_conversions.h" #include "base/test/metrics/histogram_tester.h" @@ -107,7 +107,7 @@ std::u16string ready_title(u"ready"); TitleWatcher watcher(shell()->web_contents(), ready_title); - ignore_result(watcher.WaitAndGetTitle()); + std::ignore = watcher.WaitAndGetTitle(); // Wait for the hit test data to be ready after initiating URL loading // before returning
diff --git a/content/browser/renderer_host/input/compositor_event_ack_browsertest.cc b/content/browser/renderer_host/input/compositor_event_ack_browsertest.cc index 7ced1c5..9a4c91d 100644 --- a/content/browser/renderer_host/input/compositor_event_ack_browsertest.cc +++ b/content/browser/renderer_host/input/compositor_event_ack_browsertest.cc
@@ -2,12 +2,12 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <tuple> #include <utility> #include "base/auto_reset.h" #include "base/bind.h" #include "base/command_line.h" -#include "base/ignore_result.h" #include "base/run_loop.h" #include "base/strings/utf_string_conversions.h" #include "build/build_config.h" @@ -137,7 +137,7 @@ std::u16string ready_title(u"ready"); TitleWatcher watcher(shell()->web_contents(), ready_title); - ignore_result(watcher.WaitAndGetTitle()); + std::ignore = watcher.WaitAndGetTitle(); // SetSize triggers an animation of the size, leading to a a new // viz::LocalSurfaceId being generated. Since this was done right after
diff --git a/content/browser/renderer_host/input/fling_browsertest.cc b/content/browser/renderer_host/input/fling_browsertest.cc index d6e2a62..91757d4b 100644 --- a/content/browser/renderer_host/input/fling_browsertest.cc +++ b/content/browser/renderer_host/input/fling_browsertest.cc
@@ -2,8 +2,9 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <tuple> + #include "base/bind.h" -#include "base/ignore_result.h" #include "base/memory/raw_ptr.h" #include "base/test/scoped_feature_list.h" #include "build/build_config.h" @@ -104,7 +105,7 @@ std::u16string ready_title(u"ready"); TitleWatcher watcher(shell()->web_contents(), ready_title); - ignore_result(watcher.WaitAndGetTitle()); + std::ignore = watcher.WaitAndGetTitle(); SynchronizeThreads(); }
diff --git a/content/browser/renderer_host/input/input_event_browsertest.cc b/content/browser/renderer_host/input/input_event_browsertest.cc index 6ecd6ed..abafe046 100644 --- a/content/browser/renderer_host/input/input_event_browsertest.cc +++ b/content/browser/renderer_host/input/input_event_browsertest.cc
@@ -2,8 +2,9 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <tuple> + #include "base/command_line.h" -#include "base/ignore_result.h" #include "base/json/json_reader.h" #include "base/strings/utf_string_conversions.h" #include "base/test/bind.h" @@ -95,7 +96,7 @@ std::u16string ready_title(u"ready"); TitleWatcher watcher(shell()->web_contents(), ready_title); - ignore_result(watcher.WaitAndGetTitle()); + std::ignore = watcher.WaitAndGetTitle(); // We need to wait until hit test data is available. We use our own // HitTestRegionObserver here because we have the RenderWidgetHostImpl
diff --git a/content/browser/renderer_host/input/main_thread_event_queue_browsertest.cc b/content/browser/renderer_host/input/main_thread_event_queue_browsertest.cc index ac3ff27..2cff438 100644 --- a/content/browser/renderer_host/input/main_thread_event_queue_browsertest.cc +++ b/content/browser/renderer_host/input/main_thread_event_queue_browsertest.cc
@@ -2,12 +2,12 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <tuple> #include <utility> #include "base/auto_reset.h" #include "base/bind.h" #include "base/command_line.h" -#include "base/ignore_result.h" #include "base/run_loop.h" #include "base/strings/utf_string_conversions.h" #include "build/build_config.h" @@ -105,7 +105,7 @@ std::u16string ready_title(u"ready"); TitleWatcher watcher(shell()->web_contents(), ready_title); - ignore_result(watcher.WaitAndGetTitle()); + std::ignore = watcher.WaitAndGetTitle(); HitTestRegionObserver observer(host->GetFrameSinkId()); observer.WaitForHitTestData();
diff --git a/content/browser/renderer_host/input/scroll_behavior_browsertest.cc b/content/browser/renderer_host/input/scroll_behavior_browsertest.cc index bfb1458..f0ca2f6 100644 --- a/content/browser/renderer_host/input/scroll_behavior_browsertest.cc +++ b/content/browser/renderer_host/input/scroll_behavior_browsertest.cc
@@ -3,9 +3,9 @@ // found in the LICENSE file. #include <memory> +#include <tuple> #include "base/bind.h" -#include "base/ignore_result.h" #include "base/run_loop.h" #include "build/build_config.h" #include "cc/base/switches.h" @@ -184,7 +184,7 @@ std::u16string ready_title(u"ready"); TitleWatcher watcher(shell()->web_contents(), ready_title); - ignore_result(watcher.WaitAndGetTitle()); + std::ignore = watcher.WaitAndGetTitle(); HitTestRegionObserver observer(host->GetFrameSinkId()); // Wait for the hit test data to be ready
diff --git a/content/browser/renderer_host/input/synthetic_input_browsertest.cc b/content/browser/renderer_host/input/synthetic_input_browsertest.cc index d7f815b..daf4f1b 100644 --- a/content/browser/renderer_host/input/synthetic_input_browsertest.cc +++ b/content/browser/renderer_host/input/synthetic_input_browsertest.cc
@@ -3,10 +3,10 @@ // found in the LICENSE file. #include <memory> +#include <tuple> #include "base/bind.h" #include "base/callback.h" -#include "base/ignore_result.h" #include "base/run_loop.h" #include "base/test/test_timeouts.h" #include "build/build_config.h" @@ -58,7 +58,7 @@ std::u16string ready_title(u"ready"); TitleWatcher watcher(shell()->web_contents(), ready_title); - ignore_result(watcher.WaitAndGetTitle()); + std::ignore = watcher.WaitAndGetTitle(); // Wait for the hit test data to be ready after initiating URL loading // before returning
diff --git a/content/browser/renderer_host/input/touch_action_browsertest.cc b/content/browser/renderer_host/input/touch_action_browsertest.cc index af00296..32cc566d 100644 --- a/content/browser/renderer_host/input/touch_action_browsertest.cc +++ b/content/browser/renderer_host/input/touch_action_browsertest.cc
@@ -2,13 +2,13 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <tuple> #include <utility> #include "base/auto_reset.h" #include "base/bind.h" #include "base/callback_helpers.h" #include "base/command_line.h" -#include "base/ignore_result.h" #include "base/json/json_reader.h" #include "base/run_loop.h" #include "base/strings/stringprintf.h" @@ -170,7 +170,7 @@ std::u16string ready_title(u"ready"); TitleWatcher watcher(shell()->web_contents(), ready_title); - ignore_result(watcher.WaitAndGetTitle()); + std::ignore = watcher.WaitAndGetTitle(); // We need to wait until hit test data is available. We use our own // HitTestRegionObserver here because we have the RenderWidgetHostImpl
diff --git a/content/browser/renderer_host/input/wheel_event_listener_browsertest.cc b/content/browser/renderer_host/input/wheel_event_listener_browsertest.cc index c40c2607..91ada9d 100644 --- a/content/browser/renderer_host/input/wheel_event_listener_browsertest.cc +++ b/content/browser/renderer_host/input/wheel_event_listener_browsertest.cc
@@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "base/ignore_result.h" +#include <tuple> + #include "base/test/scoped_feature_list.h" #include "content/browser/web_contents/web_contents_impl.h" #include "content/public/common/content_features.h" @@ -79,7 +80,7 @@ std::u16string ready_title(u"ready"); TitleWatcher watcher(shell()->web_contents(), ready_title); - ignore_result(watcher.WaitAndGetTitle()); + std::ignore = watcher.WaitAndGetTitle(); MainThreadFrameObserver main_thread_sync(host); main_thread_sync.Wait();
diff --git a/content/browser/renderer_host/media/media_devices_dispatcher_host.cc b/content/browser/renderer_host/media/media_devices_dispatcher_host.cc index 073b8291..15a05c73 100644 --- a/content/browser/renderer_host/media/media_devices_dispatcher_host.cc +++ b/content/browser/renderer_host/media/media_devices_dispatcher_host.cc
@@ -278,6 +278,7 @@ // so we execute directly. bad_message::ReceivedBadMessage(render_process_id, bad_message::MDDH_NOT_TOP_LEVEL); + return; } rfhi->delegate()->SetCaptureHandleConfig(std::move(config)); },
diff --git a/content/browser/renderer_host/media/render_frame_audio_input_stream_factory_unittest.cc b/content/browser/renderer_host/media/render_frame_audio_input_stream_factory_unittest.cc index 6966c4b..214da21 100644 --- a/content/browser/renderer_host/media/render_frame_audio_input_stream_factory_unittest.cc +++ b/content/browser/renderer_host/media/render_frame_audio_input_stream_factory_unittest.cc
@@ -5,11 +5,11 @@ #include "content/browser/renderer_host/media/render_frame_audio_input_stream_factory.h" #include <string> +#include <tuple> #include <utility> #include "base/bind.h" #include "base/callback_helpers.h" -#include "base/ignore_result.h" #include "base/run_loop.h" #include "build/build_config.h" #include "build/chromeos_buildflags.h" @@ -217,7 +217,7 @@ mojo::PendingRemote<blink::mojom::RendererAudioInputStreamFactoryClient> client; - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory_remote->CreateStream(std::move(client), session_id, kParams, kAGC, kSharedMemoryCount); @@ -245,7 +245,7 @@ mojo::PendingRemote<blink::mojom::RendererAudioInputStreamFactoryClient> client; - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory_remote->CreateStream(std::move(client), session_id, kParams, kAGC, kSharedMemoryCount); @@ -274,7 +274,7 @@ source_contents.reset(); mojo::PendingRemote<blink::mojom::RendererAudioInputStreamFactoryClient> client; - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory_remote->CreateStream(std::move(client), session_id, kParams, kAGC, kSharedMemoryCount); @@ -293,7 +293,7 @@ base::UnguessableToken session_id = base::UnguessableToken::Create(); mojo::PendingRemote<blink::mojom::RendererAudioInputStreamFactoryClient> client; - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); factory_remote->CreateStream(std::move(client), session_id, kParams, kAGC, kSharedMemoryCount);
diff --git a/content/browser/renderer_host/media/render_frame_audio_output_stream_factory_unittest.cc b/content/browser/renderer_host/media/render_frame_audio_output_stream_factory_unittest.cc index 1ba8195..d924cb10 100644 --- a/content/browser/renderer_host/media/render_frame_audio_output_stream_factory_unittest.cc +++ b/content/browser/renderer_host/media/render_frame_audio_output_stream_factory_unittest.cc
@@ -6,11 +6,11 @@ #include <memory> #include <string> +#include <tuple> #include <utility> #include "base/bind.h" #include "base/callback_helpers.h" -#include "base/ignore_result.h" #include "base/run_loop.h" #include "base/test/mock_callback.h" #include "base/test/task_environment.h" @@ -217,7 +217,7 @@ kDefaultDeviceId, mock_callback.Get()); { mojo::PendingRemote<media::mojom::AudioOutputStreamProviderClient> client; - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); provider_remote->Acquire(kParams, std::move(client)); } @@ -256,7 +256,7 @@ // Now factory is destructed. Trying to create a stream should fail. { mojo::PendingRemote<media::mojom::AudioOutputStreamProviderClient> client; - ignore_result(client.InitWithNewPipeAndPassReceiver()); + std::ignore = client.InitWithNewPipeAndPassReceiver(); provider_remote->Acquire(kParams, std::move(client)); }
diff --git a/content/browser/renderer_host/pepper/quota_reservation.cc b/content/browser/renderer_host/pepper/quota_reservation.cc index 286b832..f768ce1 100644 --- a/content/browser/renderer_host/pepper/quota_reservation.cc +++ b/content/browser/renderer_host/pepper/quota_reservation.cc
@@ -5,10 +5,10 @@ #include "content/browser/renderer_host/pepper/quota_reservation.h" #include <memory> +#include <tuple> #include "base/bind.h" #include "base/callback.h" -#include "base/ignore_result.h" #include "content/public/browser/browser_task_traits.h" #include "content/public/browser/browser_thread.h" #include "storage/browser/file_system/file_system_operation_runner.h" @@ -75,7 +75,7 @@ files_.insert(std::make_pair(id, file_handle.get())); if (insert_result.second) { int64_t max_written_offset = file_handle->GetMaxWrittenOffset(); - ignore_result(file_handle.release()); + std::ignore = file_handle.release(); return max_written_offset; } NOTREACHED();
diff --git a/content/browser/renderer_host/render_frame_host_impl.cc b/content/browser/renderer_host/render_frame_host_impl.cc index 0d42081..4d44187a 100644 --- a/content/browser/renderer_host/render_frame_host_impl.cc +++ b/content/browser/renderer_host/render_frame_host_impl.cc
@@ -6,6 +6,7 @@ #include <algorithm> #include <memory> +#include <tuple> #include <unordered_map> #include <utility> @@ -17,7 +18,6 @@ #include "base/debug/crash_logging.h" #include "base/debug/dump_without_crashing.h" #include "base/feature_list.h" -#include "base/ignore_result.h" #include "base/lazy_instance.h" #include "base/memory/memory_pressure_monitor.h" #include "base/memory/ptr_util.h" @@ -2454,8 +2454,7 @@ } else { // The channel may not be initialized in some tests environments. In this // case we set up a dummy interface provider. - ignore_result( - remote_interfaces.BindNewEndpointAndPassDedicatedReceiver()); + std::ignore = remote_interfaces.BindNewEndpointAndPassDedicatedReceiver(); } remote_associated_interfaces_ = std::make_unique<blink::AssociatedInterfaceProvider>(
diff --git a/content/browser/renderer_host/render_frame_host_impl_browsertest.cc b/content/browser/renderer_host/render_frame_host_impl_browsertest.cc index 1356a72..51b5031 100644 --- a/content/browser/renderer_host/render_frame_host_impl_browsertest.cc +++ b/content/browser/renderer_host/render_frame_host_impl_browsertest.cc
@@ -7,12 +7,12 @@ #include <memory> #include <set> #include <string> +#include <tuple> #include <utility> #include "base/bind.h" #include "base/callback_helpers.h" #include "base/files/file_path.h" -#include "base/ignore_result.h" #include "base/memory/ptr_util.h" #include "base/memory/raw_ptr.h" #include "base/path_service.h" @@ -3133,7 +3133,7 @@ // soon, let's wait until the document.readyState finalizes. We don't really // care if that succeeds since, in the failing case, the renderer is crashing. EXPECT_TRUE(NavigateToURL(shell(), url)); - ignore_result(WaitForRenderFrameReady(web_contents()->GetMainFrame())); + std::ignore = WaitForRenderFrameReady(web_contents()->GetMainFrame()); EXPECT_TRUE(crash_observer.did_exit_normally()); }
diff --git a/content/browser/renderer_host/render_frame_proxy_host.cc b/content/browser/renderer_host/render_frame_proxy_host.cc index 71ead2f..0c456ae 100644 --- a/content/browser/renderer_host/render_frame_proxy_host.cc +++ b/content/browser/renderer_host/render_frame_proxy_host.cc
@@ -5,6 +5,7 @@ #include "content/browser/renderer_host/render_frame_proxy_host.h" #include <memory> +#include <tuple> #include <unordered_map> #include <utility> #include <vector> @@ -13,7 +14,6 @@ #include "base/containers/circular_deque.h" #include "base/containers/contains.h" #include "base/hash/hash.h" -#include "base/ignore_result.h" #include "base/lazy_instance.h" #include "base/no_destructor.h" #include "base/stl_util.h" @@ -345,8 +345,7 @@ } else { // The channel may not be initialized in some tests environments. In this // case we set up a dummy interface provider. - ignore_result( - remote_interfaces.BindNewEndpointAndPassDedicatedReceiver()); + std::ignore = remote_interfaces.BindNewEndpointAndPassDedicatedReceiver(); } remote_associated_interfaces_ = std::make_unique<blink::AssociatedInterfaceProvider>(
diff --git a/content/browser/renderer_host/render_widget_host_view_mac.mm b/content/browser/renderer_host/render_widget_host_view_mac.mm index 9b276a3..86da1b5 100644 --- a/content/browser/renderer_host/render_widget_host_view_mac.mm +++ b/content/browser/renderer_host/render_widget_host_view_mac.mm
@@ -8,13 +8,13 @@ #include <limits> #include <memory> +#include <tuple> #include <utility> #include "base/auto_reset.h" #include "base/bind.h" #include "base/command_line.h" #include "base/feature_list.h" -#include "base/ignore_result.h" #include "base/logging.h" #include "base/mac/mac_util.h" #include "base/mac/scoped_cftyperef.h" @@ -226,7 +226,7 @@ // first to rebaseline some unreliable web tests. // NOTE: This will not be run for child frame widgets, which do not have // an owner delegate and won't get a RenderViewHost here. - ignore_result(owner_delegate->GetWebkitPreferencesForWidget()); + std::ignore = owner_delegate->GetWebkitPreferencesForWidget(); } cursor_manager_ = std::make_unique<CursorManager>(this);
diff --git a/content/browser/resources/attribution_reporting/attribution_internals.html b/content/browser/resources/attribution_reporting/attribution_internals.html index 0287749..09f9caf 100644 --- a/content/browser/resources/attribution_reporting/attribution_internals.html +++ b/content/browser/resources/attribution_reporting/attribution_internals.html
@@ -33,7 +33,7 @@ <h2> Sent and Pending Reports </h2> <div class="content"> <div> - <button id="send-reports">Send All Pending Reports</button> + <button id="send-reports" disabled>Send Selected Pending Reports</button> </div> <div> <span id="debug-mode-content"></span>
diff --git a/content/browser/resources/attribution_reporting/attribution_internals.js b/content/browser/resources/attribution_reporting/attribution_internals.js index b6d91aff..faf237b 100644 --- a/content/browser/resources/attribution_reporting/attribution_internals.js +++ b/content/browser/resources/attribution_reporting/attribution_internals.js
@@ -27,9 +27,34 @@ /** * @template T - * @template V + * @abstract */ class Column { + constructor() { + /** @type {?function(!T, !T): number} */ + this.compare; + } + + /** + * @param {!Element} td + * @param {!T} row + * @abstract + */ + render(td, row) {} + + /** + * @param {!Element} th + * @abstract + */ + renderHeader(th) {} +} + +/** + * @template T + * @template V + * @extends {Column<T>} + */ +class ValueColumn extends Column { /** * @param {string} header * @param {function(!T): !V} getValue @@ -38,6 +63,8 @@ constructor( header, getValue, compare = (a, b) => compareDefault(getValue(a), getValue(b))) { + super(); + this.header = header; /** @protected */ @@ -46,20 +73,22 @@ this.compare = compare; } - /** - * @param {!Element} td - * @param {!T} row - */ + /** @override */ render(td, row) { td.innerText = this.getValue(row); } + + /** @override */ + renderHeader(th) { + th.innerText = this.header; + } } /** * @template T - * @extends {Column<T, Date>} + * @extends {ValueColumn<T, Date>} */ -class DateColumn extends Column { +class DateColumn extends ValueColumn { /** * @param {string} header * @param {function(!T): Date} getValue @@ -76,9 +105,9 @@ /** * @template T - * @extends {Column<T, string>} + * @extends {ValueColumn<T, string>} */ -class CodeColumn extends Column { +class CodeColumn extends ValueColumn { /** * @param {string} header * @param {function(!T): string} getValue @@ -105,10 +134,7 @@ */ class TableModel { constructor() { - /** @type {?Table<T>} */ - this.table; - - /** @type {!Array<Column<T, ?>>} */ + /** @type {!Array<Column<T>>} */ this.cols; /** @type {string} */ @@ -116,6 +142,9 @@ /** @type {number} */ this.sortIdx = -1; + + /** @type {!Set<function()>} */ + this.rowsChangedListeners = new Set(); } /** @@ -129,6 +158,95 @@ * @return {!Array<!T>} */ getRows() {} + + notifyRowsChanged() { + this.rowsChangedListeners.forEach((f) => f()); + } +} + +class Selectable { + constructor() { + this.input = document.createElement('input'); + this.input.type = 'checkbox'; + } +} + +/** + * @template T + * @extends {Column<T>} + */ +class SelectionColumn extends Column { + /** + * @param {!TableModel<T>} model + */ + constructor(model) { + super(); + + this.model = model; + + this.selectAll = document.createElement('input'); + this.selectAll.type = 'checkbox'; + this.selectAll.addEventListener('input', () => { + const checked = this.selectAll.checked; + this.model.getRows().forEach((row) => { + if (!row.input.disabled) { + row.input.checked = checked; + } + }); + this.notifySelectionChanged(checked); + }); + + this.listener = () => this.onChange(); + this.model.rowsChangedListeners.add(this.listener); + + /** @type {!Set<function(boolean)>} */ + this.selectionChangedListeners = new Set(); + } + + /** @override */ + render(td, row) { + td.appendChild(row.input); + } + + /** @override */ + renderHeader(th) { + th.appendChild(this.selectAll); + } + + onChange() { + let anySelectable = false; + let anySelected = false; + let anyUnselected = false; + + this.model.getRows().forEach((row) => { + // addEventListener deduplicates, so only one event will be fired per + // input. + row.input.addEventListener('input', this.listener); + + if (row.input.disabled) { + return; + } + + anySelectable = true; + + if (row.input.checked) { + anySelected = true; + } else { + anyUnselected = true; + } + }); + + this.selectAll.disabled = !anySelectable; + this.selectAll.checked = anySelected && !anyUnselected; + this.selectAll.indeterminate = anySelected && anyUnselected; + + this.notifySelectionChanged(anySelected); + } + + /** @param {boolean} anySelected */ + notifySelectionChanged(anySelected) { + this.selectionChangedListeners.forEach((f) => f(anySelected)); + } } /** @@ -162,8 +280,6 @@ self.__proto__ = Table.prototype; self = /** @type {!Table} */ (self); - model.table = self; - self.model = model; self.sortDesc = false; @@ -171,7 +287,7 @@ self.model.cols.forEach((col, idx) => { const th = self.ownerDocument.createElement('th'); th.scope = 'col'; - th.innerText = col.header; + col.renderHeader(th); if (col.compare) { th.role = 'button'; @@ -193,6 +309,8 @@ table.appendChild(self.tbody); self.appendChild(table); + + self.model.rowsChangedListeners.add(() => self.updateTbody()); } /** @@ -319,16 +437,16 @@ super(); this.cols = [ - new Column('Source Event ID', (e) => e.sourceEventId), - new Column('Source Origin', (e) => e.impressionOrigin), - new Column('Destination', (e) => e.attributionDestination), - new Column('Report To', (e) => e.reportingOrigin), + new ValueColumn('Source Event ID', (e) => e.sourceEventId), + new ValueColumn('Source Origin', (e) => e.impressionOrigin), + new ValueColumn('Destination', (e) => e.attributionDestination), + new ValueColumn('Report To', (e) => e.reportingOrigin), new DateColumn('Source Registration Time', (e) => e.impressionTime), new DateColumn('Expiry Time', (e) => e.expiryTime), - new Column('Source Type', (e) => e.sourceType), - new Column('Priority', (e) => e.priority), - new Column('Dedup Keys', (e) => e.dedupKeys, /*compare=*/ null), - new Column('Status', (e) => e.status), + new ValueColumn('Source Type', (e) => e.sourceType), + new ValueColumn('Priority', (e) => e.priority), + new ValueColumn('Dedup Keys', (e) => e.dedupKeys, /*compare=*/ null), + new ValueColumn('Status', (e) => e.status), ]; this.emptyRowText = 'No sources.'; @@ -351,7 +469,7 @@ /** @param {!Array<!Source>} storedSources */ setStoredSources(storedSources) { this.storedSources = storedSources; - this.table.updateTbody(); + this.notifyRowsChanged(); } /** @param {!Source} source */ @@ -363,21 +481,24 @@ } this.deactivatedSources.push(source); - this.table.updateTbody(); + this.notifyRowsChanged(); } clear() { this.storedSources = []; this.deactivatedSources = []; - this.table.updateTbody(); + this.notifyRowsChanged(); } } -class Report { +class Report extends Selectable { /** * @param {!WebUIAttributionReport} mojo */ constructor(mojo) { + super(); + + this.id = mojo.id; this.reportBody = mojo.reportBody; this.attributionDestination = mojo.attributionDestination; this.reportUrl = mojo.reportUrl.url; @@ -386,6 +507,12 @@ this.reportPriority = mojo.priority; this.attributedTruthfully = mojo.attributedTruthfully; + // Only pending reports are selectable. + if (this.id === null || + mojo.status !== WebUIAttributionReport_Status.kPending) { + this.input.disabled = true; + } + switch (mojo.status) { case WebUIAttributionReport_Status.kSent: this.status = `Sent: HTTP ${mojo.httpResponseCode}`; @@ -418,21 +545,25 @@ constructor() { super(); + this.selectionColumn = new SelectionColumn(this); + this.cols = [ + this.selectionColumn, new CodeColumn('Report Body', (e) => e.reportBody), - new Column('Destination', (e) => e.attributionDestination), - new Column('Report URL', (e) => e.reportUrl), + new ValueColumn('Destination', (e) => e.attributionDestination), + new ValueColumn('Report URL', (e) => e.reportUrl), new DateColumn('Trigger Time', (e) => e.triggerTime), new DateColumn('Report Time', (e) => e.reportTime), - new Column('Report Priority', (e) => e.reportPriority), - new Column('Fake Report', (e) => e.attributedTruthfully ? 'no' : 'yes'), - new Column('Status', (e) => e.status), + new ValueColumn('Report Priority', (e) => e.reportPriority), + new ValueColumn( + 'Fake Report', (e) => e.attributedTruthfully ? 'no' : 'yes'), + new ValueColumn('Status', (e) => e.status), ]; this.emptyRowText = 'No sent or pending reports.'; // Sort by report time by default. - this.sortIdx = 4; + this.sortIdx = 5; /** @type {!Array<!Report>} */ this.sentOrDroppedReports = []; @@ -456,7 +587,7 @@ /** @param {!Array<!Report>} storedReports */ setStoredReports(storedReports) { this.storedReports = storedReports; - this.table.updateTbody(); + this.notifyRowsChanged(); } /** @param {!Report} report */ @@ -468,13 +599,13 @@ } this.sentOrDroppedReports.push(report); - this.table.updateTbody(); + this.notifyRowsChanged(); } clear() { this.storedReports = []; this.sentOrDroppedReports = []; - this.table.updateTbody(); + this.notifyRowsChanged(); } } @@ -605,13 +736,24 @@ * the data on completion. */ function sendReports() { + const ids = []; + reportTableModel.storedReports.forEach((report) => { + if (!report.input.disabled && report.input.checked && report.id !== null) { + ids.push(report.id); + } + }); + + if (ids.length === 0) { + return; + } + const button = $('send-reports'); const previousText = $('send-reports').innerText; button.disabled = true; button.innerText = 'Sending...'; - pageHandler.sendPendingReports().then(() => { - button.disabled = false; + + pageHandler.sendReports(ids).then(() => { button.innerText = previousText; }); } @@ -653,7 +795,13 @@ $('refresh').addEventListener('click', updatePageData); $('clear-data').addEventListener('click', clearStorage); - $('send-reports').addEventListener('click', sendReports); + + const sendReportsButton = $('send-reports'); + sendReportsButton.addEventListener('click', sendReports); + reportTableModel.selectionColumn.selectionChangedListeners.add( + (anySelected) => { + sendReportsButton.disabled = !anySelected; + }); Table.decorate(getRequiredElement('source-table-wrapper'), sourceTableModel); Table.decorate(getRequiredElement('report-table-wrapper'), reportTableModel);
diff --git a/content/browser/security_exploit_browsertest.cc b/content/browser/security_exploit_browsertest.cc index 90e22cb..dc3cbbd 100644 --- a/content/browser/security_exploit_browsertest.cc +++ b/content/browser/security_exploit_browsertest.cc
@@ -4,11 +4,12 @@ #include <stdint.h> +#include <tuple> + #include "base/bind.h" #include "base/command_line.h" #include "base/feature_list.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/memory/ptr_util.h" #include "base/memory/raw_ptr.h" #include "base/strings/stringprintf.h" @@ -510,8 +511,8 @@ shell()->web_contents()->GetMainFrame()->GetProcess()); // ExecJs will sometimes finish before the renderer gets killed, so we must // ignore the result. - ignore_result(ExecJs(shell()->web_contents()->GetMainFrame(), - "history.pushState({}, '', location.href);")); + std::ignore = ExecJs(shell()->web_contents()->GetMainFrame(), + "history.pushState({}, '', location.href);"); EXPECT_EQ(bad_message::RFH_INVALID_ORIGIN_ON_COMMIT, kill_waiter.Wait()); } @@ -538,7 +539,7 @@ RenderProcessHostBadIpcMessageWaiter kill_waiter(subframe->GetProcess()); // ExecJs will sometimes finish before the renderer gets killed, so we must // ignore the result. - ignore_result(ExecJs(subframe, "location.hash='foo';")); + std::ignore = ExecJs(subframe, "location.hash='foo';"); EXPECT_EQ(bad_message::RFH_INVALID_ORIGIN_ON_COMMIT, kill_waiter.Wait()); } @@ -568,8 +569,8 @@ shell()->web_contents()->GetMainFrame()->GetProcess()); // ExecJs will sometimes finish before the renderer gets killed, so we must // ignore the result. - ignore_result(ExecJs(shell()->web_contents()->GetMainFrame(), - "history.pushState({}, '', location.href);")); + std::ignore = ExecJs(shell()->web_contents()->GetMainFrame(), + "history.pushState({}, '', location.href);"); EXPECT_EQ(bad_message::RFH_INVALID_ORIGIN_ON_COMMIT, kill_waiter.Wait()); } @@ -763,7 +764,7 @@ // The renderer should always get killed, but sometimes ExecuteScript returns // true anyway, so just ignore the result. - ignore_result(ExecJs(rfh, "URL.createObjectURL(new Blob(['foo']))")); + std::ignore = ExecJs(rfh, "URL.createObjectURL(new Blob(['foo']))"); // If the process is killed, this test passes. EXPECT_EQ( @@ -1689,8 +1690,8 @@ // // It also can't EXPECT_TRUE or EXPECT_FALSE: sometimes the ExecJs call will // finish before the renderer gets killed, and sometimes it won't. - ignore_result( - ExecJs(child, JsReplace("location.href=$1;", GURL("/title2.html")))); + std::ignore = + ExecJs(child, JsReplace("location.href=$1;", GURL("/title2.html"))); EXPECT_THAT(kill_waiter.Wait(), Optional(HasSubstr("Permissions Policy feature is absent"))); @@ -1735,8 +1736,8 @@ // // It also can't EXPECT_TRUE or EXPECT_FALSE: sometimes the ExecJs call will // finish before the renderer gets killed, and sometimes it won't. - ignore_result( - ExecJs(child, JsReplace("location.href=$1;", GURL("/title2.html")))); + std::ignore = + ExecJs(child, JsReplace("location.href=$1;", GURL("/title2.html"))); EXPECT_THAT(kill_waiter.Wait(), Optional(HasSubstr("Permissions Policy feature is absent"))); @@ -1763,8 +1764,8 @@ // Can't use NavigateToURL here because it would hang. Additionally, we can't // EXPECT_TRUE or EXPECT_FALSE: sometimes the ExecJs call will finish // before the renderer gets killed, and sometimes it won't. - ignore_result(ExecJs(compromised_renderer, - JsReplace("location.href=$1", GURL("/title2.html")))); + std::ignore = ExecJs(compromised_renderer, + JsReplace("location.href=$1", GURL("/title2.html"))); EXPECT_THAT(kill_waiter.Wait(), Optional(HasSubstr("Trust Token params in main frame nav"))); @@ -1844,10 +1845,10 @@ compromised_renderer->GetProcess()); replacer.Activate(); - ignore_result(ExecJs( + std::ignore = ExecJs( compromised_renderer, JsReplace("location.href=$1", - embedded_test_server()->GetURL("/fenced_frames/title1.html")))); + embedded_test_server()->GetURL("/fenced_frames/title1.html"))); absl::optional<std::string> result = kill_waiter.Wait(); EXPECT_THAT(result,
diff --git a/content/browser/service_worker/service_worker_subresource_filter_browsertest.cc b/content/browser/service_worker/service_worker_subresource_filter_browsertest.cc deleted file mode 100644 index 4e4aec9..0000000 --- a/content/browser/service_worker/service_worker_subresource_filter_browsertest.cc +++ /dev/null
@@ -1,194 +0,0 @@ -// Copyright 2021 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "base/strings/strcat.h" -#include "content/public/test/browser_test.h" -#include "content/public/test/browser_test_utils.h" -#include "content/public/test/content_browser_test.h" -#include "content/public/test/content_browser_test_utils.h" -#include "content/public/test/url_loader_interceptor.h" -#include "content/shell/browser/shell.h" -#include "testing/gtest/include/gtest/gtest.h" - -namespace content { -namespace { - -constexpr char kBaseDataDir[] = "content/test/data/"; - -// Generate token with the command: -// generate_token.py https://service-worker-subresource-filter.test:443 -// ServiceWorkerSubresourceFilter -// --expire-timestamp=2000000000 -base::StringPiece origin_trial_token = - "A1CKeg8m2+M4knvICqx+5ELaI1Bh17J1+2cAfNSKCgL4zmPh4hXikI4YGxbR/QQo" - "zQyH6JOw/fqwNdWxman2RgQAAACDeyJvcmlnaW4iOiAiaHR0cHM6Ly9zZXJ2aWNl" - "LXdvcmtlci1zdWJyZXNvdXJjZS1maWx0ZXIudGVzdDo0NDMiLCAiZmVhdHVyZSI6" - "ICJTZXJ2aWNlV29ya2VyU3VicmVzb3VyY2VGaWx0ZXIiLCAiZXhwaXJ5IjogMjAw" - "MDAwMDAwMH0="; - -const std::string script = R"( - (async () => { - const saw_message = new Promise(resolve => { - navigator.serviceWorker.onmessage = event => { - resolve(event.data); - }; - }); - const registration = await navigator.serviceWorker.ready; - registration.active.postMessage(''); - return await saw_message; - })(); - )"; - -class ServiceWorkerSubresourceFilterBrowserTest - : public ContentBrowserTest, - public testing::WithParamInterface<bool> { - protected: - ServiceWorkerSubresourceFilterBrowserTest() {} - - void SetUpOnMainThread() override { - ContentBrowserTest::SetUpOnMainThread(); - - // We use a URLLoaderInterceptor, rather than the EmbeddedTestServer, since - // the origin trial token in the response is associated with a fixed origin, - // whereas EmbeddedTestServer serves content on a random port. - url_loader_interceptor_.emplace(base::BindRepeating( - &ServiceWorkerSubresourceFilterBrowserTest::InterceptRequest, - base::Unretained(this))); - } - - void TearDownOnMainThread() override { url_loader_interceptor_.reset(); } - - void NavigateAndFetch(std::string url, bool expect_only_matching_filter) { - EXPECT_TRUE(NavigateToURL( - shell(), GetUrl("/service_worker/create_service_worker.html"))); - EXPECT_EQ( - "DONE", - EvalJs( - shell(), - "register('/service_worker/fetch_event_pass_through.js', '/');")); - - GURL page_url = GetUrl(url); - GURL fetch_url = GetUrl("/echo"); - GURL fetch_url_with_fragment = GetUrl("/echo#foo"); - GURL fetch_url_with_fragment_substring = GetUrl("/echo#afooz"); - GURL fetch_url_with_other_fragment = GetUrl("/echo#bar"); - - EXPECT_TRUE(NavigateToURL(shell(), page_url)); - EXPECT_EQ("Echo", - EvalJs(shell(), JsReplace("fetch_from_page($1)", fetch_url))); - EXPECT_EQ("Echo", EvalJs(shell(), JsReplace("fetch_from_page($1)", - fetch_url_with_fragment))); - EXPECT_EQ("Echo", - EvalJs(shell(), JsReplace("fetch_from_page($1)", - fetch_url_with_fragment_substring))); - EXPECT_EQ("Echo", - EvalJs(shell(), JsReplace("fetch_from_page($1)", - fetch_url_with_other_fragment))); - - base::Value list(base::Value::Type::LIST); - if (expect_only_matching_filter) { - list.Append(page_url.spec()); - list.Append(fetch_url_with_fragment.spec()); - list.Append(fetch_url_with_fragment_substring.spec()); - } else { - list.Append(page_url.spec()); - list.Append(fetch_url.spec()); - list.Append(fetch_url_with_fragment.spec()); - list.Append(fetch_url_with_fragment_substring.spec()); - list.Append(fetch_url_with_other_fragment.spec()); - } - - EXPECT_EQ(list, EvalJs(shell(), script)); - } - - bool FeatureIsEnabled() { return GetParam(); } - - private: - static GURL GetUrl(const std::string& path) { - return GURL("https://service-worker-subresource-filter.test/") - .Resolve(path); - } - - bool InterceptRequest(URLLoaderInterceptor::RequestParams* params) { - std::string headers = - "HTTP/1.1 200 OK\n" - "Content-Type: text/html\n"; - - if (params->url_request.url.path() == "/echo") { - URLLoaderInterceptor::WriteResponse(headers, std::string("Echo"), - params->client.get(), - absl::optional<net::SSLInfo>()); - return true; - } - - if (FeatureIsEnabled()) { - base::StrAppend(&headers, {"Origin-Trial: ", origin_trial_token, "\n"}); - headers += '\n'; - } - - if (params->url_request.url.path() == "/filter") { - base::StrAppend(&headers, {"Service-Worker-Subresource-Filter: foo"}); - headers += '\n'; - URLLoaderInterceptor::WriteResponse( - base::StrCat({kBaseDataDir, "/service_worker/fetch_from_page.html"}), - params->client.get(), &headers, absl::optional<net::SSLInfo>()); - - return true; - } - - if (params->url_request.url.path() == "/nofilter") { - // Do not add any additional headers, but intercept the request. - URLLoaderInterceptor::WriteResponse( - base::StrCat({kBaseDataDir, "/service_worker/fetch_from_page.html"}), - params->client.get(), &headers, absl::optional<net::SSLInfo>()); - - return true; - } - - if (params->url_request.url.path() == "/emptyfilter") { - base::StrAppend(&headers, {"Service-Worker-Subresource-Filter:"}); - headers += '\n'; - URLLoaderInterceptor::WriteResponse( - base::StrCat({kBaseDataDir, "/service_worker/fetch_from_page.html"}), - params->client.get(), &headers, absl::optional<net::SSLInfo>()); - - return true; - } - - URLLoaderInterceptor::WriteResponse( - base::StrCat({kBaseDataDir, params->url_request.url.path_piece()}), - params->client.get()); - return true; - } - - absl::optional<URLLoaderInterceptor> url_loader_interceptor_; -}; - -INSTANTIATE_TEST_SUITE_P(EnabledDisabled, - ServiceWorkerSubresourceFilterBrowserTest, - testing::Bool()); - -IN_PROC_BROWSER_TEST_P(ServiceWorkerSubresourceFilterBrowserTest, WithFilter) { - // If the feature is disabled, all URLs should be seen by the Service Worker. - // If the feature is enabled, only the initial navigation URL and URLs - // matching the filter should be seen by the Service Worker. - NavigateAndFetch("/filter", FeatureIsEnabled()); -} - -IN_PROC_BROWSER_TEST_P(ServiceWorkerSubresourceFilterBrowserTest, - WithoutFilter) { - // All URLs should be seen by the Service Worker regardless of whether or not - // the feature is enabled. - NavigateAndFetch("/nofilter", false); -} - -IN_PROC_BROWSER_TEST_P(ServiceWorkerSubresourceFilterBrowserTest, - WithEmptyFilter) { - // All URLs should be seen by the Service Worker regardless of whether or not - // the feature is enabled. - NavigateAndFetch("/emptyfilter", false); -} - -} // namespace -} // namespace content \ No newline at end of file
diff --git a/content/browser/site_per_process_browsertest.cc b/content/browser/site_per_process_browsertest.cc index 4a9baa42..c594526a 100644 --- a/content/browser/site_per_process_browsertest.cc +++ b/content/browser/site_per_process_browsertest.cc
@@ -25,7 +25,6 @@ #include "base/containers/contains.h" #include "base/cxx17_backports.h" #include "base/feature_list.h" -#include "base/ignore_result.h" #include "base/json/json_reader.h" #include "base/location.h" #include "base/memory/ptr_util.h" @@ -4309,7 +4308,7 @@ main_frame.BindNewEndpointAndPassReceiver(); mojo::AssociatedRemote<blink::mojom::RemoteMainFrameHost> main_frame_host; - ignore_result(main_frame_host.BindNewEndpointAndPassReceiver()); + std::ignore = main_frame_host.BindNewEndpointAndPassReceiver(); remote_main_frame_interfaces->main_frame_host = main_frame_host.Unbind(); // Send the message to create a proxy for B's new child frame in A. This @@ -4389,7 +4388,7 @@ mojom::CreateFrameParamsPtr params = mojom::CreateFrameParams::New(); params->routing_id = frame_routing_id; params->frame = pending_frame.InitWithNewEndpointAndPassReceiver(); - ignore_result(params->interface_broker.InitWithNewPipeAndPassReceiver()); + std::ignore = params->interface_broker.InitWithNewPipeAndPassReceiver(); params->previous_routing_id = previous_routing_id; params->opener_frame_token = absl::nullopt; params->parent_routing_id = @@ -4472,7 +4471,7 @@ mojom::CreateFrameParamsPtr params = mojom::CreateFrameParams::New(); params->routing_id = frame_routing_id; params->frame = pending_frame.InitWithNewEndpointAndPassReceiver(); - ignore_result(params->interface_broker.InitWithNewPipeAndPassReceiver()); + std::ignore = params->interface_broker.InitWithNewPipeAndPassReceiver(); params->previous_routing_id = IPC::mojom::kRoutingIdNone; params->opener_frame_token = absl::nullopt; params->parent_routing_id = parent_routing_id; @@ -4484,10 +4483,10 @@ blink_frame_widget.InitWithNewEndpointAndPassReceiver(); params->widget_params->widget = blink_widget.InitWithNewEndpointAndPassReceiver(); - ignore_result(params->widget_params->frame_widget_host - .InitWithNewEndpointAndPassReceiver()); - ignore_result(params->widget_params->widget_host - .InitWithNewEndpointAndPassReceiver()); + std::ignore = params->widget_params->frame_widget_host + .InitWithNewEndpointAndPassReceiver(); + std::ignore = + params->widget_params->widget_host.InitWithNewEndpointAndPassReceiver(); params->widget_params->visual_properties.screen_infos = display::ScreenInfos(display::ScreenInfo()); params->replication_state = blink::mojom::FrameReplicationState::New();
diff --git a/content/browser/storage_service_restart_browsertest.cc b/content/browser/storage_service_restart_browsertest.cc index a9cedde4..9c811d4 100644 --- a/content/browser/storage_service_restart_browsertest.cc +++ b/content/browser/storage_service_restart_browsertest.cc
@@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "base/ignore_result.h" +#include <tuple> + #include "base/run_loop.h" #include "base/test/bind.h" #include "base/test/scoped_feature_list.h" @@ -115,8 +116,8 @@ // operation after a Storage Service crash. EXPECT_TRUE( NavigateToURL(shell(), GetTestUrl("dom_storage", "crash_recovery.html"))); - ignore_result( - EvalJs(shell()->web_contents(), R"(setSessionStorageValue("foo", 42))")); + std::ignore = + EvalJs(shell()->web_contents(), R"(setSessionStorageValue("foo", 42))"); // Note that for Session Storage we don't need to wait for a commit. This is // racy, but that's the point: whether or not a commit happens in time, the @@ -140,8 +141,8 @@ // after a Storage Service crash. EXPECT_TRUE( NavigateToURL(shell(), GetTestUrl("dom_storage", "crash_recovery.html"))); - ignore_result( - EvalJs(shell()->web_contents(), R"(setLocalStorageValue("foo", 42))")); + std::ignore = + EvalJs(shell()->web_contents(), R"(setLocalStorageValue("foo", 42))"); // We wait for the above storage request to be fully committed to disk. This // ensures that renderer gets the correct value when recovering from the
diff --git a/content/browser/storage_service_sandbox_browsertest.cc b/content/browser/storage_service_sandbox_browsertest.cc index 98d389f..7883888a 100644 --- a/content/browser/storage_service_sandbox_browsertest.cc +++ b/content/browser/storage_service_sandbox_browsertest.cc
@@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "base/ignore_result.h" +#include <tuple> + #include "base/run_loop.h" #include "base/test/bind.h" #include "base/test/scoped_feature_list.h" @@ -96,8 +97,8 @@ IN_PROC_BROWSER_TEST_F(StorageServiceSandboxBrowserTest, PRE_DomStorage) { EXPECT_TRUE(NavigateToURL(shell(), GetTestUrl(nullptr, "empty.html"))); - ignore_result( - EvalJs(shell()->web_contents(), R"(window.localStorage.yeet = 42)")); + std::ignore = + EvalJs(shell()->web_contents(), R"(window.localStorage.yeet = 42)"); WaitForAnyLocalStorageData(); FlushLocalStorage(); }
diff --git a/content/browser/tracing/tracing_controller_impl_data_endpoint.cc b/content/browser/tracing/tracing_controller_impl_data_endpoint.cc index 59642db..314a3734 100644 --- a/content/browser/tracing/tracing_controller_impl_data_endpoint.cc +++ b/content/browser/tracing/tracing_controller_impl_data_endpoint.cc
@@ -3,11 +3,11 @@ // found in the LICENSE file. #include <memory> +#include <tuple> #include <utility> #include "base/bind.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/json/json_writer.h" #include "base/memory/raw_ptr.h" #include "base/memory/ref_counted_memory.h" @@ -89,7 +89,7 @@ void ReceiveTraceChunkOnBlockingThread(std::unique_ptr<std::string> chunk) { if (!OpenFileIfNeededOnBlockingThread()) return; - ignore_result(fwrite(chunk->c_str(), chunk->size(), 1, file_)); + std::ignore = fwrite(chunk->c_str(), chunk->size(), 1, file_); } bool OpenFileIfNeededOnBlockingThread() {
diff --git a/content/browser/web_contents/web_contents_impl.cc b/content/browser/web_contents/web_contents_impl.cc index e6d5b77..daec329 100644 --- a/content/browser/web_contents/web_contents_impl.cc +++ b/content/browser/web_contents/web_contents_impl.cc
@@ -5263,7 +5263,8 @@ void WebContentsImpl::Find(int request_id, const std::u16string& search_text, - blink::mojom::FindOptionsPtr options) { + blink::mojom::FindOptionsPtr options, + bool skip_delay) { OPTIONAL_TRACE_EVENT0("content", "WebContentsImpl::Find"); // Cowardly refuse to search for no text. if (search_text.empty()) { @@ -5272,7 +5273,7 @@ } GetOrCreateFindRequestManager()->Find(request_id, search_text, - std::move(options)); + std::move(options), skip_delay); } void WebContentsImpl::StopFinding(StopFindAction action) {
diff --git a/content/browser/web_contents/web_contents_impl.h b/content/browser/web_contents/web_contents_impl.h index 2633685..eea51b44 100644 --- a/content/browser/web_contents/web_contents_impl.h +++ b/content/browser/web_contents/web_contents_impl.h
@@ -527,7 +527,8 @@ WebContents::ImageDownloadCallback callback) override; void Find(int request_id, const std::u16string& search_text, - blink::mojom::FindOptionsPtr options) override; + blink::mojom::FindOptionsPtr options, + bool skip_delay = false) override; void StopFinding(StopFindAction action) override; bool WasEverAudible() override; bool IsFullscreen() override;
diff --git a/content/browser/web_contents/web_contents_impl_browsertest.cc b/content/browser/web_contents/web_contents_impl_browsertest.cc index 8bd037d8..c3b8ade 100644 --- a/content/browser/web_contents/web_contents_impl_browsertest.cc +++ b/content/browser/web_contents/web_contents_impl_browsertest.cc
@@ -4,6 +4,7 @@ #include <algorithm> #include <array> +#include <tuple> #include <utility> #include <vector> @@ -13,7 +14,6 @@ #include "base/containers/contains.h" #include "base/feature_list.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/memory/ptr_util.h" #include "base/memory/raw_ptr.h" #include "base/memory/ref_counted.h" @@ -4959,7 +4959,7 @@ web_contents->UpdateWindowControlsOverlay(bounding_client_rect); TitleWatcher title_watcher(web_contents, u"ongeometrychange"); - ignore_result(title_watcher.WaitAndGetTitle()); + std::ignore = title_watcher.WaitAndGetTitle(); } private: @@ -5039,7 +5039,7 @@ gfx::Rect bounding_client_rect = gfx::Rect(2, 3, 4, 5); web_contents->UpdateWindowControlsOverlay(bounding_client_rect); TitleWatcher title_watcher(web_contents, u"ongeometrychange1"); - ignore_result(title_watcher.WaitAndGetTitle()); + std::ignore = title_watcher.WaitAndGetTitle(); // Expect the "geometrychange" event to have fired once. EXPECT_EQ(1, EvalJs(web_contents, "geometrychangeCount")); @@ -5082,7 +5082,7 @@ "}")); content::HostZoomMap::SetZoomLevel(web_contents, 1.5); TitleWatcher title_watcher(web_contents, u"ongeometrychangefromzoomlevel"); - ignore_result(title_watcher.WaitAndGetTitle()); + std::ignore = title_watcher.WaitAndGetTitle(); // Validate the event payload. double zoom_factor = blink::PageZoomLevelToZoomFactor( @@ -5435,8 +5435,8 @@ // now disabled. { PCScanReadyToCommitObserver observer(shell()->web_contents()); - ignore_result(ExecJs(shell()->web_contents()->GetMainFrame(), - JsReplace("location = $1", prendering_url))); + std::ignore = ExecJs(shell()->web_contents()->GetMainFrame(), + JsReplace("location = $1", prendering_url)); observer.Wait(); EXPECT_FALSE(observer.WasPCScanEnabled()); }
diff --git a/content/browser/webauth/authenticator_common.cc b/content/browser/webauth/authenticator_common.cc index a4cb8fb..126ab52 100644 --- a/content/browser/webauth/authenticator_common.cc +++ b/content/browser/webauth/authenticator_common.cc
@@ -45,6 +45,7 @@ #include "device/fido/fido_constants.h" #include "device/fido/fido_parsing_utils.h" #include "device/fido/fido_transport_protocol.h" +#include "device/fido/fido_types.h" #include "device/fido/filter.h" #include "device/fido/get_assertion_request_handler.h" #include "device/fido/make_credential_request_handler.h" @@ -1290,14 +1291,15 @@ DCHECK(response_data.has_value()); DCHECK(authenticator); - transport_ = authenticator->AuthenticatorTransport(); + absl::optional<device::FidoTransportProtocol> transport = + authenticator->AuthenticatorTransport(); bool is_transport_used_internal = false; bool is_transport_used_cable = false; - if (transport_) { + if (transport) { is_transport_used_internal = - *transport_ == device::FidoTransportProtocol::kInternal; + *transport == device::FidoTransportProtocol::kInternal; is_transport_used_cable = - *transport_ == + *transport == device::FidoTransportProtocol::kCloudAssistedBluetoothLowEnergy; } @@ -1505,7 +1507,6 @@ case device::GetAssertionStatus::kSuccess: DCHECK(response_data.has_value()); DCHECK(authenticator); - transport_ = authenticator->AuthenticatorTransport(); // Show an account picker for requests with empty allow lists. // Authenticators may omit the identifying information in the user entity @@ -1614,6 +1615,12 @@ common_info->id = Base64UrlEncode(common_info->raw_id); response->info = std::move(common_info); + response->authenticator_attachment = + response_data.transport_used() + ? device::AuthenticatorAttachmentFromTransport( + *response_data.transport_used()) + : device::AuthenticatorAttachment::kAny; + // The transport list must not contain duplicates but the order doesn't matter // because Blink will sort the resulting strings before returning them. std::vector<device::FidoTransportProtocol> transports; @@ -1631,9 +1638,6 @@ } response->transports = std::move(transports); - response->has_transport = transport_.has_value(); - if (response->has_transport) - response->transport = *transport_; bool did_create_hmac_secret = false; bool did_store_cred_blob = false; @@ -1759,9 +1763,11 @@ response->info->authenticator_data = response_data.authenticator_data.SerializeToByteArray(); response->signature = response_data.signature; - response->has_transport = transport_.has_value(); - if (response->has_transport) - response->transport = *transport_; + response->authenticator_attachment = + response_data.transport_used + ? device::AuthenticatorAttachmentFromTransport( + *response_data.transport_used) + : device::AuthenticatorAttachment::kAny; response_data.user_entity ? response->user_handle.emplace(response_data.user_entity->id) : response->user_handle.emplace();
diff --git a/content/browser/webauth/authenticator_common.h b/content/browser/webauth/authenticator_common.h index fc9c50f..193a6ca 100644 --- a/content/browser/webauth/authenticator_common.h +++ b/content/browser/webauth/authenticator_common.h
@@ -234,9 +234,6 @@ blink::mojom::Authenticator::GetAssertionCallback get_assertion_response_callback_; std::string client_data_json_; - // Transport used during authentication. May be empty if unknown, e.g. on old - // Windows. - absl::optional<device::FidoTransportProtocol> transport_; // empty_allow_list_ is true iff a GetAssertion is currently pending and the // request did not list any credential IDs in the allow list. bool empty_allow_list_ = false;
diff --git a/content/browser/webauth/authenticator_impl_unittest.cc b/content/browser/webauth/authenticator_impl_unittest.cc index dfcf3e7..738354e 100644 --- a/content/browser/webauth/authenticator_impl_unittest.cc +++ b/content/browser/webauth/authenticator_impl_unittest.cc
@@ -62,6 +62,7 @@ #include "device/fido/fido_authenticator.h" #include "device/fido/fido_constants.h" #include "device/fido/fido_test_data.h" +#include "device/fido/fido_transport_protocol.h" #include "device/fido/fido_types.h" #include "device/fido/filter.h" #include "device/fido/hid/fake_hid_impl_for_testing.h" @@ -4150,6 +4151,10 @@ device::FidoTransportProtocol::kBluetoothLowEnergy, device::FidoTransportProtocol::kNearFieldCommunication, device::FidoTransportProtocol::kInternal}) { + device::AuthenticatorAttachment attachment = + (transport == device::FidoTransportProtocol::kInternal + ? device::AuthenticatorAttachment::kPlatform + : device::AuthenticatorAttachment::kCrossPlatform); ResetVirtualDevice(); virtual_device_factory_->SetSupportedProtocol( device::ProtocolVersion::kCtap2); @@ -4161,7 +4166,7 @@ MakeCredentialResult create_result = AuthenticatorMakeCredential(std::move(create_options)); ASSERT_EQ(create_result.status, AuthenticatorStatus::SUCCESS); - EXPECT_EQ(create_result.response->transport, transport); + EXPECT_EQ(create_result.response->authenticator_attachment, attachment); PublicKeyCredentialRequestOptionsPtr get_options = GetTestPublicKeyCredentialRequestOptions(); @@ -4172,7 +4177,7 @@ GetAssertionResult get_result = AuthenticatorGetAssertion(std::move(get_options)); ASSERT_EQ(get_result.status, AuthenticatorStatus::SUCCESS); - EXPECT_EQ(get_result.response->transport, transport); + EXPECT_EQ(get_result.response->authenticator_attachment, attachment); } }
diff --git a/content/browser/webrtc/webrtc_internals.cc b/content/browser/webrtc/webrtc_internals.cc index c9da582..513c74b 100644 --- a/content/browser/webrtc/webrtc_internals.cc +++ b/content/browser/webrtc/webrtc_internals.cc
@@ -193,7 +193,7 @@ auto it = FindRecord(frame_id, lid); if (it != peer_connection_data_.GetList().end()) { - MaybeClosePeerConnection(&*it); + MaybeClosePeerConnection(*it); peer_connection_data_.EraseListIter(it); } @@ -217,13 +217,13 @@ if (type == "iceconnectionstatechange") { if (value == "connected" || value == "checking" || value == "completed") { - MaybeMarkPeerConnectionAsConnected(&*it); + MaybeMarkPeerConnectionAsConnected(*it); } else if (value == "failed" || value == "disconnected" || value == "closed" || value == "new") { - MaybeMarkPeerConnectionAsNotConnected(&*it); + MaybeMarkPeerConnectionAsNotConnected(*it); } } else if (type == "close") { - MaybeClosePeerConnection(&*it); + MaybeClosePeerConnection(*it); } else if (type == "setConfiguration") { // Update the configuration we have for this connection. it->SetStringKey("rtcConfiguration", value); @@ -605,13 +605,11 @@ // by the exiting renderer. base::Value::ListView peer_conn_view = peer_connection_data_.GetList(); for (int i = peer_conn_view.size() - 1; i >= 0; --i) { - base::DictionaryValue* record = nullptr; DCHECK(peer_conn_view[i].is_dict()); - peer_conn_view[i].GetAsDictionary(&record); absl::optional<int> this_rid, this_lid; - this_rid = record->FindIntKey("rid"); - this_lid = record->FindIntKey("lid"); + this_rid = peer_conn_view[i].FindIntKey("rid"); + this_lid = peer_conn_view[i].FindIntKey("lid"); if (this_rid.value_or(0) == render_process_id) { if (!observers_.empty()) { @@ -620,7 +618,7 @@ update.SetIntKey("lid", this_lid.value_or(0)); SendUpdate("remove-peer-connection", std::move(update)); } - MaybeClosePeerConnection(record); + MaybeClosePeerConnection(peer_conn_view[i]); peer_connection_data_.EraseListIter(peer_conn_view.begin() + i); } } @@ -632,11 +630,10 @@ base::Value::ListView get_user_media_requests_view = get_user_media_requests_.GetList(); for (int i = get_user_media_requests_view.size() - 1; i >= 0; --i) { - base::DictionaryValue* record = nullptr; DCHECK(get_user_media_requests_view[i].is_dict()); - get_user_media_requests_view[i].GetAsDictionary(&record); - absl::optional<int> this_rid = record->FindIntKey("rid"); + absl::optional<int> this_rid = + get_user_media_requests_view[i].FindIntKey("rid"); if (this_rid.value_or(0) == render_process_id) { get_user_media_requests_.EraseListIter( @@ -669,21 +666,21 @@ } } -void WebRTCInternals::MaybeClosePeerConnection(base::Value* record) { - absl::optional<bool> is_open = record->FindBoolKey("isOpen"); +void WebRTCInternals::MaybeClosePeerConnection(base::Value& record) { + absl::optional<bool> is_open = record.FindBoolKey("isOpen"); DCHECK(is_open.has_value()); if (!*is_open) return; - record->SetBoolKey("isOpen", false); + record.SetBoolKey("isOpen", false); MaybeMarkPeerConnectionAsNotConnected(record); } -void WebRTCInternals::MaybeMarkPeerConnectionAsConnected(base::Value* record) { - bool was_connected = record->FindBoolKey("connected").value_or(true); +void WebRTCInternals::MaybeMarkPeerConnectionAsConnected(base::Value& record) { + bool was_connected = record.FindBoolKey("connected").value_or(true); if (!was_connected) { ++num_connected_connections_; - record->SetBoolKey("connected", true); + record.SetBoolKey("connected", true); UpdateWakeLock(); for (auto& observer : connections_observers_) observer.OnConnectionsCountChange(num_connected_connections_); @@ -691,10 +688,10 @@ } void WebRTCInternals::MaybeMarkPeerConnectionAsNotConnected( - base::Value* record) { - bool was_connected = record->FindBoolKey("connected").value_or(false); + base::Value& record) { + bool was_connected = record.FindBoolKey("connected").value_or(false); if (was_connected) { - record->SetBoolKey("connected", false); + record.SetBoolKey("connected", false); --num_connected_connections_; DCHECK_GE(num_connected_connections_, 0); UpdateWakeLock(); @@ -753,12 +750,10 @@ base::Value::ListView peer_conn_view = peer_connection_data_.GetList(); for (auto it = peer_conn_view.begin(); it != peer_conn_view.end(); ++it) { - base::DictionaryValue* record = nullptr; DCHECK(it->is_dict()); - it->GetAsDictionary(&record); - int this_rid = record->FindIntKey("rid").value_or(0); - int this_lid = record->FindIntKey("lid").value_or(0); + int this_rid = it->FindIntKey("rid").value_or(0); + int this_lid = it->FindIntKey("lid").value_or(0); if (this_rid == frame_id.child_id && this_lid == lid) return it;
diff --git a/content/browser/webrtc/webrtc_internals.h b/content/browser/webrtc/webrtc_internals.h index 8f484fb..530b8b9 100644 --- a/content/browser/webrtc/webrtc_internals.h +++ b/content/browser/webrtc/webrtc_internals.h
@@ -167,10 +167,10 @@ // Updates the number of open PeerConnections. Called when a PeerConnection // is stopped or removed. - void MaybeClosePeerConnection(base::Value* record); + void MaybeClosePeerConnection(base::Value& record); - void MaybeMarkPeerConnectionAsConnected(base::Value* record); - void MaybeMarkPeerConnectionAsNotConnected(base::Value* record); + void MaybeMarkPeerConnectionAsConnected(base::Value& record); + void MaybeMarkPeerConnectionAsNotConnected(base::Value& record); // Called whenever a PeerConnection is created or stopped in order to // request/cancel a wake lock on suspending the current application for power
diff --git a/content/child/child_thread_impl.cc b/content/child/child_thread_impl.cc index 35850e1..f9026fd5 100644 --- a/content/child/child_thread_impl.cc +++ b/content/child/child_thread_impl.cc
@@ -8,6 +8,7 @@ #include <memory> #include <string> +#include <tuple> #include <utility> #include "base/base_switches.h" @@ -19,7 +20,6 @@ #include "base/debug/leak_annotations.h" #include "base/debug/profiler.h" #include "base/files/file.h" -#include "base/ignore_result.h" #include "base/lazy_instance.h" #include "base/location.h" #include "base/logging.h" @@ -154,7 +154,7 @@ // delegate object alive for the lifetime of the process. WaitAndExitDelegate* leaking_delegate = delegate.release(); ANNOTATE_LEAKING_OBJECT_PTR(leaking_delegate); - ignore_result(leaking_delegate); + std::ignore = leaking_delegate; return true; } #endif
diff --git a/content/child/field_trial.cc b/content/child/field_trial.cc index 93f12ea..21847d0 100644 --- a/content/child/field_trial.cc +++ b/content/child/field_trial.cc
@@ -3,11 +3,13 @@ // found in the LICENSE file. #include "content/child/field_trial.h" + +#include <tuple> + #include "base/base_switches.h" #include "base/command_line.h" #include "base/debug/leak_annotations.h" #include "base/feature_list.h" -#include "base/ignore_result.h" #include "base/metrics/field_trial.h" #include "build/build_config.h" #include "content/public/common/content_descriptors.h" @@ -30,7 +32,7 @@ base::FieldTrialList* leaked_field_trial_list = new base::FieldTrialList(nullptr); ANNOTATE_LEAKING_OBJECT_PTR(leaked_field_trial_list); - ignore_result(leaked_field_trial_list); + std::ignore = leaked_field_trial_list; // Ensure any field trials in browser are reflected into the child process. base::FieldTrialList::CreateTrialsFromCommandLine(command_line,
diff --git a/content/common/child_process_host_impl.cc b/content/common/child_process_host_impl.cc index e573e26..1a63031 100644 --- a/content/common/child_process_host_impl.cc +++ b/content/common/child_process_host_impl.cc
@@ -5,6 +5,7 @@ #include "content/common/child_process_host_impl.h" #include <limits> +#include <tuple> #include "base/atomic_sequence_num.h" #include "base/clang_profiling_buildflags.h" @@ -12,7 +13,6 @@ #include "base/files/file.h" #include "base/files/file_path.h" #include "base/hash/hash.h" -#include "base/ignore_result.h" #include "base/logging.h" #include "base/memory/ptr_util.h" #include "base/metrics/histogram_macros.h" @@ -123,7 +123,7 @@ if (ipc_mode_ == IpcMode::kLegacy) { // In legacy mode, we only have an IPC Channel. Bind ChildProcess to a // disconnected pipe so it quietly discards messages. - ignore_result(child_process_.BindNewPipeAndPassReceiver()); + std::ignore = child_process_.BindNewPipeAndPassReceiver(); channel_ = IPC::ChannelMojo::Create( mojo_invitation_->AttachMessagePipe( kChildProcessReceiverAttachmentName),
diff --git a/content/public/android/java/src/org/chromium/content/browser/selection/SelectionPopupControllerImpl.java b/content/public/android/java/src/org/chromium/content/browser/selection/SelectionPopupControllerImpl.java index 2693d46..6b20051 100644 --- a/content/public/android/java/src/org/chromium/content/browser/selection/SelectionPopupControllerImpl.java +++ b/content/public/android/java/src/org/chromium/content/browser/selection/SelectionPopupControllerImpl.java
@@ -1501,10 +1501,11 @@ } @CalledByNative - private void onSelectAroundCaretSuccess(int extendedStartAdjust, int extendedEndAdjust) { + private void onSelectAroundCaretSuccess(int extendedStartAdjust, int extendedEndAdjust, + int wordStartAdjust, int wordEndAdjust) { if (mSelectionClient != null) { - SelectAroundCaretResult result = - new SelectAroundCaretResult(extendedStartAdjust, extendedEndAdjust); + SelectAroundCaretResult result = new SelectAroundCaretResult( + extendedStartAdjust, extendedEndAdjust, wordStartAdjust, wordEndAdjust); mSelectionClient.selectAroundCaretAck(result); } }
diff --git a/content/public/android/java/src/org/chromium/content_public/browser/SelectAroundCaretResult.java b/content/public/android/java/src/org/chromium/content_public/browser/SelectAroundCaretResult.java index f4532bf..aba1950 100644 --- a/content/public/android/java/src/org/chromium/content_public/browser/SelectAroundCaretResult.java +++ b/content/public/android/java/src/org/chromium/content_public/browser/SelectAroundCaretResult.java
@@ -11,32 +11,57 @@ public class SelectAroundCaretResult { private final int mExtendedStartAdjust; private final int mExtendedEndAdjust; + private final int mWordStartAdjust; + private final int mWordEndAdjust; /** - * @return The start offset difference between the extended selection (if - * not using Word granularity) and the previous selection (caret). + * @return The start offset difference between the extended selection and the initial selection + * (caret). */ public int getExtendedStartAdjust() { return mExtendedStartAdjust; } /** - * @return The end offset difference between the extended selection (if not - * using Word granularity) and the previous selection (caret). + * @return The end offset difference between the extended selection and the initial selection + * (caret). */ public int getExtendedEndAdjust() { return mExtendedEndAdjust; } /** + * @return The start offset difference between the word selection (regardless of the extended + * selection granularity) and the initial selection (caret). + */ + public int getWordStartAdjust() { + return mWordStartAdjust; + } + + /** + * @return The end offset difference between the word selection (regardless of the extended + * selection granularity) and the initial selection (caret). + */ + public int getWordEndAdjust() { + return mWordEndAdjust; + } + + /** * Create {@link SelectAroundCaretResult} instance. * @param extendedStartAdjust The start offset difference between the extended selection and the - * previous selection (caret). + * initial selection (caret). * @param extendedEndAdjust The end offset difference between the extended selection and the - * previous selection (caret). + * initial selection (caret). + * @param wordStartAdjust The start offset difference between the word selection (regardless of + * the extended selection granularity) and the initial selection (caret). + * @param wordEndAdjust The end offset difference between the word selection (regardless of the + * extended selection granularity) and the initial selection (caret). */ - public SelectAroundCaretResult(int extendedStartAdjust, int extendedEndAdjust) { + public SelectAroundCaretResult(int extendedStartAdjust, int extendedEndAdjust, + int wordStartAdjust, int wordEndAdjust) { mExtendedStartAdjust = extendedStartAdjust; mExtendedEndAdjust = extendedEndAdjust; + mWordStartAdjust = wordStartAdjust; + mWordEndAdjust = wordEndAdjust; } } \ No newline at end of file
diff --git a/content/public/android/javatests/src/org/chromium/content/browser/accessibility/WebContentsAccessibilityTreeTest.java b/content/public/android/javatests/src/org/chromium/content/browser/accessibility/WebContentsAccessibilityTreeTest.java index c1c8e798..54686b30 100644 --- a/content/public/android/javatests/src/org/chromium/content/browser/accessibility/WebContentsAccessibilityTreeTest.java +++ b/content/public/android/javatests/src/org/chromium/content/browser/accessibility/WebContentsAccessibilityTreeTest.java
@@ -454,6 +454,7 @@ @Test @SmallTest + @DisabledTest(message = "https://crbug.com/1282189") public void test_ariaHiddenIframe() { performAriaTest("aria-hidden-iframe.html"); }
diff --git a/content/public/browser/web_contents.h b/content/public/browser/web_contents.h index 3e209a6..6dd28a6 100644 --- a/content/public/browser/web_contents.h +++ b/content/public/browser/web_contents.h
@@ -1121,10 +1121,13 @@ bool bypass_cache, ImageDownloadCallback callback) = 0; - // Finds text on a page. |search_text| should not be empty. + // Finds text on a page. |search_text| should not be empty. |skip_delay| + // indicates that the find request should be sent to the renderer immediately + // instead of waiting for privacy/performance mitigations. virtual void Find(int request_id, const std::u16string& search_text, - blink::mojom::FindOptionsPtr options) = 0; + blink::mojom::FindOptionsPtr options, + bool skip_delay = false) = 0; // Notifies the renderer that the user has closed the FindInPage window // (and what action to take regarding the selection).
diff --git a/content/public/common/content_features.cc b/content/public/common/content_features.cc index 7a0026d..98857a4 100644 --- a/content/public/common/content_features.cc +++ b/content/public/common/content_features.cc
@@ -813,14 +813,6 @@ base::FEATURE_DISABLED_BY_DEFAULT}; #endif -// Experiment allowing control over what requests are intercepted by Service -// Worker fetch events. By setting a Service-Worker-Subresource-Filter HTTP -// header on the document to some string, only requests which contain a fragment -// matching the header string will be intercepted. When not set, Service Workers -// will intercept all requests, as normal. -const base::Feature kServiceWorkerSubresourceFilter{ - "ServiceWorkerSubresourceFilter", base::FEATURE_DISABLED_BY_DEFAULT}; - // Controls whether to isolate sites of documents that specify an eligible // Cross-Origin-Opener-Policy header. Note that this is only intended to be // used on Android, which does not use strict site isolation. See
diff --git a/content/public/common/content_features.h b/content/public/common/content_features.h index 4bb2105..af6d6dfb 100644 --- a/content/public/common/content_features.h +++ b/content/public/common/content_features.h
@@ -202,7 +202,6 @@ #endif // defined(OS_CHROMEOS) CONTENT_EXPORT extern const base::Feature kWebOTP; CONTENT_EXPORT extern const base::Feature kWebOTPAssertionFeaturePolicy; -CONTENT_EXPORT extern const base::Feature kServiceWorkerSubresourceFilter; CONTENT_EXPORT extern const base::Feature kSpareRendererForSitePerProcess; CONTENT_EXPORT extern const base::Feature kStopVideoCaptureOnScreenLock; CONTENT_EXPORT extern const base::Feature kStorageServiceOutOfProcess;
diff --git a/content/public/common/custom_handlers/protocol_handler.cc b/content/public/common/custom_handlers/protocol_handler.cc index 3488c18..b2b64c85 100644 --- a/content/public/common/custom_handlers/protocol_handler.cc +++ b/content/public/common/custom_handlers/protocol_handler.cc
@@ -61,7 +61,7 @@ bool ProtocolHandler::IsValidDict(const base::DictionaryValue* value) { // Note that "title" parameter is ignored. // The |last_modified| field is optional as it was introduced in M68. - return value->HasKey("protocol") && value->HasKey("url"); + return value->FindKey("protocol") && value->FindKey("url"); } bool ProtocolHandler::IsValid() const { @@ -105,23 +105,26 @@ base::Time time; blink::ProtocolHandlerSecurityLevel security_level = blink::ProtocolHandlerSecurityLevel::kStrict; - value->GetString("protocol", &protocol); - value->GetString("url", &url); + if (const std::string* protocol_in = value->FindStringKey("protocol")) + protocol = *protocol_in; + if (const std::string* url_in = value->FindStringKey("url")) + url = *url_in; absl::optional<base::Time> time_value = base::ValueToTime(value->FindKey("last_modified")); // Treat invalid times as the default value. if (time_value) time = *time_value; absl::optional<int> security_level_value = - value->FindIntPath("security_level"); + value->FindIntKey("security_level"); if (security_level_value) { security_level = blink::ProtocolHandlerSecurityLevelFrom(*security_level_value); } - if (value->HasKey("app_id")) { + if (const base::Value* app_id_val = value->FindKey("app_id")) { std::string app_id; - value->GetString("app_id", &app_id); + if (app_id_val->is_string()) + app_id = app_id_val->GetString(); return ProtocolHandler(protocol, GURL(url), app_id, time, security_level); }
diff --git a/content/public/test/mock_render_process_host.cc b/content/public/test/mock_render_process_host.cc index 29d3de6..f94ee55 100644 --- a/content/public/test/mock_render_process_host.cc +++ b/content/public/test/mock_render_process_host.cc
@@ -5,12 +5,12 @@ #include "content/public/test/mock_render_process_host.h" #include <algorithm> +#include <tuple> #include <utility> #include <vector> #include "base/bind.h" #include "base/callback_helpers.h" -#include "base/ignore_result.h" #include "base/lazy_instance.h" #include "base/location.h" #include "base/no_destructor.h" @@ -456,8 +456,8 @@ if (!renderer_interface_) { renderer_interface_ = std::make_unique<mojo::AssociatedRemote<mojom::Renderer>>(); - ignore_result( - renderer_interface_->BindNewEndpointAndPassDedicatedReceiver()); + std::ignore = + renderer_interface_->BindNewEndpointAndPassDedicatedReceiver(); } return renderer_interface_->get(); }
diff --git a/content/public/test/mock_render_thread.cc b/content/public/test/mock_render_thread.cc index 87aa32b..0daad07 100644 --- a/content/public/test/mock_render_thread.cc +++ b/content/public/test/mock_render_thread.cc
@@ -5,8 +5,8 @@ #include "content/public/test/mock_render_thread.h" #include <memory> +#include <tuple> -#include "base/ignore_result.h" #include "base/logging.h" #include "base/task/single_thread_task_runner.h" #include "base/threading/thread_task_runner_handle.h" @@ -310,13 +310,13 @@ blink_frame_widget_receiver = blink_frame_widget.BindNewEndpointAndPassDedicatedReceiver(); mojo::AssociatedRemote<blink::mojom::FrameWidgetHost> blink_frame_widget_host; - ignore_result( - blink_frame_widget_host.BindNewEndpointAndPassDedicatedReceiver()); + std::ignore = + blink_frame_widget_host.BindNewEndpointAndPassDedicatedReceiver(); mojo::AssociatedRemote<blink::mojom::Widget> blink_widget; mojo::PendingAssociatedReceiver<blink::mojom::Widget> blink_widget_receiver = blink_widget.BindNewEndpointAndPassDedicatedReceiver(); mojo::AssociatedRemote<blink::mojom::WidgetHost> blink_widget_host; - ignore_result(blink_widget_host.BindNewEndpointAndPassDedicatedReceiver()); + std::ignore = blink_widget_host.BindNewEndpointAndPassDedicatedReceiver(); widget_params->frame_widget = std::move(blink_frame_widget_receiver); widget_params->frame_widget_host = blink_frame_widget_host.Unbind();
diff --git a/content/public/test/policy_container_utils.cc b/content/public/test/policy_container_utils.cc index 3cab2d5..2c85ec0 100644 --- a/content/public/test/policy_container_utils.cc +++ b/content/public/test/policy_container_utils.cc
@@ -4,7 +4,8 @@ #include "content/public/test/policy_container_utils.h" -#include "base/ignore_result.h" +#include <tuple> + #include "third_party/blink/public/mojom/frame/policy_container.mojom.h" namespace content { @@ -12,8 +13,8 @@ blink::mojom::PolicyContainerPtr CreateStubPolicyContainer() { mojo::PendingAssociatedRemote<blink::mojom::PolicyContainerHost> stub_policy_container_remote; - ignore_result( - stub_policy_container_remote.InitWithNewEndpointAndPassReceiver()); + std::ignore = + stub_policy_container_remote.InitWithNewEndpointAndPassReceiver(); return blink::mojom::PolicyContainer::New( blink::mojom::PolicyContainerPolicies::New(), std::move(stub_policy_container_remote));
diff --git a/content/public/test/prerender_test_util.cc b/content/public/test/prerender_test_util.cc index db730e5..262f11b 100644 --- a/content/public/test/prerender_test_util.cc +++ b/content/public/test/prerender_test_util.cc
@@ -4,8 +4,9 @@ #include "content/public/test/prerender_test_util.h" +#include <tuple> + #include "base/callback_helpers.h" -#include "base/ignore_result.h" #include "base/trace_event/typed_macros.h" #include "content/browser/prerender/prerender_host_registry.h" #include "content/browser/renderer_host/frame_tree.h" @@ -346,8 +347,8 @@ // approach just to ignore it instead of fixing the timing issue. When // ExecJs() actually fails, the remaining test steps should fail, so it // should be safe to ignore it. - ignore_result( - ExecJs(prerender_render_frame_host, JsReplace("location = $1", gurl))); + std::ignore = + ExecJs(prerender_render_frame_host, JsReplace("location = $1", gurl)); } // static @@ -380,8 +381,8 @@ // approach just to ignore it instead of fixing the timing issue. When // ExecJs() actually fails, the remaining test steps should fail, so it // should be safe to ignore it. - ignore_result( - ExecJs(web_contents.GetMainFrame(), JsReplace("location = $1", gurl))); + std::ignore = + ExecJs(web_contents.GetMainFrame(), JsReplace("location = $1", gurl)); observer.Wait(); }
diff --git a/content/public/test/referrer_unittest.cc b/content/public/test/referrer_unittest.cc index 1962891..fc78f84 100644 --- a/content/public/test/referrer_unittest.cc +++ b/content/public/test/referrer_unittest.cc
@@ -4,7 +4,8 @@ #include "content/public/common/referrer.h" -#include "base/ignore_result.h" +#include <tuple> + #include "base/test/gtest_util.h" #include "net/url_request/referrer_policy.h" #include "testing/gtest/include/gtest/gtest.h" @@ -16,16 +17,18 @@ using ReferrerSanitizerTest = testing::Test; TEST_F(ReferrerSanitizerTest, SanitizesPolicyForEmptyReferrers) { - EXPECT_DCHECK_DEATH(ignore_result(Referrer::SanitizeForRequest( - GURL("https://a"), - Referrer(GURL(), static_cast<network::mojom::ReferrerPolicy>(200))))); + EXPECT_DCHECK_DEATH( + std::ignore = Referrer::SanitizeForRequest( + GURL("https://a"), + Referrer(GURL(), static_cast<network::mojom::ReferrerPolicy>(200)))); } TEST_F(ReferrerSanitizerTest, SanitizesPolicyForNonEmptyReferrers) { - EXPECT_DCHECK_DEATH(ignore_result(Referrer::SanitizeForRequest( - GURL("https://a"), - Referrer(GURL("http://b"), - static_cast<network::mojom::ReferrerPolicy>(200))))); + EXPECT_DCHECK_DEATH( + std::ignore = Referrer::SanitizeForRequest( + GURL("https://a"), + Referrer(GURL("http://b"), + static_cast<network::mojom::ReferrerPolicy>(200)))); } TEST_F(ReferrerSanitizerTest, SanitizeOriginForRequest) {
diff --git a/content/public/test/render_view_test.cc b/content/public/test/render_view_test.cc index 3e83516..9cd9b18f 100644 --- a/content/public/test/render_view_test.cc +++ b/content/public/test/render_view_test.cc
@@ -7,10 +7,10 @@ #include <stddef.h> #include <cctype> +#include <tuple> #include "base/bind.h" #include "base/callback_helpers.h" -#include "base/ignore_result.h" #include "base/location.h" #include "base/memory/raw_ptr.h" #include "base/run_loop.h" @@ -514,8 +514,8 @@ main_frame_params->routing_id = render_thread_->GetNextRoutingID(); main_frame_params->frame = TestRenderFrame::CreateStubFrameReceiver(); // Ignoring the returned PendingReceiver because it is not bound to anything - ignore_result( - main_frame_params->interface_broker.InitWithNewPipeAndPassReceiver()); + std::ignore = + main_frame_params->interface_broker.InitWithNewPipeAndPassReceiver(); policy_container_host_ = std::make_unique<MockPolicyContainerHost>(); main_frame_params->policy_container = policy_container_host_->CreatePolicyContainerForBlink(); @@ -561,7 +561,7 @@ mojo::Remote<blink::mojom::LeakDetector> leak_detector; mojo::GenericPendingReceiver receiver( leak_detector.BindNewPipeAndPassReceiver()); - ignore_result(binders_.TryBind(&receiver)); + std::ignore = binders_.TryBind(&receiver); // Close the main |view_| as well as any other windows that might have been // opened by the test.
diff --git a/content/public/test/test_storage_partition.cc b/content/public/test/test_storage_partition.cc index bc39c16..2e8e73f 100644 --- a/content/public/test/test_storage_partition.cc +++ b/content/public/test/test_storage_partition.cc
@@ -4,7 +4,8 @@ #include "content/public/test/test_storage_partition.h" -#include "base/ignore_result.h" +#include <tuple> + #include "components/leveldb_proto/public/proto_database_provider.h" #include "content/public/browser/file_system_access_entry_factory.h" #include "services/network/public/mojom/cookie_manager.mojom.h" @@ -89,7 +90,7 @@ // Bind and throw away the receiver. If testing is required, then add a method // to set the remote. if (!local_storage_control_.is_bound()) - ignore_result(local_storage_control_.BindNewPipeAndPassReceiver()); + std::ignore = local_storage_control_.BindNewPipeAndPassReceiver(); return local_storage_control_.get(); } @@ -97,7 +98,7 @@ // Bind and throw away the receiver. If testing is required, then add a method // to set the remote. if (!indexed_db_control_.is_bound()) - ignore_result(indexed_db_control_.BindNewPipeAndPassReceiver()); + std::ignore = indexed_db_control_.BindNewPipeAndPassReceiver(); return *indexed_db_control_; } @@ -127,7 +128,7 @@ // Bind and throw away the receiver. If testing is required, then add a method // to set the remote. if (!cache_storage_control_.is_bound()) - ignore_result(cache_storage_control_.BindNewPipeAndPassReceiver()); + std::ignore = cache_storage_control_.BindNewPipeAndPassReceiver(); return cache_storage_control_.get(); }
diff --git a/content/renderer/BUILD.gn b/content/renderer/BUILD.gn index 01fcc4b..d669f97 100644 --- a/content/renderer/BUILD.gn +++ b/content/renderer/BUILD.gn
@@ -237,7 +237,6 @@ "//base:i18n", "//build:chromecast_buildflags", "//build:chromeos_buildflags", - "//build:os_buildflags", "//cc", "//cc/animation", "//cc/mojo_embedder",
diff --git a/content/renderer/media/android/stream_texture_proxy_unittest.cc b/content/renderer/media/android/stream_texture_proxy_unittest.cc index ba1f57d..c35005a 100644 --- a/content/renderer/media/android/stream_texture_proxy_unittest.cc +++ b/content/renderer/media/android/stream_texture_proxy_unittest.cc
@@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "base/ignore_result.h" +#include <tuple> + #include "base/memory/ptr_util.h" #include "base/memory/scoped_refptr.h" #include "base/test/task_environment.h" @@ -46,7 +47,7 @@ // Create the StreamTextureHost with a valid |channel_|. Note that route_id // does not matter here for the test we are writing. mojo::PendingAssociatedRemote<gpu::mojom::StreamTexture> texture; - ignore_result(texture.InitWithNewEndpointAndPassReceiver()); + std::ignore = texture.InitWithNewEndpointAndPassReceiver(); texture.EnableUnassociatedUsage(); auto host = std::make_unique<StreamTextureHost>(channel_, 1 /* route_id */, std::move(texture));
diff --git a/content/renderer/media/media_factory.cc b/content/renderer/media/media_factory.cc index a973c8d2..57641a53 100644 --- a/content/renderer/media/media_factory.cc +++ b/content/renderer/media/media_factory.cc
@@ -19,7 +19,6 @@ #include "build/build_config.h" #include "build/buildflag.h" #include "build/chromecast_buildflags.h" -#include "build/os_buildflags.h" #include "cc/trees/layer_tree_settings.h" #include "content/public/common/content_client.h" #include "content/public/common/content_switches.h"
diff --git a/content/renderer/mock_agent_scheduling_group.cc b/content/renderer/mock_agent_scheduling_group.cc index 23a9327..e74b84701 100644 --- a/content/renderer/mock_agent_scheduling_group.cc +++ b/content/renderer/mock_agent_scheduling_group.cc
@@ -4,7 +4,8 @@ #include "content/renderer/mock_agent_scheduling_group.h" -#include "base/ignore_result.h" +#include <tuple> + #include "content/renderer/render_thread_impl.h" #include "third_party/blink/public/mojom/browser_interface_broker.mojom.h" @@ -57,11 +58,11 @@ void MockAgentSchedulingGroup::Init() { mojo::AssociatedRemote<mojom::AgentSchedulingGroupHost> agent_scheduling_group_host; - ignore_result( - agent_scheduling_group_host.BindNewEndpointAndPassDedicatedReceiver()); + std::ignore = + agent_scheduling_group_host.BindNewEndpointAndPassDedicatedReceiver(); mojo::AssociatedRemote<mojom::RouteProvider> browser_route_provider; - ignore_result( - browser_route_provider.BindNewEndpointAndPassDedicatedReceiver()); + std::ignore = + browser_route_provider.BindNewEndpointAndPassDedicatedReceiver(); BindAssociatedInterfaces( agent_scheduling_group_host.Unbind(), browser_route_provider.Unbind(),
diff --git a/content/renderer/render_frame_impl.cc b/content/renderer/render_frame_impl.cc index 8f5d497..4a82485 100644 --- a/content/renderer/render_frame_impl.cc +++ b/content/renderer/render_frame_impl.cc
@@ -2951,17 +2951,6 @@ new_loader_factories->Clone())); } - WebString subresource_filter = navigation_params->response.HttpHeaderField( - WebString::FromUTF8("Service-Worker-Subresource-Filter")); - if (!subresource_filter.IsEmpty()) { - ServiceWorkerNetworkProviderForFrame* provider = - static_cast<ServiceWorkerNetworkProviderForFrame*>( - navigation_params->service_worker_network_provider.get()); - DCHECK(provider); - - provider->context()->SetSubresourceFilter(subresource_filter.Utf8()); - } - DCHECK(!pending_loader_factories_); pending_loader_factories_ = std::move(new_loader_factories); pending_code_cache_host_ = std::move(code_cache_host);
diff --git a/content/renderer/render_frame_impl_browsertest.cc b/content/renderer/render_frame_impl_browsertest.cc index a7c256b5..0f5d9f2 100644 --- a/content/renderer/render_frame_impl_browsertest.cc +++ b/content/renderer/render_frame_impl_browsertest.cc
@@ -3,6 +3,7 @@ // found in the LICENSE file. #include <stdint.h> + #include <tuple> #include <utility> @@ -10,7 +11,6 @@ #include "base/callback_helpers.h" #include "base/command_line.h" #include "base/debug/leak_annotations.h" -#include "base/ignore_result.h" #include "base/run_loop.h" #include "base/strings/stringprintf.h" #include "base/strings/utf_string_conversions.h" @@ -138,7 +138,7 @@ main_frame.BindNewEndpointAndPassDedicatedReceiver(); mojo::AssociatedRemote<blink::mojom::RemoteMainFrameHost> main_frame_host; - ignore_result(main_frame_host.BindNewEndpointAndPassDedicatedReceiver()); + std::ignore = main_frame_host.BindNewEndpointAndPassDedicatedReceiver(); remote_main_frame_interfaces->main_frame_host = main_frame_host.Unbind(); RenderFrameImpl::FromWebFrame(
diff --git a/content/renderer/render_view_browsertest.cc b/content/renderer/render_view_browsertest.cc index 55156ed..5ed5981 100644 --- a/content/renderer/render_view_browsertest.cc +++ b/content/renderer/render_view_browsertest.cc
@@ -13,7 +13,6 @@ #include "base/callback_helpers.h" #include "base/command_line.h" #include "base/cxx17_backports.h" -#include "base/ignore_result.h" #include "base/json/json_reader.h" #include "base/json/json_writer.h" #include "base/location.h" @@ -273,7 +272,7 @@ interfaces->main_frame = main_frame.BindNewEndpointAndPassDedicatedReceiver(); mojo::AssociatedRemote<blink::mojom::RemoteMainFrameHost> main_frame_host; - ignore_result(main_frame_host.BindNewEndpointAndPassDedicatedReceiver()); + std::ignore = main_frame_host.BindNewEndpointAndPassDedicatedReceiver(); interfaces->main_frame_host = main_frame_host.Unbind(); return interfaces; @@ -1107,7 +1106,7 @@ blink_widget.BindNewEndpointAndPassDedicatedReceiver(); mojo::AssociatedRemote<blink::mojom::WidgetHost> blink_widget_host; - ignore_result(blink_widget_host.BindNewEndpointAndPassDedicatedReceiver()); + std::ignore = blink_widget_host.BindNewEndpointAndPassDedicatedReceiver(); mojo::AssociatedRemote<blink::mojom::FrameWidget> blink_frame_widget; mojo::PendingAssociatedReceiver<blink::mojom::FrameWidget> @@ -1115,8 +1114,8 @@ blink_frame_widget.BindNewEndpointAndPassDedicatedReceiver(); mojo::AssociatedRemote<blink::mojom::FrameWidgetHost> blink_frame_widget_host; - ignore_result( - blink_frame_widget_host.BindNewEndpointAndPassDedicatedReceiver()); + std::ignore = + blink_frame_widget_host.BindNewEndpointAndPassDedicatedReceiver(); widget_params->frame_widget = std::move(blink_frame_widget_receiver); widget_params->frame_widget_host = blink_frame_widget_host.Unbind();
diff --git a/content/renderer/service_worker/service_worker_network_provider_for_frame.cc b/content/renderer/service_worker/service_worker_network_provider_for_frame.cc index ce87ff6b..554b919d 100644 --- a/content/renderer/service_worker/service_worker_network_provider_for_frame.cc +++ b/content/renderer/service_worker/service_worker_network_provider_for_frame.cc
@@ -138,21 +138,6 @@ if (request.GetSkipServiceWorker()) return nullptr; - if (observer_ && observer_->render_frame() - ->GetWebFrame() - ->ServiceWorkerSubresourceFilterEnabled()) { - const std::string subresource_filter = context()->subresource_filter(); - // If the document has a subresource filter set and the requested URL does - // not match it, do not intercept the request. - if (!subresource_filter.empty() && - gurl.ref().find(subresource_filter) == std::string::npos) { - observer_->ReportFeatureUsage( - blink::mojom::WebFeature:: - kServiceWorkerSubresourceFilterBypassedRequest); - return nullptr; - } - } - // Record use counter for intercepting requests from opaque stylesheets. // TODO(crbug.com/898497): Remove this feature usage once we have enough data. if (observer_ && request.IsFromOriginDirtyStyleSheet()) {
diff --git a/content/renderer/service_worker/service_worker_provider_context.h b/content/renderer/service_worker/service_worker_provider_context.h index b2045274..bc9a84b 100644 --- a/content/renderer/service_worker/service_worker_provider_context.h +++ b/content/renderer/service_worker/service_worker_provider_context.h
@@ -201,11 +201,6 @@ const override; const blink::WebString client_id() const override; - std::string subresource_filter() const { return subresource_filter_; } - void SetSubresourceFilter(const std::string& filter) { - subresource_filter_ = filter; - } - private: friend class base::DeleteHelper<ServiceWorkerProviderContext>; friend class base::RefCountedThreadSafe<ServiceWorkerProviderContext, @@ -336,8 +331,6 @@ bool sent_execution_ready_ = false; - std::string subresource_filter_; - // Contains pending receivers whose corresponding requests are still // in-flight. The pending receivers are taken by // TakePendingWorkerTimingReceiver() when the request is completed.
diff --git a/content/shell/app/shell_main_delegate.cc b/content/shell/app/shell_main_delegate.cc index 62aeb5e..dd80654 100644 --- a/content/shell/app/shell_main_delegate.cc +++ b/content/shell/app/shell_main_delegate.cc
@@ -5,6 +5,7 @@ #include "content/shell/app/shell_main_delegate.h" #include <iostream> +#include <tuple> #include <utility> #include "base/base_paths.h" @@ -13,7 +14,6 @@ #include "base/cpu.h" #include "base/files/file.h" #include "base/files/file_path.h" -#include "base/ignore_result.h" #include "base/lazy_instance.h" #include "base/logging.h" #include "base/path_service.h" @@ -262,7 +262,7 @@ main_runner->Initialize(std::move(main_function_params)); DCHECK_LT(initialize_exit_code, 0) << "BrowserMainRunner::Initialize failed in ShellMainDelegate"; - ignore_result(main_runner.release()); + std::ignore = main_runner.release(); // Return 0 as BrowserMain() should not be called after this, bounce up to // the system message loop for ContentShell, and we're already done thanks // to the |ui_task| for browser tests.
diff --git a/content/test/BUILD.gn b/content/test/BUILD.gn index 8e92405..b9a7341 100644 --- a/content/test/BUILD.gn +++ b/content/test/BUILD.gn
@@ -1284,7 +1284,6 @@ "../browser/service_worker/service_worker_no_best_effort_tasks_browsertest.cc", "../browser/service_worker/service_worker_offline_capability_check_browsertest.cc", "../browser/service_worker/service_worker_process_browsertest.cc", - "../browser/service_worker/service_worker_subresource_filter_browsertest.cc", "../browser/service_worker/service_worker_version_browsertest.cc", "../browser/session_history_browsertest.cc", "../browser/shape_detection/shape_detection_browsertest.cc",
diff --git a/content/test/test_navigation_url_loader.cc b/content/test/test_navigation_url_loader.cc index 76b0861..4ccb572 100644 --- a/content/test/test_navigation_url_loader.cc +++ b/content/test/test_navigation_url_loader.cc
@@ -4,9 +4,9 @@ #include "content/test/test_navigation_url_loader.h" +#include <tuple> #include <utility> -#include "base/ignore_result.h" #include "content/browser/loader/navigation_early_hints_manager.h" #include "content/browser/loader/navigation_url_loader_delegate.h" #include "content/browser/navigation_subresource_loader_params.h" @@ -97,7 +97,7 @@ // purpose of this is not to violate some DCHECKs when the navigation commits. mojo::PendingRemote<network::mojom::URLLoaderClient> url_loader_client_remote; mojo::PendingRemote<network::mojom::URLLoader> url_loader_remote; - ignore_result(url_loader_remote.InitWithNewPipeAndPassReceiver()); + std::ignore = url_loader_remote.InitWithNewPipeAndPassReceiver(); auto url_loader_client_endpoints = network::mojom::URLLoaderClientEndpoints::New( std::move(url_loader_remote),
diff --git a/content/test/test_render_view_host.cc b/content/test/test_render_view_host.cc index 6a097ce..b2b2e5e 100644 --- a/content/test/test_render_view_host.cc +++ b/content/test/test_render_view_host.cc
@@ -5,8 +5,8 @@ #include "content/test/test_render_view_host.h" #include <memory> +#include <tuple> -#include "base/ignore_result.h" #include "base/strings/utf_string_conversions.h" #include "build/build_config.h" #include "components/viz/common/surfaces/parent_local_surface_id_allocator.h" @@ -406,7 +406,7 @@ // Pretend that mojo connections of the RemoteFrame is transferred to // renderer process and bound in blink. mojo::AssociatedRemote<blink::mojom::RemoteMainFrame> remote_main_frame; - ignore_result(remote_main_frame.BindNewEndpointAndPassDedicatedReceiver()); + std::ignore = remote_main_frame.BindNewEndpointAndPassDedicatedReceiver(); proxy_host->BindRemoteMainFrameInterfaces( remote_main_frame.Unbind(), mojo::AssociatedRemote<blink::mojom::RemoteMainFrameHost>()
diff --git a/content/web_test/renderer/gc_controller.cc b/content/web_test/renderer/gc_controller.cc index a343bbf..1c20b29 100644 --- a/content/web_test/renderer/gc_controller.cc +++ b/content/web_test/renderer/gc_controller.cc
@@ -4,8 +4,9 @@ #include "content/web_test/renderer/gc_controller.h" +#include <tuple> + #include "base/bind.h" -#include "base/ignore_result.h" #include "gin/arguments.h" #include "gin/handle.h" #include "gin/object_template_builder.h" @@ -103,7 +104,7 @@ isolate, v8::MicrotasksScope::kDoNotRunMicrotasks); auto result = func->Call(context, context->Global(), 0, nullptr); // Swallow potential exception. - ignore_result(result); + std::ignore = result; } void GCController::MinorCollect(const gin::Arguments& args) {
diff --git a/content/zygote/zygote_linux.cc b/content/zygote/zygote_linux.cc index a93ff40..be67fab 100644 --- a/content/zygote/zygote_linux.cc +++ b/content/zygote/zygote_linux.cc
@@ -14,11 +14,11 @@ #include <sys/types.h> #include <sys/wait.h> +#include <tuple> #include <utility> #include "base/command_line.h" #include "base/files/file_util.h" -#include "base/ignore_result.h" #include "base/linux_util.h" #include "base/logging.h" #include "base/pickle.h" @@ -591,7 +591,7 @@ // Pass ownership of file descriptors from fds to GlobalDescriptors. for (base::ScopedFD& fd : fds) - ignore_result(fd.release()); + std::ignore = fd.release(); base::GlobalDescriptors::GetInstance()->Reset(mapping); // Reset the process-wide command line to our new command line.
diff --git a/device/fido/appid_exclude_probe_task.cc b/device/fido/appid_exclude_probe_task.cc index 065fded7..d36ea64 100644 --- a/device/fido/appid_exclude_probe_task.cc +++ b/device/fido/appid_exclude_probe_task.cc
@@ -65,7 +65,8 @@ device(), std::move(request), base::BindOnce(&AppIdExcludeProbeTask::HandleResponseToSilentSignRequest, weak_factory_.GetWeakPtr()), - base::BindOnce(&ReadCTAPGetAssertionResponse), + base::BindOnce(&ReadCTAPGetAssertionResponse, + device()->DeviceTransport()), /*string_fixup_predicate=*/nullptr); silent_sign_operation_->Start(); }
diff --git a/device/fido/authenticator_get_assertion_response.h b/device/fido/authenticator_get_assertion_response.h index cd6df48..0b29f28f3 100644 --- a/device/fido/authenticator_get_assertion_response.h +++ b/device/fido/authenticator_get_assertion_response.h
@@ -69,6 +69,10 @@ // Whether a large blob was successfully written as part of this GetAssertion // request. bool large_blob_written = false; + + // The transport used to generate this response. This is unknown when using + // the Windows WebAuthn API. + absl::optional<FidoTransportProtocol> transport_used; }; } // namespace device
diff --git a/device/fido/cable/v2_authenticator.cc b/device/fido/cable/v2_authenticator.cc index a10ada9..b195c28f 100644 --- a/device/fido/cable/v2_authenticator.cc +++ b/device/fido/cable/v2_authenticator.cc
@@ -394,19 +394,11 @@ return; } - // It should be the case that all post-handshake messages fall into - // a single padding bucket. (It doesn't have to be the smallest one.) - // - // This check should be: - // DCHECK_EQ(post_handshake_msg_bytes->size(), - // kPostHandshakeMsgPaddingGranularity); - // - // ... but we're waiting to roll out a protocol change that allows it. - // For now, check that the messages fit within the future padding - // granularity, which will also highlight this when that constant is - // rename to remove "Future". - DCHECK_LE(post_handshake_msg_bytes->size(), - kFuturePostHandshakeMsgPaddingGranularity); + // All post-handshake messages should fit into the same padding bucket. + // It doesn't have to be the smallest one, but that's currently true + // which yields this easy check: + DCHECK_EQ(post_handshake_msg_bytes->size(), + kPostHandshakeMsgPaddingGranularity); if (!crypter_->Encrypt(&post_handshake_msg_bytes.value())) { FIDO_LOG(ERROR) << "failed to encrypt post-handshake message";
diff --git a/device/fido/cable/v2_constants.h b/device/fido/cable/v2_constants.h index 12ad128..a212231 100644 --- a/device/fido/cable/v2_constants.h +++ b/device/fido/cable/v2_constants.h
@@ -64,11 +64,7 @@ // kPostHandshakeMsgPaddingGranularity is the granularity of the padding added // to the post-handshake message. This should be sufficiently large to pad away // all information about the contents of this message. -constexpr size_t kPostHandshakeMsgPaddingGranularity = 256; -// kFuturePostHandshakeMsgPaddingGranularity will be the granularity of the -// padding added to the post-handshake message. This is currently only used for -// testing. -constexpr size_t kFuturePostHandshakeMsgPaddingGranularity = 512; +constexpr size_t kPostHandshakeMsgPaddingGranularity = 512; } // namespace cablev2 } // namespace device
diff --git a/device/fido/cable/v2_handshake.cc b/device/fido/cable/v2_handshake.cc index 0fec587a..5e33bb6a 100644 --- a/device/fido/cable/v2_handshake.cc +++ b/device/fido/cable/v2_handshake.cc
@@ -569,6 +569,15 @@ absl::optional<std::vector<uint8_t>> EncodePaddedCBORMap( cbor::Value::MapValue map) { + // The number of padding bytes is a uint16_t, so the granularity cannot be + // larger than that. + static_assert(kPostHandshakeMsgPaddingGranularity > 0, ""); + static_assert(kPostHandshakeMsgPaddingGranularity - 1 <= + std::numeric_limits<uint16_t>::max()); + // The granularity must also be a power of two. + static_assert((kPostHandshakeMsgPaddingGranularity & + (kPostHandshakeMsgPaddingGranularity - 1)) == 0); + absl::optional<std::vector<uint8_t>> cbor_bytes = cbor::Writer::Write(cbor::Value(std::move(map))); if (!cbor_bytes) { @@ -576,60 +585,39 @@ } base::CheckedNumeric<size_t> padded_size_checked = cbor_bytes->size(); - padded_size_checked += 1; // padding-length byte - padded_size_checked = (padded_size_checked + 255) & ~255; + padded_size_checked += sizeof(uint16_t); // padding-length bytes + padded_size_checked = + (padded_size_checked + kPostHandshakeMsgPaddingGranularity - 1) & + ~(kPostHandshakeMsgPaddingGranularity - 1); if (!padded_size_checked.IsValid()) { return absl::nullopt; } const size_t padded_size = padded_size_checked.ValueOrDie(); - DCHECK_GT(padded_size, cbor_bytes->size()); - const size_t extra_padding = padded_size - cbor_bytes->size(); + DCHECK_GE(padded_size, cbor_bytes->size() + sizeof(uint16_t)); + const size_t extra_bytes = padded_size - cbor_bytes->size(); + const size_t num_padding_bytes = + extra_bytes - sizeof(uint16_t) /* length of padding length */; cbor_bytes->resize(padded_size); - DCHECK_LE(extra_padding, 256u); - cbor_bytes->at(padded_size - 1) = static_cast<uint8_t>(extra_padding - 1); + const uint16_t num_padding_bytes16 = + base::checked_cast<uint16_t>(num_padding_bytes); + memcpy(&cbor_bytes.value()[padded_size - sizeof(num_padding_bytes16)], + &num_padding_bytes16, sizeof(num_padding_bytes16)); return *cbor_bytes; } -// DecodePaddedCBORMap16 is the future replacement for |DecodePaddedCBORMap|, -// below. It parses a slightly different format that allows for more padding. -// (This is needed because some structures ended up larger than initially -// expected.) In order to transition, |DecodePaddedCBORMap| calls this function -// if it fails to parse. In the future we can drop supporting the old format and -// start sending new-format messages. This works because CBOR parsing doesn't -// depend on the length of the input (other than to fail on truncation) so -// there's no ambiguity about the parse. -absl::optional<cbor::Value> DecodePaddedCBORMap16( - base::span<const uint8_t> input) { - if (input.size() < sizeof(uint16_t)) { - return absl::nullopt; - } +namespace { - uint16_t padding_length16; - memcpy(&padding_length16, &input[input.size() - sizeof(padding_length16)], - sizeof(padding_length16)); - const size_t padding_length = padding_length16; - if (padding_length + sizeof(uint16_t) > input.size()) { - FIDO_LOG(DEBUG) << "Invalid padding in caBLE handshake message"; - return absl::nullopt; - } - input = input.subspan(0, input.size() - padding_length - sizeof(uint16_t)); - - absl::optional<cbor::Value> payload = cbor::Reader::Read(input); - if (!payload || !payload->is_map()) { - FIDO_LOG(DEBUG) << "CBOR parse failure in caBLE handshake message"; - return absl::nullopt; - } - - return payload; -} - -absl::optional<cbor::Value> DecodePaddedCBORMap( +// DecodePaddedCBORMap8 performs the actions of |DecodePaddedCBORMap| using the +// old padding format. We still support this format for backwards compatibility. +// See comment in |DecodePaddedCBORMap|. +// +// TODO(agl): remove support for this padding format. (Chromium started sending +// the new format with M99.) +absl::optional<cbor::Value> DecodePaddedCBORMap8( const base::span<const uint8_t> input) { - // TODO: replace this with the body of |DecodePaddedCBORMap16| once M92 is - // everywhere. if (input.empty()) { return absl::nullopt; } @@ -642,12 +630,54 @@ absl::optional<cbor::Value> payload = cbor::Reader::Read(unpadded_input); if (!payload || !payload->is_map()) { - return DecodePaddedCBORMap16(input); + return absl::nullopt; } return payload; } +// DecodePaddedCBORMap16 performs the actions of |DecodePaddedCBORMap| using the +// new padding format. See comment in |DecodePaddedCBORMap|. +absl::optional<cbor::Value> DecodePaddedCBORMap16( + base::span<const uint8_t> input) { + if (input.size() < sizeof(uint16_t)) { + return absl::nullopt; + } + + uint16_t padding_length16; + memcpy(&padding_length16, &input[input.size() - sizeof(padding_length16)], + sizeof(padding_length16)); + const size_t padding_length = padding_length16; + if (padding_length + sizeof(uint16_t) > input.size()) { + return absl::nullopt; + } + input = input.subspan(0, input.size() - padding_length - sizeof(uint16_t)); + + absl::optional<cbor::Value> payload = cbor::Reader::Read(input); + if (!payload || !payload->is_map()) { + return absl::nullopt; + } + + return payload; +} + +} // namespace + +absl::optional<cbor::Value> DecodePaddedCBORMap( + base::span<const uint8_t> input) { + // Two padding formats are currently in use. They are unambiguous so we try + // each, new first. Eventually the old format can be removed once enough time + // has passed since M99. + absl::optional<cbor::Value> result = DecodePaddedCBORMap16(input); + if (!result) { + result = DecodePaddedCBORMap8(input); + } + if (!result) { + FIDO_LOG(DEBUG) << "Invalid padding in caBLE handshake message"; + } + return result; +} + Crypter::Crypter(base::span<const uint8_t, 32> read_key, base::span<const uint8_t, 32> write_key) : read_key_(fido_parsing_utils::Materialize(read_key)),
diff --git a/device/fido/cable/v2_handshake_unittest.cc b/device/fido/cable/v2_handshake_unittest.cc index 1795f13..132f661 100644 --- a/device/fido/cable/v2_handshake_unittest.cc +++ b/device/fido/cable/v2_handshake_unittest.cc
@@ -115,20 +115,11 @@ EXPECT_EQ(1u, decoded->GetMap().size()); } -// FutureEncodePaddedCBORMapFunction is the future replacement for -// |EncodePaddedCBORMap|. See comment on |DecodePaddedCBORMap16|. -absl::optional<std::vector<uint8_t>> FutureEncodePaddedCBORMapFunction( +// EncodePaddedCBORMapOld is the old padding function that used to be used. +// We should still be compatible with it until M99 has been out in the world +// for long enough. +absl::optional<std::vector<uint8_t>> EncodePaddedCBORMapOld( cbor::Value::MapValue map) { - // TODO: when promoting this function, update comment on - // |kPostHandshakeMsgPaddingGranularity|. - - // The number of padding bytes is a uint16_t, so the granularity cannot be - // larger than that. - static_assert(kFuturePostHandshakeMsgPaddingGranularity > 0, ""); - static_assert(kFuturePostHandshakeMsgPaddingGranularity - 1 <= - std::numeric_limits<uint16_t>::max(), - ""); - absl::optional<std::vector<uint8_t>> cbor_bytes = cbor::Writer::Write(cbor::Value(std::move(map))); if (!cbor_bytes) { @@ -136,41 +127,33 @@ } base::CheckedNumeric<size_t> padded_size_checked = cbor_bytes->size(); - padded_size_checked += sizeof(uint16_t); // padding-length bytes - padded_size_checked = - (padded_size_checked + kFuturePostHandshakeMsgPaddingGranularity - 1) & - ~(kFuturePostHandshakeMsgPaddingGranularity - 1); + padded_size_checked += 1; // padding-length byte + padded_size_checked = (padded_size_checked + 255) & ~255; if (!padded_size_checked.IsValid()) { return absl::nullopt; } const size_t padded_size = padded_size_checked.ValueOrDie(); - DCHECK_GE(padded_size, cbor_bytes->size() + sizeof(uint16_t)); - const size_t extra_bytes = padded_size - cbor_bytes->size(); - const size_t num_padding_bytes = - extra_bytes - sizeof(uint16_t) /* length of padding length */; + DCHECK_GT(padded_size, cbor_bytes->size()); + const size_t extra_padding = padded_size - cbor_bytes->size(); cbor_bytes->resize(padded_size); - const uint16_t num_padding_bytes16 = - base::checked_cast<uint16_t>(num_padding_bytes); - memcpy(&cbor_bytes.value()[padded_size - sizeof(num_padding_bytes16)], - &num_padding_bytes16, sizeof(num_padding_bytes16)); + DCHECK_LE(extra_padding, 256u); + cbor_bytes->at(padded_size - 1) = static_cast<uint8_t>(extra_padding - 1); return *cbor_bytes; } -TEST(CableV2Encoding, FuturePaddedCBOR) { - // Test that we can decode messages padded by the encoding function that - // will be used in the future. +TEST(CableV2Encoding, OldPaddedCBOR) { + // Test that we can decode messages padded by the old encoding function. for (size_t i = 0; i < 512; i++) { SCOPED_TRACE(i); - // Check that new->old direction works. const std::vector<uint8_t> dummy_array(i); cbor::Value::MapValue map; map.emplace(1, dummy_array); absl::optional<std::vector<uint8_t>> encoded = - FutureEncodePaddedCBORMapFunction(std::move(map)); + EncodePaddedCBORMapOld(std::move(map)); ASSERT_TRUE(encoded); absl::optional<cbor::Value> decoded = DecodePaddedCBORMap(*encoded);
diff --git a/device/fido/cros/authenticator.cc b/device/fido/cros/authenticator.cc index 716812859..39b5f15 100644 --- a/device/fido/cros/authenticator.cc +++ b/device/fido/cros/authenticator.cc
@@ -18,6 +18,7 @@ #include "device/fido/attestation_statement_formats.h" #include "device/fido/authenticator_data.h" #include "device/fido/fido_parsing_utils.h" +#include "device/fido/fido_transport_protocol.h" #include "device/fido/opaque_attestation_statement.h" #include "third_party/cros_system_api/dbus/u2f/dbus-constants.h" @@ -255,6 +256,7 @@ assertion.signature().end()); AuthenticatorGetAssertionResponse authenticator_response( std::move(*authenticator_data), std::move(signature)); + authenticator_response.transport_used = FidoTransportProtocol::kInternal; const std::string& credential_id = assertion.credential_id(); authenticator_response.credential = PublicKeyCredentialDescriptor( CredentialType::kPublicKey,
diff --git a/device/fido/ctap_response_fuzzer.cc b/device/fido/ctap_response_fuzzer.cc index 600884831..46aeb20a 100644 --- a/device/fido/ctap_response_fuzzer.cc +++ b/device/fido/ctap_response_fuzzer.cc
@@ -47,25 +47,25 @@ } std::array<uint8_t, 32> relying_party_id_hash = {}; - auto response = device::ReadCTAPMakeCredentialResponse( + auto response = ReadCTAPMakeCredentialResponse( FidoTransportProtocol::kUsbHumanInterfaceDevice, input_cbor); if (response) response->EraseAttestationStatement(AttestationObject::AAGUID::kErase); - response = device::AuthenticatorMakeCredentialResponse:: - CreateFromU2fRegisterResponse( - FidoTransportProtocol::kUsbHumanInterfaceDevice, - relying_party_id_hash, input); + response = AuthenticatorMakeCredentialResponse::CreateFromU2fRegisterResponse( + FidoTransportProtocol::kUsbHumanInterfaceDevice, relying_party_id_hash, + input); if (response) response->EraseAttestationStatement(AttestationObject::AAGUID::kErase); - device::ReadCTAPGetAssertionResponse(input_cbor); + ReadCTAPGetAssertionResponse(FidoTransportProtocol::kUsbHumanInterfaceDevice, + input_cbor); std::vector<uint8_t> u2f_response_data(data, data + size); std::vector<uint8_t> key_handle(data, data + size); - device::AuthenticatorGetAssertionResponse::CreateFromU2fSignResponse( + AuthenticatorGetAssertionResponse::CreateFromU2fSignResponse( relying_party_id_hash, u2f_response_data, key_handle); - device::ReadCTAPGetInfoResponse(input); + ReadCTAPGetInfoResponse(input); return 0; }
diff --git a/device/fido/ctap_response_unittest.cc b/device/fido/ctap_response_unittest.cc index 1bdbdae..c2d0d696 100644 --- a/device/fido/ctap_response_unittest.cc +++ b/device/fido/ctap_response_unittest.cc
@@ -16,6 +16,7 @@ #include "device/fido/fido_constants.h" #include "device/fido/fido_parsing_utils.h" #include "device/fido/fido_test_data.h" +#include "device/fido/fido_transport_protocol.h" #include "device/fido/fido_types.h" #include "device/fido/opaque_attestation_statement.h" #include "device/fido/p256_public_key.h" @@ -519,8 +520,11 @@ // https://fidoalliance.org/specs/fido-v2.0-rd-20170927/fido-client-to-authenticator-protocol-v2.0-rd-20170927.html TEST(CTAPResponseTest, TestReadGetAssertionResponse) { auto get_assertion_response = ReadCTAPGetAssertionResponse( + FidoTransportProtocol::kBluetoothLowEnergy, DecodeCBOR(test_data::kDeviceGetAssertionResponse)); ASSERT_TRUE(get_assertion_response); + EXPECT_EQ(*get_assertion_response->transport_used, + FidoTransportProtocol::kBluetoothLowEnergy); ASSERT_TRUE(get_assertion_response->num_credentials); EXPECT_EQ(*get_assertion_response->num_credentials, 1u);
diff --git a/device/fido/device_response_converter.cc b/device/fido/device_response_converter.cc index eb8736f7..beccd103 100644 --- a/device/fido/device_response_converter.cc +++ b/device/fido/device_response_converter.cc
@@ -118,6 +118,7 @@ } absl::optional<AuthenticatorGetAssertionResponse> ReadCTAPGetAssertionResponse( + FidoTransportProtocol transport_used, const absl::optional<cbor::Value>& cbor) { if (!cbor || !cbor->is_map()) return absl::nullopt; @@ -141,6 +142,8 @@ AuthenticatorGetAssertionResponse response(std::move(*auth_data), std::move(signature)); + response.transport_used = transport_used; + it = response_map.find(CBOR(0x01)); if (it != response_map.end()) { auto credential =
diff --git a/device/fido/device_response_converter.h b/device/fido/device_response_converter.h index 3ef3109..72d2e64 100644 --- a/device/fido/device_response_converter.h +++ b/device/fido/device_response_converter.h
@@ -40,6 +40,7 @@ // |AuthenticatorGetAssertionResponse|. COMPONENT_EXPORT(DEVICE_FIDO) absl::optional<AuthenticatorGetAssertionResponse> ReadCTAPGetAssertionResponse( + FidoTransportProtocol transport_used, const absl::optional<cbor::Value>& cbor); // De-serializes CBOR encoded response to AuthenticatorGetInfo request to
diff --git a/device/fido/fido_device_authenticator.cc b/device/fido/fido_device_authenticator.cc index e4fbbc49..4a13c36 100644 --- a/device/fido/fido_device_authenticator.cc +++ b/device/fido/fido_device_authenticator.cc
@@ -206,7 +206,7 @@ void FidoDeviceAuthenticator::GetNextAssertion(GetAssertionCallback callback) { RunOperation<CtapGetNextAssertionRequest, AuthenticatorGetAssertionResponse>( CtapGetNextAssertionRequest(), std::move(callback), - base::BindOnce(&ReadCTAPGetAssertionResponse), + base::BindOnce(&ReadCTAPGetAssertionResponse, device_->DeviceTransport()), GetAssertionTask::StringFixupPredicate); }
diff --git a/device/fido/fido_transport_protocol.cc b/device/fido/fido_transport_protocol.cc index c3fbbce..8f8c25d 100644 --- a/device/fido/fido_transport_protocol.cc +++ b/device/fido/fido_transport_protocol.cc
@@ -28,7 +28,6 @@ return absl::nullopt; } -COMPONENT_EXPORT(DEVICE_FIDO) base::StringPiece ToString(FidoTransportProtocol protocol) { switch (protocol) { case FidoTransportProtocol::kUsbHumanInterfaceDevice: @@ -48,4 +47,18 @@ } } +AuthenticatorAttachment AuthenticatorAttachmentFromTransport( + FidoTransportProtocol transport) { + switch (transport) { + case FidoTransportProtocol::kInternal: + return AuthenticatorAttachment::kPlatform; + case FidoTransportProtocol::kUsbHumanInterfaceDevice: + case FidoTransportProtocol::kNearFieldCommunication: + case FidoTransportProtocol::kBluetoothLowEnergy: + case FidoTransportProtocol::kCloudAssistedBluetoothLowEnergy: + case FidoTransportProtocol::kAndroidAccessory: + return AuthenticatorAttachment::kCrossPlatform; + } +} + } // namespace device
diff --git a/device/fido/fido_transport_protocol.h b/device/fido/fido_transport_protocol.h index 971c85a6..39c28a0 100644 --- a/device/fido/fido_transport_protocol.h +++ b/device/fido/fido_transport_protocol.h
@@ -7,6 +7,7 @@ #include "base/component_export.h" #include "base/strings/string_piece.h" +#include "device/fido/fido_types.h" #include "third_party/abseil-cpp/absl/types/optional.h" namespace device { @@ -40,6 +41,10 @@ COMPONENT_EXPORT(DEVICE_FIDO) base::StringPiece ToString(FidoTransportProtocol protocol); +COMPONENT_EXPORT(DEVICE_FIDO) +AuthenticatorAttachment AuthenticatorAttachmentFromTransport( + FidoTransportProtocol transport); + } // namespace device #endif // DEVICE_FIDO_FIDO_TRANSPORT_PROTOCOL_H_
diff --git a/device/fido/get_assertion_task.cc b/device/fido/get_assertion_task.cc index 5a78af9..d04bd97 100644 --- a/device/fido/get_assertion_task.cc +++ b/device/fido/get_assertion_task.cc
@@ -173,7 +173,9 @@ device(), request_, base::BindOnce(&GetAssertionTask::HandleResponse, weak_factory_.GetWeakPtr(), request_.allow_list), - base::BindOnce(&ReadCTAPGetAssertionResponse), StringFixupPredicate); + base::BindOnce(&ReadCTAPGetAssertionResponse, + device()->DeviceTransport()), + StringFixupPredicate); sign_operation_->Start(); return; } @@ -214,7 +216,9 @@ device(), std::move(request), base::BindOnce(&GetAssertionTask::HandleResponse, weak_factory_.GetWeakPtr(), request.allow_list), - base::BindOnce(&ReadCTAPGetAssertionResponse), StringFixupPredicate); + base::BindOnce(&ReadCTAPGetAssertionResponse, + device()->DeviceTransport()), + StringFixupPredicate); sign_operation_->Start(); return; } @@ -228,7 +232,8 @@ device(), NextSilentRequest(), base::BindOnce(&GetAssertionTask::HandleResponseToSilentRequest, weak_factory_.GetWeakPtr()), - base::BindOnce(&ReadCTAPGetAssertionResponse), + base::BindOnce(&ReadCTAPGetAssertionResponse, + device()->DeviceTransport()), /*string_fixup_predicate=*/nullptr); sign_operation_->Start(); } @@ -339,7 +344,8 @@ device(), std::move(request), base::BindOnce(&GetAssertionTask::HandleResponse, weak_factory_.GetWeakPtr(), request.allow_list), - base::BindOnce(&ReadCTAPGetAssertionResponse), + base::BindOnce(&ReadCTAPGetAssertionResponse, + device()->DeviceTransport()), /*string_fixup_predicate=*/nullptr); sign_operation_->Start(); return; @@ -353,7 +359,8 @@ device(), NextSilentRequest(), base::BindOnce(&GetAssertionTask::HandleResponseToSilentRequest, weak_factory_.GetWeakPtr()), - base::BindOnce(&ReadCTAPGetAssertionResponse), + base::BindOnce(&ReadCTAPGetAssertionResponse, + device()->DeviceTransport()), /*string_fixup_predicate=*/nullptr); sign_operation_->Start(); return;
diff --git a/device/fido/mac/get_assertion_operation.mm b/device/fido/mac/get_assertion_operation.mm index f444e1a..f525650 100644 --- a/device/fido/mac/get_assertion_operation.mm +++ b/device/fido/mac/get_assertion_operation.mm
@@ -6,6 +6,7 @@ #include <set> #include <string> +#include "device/fido/fido_transport_protocol.h" #import <Foundation/Foundation.h> @@ -140,6 +141,7 @@ } AuthenticatorGetAssertionResponse response(std::move(authenticator_data), std::move(*signature)); + response.transport_used = FidoTransportProtocol::kInternal; response.credential = PublicKeyCredentialDescriptor( CredentialType::kPublicKey, credential.credential_id); response.user_entity = metadata->ToPublicKeyCredentialUserEntity();
diff --git a/device/fido/make_credential_task.cc b/device/fido/make_credential_task.cc index fbb45d7..33fbccc 100644 --- a/device/fido/make_credential_task.cc +++ b/device/fido/make_credential_task.cc
@@ -247,7 +247,8 @@ device(), NextSilentRequest(), base::BindOnce(&MakeCredentialTask::HandleResponseToSilentSignRequest, weak_factory_.GetWeakPtr()), - base::BindOnce(&ReadCTAPGetAssertionResponse), + base::BindOnce(&ReadCTAPGetAssertionResponse, + device()->DeviceTransport()), /*string_fixup_predicate=*/nullptr); silent_sign_operation_->Start(); } @@ -301,7 +302,8 @@ device(), NextSilentRequest(), base::BindOnce(&MakeCredentialTask::HandleResponseToSilentSignRequest, weak_factory_.GetWeakPtr()), - base::BindOnce(&ReadCTAPGetAssertionResponse), + base::BindOnce(&ReadCTAPGetAssertionResponse, + device()->DeviceTransport()), /*string_fixup_predicate=*/nullptr); silent_sign_operation_->Start(); return;
diff --git a/extensions/browser/guest_view/web_view/web_view_find_helper.cc b/extensions/browser/guest_view/web_view/web_view_find_helper.cc index 2a7cbf1..95e4d52 100644 --- a/extensions/browser/guest_view/web_view/web_view_find_helper.cc +++ b/extensions/browser/guest_view/web_view/web_view_find_helper.cc
@@ -137,7 +137,7 @@ } guest_web_contents->Find(current_find_request_id_, search_text, - std::move(full_options)); + std::move(full_options), /*skip_delay=*/true); } void WebViewFindHelper::FindReply(int request_id,
diff --git a/gpu/command_buffer/service/shared_image_backing_factory_gl_image.cc b/gpu/command_buffer/service/shared_image_backing_factory_gl_image.cc index b1148ea..f1185617 100644 --- a/gpu/command_buffer/service/shared_image_backing_factory_gl_image.cc +++ b/gpu/command_buffer/service/shared_image_backing_factory_gl_image.cc
@@ -265,12 +265,6 @@ *allow_legacy_mailbox = gr_context_type == GrContextType::kGL; return true; #else -#if BUILDFLAG(IS_CHROMEOS_ASH) - // On ChromeOS Ash, use only for SHARED_MEMORY gmb - if (gmb_type != gfx::SHARED_MEMORY_BUFFER) { - return false; - } -#endif // Doesn't support contexts other than GL for OOPR Canvas if (gr_context_type != GrContextType::kGL && ((usage & SHARED_IMAGE_USAGE_DISPLAY) || @@ -289,7 +283,6 @@ // return false if it needs interop factory return false; } - *allow_legacy_mailbox = gr_context_type == GrContextType::kGL; return true; #endif
diff --git a/gpu/command_buffer/service/shared_image_backing_factory_ozone.cc b/gpu/command_buffer/service/shared_image_backing_factory_ozone.cc index c46ab80..396c079 100644 --- a/gpu/command_buffer/service/shared_image_backing_factory_ozone.cc +++ b/gpu/command_buffer/service/shared_image_backing_factory_ozone.cc
@@ -15,7 +15,6 @@ #include "gpu/command_buffer/service/shared_image_backing_ozone.h" #include "gpu/command_buffer/service/shared_memory_region_wrapper.h" #include "gpu/vulkan/vulkan_device_queue.h" -#include "ui/gfx/buffer_types.h" #include "ui/gfx/gpu_memory_buffer.h" #include "ui/gfx/native_pixmap.h" #include "ui/gl/buildflags.h" @@ -190,6 +189,21 @@ gmb_type != gfx::SHARED_MEMORY_BUFFER) { return false; } + // TODO(crbug.com/969114): Not all shared image factory implementations + // support concurrent read/write usage. + if (usage & SHARED_IMAGE_USAGE_CONCURRENT_READ_WRITE) { + return false; + } + + // TODO(hitawala): Until SharedImageBackingOzone supports all use cases prefer + // using SharedImageBackingGLImage instead + bool needs_interop_factory = (gr_context_type == GrContextType::kVulkan && + (usage & SHARED_IMAGE_USAGE_DISPLAY)) || + (usage & SHARED_IMAGE_USAGE_WEBGPU) || + (usage & SHARED_IMAGE_USAGE_VIDEO_DECODE); + if (!needs_interop_factory) { + return false; + } *allow_legacy_mailbox = false; return true;
diff --git a/gpu/command_buffer/service/shared_image_factory.cc b/gpu/command_buffer/service/shared_image_factory.cc index 4b956e1..0a4bf25 100644 --- a/gpu/command_buffer/service/shared_image_factory.cc +++ b/gpu/command_buffer/service/shared_image_factory.cc
@@ -251,9 +251,12 @@ } #if BUILDFLAG(IS_CHROMEOS_ASH) - auto ozone_factory = - std::make_unique<SharedImageBackingFactoryOzone>(context_state); - factories_.push_back(std::move(ozone_factory)); + if (gpu_preferences.enable_webgpu || + gr_context_type_ == GrContextType::kVulkan) { + auto ozone_factory = + std::make_unique<SharedImageBackingFactoryOzone>(context_state); + factories_.push_back(std::move(ozone_factory)); + } #endif // IS_CHROMEOS_ASH #if defined(OS_FUCHSIA)
diff --git a/gpu/command_buffer/service/shared_image_representation_gl_ozone.cc b/gpu/command_buffer/service/shared_image_representation_gl_ozone.cc index 5879bc2..1d05d407 100644 --- a/gpu/command_buffer/service/shared_image_representation_gl_ozone.cc +++ b/gpu/command_buffer/service/shared_image_representation_gl_ozone.cc
@@ -217,6 +217,7 @@ GLuint internal_format = image->GetInternalFormat(); GLenum gl_format = image->GetDataFormat(); GLenum gl_type = image->GetDataType(); + scoped_refptr<gles2::TexturePassthrough> texture_passthrough = base::MakeRefCounted<gpu::gles2::TexturePassthrough>( *gl_texture_service_id, target, internal_format,
diff --git a/infra/config/generated/builders/ci/GPU FYI Android arm Builder/properties.textpb b/infra/config/generated/builders/ci/GPU FYI Android arm Builder/properties.textpb index 1c9949f..c01b8ac 100644 --- a/infra/config/generated/builders/ci/GPU FYI Android arm Builder/properties.textpb +++ b/infra/config/generated/builders/ci/GPU FYI Android arm Builder/properties.textpb
@@ -1,7 +1,7 @@ { "$build/reclient": { "instance": "rbe-chromium-trusted", - "jobs": 250, + "jobs": 500, "metrics_project": "chromium-reclient-metrics" }, "$recipe_engine/resultdb/test_presentation": {
diff --git a/infra/config/generated/builders/ci/GPU FYI Android arm64 Builder/properties.textpb b/infra/config/generated/builders/ci/GPU FYI Android arm64 Builder/properties.textpb index 1c9949f..c01b8ac 100644 --- a/infra/config/generated/builders/ci/GPU FYI Android arm64 Builder/properties.textpb +++ b/infra/config/generated/builders/ci/GPU FYI Android arm64 Builder/properties.textpb
@@ -1,7 +1,7 @@ { "$build/reclient": { "instance": "rbe-chromium-trusted", - "jobs": 250, + "jobs": 500, "metrics_project": "chromium-reclient-metrics" }, "$recipe_engine/resultdb/test_presentation": {
diff --git "a/infra/config/generated/builders/ci/GPU Linux Builder \050dbg\051/properties.textpb" "b/infra/config/generated/builders/ci/GPU Linux Builder \050dbg\051/properties.textpb" index 94f3cb4..94d4601 100644 --- "a/infra/config/generated/builders/ci/GPU Linux Builder \050dbg\051/properties.textpb" +++ "b/infra/config/generated/builders/ci/GPU Linux Builder \050dbg\051/properties.textpb"
@@ -1,9 +1,8 @@ { - "$build/goma": { - "enable_ats": true, - "rpc_extra_params": "?prod", - "server_host": "goma.chromium.org", - "use_luci_auth": true + "$build/reclient": { + "instance": "rbe-chromium-trusted", + "jobs": 250, + "metrics_project": "chromium-reclient-metrics" }, "$recipe_engine/resultdb/test_presentation": { "column_keys": [],
diff --git "a/infra/config/generated/builders/ci/Linux CFI \050reclient shadow\051/properties.textpb" "b/infra/config/generated/builders/ci/Linux CFI \050reclient shadow\051/properties.textpb" index dabc646..052ac08 100644 --- "a/infra/config/generated/builders/ci/Linux CFI \050reclient shadow\051/properties.textpb" +++ "b/infra/config/generated/builders/ci/Linux CFI \050reclient shadow\051/properties.textpb"
@@ -3,7 +3,10 @@ "ensure_verified": true, "instance": "rbe-chromium-trusted", "jobs": 400, - "metrics_project": "chromium-reclient-metrics" + "metrics_project": "chromium-reclient-metrics", + "rewrapper_env": { + "RBE_compare": "true" + } }, "$recipe_engine/resultdb/test_presentation": { "column_keys": [],
diff --git a/infra/config/generated/builders/ci/Linux FYI GPU TSAN Release/properties.textpb b/infra/config/generated/builders/ci/Linux FYI GPU TSAN Release/properties.textpb index def7636..1c9949f 100644 --- a/infra/config/generated/builders/ci/Linux FYI GPU TSAN Release/properties.textpb +++ b/infra/config/generated/builders/ci/Linux FYI GPU TSAN Release/properties.textpb
@@ -1,9 +1,8 @@ { - "$build/goma": { - "enable_ats": true, - "rpc_extra_params": "?prod", - "server_host": "goma.chromium.org", - "use_luci_auth": true + "$build/reclient": { + "instance": "rbe-chromium-trusted", + "jobs": 250, + "metrics_project": "chromium-reclient-metrics" }, "$recipe_engine/resultdb/test_presentation": { "column_keys": [],
diff --git a/infra/config/generated/builders/ci/Linux Viz/properties.textpb b/infra/config/generated/builders/ci/Linux Viz/properties.textpb index fc3bce3..b232002 100644 --- a/infra/config/generated/builders/ci/Linux Viz/properties.textpb +++ b/infra/config/generated/builders/ci/Linux Viz/properties.textpb
@@ -1,9 +1,8 @@ { - "$build/goma": { - "enable_ats": true, - "rpc_extra_params": "?prod", - "server_host": "goma.chromium.org", - "use_luci_auth": true + "$build/reclient": { + "instance": "rbe-chromium-trusted", + "jobs": 250, + "metrics_project": "chromium-reclient-metrics" }, "$recipe_engine/resultdb/test_presentation": { "column_keys": [],
diff --git a/infra/config/generated/builders/ci/Win11 Tests x64/properties.textpb b/infra/config/generated/builders/ci/Win11 Tests x64/properties.textpb new file mode 100644 index 0000000..ef42b767 --- /dev/null +++ b/infra/config/generated/builders/ci/Win11 Tests x64/properties.textpb
@@ -0,0 +1,11 @@ +{ + "$recipe_engine/resultdb/test_presentation": { + "column_keys": [], + "grouping_keys": [ + "status", + "v.test_suite" + ] + }, + "builder_group": "chromium.fyi", + "recipe": "chromium" +} \ No newline at end of file
diff --git a/infra/config/generated/builders/try/win11-x64-fyi-rel/properties.textpb b/infra/config/generated/builders/try/win11-x64-fyi-rel/properties.textpb new file mode 100644 index 0000000..b8b797a --- /dev/null +++ b/infra/config/generated/builders/try/win11-x64-fyi-rel/properties.textpb
@@ -0,0 +1,24 @@ +{ + "$build/code_coverage": { + "coverage_test_types": [ + "unit", + "overall" + ], + "use_clang_coverage": true + }, + "$build/goma": { + "enable_ats": false, + "rpc_extra_params": "?prod", + "server_host": "goma.chromium.org", + "use_luci_auth": true + }, + "$recipe_engine/resultdb/test_presentation": { + "column_keys": [], + "grouping_keys": [ + "status", + "v.test_suite" + ] + }, + "builder_group": "tryserver.chromium.win", + "recipe": "chromium_trybot" +} \ No newline at end of file
diff --git a/infra/config/generated/luci/commit-queue.cfg b/infra/config/generated/luci/commit-queue.cfg index 235db8a4..fe30bb16 100644 --- a/infra/config/generated/luci/commit-queue.cfg +++ b/infra/config/generated/luci/commit-queue.cfg
@@ -627,10 +627,6 @@ location_regexp: ".+/[+]/infra/config/.+" } builders { - name: "chromium/try/cast-binary-size" - includable_only: true - } - builders { name: "chromium/try/cast_shell_android" location_regexp: ".*" location_regexp_exclude: ".+/[+]/docs/.+" @@ -1910,6 +1906,10 @@ includable_only: true } builders { + name: "chromium/try/win11-x64-fyi-rel" + includable_only: true + } + builders { name: "chromium/try/win32-official" includable_only: true }
diff --git a/infra/config/generated/luci/cr-buildbucket.cfg b/infra/config/generated/luci/cr-buildbucket.cfg index a2381400d..1e2ac93 100644 --- a/infra/config/generated/luci/cr-buildbucket.cfg +++ b/infra/config/generated/luci/cr-buildbucket.cfg
@@ -22652,6 +22652,89 @@ } } builders { + name: "Win11 Tests x64" + swarming_host: "chromium-swarm.appspot.com" + dimensions: "builderless:1" + dimensions: "cores:8" + dimensions: "cpu:x86-64" + dimensions: "os:Windows-10" + dimensions: "pool:luci.chromium.ci" + dimensions: "ssd:0" + exe { + cipd_package: "infra/chromium/bootstrapper/${platform}" + cipd_version: "latest" + cmd: "bootstrapper" + } + properties: + '{' + ' "$bootstrap/exe": {' + ' "exe": {' + ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' + ' "cipd_version": "refs/heads/main",' + ' "cmd": [' + ' "luciexe"' + ' ]' + ' }' + ' },' + ' "$bootstrap/properties": {' + ' "properties_file": "infra/config/generated/builders/ci/Win11 Tests x64/properties.textpb",' + ' "top_level_project": {' + ' "ref": "refs/heads/main",' + ' "repo": {' + ' "host": "chromium.googlesource.com",' + ' "project": "chromium/src"' + ' }' + ' }' + ' },' + ' "builder_group": "chromium.fyi",' + ' "led_builder_is_bootstrapped": true,' + ' "recipe": "chromium"' + '}' + execution_timeout_secs: 36000 + build_numbers: YES + service_account: "chromium-ci-builder@chops-service-accounts.iam.gserviceaccount.com" + experiments { + key: "luci.recipes.use_python3" + value: 100 + } + experiments { + key: "luci.use_realms" + value: 100 + } + resultdb { + enable: true + bq_exports { + project: "chrome-luci-data" + dataset: "chromium" + table: "ci_test_results" + test_results {} + } + bq_exports { + project: "chrome-luci-data" + dataset: "chromium" + table: "gpu_ci_test_results" + test_results { + predicate { + test_id_regexp: "ninja://(chrome/test:|content/test:fuchsia_)telemetry_gpu_integration_test[^/]*/.+" + } + } + } + bq_exports { + project: "chrome-luci-data" + dataset: "chromium" + table: "blink_web_tests_ci_test_results" + test_results { + predicate { + test_id_regexp: "ninja://[^/]*blink_web_tests/.+" + } + } + } + history_options { + use_invocation_timestamp: true + } + } + } + builders { name: "Win7 (32) Tests" swarming_host: "chromium-swarm.appspot.com" dimensions: "builderless:1" @@ -23700,6 +23783,10 @@ build_numbers: YES service_account: "chromium-cipd-builder@chops-service-accounts.iam.gserviceaccount.com" experiments { + key: "luci.recipes.use_python3" + value: 100 + } + experiments { key: "luci.use_realms" value: 100 } @@ -54685,97 +54772,6 @@ } } builders { - name: "cast-binary-size" - swarming_host: "chromium-swarm.appspot.com" - dimensions: "builderless:1" - dimensions: "cores:8" - dimensions: "cpu:x86-64" - dimensions: "os:Ubuntu-18.04" - dimensions: "pool:luci.chromium.try" - dimensions: "ssd:0" - exe { - cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" - cipd_version: "refs/heads/main" - cmd: "luciexe" - } - properties: - '{' - ' "$build/binary_size": {' - ' "analyze_targets": [' - ' "//chromecast:cast_shell"' - ' ],' - ' "compile_targets": [' - ' "cast_shell"' - ' ]' - ' },' - ' "$build/goma": {' - ' "enable_ats": true,' - ' "rpc_extra_params": "?prod",' - ' "server_host": "goma.chromium.org",' - ' "use_luci_auth": true' - ' },' - ' "$recipe_engine/resultdb/test_presentation": {' - ' "column_keys": [],' - ' "grouping_keys": [' - ' "status",' - ' "v.test_suite"' - ' ]' - ' },' - ' "builder_group": "tryserver.chromium.linux",' - ' "recipe": "binary_size_cast_trybot"' - '}' - execution_timeout_secs: 14400 - expiration_secs: 7200 - grace_period { - seconds: 120 - } - caches { - name: "win_toolchain" - path: "win_toolchain" - } - build_numbers: YES - service_account: "chromium-try-builder@chops-service-accounts.iam.gserviceaccount.com" - task_template_canary_percentage { - value: 5 - } - experiments { - key: "luci.use_realms" - value: 100 - } - resultdb { - enable: true - bq_exports { - project: "chrome-luci-data" - dataset: "chromium" - table: "try_test_results" - test_results {} - } - bq_exports { - project: "chrome-luci-data" - dataset: "chromium" - table: "gpu_try_test_results" - test_results { - predicate { - test_id_regexp: "ninja://(chrome/test:|content/test:fuchsia_)telemetry_gpu_integration_test[^/]*/.+" - } - } - } - bq_exports { - project: "chrome-luci-data" - dataset: "chromium" - table: "blink_web_tests_try_test_results" - test_results { - predicate { - test_id_regexp: "ninja://[^/]*blink_web_tests/.+" - } - } - } - history_options { - use_invocation_timestamp: true - } - } - } - builders { name: "cast_shell_android" swarming_host: "chromium-swarm.appspot.com" dimensions: "builder:cast_shell_android" @@ -68514,7 +68510,7 @@ } experiments { key: "luci.recipes.use_python3" - value: 25 + value: 100 } experiments { key: "luci.use_realms" @@ -77238,7 +77234,7 @@ } experiments { key: "luci.recipes.use_python3" - value: 25 + value: 100 } experiments { key: "luci.use_realms" @@ -78810,6 +78806,100 @@ } } builders { + name: "win11-x64-fyi-rel" + swarming_host: "chromium-swarm.appspot.com" + dimensions: "builderless:1" + dimensions: "cores:8" + dimensions: "cpu:x86-64" + dimensions: "os:Windows-10" + dimensions: "pool:luci.chromium.try" + dimensions: "ssd:0" + exe { + cipd_package: "infra/chromium/bootstrapper/${platform}" + cipd_version: "latest" + cmd: "bootstrapper" + } + properties: + '{' + ' "$bootstrap/exe": {' + ' "exe": {' + ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' + ' "cipd_version": "refs/heads/main",' + ' "cmd": [' + ' "luciexe"' + ' ]' + ' }' + ' },' + ' "$bootstrap/properties": {' + ' "properties_file": "infra/config/generated/builders/try/win11-x64-fyi-rel/properties.textpb",' + ' "top_level_project": {' + ' "ref": "refs/heads/main",' + ' "repo": {' + ' "host": "chromium.googlesource.com",' + ' "project": "chromium/src"' + ' }' + ' }' + ' },' + ' "builder_group": "tryserver.chromium.win",' + ' "led_builder_is_bootstrapped": true,' + ' "recipe": "chromium_trybot"' + '}' + execution_timeout_secs: 14400 + expiration_secs: 7200 + grace_period { + seconds: 120 + } + caches { + name: "win_toolchain" + path: "win_toolchain" + } + build_numbers: YES + service_account: "chromium-try-builder@chops-service-accounts.iam.gserviceaccount.com" + task_template_canary_percentage { + value: 5 + } + experiments { + key: "luci.recipes.use_python3" + value: 100 + } + experiments { + key: "luci.use_realms" + value: 100 + } + resultdb { + enable: true + bq_exports { + project: "chrome-luci-data" + dataset: "chromium" + table: "try_test_results" + test_results {} + } + bq_exports { + project: "chrome-luci-data" + dataset: "chromium" + table: "gpu_try_test_results" + test_results { + predicate { + test_id_regexp: "ninja://(chrome/test:|content/test:fuchsia_)telemetry_gpu_integration_test[^/]*/.+" + } + } + } + bq_exports { + project: "chrome-luci-data" + dataset: "chromium" + table: "blink_web_tests_try_test_results" + test_results { + predicate { + test_id_regexp: "ninja://[^/]*blink_web_tests/.+" + } + } + } + history_options { + use_invocation_timestamp: true + } + } + } + builders { name: "win32-official" swarming_host: "chromium-swarm.appspot.com" dimensions: "builderless:1" @@ -80472,33 +80562,26 @@ dimensions: "os:Ubuntu-16.04|Ubuntu-18.04" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' + ' "$build/goma": {' + ' "enable_ats": true,' + ' "rpc_extra_params": "?prod",' + ' "server_host": "goma.chromium.org",' + ' "use_luci_auth": true' ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Android Builder/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -80523,33 +80606,26 @@ dimensions: "os:Ubuntu-16.04|Ubuntu-18.04" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' + ' "$build/goma": {' + ' "enable_ats": true,' + ' "rpc_extra_params": "?prod",' + ' "server_host": "goma.chromium.org",' + ' "use_luci_auth": true' ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Android Builder (dbg)/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -80574,33 +80650,26 @@ dimensions: "os:Ubuntu-16.04|Ubuntu-18.04" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' + ' "$build/goma": {' + ' "enable_ats": true,' + ' "rpc_extra_params": "?prod",' + ' "server_host": "goma.chromium.org",' + ' "use_luci_auth": true' ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Android Builder ARM64 (dbg)/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -80625,33 +80694,20 @@ dimensions: "os:Ubuntu-16.04|Ubuntu-18.04" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' - ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Android Tests (dbg) (M Nexus5X)/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -80676,33 +80732,20 @@ dimensions: "os:Ubuntu-16.04|Ubuntu-18.04" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' - ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Android Tests (dbg) (N Nexus5X)/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -80727,33 +80770,26 @@ dimensions: "os:Ubuntu-16.04|Ubuntu-18.04" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' + ' "$build/goma": {' + ' "enable_ats": true,' + ' "rpc_extra_params": "?prod",' + ' "server_host": "goma.chromium.org",' + ' "use_luci_auth": true' ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Linux Builder/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -80778,33 +80814,26 @@ dimensions: "os:Ubuntu-16.04|Ubuntu-18.04" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' + ' "$build/goma": {' + ' "enable_ats": true,' + ' "rpc_extra_params": "?prod",' + ' "server_host": "goma.chromium.org",' + ' "use_luci_auth": true' ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Linux Builder (dbg)/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -80829,33 +80858,20 @@ dimensions: "os:Ubuntu-16.04|Ubuntu-18.04" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' - ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Linux Tester/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -80880,33 +80896,25 @@ dimensions: "os:Mac" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' + ' "$build/goma": {' + ' "rpc_extra_params": "?prod",' + ' "server_host": "goma.chromium.org",' + ' "use_luci_auth": true' ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Mac Builder/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -80931,33 +80939,25 @@ dimensions: "os:Mac" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' + ' "$build/goma": {' + ' "rpc_extra_params": "?prod",' + ' "server_host": "goma.chromium.org",' + ' "use_luci_auth": true' ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Mac Builder (dbg)/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -80982,33 +80982,20 @@ dimensions: "os:Mac" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' - ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Mac Tester/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -81033,33 +81020,26 @@ dimensions: "os:Windows-10" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' + ' "$build/goma": {' + ' "enable_ats": true,' + ' "rpc_extra_params": "?prod",' + ' "server_host": "goma.chromium.org",' + ' "use_luci_auth": true' ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Win Builder/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -81084,33 +81064,26 @@ dimensions: "os:Windows-10" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' + ' "$build/goma": {' + ' "enable_ats": true,' + ' "rpc_extra_params": "?prod",' + ' "server_host": "goma.chromium.org",' + ' "use_luci_auth": true' ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Win Builder (dbg)/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -81135,33 +81108,20 @@ dimensions: "os:Windows-10" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' - ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Win10 Tester/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -81186,33 +81146,20 @@ dimensions: "os:Windows-10" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' - ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Win7 Tester/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200 @@ -81237,33 +81184,20 @@ dimensions: "os:Windows-10" dimensions: "pool:luci.chromium.webrtc.fyi" exe { - cipd_package: "infra/chromium/bootstrapper/${platform}" - cipd_version: "latest" - cmd: "bootstrapper" + cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" + cipd_version: "refs/heads/main" + cmd: "luciexe" } properties: '{' - ' "$bootstrap/exe": {' - ' "exe": {' - ' "cipd_package": "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build",' - ' "cipd_version": "refs/heads/main",' - ' "cmd": [' - ' "luciexe"' - ' ]' - ' }' - ' },' - ' "$bootstrap/properties": {' - ' "properties_file": "infra/config/generated/builders/webrtc.fyi/WebRTC Chromium FYI Win8 Tester/properties.textpb",' - ' "top_level_project": {' - ' "ref": "refs/heads/main",' - ' "repo": {' - ' "host": "chromium.googlesource.com",' - ' "project": "chromium/src"' - ' }' - ' }' + ' "$recipe_engine/resultdb/test_presentation": {' + ' "column_keys": [],' + ' "grouping_keys": [' + ' "status",' + ' "v.test_suite"' + ' ]' ' },' ' "builder_group": "chromium.webrtc.fyi",' - ' "led_builder_is_bootstrapped": true,' ' "recipe": "chromium"' '}' execution_timeout_secs: 7200
diff --git a/infra/config/generated/luci/luci-milo.cfg b/infra/config/generated/luci/luci-milo.cfg index 2618c5e..3cebf77 100644 --- a/infra/config/generated/luci/luci-milo.cfg +++ b/infra/config/generated/luci/luci-milo.cfg
@@ -6448,6 +6448,10 @@ category: "win10" } builders { + name: "buildbucket/luci.chromium.ci/Win11 Tests x64" + category: "win11" + } + builders { name: "buildbucket/luci.chromium.ci/win32-arm64-rel" category: "win32|arm64" } @@ -14130,9 +14134,6 @@ name: "buildbucket/luci.chromium.try/branch-config-verifier" } builders { - name: "buildbucket/luci.chromium.try/cast-binary-size" - } - builders { name: "buildbucket/luci.chromium.try/cast_shell_android" } builders { @@ -14919,6 +14920,9 @@ name: "buildbucket/luci.chromium.try/win10_chromium_x64_rel_ng_rts" } builders { + name: "buildbucket/luci.chromium.try/win11-x64-fyi-rel" + } + builders { name: "buildbucket/luci.chromium.try/win32-official" } builders { @@ -15452,9 +15456,6 @@ id: "tryserver.chromium.linux" name: "tryserver.chromium.linux" builders { - name: "buildbucket/luci.chromium.try/cast-binary-size" - } - builders { name: "buildbucket/luci.chromium.try/cast_shell_audio_linux" } builders { @@ -16061,6 +16062,9 @@ name: "buildbucket/luci.chromium.try/win10_chromium_x64_rel_ng_rts" } builders { + name: "buildbucket/luci.chromium.try/win11-x64-fyi-rel" + } + builders { name: "buildbucket/luci.chromium.try/win7-rel" } builders {
diff --git a/infra/config/generated/luci/luci-scheduler.cfg b/infra/config/generated/luci/luci-scheduler.cfg index 409ebba..d5a4f9f90 100644 --- a/infra/config/generated/luci/luci-scheduler.cfg +++ b/infra/config/generated/luci/luci-scheduler.cfg
@@ -3903,6 +3903,20 @@ } } job { + id: "Win11 Tests x64" + realm: "ci" + acls { + role: TRIGGERER + granted_to: "chromium-ci-builder@chops-service-accounts.iam.gserviceaccount.com" + } + acl_sets: "ci" + buildbucket { + server: "cr-buildbucket.appspot.com" + bucket: "luci.chromium.ci" + builder: "Win11 Tests x64" + } +} +job { id: "Win7 (32) Tests" realm: "ci" acls {
diff --git a/infra/config/generators/scheduler-noop-jobs.star b/infra/config/generators/scheduler-noop-jobs.star index 4bee49e..7e7f8ed 100644 --- a/infra/config/generators/scheduler-noop-jobs.star +++ b/infra/config/generators/scheduler-noop-jobs.star
@@ -23,6 +23,11 @@ # the branches "mac-osxbeta-rel": branches.DESKTOP_EXTENDED_STABLE_MILESTONE, + # This tester is triggered by 'Win x64 Builder', but it is an FYI builder + # and not mirrored by any branched try builders, so we do not need to run it + # on the branches + "Win11 Tests x64": branches.STANDARD_MILESTONE, + # These Android testers are triggered by 'Android arm Builder (dbg)', but we # don't have sufficient capacity of devices with older Android versions, so # we do not run them on the branches
diff --git a/infra/config/recipes.star b/infra/config/recipes.star index 9abf88c0..53dbab08 100644 --- a/infra/config/recipes.star +++ b/infra/config/recipes.star
@@ -81,6 +81,7 @@ build_recipe( name = "recipe:android/androidx_packager", + use_python3 = True, ) build_recipe( @@ -112,10 +113,6 @@ ) build_recipe( - name = "recipe:binary_size_cast_trybot", -) - -build_recipe( name = "recipe:binary_size_fuchsia_trybot", ) @@ -176,9 +173,7 @@ build_recipe( name = "recipe:chromium_libfuzzer_trybot", - experiments = { - "luci.recipes.use_python3": 25, - }, + use_python3 = True, ) build_recipe(
diff --git a/infra/config/subprojects/chromium/ci/chromium.fyi.star b/infra/config/subprojects/chromium/ci/chromium.fyi.star index d6023cdf..530c822 100644 --- a/infra/config/subprojects/chromium/ci/chromium.fyi.star +++ b/infra/config/subprojects/chromium/ci/chromium.fyi.star
@@ -38,6 +38,7 @@ "network", "viz", "win10", + "win11", "win32", "paeverywhere", "backuprefptr", @@ -86,7 +87,10 @@ console_view_entry = consoles.console_view_entry( category = "viz", ), + goma_backend = None, os = os.LINUX_BIONIC_SWITCH_TO_DEFAULT, + reclient_jobs = rbe_jobs.DEFAULT, + reclient_instance = rbe_instance.DEFAULT, ) ci.builder( @@ -646,6 +650,7 @@ reclient_jobs = 400, reclient_instance = rbe_instance.DEFAULT, reclient_ensure_verified = True, + reclient_rewrapper_env = {"RBE_compare": "true"}, ) # End - Reclient migration, phase 2, block 1 shadow builders @@ -1143,6 +1148,18 @@ ) ci.builder( + name = "Win11 Tests x64", + builderless = True, + console_view_entry = consoles.console_view_entry( + category = "win11", + ), + goma_backend = None, + main_console_view = None, + os = os.WINDOWS_10, + triggered_by = ["ci/Win x64 Builder"], +) + +ci.builder( name = "win32-arm64-rel", console_view_entry = consoles.console_view_entry( category = "win32|arm64",
diff --git a/infra/config/subprojects/chromium/ci/chromium.gpu.fyi.star b/infra/config/subprojects/chromium/ci/chromium.gpu.fyi.star index 89fa9f4..a07efb0 100644 --- a/infra/config/subprojects/chromium/ci/chromium.gpu.fyi.star +++ b/infra/config/subprojects/chromium/ci/chromium.gpu.fyi.star
@@ -153,7 +153,7 @@ short_name = "arm", ), goma_backend = None, - reclient_jobs = rbe_jobs.DEFAULT, + reclient_jobs = rbe_jobs.HIGH_JOBS_FOR_CI, reclient_instance = rbe_instance.DEFAULT, ) @@ -164,7 +164,7 @@ short_name = "arm64", ), goma_backend = None, - reclient_jobs = rbe_jobs.DEFAULT, + reclient_jobs = rbe_jobs.HIGH_JOBS_FOR_CI, reclient_instance = rbe_instance.DEFAULT, ) @@ -204,6 +204,9 @@ category = "Linux", short_name = "tsn", ), + goma_backend = None, + reclient_jobs = rbe_jobs.DEFAULT, + reclient_instance = rbe_instance.DEFAULT, ) # Builder + tester.
diff --git a/infra/config/subprojects/chromium/ci/chromium.gpu.star b/infra/config/subprojects/chromium/ci/chromium.gpu.star index 7863d35..06cd6c8 100644 --- a/infra/config/subprojects/chromium/ci/chromium.gpu.star +++ b/infra/config/subprojects/chromium/ci/chromium.gpu.star
@@ -56,6 +56,9 @@ category = "Linux", ), tree_closing = False, + goma_backend = None, + reclient_jobs = rbe_jobs.DEFAULT, + reclient_instance = rbe_instance.DEFAULT, ) ci.gpu.mac_builder(
diff --git a/infra/config/subprojects/chromium/try/tryserver.chromium.linux.star b/infra/config/subprojects/chromium/try/tryserver.chromium.linux.star index 2bd0b52..86606fcb 100644 --- a/infra/config/subprojects/chromium/try/tryserver.chromium.linux.star +++ b/infra/config/subprojects/chromium/try/tryserver.chromium.linux.star
@@ -61,22 +61,6 @@ ) try_.builder( - name = "cast-binary-size", - builderless = True, - executable = "recipe:binary_size_cast_trybot", - properties = { - "$build/binary_size": { - "analyze_targets": [ - "//chromecast:cast_shell", - ], - "compile_targets": [ - "cast_shell", - ], - }, - }, -) - -try_.builder( name = "fuchsia-binary-size", branch_selector = branches.STANDARD_MILESTONE, builderless = True,
diff --git a/infra/config/subprojects/chromium/try/tryserver.chromium.win.star b/infra/config/subprojects/chromium/try/tryserver.chromium.win.star index 6f7b91a..118156e 100644 --- a/infra/config/subprojects/chromium/try/tryserver.chromium.win.star +++ b/infra/config/subprojects/chromium/try/tryserver.chromium.win.star
@@ -106,6 +106,14 @@ ) try_.builder( + name = "win11-x64-fyi-rel", + builderless = True, + use_clang_coverage = True, + coverage_test_types = ["unit", "overall"], + os = os.WINDOWS_10, +) + +try_.builder( name = "win10_chromium_inverse_fieldtrials_x64_fyi_rel_ng", os = os.WINDOWS_10, )
diff --git a/infra/config/subprojects/webrtc/webrtc.fyi.star b/infra/config/subprojects/webrtc/webrtc.fyi.star index 8845f563..66b30e2a 100644 --- a/infra/config/subprojects/webrtc/webrtc.fyi.star +++ b/infra/config/subprojects/webrtc/webrtc.fyi.star
@@ -2,7 +2,7 @@ # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -load("//lib/builders.star", "builder", "cpu", "defaults", "goma", "os", "xcode") +load("//lib/builders.star", "cpu", "defaults", "goma", "os", "xcode", base_builder = "builder") luci.bucket( name = "webrtc.fyi", @@ -33,6 +33,10 @@ refs = ["refs/heads/main"], ) +def builder(**kwargs): + kwargs.setdefault("bootstrap", False) + return base_builder(**kwargs) + defaults.bucket.set("webrtc.fyi") defaults.builder_group.set("chromium.webrtc.fyi") defaults.builderless.set(None)
diff --git a/infra/inclusive_language_presubmit_exempt_dirs.txt b/infra/inclusive_language_presubmit_exempt_dirs.txt index 7e34076..e41e9de 100644 --- a/infra/inclusive_language_presubmit_exempt_dirs.txt +++ b/infra/inclusive_language_presubmit_exempt_dirs.txt
@@ -734,15 +734,30 @@ third_party/tensorflow-text/src/tensorflow_text/python/metrics 1 1 third_party/tensorflow-text/src/tensorflow_text/python/ops/test_data 1 1 third_party/test_fonts 15 1 +third_party/tflite_support/patches 8 2 +third_party/tflite_support/src 1 1 +third_party/tflite_support/src/tensorflow_lite_support/acceleration 3 1 +third_party/tflite_support/src/tensorflow_lite_support/c 2 2 +third_party/tflite_support/src/tensorflow_lite_support/cc/port/default 1 1 third_party/tflite_support/src/tensorflow_lite_support/cc/task 4 1 -third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision 9 4 -third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto 4 2 +third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio 1 1 +third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto 2 1 +third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision 13 5 +third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto 12 4 +third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text 5 3 third_party/tflite_support/src/tensorflow_lite_support/codegen 1 1 third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece 2 1 third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/testdata 1 1 third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python 2 1 -third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop 4 2 +third_party/tflite_support/src/tensorflow_lite_support/custom_ops/testdata 5 1 +third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop 8 3 +third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python 12 3 +third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops 1 1 +third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier 1 1 +third_party/tflite_support/src/tensorflow_lite_support/metadata 5 1 third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata 3 1 +third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier 2 1 +third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image 1 1 third_party/tlslite 1 1 third_party/tlslite/patches 7 3 third_party/tlslite/tlslite 3 2
diff --git a/ios/chrome/browser/policy/reporting/report_scheduler_ios_unittest.mm b/ios/chrome/browser/policy/reporting/report_scheduler_ios_unittest.mm index 636d5d4..edc6e6a 100644 --- a/ios/chrome/browser/policy/reporting/report_scheduler_ios_unittest.mm +++ b/ios/chrome/browser/policy/reporting/report_scheduler_ios_unittest.mm
@@ -72,7 +72,8 @@ MockReportUploader& operator=(const MockReportUploader&) = delete; ~MockReportUploader() override = default; - MOCK_METHOD2(SetRequestAndUpload, void(ReportRequestQueue, ReportCallback)); + MOCK_METHOD3(SetRequestAndUpload, + void(ReportType, ReportRequestQueue, ReportCallback)); }; class ReportSchedulerIOSTest : public PlatformTest { @@ -193,8 +194,8 @@ EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kSuccess)); + EXPECT_CALL(*uploader_, SetRequestAndUpload(ReportType::kFull, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kSuccess)); CreateScheduler(); EXPECT_TRUE(scheduler_->IsNextReportScheduledForTesting()); @@ -214,8 +215,8 @@ EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kTransientError)); + EXPECT_CALL(*uploader_, SetRequestAndUpload(ReportType::kFull, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kTransientError)); CreateScheduler(); EXPECT_TRUE(scheduler_->IsNextReportScheduledForTesting()); @@ -235,8 +236,8 @@ EXPECT_CALL_SetupRegistrationWithSetDMToken(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kPersistentError)); + EXPECT_CALL(*uploader_, SetRequestAndUpload(ReportType::kFull, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kPersistentError)); CreateScheduler(); EXPECT_TRUE(scheduler_->IsNextReportScheduledForTesting()); @@ -261,7 +262,7 @@ EXPECT_CALL_SetupRegistrationWithSetDMToken(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(0))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)).Times(0); + EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _, _)).Times(0); CreateScheduler(); EXPECT_TRUE(scheduler_->IsNextReportScheduledForTesting()); @@ -289,8 +290,8 @@ EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kSuccess)); + EXPECT_CALL(*uploader_, SetRequestAndUpload(ReportType::kFull, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kSuccess)); CreateScheduler(); EXPECT_TRUE(scheduler_->IsNextReportScheduledForTesting()); @@ -309,8 +310,8 @@ EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kSuccess)); + EXPECT_CALL(*uploader_, SetRequestAndUpload(ReportType::kFull, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kSuccess)); CreateScheduler(); EXPECT_TRUE(scheduler_->IsNextReportScheduledForTesting()); @@ -346,8 +347,8 @@ EXPECT_CALL_SetupRegistration(); EXPECT_CALL(*generator_, OnGenerate(ReportType::kFull, _)) .WillOnce(WithArgs<1>(ScheduleGeneratorCallback(1))); - EXPECT_CALL(*uploader_, SetRequestAndUpload(_, _)) - .WillOnce(RunOnceCallback<1>(ReportUploader::kSuccess)); + EXPECT_CALL(*uploader_, SetRequestAndUpload(ReportType::kFull, _, _)) + .WillOnce(RunOnceCallback<2>(ReportUploader::kSuccess)); CreateScheduler(); EXPECT_TRUE(scheduler_->IsNextReportScheduledForTesting());
diff --git a/ios/chrome/browser/providers/BUILD.gn b/ios/chrome/browser/providers/BUILD.gn index de7ee7f..0c486d0 100644 --- a/ios/chrome/browser/providers/BUILD.gn +++ b/ios/chrome/browser/providers/BUILD.gn
@@ -15,7 +15,6 @@ "//ios/chrome/browser/voice:voice_search_language", "//ios/chrome/browser/web:feature_flags", "//ios/public/provider/chrome/browser", - "//ios/public/provider/chrome/browser:font_size_java_script_feature", "//ios/public/provider/chrome/browser/discover_feed", "//ios/public/provider/chrome/browser/follow", "//ios/public/provider/chrome/browser/signin",
diff --git a/ios/chrome/browser/providers/text_zoom/BUILD.gn b/ios/chrome/browser/providers/text_zoom/BUILD.gn index efe34d2..393de26 100644 --- a/ios/chrome/browser/providers/text_zoom/BUILD.gn +++ b/ios/chrome/browser/providers/text_zoom/BUILD.gn
@@ -7,7 +7,7 @@ sources = [ "chromium_text_zoom.mm" ] deps = [ "//ios/chrome/browser/web:feature_flags", - "//ios/public/provider/chrome/browser:font_size_java_script_feature", + "//ios/chrome/browser/web/font_size", "//ios/public/provider/chrome/browser/text_zoom:text_zoom_api", "//ui/base", ]
diff --git a/ios/chrome/browser/providers/text_zoom/chromium_text_zoom.mm b/ios/chrome/browser/providers/text_zoom/chromium_text_zoom.mm index 68c532a..b108ef4 100644 --- a/ios/chrome/browser/providers/text_zoom/chromium_text_zoom.mm +++ b/ios/chrome/browser/providers/text_zoom/chromium_text_zoom.mm
@@ -3,7 +3,7 @@ // found in the LICENSE file. #include "ios/chrome/browser/web/features.h" -#import "ios/public/provider/chrome/browser/font_size_java_script_feature.h" +#import "ios/chrome/browser/web/font_size/font_size_java_script_feature.h" #import "ios/public/provider/chrome/browser/text_zoom/text_zoom_api.h" #include "ui/base/device_form_factor.h"
diff --git a/ios/chrome/browser/ssl/ios_captive_portal_blocking_page.mm b/ios/chrome/browser/ssl/ios_captive_portal_blocking_page.mm index 5333f0ea..f6511ec 100644 --- a/ios/chrome/browser/ssl/ios_captive_portal_blocking_page.mm +++ b/ios/chrome/browser/ssl/ios_captive_portal_blocking_page.mm
@@ -44,6 +44,8 @@ load_time_data->SetStringKey("iconClass", "icon-offline"); load_time_data->SetStringKey("type", "CAPTIVE_PORTAL"); load_time_data->SetBoolKey("overridable", false); + load_time_data->SetBoolKey("hide_primary_button", false); + load_time_data->SetStringKey( "primaryButtonText", l10n_util::GetStringUTF16(IDS_CAPTIVE_PORTAL_BUTTON_OPEN_LOGIN_PAGE));
diff --git a/ios/chrome/browser/ui/bookmarks/bookmark_ios_unittest.h b/ios/chrome/browser/ui/bookmarks/bookmark_ios_unittest.h index 689603d..82ff8af 100644 --- a/ios/chrome/browser/ui/bookmarks/bookmark_ios_unittest.h +++ b/ios/chrome/browser/ui/bookmarks/bookmark_ios_unittest.h
@@ -18,9 +18,6 @@ class ManagedBookmarkService; } // namespace bookmarks class Browser; -namespace base { -class ScopedTempDir; -} // namespace base class TestChromeBrowserState; // Provides common bookmark testing infrastructure. @@ -39,11 +36,6 @@ NSString* title); void ChangeTitle(NSString* title, const bookmarks::BookmarkNode* node); - // A state directory that outlives |task_environment_| is needed because - // CreateHistoryService/CreateBookmarkModel use the directory to host - // databases. See https://crbug.com/546640 for more details. - std::unique_ptr<base::ScopedTempDir> state_dir_; - web::WebTaskEnvironment task_environment_; IOSChromeScopedTestingLocalState local_state_; std::unique_ptr<Browser> browser_;
diff --git a/ios/chrome/browser/ui/bookmarks/bookmark_ios_unittest.mm b/ios/chrome/browser/ui/bookmarks/bookmark_ios_unittest.mm index 575fb620..bc93173 100644 --- a/ios/chrome/browser/ui/bookmarks/bookmark_ios_unittest.mm +++ b/ios/chrome/browser/ui/bookmarks/bookmark_ios_unittest.mm
@@ -33,9 +33,6 @@ AuthenticationServiceFactory::GetInstance(), base::BindRepeating( &AuthenticationServiceFake::CreateAuthenticationService)); - state_dir_ = std::make_unique<base::ScopedTempDir>(); - ASSERT_TRUE(state_dir_->CreateUniqueTempDir()); - test_cbs_builder.SetPath(state_dir_->GetPath()); chrome_browser_state_ = test_cbs_builder.Build(); chrome_browser_state_->CreateBookmarkModel(true);
diff --git a/ios/chrome/browser/ui/recent_tabs/recent_tabs_menu_helper.mm b/ios/chrome/browser/ui/recent_tabs/recent_tabs_menu_helper.mm index b77623b..03b3790 100644 --- a/ios/chrome/browser/ui/recent_tabs/recent_tabs_menu_helper.mm +++ b/ios/chrome/browser/ui/recent_tabs/recent_tabs_menu_helper.mm
@@ -7,6 +7,7 @@ #import "base/ios/ios_util.h" #include "base/metrics/histogram_functions.h" #include "base/metrics/histogram_macros.h" +#import "ios/chrome/browser/net/crurl.h" #import "ios/chrome/browser/ui/coordinators/chrome_coordinator.h" #import "ios/chrome/browser/ui/menu/browser_action_factory.h" #import "ios/chrome/browser/ui/menu/menu_histograms.h" @@ -73,10 +74,14 @@ NSMutableArray<UIMenuElement*>* menuElements = [[NSMutableArray alloc] init]; + GURL gurl; + if (item.URL) { + gurl = item.URL.gurl; + } [menuElements addObject: [actionFactory - actionToOpenInNewTabWithURL:item.URL + actionToOpenInNewTabWithURL:gurl completion:^{ [weakSelf.recentTabsPresentationDelegate showActiveRegularTabFromRecentTabs]; @@ -85,16 +90,16 @@ if (base::ios::IsMultipleScenesSupported()) { [menuElements addObject:[actionFactory - actionToOpenInNewWindowWithURL:item.URL + actionToOpenInNewWindowWithURL:gurl activityOrigin: WindowActivityRecentTabsOrigin]]; } - [menuElements addObject:[actionFactory actionToCopyURL:item.URL]]; + [menuElements addObject:[actionFactory actionToCopyURL:gurl]]; [menuElements addObject:[actionFactory actionToShareWithBlock:^{ [weakSelf.contextMenuDelegate - shareURL:item.URL + shareURL:gurl title:item.title scenario:ActivityScenario::RecentTabsEntry fromView:view];
diff --git a/ios/chrome/browser/ui/recent_tabs/recent_tabs_table_view_controller.mm b/ios/chrome/browser/ui/recent_tabs/recent_tabs_table_view_controller.mm index 44de30c6..2925cdc5 100644 --- a/ios/chrome/browser/ui/recent_tabs/recent_tabs_table_view_controller.mm +++ b/ios/chrome/browser/ui/recent_tabs/recent_tabs_table_view_controller.mm
@@ -358,7 +358,8 @@ TableViewURLItem* recentlyClosedTab = [[TableViewURLItem alloc] initWithType:ItemTypeRecentlyClosed]; recentlyClosedTab.title = base::SysUTF16ToNSString(navigationEntry.title()); - recentlyClosedTab.URL = navigationEntry.virtual_url(); + recentlyClosedTab.URL = + [[CrURL alloc] initWithGURL:navigationEntry.virtual_url()]; [self.tableViewModel addItem:recentlyClosedTab toSectionWithIdentifier:SectionIdentifierRecentlyClosedTabs]; } @@ -437,7 +438,7 @@ TableViewURLItem* sessionTabItem = [[TableViewURLItem alloc] initWithType:ItemTypeSessionTabData]; sessionTabItem.title = title; - sessionTabItem.URL = sessionTab->virtual_url; + sessionTabItem.URL = [[CrURL alloc] initWithGURL:sessionTab->virtual_url]; [model addItem:sessionTabItem toSectionWithIdentifier:[self sectionIdentifierForSession:session]]; } @@ -970,7 +971,10 @@ TableViewItem* item = [self.tableViewModel itemAtIndexPath:indexPath]; TableViewURLItem* URLItem = base::mac::ObjCCastStrict<TableViewURLItem>(item); - return [[URLInfo alloc] initWithURL:URLItem.URL title:URLItem.title]; + GURL gurl; + if (URLItem.URL) + gurl = URLItem.URL.gurl; + return [[URLInfo alloc] initWithURL:gurl title:URLItem.title]; } case ItemTypeRecentlyClosedHeader: @@ -1027,9 +1031,8 @@ TableViewURLCell* URLCell = base::mac::ObjCCastStrict<TableViewURLCell>(cell); NSString* itemIdentifier = URLItem.uniqueIdentifier; - CrURL* crurl = [[CrURL alloc] initWithGURL:URLItem.URL]; [self.imageDataSource - faviconForURL:crurl + faviconForURL:URLItem.URL completion:^(FaviconAttributes* attributes) { // Only set favicon if the cell hasn't been reused. if ([URLCell.cellUniqueIdentifier isEqualToString:itemIdentifier]) {
diff --git a/ios/chrome/browser/ui/settings/table_cell_catalog_view_controller.mm b/ios/chrome/browser/ui/settings/table_cell_catalog_view_controller.mm index f98be8473..ae4d0a3 100644 --- a/ios/chrome/browser/ui/settings/table_cell_catalog_view_controller.mm +++ b/ios/chrome/browser/ui/settings/table_cell_catalog_view_controller.mm
@@ -558,35 +558,36 @@ TableViewURLItem* item = [[TableViewURLItem alloc] initWithType:ItemTypeURLNoMetadata]; item.title = @"Google Design"; - item.URL = GURL("https://design.google.com"); + item.URL = [[CrURL alloc] initWithGURL:GURL("https://design.google.com")]; [model addItem:item toSectionWithIdentifier:SectionIdentifierURL]; item = [[TableViewURLItem alloc] initWithType:ItemTypeURLNoMetadata]; - item.URL = GURL("https://notitle.google.com"); + item.URL = [[CrURL alloc] initWithGURL:GURL("https://notitle.google.com")]; [model addItem:item toSectionWithIdentifier:SectionIdentifierURL]; item = [[TableViewURLItem alloc] initWithType:ItemTypeURLWithTimestamp]; item.title = @"Google"; - item.URL = GURL("https://www.google.com"); + item.URL = [[CrURL alloc] initWithGURL:GURL("https://www.google.com")]; item.metadata = @"3:42 PM"; [model addItem:item toSectionWithIdentifier:SectionIdentifierURL]; item = [[TableViewURLItem alloc] initWithType:ItemTypeURLWithSize]; item.title = @"World Series 2017: Houston Astros Defeat Someone Else"; - item.URL = GURL("https://m.bbc.com"); + item.URL = [[CrURL alloc] initWithGURL:GURL("https://m.bbc.com")]; item.metadata = @"176 KB"; [model addItem:item toSectionWithIdentifier:SectionIdentifierURL]; item = [[TableViewURLItem alloc] initWithType:ItemTypeURLWithSupplementalText]; item.title = @"Chrome | Google Blog"; - item.URL = GURL("https://blog.google/products/chrome/"); + item.URL = + [[CrURL alloc] initWithGURL:GURL("https://blog.google/products/chrome/")]; item.supplementalURLText = @"Read 4 days ago"; [model addItem:item toSectionWithIdentifier:SectionIdentifierURL]; item = [[TableViewURLItem alloc] initWithType:ItemTypeURLWithBadgeImage]; item.title = @"Photos - Google Photos"; - item.URL = GURL("https://photos.google.com/"); + item.URL = [[CrURL alloc] initWithGURL:GURL("https://photos.google.com/")]; item.badgeImage = [UIImage imageNamed:@"table_view_cell_check_mark"]; [model addItem:item toSectionWithIdentifier:SectionIdentifierURL]; }
diff --git a/ios/chrome/browser/ui/table_view/cells/BUILD.gn b/ios/chrome/browser/ui/table_view/cells/BUILD.gn index f1eda1b0..62fe2da 100644 --- a/ios/chrome/browser/ui/table_view/cells/BUILD.gn +++ b/ios/chrome/browser/ui/table_view/cells/BUILD.gn
@@ -116,6 +116,7 @@ ":cells_constants", "//base", "//base/test:test_support", + "//ios/chrome/browser/net:crurl", "//ios/chrome/browser/ui/icons", "//ios/chrome/browser/ui/table_view:styler", "//ios/chrome/common/ui/colors",
diff --git a/ios/chrome/browser/ui/table_view/cells/table_view_url_item.h b/ios/chrome/browser/ui/table_view/cells/table_view_url_item.h index c54ac6e..088a7819 100644 --- a/ios/chrome/browser/ui/table_view/cells/table_view_url_item.h +++ b/ios/chrome/browser/ui/table_view/cells/table_view_url_item.h
@@ -10,7 +10,7 @@ #import "ios/chrome/browser/ui/table_view/cells/table_view_cell.h" #import "ios/chrome/browser/ui/table_view/cells/table_view_item.h" -class GURL; +@class CrURL; @class FaviconView; @class TableViewURLCellFaviconBadgeView; @@ -19,8 +19,8 @@ // The title of the page at |URL|. @property(nonatomic, readwrite, copy) NSString* title; -// GURL from which the cell will retrieve a favicon and display the host name. -@property(nonatomic, assign) GURL URL; +// CrURL from which the cell will retrieve a favicon and display the host name. +@property(nonatomic, readwrite, strong) CrURL* URL; // Supplemental text used to describe the URL. @property(nonatomic, readwrite, copy) NSString* supplementalURLText; // Delimiter used to separate the URL hostname and the supplemental text.
diff --git a/ios/chrome/browser/ui/table_view/cells/table_view_url_item.mm b/ios/chrome/browser/ui/table_view/cells/table_view_url_item.mm index a3223627..9e27e6e 100644 --- a/ios/chrome/browser/ui/table_view/cells/table_view_url_item.mm +++ b/ios/chrome/browser/ui/table_view/cells/table_view_url_item.mm
@@ -6,6 +6,7 @@ #include "base/mac/foundation_util.h" #include "base/strings/sys_string_conversions.h" +#import "ios/chrome/browser/net/crurl.h" #import "ios/chrome/browser/ui/elements/favicon_container_view.h" #import "ios/chrome/browser/ui/table_view/cells/table_view_cells_constants.h" #import "ios/chrome/browser/ui/table_view/cells/table_view_url_cell_favicon_badge_view.h" @@ -71,21 +72,24 @@ } - (NSString*)uniqueIdentifier { - return base::SysUTF8ToNSString(self.URL.host()); + if (!self.URL) + return @""; + return base::SysUTF8ToNSString(self.URL.gurl.host()); } #pragma mark Private // Returns the text to use when configuring a TableViewURLCell's title label. - (NSString*)titleLabelText { - if (self.title.length) { + if (self.title.length) return self.title; - } else if (base::SysUTF8ToNSString(self.URL.host()).length) { - return base::SysUTF8ToNSString(self.URL.host()); - } else { - // Backup in case host returns nothing (e.g. about:blank). - return base::SysUTF8ToNSString(self.URL.spec()); - } + if (!self.URL) + return @""; + NSString* hostname = base::SysUTF8ToNSString(self.URL.gurl.host()); + if (hostname.length) + return hostname; + // Backup in case host returns nothing (e.g. about:blank). + return base::SysUTF8ToNSString(self.URL.gurl.spec()); } // Returns the text to use when configuring a TableViewURLCell's URL label. @@ -96,7 +100,9 @@ return self.supplementalURLText; // Append the hostname with the supplemental text. - NSString* hostname = base::SysUTF8ToNSString(self.URL.host()); + if (!self.URL) + return @""; + NSString* hostname = base::SysUTF8ToNSString(self.URL.gurl.host()); if (self.supplementalURLText.length) { NSString* delimeter = self.supplementalURLTextDelimiter.length
diff --git a/ios/chrome/browser/ui/table_view/cells/table_view_url_item_unittest.mm b/ios/chrome/browser/ui/table_view/cells/table_view_url_item_unittest.mm index 46c00fd..fe9056e 100644 --- a/ios/chrome/browser/ui/table_view/cells/table_view_url_item_unittest.mm +++ b/ios/chrome/browser/ui/table_view/cells/table_view_url_item_unittest.mm
@@ -6,6 +6,7 @@ #include "base/mac/foundation_util.h" #include "base/strings/sys_string_conversions.h" +#import "ios/chrome/browser/net/crurl.h" #import "ios/chrome/browser/ui/table_view/chrome_table_view_styler.h" #include "net/base/mac/url_conversions.h" #include "testing/gtest/include/gtest/gtest.h" @@ -30,7 +31,7 @@ TableViewURLItem* item = [[TableViewURLItem alloc] initWithType:0]; item.title = titleText; - item.URL = net::GURLWithNSURL([NSURL URLWithString:URLText]); + item.URL = [[CrURL alloc] initWithNSURL:[NSURL URLWithString:URLText]]; item.metadata = metadataText; id cell = [[[item cellClass] alloc] init]; @@ -91,7 +92,7 @@ TableViewURLItem* item = [[TableViewURLItem alloc] initWithType:0]; item.title = kTitle; - item.URL = kURL; + item.URL = [[CrURL alloc] initWithGURL:kURL]; item.supplementalURLText = kSupplementalURLText; item.supplementalURLTextDelimiter = kSupplementalURLTextDelimiter; @@ -110,7 +111,7 @@ NSString* const kSupplementalURLText = @"supplement"; TableViewURLItem* item = [[TableViewURLItem alloc] initWithType:0]; - item.URL = kURL; + item.URL = [[CrURL alloc] initWithGURL:kURL]; item.supplementalURLText = kSupplementalURLText; id cell = [[[item cellClass] alloc] init];
diff --git a/ios/chrome/browser/web/BUILD.gn b/ios/chrome/browser/web/BUILD.gn index 15be8c73..f7ae60a 100644 --- a/ios/chrome/browser/web/BUILD.gn +++ b/ios/chrome/browser/web/BUILD.gn
@@ -245,6 +245,7 @@ "//ios/chrome/browser/ui/reading_list:reading_list_javascript_feature", "//ios/chrome/browser/ui/util", "//ios/chrome/browser/web:feature_flags", + "//ios/chrome/browser/web/font_size", "//ios/chrome/browser/web/image_fetch", "//ios/chrome/browser/web/java_script_console", "//ios/chrome/browser/web/print", @@ -254,7 +255,6 @@ "//ios/components/security_interstitials/lookalikes", "//ios/components/webui:url_constants", "//ios/net", - "//ios/public/provider/chrome/browser:font_size_java_script_feature", "//ios/public/provider/chrome/browser/url_rewriters:url_rewriters_api", "//ios/web", "//ios/web/common",
diff --git a/ios/chrome/browser/web/chrome_web_client.mm b/ios/chrome/browser/web/chrome_web_client.mm index 97acf5e..c68d57b 100644 --- a/ios/chrome/browser/web/chrome_web_client.mm +++ b/ios/chrome/browser/web/chrome_web_client.mm
@@ -49,6 +49,7 @@ #include "ios/chrome/browser/web/error_page_controller_bridge.h" #import "ios/chrome/browser/web/error_page_util.h" #include "ios/chrome/browser/web/features.h" +#import "ios/chrome/browser/web/font_size/font_size_java_script_feature.h" #include "ios/chrome/browser/web/image_fetch/image_fetch_java_script_feature.h" #import "ios/chrome/browser/web/java_script_console/java_script_console_feature.h" #import "ios/chrome/browser/web/java_script_console/java_script_console_feature_factory.h" @@ -62,7 +63,6 @@ #import "ios/components/security_interstitials/lookalikes/lookalike_url_error.h" #include "ios/components/webui/web_ui_url_constants.h" #import "ios/net/protocol_handler_util.h" -#import "ios/public/provider/chrome/browser/font_size_java_script_feature.h" #include "ios/public/provider/chrome/browser/url_rewriters/url_rewriters_api.h" #include "ios/web/common/features.h" #include "ios/web/common/user_agent.h"
diff --git a/ios/chrome/browser/web/font_size/BUILD.gn b/ios/chrome/browser/web/font_size/BUILD.gn index b0493ed..f618ea8 100644 --- a/ios/chrome/browser/web/font_size/BUILD.gn +++ b/ios/chrome/browser/web/font_size/BUILD.gn
@@ -7,6 +7,8 @@ source_set("font_size") { configs += [ "//build/config/compiler:enable_arc" ] sources = [ + "font_size_java_script_feature.h", + "font_size_java_script_feature.mm", "font_size_tab_helper.h", "font_size_tab_helper.mm", ] @@ -22,10 +24,10 @@ "//ios/chrome/browser/browser_state", "//ios/chrome/browser/web:feature_flags", "//ios/components/ui_util", - "//ios/public/provider/chrome/browser:font_size_java_script_feature", "//ios/public/provider/chrome/browser/text_zoom:text_zoom_api", "//services/metrics/public/cpp:ukm_builders", ] + public_deps = [ "//ios/web/public/js_messaging" ] } js_compile_bundle("font_size_js") { @@ -59,7 +61,6 @@ "//ios/chrome/browser/web:test_support", "//ios/chrome/test/fakes", "//ios/public/provider/chrome/browser", - "//ios/public/provider/chrome/browser:font_size_java_script_feature", "//ios/web/public/js_messaging", "//ios/web/public/test", "//testing/gtest",
diff --git a/ios/chrome/browser/web/font_size/font_size_java_script_feature.h b/ios/chrome/browser/web/font_size/font_size_java_script_feature.h new file mode 100644 index 0000000..94d6d5a --- /dev/null +++ b/ios/chrome/browser/web/font_size/font_size_java_script_feature.h
@@ -0,0 +1,38 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef IOS_CHROME_BROWSER_WEB_FONT_SIZE_FONT_SIZE_JAVA_SCRIPT_FEATURE_H_ +#define IOS_CHROME_BROWSER_WEB_FONT_SIZE_FONT_SIZE_JAVA_SCRIPT_FEATURE_H_ + +#include "base/no_destructor.h" +#include "ios/web/public/js_messaging/java_script_feature.h" + +namespace web { +class WebFrame; +class WebState; +} // namespace web + +// Feature which adjusts the font size on a page. +class FontSizeJavaScriptFeature : public web::JavaScriptFeature { + public: + static FontSizeJavaScriptFeature* GetInstance(); + + // Adjusts the font size in all frames of |web_state| by |size| percentage. + void AdjustFontSize(web::WebState* web_state, int size); + + // Adjusts the font size in |web_frame| by |size| percentage. + void AdjustFontSize(web::WebFrame* web_frame, int size); + + private: + friend class base::NoDestructor<FontSizeJavaScriptFeature>; + + FontSizeJavaScriptFeature(); + ~FontSizeJavaScriptFeature() override; + + FontSizeJavaScriptFeature(const FontSizeJavaScriptFeature&) = delete; + FontSizeJavaScriptFeature& operator=(const FontSizeJavaScriptFeature&) = + delete; +}; + +#endif // IOS_CHROME_BROWSER_WEB_FONT_SIZE_FONT_SIZE_JAVA_SCRIPT_FEATURE_H_
diff --git a/ios/chrome/browser/web/font_size/font_size_java_script_feature.mm b/ios/chrome/browser/web/font_size/font_size_java_script_feature.mm new file mode 100644 index 0000000..bc4522c --- /dev/null +++ b/ios/chrome/browser/web/font_size/font_size_java_script_feature.mm
@@ -0,0 +1,50 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#import "ios/chrome/browser/web/font_size/font_size_java_script_feature.h" + +#include "base/no_destructor.h" +#include "base/values.h" +#import "ios/web/public/js_messaging/web_frame.h" +#include "ios/web/public/js_messaging/web_frames_manager.h" +#import "ios/web/public/web_state.h" + +#if !defined(__has_feature) || !__has_feature(objc_arc) +#error "This file requires ARC support." +#endif + +namespace { +const char kFontSizeScript[] = "font_size_js"; +} // namespace + +// static +FontSizeJavaScriptFeature* FontSizeJavaScriptFeature::GetInstance() { + static base::NoDestructor<FontSizeJavaScriptFeature> instance; + return instance.get(); +} + +void FontSizeJavaScriptFeature::AdjustFontSize(web::WebState* web_state, + int size) { + for (web::WebFrame* frame : + web_state->GetWebFramesManager()->GetAllWebFrames()) { + AdjustFontSize(frame, size); + } +} + +void FontSizeJavaScriptFeature::AdjustFontSize(web::WebFrame* web_frame, + int size) { + std::vector<base::Value> parameters; + parameters.push_back(base::Value(size)); + CallJavaScriptFunction(web_frame, "font_size.adjustFontSize", parameters); +} + +FontSizeJavaScriptFeature::FontSizeJavaScriptFeature() + : web::JavaScriptFeature( + web::JavaScriptFeature::ContentWorld::kAnyContentWorld, + {FeatureScript::CreateWithFilename( + kFontSizeScript, + FeatureScript::InjectionTime::kDocumentStart, + FeatureScript::TargetFrames::kAllFrames)}) {} + +FontSizeJavaScriptFeature::~FontSizeJavaScriptFeature() = default;
diff --git a/ios/chrome/browser/web/font_size/font_size_tab_helper.mm b/ios/chrome/browser/web/font_size/font_size_tab_helper.mm index 62e6367..eb5cc41 100644 --- a/ios/chrome/browser/web/font_size/font_size_tab_helper.mm +++ b/ios/chrome/browser/web/font_size/font_size_tab_helper.mm
@@ -21,8 +21,8 @@ #include "ios/chrome/browser/browser_state/chrome_browser_state.h" #include "ios/chrome/browser/pref_names.h" #include "ios/chrome/browser/web/features.h" +#import "ios/chrome/browser/web/font_size/font_size_java_script_feature.h" #include "ios/components/ui_util/dynamic_type_util.h" -#import "ios/public/provider/chrome/browser/font_size_java_script_feature.h" #import "ios/public/provider/chrome/browser/text_zoom/text_zoom_api.h" #include "services/metrics/public/cpp/ukm_builders.h"
diff --git a/ios/chrome/browser/web/font_size/font_size_tab_helper_unittest.mm b/ios/chrome/browser/web/font_size/font_size_tab_helper_unittest.mm index 122a22e..6ac3509 100644 --- a/ios/chrome/browser/web/font_size/font_size_tab_helper_unittest.mm +++ b/ios/chrome/browser/web/font_size/font_size_tab_helper_unittest.mm
@@ -17,7 +17,7 @@ #include "ios/chrome/browser/prefs/browser_prefs.h" #import "ios/chrome/browser/web/chrome_web_test.h" #include "ios/chrome/browser/web/features.h" -#import "ios/public/provider/chrome/browser/font_size_java_script_feature.h" +#import "ios/chrome/browser/web/font_size/font_size_java_script_feature.h" #import "ios/web/public/test/fakes/fake_web_client.h" #import "ios/web/public/test/fakes/fake_web_state.h" #include "testing/gtest/include/gtest/gtest.h"
diff --git a/ios/google_internal/frameworks/chrome_internal_dynamic_framework.ios.zip.sha1 b/ios/google_internal/frameworks/chrome_internal_dynamic_framework.ios.zip.sha1 index 13e8e75..de37135 100644 --- a/ios/google_internal/frameworks/chrome_internal_dynamic_framework.ios.zip.sha1 +++ b/ios/google_internal/frameworks/chrome_internal_dynamic_framework.ios.zip.sha1
@@ -1 +1 @@ -4a9c4a2f9de0a11ac8862ceffccf0547bdea32a9 \ No newline at end of file +ca93720d6b740f7e6d1289bf31fea01877c34cb1 \ No newline at end of file
diff --git a/ios/google_internal/frameworks/chrome_internal_dynamic_framework.iossimulator.zip.sha1 b/ios/google_internal/frameworks/chrome_internal_dynamic_framework.iossimulator.zip.sha1 index fc3c814..fdbabf5 100644 --- a/ios/google_internal/frameworks/chrome_internal_dynamic_framework.iossimulator.zip.sha1 +++ b/ios/google_internal/frameworks/chrome_internal_dynamic_framework.iossimulator.zip.sha1
@@ -1 +1 @@ -4e179a5f43ee114d4d0543c5238f9ecb46f5ce39 \ No newline at end of file +88d82a87b45efce403881ae4ac0a02af8deccd45 \ No newline at end of file
diff --git a/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.ios.zip.sha1 b/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.ios.zip.sha1 index 6f0d7aa..5da96400 100644 --- a/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.ios.zip.sha1 +++ b/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.ios.zip.sha1
@@ -1 +1 @@ -47bc9b9e48c04637a36c8ee671420f10dd502f61 \ No newline at end of file +032600b608568fa6962a3a17d6f7291704285d4a \ No newline at end of file
diff --git a/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.iossimulator.zip.sha1 b/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.iossimulator.zip.sha1 index a49f05b..c956986 100644 --- a/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.iossimulator.zip.sha1 +++ b/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.iossimulator.zip.sha1
@@ -1 +1 @@ -f0182ca3707e7afbbed2244c5386fdc84318432d \ No newline at end of file +e15736feae9e035632b4c618d86630b9d1544c7d \ No newline at end of file
diff --git a/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.ios.zip.sha1 b/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.ios.zip.sha1 index 41126d5..70b2b93 100644 --- a/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.ios.zip.sha1 +++ b/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.ios.zip.sha1
@@ -1 +1 @@ -f2ad453581118d1e7ab25a8893eca776da2e3f0f \ No newline at end of file +612c1689789139c4a762e36d0c87bc1b875f21d1 \ No newline at end of file
diff --git a/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.iossimulator.zip.sha1 b/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.iossimulator.zip.sha1 index dcb7abc..4f2750875 100644 --- a/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.iossimulator.zip.sha1 +++ b/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.iossimulator.zip.sha1
@@ -1 +1 @@ -ba5a3c5dc830eb462773263d1c82f05ef1960852 \ No newline at end of file +2bfdcfb0cedc351b47fe6320d9c160878c986509 \ No newline at end of file
diff --git a/ios/google_internal/frameworks/remoting_internal_dynamic_framework.ios.zip.sha1 b/ios/google_internal/frameworks/remoting_internal_dynamic_framework.ios.zip.sha1 index f11b6d2..c27eb11 100644 --- a/ios/google_internal/frameworks/remoting_internal_dynamic_framework.ios.zip.sha1 +++ b/ios/google_internal/frameworks/remoting_internal_dynamic_framework.ios.zip.sha1
@@ -1 +1 @@ -faf8a5a0eb5af30818b9833493f2cd79041953ee \ No newline at end of file +b3f895acd5dd8cb0d5191dbfb85b0d8030b687d7 \ No newline at end of file
diff --git a/ios/google_internal/frameworks/remoting_internal_dynamic_framework.iossimulator.zip.sha1 b/ios/google_internal/frameworks/remoting_internal_dynamic_framework.iossimulator.zip.sha1 index 1a07fa9..ff7208a 100644 --- a/ios/google_internal/frameworks/remoting_internal_dynamic_framework.iossimulator.zip.sha1 +++ b/ios/google_internal/frameworks/remoting_internal_dynamic_framework.iossimulator.zip.sha1
@@ -1 +1 @@ -8c503cf9981878c59055ea6145e3d8c0a0c3b5fe \ No newline at end of file +b29bffbf6b2071a663b7b0b10b279b000105cfa4 \ No newline at end of file
diff --git a/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.ios.zip.sha1 b/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.ios.zip.sha1 index be64e98e..c5b23e89 100644 --- a/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.ios.zip.sha1 +++ b/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.ios.zip.sha1
@@ -1 +1 @@ -d06247d0cca73ffa0f0199a5b64a5c842dc34e6b \ No newline at end of file +02ef5930e78b5dc56a2fb214899af0c0aa5080db \ No newline at end of file
diff --git a/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.iossimulator.zip.sha1 b/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.iossimulator.zip.sha1 index d2e1f5c..de5579c 100644 --- a/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.iossimulator.zip.sha1 +++ b/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.iossimulator.zip.sha1
@@ -1 +1 @@ -0c9e00a30b130501474d3fb894e085ae8fa34e23 \ No newline at end of file +34f718cc00bc8b61f2c74138dd89d76d27f8a91c \ No newline at end of file
diff --git a/ios/public/provider/chrome/browser/BUILD.gn b/ios/public/provider/chrome/browser/BUILD.gn index 16c0eba..e7c72802 100644 --- a/ios/public/provider/chrome/browser/BUILD.gn +++ b/ios/public/provider/chrome/browser/BUILD.gn
@@ -11,7 +11,6 @@ "chrome_browser_provider.mm", ] deps = [ - ":font_size_java_script_feature", "//base", "//components/metrics", "//ios/public/provider/chrome/browser/mailto", @@ -25,15 +24,8 @@ source_set("font_size_java_script_feature") { configs += [ "//build/config/compiler:enable_arc" ] - sources = [ - "font_size_java_script_feature.h", - "font_size_java_script_feature.mm", - ] - deps = [ - "//base", - "//ios/web/public", - ] - public_deps = [ "//ios/web/public/js_messaging" ] + sources = [ "font_size_java_script_feature.h" ] + public_deps = [ "//ios/chrome/browser/web/font_size" ] } group("provider_api") {
diff --git a/ios/public/provider/chrome/browser/DEPS b/ios/public/provider/chrome/browser/DEPS index 4099d13..65f0b72 100644 --- a/ios/public/provider/chrome/browser/DEPS +++ b/ios/public/provider/chrome/browser/DEPS
@@ -12,3 +12,9 @@ # provider is not allowed to depends on //ios/chrome. "-ios/public/provider/chrome/browser/browser_state/chrome_browser_state.h", ] + +specific_include_rules = { + "font_size_java_script_feature.h": [ + "+ios/chrome/browser/web/font_size/font_size_java_script_feature.h" + ] +}
diff --git a/ios/public/provider/chrome/browser/font_size_java_script_feature.h b/ios/public/provider/chrome/browser/font_size_java_script_feature.h index abf9a8b..75dc7f2 100644 --- a/ios/public/provider/chrome/browser/font_size_java_script_feature.h +++ b/ios/public/provider/chrome/browser/font_size_java_script_feature.h
@@ -1,38 +1,13 @@ -// Copyright 2021 The Chromium Authors. All rights reserved. +// Copyright 2022 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #ifndef IOS_PUBLIC_PROVIDER_CHROME_BROWSER_FONT_SIZE_JAVA_SCRIPT_FEATURE_H_ #define IOS_PUBLIC_PROVIDER_CHROME_BROWSER_FONT_SIZE_JAVA_SCRIPT_FEATURE_H_ -#include "base/no_destructor.h" -#include "ios/web/public/js_messaging/java_script_feature.h" - -namespace web { -class WebFrame; -class WebState; -} // namespace web - -// Feature which adjusts the font size on a page. -class FontSizeJavaScriptFeature : public web::JavaScriptFeature { - public: - static FontSizeJavaScriptFeature* GetInstance(); - - // Adjusts the font size in all frames of |web_state| by |size| percentage. - void AdjustFontSize(web::WebState* web_state, int size); - - // Adjusts the font size in |web_frame| by |size| percentage. - void AdjustFontSize(web::WebFrame* web_frame, int size); - - private: - friend class base::NoDestructor<FontSizeJavaScriptFeature>; - - FontSizeJavaScriptFeature(); - ~FontSizeJavaScriptFeature() override; - - FontSizeJavaScriptFeature(const FontSizeJavaScriptFeature&) = delete; - FontSizeJavaScriptFeature& operator=(const FontSizeJavaScriptFeature&) = - delete; -}; +// This is a forwarding header to allow renaming this file without breaking +// the internal build. It will be removed once the internal repository has +// been converted to use the new path. +#import "ios/chrome/browser/web/font_size/font_size_java_script_feature.h" #endif // IOS_PUBLIC_PROVIDER_CHROME_BROWSER_FONT_SIZE_JAVA_SCRIPT_FEATURE_H_
diff --git a/ios/public/provider/chrome/browser/font_size_java_script_feature.mm b/ios/public/provider/chrome/browser/font_size_java_script_feature.mm deleted file mode 100644 index 2fb16a8..0000000 --- a/ios/public/provider/chrome/browser/font_size_java_script_feature.mm +++ /dev/null
@@ -1,49 +0,0 @@ -// Copyright 2021 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#import "ios/public/provider/chrome/browser/font_size_java_script_feature.h" - -#include "base/no_destructor.h" -#include "base/values.h" -#import "ios/web/public/js_messaging/web_frame.h" -#include "ios/web/public/js_messaging/web_frames_manager.h" -#import "ios/web/public/web_state.h" - -#if !defined(__has_feature) || !__has_feature(objc_arc) -#error "This file requires ARC support." -#endif - -namespace { -const char kFontSizeScript[] = "font_size_js"; -} // namespace - -// static -FontSizeJavaScriptFeature* FontSizeJavaScriptFeature::GetInstance() { - static base::NoDestructor<FontSizeJavaScriptFeature> instance; - return instance.get(); -} - -void FontSizeJavaScriptFeature::AdjustFontSize(web::WebState* web_state, - int size) { - for (web::WebFrame* frame : - web_state->GetWebFramesManager()->GetAllWebFrames()) { - AdjustFontSize(frame, size); - } -} - -void FontSizeJavaScriptFeature::AdjustFontSize(web::WebFrame* web_frame, - int size) { - std::vector<base::Value> parameters; - parameters.push_back(base::Value(size)); - CallJavaScriptFunction(web_frame, "font_size.adjustFontSize", parameters); -} - -FontSizeJavaScriptFeature::FontSizeJavaScriptFeature() - : web::JavaScriptFeature( - web::JavaScriptFeature::ContentWorld::kAnyContentWorld, - {FeatureScript::CreateWithFilename( - kFontSizeScript, - FeatureScript::InjectionTime::kDocumentStart, - FeatureScript::TargetFrames::kAllFrames)}) {} -FontSizeJavaScriptFeature::~FontSizeJavaScriptFeature() = default;
diff --git a/ios/public/provider/chrome/browser/text_zoom/BUILD.gn b/ios/public/provider/chrome/browser/text_zoom/BUILD.gn index c339c5d..1e0c235 100644 --- a/ios/public/provider/chrome/browser/text_zoom/BUILD.gn +++ b/ios/public/provider/chrome/browser/text_zoom/BUILD.gn
@@ -13,6 +13,6 @@ sources = [ "test_text_zoom.mm" ] deps = [ ":text_zoom_api", - "//ios/public/provider/chrome/browser:font_size_java_script_feature", + "//ios/chrome/browser/web/font_size", ] }
diff --git a/ios/public/provider/chrome/browser/text_zoom/DEPS b/ios/public/provider/chrome/browser/text_zoom/DEPS new file mode 100644 index 0000000..2c0e1b6 --- /dev/null +++ b/ios/public/provider/chrome/browser/text_zoom/DEPS
@@ -0,0 +1,7 @@ +include_rules = [] + +specific_include_rules = { + "test_text_zoom.mm": [ + "+ios/chrome/browser/web/font_size/font_size_java_script_feature.h", + ] +}
diff --git a/ios/public/provider/chrome/browser/text_zoom/test_text_zoom.mm b/ios/public/provider/chrome/browser/text_zoom/test_text_zoom.mm index 4777b3b..a3c02693 100644 --- a/ios/public/provider/chrome/browser/text_zoom/test_text_zoom.mm +++ b/ios/public/provider/chrome/browser/text_zoom/test_text_zoom.mm
@@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#import "ios/public/provider/chrome/browser/font_size_java_script_feature.h" +#import "ios/chrome/browser/web/font_size/font_size_java_script_feature.h" #import "ios/public/provider/chrome/browser/text_zoom/text_zoom_api.h" #if !defined(__has_feature) || !__has_feature(objc_arc)
diff --git a/ios/web/security/crw_cert_verification_controller_unittest.mm b/ios/web/security/crw_cert_verification_controller_unittest.mm index e4fbf16..4d8852cc 100644 --- a/ios/web/security/crw_cert_verification_controller_unittest.mm +++ b/ios/web/security/crw_cert_verification_controller_unittest.mm
@@ -4,7 +4,7 @@ #import "ios/web/security/crw_cert_verification_controller.h" -#include "base/mac/foundation_util.h" +#include "base/mac/bridging.h" #import "base/test/ios/wait_util.h" #include "ios/web/public/test/web_test.h" #include "ios/web/public/thread/web_thread.h" @@ -45,10 +45,10 @@ cert_.get())); ASSERT_TRUE(chain); valid_trust_ = web::CreateServerTrustFromChain( - base::mac::CFToNSCast(chain.get()), kHostName); + base::mac::CFToNSPtrCast(chain.get()), kHostName); web::EnsureFutureTrustEvaluationSucceeds(valid_trust_.get()); invalid_trust_ = web::CreateServerTrustFromChain( - base::mac::CFToNSCast(chain.get()), kHostName); + base::mac::CFToNSPtrCast(chain.get()), kHostName); } // Synchronously returns result of
diff --git a/ios/web/security/crw_ssl_status_updater_unittest.mm b/ios/web/security/crw_ssl_status_updater_unittest.mm index e804b515..3f4bd667 100644 --- a/ios/web/security/crw_ssl_status_updater_unittest.mm +++ b/ios/web/security/crw_ssl_status_updater_unittest.mm
@@ -6,7 +6,7 @@ #import <WebKit/WebKit.h> -#include "base/mac/foundation_util.h" +#include "base/mac/bridging.h" #include "base/strings/sys_string_conversions.h" #import "ios/web/navigation/navigation_manager_impl.h" #import "ios/web/public/navigation/navigation_item.h" @@ -108,7 +108,7 @@ net::x509_util::CreateSecCertificateArrayForX509Certificate( cert.get())); ASSERT_TRUE(chain); - trust_ = CreateServerTrustFromChain(base::mac::CFToNSCast(chain.get()), + trust_ = CreateServerTrustFromChain(base::mac::CFToNSPtrCast(chain.get()), kHostName); }
diff --git a/ios/web/security/wk_web_view_security_util_unittest.mm b/ios/web/security/wk_web_view_security_util_unittest.mm index 4fde904b..7c6d32d 100644 --- a/ios/web/security/wk_web_view_security_util_unittest.mm +++ b/ios/web/security/wk_web_view_security_util_unittest.mm
@@ -9,7 +9,7 @@ #include <memory> -#include "base/mac/foundation_util.h" +#include "base/mac/bridging.h" #include "base/mac/scoped_cftyperef.h" #include "crypto/rsa_private_key.h" #include "net/cert/x509_certificate.h" @@ -60,7 +60,7 @@ base::ScopedCFTypeRef<SecTrustRef> CreateTestTrust(NSArray* cert_chain) { base::ScopedCFTypeRef<SecPolicyRef> policy(SecPolicyCreateBasicX509()); SecTrustRef trust = nullptr; - SecTrustCreateWithCertificates(base::mac::NSToCFCast(cert_chain), policy, + SecTrustCreateWithCertificates(base::mac::NSToCFPtrCast(cert_chain), policy, &trust); return base::ScopedCFTypeRef<SecTrustRef>(trust); }
diff --git a/ios/web/web_view/error_translation_util_unittest.mm b/ios/web/web_view/error_translation_util_unittest.mm index ce79526..b08c553b 100644 --- a/ios/web/web_view/error_translation_util_unittest.mm +++ b/ios/web/web_view/error_translation_util_unittest.mm
@@ -6,7 +6,7 @@ #import <Foundation/Foundation.h> -#include "base/mac/foundation_util.h" +#include "base/mac/bridging.h" #import "ios/net/protocol_handler_util.h" #include "ios/web/test/test_url_constants.h" #import "net/base/mac/url_conversions.h" @@ -74,7 +74,7 @@ // underlying error. TEST_F(ErrorTranslationUtilTest, UnknownCFNetworkError) { NSError* error = [[NSError alloc] - initWithDomain:base::mac::CFToNSCast(kCFErrorDomainCFNetwork) + initWithDomain:base::mac::CFToNSPtrCast(kCFErrorDomainCFNetwork) code:kCFURLErrorUnknown userInfo:nil]; NSError* net_error = NetErrorFromError(error);
diff --git a/ipc/BUILD.gn b/ipc/BUILD.gn index 833da0a6..1ac4c2c 100644 --- a/ipc/BUILD.gn +++ b/ipc/BUILD.gn
@@ -284,7 +284,6 @@ "//base", "//base:i18n", "//base/test:test_support", - "//build:os_buildflags", "//crypto", "//mojo/core/test:test_support", "//testing/gtest",
diff --git a/ipc/ipc_message_attachment_set_posix_unittest.cc b/ipc/ipc_message_attachment_set_posix_unittest.cc index e0ed3b5..da20449 100644 --- a/ipc/ipc_message_attachment_set_posix_unittest.cc +++ b/ipc/ipc_message_attachment_set_posix_unittest.cc
@@ -12,7 +12,6 @@ #include "base/posix/eintr_wrapper.h" #include "build/build_config.h" -#include "build/os_buildflags.h" #include "ipc/ipc_platform_file_attachment_posix.h" #include "testing/gtest/include/gtest/gtest.h"
diff --git a/ipc/ipc_send_fds_test.cc b/ipc/ipc_send_fds_test.cc index 0523d9b..dad1710 100644 --- a/ipc/ipc_send_fds_test.cc +++ b/ipc/ipc_send_fds_test.cc
@@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "build/os_buildflags.h" +#include "build/build_config.h" #if BUILDFLAG(IS_MAC) extern "C" {
diff --git a/media/audio/BUILD.gn b/media/audio/BUILD.gn index 5e1b1f5..b19e098 100644 --- a/media/audio/BUILD.gn +++ b/media/audio/BUILD.gn
@@ -406,7 +406,6 @@ "//base", "//base/test:test_support", "//build:chromeos_buildflags", - "//build:os_buildflags", "//media:test_support", "//testing/gmock", "//testing/gtest",
diff --git a/media/audio/audio_input_unittest.cc b/media/audio/audio_input_unittest.cc index ce2f2589..34a16e8 100644 --- a/media/audio/audio_input_unittest.cc +++ b/media/audio/audio_input_unittest.cc
@@ -17,7 +17,6 @@ #include "base/test/test_message_loop.h" #include "base/threading/platform_thread.h" #include "build/build_config.h" -#include "build/os_buildflags.h" #include "media/audio/audio_device_description.h" #include "media/audio/audio_device_info_accessor_for_tests.h" #include "media/audio/audio_io.h"
diff --git a/media/formats/BUILD.gn b/media/formats/BUILD.gn index d24a03d..212c2559 100644 --- a/media/formats/BUILD.gn +++ b/media/formats/BUILD.gn
@@ -189,6 +189,8 @@ "hls/parse_context.cc", "hls/parse_context.h", "hls/parse_status.h", + "hls/tags.cc", + "hls/tags.h", "hls/types.cc", "hls/types.h", ] @@ -302,6 +304,7 @@ # TODO(https://crbug.com/1266991): This should be gated behind `enable_hls_demuxer`, once that's enabled by default. sources += [ "hls/items_unittest.cc", + "hls/tags_unittest.cc", "hls/types_unittest.cc", ] }
diff --git a/media/formats/hls/parse_context.cc b/media/formats/hls/parse_context.cc index 2905749..dfbc05f 100644 --- a/media/formats/hls/parse_context.cc +++ b/media/formats/hls/parse_context.cc
@@ -18,6 +18,10 @@ return SourceString(line, 1, str); } +SourceString SourceString::CreateForTesting(base::StringPiece str) { + return SourceString::CreateForTesting(1, 1, str); +} + SourceString SourceString::CreateForTesting(size_t line, size_t column, base::StringPiece str) {
diff --git a/media/formats/hls/parse_context.h b/media/formats/hls/parse_context.h index 53ad33e..9518e1f 100644 --- a/media/formats/hls/parse_context.h +++ b/media/formats/hls/parse_context.h
@@ -23,6 +23,7 @@ static SourceString Create(base::PassKey<SourceLineIterator>, size_t line, base::StringPiece str); + static SourceString CreateForTesting(base::StringPiece str); static SourceString CreateForTesting(size_t line, size_t column, base::StringPiece str);
diff --git a/media/formats/hls/parse_status.h b/media/formats/hls/parse_status.h index 43673a8..43bce92 100644 --- a/media/formats/hls/parse_status.h +++ b/media/formats/hls/parse_status.h
@@ -17,6 +17,7 @@ kFailedToParseDecimalInteger, kFailedToParseDecimalFloatingPoint, kFailedToParseSignedDecimalFloatingPoint, + kInvalidPlaylistVersion, kPlaylistMissingM3uTag, kMediaSegmentMissingInfTag, };
diff --git a/media/formats/hls/tags.cc b/media/formats/hls/tags.cc new file mode 100644 index 0000000..ebbf6d5 --- /dev/null +++ b/media/formats/hls/tags.cc
@@ -0,0 +1,67 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "media/formats/hls/tags.h" + +#include <cstddef> +#include "media/formats/hls/parse_status.h" + +namespace media { +namespace hls { + +ParseStatus::Or<M3uTag> M3uTag::Parse(TagItem tag) { + DCHECK(tag.kind == TagKind::kM3u); + if (!tag.content.Str().empty()) { + return ParseStatusCode::kMalformedTag; + } + + return M3uTag{}; +} + +ParseStatus::Or<XVersionTag> XVersionTag::Parse(TagItem tag) { + DCHECK(tag.kind == TagKind::kXVersion); + + auto value_result = types::ParseDecimalInteger(tag.content); + if (value_result.has_error()) { + return ParseStatus(ParseStatusCode::kMalformedTag) + .AddCause(std::move(value_result).error()); + } + + // Reject invalid version numbers. + // For valid version numbers, caller will decide if the version is supported. + auto value = std::move(value_result).value(); + if (value == 0) { + return ParseStatusCode::kInvalidPlaylistVersion; + } + + return XVersionTag{.version = value}; +} + +ParseStatus::Or<InfTag> InfTag::Parse(TagItem tag) { + DCHECK(tag.kind == TagKind::kInf); + + // Inf tags have the form #EXTINF:<duration>,[<title>] + // Find the comma. + auto comma = tag.content.Str().find_first_of(','); + if (comma == base::StringPiece::npos) { + return ParseStatusCode::kMalformedTag; + } + + auto duration_str = tag.content.Substr(0, comma); + auto title_str = tag.content.Substr(comma + 1); + + // Extract duration + // TODO(crbug.com/1284763): Below version 3 this should be rounded to an + // integer + auto duration = types::ParseDecimalFloatingPoint(duration_str); + if (duration.has_error()) { + return ParseStatus(ParseStatusCode::kMalformedTag) + .AddCause(std::move(duration).error()); + } + + return InfTag{.duration = std::move(duration).value(), .title = title_str}; +} + +} // namespace hls +} // namespace media
diff --git a/media/formats/hls/tags.h b/media/formats/hls/tags.h new file mode 100644 index 0000000..053c1618 --- /dev/null +++ b/media/formats/hls/tags.h
@@ -0,0 +1,45 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MEDIA_FORMATS_HLS_TAGS_H_ +#define MEDIA_FORMATS_HLS_TAGS_H_ + +#include "media/base/media_export.h" +#include "media/formats/hls/items.h" +#include "media/formats/hls/parse_status.h" +#include "media/formats/hls/types.h" + +namespace media { +namespace hls { + +// Represents the contents of the #EXTM3U tag +struct M3uTag { + static constexpr TagKind kKind = TagKind::kM3u; + static MEDIA_EXPORT ParseStatus::Or<M3uTag> Parse(TagItem); +}; + +// Represents the contents of the #EXT-X-VERSION tag +struct XVersionTag { + static constexpr TagKind kKind = TagKind::kXVersion; + static MEDIA_EXPORT ParseStatus::Or<XVersionTag> Parse(TagItem); + + types::DecimalInteger version; +}; + +// Represents the contents of the #EXTINF tag +struct InfTag { + static constexpr TagKind kKind = TagKind::kInf; + static MEDIA_EXPORT ParseStatus::Or<InfTag> Parse(TagItem); + + // Target duration of the media segment, in seconds. + types::DecimalFloatingPoint duration; + + // Human-readable title of the media segment. + SourceString title; +}; + +} // namespace hls +} // namespace media + +#endif // MEDIA_FORMATS_HLS_TAGS_H_
diff --git a/media/formats/hls/tags_unittest.cc b/media/formats/hls/tags_unittest.cc new file mode 100644 index 0000000..dc2f97a --- /dev/null +++ b/media/formats/hls/tags_unittest.cc
@@ -0,0 +1,126 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "media/formats/hls/tags.h" +#include "media/formats/hls/parse_context.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace media { +namespace hls { + +template <typename T> +void ErrorTest(SourceString content, ParseStatusCode expected_status) { + auto tag = TagItem{.kind = T::kKind, .content = content}; + auto result = T::Parse(tag); + EXPECT_TRUE(result.has_error()); + auto error = std::move(result).error(); + EXPECT_EQ(error.code(), expected_status); +} + +template <typename T> +T OkTest(SourceString content) { + auto tag = TagItem{.kind = T::kKind, .content = content}; + auto result = T::Parse(tag); + EXPECT_TRUE(result.has_value()); + return std::move(result).value(); +} + +TEST(HlsFormatParserTest, ParseM3uTagTest) { + // Empty content is the only allowed content + OkTest<M3uTag>(SourceString::CreateForTesting("")); + + // Test with non-empty content + ErrorTest<M3uTag>(SourceString::CreateForTesting(" "), + ParseStatusCode::kMalformedTag); + ErrorTest<M3uTag>(SourceString::CreateForTesting("a"), + ParseStatusCode::kMalformedTag); + ErrorTest<M3uTag>(SourceString::CreateForTesting("1234"), + ParseStatusCode::kMalformedTag); + ErrorTest<M3uTag>(SourceString::CreateForTesting("\t"), + ParseStatusCode::kMalformedTag); +} + +TEST(HlsFormatParserTest, ParseXVersionTagTest) { + // Test valid versions + XVersionTag tag = OkTest<XVersionTag>(SourceString::CreateForTesting("1")); + EXPECT_EQ(tag.version, 1u); + tag = OkTest<XVersionTag>(SourceString::CreateForTesting("2")); + EXPECT_EQ(tag.version, 2u); + tag = OkTest<XVersionTag>(SourceString::CreateForTesting("3")); + EXPECT_EQ(tag.version, 3u); + tag = OkTest<XVersionTag>(SourceString::CreateForTesting("4")); + EXPECT_EQ(tag.version, 4u); + tag = OkTest<XVersionTag>(SourceString::CreateForTesting("5")); + EXPECT_EQ(tag.version, 5u); + tag = OkTest<XVersionTag>(SourceString::CreateForTesting("6")); + EXPECT_EQ(tag.version, 6u); + tag = OkTest<XVersionTag>(SourceString::CreateForTesting("7")); + EXPECT_EQ(tag.version, 7u); + tag = OkTest<XVersionTag>(SourceString::CreateForTesting("8")); + EXPECT_EQ(tag.version, 8u); + tag = OkTest<XVersionTag>(SourceString::CreateForTesting("9")); + EXPECT_EQ(tag.version, 9u); + tag = OkTest<XVersionTag>(SourceString::CreateForTesting("10")); + EXPECT_EQ(tag.version, 10u); + + // While unsupported playlist versions are rejected, that's NOT the + // responsibility of this tag parsing function. The playlist should be + // rejected at a higher level. + tag = OkTest<XVersionTag>(SourceString::CreateForTesting("99999")); + EXPECT_EQ(tag.version, 99999u); + + // Test invalid versions + ErrorTest<XVersionTag>(SourceString::CreateForTesting(""), + ParseStatusCode::kMalformedTag); + ErrorTest<XVersionTag>(SourceString::CreateForTesting("0"), + ParseStatusCode::kInvalidPlaylistVersion); + ErrorTest<XVersionTag>(SourceString::CreateForTesting("-1"), + ParseStatusCode::kMalformedTag); + ErrorTest<XVersionTag>(SourceString::CreateForTesting("1.0"), + ParseStatusCode::kMalformedTag); + ErrorTest<XVersionTag>(SourceString::CreateForTesting("asdf"), + ParseStatusCode::kMalformedTag); + ErrorTest<XVersionTag>(SourceString::CreateForTesting(" 1 "), + ParseStatusCode::kMalformedTag); +} + +TEST(HlsFormatParserTest, ParseInfTagTest) { + // Test some valid tags + InfTag tag = OkTest<InfTag>(SourceString::CreateForTesting("1234,")); + EXPECT_EQ(tag.duration, 1234.0); + EXPECT_EQ(tag.title.Str(), ""); + + tag = OkTest<InfTag>(SourceString::CreateForTesting("1.234,")); + EXPECT_EQ(tag.duration, 1.234); + EXPECT_EQ(tag.title.Str(), ""); + + // The spec implies that whitespace characters like this usually aren't + // permitted, but "\t" is a common occurrence for the title value. + tag = OkTest<InfTag>(SourceString::CreateForTesting("99.5,\t")); + EXPECT_EQ(tag.duration, 99.5); + EXPECT_EQ(tag.title.Str(), "\t"); + + tag = OkTest<InfTag>(SourceString::CreateForTesting("9.5,,,,")); + EXPECT_EQ(tag.duration, 9.5); + EXPECT_EQ(tag.title.Str(), ",,,"); + + tag = OkTest<InfTag>(SourceString::CreateForTesting("12,asdfsdf ")); + EXPECT_EQ(tag.duration, 12.0); + EXPECT_EQ(tag.title.Str(), "asdfsdf "); + + // Test some invalid tags + ErrorTest<InfTag>(SourceString::CreateForTesting(""), + ParseStatusCode::kMalformedTag); + ErrorTest<InfTag>(SourceString::CreateForTesting(","), + ParseStatusCode::kMalformedTag); + ErrorTest<InfTag>(SourceString::CreateForTesting("-123,"), + ParseStatusCode::kMalformedTag); + ErrorTest<InfTag>(SourceString::CreateForTesting("123"), + ParseStatusCode::kMalformedTag); + ErrorTest<InfTag>(SourceString::CreateForTesting("asdf,"), + ParseStatusCode::kMalformedTag); +} + +} // namespace hls +} // namespace media
diff --git a/mojo/public/rust/OWNERS b/mojo/public/rust/OWNERS new file mode 100644 index 0000000..3ba9ffb --- /dev/null +++ b/mojo/public/rust/OWNERS
@@ -0,0 +1,3 @@ +file://build/rust/OWNERS + +collinbaker@chromium.org \ No newline at end of file
diff --git a/mojo/public/rust/bindings/decoding.rs b/mojo/public/rust/bindings/decoding.rs new file mode 100644 index 0000000..304817b --- /dev/null +++ b/mojo/public/rust/bindings/decoding.rs
@@ -0,0 +1,398 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use bindings::encoding::{ + Bits, Context, DataHeader, DataHeaderValue, MojomNumeric, DATA_HEADER_SIZE, +}; +use bindings::mojom::{MojomEncodable, MOJOM_NULL_POINTER, UNION_SIZE}; +use bindings::util; + +use std::mem; +use std::ptr; +use std::vec::Vec; + +use system; +use system::{CastHandle, Handle, UntypedHandle}; + +#[derive(Debug, Eq, PartialEq)] +pub enum ValidationError { + DifferentSizedArraysInMap, + IllegalHandle, + IllegalMemoryRange, + IllegalPointer, + MessageHeaderInvalidFlags, + MessageHeaderMissingRequestId, + MessageHeaderUnknownMethod, + MisalignedObject, + UnexpectedArrayHeader, + UnexpectedInvalidHandle, + UnexpectedNullPointer, + UnexpectedNullUnion, + UnexpectedStructHeader, +} + +impl ValidationError { + pub fn as_str(self) -> &'static str { + match self { + ValidationError::DifferentSizedArraysInMap => { + "VALIDATION_ERROR_DIFFERENT_SIZED_ARRAYS_IN_MAP" + } + ValidationError::IllegalHandle => "VALIDATION_ERROR_ILLEGAL_HANDLE", + ValidationError::IllegalMemoryRange => "VALIDATION_ERROR_ILLEGAL_MEMORY_RANGE", + ValidationError::IllegalPointer => "VALIDATION_ERROR_ILLEGAL_POINTER", + ValidationError::MessageHeaderInvalidFlags => { + "VALIDATION_ERROR_MESSAGE_HEADER_INVALID_FLAGS" + } + ValidationError::MessageHeaderMissingRequestId => { + "VALIDATION_ERROR_MESSAGE_HEADER_MISSING_REQUEST_ID" + } + ValidationError::MessageHeaderUnknownMethod => { + "VALIDATION_ERROR_MESSAGE_HEADER_UNKNOWN_METHOD" + } + ValidationError::MisalignedObject => "VALIDATION_ERROR_MISALIGNED_OBJECT", + ValidationError::UnexpectedArrayHeader => "VALIDATION_ERROR_UNEXPECTED_ARRAY_HEADER", + ValidationError::UnexpectedInvalidHandle => { + "VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE" + } + ValidationError::UnexpectedNullPointer => "VALIDATION_ERROR_UNEXPECTED_NULL_POINTER", + ValidationError::UnexpectedNullUnion => "VALIDATION_ERROR_UNEXPECTED_NULL_UNION", + ValidationError::UnexpectedStructHeader => "VALIDATION_ERROR_UNEXPECTED_STRUCT_HEADER", + } + } +} + +/// An decoding state represents the decoding logic for a single +/// Mojom object that is NOT inlined, such as a struct or an array. +pub struct DecodingState<'slice> { + /// The buffer the state may write to. + data: &'slice [u8], + + /// The offset of this serialized object into the overall buffer. + global_offset: usize, + + /// The current offset within 'data'. + offset: usize, + + /// The current bit offset within 'data'. + bit_offset: Bits, +} + +impl<'slice> DecodingState<'slice> { + /// Create a new decoding state. + pub fn new(buffer: &'slice [u8], offset: usize) -> DecodingState<'slice> { + DecodingState { data: buffer, global_offset: offset, offset: 0, bit_offset: Bits(0) } + } + + /// Align the decoding state to the next byte. + pub fn align_to_byte(&mut self) { + if self.bit_offset > Bits(0) { + self.offset += 1; + self.bit_offset = Bits(0); + } + } + + /// Align the decoding state to the next 'bytes' boundary. + pub fn align_to_bytes(&mut self, bytes: usize) { + if self.offset != 0 { + self.offset = util::align_bytes(self.offset, bytes); + } + } + + /// Read a primitive from the buffer without incrementing the offset. + fn read_in_place<T: MojomNumeric>(&mut self) -> T { + let mut value: T = Default::default(); + debug_assert!(mem::size_of::<T>() + self.offset <= self.data.len()); + let ptr = (&self.data[self.offset..]).as_ptr(); + unsafe { + ptr::copy_nonoverlapping( + mem::transmute::<*const u8, *const T>(ptr), + &mut value as *mut T, + 1, + ); + } + value + } + + /// Read a primitive from the buffer and increment the offset. + fn read<T: MojomNumeric>(&mut self) -> T { + let value = self.read_in_place::<T>(); + self.bit_offset = Bits(0); + self.offset += mem::size_of::<T>(); + value + } + + /// Decode a primitive from the buffer, naturally aligning before we read. + pub fn decode<T: MojomNumeric>(&mut self) -> T { + self.align_to_byte(); + self.align_to_bytes(mem::size_of::<T>()); + self.read::<T>() + } + + /// Decode a boolean value from the buffer as one bit. + pub fn decode_bool(&mut self) -> bool { + let offset = self.offset; + // Check the bit by getting the set bit and checking if its non-zero + let value = (self.data[offset] & self.bit_offset.as_set_bit()) > 0; + self.bit_offset += Bits(1); + let (bits, bytes) = self.bit_offset.as_bits_and_bytes(); + self.offset += bytes; + self.bit_offset = bits; + value + } + + /// If we encounter a null pointer, increment past it. + /// + /// Returns if we skipped or not. + pub fn skip_if_null_pointer(&mut self) -> bool { + self.align_to_byte(); + self.align_to_bytes(8); + let ptr = self.read_in_place::<u64>(); + if ptr == MOJOM_NULL_POINTER { + self.offset += 8; + } + (ptr == MOJOM_NULL_POINTER) + } + + /// If we encounter a null union, increment past it. + /// + /// Returns if we skipped or not. + pub fn skip_if_null_union(&mut self) -> bool { + self.align_to_byte(); + self.align_to_bytes(8); + let size = self.read_in_place::<u32>(); + if size == 0 { + self.offset += UNION_SIZE; + } + (size == 0) + } + + /// If we encounter a null handle, increment past it. + /// + /// Returns if we skipped or not. + pub fn skip_if_null_handle(&mut self) -> bool { + self.align_to_byte(); + self.align_to_bytes(4); + let index = self.read_in_place::<i32>(); + if index < 0 { + self.offset += 4; + } + (index < 0) + } + + /// If we encounter a null interface, increment past it. + /// + /// Returns if we skipped or not. + pub fn skip_if_null_interface(&mut self) -> bool { + self.align_to_byte(); + self.align_to_bytes(4); + let index = self.read_in_place::<i32>(); + if index < 0 { + self.offset += 8; + } + (index < 0) + } + + /// Decode a pointer from the buffer as a global offset into the buffer. + /// + /// The pointer in the buffer is an offset relative to the pointer to another + /// location in the buffer. We convert that to an absolute offset with respect + /// to the buffer before returning. This is our defintion of a pointer. + pub fn decode_pointer(&mut self) -> Option<u64> { + self.align_to_byte(); + self.align_to_bytes(8); + let current_location = (self.global_offset + self.offset) as u64; + let offset = self.read::<u64>(); + if offset == MOJOM_NULL_POINTER { + Some(MOJOM_NULL_POINTER) + } else { + offset.checked_add(current_location) + } + } + + /// A routine for decoding an array header. + /// + /// Must be called with offset zero (that is, it must be the first thing + /// decoded). Performs numerous validation checks. + pub fn decode_array_header<T>(&mut self) -> Result<DataHeader, ValidationError> + where + T: MojomEncodable, + { + debug_assert_eq!(self.offset, 0); + // Make sure we can read the size first... + if self.data.len() < mem::size_of::<u32>() { + return Err(ValidationError::UnexpectedArrayHeader); + } + let bytes = self.decode::<u32>(); + if (bytes as usize) < DATA_HEADER_SIZE { + return Err(ValidationError::UnexpectedArrayHeader); + } + let elems = self.decode::<u32>(); + match T::embed_size(&Default::default()).checked_mul(elems as usize) { + Some(value) => { + if (bytes as usize) < value.as_bytes() + DATA_HEADER_SIZE { + return Err(ValidationError::UnexpectedArrayHeader); + } + } + None => return Err(ValidationError::UnexpectedArrayHeader), + } + Ok(DataHeader::new(bytes as usize, DataHeaderValue::Elements(elems))) + } + + /// A routine for decoding an struct header. + /// + /// Must be called with offset zero (that is, it must be the first thing + /// decoded). Performs numerous validation checks. + pub fn decode_struct_header( + &mut self, + versions: &[(u32, u32)], + ) -> Result<DataHeader, ValidationError> { + debug_assert_eq!(self.offset, 0); + // Make sure we can read the size first... + if self.data.len() < mem::size_of::<u32>() { + return Err(ValidationError::UnexpectedStructHeader); + } + let bytes = self.decode::<u32>(); + if (bytes as usize) < DATA_HEADER_SIZE { + return Err(ValidationError::UnexpectedStructHeader); + } + let version = self.decode::<u32>(); + // Versioning validation: versions are generated as a sorted array of tuples, so + // to find the version we are given by the header we use a binary search. + match versions.binary_search_by(|val| val.0.cmp(&version)) { + Ok(idx) => { + let (_, size) = versions[idx]; + if bytes != size { + return Err(ValidationError::UnexpectedStructHeader); + } + } + Err(idx) => { + if idx == 0 { + panic!( + "Should be earliest version? \ + Versions: {:?}, \ + Version: {}, \ + Size: {}", + versions, version, bytes + ); + } + let len = versions.len(); + let (latest_version, _) = versions[len - 1]; + let (_, size) = versions[idx - 1]; + // If this is higher than any version we know, its okay for the size to be bigger, + // but if its a version we know about, it must match the size. + if (version > latest_version && bytes < size) + || (version <= latest_version && bytes != size) + { + return Err(ValidationError::UnexpectedStructHeader); + } + } + } + Ok(DataHeader::new(bytes as usize, DataHeaderValue::Version(version))) + } +} + +/// A struct that will encode a given Mojom object and convert it into +/// bytes and a vector of handles. +pub struct Decoder<'slice> { + bytes: usize, + buffer: Option<&'slice [u8]>, + states: Vec<DecodingState<'slice>>, + handles: Vec<UntypedHandle>, + handles_claimed: usize, // A length that claims all handles were claimed up to this index + max_offset: usize, // Represents the maximum value an offset may have +} + +impl<'slice> Decoder<'slice> { + /// Create a new Decoder. + pub fn new(buffer: &'slice [u8], handles: Vec<UntypedHandle>) -> Decoder<'slice> { + let max_offset = buffer.len(); + Decoder { + bytes: 0, + buffer: Some(buffer), + states: Vec::new(), + handles: handles, + handles_claimed: 0, + max_offset: max_offset, + } + } + + /// Claim space in the buffer to start decoding some object. + /// + /// Creates a new decoding state for the object and returns a context. + pub fn claim(&mut self, offset: usize) -> Result<Context, ValidationError> { + // Check if the layout order is sane + if offset < self.bytes { + return Err(ValidationError::IllegalMemoryRange); + } + // Check for 8-byte alignment + if offset & 7 != 0 { + return Err(ValidationError::MisalignedObject); + } + // Bounds check on offset + if offset > self.max_offset { + return Err(ValidationError::IllegalPointer); + } + let mut buffer = self.buffer.take().expect("No buffer?"); + let space = offset - self.bytes; + buffer = &buffer[space..]; + // Make sure we can even read the bytes in the header + if buffer.len() < mem::size_of::<u32>() { + return Err(ValidationError::IllegalMemoryRange); + } + // Read the number of bytes in the memory region according to the data header + let mut read_size: u32 = 0; + unsafe { + ptr::copy_nonoverlapping( + mem::transmute::<*const u8, *const u32>(buffer.as_ptr()), + &mut read_size as *mut u32, + mem::size_of::<u32>(), + ); + } + let size = u32::from_le(read_size) as usize; + // Make sure the size we read is sane... + if size > buffer.len() { + return Err(ValidationError::IllegalMemoryRange); + } + // TODO(mknyszek): Check size for validation + let (claimed, unclaimed) = buffer.split_at(size); + self.states.push(DecodingState::new(claimed, offset)); + self.buffer = Some(unclaimed); + self.bytes += space + size; + Ok(Context::new(self.states.len() - 1)) + } + + /// Claims a handle at some particular index in the given handles array. + /// + /// Returns the handle with all type information in-tact. + pub fn claim_handle<T: Handle + CastHandle>( + &mut self, + index: i32, + ) -> Result<T, ValidationError> { + let real_index = if index >= 0 { + index as usize + } else { + return Err(ValidationError::UnexpectedInvalidHandle); + }; + // If the index exceeds our number of handles or if we have already claimed that handle + if real_index >= self.handles.len() || real_index < self.handles_claimed { + return Err(ValidationError::IllegalHandle); + } + self.handles_claimed = real_index + 1; + let raw_handle = self.handles[real_index].get_native_handle(); + unsafe { + self.handles[real_index].invalidate(); + Ok(T::from_untyped(system::acquire(raw_handle))) + } + } + + /// Immutably borrow a decoding state via Context. + pub fn get(&self, context: &Context) -> &DecodingState<'slice> { + &self.states[context.id()] + } + + /// Mutably borrow a decoding state via Context. + pub fn get_mut(&mut self, context: &Context) -> &mut DecodingState<'slice> { + &mut self.states[context.id()] + } +}
diff --git a/mojo/public/rust/bindings/encoding.rs b/mojo/public/rust/bindings/encoding.rs new file mode 100644 index 0000000..39e0744 --- /dev/null +++ b/mojo/public/rust/bindings/encoding.rs
@@ -0,0 +1,399 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use bindings::mojom::MOJOM_NULL_POINTER; +use bindings::util; + +use std::mem; +use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub}; +use std::ptr; +use std::vec::Vec; + +use system::UntypedHandle; + +/// Represents some count of bits. +/// +/// Used to distinguish when we have a bit and a byte +/// count. The byte count will go in a usize, while we +/// can use this structure to safely count bits without +/// running into some subtle bugs or crazy errors. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)] +pub struct Bits(pub usize); + +impl Bits { + /// Convert bit representation to bytes, rounding up to the nearest byte. + pub fn as_bytes(self) -> usize { + util::bits_to_bytes(self.0) + } + + /// Convert to a number of bytes plus the number of bits leftover + /// that could not fit in a full byte. + pub fn as_bits_and_bytes(self) -> (Bits, usize) { + (Bits(self.0 & 7), self.0 >> 3) + } + + /// Return 1 left-shifted by the amount of bits stored here. + /// + /// Only guaranteed to work for up to 8 bits. + pub fn as_set_bit(self) -> u8 { + debug_assert!(self.0 < 8); + 1 << (self.0 & 7) + } + + pub fn checked_mul(self, val: usize) -> Option<Bits> { + match val.checked_mul(self.0) { + Some(result) => Some(Bits(result)), + None => None, + } + } + + /// Align the bits to some number of bytes. + pub fn align_to_bytes(&mut self, bytes: usize) { + self.0 = util::align_bytes(self.0, 8 * bytes); + } +} + +impl Add for Bits { + type Output = Self; + fn add(self, rhs: Self) -> Self { + Bits(self.0 + rhs.0) + } +} + +impl AddAssign for Bits { + fn add_assign(&mut self, rhs: Self) { + self.0 += rhs.0 + } +} + +impl Mul<usize> for Bits { + type Output = Self; + fn mul(self, rhs: usize) -> Self { + Bits(self.0 * rhs) + } +} + +/// This trait is intended to be used by Mojom primitive values +/// in order to be identified in generic contexts. +pub trait MojomNumeric: + Copy + + Clone + + Sized + + Add<Self> + + Sub<Self, Output = Self> + + Mul<Self> + + Div<Self, Output = Self> + + Rem<Self, Output = Self> + + PartialEq<Self> + + Default +{ + /// Converts the primitive to a little-endian representation (the mojom endianness). + fn to_mojom_endian(self) -> Self; +} + +macro_rules! impl_mojom_numeric_for_prim { + ($($t:ty),*) => { + $( + impl MojomNumeric for $t { + fn to_mojom_endian(self) -> $t { self.to_le() } + } + )* + } +} + +impl_mojom_numeric_for_prim!(i8, i16, i32, i64, u8, u16, u32, u64); + +impl MojomNumeric for f32 { + fn to_mojom_endian(self) -> f32 { + unsafe { mem::transmute::<u32, f32>(mem::transmute::<f32, u32>(self).to_le()) } + } +} + +impl MojomNumeric for f64 { + fn to_mojom_endian(self) -> f64 { + unsafe { mem::transmute::<u64, f64>(mem::transmute::<f64, u64>(self).to_le()) } + } +} + +/// Align to the Mojom default of 8 bytes. +pub fn align_default(bytes: usize) -> usize { + util::align_bytes(bytes, 8) +} + +/// The size in bytes of any data header. +pub const DATA_HEADER_SIZE: usize = 8; + +/// A value that goes in the second u32 of a +/// a data header. +/// +/// Since the data header can head many types, +/// this enum represents all the kinds of data +/// that can end up in a data header. +#[derive(Clone, Copy)] +pub enum DataHeaderValue { + Elements(u32), + Version(u32), + UnionTag(u32), +} + +impl DataHeaderValue { + /// Get the raw u32 value. + fn as_raw(self) -> u32 { + match self { + DataHeaderValue::Elements(v) => v, + DataHeaderValue::Version(v) => v, + DataHeaderValue::UnionTag(v) => v, + } + } +} + +/// A data header is placed at the beginning of every serialized +/// Mojom object, providing its size as well as some extra meta-data. +/// +/// The meta-data should always come from a DataHeaderValue. +pub struct DataHeader { + size: u32, + data: u32, +} + +impl DataHeader { + /// Create a new DataHeader. + pub fn new(size: usize, data: DataHeaderValue) -> DataHeader { + DataHeader { size: size as u32, data: data.as_raw() } + } + + /// Getter for size. + pub fn size(&self) -> u32 { + self.size + } + + /// Getter for extra meta-data. + pub fn data(&self) -> u32 { + self.data + } +} + +/// This context object represents an encoding/decoding context. +#[derive(Clone, Default)] +pub struct Context { + /// An index representing an encoding state. + id: usize, + + /// Whether or not our current context is directly inside of + /// a union. + is_union: bool, +} + +impl Context { + /// Create a new context with all data default. + pub fn new(id: usize) -> Context { + Context { id: id, is_union: false } + } + + /// Getter for the encoding state ID. + pub fn id(&self) -> usize { + self.id + } + + /// Getter for whether or not we are in a union. + pub fn is_union(&self) -> bool { + self.is_union + } + + /// Change whether or not we are inside of a union and create that + /// as a new context. + pub fn set_is_union(&self, value: bool) -> Context { + let mut new_context = self.clone(); + new_context.is_union = value; + new_context + } +} + +/// An encoding state represents the encoding logic for a single +/// Mojom object that is NOT inlined, such as a struct or an array. +pub struct EncodingState<'slice> { + /// The buffer the state may write to. + data: &'slice mut [u8], + + /// The offset of this serialized object into the overall buffer. + global_offset: usize, + + /// The current offset within 'data'. + offset: usize, + + /// The current bit offset within 'data'. + bit_offset: Bits, +} + +impl<'slice> EncodingState<'slice> { + /// Create a new encoding state. + /// + /// Note: the encoder will not allocate a buffer for you, rather + /// a pre-allocated buffer must be passed in. + pub fn new( + buffer: &'slice mut [u8], + header: &DataHeader, + offset: usize, + ) -> EncodingState<'slice> { + let mut state = + EncodingState { data: buffer, global_offset: offset, offset: 0, bit_offset: Bits(0) }; + state.write(header.size()); + state.write(header.data()); + state + } + + /// Align the encoding state to the next byte. + pub fn align_to_byte(&mut self) { + if self.bit_offset > Bits(0) { + self.offset += 1; + self.bit_offset = Bits(0); + } + } + + /// Align the encoding state to the next 'bytes' boundary. + pub fn align_to_bytes(&mut self, bytes: usize) { + self.offset = util::align_bytes(self.offset, bytes); + } + + /// Write a primitive into the buffer. + fn write<T: MojomNumeric>(&mut self, data: T) { + let num_bytes = mem::size_of::<T>(); + let bytes = data.to_mojom_endian(); + debug_assert!(num_bytes + self.offset <= self.data.len()); + unsafe { + ptr::copy_nonoverlapping( + mem::transmute::<&T, *const u8>(&bytes), + (&mut self.data[self.offset..]).as_mut_ptr(), + num_bytes, + ); + } + self.bit_offset = Bits(0); + self.offset += num_bytes; + } + + /// Encode a primitive into the buffer, naturally aligning it. + pub fn encode<T: MojomNumeric>(&mut self, data: T) { + self.align_to_byte(); + self.align_to_bytes(mem::size_of::<T>()); + self.write(data); + } + + /// Encode a boolean value into the buffer as one bit. + pub fn encode_bool(&mut self, data: bool) { + let offset = self.offset; + if data { + self.data[offset] |= self.bit_offset.as_set_bit(); + } + self.bit_offset += Bits(1); + let (bits, bytes) = self.bit_offset.as_bits_and_bytes(); + self.offset += bytes; + self.bit_offset = bits; + } + + /// Encode a null union into the buffer. + pub fn encode_null_union(&mut self) { + self.align_to_byte(); + self.align_to_bytes(8); + self.write(0 as u32); // Size + self.write(0 as u32); // Tag + self.write(0 as u64); // Data + } + + /// Encode a null pointer into the buffer. + pub fn encode_null_pointer(&mut self) { + self.align_to_byte(); + self.align_to_bytes(8); + self.encode(MOJOM_NULL_POINTER); + } + + /// Encode a null handle into the buffer. + pub fn encode_null_handle(&mut self) { + self.align_to_byte(); + self.align_to_bytes(4); + self.encode(-1 as i32); + } + + /// Encode a non-null pointer into the buffer. + /// + /// 'location' is an absolute location in the global buffer, but + /// Mojom pointers are offsets relative to the pointer, so we + /// perform that conversion here before writing. + pub fn encode_pointer(&mut self, location: u64) { + self.align_to_byte(); + self.align_to_bytes(8); + let current_location = (self.global_offset + self.offset) as u64; + debug_assert!(location >= current_location); + self.encode(location - current_location); + } +} + +/// A struct that will encode a given Mojom object and convert it into +/// bytes and a vector of handles. +pub struct Encoder<'slice> { + bytes: usize, + buffer: Option<&'slice mut [u8]>, + states: Vec<EncodingState<'slice>>, + handles: Vec<UntypedHandle>, +} + +impl<'slice> Encoder<'slice> { + /// Create a new Encoder. + pub fn new(buffer: &'slice mut [u8]) -> Encoder<'slice> { + Encoder { bytes: 0, buffer: Some(buffer), states: Vec::new(), handles: Vec::new() } + } + + /// Get the current encoded size (useful for writing pointers). + pub fn size(&self) -> usize { + self.bytes + } + + /// Start encoding a new object with its data header. + /// + /// Creates a new encoding state for the object. + pub fn add(&mut self, header: &DataHeader) -> Option<Context> { + let buf = self.buffer.take().unwrap(); + if buf.len() < (header.size() as usize) { + self.buffer = Some(buf); + return None; + } + let obj_bytes = header.size() as usize; + let (claimed, rest) = buf.split_at_mut(obj_bytes); + self.states.push(EncodingState::new(claimed, header, self.bytes)); + self.bytes += obj_bytes; + let padding_bytes = align_default(obj_bytes) - obj_bytes; + if padding_bytes <= rest.len() { + let (_, new_buffer) = rest.split_at_mut(padding_bytes); + self.bytes += padding_bytes; + self.buffer = Some(new_buffer); + } else { + self.buffer = Some(rest); + } + Some(Context::new(self.states.len() - 1)) + } + + /// Adds a handle and returns an offset to that handle in the + /// final handle vector. + pub fn add_handle(&mut self, handle: UntypedHandle) -> usize { + self.handles.push(handle); + self.handles.len() - 1 + } + + /// Immutably borrow an encoding state via Context. + pub fn get(&self, context: &Context) -> &EncodingState<'slice> { + &self.states[context.id()] + } + + /// Mutably borrow an encoding state via Context. + pub fn get_mut(&mut self, context: &Context) -> &mut EncodingState<'slice> { + &mut self.states[context.id()] + } + + /// Signal to finish encoding by destroying the Encoder and returning the final + /// handle vector. + /// + /// Note: No byte buffer is returned as that is pre-allocated. + pub fn unwrap(self) -> Vec<UntypedHandle> { + self.handles + } +}
diff --git a/mojo/public/rust/bindings/macros.rs b/mojo/public/rust/bindings/macros.rs new file mode 100644 index 0000000..63848a5 --- /dev/null +++ b/mojo/public/rust/bindings/macros.rs
@@ -0,0 +1,160 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/// This macro provides a common implementation of MojomEncodable +/// for MojomPointer types. +/// +/// Note: it does not implement compute_size(); +/// +/// The Rust type system currently lacks the facilities to do this +/// generically (need mutually excludable traits) but this macro +/// should be replaced as soon as this is possible. +#[macro_export] +macro_rules! impl_encodable_for_pointer { + () => { + fn mojom_alignment() -> usize { + 8 // All mojom pointers are 8 bytes in length, and thus are 8-byte aligned + } + fn mojom_type() -> $crate::bindings::mojom::MojomType { + $crate::bindings::mojom::MojomType::Pointer + } + fn embed_size( + _context: &$crate::bindings::encoding::Context, + ) -> $crate::bindings::encoding::Bits { + $crate::bindings::mojom::POINTER_BIT_SIZE + } + fn encode( + self, + encoder: &mut $crate::bindings::encoding::Encoder, + context: $crate::bindings::encoding::Context, + ) { + let loc = encoder.size() as u64; + { + let state = encoder.get_mut(&context); + state.encode_pointer(loc); + } + self.encode_new(encoder, context); + } + fn decode( + decoder: &mut $crate::bindings::decoding::Decoder, + context: $crate::bindings::encoding::Context, + ) -> Result<Self, ValidationError> { + let ptr = { + let state = decoder.get_mut(&context); + match state.decode_pointer() { + Some(ptr) => ptr, + None => return Err(ValidationError::IllegalPointer), + } + }; + if ptr == $crate::bindings::mojom::MOJOM_NULL_POINTER { + Err(ValidationError::UnexpectedNullPointer) + } else { + Self::decode_new(decoder, context, ptr) + } + } + }; +} + +/// This macro provides a common implementation of MojomEncodable +/// for MojomUnion types. +/// +/// Note: it does not implement compute_size(); +/// +/// The Rust type system currently lacks the facilities to do this +/// generically (need mutually excludable traits) but this macro +/// should be replaced as soon as this is possible. +#[macro_export] +macro_rules! impl_encodable_for_union { + () => { + fn mojom_alignment() -> usize { + 8 + } + fn mojom_type() -> $crate::bindings::mojom::MojomType { + $crate::bindings::mojom::MojomType::Union + } + fn embed_size( + context: &$crate::bindings::encoding::Context, + ) -> $crate::bindings::encoding::Bits { + if context.is_union() { + Self::nested_embed_size() + } else { + Self::inline_embed_size() + } + } + fn encode( + self, + encoder: &mut $crate::bindings::encoding::Encoder, + context: $crate::bindings::encoding::Context, + ) { + if context.is_union() { + self.nested_encode(encoder, context); + } else { + self.inline_encode(encoder, context.set_is_union(true)); + } + } + fn decode( + decoder: &mut $crate::bindings::decoding::Decoder, + context: $crate::bindings::encoding::Context, + ) -> Result<Self, ValidationError> { + if context.is_union() { + Self::nested_decode(decoder, context) + } else { + Self::inline_decode(decoder, context.set_is_union(true)) + } + } + }; +} + +/// This macro provides a common implementation of MojomEncodable +/// for MojomInterface types. +/// +/// Note: it does not implement compute_size(); +/// +/// The Rust type system currently lacks the facilities to do this +/// generically (need mutually excludable traits) but this macro +/// should be replaced as soon as this is possible. +#[macro_export] +macro_rules! impl_encodable_for_interface { + () => { + fn mojom_alignment() -> usize { + 4 + } + fn mojom_type() -> $crate::bindings::mojom::MojomType { + $crate::bindings::mojom::MojomType::Interface + } + fn embed_size( + _context: &$crate::bindings::encoding::Context, + ) -> $crate::bindings::encoding::Bits { + use std::mem; + $crate::bindings::encoding::Bits(2 * 8 * mem::size_of::<u32>()) + } + fn compute_size(&self, _context: $crate::bindings::encoding::Context) -> usize { + 0 // Indicates that this type is inlined and it adds nothing external to the size + } + fn encode( + self, + encoder: &mut $crate::bindings::encoding::Encoder, + context: $crate::bindings::encoding::Context, + ) { + let version = self.version(); + let pos = encoder.add_handle(self.as_untyped()); + let mut state = encoder.get_mut(&context); + state.encode(pos as i32); + state.encode(version as u32); + } + fn decode( + decoder: &mut $crate::bindings::decoding::Decoder, + context: $crate::bindings::encoding::Context, + ) -> Result<Self, ValidationError> { + let (handle_index, version) = { + let mut state = decoder.get_mut(&context); + (state.decode::<i32>(), state.decode::<u32>()) + }; + let handle = + try!(decoder + .claim_handle::<$crate::system::message_pipe::MessageEndpoint>(handle_index)); + Ok(Self::with_version(handle, version)) + } + }; +}
diff --git a/mojo/public/rust/bindings/message.rs b/mojo/public/rust/bindings/message.rs new file mode 100644 index 0000000..9e1a830 --- /dev/null +++ b/mojo/public/rust/bindings/message.rs
@@ -0,0 +1,99 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use bindings::decoding::{Decoder, ValidationError}; +use bindings::encoding; +use bindings::encoding::{Context, DataHeaderValue, Encoder, DATA_HEADER_SIZE}; +use bindings::mojom::{MojomEncodable, MojomPointer, MojomStruct}; + +/// A flag for the message header indicating that no flag has been set. +pub const MESSAGE_HEADER_NO_FLAG: u32 = 0; + +/// A flag for the message header indicating that this message expects +/// a response. +pub const MESSAGE_HEADER_EXPECT_RESPONSE: u32 = 1; + +/// A flag for the message header indicating that this message is +/// a response. +pub const MESSAGE_HEADER_IS_RESPONSE: u32 = 2; + +const MESSAGE_HEADER_VERSIONS: [(u32, u32); 2] = [(0, 16), (1, 24)]; + +/// A message header object implemented as a Mojom struct. +pub struct MessageHeader { + pub version: u32, + pub name: u32, + pub flags: u32, + pub request_id: u64, +} + +impl MessageHeader { + /// Create a new MessageHeader. + pub fn new(version: u32, name: u32, flags: u32) -> MessageHeader { + MessageHeader { version: version, name: name, flags: flags, request_id: 0 } + } +} + +impl MojomPointer for MessageHeader { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(self.version) + } + + /// Get the serialized size. + /// + /// This value differs based on whether or not + /// a request_id is necessary. + fn serialized_size(&self, _context: &Context) -> usize { + let mut size = DATA_HEADER_SIZE + 8; + if self.flags != MESSAGE_HEADER_NO_FLAG { + size += 8; + } + encoding::align_default(size) + } + + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.name, encoder, context.clone()); + MojomEncodable::encode(self.flags, encoder, context.clone()); + if self.version > 0 { + MojomEncodable::encode(self.request_id, encoder, context.clone()); + } + } + + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let mut state = decoder.get_mut(&context); + let version = match state.decode_struct_header(&MESSAGE_HEADER_VERSIONS) { + Ok(header) => header.data(), + Err(err) => return Err(err), + }; + let name = state.decode::<u32>(); + let flags = state.decode::<u32>(); + if flags > MESSAGE_HEADER_IS_RESPONSE { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + if version == 0 { + if flags == MESSAGE_HEADER_IS_RESPONSE || flags == MESSAGE_HEADER_EXPECT_RESPONSE { + return Err(ValidationError::MessageHeaderMissingRequestId); + } + Ok(MessageHeader { version: version, name: name, flags: flags, request_id: 0 }) + } else if version == 1 { + Ok(MessageHeader { + version: version, + name: name, + flags: flags, + request_id: state.decode::<u64>(), + }) + } else { + return Err(ValidationError::UnexpectedStructHeader); + } + } +} + +impl MojomEncodable for MessageHeader { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + self.serialized_size(&context) + } +} + +impl MojomStruct for MessageHeader {}
diff --git a/mojo/public/rust/bindings/mod.rs b/mojo/public/rust/bindings/mod.rs new file mode 100644 index 0000000..b8475b47 --- /dev/null +++ b/mojo/public/rust/bindings/mod.rs
@@ -0,0 +1,13 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#[macro_use] +mod macros; +mod util; + +pub mod decoding; +pub mod encoding; +pub mod message; +pub mod mojom; +pub mod run_loop;
diff --git a/mojo/public/rust/bindings/mojom.rs b/mojo/public/rust/bindings/mojom.rs new file mode 100644 index 0000000..9ea3f99 --- /dev/null +++ b/mojo/public/rust/bindings/mojom.rs
@@ -0,0 +1,913 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use bindings::decoding::{Decoder, ValidationError}; +use bindings::encoding; +use bindings::encoding::{Bits, Context, DataHeader, DataHeaderValue, Encoder, DATA_HEADER_SIZE}; +use bindings::message::MessageHeader; + +use std::cmp::Eq; +use std::collections::HashMap; +use std::hash::Hash; +use std::mem; +use std::panic; +use std::ptr; +use std::vec::Vec; + +use system::data_pipe; +use system::message_pipe; +use system::shared_buffer; +use system::wait_set; +use system::{CastHandle, Handle, MojoResult, UntypedHandle}; + +/// The size of a Mojom map plus header in bytes. +const MAP_SIZE: usize = 24; + +/// The sorted set of versions for a map. +const MAP_VERSIONS: [(u32, u32); 1] = [(0, MAP_SIZE as u32)]; + +/// The size of a Mojom union in bytes (header included). +pub const UNION_SIZE: usize = 16; + +/// The size of a Mojom pointer in bits. +pub const POINTER_BIT_SIZE: Bits = Bits(64); + +/// The value of a Mojom null pointer. +pub const MOJOM_NULL_POINTER: u64 = 0; + +/// An enumeration of all the possible low-level Mojom types. +pub enum MojomType { + Simple, + Pointer, + Union, + Handle, + Interface, +} + +/// Whatever implements this trait can be serialized in the Mojom format. +pub trait MojomEncodable: Sized { + /// Get the Mojom type. + fn mojom_type() -> MojomType; + + /// Get this type's Mojom alignment. + fn mojom_alignment() -> usize; + + /// The amount of space in bits the type takes up when inlined + /// into another type at serialization time. + fn embed_size(context: &Context) -> Bits; + + /// Recursively computes the size of the complete Mojom archive + /// starting from this type. + fn compute_size(&self, context: Context) -> usize; + + /// Encodes this type into the encoder given a context. + fn encode(self, encoder: &mut Encoder, context: Context); + + /// Using a decoder, decodes itself out of a byte buffer. + fn decode(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError>; +} + +/// Whatever implements this trait is a Mojom pointer type which means +/// that on encode, a pointer is inlined and the implementer is +/// serialized elsewhere in the output buffer. +pub trait MojomPointer: MojomEncodable { + /// Get the DataHeader meta-data for this pointer type. + fn header_data(&self) -> DataHeaderValue; + + /// Get the size of only this type when serialized. + fn serialized_size(&self, context: &Context) -> usize; + + /// Encodes the actual values of the type into the encoder. + fn encode_value(self, encoder: &mut Encoder, context: Context); + + /// Decodes the actual values of the type into the decoder. + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError>; + + /// Writes a pointer inlined into the current context before calling + /// encode_value. + fn encode_new(self, encoder: &mut Encoder, context: Context) { + let data_size = self.serialized_size(&context); + let data_header = DataHeader::new(data_size, self.header_data()); + let new_context = encoder.add(&data_header).unwrap(); + self.encode_value(encoder, new_context); + } + + /// Reads a pointer inlined into the current context before calling + /// decode_value. + fn decode_new( + decoder: &mut Decoder, + _context: Context, + pointer: u64, + ) -> Result<Self, ValidationError> { + match decoder.claim(pointer as usize) { + Ok(new_context) => Self::decode_value(decoder, new_context), + Err(err) => Err(err), + } + } +} + +/// Whatever implements this trait is a Mojom union type which means that +/// on encode it is inlined, but if the union is nested inside of another +/// union type, it is treated as a pointer type. +pub trait MojomUnion: MojomEncodable { + /// Get the union's current tag. + fn get_tag(&self) -> u32; + + /// Encode the actual value of the union. + fn encode_value(self, encoder: &mut Encoder, context: Context); + + /// Decode the actual value of the union. + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError>; + + /// The embed_size for when the union acts as a pointer type. + fn nested_embed_size() -> Bits { + POINTER_BIT_SIZE + } + + /// The encoding routine for when the union acts as a pointer type. + fn nested_encode(self, encoder: &mut Encoder, context: Context) { + let loc = encoder.size() as u64; + { + let state = encoder.get_mut(&context); + state.encode_pointer(loc); + } + let tag = DataHeaderValue::UnionTag(self.get_tag()); + let data_header = DataHeader::new(UNION_SIZE, tag); + let new_context = encoder.add(&data_header).unwrap(); + self.encode_value(encoder, new_context.set_is_union(true)); + } + + /// The decoding routine for when the union acts as a pointer type. + fn nested_decode(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let global_offset = { + let state = decoder.get_mut(&context); + match state.decode_pointer() { + Some(ptr) => ptr as usize, + None => return Err(ValidationError::IllegalPointer), + } + }; + if global_offset == (MOJOM_NULL_POINTER as usize) { + return Err(ValidationError::UnexpectedNullPointer); + } + match decoder.claim(global_offset as usize) { + Ok(new_context) => Self::decode_value(decoder, new_context), + Err(err) => Err(err), + } + } + + /// The embed_size for when the union is inlined into the current context. + fn inline_embed_size() -> Bits { + Bits(8 * (UNION_SIZE as usize)) + } + + /// The encoding routine for when the union is inlined into the current context. + fn inline_encode(self, encoder: &mut Encoder, context: Context) { + { + let mut state = encoder.get_mut(&context); + state.align_to_bytes(8); + state.encode(UNION_SIZE as u32); + state.encode(self.get_tag()); + } + self.encode_value(encoder, context.clone()); + { + let mut state = encoder.get_mut(&context); + state.align_to_bytes(8); + state.align_to_byte(); + } + } + + /// The decoding routine for when the union is inlined into the current context. + fn inline_decode(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + { + let mut state = decoder.get_mut(&context); + state.align_to_byte(); + state.align_to_bytes(8); + } + let value = Self::decode_value(decoder, context.clone()); + { + let mut state = decoder.get_mut(&context); + state.align_to_byte(); + state.align_to_bytes(8); + } + value + } +} + +/// A marker trait that marks Mojo handles as encodable. +pub trait MojomHandle: CastHandle + MojomEncodable {} + +/// Whatever implements this trait is considered to be a Mojom +/// interface, that is, a message pipe which conforms to some +/// messaging interface. +/// +/// We force an underlying message pipe to be used via the pipe() +/// and unwrap() routines. +pub trait MojomInterface: MojomEncodable { + /// Get the service name for this interface. + fn service_name() -> &'static str; + + /// Get the version for this interface. + fn version(&self) -> u32; + + /// Access the underlying message pipe for this interface. + fn pipe(&self) -> &message_pipe::MessageEndpoint; + + /// Unwrap the interface into its underlying message pipe. + fn unwrap(self) -> message_pipe::MessageEndpoint; +} + +/// An error that may occur when sending data over a Mojom interface. +#[derive(Debug)] +pub enum MojomSendError { + /// Failed to write to the underlying message pipe. + FailedWrite(MojoResult), + + /// The version is too old to write the attempted message. + OldVersion(u32, u32), +} + +/// Whatever implements this trait is considered to be a Mojom +/// interface that may send messages of some generic type. +/// +/// When implementing this trait, the correct way is to specify +/// a tighter trait bound than MojomMessage that limits the types +/// available for sending to those that are valid messages available +/// to the interface. +/// +/// TODO(mknyszek): Add sending control messages +pub trait MojomInterfaceSend<R: MojomMessage>: MojomInterface { + /// Creates a message. + fn create_request(&self, req_id: u64, payload: R) -> (Vec<u8>, Vec<UntypedHandle>) { + let mut header = R::create_header(); + header.request_id = req_id; + let header_size = header.compute_size(Default::default()); + let size = header_size + payload.compute_size(Default::default()); + let mut buffer: Vec<u8> = Vec::with_capacity(size); + buffer.resize(size, 0); + let handles = { + let (header_buf, rest_buf) = buffer.split_at_mut(header_size); + let mut handles = header.serialize(header_buf); + handles.extend(payload.serialize(rest_buf).into_iter()); + handles + }; + (buffer, handles) + } + + /// Creates and sends a message, and returns its request ID. + fn send_request(&self, req_id: u64, payload: R) -> Result<(), MojomSendError> { + if self.version() < R::min_version() { + return Err(MojomSendError::OldVersion(self.version(), R::min_version())); + } + let (buffer, handles) = self.create_request(req_id, payload); + match self.pipe().write(&buffer, handles, mpflags!(Write::None)) { + MojoResult::Okay => Ok(()), + err => Err(MojomSendError::FailedWrite(err)), + } + } +} + +/// An error that may occur when attempting to recieve a message over a +/// Mojom interface. +#[derive(Debug)] +pub enum MojomRecvError { + /// Failed to read from the underlying message pipe. + FailedRead(MojoResult), + + /// Failed to validate the buffer during decode. + FailedValidation(ValidationError), +} + +/// Whatever implements this trait is considered to be a Mojom +/// interface that may recieve messages for some interface. +/// +/// When implementing this trait, specify the container "union" type +/// which can contain any of the potential messages that may be recieved. +/// This way, we can return that type and let the user multiplex over +/// what message was received. +/// +/// TODO(mknyszek): Add responding to control messages +pub trait MojomInterfaceRecv: MojomInterface { + type Container: MojomMessageOption; + + /// Tries to read a message from a pipe and decodes it. + fn recv_response(&self) -> Result<(u64, Self::Container), MojomRecvError> { + match self.pipe().read(mpflags!(Read::None)) { + Ok((buffer, handles)) => match Self::Container::decode_message(buffer, handles) { + Ok((req_id, val)) => Ok((req_id, val)), + Err(err) => Err(MojomRecvError::FailedValidation(err)), + }, + Err(err) => Err(MojomRecvError::FailedRead(err)), + } + } +} + +/// Whatever implements this trait is considered to be a Mojom struct. +/// +/// Mojom structs are always the root of any Mojom message. Thus, we +/// provide convenience functions for serialization here. +pub trait MojomStruct: MojomPointer { + /// Given a pre-allocated buffer, the struct serializes itself. + fn serialize(self, buffer: &mut [u8]) -> Vec<UntypedHandle> { + let mut encoder = Encoder::new(buffer); + self.encode_new(&mut encoder, Default::default()); + encoder.unwrap() + } + + /// The struct computes its own size, allocates a buffer, and then + /// serializes itself into that buffer. + fn auto_serialize(self) -> (Vec<u8>, Vec<UntypedHandle>) { + let size = self.compute_size(Default::default()); + let mut buf = Vec::with_capacity(size); + buf.resize(size, 0); + let handles = self.serialize(&mut buf); + (buf, handles) + } + + /// Decode the type from a byte array and a set of handles. + fn deserialize(buffer: &[u8], handles: Vec<UntypedHandle>) -> Result<Self, ValidationError> { + let mut decoder = Decoder::new(buffer, handles); + Self::decode_new(&mut decoder, Default::default(), 0) + } +} + +/// Marks a MojomStruct as being capable of being sent across some +/// Mojom interface. +pub trait MojomMessage: MojomStruct { + fn min_version() -> u32; + fn create_header() -> MessageHeader; +} + +/// The trait for a "container" type intended to be used in MojomInterfaceRecv. +/// +/// This trait contains the decode logic which decodes based on the message header +/// and returns itself: a union type which may contain any of the possible messages +/// that may be sent across this interface. +pub trait MojomMessageOption: Sized { + /// Decodes the actual payload of the message. + /// + /// Implemented by a code generator. + fn decode_payload( + header: MessageHeader, + buffer: &[u8], + handles: Vec<UntypedHandle>, + ) -> Result<Self, ValidationError>; + + /// Decodes the message header and then the payload, returning a new + /// copy of itself and the request ID found in the header. + fn decode_message( + buffer: Vec<u8>, + handles: Vec<UntypedHandle>, + ) -> Result<(u64, Self), ValidationError> { + let header = try!(MessageHeader::deserialize(&buffer[..], Vec::new())); + let payload_buffer = &buffer[header.serialized_size(&Default::default())..]; + let req_id = header.request_id; + let ret = try!(Self::decode_payload(header, payload_buffer, handles)); + Ok((req_id, ret)) + } +} + +// ********************************************** // +// ****** IMPLEMENTATIONS FOR COMMON TYPES ****** // +// ********************************************** // + +macro_rules! impl_encodable_for_prim { + ($($prim_type:ty),*) => { + $( + impl MojomEncodable for $prim_type { + fn mojom_type() -> MojomType { + MojomType::Simple + } + fn mojom_alignment() -> usize { + mem::size_of::<$prim_type>() + } + fn embed_size(_context: &Context) -> Bits { + Bits(8 * mem::size_of::<$prim_type>()) + } + fn compute_size(&self, _context: Context) -> usize { + 0 // Indicates that this type is inlined and it adds nothing external to the size + } + fn encode(self, encoder: &mut Encoder, context: Context) { + let mut state = encoder.get_mut(&context); + state.encode(self); + } + fn decode(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let mut state = decoder.get_mut(&context); + Ok(state.decode::<Self>()) + } + } + )* + } +} + +impl_encodable_for_prim!(i8, i16, i32, i64, u8, u16, u32, u64, f32, f64); + +impl MojomEncodable for bool { + fn mojom_alignment() -> usize { + panic!("Should never check_decode mojom_alignment of bools (they're bit-aligned)!"); + } + fn mojom_type() -> MojomType { + MojomType::Simple + } + fn embed_size(_context: &Context) -> Bits { + Bits(1) + } + fn compute_size(&self, _context: Context) -> usize { + 0 // Indicates that this type is inlined and it adds nothing external to the size + } + fn encode(self, encoder: &mut Encoder, context: Context) { + let mut state = encoder.get_mut(&context); + state.encode_bool(self); + } + fn decode(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let mut state = decoder.get_mut(&context); + Ok(state.decode_bool()) + } +} + +// Options should be considered to represent nullability the Mojom IDL. +// Any type wrapped in an Option type is nullable. + +impl<T: MojomEncodable> MojomEncodable for Option<T> { + fn mojom_alignment() -> usize { + T::mojom_alignment() + } + fn mojom_type() -> MojomType { + T::mojom_type() + } + fn embed_size(context: &Context) -> Bits { + T::embed_size(context) + } + fn compute_size(&self, context: Context) -> usize { + match *self { + Some(ref value) => value.compute_size(context), + None => 0, + } + } + fn encode(self, encoder: &mut Encoder, context: Context) { + match self { + Some(value) => value.encode(encoder, context), + None => { + let mut state = encoder.get_mut(&context); + match T::mojom_type() { + MojomType::Pointer => state.encode_null_pointer(), + MojomType::Union => state.encode_null_union(), + MojomType::Handle => state.encode_null_handle(), + MojomType::Interface => { + state.encode_null_handle(); + state.encode(0 as u32); + } + MojomType::Simple => panic!("Unexpected simple type in Option!"), + } + } + } + } + fn decode(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let skipped = { + let mut state = decoder.get_mut(&context); + match T::mojom_type() { + MojomType::Pointer => state.skip_if_null_pointer(), + MojomType::Union => state.skip_if_null_union(), + MojomType::Handle => state.skip_if_null_handle(), + MojomType::Interface => state.skip_if_null_interface(), + MojomType::Simple => panic!("Unexpected simple type in Option!"), + } + }; + if skipped { + Ok(None) + } else { + match T::decode(decoder, context) { + Ok(value) => Ok(Some(value)), + Err(err) => Err(err), + } + } + } +} + +macro_rules! impl_pointer_for_array { + () => { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Elements(self.len() as u32) + } + fn serialized_size(&self, context: &Context) -> usize { + DATA_HEADER_SIZE + + if self.len() > 0 { (T::embed_size(context) * self.len()).as_bytes() } else { 0 } + } + }; +} + +macro_rules! impl_encodable_for_array { + () => { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + let mut size = encoding::align_default(self.serialized_size(&context)); + for elem in self.iter() { + size += elem.compute_size(context.clone()); + } + size + } + }; +} + +impl<T: MojomEncodable> MojomPointer for Vec<T> { + impl_pointer_for_array!(); + fn encode_value(self, encoder: &mut Encoder, context: Context) { + for elem in self.into_iter() { + elem.encode(encoder, context.clone()); + } + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Vec<T>, ValidationError> { + let elems = { + let mut state = decoder.get_mut(&context); + match state.decode_array_header::<T>() { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let mut value = Vec::with_capacity(elems as usize); + for _ in 0..elems { + match T::decode(decoder, context.clone()) { + Ok(elem) => value.push(elem), + Err(err) => return Err(err), + } + } + Ok(value) + } +} + +impl<T: MojomEncodable> MojomEncodable for Vec<T> { + impl_encodable_for_array!(); +} + +macro_rules! impl_encodable_for_fixed_array { + ($($len:expr),*) => { + $( + impl<T: MojomEncodable> MojomPointer for [T; $len] { + impl_pointer_for_array!(); + fn encode_value(mut self, encoder: &mut Encoder, context: Context) { + let mut panic_error = None; + let mut moves = 0; + unsafe { + // In order to move elements out of an array we need to replace the + // value with uninitialized memory. + for elem in self.iter_mut() { + let owned_elem = mem::replace(elem, mem::uninitialized()); + // We need to handle if an unwinding panic happens to prevent use of + // uninitialized memory... + let next_context = context.clone(); + // We assert everything going into this closure is unwind safe. If anything + // is added, PLEASE make sure it is also unwind safe... + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { + owned_elem.encode(encoder, next_context); + })); + if let Err(err) = result { + panic_error = Some(err); + break; + } + moves += 1; + } + if let Some(err) = panic_error { + for i in moves..self.len() { + ptr::drop_in_place(&mut self[i] as *mut T); + } + // Forget the array to prevent a drop + mem::forget(self); + // Continue unwinding + panic::resume_unwind(err); + } + // We cannot risk drop() getting run on the array values, so we just + // forget self. + mem::forget(self); + } + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<[T; $len], ValidationError> { + let elems = { + let mut state = decoder.get_mut(&context); + match state.decode_array_header::<T>() { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + if elems != $len { + return Err(ValidationError::UnexpectedArrayHeader); + } + let mut array: [T; $len]; + let mut panic_error = None; + let mut inits = 0; + let mut error = None; + unsafe { + // Since we don't force Clone to be implemented on Mojom types + // (mainly due to handles) we need to create this array as uninitialized + // and initialize it manually. + array = mem::uninitialized(); + for elem in &mut array[..] { + // When a panic unwinds it may try to read and drop uninitialized + // memory, so we need to catch this. However, we pass mutable state! + // This could be bad as we could observe a broken invariant inside + // of decoder and access it as usual, but we do NOT access decoder + // here, nor do we ever unwind through one of decoder's methods. + // Therefore, it should be safe to assert that decoder is unwind safe. + let next_context = context.clone(); + // We assert everything going into this closure is unwind safe. If anything + // is added, PLEASE make sure it is also unwind safe... + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { + T::decode(decoder, next_context) + })); + match result { + Ok(non_panic_value) => match non_panic_value { + Ok(value) => ptr::write(elem, value), + Err(err) => { + error = Some(err); + break; + }, + }, + Err(err) => { + panic_error = Some(err); + break; + }, + } + inits += 1; + } + if panic_error.is_some() || error.is_some() { + // Drop everything that was initialized + for i in 0..inits { + ptr::drop_in_place(&mut array[i] as *mut T); + } + // Forget the array to prevent a drop + mem::forget(array); + if let Some(err) = panic_error { + panic::resume_unwind(err); + } + return Err(error.take().expect("Corrupted stack?")); + } + } + Ok(array) + } + } + impl<T: MojomEncodable> MojomEncodable for [T; $len] { + impl_encodable_for_array!(); + } + )* + } +} + +// Unfortunately, we cannot be generic over the length of a fixed array +// even though its part of the type (this will hopefully be added in the +// future) so for now we implement encodable for only the first 33 fixed +// size array types. +impl_encodable_for_fixed_array!( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31, 32 +); + +impl<T: MojomEncodable> MojomPointer for Box<[T]> { + impl_pointer_for_array!(); + fn encode_value(self, encoder: &mut Encoder, context: Context) { + for elem in self.into_vec().into_iter() { + elem.encode(encoder, context.clone()); + } + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Box<[T]>, ValidationError> { + match Vec::<T>::decode_value(decoder, context) { + Ok(vec) => Ok(vec.into_boxed_slice()), + Err(err) => Err(err), + } + } +} + +impl<T: MojomEncodable> MojomEncodable for Box<[T]> { + impl_encodable_for_array!(); +} + +// We can represent a Mojom string as just a Rust String type +// since both are UTF-8. +impl MojomPointer for String { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Elements(self.len() as u32) + } + fn serialized_size(&self, _context: &Context) -> usize { + DATA_HEADER_SIZE + self.len() + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + for byte in self.as_bytes() { + byte.encode(encoder, context.clone()); + } + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<String, ValidationError> { + let mut state = decoder.get_mut(&context); + let elems = match state.decode_array_header::<u8>() { + Ok(header) => header.data(), + Err(err) => return Err(err), + }; + let mut value = Vec::with_capacity(elems as usize); + for _ in 0..elems { + value.push(state.decode::<u8>()); + } + match String::from_utf8(value) { + Ok(string) => Ok(string), + Err(err) => panic!("Error decoding String: {}", err), + } + } +} + +impl MojomEncodable for String { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + } +} + +/// Helper function to clean up duplicate code in HashMap. +fn array_claim_and_decode_header<T: MojomEncodable>( + decoder: &mut Decoder, + offset: usize, +) -> Result<(Context, usize), ValidationError> { + let context = match decoder.claim(offset) { + Ok(new_context) => new_context, + Err(err) => return Err(err), + }; + let elems = { + let state = decoder.get_mut(&context); + match state.decode_array_header::<T>() { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + Ok((context, elems as usize)) +} + +impl<K: MojomEncodable + Eq + Hash, V: MojomEncodable> MojomPointer for HashMap<K, V> { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + MAP_SIZE + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + let elems = self.len(); + let meta_value = DataHeaderValue::Elements(elems as u32); + // We need to move values into this vector because we can't copy the keys. + // (Handles are not copyable so MojomEncodable cannot be copyable!) + let mut vals_vec = Vec::with_capacity(elems); + // Key setup + // Write a pointer to the keys array. + let keys_loc = encoder.size() as u64; + { + let state = encoder.get_mut(&context); + state.encode_pointer(keys_loc); + } + // Create the keys data header + let keys_bytes = DATA_HEADER_SIZE + (K::embed_size(&context) * elems).as_bytes(); + let keys_data_header = DataHeader::new(keys_bytes, meta_value); + // Claim space for the keys array in the encoder + let keys_context = encoder.add(&keys_data_header).unwrap(); + // Encode keys, setup vals + for (key, value) in self.into_iter() { + key.encode(encoder, keys_context.clone()); + vals_vec.push(value); + } + // Encode vals + vals_vec.encode(encoder, context.clone()) + } + fn decode_value( + decoder: &mut Decoder, + context: Context, + ) -> Result<HashMap<K, V>, ValidationError> { + let (keys_offset, vals_offset) = { + let state = decoder.get_mut(&context); + match state.decode_struct_header(&MAP_VERSIONS) { + Ok(_) => (), + Err(err) => return Err(err), + }; + // Decode the keys pointer and check for overflow + let keys_offset = match state.decode_pointer() { + Some(ptr) => ptr, + None => return Err(ValidationError::IllegalPointer), + }; + // Decode the keys pointer and check for overflow + let vals_offset = match state.decode_pointer() { + Some(ptr) => ptr, + None => return Err(ValidationError::IllegalPointer), + }; + if keys_offset == MOJOM_NULL_POINTER || vals_offset == MOJOM_NULL_POINTER { + return Err(ValidationError::UnexpectedNullPointer); + } + (keys_offset as usize, vals_offset as usize) + }; + let (keys_context, keys_elems) = + match array_claim_and_decode_header::<K>(decoder, keys_offset) { + Ok((context, elems)) => (context, elems), + Err(err) => return Err(err), + }; + let mut keys_vec: Vec<K> = Vec::with_capacity(keys_elems as usize); + for _ in 0..keys_elems { + let key = match K::decode(decoder, keys_context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + keys_vec.push(key); + } + let (vals_context, vals_elems) = + match array_claim_and_decode_header::<V>(decoder, vals_offset) { + Ok((context, elems)) => (context, elems), + Err(err) => return Err(err), + }; + if keys_elems != vals_elems { + return Err(ValidationError::DifferentSizedArraysInMap); + } + let mut map = HashMap::with_capacity(keys_elems as usize); + for key in keys_vec.into_iter() { + let val = match V::decode(decoder, vals_context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + map.insert(key, val); + } + Ok(map) + } +} + +impl<K: MojomEncodable + Eq + Hash, V: MojomEncodable> MojomEncodable for HashMap<K, V> { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + let mut size = encoding::align_default(self.serialized_size(&context)); + // The size of the one array + size += DATA_HEADER_SIZE; + size += (K::embed_size(&context) * self.len()).as_bytes(); + size = encoding::align_default(size); + // Any extra space used by the keys + for (key, _) in self { + size += key.compute_size(context.clone()); + } + // Need to re-align after this for the next array + size = encoding::align_default(size); + // The size of the one array + size += DATA_HEADER_SIZE; + size += (V::embed_size(&context) * self.len()).as_bytes(); + size = encoding::align_default(size); + // Any extra space used by the values + for (_, value) in self { + size += value.compute_size(context.clone()); + } + // Align one more time at the end to keep the next object aligned. + encoding::align_default(size) + } +} + +impl<T: MojomEncodable + CastHandle + Handle> MojomHandle for T {} + +macro_rules! impl_encodable_for_handle { + ($handle_type:path) => { + fn mojom_alignment() -> usize { + 4 + } + fn mojom_type() -> MojomType { + MojomType::Handle + } + fn embed_size(_context: &Context) -> Bits { + Bits(8 * mem::size_of::<u32>()) + } + fn compute_size(&self, _context: Context) -> usize { + 0 + } + fn encode(self, encoder: &mut Encoder, context: Context) { + let pos = encoder.add_handle(self.as_untyped()); + let mut state = encoder.get_mut(&context); + state.encode(pos as i32); + } + fn decode( + decoder: &mut Decoder, + context: Context, + ) -> Result<$handle_type, ValidationError> { + let handle_index = { + let mut state = decoder.get_mut(&context); + state.decode::<i32>() + }; + decoder.claim_handle::<$handle_type>(handle_index) + } + }; +} + +impl MojomEncodable for UntypedHandle { + impl_encodable_for_handle!(UntypedHandle); +} + +impl MojomEncodable for message_pipe::MessageEndpoint { + impl_encodable_for_handle!(message_pipe::MessageEndpoint); +} + +impl MojomEncodable for shared_buffer::SharedBuffer { + impl_encodable_for_handle!(shared_buffer::SharedBuffer); +} + +impl<T> MojomEncodable for data_pipe::Consumer<T> { + impl_encodable_for_handle!(data_pipe::Consumer<T>); +} + +impl<T> MojomEncodable for data_pipe::Producer<T> { + impl_encodable_for_handle!(data_pipe::Producer<T>); +} + +impl MojomEncodable for wait_set::WaitSet { + impl_encodable_for_handle!(wait_set::WaitSet); +}
diff --git a/mojo/public/rust/bindings/run_loop.rs b/mojo/public/rust/bindings/run_loop.rs new file mode 100644 index 0000000..1f50e03 --- /dev/null +++ b/mojo/public/rust/bindings/run_loop.rs
@@ -0,0 +1,668 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! This module contains a thread-local run-loop. +//! +//! The run-loop may have handles and handlers pre-registers +//! (and in fact, must) in order to keep running. The run-loop +//! executes until it has no more handles or handlers on itself, +//! or until it is told to quit via stop(). +//! +//! The run-loop waits until some signals on some handle is satisfied, +//! at which point it wakes up and executes the appropriate handler +//! method. This handler method may then be used to further populate +//! or de-populate the run-loop. +//! +//! As of yet, the run-loop is NOT thread-safe. Although it is useful +//! to be able to register tasks or handles from one thread onto +//! another thread's run-loop, this is as-of-yet unsupported, and +//! Rust should complain loudly when you try to do any threading here. + +use std::cell::RefCell; +use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd}; +use std::collections::BinaryHeap; +use std::collections::HashMap; +use std::i64; +use std::u32; +use std::vec::Vec; + +use system; +use system::core; +use system::wait_set; +use system::{Handle, MojoResult, MOJO_INDEFINITE}; + +/// Define the equivalent of MOJO_INDEFINITE for absolute deadlines +const MOJO_INDEFINITE_ABSOLUTE: system::MojoTimeTicks = 0; + +// TODO(mknyszek): The numbers below are arbitrary and come from the C++ bindings, +// and should probably be changed at some point + +/// Initial size of the result buffer. +const INITIAL_WAIT_SET_NUM_RESULTS: usize = 16; + +/// Maximum size of the result buffer. +const MAXIMUM_WAIT_SET_NUM_RESULTS: usize = 256; + +/// Thread-local data structure for keeping track of handles to wait on. +thread_local!(static TL_RUN_LOOP: RefCell<RunLoop<'static, 'static>> = RefCell::new(RunLoop::new())); + +/// Token representing handle/callback to wait on for this thread only. This +/// token only has meaning on the thread in which the handle was registered. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Token(u64); + +impl Token { + /// Get the wait token's "cookie" form, suitable for use in a wait set. + fn as_cookie(&self) -> u64 { + self.0 + } +} + +/// Represents the possible error cases that may occur when waiting +/// on a handle in a RunLoop. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum WaitError { + /// The handle has been closed or is otherwise no longer valid. + HandleClosed, + + /// The handle is currently busy in some transaction. + HandleBusy, + + /// It has been determined that the signals provided will never + /// be satisfied for this handle. + Unsatisfiable, +} + +/// A trait which defines an interface to be a handler usable by +/// a RunLoop. +pub trait Handler { + /// Called after a successful wait. + fn on_ready(&mut self, runloop: &mut RunLoop, token: Token); + + /// Called after the given deadline expires. + fn on_timeout(&mut self, runloop: &mut RunLoop, token: Token); + + /// Called when an unexpected error occurs. + fn on_error(&mut self, runloop: &mut RunLoop, token: Token, error: WaitError); +} + +/// A wrapper struct for carrying the handler as well as various information +/// about it. +struct HandlerInfo<'h> { + /// The handle for which we are waiting. + /// + /// We keep this handle around so that we may easily re-register. + handle: system::MojoHandle, + + /// The handler, boxed up. + /// + /// The handler is in an Option type because if it is currently being + /// used in a callback, we must take ownership to avoid mutability + /// cycles. The easiest way to do this is to take() from the Option then + /// put it back. + handler: Option<Box<Handler + 'h>>, + + /// An absolute deadline in terms of time ticks. + /// + /// This is the most recently updated deadline that + /// we should be watching out for. All others for this + /// token may be considered "stale". + deadline: system::MojoTimeTicks, +} + +impl<'h> HandlerInfo<'h> { + /// Take the handler out of its Option type. + pub fn take(&mut self) -> Option<Box<Handler + 'h>> { + self.handler.take() + } + + /// Put a new handler into the Option type. + pub fn give(&mut self, handler: Box<Handler + 'h>) { + self.handler = Some(handler); + } + + /// Getter for the system::MojoHandle held inside. + pub fn handle(&self) -> system::MojoHandle { + self.handle + } + + /// Getter for the current absolute deadline held inside. + pub fn deadline(&self) -> system::MojoTimeTicks { + self.deadline + } + + /// Setter to update the current absolute deadline. + pub fn set_deadline(&mut self, deadline: system::MojoTimeTicks) { + self.deadline = deadline + } +} + +/// A wrapper struct for carrying the task as well as some information about +/// it. +struct TaskInfo<'t> { + /// The task, boxed up. + closure: Box<FnMut(&mut RunLoop) + 't>, + + /// An absolute deadline in terms of time ticks. + /// + /// This is the most recently updated deadline that + /// we should be watching out for. All others for this + /// token may be considered "stale". + deadline: system::MojoTimeTicks, +} + +impl<'t> TaskInfo<'t> { + /// Executes the task within the info object, consuming it + /// in the process. + pub fn execute_task(mut self, run_loop: &mut RunLoop) { + (*self.closure)(run_loop); + } + + /// Getter for the current absolute deadline held inside. + pub fn deadline(&self) -> system::MojoTimeTicks { + self.deadline + } +} + +impl<'t> PartialEq for TaskInfo<'t> { + /// Equality for TaskInfo in terms of its deadline + fn eq(&self, other: &TaskInfo) -> bool { + self.deadline == other.deadline + } +} + +impl<'t> Eq for TaskInfo<'t> {} + +impl<'t> PartialOrd for TaskInfo<'t> { + /// Partial comparison for TaskInfo in terms of its deadline + /// + /// Reverses the comparison because the Rust std library only + /// offers a max-heap, and we need a min-heap. + fn partial_cmp(&self, other: &TaskInfo) -> Option<Ordering> { + other.deadline.partial_cmp(&self.deadline) + } +} + +impl<'t> Ord for TaskInfo<'t> { + /// Implement comparisons for Task Info. + /// + /// Reverses the comparison because the Rust std library only + /// offers a max-heap, and we need a min-heap. + fn cmp(&self, other: &Self) -> Ordering { + other.deadline.cmp(&self.deadline) + } +} + +/// Wrapper struct intended to be used in a priority queue +/// for efficiently retrieving the next closest deadline. +#[derive(Clone)] +struct DeadlineInfo { + /// The ID of the associated Handler struct in the RunLoop's + /// hash map. + token: Token, + + /// An absolute deadline in terms of time ticks. + deadline: system::MojoTimeTicks, +} + +impl DeadlineInfo { + /// Getter for an immutable borrow for the token inside. + pub fn token(&self) -> &Token { + &self.token + } + + /// Getter for the absolute deadline inside. + pub fn deadline(&self) -> system::MojoTimeTicks { + self.deadline + } +} + +impl PartialEq for DeadlineInfo { + /// Equality for DeadlineInfo in terms of its deadline + fn eq(&self, other: &DeadlineInfo) -> bool { + self.deadline == other.deadline + } +} + +impl Eq for DeadlineInfo {} + +impl PartialOrd for DeadlineInfo { + /// Partial comparison for DeadlineInfo in terms of its deadline + /// + /// Reverses the comparison because the Rust std library only + /// offers a max-heap, and we need a min-heap. + fn partial_cmp(&self, other: &DeadlineInfo) -> Option<Ordering> { + other.deadline.partial_cmp(&self.deadline) + } +} + +impl Ord for DeadlineInfo { + /// Implement comparisons for Deadline Info. + /// + /// Reverses the comparison because the Rust std library only + /// offers a max-heap, and we need a min-heap. + fn cmp(&self, other: &Self) -> Ordering { + other.deadline.cmp(&self.deadline) + } +} + +/// Convert a mojo deadline (which is a relative deadline to "now") to +/// an absolute deadline based on time ticks. +fn absolute_deadline(deadline: system::MojoDeadline) -> system::MojoTimeTicks { + if deadline == MOJO_INDEFINITE { + return MOJO_INDEFINITE_ABSOLUTE; + } + let mut converted = MOJO_INDEFINITE_ABSOLUTE; + let max_time_ticks = i64::MAX as system::MojoDeadline; + if deadline <= max_time_ticks { + let now = core::get_time_ticks_now(); + if deadline <= (max_time_ticks - (now as u64)) { + converted = (deadline as system::MojoTimeTicks) + now + } + } + converted +} + +/// Convert an absolute deadline to a mojo deadline which is relative to some +/// notion of "now". +/// +/// If the deadline is earlier than "now", this routine rounds up to "now". +fn relative_deadline( + deadline: system::MojoTimeTicks, + now: system::MojoTimeTicks, +) -> system::MojoDeadline { + if deadline == MOJO_INDEFINITE_ABSOLUTE { + MOJO_INDEFINITE + } else if now >= deadline { + 0 + } else { + (deadline - now) as system::MojoDeadline + } +} + +/// This structure contains all information necessary to wait on handles +/// asynchronously. +/// +/// Ultimately, it should only be a singleton living in +/// thread-local storage. +pub struct RunLoop<'h, 't> { + /// Running count of the next available token slot. + token_count: u64, + + /// A map of handlers. + /// + /// TODO(mknyszek): Try out a Slab allocator instead of a hashmap. + handlers: HashMap<Token, HandlerInfo<'h>>, + + /// A min-heap of delayed tasks in order to pull the soonest task to + /// execute efficiently. + tasks: BinaryHeap<TaskInfo<'t>>, + + /// A min-heap containing deadlines in order to pull out the soonest + /// deadline efficiently. + /// + /// Warning: may contain "stale" deadlines which are not kept in the + /// map! + deadlines: BinaryHeap<DeadlineInfo>, + + /// The Mojo structure keeping track of handles and signals. + /// + /// This structure must be kept in sync with handlers. + handle_set: wait_set::WaitSet, + + /// A flag that tells the RunLoop whether it should quit. + should_quit: bool, + + /// A flag that indicates whether the RunLoop is running or now + running: bool, +} + +impl<'h, 't> RunLoop<'h, 't> { + /// Create a new RunLoop. + pub fn new() -> RunLoop<'h, 't> { + RunLoop { + token_count: 0, + handlers: HashMap::new(), + tasks: BinaryHeap::new(), + deadlines: BinaryHeap::new(), + handle_set: wait_set::WaitSet::new(wsflags!(Create::None)).unwrap(), + should_quit: false, + running: false, + } + } + + /// Generate a new Token for this RunLoop + fn generate_token(&mut self) -> Token { + self.token_count = self.token_count.wrapping_add(1); + Token(self.token_count) + } + + /// Adds a new entry to the runloop queue. + pub fn register<H>( + &mut self, + handle: &Handle, + signals: system::HandleSignals, + deadline: system::MojoDeadline, + handler: H, + ) -> Token + where + H: Handler + 'h, + { + let token = self.generate_token(); + let abs_deadline = absolute_deadline(deadline); + self.handle_set.add(handle, signals, token.as_cookie(), wsflags!(Add::None)); + self.deadlines.push(DeadlineInfo { token: token.clone(), deadline: abs_deadline }); + debug_assert!(!self.handlers.contains_key(&token)); + self.handlers.insert( + token.clone(), + HandlerInfo { + handle: handle.get_native_handle(), + handler: Some(Box::new(handler)), + deadline: abs_deadline, + }, + ); + token + } + + /// Updates the signals and deadline of an existing handler in the + /// runloop via token. The token remains unchanged and valid. + /// + /// Returns true on a successful update and false if the token was not + /// found. + pub fn reregister( + &mut self, + token: &Token, + signals: system::HandleSignals, + deadline: system::MojoDeadline, + ) -> bool { + match self.handlers.get_mut(&token) { + Some(info) => { + let _result = self.handle_set.remove(token.as_cookie()); + debug_assert_eq!(_result, MojoResult::Okay); + let abs_deadline = absolute_deadline(deadline); + // Update what deadline we should be looking for, rendering + // all previously set deadlines "stale". + info.set_deadline(abs_deadline); + // Add a new deadline + self.deadlines.push(DeadlineInfo { token: token.clone(), deadline: abs_deadline }); + // Acquire the raw handle held by the HandlerInfo in order to + // call the wait_set's add method. Invalidate it immediately after + // in order to prevent the handle from being closed. + // + // It's perfectly okay for the handle to be invalid, so although this + // is all unsafe, the whole system should just call the handler with an + // error. + let mut dummy = unsafe { system::acquire(info.handle()) }; + self.handle_set.add(&dummy, signals, token.as_cookie(), wsflags!(Add::None)); + unsafe { + dummy.invalidate(); + } + true + } + None => false, + } + } + + /// Removes an entry from the runloop. + /// + /// Since we cannot remove from the deadlines heap, we just leave the deadline + /// in there as "stale", and we handle those when trying to find the next closest + /// deadline. + pub fn deregister(&mut self, token: Token) -> bool { + match self.handlers.remove(&token) { + Some(_) => { + let _result = self.handle_set.remove(token.as_cookie()); + debug_assert_eq!(_result, MojoResult::Okay); + true + } + None => false, + } + } + + /// Adds a task to be run by the runloop after some delay. + /// + /// Returns a token if the delay is valid, otherwise returns None. + pub fn post_task<F>(&mut self, task: F, delay: system::MojoTimeTicks) -> Result<(), ()> + where + F: FnMut(&mut RunLoop) + 't, + { + let now = core::get_time_ticks_now(); + if delay > i64::MAX - now { + return Err(()); + } + let deadline = now + delay; + self.tasks.push(TaskInfo { closure: Box::new(task), deadline: deadline }); + Ok(()) + } + + /// Uses the binary heaps to get the next closest deadline. + /// + /// Removes stale deadline entries as they are found, but + /// does not otherwise modify the heap of deadlines. + fn get_next_deadline(&mut self) -> system::MojoTimeTicks { + debug_assert!(!self.handlers.is_empty()); + let top_task_deadline = match self.tasks.peek() { + Some(info) => info.deadline(), + None => MOJO_INDEFINITE_ABSOLUTE, + }; + let mut top = match self.deadlines.peek() { + Some(info) => info.clone(), + None => return MOJO_INDEFINITE_ABSOLUTE, + }; + while !self.handlers.contains_key(top.token()) { + self.deadlines.pop(); + top = match self.deadlines.peek() { + Some(info) => info.clone(), + None => return MOJO_INDEFINITE_ABSOLUTE, + } + } + if top_task_deadline != MOJO_INDEFINITE_ABSOLUTE && top_task_deadline < top.deadline() { + top_task_deadline + } else { + top.deadline() + } + } + + /// Gets a handler by token to be manipulated in a consistent environment. + /// + /// This method provides a method of accessing a handler in order to manipulate + /// it in a manner that avoids a borrow cycle, that is, it take()s the handler + /// out of the HashMap, and returns it when manipulation has completed. + fn get_handler_with<F>(&mut self, token: &Token, invoker: F) + where + F: FnOnce(&mut Self, &mut Box<Handler + 'h>, Token, system::MojoTimeTicks), + { + // Logic for pulling out the handler as well as its current deadline. + // + // Unfortunately, pulling out the handler value here and "onto the stack" + // (it probably won't actually end up on the stack thanks to optimizations) + // is necessary since otherwise the borrow checker complains that we pass + // a mutable reference to the RunLoop and the handler (as &mut self) to + // the callbacks at the same time. This is understandably unsafe since + // modifying the hashmap with register and deregister can invalidate the + // reference to self in the callback. In the C++ bindings and in other Rust + // event loops this is exactly what happens, but I avoided this. The downside + // is that we can no longer nest event loop run() calls. Once we put a handler + // onto the stack here, we can no longer call its callback in a nested manner + // from the RunLoop. I could just enable nesting with this one restriction, that + // the handler calling run() will always be ignored, but this is unintuitive. + let (mut handler, deadline) = match self.handlers.get_mut(&token) { + Some(ref_info) => ( + match ref_info.take() { + Some(handler) => handler, + None => return, + }, + ref_info.deadline(), + ), + None => return, + }; + // Call the closure that will invoke the callbacks. + invoker(self, &mut handler, token.clone(), deadline); + // Restore the handler to its location in the HashMap + if let Some(ref_info) = self.handlers.get_mut(&token) { + ref_info.give(handler); + } + } + + /// For all the results we received, we notify the respective handle + /// owners of the results by calling their given callbacks. + /// + /// We do NOT dequeue notified handles. + fn notify_of_results(&mut self, results: &Vec<system::WaitSetResult>) { + debug_assert!(!self.handlers.is_empty()); + for wsr in results.iter() { + let token = Token(wsr.cookie()); + self.get_handler_with(&token, move |runloop, boxed_handler, token, _dl| { + let handler = boxed_handler.as_mut(); + match wsr.result() { + MojoResult::Okay => handler.on_ready(runloop, token), + MojoResult::Cancelled => { + handler.on_error(runloop, token, WaitError::HandleClosed) + } + MojoResult::Busy => handler.on_error(runloop, token, WaitError::HandleBusy), + MojoResult::FailedPrecondition => { + handler.on_error(runloop, token, WaitError::Unsatisfiable) + } + other => panic!("Unexpected result received after waiting: {}", other), + } + }); + // In order to quit as soon as possible, we should check to quit after every + // potential handler call, as any of them could have signaled to quit. + if self.should_quit { + break; + } + } + } + + /// Since the deadline expired, we notify the relevant handle + /// owners of the expiration by calling their given callbacks. + /// + /// We do NOT dequeue notified handles. + fn notify_of_expired(&mut self, expired_deadline: system::MojoTimeTicks) { + debug_assert!(!self.handlers.is_empty()); + let mut top = match self.deadlines.peek() { + Some(info) => info.clone(), + None => panic!("Should not be in notify_of_expired without at least one deadline!"), + }; + while expired_deadline >= top.deadline() { + let next_deadline = top.deadline(); + self.get_handler_with( + top.token(), + move |runloop, boxed_handler, token, expected_dl| { + let handler = boxed_handler.as_mut(); + if next_deadline == expected_dl { + handler.on_timeout(runloop, token); + } + }, + ); + // In order to quit as soon as possible, we should check to quit after every + // potential handler call, as any of them could have signaled to quit. + if self.should_quit { + break; + } + // Remove the deadline + self.deadlines.pop(); + // Break if the next deadline has not yet expired. + top = match self.deadlines.peek() { + Some(info) => info.clone(), + None => break, + }; + } + } + + /// Iterates through all tasks whose deadline has passed and executes + /// them, consuming their information object. + /// + /// These tasks all have access to the RunLoop so that they may reschedule + /// themselves or manipulate the RunLoop in some other way. + fn execute_ready_tasks(&mut self) { + let now = core::get_time_ticks_now(); + let mut deadline = match self.tasks.peek() { + Some(info) => info.deadline(), + None => return, + }; + while deadline < now { + let top = self.tasks.pop().expect("Sudden change to heap?"); + top.execute_task(self); + if self.should_quit { + return; + } + deadline = match self.tasks.peek() { + Some(info) => info.deadline(), + None => return, + }; + } + } + + /// Blocks on handle_set.wait_on_set using the information contained + /// within itself. + /// + /// This method blocks for only as long as the shortest deadline among all + /// handles this thread has registered. This method returns immediately as + /// soon as any one handle has its signals satisfied, fails to ever have its + /// signals satisfied, or reaches its deadline. + fn wait(&mut self, results_buffer: &mut Vec<system::WaitSetResult>) { + debug_assert!(!self.handlers.is_empty()); + self.execute_ready_tasks(); + // If after executing a task we quit or there are no handles, + // we have no reason to continue. + if self.handlers.is_empty() || self.should_quit { + return; + } + let deadline = self.get_next_deadline(); + let until_deadline = relative_deadline(deadline, core::get_time_ticks_now()); + // Perform the wait + match self.handle_set.wait_on_set(until_deadline, results_buffer) { + Ok(max_results) => { + self.notify_of_results(results_buffer); + // Clear the buffer since we don't need the results anymore. + // Helps prevent a copy if we resize the buffer. + results_buffer.clear(); + // Increase the size of the buffer if there are more results + // we could be holding. + let capacity = results_buffer.capacity(); + if capacity < MAXIMUM_WAIT_SET_NUM_RESULTS && capacity < (max_results) as usize { + results_buffer.reserve(capacity); + } + } + Err(result) => { + assert_eq!(result, MojoResult::DeadlineExceeded); + self.notify_of_expired(deadline); + } + } + } + + /// Loop forever until a callback tells us to quit. + pub fn run(&mut self) { + // It's an error it already be running... + if self.running { + panic!("RunLoop is already running!"); + } + self.running = true; + self.should_quit = false; + let mut buffer: Vec<system::WaitSetResult> = + Vec::with_capacity(INITIAL_WAIT_SET_NUM_RESULTS); + // Loop while we haven't been signaled to quit, and there's something to wait on. + while !self.should_quit && !self.handlers.is_empty() { + self.wait(&mut buffer) + } + self.running = false; + } + + /// Set a flag to quit at the next available moment. + pub fn quit(&mut self) { + self.should_quit = true; + } +} + +/// Provides a scope to modify the current thread's runloop. +pub fn with_current<F>(modifier: F) +where + F: FnOnce(&mut RunLoop), +{ + TL_RUN_LOOP.with(|ref_runloop| { + let mut runloop = ref_runloop.borrow_mut(); + modifier(&mut *runloop); + }); +}
diff --git a/mojo/public/rust/bindings/util.rs b/mojo/public/rust/bindings/util.rs new file mode 100644 index 0000000..208a5d7 --- /dev/null +++ b/mojo/public/rust/bindings/util.rs
@@ -0,0 +1,49 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! This module contains some useful functions for encoding. + +/// Given some size value, the size is aligned to some number in bytes. +/// +/// Neither the size nor bytes may be zero (those are always aligned) and +/// bytes must be a power of two (other alignments don't make sense). +pub fn align_bytes(size: usize, bytes: usize) -> usize { + debug_assert!(bytes != 0); + debug_assert!((bytes & (!bytes + 1)) == bytes); + (size + bytes - 1) & (!(bytes - 1)) +} + +/// Converts some number of bits into however many bytes are needed to +/// represent that bit size. +pub fn bits_to_bytes(bits: usize) -> usize { + ((bits + 7) >> 3) +} + +#[cfg(test)] +mod tests { + use super::align_bytes; + use super::bits_to_bytes; + + #[test] + fn check_align_bytes() { + assert_eq!(align_bytes(12, 8), 16); + assert_eq!(align_bytes(16, 4), 16); + assert_eq!(align_bytes(1, 1), 1); + } + + #[test] + #[should_panic] + fn check_bad_align_bytes() { + assert_eq!(align_bytes(15, 7), 21); + assert_eq!(align_bytes(2, 0), 0); + } + + #[test] + fn check_bits_to_bytes() { + assert_eq!(bits_to_bytes(8), 1); + assert_eq!(bits_to_bytes(0), 0); + assert_eq!(bits_to_bytes(1), 1); + assert_eq!(bits_to_bytes(21), 3); + } +}
diff --git a/mojo/public/rust/lib.rs b/mojo/public/rust/lib.rs new file mode 100644 index 0000000..f28143d --- /dev/null +++ b/mojo/public/rust/lib.rs
@@ -0,0 +1,143 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#[macro_use] +mod macros { + /// This macro must be used at the top-level in any + /// Rust Mojo application. It defines and abstracts away the + /// hook needed by Mojo in order to set up the basic + /// functionality (see mojo::system::ffi). It must take the + /// name of a function who returns a MojoResult and takes + /// exactly one argument: a mojo::handle::Handle, or on in + /// other words, an untyped handle. + #[macro_export] + macro_rules! set_mojo_main { + ( $fn_main:ident ) => { + #[allow(bad_style)] + #[no_mangle] + pub fn MojoMain(app_request_handle: mojo::system::MojoHandle) -> mojo::MojoResult { + use mojo::system::CastHandle; + use std::panic; + let handle = unsafe { + mojo::system::message_pipe::MessageEndpoint::from_untyped( + mojo::system::acquire(app_request_handle), + ) + }; + let result = panic::catch_unwind(|| $fn_main(handle)); + match result { + Ok(value) => value, + Err(_) => mojo::MojoResult::Aborted, + } + } + }; + } + + /// This macro assists in generating flags for + /// functions and methods found in mojo::system::message_pipe. + /// + /// See mojo::system::message_pipe for the available flags + /// that may be passed. + /// + /// # Examples + /// + /// # mpflags!(Create::None); + /// # mpflags!(Read::MayDiscard); + #[macro_export] + macro_rules! mpflags { + ( $( $flag:path ),* ) => {{ + use $crate::system::message_pipe::*; + $( + ($flag as u32) + )|* + }} + } + + /// This macro assists in generating flags for + /// functions and methods found in mojo::system::data_pipe. + /// + /// See mojo::system::data_pipe for the available flags + /// that may be passed. + /// + /// # Examples + /// + /// # dpflags!(Create::None); + /// # dpflags!(Read::AllOrNone, Read::Discard); + #[macro_export] + macro_rules! dpflags { + ( $( $flag:path ),* ) => {{ + use $crate::system::data_pipe::*; + $( + ($flag as u32) + )|* + }} + } + + /// This macro assists in generating flags for + /// functions and methods found in mojo::system::shared_buffer. + /// + /// See mojo::system::shared_buffer for the available flags + /// that may be passed. + /// + /// # Examples + /// + /// # sbflags!(Create::None); + /// # sbflags!(Map::None); + #[macro_export] + macro_rules! sbflags { + ( $( $flag:path ),* ) => {{ + use $crate::system::shared_buffer::*; + $( + ($flag as u32) + )|* + }} + } + + /// This macro assists in generating flags for + /// functions and methods found in mojo::system::wait_set. + /// + /// See mojo::system::wait_set for the available flags + /// that may be passed. + /// + /// # Examples + /// + /// # wsflags!(Create::None); + /// # wsflags!(Add::None); + #[macro_export] + macro_rules! wsflags { + ( $( $flag:path ),* ) => {{ + use $crate::system::wait_set::*; + $( + ($flag as u32) + )|* + }} + } + + /// This macro assists in generating MojoSignals to be + /// used in wait() and wait_many(), part of mojo::system::core. + /// + /// See mojo::system::handle for the available signals + /// that may be checked for by wait() and wait_many(). + /// + /// # Examples + /// + /// # signals!(Signals::Readable, Signals::Writable); + /// # signals!(Signals::PeerClosed); + #[macro_export] + macro_rules! signals { + ( $( $flag:path ),* ) => {{ + use $crate::system::Signals; + $crate::system::HandleSignals::new( + $( + ($flag as u32) + )|* + ) + }} + } +} + +#[macro_use] +pub mod bindings; +pub mod system; + +pub use system::MojoResult;
diff --git a/mojo/public/rust/system/core.rs b/mojo/public/rust/system/core.rs new file mode 100644 index 0000000..9a958ea --- /dev/null +++ b/mojo/public/rust/system/core.rs
@@ -0,0 +1,66 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::ptr; +use std::u32; +use std::vec; + +use system::ffi; +// This full import is intentional; nearly every type in mojo_types needs to be used. +use system::handle; +use system::mojo_types::*; + +/// Get the time ticks now according to the Mojo IPC. As +/// can be seen in the documentation for the Mojo C API, +/// time ticks are meaningless in an absolute sense. Instead, +/// one should compare the results of two of these calls to +/// get a proper notion of time passing. +pub fn get_time_ticks_now() -> MojoTimeTicks { + unsafe { ffi::MojoGetTimeTicksNow() } +} + +/// Waits on many handles (or rather any structures that wrap +/// handles) until the signals declared in 'signals' for each handle +/// are triggered, waiting for a maximum global time of 'deadline'. +/// This function blocks. +pub fn wait_many( + handles: &[&handle::Handle], + signals: &[HandleSignals], + states: &mut [SignalsState], + deadline: MojoDeadline, +) -> (i32, MojoResult) { + assert_eq!(handles.len(), signals.len()); + assert!(states.len() == handles.len() || states.len() == 0); + let num_inputs = handles.len(); + if num_inputs == 0 { + let result = MojoResult::from_code(unsafe { + ffi::MojoWaitMany( + ptr::null(), + ptr::null(), + 0, + deadline, + ptr::null_mut(), + ptr::null_mut(), + ) + }); + return (-1, result); + } + let states_ptr = if states.len() != 0 { states.as_mut_ptr() } else { ptr::null_mut() }; + let mut index: u32 = u32::MAX; + let result = unsafe { + let mut raw_handles: vec::Vec<MojoHandle> = vec::Vec::with_capacity(num_inputs); + for handle in handles.iter() { + raw_handles.push(handle.get_native_handle()) + } + MojoResult::from_code(ffi::MojoWaitMany( + raw_handles.as_ptr(), + signals.as_ptr(), + num_inputs as u32, + deadline, + &mut index as *mut u32, + states_ptr, + )) + }; + (index as i32, result) +}
diff --git a/mojo/public/rust/system/data_pipe.rs b/mojo/public/rust/system/data_pipe.rs new file mode 100644 index 0000000..d9a4061 --- /dev/null +++ b/mojo/public/rust/system/data_pipe.rs
@@ -0,0 +1,449 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::marker; +use std::mem; +use std::ops; +use std::ptr; +use std::slice; +use std::vec; + +use system::ffi; +// This full import is intentional; nearly every type in mojo_types needs to be used. +use system::handle; +use system::handle::{CastHandle, Handle}; +use system::mojo_types::*; + +#[repr(u32)] +/// Create flags for data pipes +pub enum Create { + None = 0, +} + +#[repr(u32)] +/// Write flags for data pipes +pub enum Write { + None = 0, + + /// Write all the data to the pipe if possible or none at all + AllOrNone = 1 << 0, +} + +#[repr(u32)] +/// Read flags for data pipes +pub enum Read { + None = 0, + + /// Read all the data from the pipe if possible, or none at all + AllOrNone = 1 << 0, + + /// Dequeue the message recieved rather than reading it + Discard = 1 << 1, + + /// Get information about the queue on the pipe but do not perform the + /// read + Query = 1 << 2, + + /// Read data off the pipe's queue but do not dequeue it + Peek = 1 << 3, +} + +/// Intermediary structure in a two-phase read. +/// Reads of the requested buffer must be done directly +/// through this data structure which must then be committed. +pub struct ReadDataBuffer<'b, 'p, T> +where + 'p: 'b, + T: 'p, +{ + buffer: &'b [T], + + /// Contains a reference to parent to end commit + /// and prevent it from outliving its parent handle. + parent: &'p Consumer<T>, +} + +impl<'b, 'p, T> ReadDataBuffer<'b, 'p, T> +where + 'p: 'b, + T: 'p, +{ + /// Attempts to commit the read, that is, end the two-phase read + /// started by the parent Consumer<T> object. On a successful + /// commit, consumes self, otherwise returns self to try again. + pub fn commit(self, bytes_read: usize) -> Option<(Self, MojoResult)> { + let result = unsafe { self.parent.end_read(bytes_read) }; + if result == MojoResult::Okay { + None + } else { + Some((self, result)) + } + } + + /// Returns the length of the underlying buffer + pub fn len(&self) -> usize { + self.buffer.len() + } +} + +impl<'b, 'p, T> ops::Index<usize> for ReadDataBuffer<'b, 'p, T> +where + 'p: 'b, + T: 'p, +{ + type Output = T; + + /// Overloads the indexing ([]) operator for reads. + /// + /// Part of reimplementing the array interface to be + /// able to use the structure naturally. + fn index(&self, index: usize) -> &T { + &self.buffer[index] + } +} + +/// Intermediary structure in a two-phase write. +/// Writes to the requested buffer must be done directly +/// through this data structure which must then be committed. +pub struct WriteDataBuffer<'b, 'p, T> +where + 'p: 'b, + T: 'p, +{ + buffer: &'b mut [T], + + /// Contains a reference to parent to end commit + /// and prevent it from outliving its parent handle. + parent: &'p Producer<T>, +} + +impl<'b, 'p, T> WriteDataBuffer<'b, 'p, T> +where + 'p: 'b, + T: 'p, +{ + /// Attempts to commit the write, that is, end the two-phase + /// write started by a Producer. On a successful + /// commit, consumes self, otherwise returns self to try again. + pub fn commit(self, bytes_written: usize) -> Option<(Self, MojoResult)> { + let result = unsafe { self.parent.end_write(bytes_written) }; + if result == MojoResult::Okay { + None + } else { + Some((self, result)) + } + } + + /// Returns the length of the underlying buffer + pub fn len(&self) -> usize { + self.buffer.len() + } +} + +impl<'b, 'p, T> ops::Index<usize> for WriteDataBuffer<'b, 'p, T> +where + 'p: 'b, + T: 'p, +{ + type Output = T; + + /// Overloads the indexing ([]) operator for reads. + /// + /// Part of reimplementing the array interface to be + /// able to use the structure naturally. + fn index(&self, index: usize) -> &T { + &self.buffer[index] + } +} + +impl<'b, 'p, T> ops::IndexMut<usize> for WriteDataBuffer<'b, 'p, T> +where + 'p: 'b, + T: 'p, +{ + /// Overloads the indexing ([]) operator for writes. + /// + /// Part of reimplementing the array interface to be + /// able to use the structure naturally. + fn index_mut(&mut self, index: usize) -> &mut T { + &mut self.buffer[index] + } +} + +/// Creates a data pipe, represented as a consumer +/// and a producer. Additionally, we associate a type +/// T with the data pipe, as data pipes operate in terms +/// of elements. In this way we can enforce type safety. +/// +/// Capacity, as an input, must be given in number of elements. +/// Use a capacity of 0 in order to use some system-dependent +/// default capacity. +pub fn create<T>( + flags: CreateFlags, + capacity: u32, +) -> Result<(Consumer<T>, Producer<T>), MojoResult> { + let elem_size = mem::size_of::<T>() as u32; + let opts = ffi::MojoCreateDataPipeOptions { + struct_size: mem::size_of::<ffi::MojoCreateDataPipeOptions>() as u32, + flags: flags, + element_num_bytes: elem_size, + capacity_num_bytes: capacity * elem_size, + _align: [], + }; + // TODO(mknyszek): Make sure handles are valid + let mut chandle: MojoHandle = 0; + let mut phandle: MojoHandle = 0; + let raw_opts = &opts as *const ffi::MojoCreateDataPipeOptions; + let r = MojoResult::from_code(unsafe { + ffi::MojoCreateDataPipe( + raw_opts, + &mut phandle as *mut MojoHandle, + &mut chandle as *mut MojoHandle, + ) + }); + if r != MojoResult::Okay { + Err(r) + } else { + Ok(( + Consumer::<T> { + handle: unsafe { handle::acquire(chandle) }, + _elem_type: marker::PhantomData, + }, + Producer::<T> { + handle: unsafe { handle::acquire(phandle) }, + _elem_type: marker::PhantomData, + }, + )) + } +} + +/// Creates a data pipe, represented as a consumer +/// and a producer, using the default Mojo options. +pub fn create_default() -> Result<(Consumer<u8>, Producer<u8>), MojoResult> { + create::<u8>(Create::None as u32, 0) +} + +/// Represents the consumer half of a data pipe. +/// This data structure wraps a handle and acts +/// effectively as a typed handle. +/// +/// The purpose of the _elem_type field is to associate +/// a type with the consumer, as a data pipe works +/// in elements. +pub struct Consumer<T> { + handle: handle::UntypedHandle, + _elem_type: marker::PhantomData<T>, +} + +impl<T> Consumer<T> { + /// Perform a read operation on the consumer end of the data pipe. As + /// a result, we get an std::vec::Vec filled with whatever was written. + pub fn read(&self, flags: ReadFlags) -> Result<vec::Vec<T>, MojoResult> { + let mut num_bytes: u32 = 0; + let r_prelim = unsafe { + ffi::MojoReadData( + self.handle.get_native_handle(), + ptr::null_mut() as *mut ffi::c_void, + &mut num_bytes as *mut u32, + 1 << 2 as ReadFlags, + ) + }; + if r_prelim != 0 || num_bytes == 0 { + return Err(MojoResult::from_code(r_prelim)); + } + let elem_size: u32 = mem::size_of::<T>() as u32; + // TODO(mknyszek): make sure elem_size divides into num_bytes + let mut buf: vec::Vec<T> = vec::Vec::with_capacity((num_bytes / elem_size) as usize); + let r = MojoResult::from_code(unsafe { + ffi::MojoReadData( + self.handle.get_native_handle(), + buf.as_mut_ptr() as *const ffi::c_void, + &mut num_bytes as *mut u32, + flags, + ) + }); + unsafe { buf.set_len((num_bytes / elem_size) as usize) } + if r != MojoResult::Okay { + Err(r) + } else { + Ok(buf) + } + } + + /// Start two-phase read and return a ReadDataBuffer to perform + /// read and commit. + pub fn begin(&self, flags: ReadFlags) -> Result<ReadDataBuffer<T>, MojoResult> { + let wrapped_result = unsafe { self.begin_read(flags) }; + match wrapped_result { + Ok(arr) => Ok(ReadDataBuffer::<T> { buffer: arr, parent: self }), + Err(r) => Err(r), + } + } + + /// A private function that performs the first half of two-phase reading. + /// Kept private because it is unsafe to use (the array received may not + /// be valid if end_read is performed). + unsafe fn begin_read(&self, flags: ReadFlags) -> Result<&[T], MojoResult> { + let mut buf_num_bytes: u32 = 0; + let mut pbuf: *mut ffi::c_void = mem::uninitialized(); + let r = MojoResult::from_code(ffi::MojoBeginReadData( + self.handle.get_native_handle(), + &mut pbuf, + &mut buf_num_bytes as *mut u32, + flags, + )); + if r != MojoResult::Okay { + Err(r) + } else { + let buf_elems = (buf_num_bytes as usize) / mem::size_of::<T>(); + let buf = slice::from_raw_parts(pbuf as *mut T, buf_elems); + Ok(buf) + } + } + + /// A private function that performs the second half of two-phase reading. + /// Kept private because it is unsafe to use (the array received from start_read + /// may not be valid if end_read is performed). + /// + /// Also assumes loads/stores aren't reordered such that a load/store may be + /// optimized to be run AFTER MojoEndReadData(). In general, this is true as long + /// as raw pointers are used, but Rust's memory model is still undefined. If you're + /// getting a bad/strange runtime error, it might be for this reason. + unsafe fn end_read(&self, elems_read: usize) -> MojoResult { + let elem_size = mem::size_of::<T>(); + MojoResult::from_code(ffi::MojoEndReadData( + self.handle.get_native_handle(), + (elems_read * elem_size) as u32, + )) + } +} + +impl<T> CastHandle for Consumer<T> { + /// Generates a Consumer from an untyped handle wrapper + /// See mojo::system::handle for information on untyped vs. typed + unsafe fn from_untyped(handle: handle::UntypedHandle) -> Self { + Consumer::<T> { handle: handle, _elem_type: marker::PhantomData } + } + + /// Consumes this object and produces a plain handle wrapper + /// See mojo::system::handle for information on untyped vs. typed + fn as_untyped(self) -> handle::UntypedHandle { + self.handle + } +} + +impl<T> Handle for Consumer<T> { + /// Returns the native handle wrapped by this structure. + /// + /// See mojo::system::handle for information on handle wrappers + fn get_native_handle(&self) -> MojoHandle { + self.handle.get_native_handle() + } +} + +/// Represents the consumer half of a data pipe. +/// This data structure wraps a handle and acts +/// effectively as a typed handle. +/// +/// The purpose of the _elem_type field is to associate +/// a type with the consumer, as a data pipe works +/// in elements. +pub struct Producer<T> { + handle: handle::UntypedHandle, + _elem_type: marker::PhantomData<T>, +} + +impl<T> Producer<T> { + /// Perform a write operation on the producer end of the data pipe. + /// Returns the number of elements actually written. + pub fn write(&self, data: &[T], flags: WriteFlags) -> Result<usize, MojoResult> { + let mut num_bytes = (data.len() * mem::size_of::<T>()) as u32; + let r = MojoResult::from_code(unsafe { + ffi::MojoWriteData( + self.handle.get_native_handle(), + data.as_ptr() as *const ffi::c_void, + &mut num_bytes as *mut u32, + flags, + ) + }); + if r != MojoResult::Okay { + Err(r) + } else { + Ok(num_bytes as usize) + } + } + + /// Start two-phase write and return a WriteDataBuffer to perform + /// write and commit. + /// + /// Borrows self as mutable so that no other operation may happen on + /// the producer until the two-phase write is committed. + pub fn begin(&self, flags: WriteFlags) -> Result<WriteDataBuffer<T>, MojoResult> { + let wrapped_result = unsafe { self.begin_write(flags) }; + match wrapped_result { + Ok(arr) => Ok(WriteDataBuffer::<T> { buffer: arr, parent: self }), + Err(r) => Err(r), + } + } + + /// A private function that performs the first half of two-phase writing. + /// Kept private because it is unsafe to use (the array received may not + /// be valid if end_write is performed). + unsafe fn begin_write(&self, flags: WriteFlags) -> Result<&mut [T], MojoResult> { + let mut buf_num_bytes: u32 = 0; + let mut pbuf: *mut ffi::c_void = mem::uninitialized(); + let r = MojoResult::from_code(ffi::MojoBeginWriteData( + self.handle.get_native_handle(), + &mut pbuf, + &mut buf_num_bytes as *mut u32, + flags, + )); + if r != MojoResult::Okay { + Err(r) + } else { + let buf_elems = (buf_num_bytes as usize) / mem::size_of::<T>(); + let buf = slice::from_raw_parts_mut(pbuf as *mut T, buf_elems); + Ok(buf) + } + } + + /// A private function that performs the second half of two-phase writing. + /// Kept private because it is unsafe to use (the array received from start_write + /// may not be valid if end_write is performed). + /// + /// Also assumes loads/stores aren't reordered such that a load/store may be + /// optimized to be run AFTER MojoEndWriteData(). In general, this is true as long + /// as raw pointers are used, but Rust's memory model is still undefined. If you're + /// getting a bad/strange runtime error, it might be for this reason. + unsafe fn end_write(&self, elems_written: usize) -> MojoResult { + let elem_size = mem::size_of::<T>(); + MojoResult::from_code(ffi::MojoEndWriteData( + self.handle.get_native_handle(), + (elems_written * elem_size) as u32, + )) + } +} + +impl<T> CastHandle for Producer<T> { + /// Generates a Consumer from an untyped handle wrapper + /// See mojo::system::handle for information on untyped vs. typed + unsafe fn from_untyped(handle: handle::UntypedHandle) -> Self { + Producer::<T> { handle: handle, _elem_type: marker::PhantomData } + } + + /// Consumes this object and produces a plain handle wrapper + /// See mojo::system::handle for information on untyped vs. typed + fn as_untyped(self) -> handle::UntypedHandle { + self.handle + } +} + +impl<T> Handle for Producer<T> { + /// Returns the native handle wrapped by this structure. + /// + /// See mojo::system::handle for information on handle wrappers + fn get_native_handle(&self) -> MojoHandle { + self.handle.get_native_handle() + } +}
diff --git a/mojo/public/rust/system/ffi.rs b/mojo/public/rust/system/ffi.rs new file mode 100644 index 0000000..12615bf0 --- /dev/null +++ b/mojo/public/rust/system/ffi.rs
@@ -0,0 +1,252 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! This ffi module is used to interact with the +//! Mojo C bindings API. The structures below are +//! undocumented because they are pulled exactly +//! from the header files of the Mojo C bindings +//! API which can be found in the Mojo repository[1] +//! under //mojo/public/c/include/mojo/system. Very +//! clear documentation on these structures and +//! functions can be found there. It is not worth +//! elaborating on these all again here. +//! +//! [1] https://github.com/domokit/mojo + +// This full import is intentional; nearly every type in mojo_types needs to be used. +use system::mojo_types::*; + +#[allow(bad_style)] +/// This empty enum is used solely to provide +/// a notion of void from C. The truth is, the +/// correct move here is to use the libc Rust +/// package but, as it turns out, that's the only +/// part of libc we actually need. Rather than +/// force ourselves to pull in a dependency, we +/// instead implement libc's notion of c_void +/// here. +pub enum c_void {} + +pub mod types { + //! Defines some C-compatible types for the ffi layer of + //! the bindings. + + pub type MojoCreateSharedBufferOptionsFlags = u32; + pub type MojoDuplicateBufferHandleOptionsFlags = u32; + pub type MojoBufferInfoFlags = u32; + pub type MojoMapBufferFlags = u32; + pub type MojoCreateDataPipeOptionsFlags = u32; + pub type MojoWriteDataFlags = u32; + pub type MojoReadDataFlags = u32; + pub type MojoHandleSignals = u32; + pub type MojoCreateMessagePipeOptionsFlags = u32; + pub type MojoWriteMessageFlags = u32; + pub type MojoReadMessageFlags = u32; + pub type MojoCreateWaitSetOptionsFlags = u32; + pub type MojoWaitSetAddOptionsFlags = u32; + pub type MojoResultCode = u32; +} + +use system::ffi::types::*; + +#[repr(C)] +pub struct MojoCreateSharedBufferOptions { + pub struct_size: u32, + pub flags: MojoCreateSharedBufferOptionsFlags, + pub _align: [u64; 0], // Hack to align struct to 8 byte boundary +} + +#[repr(C)] +pub struct MojoDuplicateBufferHandleOptions { + pub struct_size: u32, + pub flags: MojoDuplicateBufferHandleOptionsFlags, + pub _align: [u64; 0], // Hack to align struct to 8 byte boundary +} + +#[repr(C)] +pub struct MojoBufferInformation { + pub struct_size: u32, + pub flags: MojoBufferInfoFlags, + pub num_bytes: u64, + pub _align: [u64; 0], // Hack to align struct to 8 byte boundary +} + +#[repr(C)] +pub struct MojoCreateDataPipeOptions { + pub struct_size: u32, + pub flags: MojoCreateDataPipeOptionsFlags, + pub element_num_bytes: u32, + pub capacity_num_bytes: u32, + pub _align: [u64; 0], // Hack to align struct to 8 byte boundary +} + +#[repr(C)] +pub struct MojoCreateMessagePipeOptions { + pub struct_size: u32, + pub flags: MojoCreateMessagePipeOptionsFlags, + pub _align: [u64; 0], // Hack to align struct to 8 byte boundary +} + +#[repr(C)] +pub struct MojoCreateWaitSetOptions { + pub struct_size: u32, + pub flags: MojoCreateWaitSetOptionsFlags, + pub _align: [u64; 0], // Hack to align struct to 8 byte boundary +} + +#[repr(C)] +pub struct MojoWaitSetAddOptions { + pub struct_size: u32, + pub flags: MojoWaitSetAddOptionsFlags, + pub _align: [u64; 0], // Hack to align struct to 8 byte boundary +} + +#[link] +extern "C" { + // From //mojo/public/c/include/mojo/system/buffer.h + pub fn MojoCreateSharedBuffer( + options: *const MojoCreateSharedBufferOptions, + num_bytes: u64, + shared_buffer_handle: *mut MojoHandle, + ) -> MojoResultCode; + + pub fn MojoDuplicateBufferHandle( + handle: MojoHandle, + options: *const MojoDuplicateBufferHandleOptions, + new_buffer_handle: *mut MojoHandle, + ) -> MojoResultCode; + + pub fn MojoGetBufferInformation( + buffer_handle: MojoHandle, + info: *mut MojoBufferInformation, + info_num_bytes: u32, + ) -> MojoResultCode; + + pub fn MojoMapBuffer( + buffer_handle: MojoHandle, + offset: u64, + num_bytes: u64, + buffer: *mut *mut c_void, + flags: MojoMapBufferFlags, + ) -> MojoResultCode; + + pub fn MojoUnmapBuffer(buffer: *const c_void) -> MojoResultCode; + + // From //mojo/public/c/include/mojo/system/data_pipe.h + pub fn MojoCreateDataPipe( + options: *const MojoCreateDataPipeOptions, + data_pipe_producer_handle: *mut MojoHandle, + data_pipe_consumer_handle: *mut MojoHandle, + ) -> MojoResultCode; + + pub fn MojoWriteData( + data_pipe_producer_handle: MojoHandle, + elements: *const c_void, + num_bytes: *mut u32, + flags: MojoWriteDataFlags, + ) -> MojoResultCode; + + pub fn MojoBeginWriteData( + data_pipe_producer_handle: MojoHandle, + buffer: *mut *mut c_void, + buffer_num_bytes: *mut u32, + flags: MojoWriteDataFlags, + ) -> MojoResultCode; + + pub fn MojoEndWriteData( + data_pipe_producer_handle: MojoHandle, + num_bytes_written: u32, + ) -> MojoResultCode; + + pub fn MojoReadData( + data_pipe_consumer_handle: MojoHandle, + elements: *const c_void, + num_bytes: *mut u32, + flags: MojoReadDataFlags, + ) -> MojoResultCode; + + pub fn MojoBeginReadData( + data_pipe_consumer_handle: MojoHandle, + buffer: *mut *mut c_void, + buffer_num_bytes: *mut u32, + flags: MojoReadDataFlags, + ) -> MojoResultCode; + + pub fn MojoEndReadData( + data_pipe_consumer_handle: MojoHandle, + num_bytes_written: u32, + ) -> MojoResultCode; + + // From //mojo/public/c/include/mojo/system/handle.h + pub fn MojoClose(handle: MojoHandle) -> MojoResultCode; + + // From //mojo/public/c/include/mojo/system/message_pipe.h + pub fn MojoCreateMessagePipe( + options: *const MojoCreateMessagePipeOptions, + message_pipe_handle0: *mut MojoHandle, + message_pipe_handle1: *mut MojoHandle, + ) -> MojoResultCode; + + pub fn MojoWriteMessage( + message_pipe_handle: MojoHandle, + bytes: *const c_void, + num_bytes: u32, + handles: *const MojoHandle, + num_handles: u32, + flags: MojoWriteMessageFlags, + ) -> MojoResultCode; + + pub fn MojoReadMessage( + message_pipe_handle: MojoHandle, + bytes: *mut c_void, + num_bytes: *mut u32, + handles: *mut MojoHandle, + num_handles: *mut u32, + flags: MojoWriteMessageFlags, + ) -> MojoResultCode; + + // From //mojo/public/c/include/mojo/system/time.h + pub fn MojoGetTimeTicksNow() -> MojoTimeTicks; + + // From //mojo/public/c/include/mojo/system/wait.h + pub fn MojoWait( + handle: MojoHandle, + signals: HandleSignals, + deadline: MojoDeadline, + signals_state: *mut SignalsState, + ) -> MojoResultCode; + + pub fn MojoWaitMany( + handles: *const MojoHandle, + signals: *const HandleSignals, + num_handles: u32, + deadline: MojoDeadline, + result_index: *mut u32, + signals_states: *mut SignalsState, + ) -> MojoResultCode; + + // From //mojo/public/c/include/mojo/system/wait_set.h + pub fn MojoCreateWaitSet( + options: *const MojoCreateWaitSetOptions, + handle: *mut MojoHandle, + ) -> MojoResultCode; + + pub fn MojoWaitSetAdd( + wait_set_handle: MojoHandle, + handle: MojoHandle, + signals: HandleSignals, + cookie: u64, + options: *const MojoWaitSetAddOptions, + ) -> MojoResultCode; + + pub fn MojoWaitSetRemove(wait_set_handle: MojoHandle, cookie: u64) -> MojoResultCode; + + pub fn MojoWaitSetWait( + wait_set_handle: MojoHandle, + deadline: MojoDeadline, + num_results: *mut u32, + results: *mut WaitSetResult, + max_results: *mut u32, + ) -> MojoResultCode; +}
diff --git a/mojo/public/rust/system/handle.rs b/mojo/public/rust/system/handle.rs new file mode 100644 index 0000000..7cd8e3a --- /dev/null +++ b/mojo/public/rust/system/handle.rs
@@ -0,0 +1,124 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! The way Mojo handles are handled in Rust is very similar +//! to Go, though more type-safe. Here we define an "untyped" +//! handle which means that what the handle actually represents +//! is unknown. This is the basic handle wrapper, and it handles +//! closing the handle once the wrapper goes out of scope, therefore +//! preventing any resources from being leaked. "Typed" handles +//! are MessageEndpoints or Consumers/Producers in this library +//! and they represent handles which represent parts of message pipes, +//! data pipes, and shared buffers. Typed handles wrap untyped handles +//! but act much the same as untyped handles. + +use system::ffi; +// This full import is intentional; nearly every type in mojo_types needs to be used. +use system::mojo_types::*; + +/// The CastHandle trait defines an interface to convert between +/// typed and untyped handles. These are only used internally for +/// typed handles. +pub trait CastHandle { + /// Passes an ownership of an untyped handle and produces a + /// typed handle which owns that untyped handle + /// + /// Casting to a typed handle is unsafe because the handle may + /// not necessarily be a handle to the typed primitive being + /// casted to internally in Mojo + unsafe fn from_untyped(handle: UntypedHandle) -> Self; + + /// Consumes a typed handle and releases ownership of the + /// untyped handle it owned + fn as_untyped(self) -> UntypedHandle; +} + +/// The Handle trait simply means that we can extract +/// the native integer-based Mojo handle from the typed handle. +pub trait Handle { + /// Returns the native handle wrapped by whatever structure + /// implements this trait. + fn get_native_handle(&self) -> MojoHandle; + + /// Waits on the handle wrapped in the current struct until the signals + /// declared in 'signals' are triggered, waiting for a maximum time of + /// 'deadline'. This method blocks. + /// + /// Returns the satisfied and satisfiable signals respectively for this + /// handle when waiting is done. + fn wait(&self, signals: HandleSignals, deadline: MojoDeadline) -> (SignalsState, MojoResult) { + let mut state: SignalsState = Default::default(); + let r = unsafe { + ffi::MojoWait( + self.get_native_handle(), + signals, + deadline, + &mut state as *mut SignalsState, + ) + }; + (state, MojoResult::from_code(r)) + } +} + +/// The basic untyped handle that wraps a MojoHandle (a u32) +pub struct UntypedHandle { + /// The native Mojo handle + value: MojoHandle, +} + +impl UntypedHandle { + /// Invalidates the Handle by setting its native handle to + /// zero, the canonical invalid handle in Mojo. + /// + /// This function is unsafe because clearing a native handle + /// without closing it is a resource leak. + pub unsafe fn invalidate(&mut self) { + self.value = 0 + } + + /// Checks if the native handle is valid (0 = canonical invalid handle). + pub fn is_valid(&self) -> bool { + self.value != 0 + } +} + +impl Handle for UntypedHandle { + /// Pulls out a copy of the native handle wrapped by this structure. + fn get_native_handle(&self) -> MojoHandle { + self.value + } +} + +impl CastHandle for UntypedHandle { + /// Casting an untyped handle is a no-op, but we include + /// this to eliminate code duplication. + unsafe fn from_untyped(handle: UntypedHandle) -> Self { + handle + } + + /// Casting to an untyped handle is a no-op, but we include + /// this to eliminate code duplication. + fn as_untyped(self) -> UntypedHandle { + self + } +} + +impl Drop for UntypedHandle { + /// The destructor for an untyped handle which closes the native handle + /// it wraps. + fn drop(&mut self) { + if self.is_valid() { + let result = MojoResult::from_code(unsafe { ffi::MojoClose(self.get_native_handle()) }); + if result != MojoResult::Okay { + panic!("Failed to close handle! Reason: {}", result); + } + } + } +} + +/// Acquires a native handle by wrapping it in an untyped handle, allowing +/// us to track the resource and free it appropriately +pub unsafe fn acquire(handle: MojoHandle) -> UntypedHandle { + UntypedHandle { value: handle } +}
diff --git a/mojo/public/rust/system/message_pipe.rs b/mojo/public/rust/system/message_pipe.rs new file mode 100644 index 0000000..19f46c2 --- /dev/null +++ b/mojo/public/rust/system/message_pipe.rs
@@ -0,0 +1,220 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::mem; +use std::ptr; +use std::vec; + +use system::ffi; +use system::handle; +use system::handle::{CastHandle, Handle}; +// This full import is intentional; nearly every type in mojo_types needs to be used. +use system::mojo_types::*; + +#[repr(u32)] +/// Create flags for message pipes +pub enum Create { + None = 0, +} + +#[repr(u32)] +/// Write flags for message pipes +pub enum Write { + None = 0, +} + +#[repr(u32)] +/// Read flags for message pipes +pub enum Read { + None = 0, + + /// If the message is unable to be + /// read for whatever reason, dequeue + /// it anyway + MayDiscard = 1 << 0, +} + +/// Creates a message pipe in Mojo and gives back two +/// MessageEndpoints which represent the endpoints of the +/// message pipe +pub fn create(flags: CreateFlags) -> Result<(MessageEndpoint, MessageEndpoint), MojoResult> { + let mut handle0: MojoHandle = 0; + let mut handle1: MojoHandle = 0; + let opts = ffi::MojoCreateMessagePipeOptions { + struct_size: mem::size_of::<ffi::MojoCreateMessagePipeOptions>() as u32, + flags: flags, + _align: [], + }; + let raw_opts = &opts as *const ffi::MojoCreateMessagePipeOptions; + let r = MojoResult::from_code(unsafe { + ffi::MojoCreateMessagePipe( + raw_opts, + &mut handle0 as *mut MojoHandle, + &mut handle1 as *mut MojoHandle, + ) + }); + if r != MojoResult::Okay { + Err(r) + } else { + Ok(( + MessageEndpoint { handle: unsafe { handle::acquire(handle0) } }, + MessageEndpoint { handle: unsafe { handle::acquire(handle1) } }, + )) + } +} + +/// Represents the one endpoint of a message pipe. +/// This data structure wraps a handle and acts +/// effectively as a typed handle. +pub struct MessageEndpoint { + handle: handle::UntypedHandle, +} + +impl MessageEndpoint { + /// Read the next message from the endpoint. Messages in Mojo + /// are some set of bytes plus a bunch of handles, so we + /// return both a vector of bytes and a vector of untyped handles. + /// + /// Because the handles are untyped, it is up to the user of this + /// library to know what type the handle actually is and to use + /// from_untyped in order to convert the handle to the correct type. + /// This is abstracted away, however, when using the Mojo bindings + /// generator where you may specify your interface in Mojom. + /// + /// If an empty message (that is, it has neither data nor handles) + /// is received, it will show up as an Err() containing MojoResult::Okay. + pub fn read( + &self, + flags: ReadFlags, + ) -> Result<(vec::Vec<u8>, vec::Vec<handle::UntypedHandle>), MojoResult> { + let mut num_bytes: u32 = 0; + let mut num_handles: u32 = 0; + let result_prelim = MojoResult::from_code(unsafe { + ffi::MojoReadMessage( + self.handle.get_native_handle(), + ptr::null_mut(), + &mut num_bytes as *mut u32, + ptr::null_mut(), + &mut num_handles as *mut u32, + flags, + ) + }); + if result_prelim != MojoResult::ResourceExhausted { + return Err(result_prelim); + } + let mut buf: vec::Vec<u8> = vec::Vec::with_capacity(num_bytes as usize); + let mut raw_handles: vec::Vec<MojoHandle> = vec::Vec::with_capacity(num_handles as usize); + let buf_ptr; + if num_bytes == 0 { + buf_ptr = ptr::null_mut(); + } else { + buf_ptr = buf.as_mut_ptr() as *mut ffi::c_void; + } + let raw_handles_ptr; + if num_handles == 0 { + raw_handles_ptr = ptr::null_mut(); + } else { + raw_handles_ptr = raw_handles.as_mut_ptr(); + } + let r = MojoResult::from_code(unsafe { + ffi::MojoReadMessage( + self.handle.get_native_handle(), + buf_ptr, + &mut num_bytes as *mut u32, + raw_handles_ptr, + &mut num_handles as *mut u32, + flags, + ) + }); + unsafe { + buf.set_len(num_bytes as usize); + raw_handles.set_len(num_handles as usize); + } + let mut handles: vec::Vec<handle::UntypedHandle> = + vec::Vec::with_capacity(num_handles as usize); + for raw_handle in raw_handles.iter() { + handles.push(unsafe { handle::acquire(*raw_handle) }); + } + if r != MojoResult::Okay { + Err(r) + } else { + Ok((buf, handles)) + } + } + + /// Write a message to the endpoint. Messages in Mojo + /// are some set of bytes plus a bunch of handles, so we + /// return both a vector of bytes and a vector of untyped handles. + /// + /// Because the handles are untyped, it is up to the user of this + /// library to know what type the handle actually is and to use + /// from_untyped in order to convert the handle to the correct type. + /// This is abstracted away, however, when using the Mojo bindings + /// generator where you may specify your interface in Mojom. + /// + /// Additionally, the handles passed in are consumed. This is because + /// Mojo handles operate on move semantics much like Rust data types. + /// When a handle is sent through a message pipe it is invalidated and + /// may not even be represented by the same integer on the other side, + /// so care must be taken to design your application with this in mind. + pub fn write( + &self, + bytes: &[u8], + mut handles: vec::Vec<handle::UntypedHandle>, + flags: WriteFlags, + ) -> MojoResult { + let bytes_ptr; + if bytes.len() == 0 { + bytes_ptr = ptr::null(); + } else { + bytes_ptr = bytes.as_ptr() as *const ffi::c_void; + } + let mut raw_handles: vec::Vec<MojoHandle> = vec::Vec::with_capacity(handles.len()); + for handle in handles.iter_mut() { + unsafe { + raw_handles.push(handle.get_native_handle()); + handle.invalidate(); + } + } + let raw_handles_ptr; + if raw_handles.len() == 0 { + raw_handles_ptr = ptr::null(); + } else { + raw_handles_ptr = raw_handles.as_ptr(); + } + return MojoResult::from_code(unsafe { + ffi::MojoWriteMessage( + self.handle.get_native_handle(), + bytes_ptr, + bytes.len() as u32, + raw_handles_ptr, + raw_handles.len() as u32, + flags, + ) + }); + } +} + +impl CastHandle for MessageEndpoint { + /// Generates a MessageEndpoint from an untyped handle wrapper + /// See mojo::system::handle for information on untyped vs. typed + unsafe fn from_untyped(handle: handle::UntypedHandle) -> Self { + MessageEndpoint { handle: handle } + } + + /// Consumes this object and produces a plain handle wrapper + /// See mojo::system::handle for information on untyped vs. typed + fn as_untyped(self) -> handle::UntypedHandle { + self.handle + } +} + +impl Handle for MessageEndpoint { + /// Returns the native handle wrapped by this structure. + /// + /// See mojo::system::handle for information on handle wrappers + fn get_native_handle(&self) -> MojoHandle { + self.handle.get_native_handle() + } +}
diff --git a/mojo/public/rust/system/mod.rs b/mojo/public/rust/system/mod.rs new file mode 100644 index 0000000..0258e79 --- /dev/null +++ b/mojo/public/rust/system/mod.rs
@@ -0,0 +1,18 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +mod ffi; +mod handle; +mod mojo_types; + +pub mod core; +pub mod data_pipe; +pub mod message_pipe; +pub mod shared_buffer; +pub mod wait_set; + +// In order to keep the interface clean, we re-export basic Mojo and handle +// types and traits here in the system module. +pub use system::handle::*; +pub use system::mojo_types::*;
diff --git a/mojo/public/rust/system/mojo_types.rs b/mojo/public/rust/system/mojo_types.rs new file mode 100644 index 0000000..1e9a67b --- /dev/null +++ b/mojo/public/rust/system/mojo_types.rs
@@ -0,0 +1,274 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! This module contains a variety of types which are used +//! for representing representing flag arguments a little bit +//! better than just as u32 and some other basic Mojo types that +//! we need to expose. +//! +//! This module also provides MojoResult which is the canonical +//! result coding system used by Mojo. +//! +//! Many places in the system code directly import this module as +//! a whole because it is intended to be used that way. It contains +//! all of the basic types needed by all system-level Mojo bindings. + +use std::fmt; +use std::u64; +use system::ffi::types::*; + +/// A MojoHandle is represented as a plain 32-bit unsigned int. +pub type MojoHandle = u32; + +/// Represents time ticks as specified by Mojo. A time tick value +/// is meaningless when not used relative to another time tick. +pub type MojoTimeTicks = i64; + +/// Represents a deadline for wait() calls. +pub type MojoDeadline = u64; +pub static MOJO_INDEFINITE: MojoDeadline = u64::MAX; + +pub type CreateFlags = u32; +pub type DuplicateFlags = u32; +pub type InfoFlags = u32; +pub type MapFlags = u32; +pub type WriteFlags = u32; +pub type ReadFlags = u32; +pub type AddFlags = u32; + +/// MojoResult represents anything that can happen +/// as a result of performing some operation in Mojo. +/// +/// It's implementation matches exactly that found in +/// the Mojo C API so this enum can be used across the +/// FFI boundary simply by using "as u32". +#[derive(Copy, Clone, Debug, PartialEq)] +#[repr(u32)] +pub enum MojoResult { + Okay = 0x0, + Cancelled = 0x1, + Unknown = 0x2, + InvalidArgument = 0x3, + DeadlineExceeded = 0x4, + NotFound = 0x5, + AlreadyExists = 0x6, + PermissionDenied = 0x7, + ResourceExhausted = 0x8, + FailedPrecondition = 0x9, + Aborted = 0xa, + OutOfRange = 0xb, + Unimplemented = 0xc, + Internal = 0xd, + Unavailable = 0xe, + DataLoss = 0xf, + Busy = 0x0019, + ShouldWait = 0x001e, + InvalidResult, +} + +impl MojoResult { + /// Convert a raw u32 code given by the C Mojo functions + /// into a MojoResult. + pub fn from_code(code: MojoResultCode) -> MojoResult { + match code as u32 { + 0x0 => MojoResult::Okay, + 0x1 => MojoResult::Cancelled, + 0x2 => MojoResult::Unknown, + 0x3 => MojoResult::InvalidArgument, + 0x4 => MojoResult::DeadlineExceeded, + 0x5 => MojoResult::NotFound, + 0x6 => MojoResult::AlreadyExists, + 0x7 => MojoResult::PermissionDenied, + 0x8 => MojoResult::ResourceExhausted, + 0x9 => MojoResult::FailedPrecondition, + 0xa => MojoResult::Aborted, + 0xb => MojoResult::OutOfRange, + 0xc => MojoResult::Unimplemented, + 0xd => MojoResult::Internal, + 0xe => MojoResult::Unavailable, + 0xf => MojoResult::DataLoss, + 0x0019 => MojoResult::Busy, + 0x001e => MojoResult::ShouldWait, + _ => MojoResult::InvalidResult, + } + } + + pub fn to_str(&self) -> &'static str { + match *self { + MojoResult::Okay => "OK", + MojoResult::Cancelled => "Cancelled", + MojoResult::Unknown => "Unknown", + MojoResult::InvalidArgument => "Invalid Argument", + MojoResult::DeadlineExceeded => "Deadline Exceeded", + MojoResult::NotFound => "Not Found", + MojoResult::AlreadyExists => "Already Exists", + MojoResult::PermissionDenied => "Permission Denied", + MojoResult::ResourceExhausted => "Resource Exhausted", + MojoResult::FailedPrecondition => "Failed Precondition", + MojoResult::Aborted => "Aborted", + MojoResult::OutOfRange => "Out Of Range", + MojoResult::Unimplemented => "Unimplemented", + MojoResult::Internal => "Internal", + MojoResult::Unavailable => "Unavailable", + MojoResult::DataLoss => "Data Loss", + MojoResult::Busy => "Busy", + MojoResult::ShouldWait => "Should Wait", + MojoResult::InvalidResult => "Something went very wrong", + } + } +} + +impl fmt::Display for MojoResult { + /// Allow a MojoResult to be displayed in a sane manner. + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.to_str()) + } +} + +/// This tuple struct represents a bit vector configuration of possible +/// Mojo signals. Used in wait() and wait_many() primarily as a convenience. +/// +/// One invariant must be true for this data structure and it is that: +/// sizeof(HandleSignals) == sizeof(MojoHandleSignals) +/// If this is ever not the case or there is a way in Rust to ensure that, +/// this data structure must be updated to reflect that. +#[repr(C)] +#[derive(Clone, Copy, Default, PartialEq)] +pub struct HandleSignals(MojoHandleSignals); + +impl HandleSignals { + /// Create a new HandleSignals given the raw MojoHandleSignals + pub fn new(s: MojoHandleSignals) -> HandleSignals { + HandleSignals(s) + } + + /// Check if the readable flag is set + pub fn is_readable(&self) -> bool { + (self.0 & (Signals::Readable as u32)) != 0 + } + + /// Check if the writable flag is set + pub fn is_writable(&self) -> bool { + (self.0 & (Signals::Writable as u32)) != 0 + } + + /// Check if the peer-closed flag is set + pub fn is_peer_closed(&self) -> bool { + (self.0 & (Signals::PeerClosed as u32)) != 0 + } + + /// Check if the read threshold flag is set + pub fn is_read_threshold(&self) -> bool { + (self.0 & (Signals::ReadThreshold as u32)) != 0 + } + + /// Check if the write threshold flag is set + pub fn is_write_threshold(&self) -> bool { + (self.0 & (Signals::WriteThreshold as u32)) != 0 + } + + /// Pull the raw MojoHandleSignals out of the data structure + pub fn get_bits(&self) -> MojoHandleSignals { + self.0 + } +} + +/// Represents the signals state of a handle: which signals are satisfied, +/// and which are satisfiable. +/// +/// One invariant must be true for this data structure and it is that: +/// sizeof(SignalsState) == sizeof(MojoSignalsState) (defined in handle.h) +/// If this is ever not the case or there is a way in Rust to ensure that, +/// this data structure must be updated to reflect that. +#[repr(C)] +#[derive(Default)] +pub struct SignalsState { + satisfied: HandleSignals, + satisfiable: HandleSignals, + _align: [u32; 0], // Hack to align to a 4-byte boundary +} + +impl SignalsState { + /// Generates a new SignalsState + pub fn new(satisfied: HandleSignals, satisfiable: HandleSignals) -> SignalsState { + SignalsState { satisfied: satisfied, satisfiable: satisfiable, _align: [] } + } + /// Gets a reference to the satisfied signals + pub fn satisfied(&self) -> &HandleSignals { + &self.satisfied + } + /// Gets a reference to the satisfiable signals + pub fn satisfiable(&self) -> &HandleSignals { + &self.satisfiable + } + /// Consume the SignalsState and release its tender interior + /// + /// Returns (satisfied, satisfiable) + pub fn unwrap(self) -> (HandleSignals, HandleSignals) { + (self.satisfied, self.satisfiable) + } +} + +/// The different signals options that can be +/// used by wait() and wait_many(). You may use +/// these directly to build a bit-vector, but +/// the signals! macro will already do it for you. +/// See the root of the library for more information. +#[repr(u32)] +pub enum Signals { + None = 0, + /// Wait for the handle to be readable + Readable = 1 << 0, + + /// Wait for the handle to be writable + Writable = 1 << 1, + + /// Wait for the handle to be closed by the peer + /// (for message pipes and data pipes, this is + /// the counterpart handle to the pipe) + PeerClosed = 1 << 2, + + /// Wait for the handle to have at least some + /// readable data + ReadThreshold = 1 << 3, + + /// Wait for the handle to allow for at least + /// some data to be writable + WriteThreshold = 1 << 4, +} + +/// The result struct used by the wait_set module +/// to return wait result information. Should remain +/// semantically identical to the implementation of +/// this struct in wait_set.h in the C bindings. +/// +/// This struct should never be constructed by anything +/// but the Mojo system in MojoWaitSetWait. +#[repr(C)] +pub struct WaitSetResult { + cookie: u64, + result: MojoResultCode, + reserved: u32, + signals_state: SignalsState, + _align: [u64; 0], // Hack to align struct to 8 byte boundary +} + +impl WaitSetResult { + /// Getter for the cookie corresponding to the handle + /// which just finished waiting. + pub fn cookie(&self) -> u64 { + self.cookie + } + + /// Getter for the wait result. + pub fn result(&self) -> MojoResult { + MojoResult::from_code(self.result) + } + + /// Getter for the signals state that comes with any + /// wait result. + pub fn state(&self) -> &SignalsState { + &self.signals_state + } +}
diff --git a/mojo/public/rust/system/shared_buffer.rs b/mojo/public/rust/system/shared_buffer.rs new file mode 100644 index 0000000..1b893772 --- /dev/null +++ b/mojo/public/rust/system/shared_buffer.rs
@@ -0,0 +1,245 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::mem; +use std::ptr; +use std::slice; + +use system::ffi; +// This full import is intentional; nearly every type in mojo_types needs to be used. +use system::handle; +use system::handle::{CastHandle, Handle}; +use system::mojo_types::*; + +#[repr(u32)] +/// Create flags for shared buffers +pub enum Create { + None = 0, +} + +#[repr(u32)] +/// Duplicate flags for shared buffers +pub enum Duplicate { + None = 0, +} + +#[repr(u32)] +/// Map flags for shared buffers +pub enum Map { + None = 0, +} + +#[repr(u32)] +/// Map flags for shared buffers +pub enum Info { + None = 0, +} + +/// A MappedBuffer represents the result of +/// calling map_buffer on a shared buffer handle. +/// +/// The C API allocates a buffer which can then be +/// read or written to through this data structure. +/// +/// The importance of this data structure is that +/// we bind the lifetime of the slice given to us +/// from the C API to this structure. Additionally, +/// reads and writes to this data structure are guaranteed +/// to be able to propagate through Mojo as they are +/// volatile. Language optimization backends are generally +/// unaware of other address spaces, and since this structure +/// represents some buffer from another address space, we +/// need to make sure loads and stores are volatile. +/// +/// Changes to this data structure are propagated through Mojo +/// on the next Mojo operation (that is, Mojo operations are +/// considered barriers). So, unmapping the buffer, sending a +/// message across a pipe, duplicating a shared buffer handle, +/// etc. are all valid ways of propagating changes. The read +/// and write methods do NOT guarantee changes to propagate. +/// +/// This structure also prevents resource leaks by +/// unmapping the buffer it contains on destruction. +pub struct MappedBuffer<'a> { + buffer: &'a mut [u8], +} + +impl<'a> MappedBuffer<'a> { + /// Returns the length of the wrapped buffer. + /// + /// Part of reimplementing the array interface to be + /// able to use the structure naturally. + pub fn len(&self) -> usize { + self.buffer.len() + } + + /// Read safely from the shared buffer. Makes sure a real load + /// is performed by marking the read as volatile. + pub fn read(&self, index: usize) -> u8 { + unsafe { ptr::read_volatile((&self.buffer[index]) as *const u8) } + } + + /// Write safely to the shared buffer. Makes sure a real store + /// is performed by marking the store as volatile. + pub fn write(&mut self, index: usize, value: u8) { + unsafe { + ptr::write_volatile((&mut self.buffer[index]) as *mut u8, value); + } + } + + /// Returns the slice this buffer wraps. + /// + /// The reason this method is unsafe is because the way Rust maps + /// reads and writes down to loads and stores may not be to real + /// loads and stores which are required to allow changes to propagate + /// through Mojo. If you are not careful, some writes and reads may be + /// to incorrect data! Use at your own risk. + pub unsafe fn as_slice(&'a mut self) -> &'a mut [u8] { + self.buffer + } +} + +impl<'a> Drop for MappedBuffer<'a> { + /// The destructor for MappedBuffer. Unmaps the buffer it + /// encloses by using the original, raw pointer to the mapped + /// memory region. + fn drop(&mut self) { + let r = MojoResult::from_code(unsafe { + ffi::MojoUnmapBuffer(self.buffer.as_ptr() as *const ffi::c_void) + }); + if r != MojoResult::Okay { + panic!("Failed to unmap buffer. Mojo Error: {}", r); + } + } +} + +/// Creates a shared buffer in Mojo and returns a SharedBuffer +/// structure which represents a handle to the shared buffer. +pub fn create(flags: CreateFlags, num_bytes: u64) -> Result<SharedBuffer, MojoResult> { + let opts = ffi::MojoCreateSharedBufferOptions { + struct_size: mem::size_of::<ffi::MojoCreateSharedBufferOptions>() as u32, + flags: flags, + _align: [], + }; + let raw_opts = &opts as *const ffi::MojoCreateSharedBufferOptions; + let mut h: MojoHandle = 0; + let r = MojoResult::from_code(unsafe { + ffi::MojoCreateSharedBuffer(raw_opts, num_bytes, &mut h as *mut MojoHandle) + }); + if r != MojoResult::Okay { + Err(r) + } else { + Ok(SharedBuffer { handle: unsafe { handle::acquire(h) } }) + } +} + +/// Represents a handle to a shared buffer in Mojo. +/// This data structure wraps a handle and acts +/// effectively as a typed handle. +pub struct SharedBuffer { + handle: handle::UntypedHandle, +} + +impl SharedBuffer { + /// Duplicates the shared buffer handle. This is NOT the same + /// as cloning the structure which is illegal since cloning could + /// lead to resource leaks. Instead this uses Mojo to duplicate the + /// buffer handle (though the handle itself may not be represented by + /// the same number) that maps to the same shared buffer as the original. + pub fn duplicate(&self, flags: DuplicateFlags) -> Result<SharedBuffer, MojoResult> { + let opts = ffi::MojoDuplicateBufferHandleOptions { + struct_size: mem::size_of::<ffi::MojoDuplicateBufferHandleOptions>() as u32, + flags: flags, + _align: [], + }; + let raw_opts = &opts as *const ffi::MojoDuplicateBufferHandleOptions; + let mut dup_h: MojoHandle = 0; + let r = MojoResult::from_code(unsafe { + ffi::MojoDuplicateBufferHandle( + self.handle.get_native_handle(), + raw_opts, + &mut dup_h as *mut MojoHandle, + ) + }); + if r != MojoResult::Okay { + Err(r) + } else { + Ok(SharedBuffer { handle: unsafe { handle::acquire(dup_h) } }) + } + } + + /// Map the shared buffer into local memory. Generates a MappedBuffer + /// structure. See MappedBuffer for more information on how to use it. + pub fn map<'a>( + &self, + offset: u64, + num_bytes: u64, + flags: MapFlags, + ) -> Result<MappedBuffer<'a>, MojoResult> { + unsafe { + let mut ptr: *mut ffi::c_void = mem::uninitialized(); + let r = MojoResult::from_code(ffi::MojoMapBuffer( + self.handle.get_native_handle(), + offset, + num_bytes, + &mut ptr, + flags, + )); + if r != MojoResult::Okay { + Err(r) + } else { + let mut buf = slice::from_raw_parts_mut(ptr as *mut u8, num_bytes as usize); + Ok(MappedBuffer { buffer: buf }) + } + } + } + + /// Retrieves information about a shared buffer the this handle. The return + /// value is a set of flags (a bit vector in a u32) representing different + /// aspects of the shared buffer and the size of the shared buffer. + pub fn get_info(&self) -> Result<(InfoFlags, u64), MojoResult> { + let info_size = mem::size_of::<ffi::MojoBufferInformation>() as u32; + let mut info = ffi::MojoBufferInformation { + struct_size: info_size, + flags: 0, + num_bytes: 0, + _align: [], + }; + let r = MojoResult::from_code(unsafe { + ffi::MojoGetBufferInformation( + self.handle.get_native_handle(), + &mut info as *mut ffi::MojoBufferInformation, + info_size, + ) + }); + if r != MojoResult::Okay { + Err(r) + } else { + Ok((info.flags, info.num_bytes)) + } + } +} + +impl CastHandle for SharedBuffer { + /// Generates a SharedBuffer from an untyped handle wrapper + /// See mojo::system::handle for information on untyped vs. typed + unsafe fn from_untyped(handle: handle::UntypedHandle) -> Self { + SharedBuffer { handle: handle } + } + + /// Consumes this object and produces a plain handle wrapper + /// See mojo::system::handle for information on untyped vs. typed + fn as_untyped(self) -> handle::UntypedHandle { + self.handle + } +} + +impl Handle for SharedBuffer { + /// Returns the native handle wrapped by this structure. + /// + /// See mojo::system::handle for information on handle wrappers + fn get_native_handle(&self) -> MojoHandle { + self.handle.get_native_handle() + } +}
diff --git a/mojo/public/rust/system/wait_set.rs b/mojo/public/rust/system/wait_set.rs new file mode 100644 index 0000000..27117838 --- /dev/null +++ b/mojo/public/rust/system/wait_set.rs
@@ -0,0 +1,161 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::mem; +use std::ptr; + +use system::ffi; +use system::handle; +use system::handle::{CastHandle, Handle}; +use system::mojo_types; +use system::mojo_types::MojoResult; + +#[repr(u32)] +/// Create flags for wait sets +pub enum Create { + None = 0, +} + +#[repr(u32)] +/// Add flags for wait sets +pub enum Add { + None = 0, +} + +/// This struct represents a handle to a wait set in the Mojo system. +/// +/// The primary purpose of a wait set is to provide an abstraction for +/// efficiently waiting asynchronously (and cooperatively) on a set of +/// handles which are registered with it. +pub struct WaitSet { + handle: handle::UntypedHandle, +} + +impl WaitSet { + /// Creates a new WaitSet object in the Mojo system, and returns a wrapper + /// for it. If creation fails, returns the result code. + pub fn new(flags: mojo_types::CreateFlags) -> Result<WaitSet, MojoResult> { + let mut raw_handle: mojo_types::MojoHandle = 0; + let opts = ffi::MojoCreateWaitSetOptions { + struct_size: mem::size_of::<ffi::MojoCreateWaitSetOptions>() as u32, + flags: flags, + _align: [], + }; + let raw_opts = &opts as *const ffi::MojoCreateWaitSetOptions; + let r = MojoResult::from_code(unsafe { + ffi::MojoCreateWaitSet(raw_opts, &mut raw_handle as *mut mojo_types::MojoHandle) + }); + if r != MojoResult::Okay { + Err(r) + } else { + Ok(WaitSet { handle: unsafe { handle::acquire(raw_handle) } }) + } + } + + /// Adds a handle to the underlying wait set. + /// + /// The handle that is added may go invalid, at which point the result + /// returned from wait_on_set for this handle will be `Cancelled'. + /// + /// One can pass in a unique cookie value which is used to identify the + /// handle in the wait result. Currently there are no supported flags, + /// but the argument is kept for future usage. + pub fn add( + &mut self, + handle: &Handle, + signals: mojo_types::HandleSignals, + cookie: u64, + flags: mojo_types::AddFlags, + ) -> MojoResult { + let opts = ffi::MojoWaitSetAddOptions { + struct_size: mem::size_of::<ffi::MojoWaitSetAddOptions>() as u32, + flags: flags, + _align: [], + }; + let raw_opts = &opts as *const ffi::MojoWaitSetAddOptions; + MojoResult::from_code(unsafe { + ffi::MojoWaitSetAdd( + self.handle.get_native_handle(), + handle.get_native_handle(), + signals, + cookie, + raw_opts, + ) + }) + } + + /// Removes a handle from the underlying wait set by cookie value. + pub fn remove(&mut self, cookie: u64) -> MojoResult { + MojoResult::from_code(unsafe { ffi::MojoWaitSetRemove(self.get_native_handle(), cookie) }) + } + + /// Waits on this wait set. + /// + /// The conditions for the wait to end include: + /// * A handle has its requested signals satisfied. + /// * A handle is determined to never be able to have its requested + /// signals satisfied. + /// * The deadline expires. + /// * This wait set handle becomes invalid (Fatal error in this bindings). + /// + /// On a successful wait, we return the maximum number of results that could + /// possibly be returned (similar to the total number of registered handles). + /// Additionally, populates the output vector with the results of each handle + /// that completed waiting. + /// + /// On a failed wait, we return the result code. + pub fn wait_on_set( + &self, + deadline: mojo_types::MojoDeadline, + output: &mut Vec<mojo_types::WaitSetResult>, + ) -> Result<u32, MojoResult> { + assert!((output.capacity() as u64) <= ((1 as u64) << 32)); + let mut num_results = output.capacity() as u32; + let mut max_results: u32 = 0; + let mut output_ptr = output.as_mut_ptr(); + if num_results == 0 { + output_ptr = ptr::null_mut(); + } + let r = MojoResult::from_code(unsafe { + ffi::MojoWaitSetWait( + self.handle.get_native_handle(), + deadline, + &mut num_results as *mut u32, + output_ptr, + &mut max_results as *mut u32, + ) + }); + unsafe { + output.set_len(num_results as usize); + } + if r == MojoResult::Okay { + Ok(max_results) + } else { + Err(r) + } + } +} + +impl CastHandle for WaitSet { + /// Generates a WaitSet from an untyped handle wrapper + /// See mojo::system::handle for information on untyped vs. typed + unsafe fn from_untyped(handle: handle::UntypedHandle) -> Self { + WaitSet { handle: handle } + } + + /// Consumes this object and produces a plain handle wrapper + /// See mojo::system::handle for information on untyped vs. typed + fn as_untyped(self) -> handle::UntypedHandle { + self.handle + } +} + +impl Handle for WaitSet { + /// Returns the native handle wrapped by this structure. + /// + /// See mojo::system::handle for information on handle wrappers + fn get_native_handle(&self) -> mojo_types::MojoHandle { + self.handle.get_native_handle() + } +}
diff --git a/mojo/public/rust/tests/encoding.rs b/mojo/public/rust/tests/encoding.rs new file mode 100644 index 0000000..45d2f3b --- /dev/null +++ b/mojo/public/rust/tests/encoding.rs
@@ -0,0 +1,705 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! Tests encoding and decoding functionality in the bindings package +//! +//! Test failure is defined as the function returning via panicking +//! and the result being caught in the test! macro. If a test function +//! returns without panicking, it is assumed to pass. + +#[macro_use] +extern crate mojo; + +use mojo::bindings::encoding::Context; +use mojo::bindings::message::MessageHeader; +use mojo::bindings::mojom::{MojomInterface, MojomPointer, MojomStruct, MojomUnion}; + +use mojo::system; +use mojo::system::Handle; + +use std::collections::HashMap; + +#[macro_use] +mod util; + +use util::mojom_validation::*; + +/// This macro is a wrapper for the tests! macro as it takes advantage of the +/// shared code between tests. +/// +/// Given a test name, it will generate a test function. In this test function +/// we perform the following steps: +/// 1. Decode the header of the validation input. +/// 2. Verify the decoded header is what we expect. +/// 3. Decode the payload of the validation input. +/// 4. Verify the decoded payload is what we expect. +/// 5. Take the decoded payload and re-encode. +/// 6. Decode this re-encoded payload. +/// 7. Verify the re-decoded payload is what we expect. +/// +/// Each test should sufficiently verify the operation of the encoding and decoding +/// frameworks as we first verify the decoder works correctly on the "golden files", +/// then verify that the encoder works by encoding the decoded output, and decoding +/// that once again. +macro_rules! encoding_tests { + ($($name:ident { MessageHeader => $header_cls:expr, $req_type:ident => $cls:expr } )*) => { + tests! { + $( + fn $name() { + let data = include_str!(concat!("../../interfaces/bindings/tests/data/validation/", + stringify!($name), + ".data")); + match util::parse_validation_test(data) { + Ok((mut data, num_handles)) => { + let mut mock_handles = Vec::with_capacity(num_handles); + for _ in 0..num_handles { + mock_handles.push(unsafe { system::acquire(0) }); + } + println!("{}: Decoding header", stringify!($name)); + let header = MessageHeader::deserialize(&mut data[..], Vec::new()).expect("Should not error"); + let ctxt: Context = Default::default(); + let header_size = header.serialized_size(&ctxt); + let header_cls = $header_cls; + println!("{}: Verifying decoded header", stringify!($name)); + header_cls(header); + let payload_buffer = &mut data[header_size..]; + let cls = $cls; + println!("{}: Decoding payload", stringify!($name)); + let decoded_payload = $req_type::deserialize(payload_buffer, mock_handles).expect("Should not error"); + println!("{}: Verifying decoded payload", stringify!($name)); + cls(&decoded_payload); + println!("{}: Re-encoding payload", stringify!($name)); + let (mut encoded_payload, handles) = decoded_payload.auto_serialize(); + println!("{}: Decoding payload again", stringify!($name)); + let redecoded_payload = $req_type::deserialize(&mut encoded_payload[..], handles).expect("Should not error"); + println!("{}: Verifying decoded payload again", stringify!($name)); + cls(&redecoded_payload); + }, + Err(msg) => panic!("Error: {}", msg), + } + } + )* + } + } +} + +encoding_tests! { + conformance_mthd0_good { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 0); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod0Request => { + |payload: &ConformanceTestInterfaceMethod0Request| { + assert_eq!(payload.param0, -1.0); + } + } + } + conformance_mthd1_good { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 1); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod1Request => { + |payload: &ConformanceTestInterfaceMethod1Request| { + assert_eq!(payload.param0.i, 1234); + } + } + } + conformance_mthd2_good { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 2); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod2Request => { + |payload: &ConformanceTestInterfaceMethod2Request| { + assert_eq!(payload.param0.struct_a.i, 12345); + assert_eq!(payload.param1.i, 67890); + } + } + } + conformance_mthd3_good { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 3); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod3Request => { + |payload: &ConformanceTestInterfaceMethod3Request| { + assert_eq!(payload.param0, vec![true, false, true, false, + true, false, true, false, + true, true, true, true]); + } + } + } + conformance_mthd4_good { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 4); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod4Request => { + |payload: &ConformanceTestInterfaceMethod4Request| { + assert_eq!(payload.param0.data, vec![0, 1, 2]); + assert_eq!(payload.param1, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + } + } + } + conformance_mthd5_good { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 5); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod5Request => { + |payload: &ConformanceTestInterfaceMethod5Request| { + assert_eq!(payload.param0.struct_d.message_pipes.len(), 2); + for h in payload.param0.struct_d.message_pipes.iter() { + assert_eq!(h.get_native_handle(), 0); + } + assert_eq!(payload.param0.data_pipe_consumer.get_native_handle(), 0); + assert_eq!(payload.param1.get_native_handle(), 0); + } + } + } + conformance_mthd6_good { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 6); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod6Request => { + |payload: &ConformanceTestInterfaceMethod6Request| { + assert_eq!(payload.param0, vec![vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]); + } + } + } + conformance_mthd7_good { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 7); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod7Request => { + |payload: &ConformanceTestInterfaceMethod7Request| { + assert_eq!(payload.param0.fixed_size_array, [0, 1, 2]); + assert_eq!(payload.param1, [None, Some([0, 1, 2])]); + } + } + } + conformance_mthd8_good { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 8); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod8Request => { + |payload: &ConformanceTestInterfaceMethod8Request| { + assert_eq!(payload.param0, + vec![None, Some(vec![String::from_utf8(vec![0, 1, 2, 3, 4]).unwrap()]), None]); + } + } + } + conformance_mthd9_good { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 9); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod9Request => { + |payload: &ConformanceTestInterfaceMethod9Request| { + assert!(payload.param0.is_some()); + if let Some(ref v) = payload.param0 { + assert_eq!(v.len(), 2); + assert_eq!(v[0].len(), 2); + assert_eq!(v[1].len(), 3); + assert!(v[0][0].is_some()); + assert!(v[0][1].is_none()); + assert!(v[1][0].is_some()); + assert!(v[1][1].is_none()); + assert!(v[1][2].is_some()); + assert_eq!(v[0][0].as_ref().unwrap().get_native_handle(), 0); + assert_eq!(v[1][0].as_ref().unwrap().get_native_handle(), 0); + assert_eq!(v[1][2].as_ref().unwrap().get_native_handle(), 0); + } + } + } + } + conformance_mthd9_good_null_array { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 9); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod9Request => { + |payload: &ConformanceTestInterfaceMethod9Request| { + assert!(payload.param0.is_none()); + } + } + } + conformance_mthd10_good { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 10); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod10Request => { + |payload: &ConformanceTestInterfaceMethod10Request| { + let mut map = HashMap::with_capacity(2); + map.insert(String::from_utf8(vec![0, 1, 2, 3, 4]).unwrap(), 1); + map.insert(String::from_utf8(vec![5, 6, 7, 8, 9]).unwrap(), 2); + assert_eq!(payload.param0, map); + } + } + } + // Non-unique keys are strange... + // Right now, we handle them by silently overwriting. + // Maybe this will be an error in the future. + // The insert calls below reflect the duplicate keys that the + // test provides, and an example as to how overwriting happens. + conformance_mthd10_good_non_unique_keys { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 10); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod10Request => { + |payload: &ConformanceTestInterfaceMethod10Request| { + let mut map = HashMap::with_capacity(2); + map.insert(String::from_utf8(vec![0, 1, 2, 3, 4]).unwrap(), 1); + map.insert(String::from_utf8(vec![0, 1, 2, 3, 4]).unwrap(), 2); + assert_eq!(payload.param0, map); + } + } + } + conformance_mthd11_good_version0 { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 11); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod11Request => { + |payload: &ConformanceTestInterfaceMethod11Request| { + assert_eq!(payload.param0.i, 123); + assert_eq!(payload.param0.b, false); + assert!(payload.param0.struct_a.is_none()); + assert!(payload.param0.str.is_none()); + } + } + } + conformance_mthd11_good_version1 { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 11); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod11Request => { + |payload: &ConformanceTestInterfaceMethod11Request| { + assert_eq!(payload.param0.i, 123); + assert_eq!(payload.param0.b, false); + assert!(payload.param0.struct_a.is_none()); + assert!(payload.param0.str.is_none()); + } + } + } + conformance_mthd11_good_version2 { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 11); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod11Request => { + |payload: &ConformanceTestInterfaceMethod11Request| { + assert_eq!(payload.param0.i, 123); + assert_eq!(payload.param0.b, false); + assert!(payload.param0.struct_a.is_none()); + assert!(payload.param0.str.is_none()); + } + } + } + conformance_mthd11_good_version3 { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 11); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod11Request => { + |payload: &ConformanceTestInterfaceMethod11Request| { + assert_eq!(payload.param0.i, 123); + assert_eq!(payload.param0.b, true); + assert!(payload.param0.struct_a.is_none()); + assert_eq!(payload.param0.str, Some(String::from_utf8(vec![0, 1]).unwrap())); + } + } + } + conformance_mthd11_good_version_newer_than_known_1 { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 11); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod11Request => { + |payload: &ConformanceTestInterfaceMethod11Request| { + assert_eq!(payload.param0.i, 123); + assert_eq!(payload.param0.b, true); + assert!(payload.param0.struct_a.is_none()); + assert_eq!(payload.param0.str, Some(String::from_utf8(vec![0, 1]).unwrap())); + } + } + } + conformance_mthd11_good_version_newer_than_known_2 { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 11); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod11Request => { + |payload: &ConformanceTestInterfaceMethod11Request| { + assert_eq!(payload.param0.i, 123); + assert_eq!(payload.param0.b, true); + assert!(payload.param0.struct_a.is_none()); + assert_eq!(payload.param0.str, Some(String::from_utf8(vec![0, 1]).unwrap())); + } + } + } + conformance_mthd13_good_1 { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 13); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod13Request => { + |payload: &ConformanceTestInterfaceMethod13Request| { + assert!(payload.param0.is_none()); + assert_eq!(payload.param1, 65535); + assert!(payload.param2.is_none()); + } + } + } + conformance_mthd13_good_2 { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 13); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod13Request => { + |payload: &ConformanceTestInterfaceMethod13Request| { + assert!(payload.param0.is_some()); + assert_eq!(payload.param0.as_ref().unwrap().pipe().get_native_handle(), 0); + assert_eq!(payload.param1, 65535); + assert!(payload.param2.is_some()); + assert_eq!(payload.param2.as_ref().unwrap().pipe().get_native_handle(), 0); + } + } + } + conformance_mthd14_good_1 { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 14); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod14Request => { + |payload: &ConformanceTestInterfaceMethod14Request| { + match payload.param0 { + UnionA::a(ref val) => assert_eq!(*val, 54), + _ => panic!("Incorrect union variant! Tag found: {}", payload.param0.get_tag()), + } + } + } + } + conformance_mthd14_good_array_in_union { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 14); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod14Request => { + |payload: &ConformanceTestInterfaceMethod14Request| { + match payload.param0 { + UnionA::d(ref val) => assert_eq!(*val, Some(vec![0, 1, 2])), + _ => panic!("Incorrect union variant! Tag found: {}", payload.param0.get_tag()), + } + } + } + } + conformance_mthd14_good_map_in_union { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 14); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod14Request => { + |payload: &ConformanceTestInterfaceMethod14Request| { + let mut map = HashMap::with_capacity(2); + map.insert(String::from_utf8(vec![0, 1, 2, 3, 4]).unwrap(), 1); + map.insert(String::from_utf8(vec![5, 6, 7, 8, 9]).unwrap(), 2); + match payload.param0 { + UnionA::e(ref val) => assert_eq!(*val, Some(map)), + _ => panic!("Incorrect union variant! Tag found: {}", payload.param0.get_tag()), + } + } + } + } + conformance_mthd14_good_nested_union { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 14); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod14Request => { + |payload: &ConformanceTestInterfaceMethod14Request| { + match payload.param0 { + UnionA::f(ref val) => { + assert!(val.is_some()); + let inner = val.as_ref().unwrap(); + match *inner { + UnionB::b(inner_val) => assert_eq!(inner_val, 10), + _ => panic!("Incorrect inner union variant! Tag found: {}", inner.get_tag()), + } + }, + _ => panic!("Incorrect union variant! Tag found: {}", payload.param0.get_tag()), + } + } + } + } + conformance_mthd14_good_null_array_in_union { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 14); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod14Request => { + |payload: &ConformanceTestInterfaceMethod14Request| { + match payload.param0 { + UnionA::d(ref val) => assert_eq!(*val, None), + _ => panic!("Incorrect union variant! Tag found: {}", payload.param0.get_tag()), + } + } + } + } + conformance_mthd14_good_null_map_in_union { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 14); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod14Request => { + |payload: &ConformanceTestInterfaceMethod14Request| { + match payload.param0 { + UnionA::e(ref val) => assert_eq!(*val, None), + _ => panic!("Incorrect union variant! Tag found: {}", payload.param0.get_tag()), + } + } + } + } + conformance_mthd14_good_struct_in_union { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 14); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod14Request => { + |payload: &ConformanceTestInterfaceMethod14Request| { + match payload.param0 { + UnionA::c(ref val) => { + let struct_val = val.as_ref().unwrap(); + assert_eq!(struct_val.i, 20); + }, + _ => panic!("Incorrect union variant! Tag found: {}", payload.param0.get_tag()), + } + } + } + } + conformance_mthd14_good_unknown_union_tag { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 14); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod14Request => { + |payload: &ConformanceTestInterfaceMethod14Request| { + match payload.param0 { + UnionA::_Unknown(ref val) => assert_eq!(*val, 54), + _ => panic!("Incorrect union variant! Tag found: {}", payload.param0.get_tag()), + } + } + } + } + conformance_mthd15_good_union_in_array { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 15); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod15Request => { + |payload: &ConformanceTestInterfaceMethod15Request| { + assert_eq!(payload.param0.a, true); + assert_eq!(payload.param0.b, 22); + assert!(payload.param0.c.is_none()); + assert!(payload.param0.d.is_some()); + assert!(payload.param0.e.is_none()); + let array = payload.param0.d.as_ref().unwrap(); + assert_eq!(array.len(), 3); + for u in array.iter() { + match *u { + UnionA::b(ref val) => assert_eq!(*val, 10), + _ => panic!("Incorrect union variant! Tag found: {}", u.get_tag()), + } + } + } + } + } + conformance_mthd15_good_union_in_map { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 15); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod15Request => { + |payload: &ConformanceTestInterfaceMethod15Request| { + assert_eq!(payload.param0.a, true); + assert_eq!(payload.param0.b, 22); + assert!(payload.param0.c.is_none()); + assert!(payload.param0.d.is_none()); + assert!(payload.param0.e.is_some()); + let map = payload.param0.e.as_ref().unwrap(); + assert_eq!(map.len(), 3); + let mut expect_keys = HashMap::with_capacity(3); + expect_keys.insert(8, false); + expect_keys.insert(7, false); + expect_keys.insert(1, false); + for (key, value) in map.iter() { + expect_keys.insert(*key, true); + match *value { + UnionA::b(ref val) => assert_eq!(*val, 10), + _ => panic!("Incorrect union variant! Tag found: {}", value.get_tag()), + } + } + for (key, value) in expect_keys.iter() { + if *value == false { + panic!("Expected key `{}`, but not found!", *key); + } + } + } + } + } + conformance_mthd15_good_union_in_struct { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 0); + assert_eq!(header.name, 15); + assert_eq!(header.flags, 0); + } + }, + ConformanceTestInterfaceMethod15Request => { + |payload: &ConformanceTestInterfaceMethod15Request| { + assert_eq!(payload.param0.a, true); + assert_eq!(payload.param0.b, 22); + assert!(payload.param0.c.is_some()); + assert!(payload.param0.d.is_none()); + assert!(payload.param0.e.is_none()); + let union_val = payload.param0.c.as_ref().unwrap(); + match *union_val { + UnionA::b(ref val) => assert_eq!(*val, 54), + _ => panic!("Incorrect union variant! Tag found: {}", union_val.get_tag()), + } + } + } + } + integration_intf_rqst_mthd0_good { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 1); + assert_eq!(header.name, 0); + assert_eq!(header.flags, 1); + assert_eq!(header.request_id, 7); + } + }, + IntegrationTestInterfaceMethod0Request => { + |payload: &IntegrationTestInterfaceMethod0Request| { + assert_eq!(payload.param0.a, -1); + } + } + } + integration_intf_resp_mthd0_good { + MessageHeader => { + |header: MessageHeader| { + assert_eq!(header.version, 1); + assert_eq!(header.name, 0); + assert_eq!(header.flags, 2); + assert_eq!(header.request_id, 1); + } + }, + IntegrationTestInterfaceMethod0Response => { + |payload: &IntegrationTestInterfaceMethod0Response| { + assert_eq!(payload.param0, vec![0]); + } + } + } +}
diff --git a/mojo/public/rust/tests/integration.rs b/mojo/public/rust/tests/integration.rs new file mode 100644 index 0000000..3f551f8 --- /dev/null +++ b/mojo/public/rust/tests/integration.rs
@@ -0,0 +1,67 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! Tests some higher-level functionality of Mojom interfaces. +//! +//! Test failure is defined as the function returning via panicking +//! and the result being caught in the test! macro. If a test function +//! returns without panicking, it is assumed to pass. + +#[macro_use] +extern crate mojo; + +#[macro_use] +mod util; + +use mojo::bindings::mojom::{MojomInterface, MojomInterfaceRecv, MojomInterfaceSend}; +use mojo::system::message_pipe; +use mojo::system::{Handle, MOJO_INDEFINITE}; + +use std::thread; + +use util::mojom_validation::*; + +tests! { + // Tests basic client and server interaction over a thread + fn send_and_recv() { + let (endpt0, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + // Client and server handles + let client = IntegrationTestInterfaceClient::new(endpt0); + let server = IntegrationTestInterfaceServer::with_version(endpt1, 0); + // Client thread + let handle = thread::spawn(move || { + // Send request + client.send_request(5, IntegrationTestInterfaceMethod0Request { + param0: BasicStruct { + a: -1, + }, + }).unwrap(); + // Wait for response + client.pipe().wait(signals!(Signals::Readable), MOJO_INDEFINITE); + // Decode response + let (req_id, options) = client.recv_response().unwrap(); + assert_eq!(req_id, 5); + match options { + IntegrationTestInterfaceResponseOption::IntegrationTestInterfaceMethod0(msg) => { + assert_eq!(msg.param0, vec![1, 2, 3]); + }, + } + }); + // Wait for request + server.pipe().wait(signals!(Signals::Readable), MOJO_INDEFINITE); + // Decode request + let (req_id, options) = server.recv_response().unwrap(); + assert_eq!(req_id, 5); + match options { + IntegrationTestInterfaceRequestOption::IntegrationTestInterfaceMethod0(msg) => { + assert_eq!(msg.param0.a, -1); + }, + } + // Send response + server.send_request(5, IntegrationTestInterfaceMethod0Response { + param0: vec![1, 2, 3], + }).unwrap(); + let _ = handle.join(); + } +}
diff --git a/mojo/public/rust/tests/regression.rs b/mojo/public/rust/tests/regression.rs new file mode 100644 index 0000000..9b10159 --- /dev/null +++ b/mojo/public/rust/tests/regression.rs
@@ -0,0 +1,106 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! Tests validation functionality in the bindings package +//! +//! Test failure is defined as the function returning via panicking +//! and the result being caught in the test! macro. If a test function +//! returns without panicking, it is assumed to pass. + +#[macro_use] +extern crate mojo; + +use mojo::bindings::decoding::{Decoder, ValidationError}; +use mojo::bindings::encoding; +use mojo::bindings::encoding::{Context, DataHeaderValue, Encoder}; +use mojo::bindings::mojom::{MojomEncodable, MojomPointer, MojomStruct}; +use mojo::system; +use mojo::system::UntypedHandle; + +#[macro_use] +mod util; + +const STRUCT_A_VERSIONS: [(u32, u32); 1] = [(0, 16)]; + +struct StructA<T: MojomEncodable> { + param0: [T; 3], +} + +impl<T: MojomEncodable> MojomPointer for StructA<T> { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let _version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&STRUCT_A_VERSIONS) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <[T; 3]>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(StructA { param0: param0 }) + } +} + +impl<T: MojomEncodable> MojomEncodable for StructA<T> { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl<T: MojomEncodable> MojomStruct for StructA<T> {} + +tests! { + // Fixed size arrays have complex and unsafe semantics to ensure + // there are no memory leaks. We test this behavior here to make + // sure memory isn't becoming corrupted. + fn regression_fixed_size_array_error_propagates_safely() { + let handle1 = unsafe { system::acquire(0) }; + let handle2 = unsafe { system::acquire(0) }; + let handle3 = unsafe { system::acquire(0) }; + let val = StructA { + param0: [handle1, handle2, handle3], + }; + let (mut buffer, mut handles) = val.auto_serialize(); + handles.truncate(1); + let new_val = <StructA<UntypedHandle>>::deserialize(&mut buffer[..], handles); + match new_val { + Ok(_) => panic!("Value should not be okay!"), + Err(err) => assert_eq!(err, ValidationError::IllegalHandle), + } + } + + // Same as the above test, but verifies that drop() is called. + // For the only handle that should drop, we make the handle some + // random number which is potentially a valid handle. When on + // drop() we try to close it, we should panic. + #[should_panic] + fn regression_fixed_size_array_verify_drop() { + let handle1 = unsafe { system::acquire(42) }; + let handle2 = unsafe { system::acquire(0) }; + let handle3 = unsafe { system::acquire(0) }; + let val = StructA { + param0: [handle1, handle2, handle3], + }; + let (mut buffer, mut handles) = val.auto_serialize(); + handles.truncate(1); + let new_val = <StructA<UntypedHandle>>::deserialize(&mut buffer[..], handles); + match new_val { + Ok(_) => panic!("Value should not be okay!"), + Err(err) => assert_eq!(err, ValidationError::IllegalHandle), + } + } +}
diff --git a/mojo/public/rust/tests/run_loop.rs b/mojo/public/rust/tests/run_loop.rs new file mode 100644 index 0000000..47aed22 --- /dev/null +++ b/mojo/public/rust/tests/run_loop.rs
@@ -0,0 +1,400 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! Tests all functionality in the system package +//! +//! Test failure is defined as the function returning via panicking +//! and the result being caught in the test! macro. If a test function +//! returns without panicking, it is assumed to pass. + +#[macro_use] +extern crate mojo; + +#[macro_use] +mod util; + +use mojo::bindings::run_loop; +use mojo::bindings::run_loop::{Handler, RunLoop, Token, WaitError}; +use mojo::system::message_pipe; +use mojo::system::MOJO_INDEFINITE; + +use std::cell::Cell; +use std::rc::Rc; + +struct HandlerExpectReady {} + +impl Handler for HandlerExpectReady { + fn on_ready(&mut self, runloop: &mut RunLoop, token: Token) { + runloop.deregister(token); + } + fn on_timeout(&mut self, _runloop: &mut RunLoop, _token: Token) { + panic!("Timed-out when expected ready"); + } + fn on_error(&mut self, _runloop: &mut RunLoop, _token: Token, _error: WaitError) { + panic!("Error when expected ready"); + } +} + +struct HandlerExpectTimeout {} + +impl Handler for HandlerExpectTimeout { + fn on_ready(&mut self, _runloop: &mut RunLoop, _token: Token) { + panic!("Ready when expected timeout"); + } + fn on_timeout(&mut self, runloop: &mut RunLoop, token: Token) { + runloop.deregister(token); + } + fn on_error(&mut self, _runloop: &mut RunLoop, _token: Token, _error: WaitError) { + panic!("Error when expected timeout"); + } +} + +struct HandlerExpectError {} + +impl Handler for HandlerExpectError { + fn on_ready(&mut self, _runloop: &mut RunLoop, _token: Token) { + panic!("Ready when expected error"); + } + fn on_timeout(&mut self, _runloop: &mut RunLoop, _token: Token) { + panic!("Timed-out when expected error"); + } + fn on_error(&mut self, runloop: &mut RunLoop, token: Token, error: WaitError) { + assert_eq!(error, WaitError::Unsatisfiable); + runloop.deregister(token); + } +} + +struct HandlerQuit {} + +impl Handler for HandlerQuit { + fn on_ready(&mut self, runloop: &mut RunLoop, _token: Token) { + runloop.quit(); + } + fn on_timeout(&mut self, _runloop: &mut RunLoop, _token: Token) { + panic!("Timed-out when expected error"); + } + fn on_error(&mut self, _runloop: &mut RunLoop, _token: Token, _error: WaitError) { + panic!("Error when expected ready"); + } +} + +struct HandlerRegister {} + +impl Handler for HandlerRegister { + fn on_ready(&mut self, runloop: &mut RunLoop, token: Token) { + let (_, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + let _ = runloop.register( + &endpt1, + signals!(Signals::Writable), + MOJO_INDEFINITE, + HandlerDeregisterOther { other: token }, + ); + } + fn on_timeout(&mut self, _runloop: &mut RunLoop, _token: Token) { + panic!("Timed-out when expected error"); + } + fn on_error(&mut self, _runloop: &mut RunLoop, _token: Token, _error: WaitError) { + panic!("Error when expected ready"); + } +} + +struct HandlerDeregisterOther { + other: Token, +} + +impl Handler for HandlerDeregisterOther { + fn on_ready(&mut self, _runloop: &mut RunLoop, _token: Token) { + panic!("Ready when expected error"); + } + fn on_timeout(&mut self, _runloop: &mut RunLoop, _token: Token) { + panic!("Timed-out when expected error"); + } + fn on_error(&mut self, runloop: &mut RunLoop, token: Token, error: WaitError) { + assert_eq!(error, WaitError::HandleClosed); + runloop.deregister(token); + runloop.deregister(self.other.clone()); + } +} + +struct HandlerReregister { + count: u64, +} + +impl Handler for HandlerReregister { + fn on_ready(&mut self, runloop: &mut RunLoop, token: Token) { + runloop.deregister(token); + } + fn on_timeout(&mut self, runloop: &mut RunLoop, token: Token) { + if self.count < 10 { + runloop.reregister(&token, signals!(Signals::Readable), 0); + self.count += 1; + } else { + runloop.reregister(&token, signals!(Signals::Writable), MOJO_INDEFINITE); + } + } + fn on_error(&mut self, _runloop: &mut RunLoop, _token: Token, _error: WaitError) { + panic!("Error when expected ready"); + } +} + +struct HandlerNesting { + count: u64, +} + +impl Handler for HandlerNesting { + fn on_ready(&mut self, _runloop: &mut RunLoop, _token: Token) { + panic!("Ready when expected timeout"); + } + fn on_timeout(&mut self, runloop: &mut RunLoop, token: Token) { + let mut nested_runloop = run_loop::RunLoop::new(); + if self.count < 10 { + let handler = HandlerNesting { count: self.count + 1 }; + let (endpt0, _endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + let _ = nested_runloop.register(&endpt0, signals!(Signals::Readable), 0, handler); + nested_runloop.run(); + } else { + let handler = HandlerNesting { count: self.count + 1 }; + let (endpt0, _) = message_pipe::create(mpflags!(Create::None)).unwrap(); + let _ = nested_runloop.register(&endpt0, signals!(Signals::Readable), 0, handler); + nested_runloop.run(); + } + runloop.deregister(token); + } + fn on_error(&mut self, runloop: &mut RunLoop, token: Token, error: WaitError) { + assert_eq!(error, WaitError::Unsatisfiable); + assert_eq!(self.count, 11); + runloop.deregister(token); + } +} + +struct HandlerBadNesting {} + +impl Handler for HandlerBadNesting { + fn on_ready(&mut self, runloop: &mut RunLoop, _token: Token) { + runloop.quit(); + } + fn on_timeout(&mut self, runloop: &mut RunLoop, _token: Token) { + runloop.run(); + } + fn on_error(&mut self, runloop: &mut RunLoop, _token: Token, _error: WaitError) { + runloop.quit(); + } +} + +struct HandlerTasks { + count: Rc<Cell<u64>>, +} + +impl Handler for HandlerTasks { + fn on_ready(&mut self, runloop: &mut RunLoop, token: Token) { + let r = self.count.clone(); + let _ = runloop.post_task( + move |_runloop| { + let val = (*r).get(); + (*r).set(val + 1); + }, + 10, + ); + if (*self.count).get() > 10 { + runloop.deregister(token); + } + } + fn on_timeout(&mut self, _runloop: &mut RunLoop, _token: Token) { + panic!("Timed-out when expected error"); + } + fn on_error(&mut self, _runloop: &mut RunLoop, _token: Token, _error: WaitError) { + panic!("Error when expected ready"); + } +} + +struct NestedTasks { + count: Rc<Cell<u64>>, + quitter: bool, +} + +impl Handler for NestedTasks { + fn on_ready(&mut self, runloop: &mut RunLoop, token: Token) { + let r = self.count.clone(); + let quit = self.quitter; + let _ = runloop.post_task( + move |runloop| { + let r2 = r.clone(); + let tk = token.clone(); + if (*r).get() < 10 { + let _ = runloop.post_task( + move |_runloop| { + let val = (*r2).get(); + (*r2).set(val + 1); + }, + 0, + ); + } else { + if quit { + runloop.quit(); + } else { + runloop.deregister(tk); + } + } + }, + 0, + ); + } + fn on_timeout(&mut self, _runloop: &mut RunLoop, _token: Token) { + panic!("Timed-out when expected error"); + } + fn on_error(&mut self, _runloop: &mut RunLoop, _token: Token, _error: WaitError) { + panic!("Error when expected ready"); + } +} + +tests! { + // Verifies that after adding and removing, we can run, exit and be + // left in a consistent state. + fn add_remove() { + run_loop::with_current(|runloop| { + let (endpt0, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + let token0 = runloop.register(&endpt0, signals!(Signals::Writable), 0, HandlerExpectReady {}); + let token1 = runloop.register(&endpt1, signals!(Signals::Writable), 0, HandlerExpectReady {}); + runloop.deregister(token1); + runloop.deregister(token0); + runloop.run(); + }) + } + + // Verifies that generated tokens are unique. + fn tokens() { + let (_endpt0, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + let mut vec = Vec::new(); + run_loop::with_current(|runloop| { + for _ in 0..10 { + vec.push(runloop.register(&endpt1, signals!(Signals::None), 0, HandlerExpectReady {})); + } + for i in 0..10 { + for j in 0..10 { + if i != j { + assert!(vec[i] != vec[j]); + } + } + } + }); + } + + // Verifies that the handler's "on_ready" function is called. + fn notify_results() { + let (_endpt0, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + run_loop::with_current(|runloop| { + let _ = runloop.register(&endpt1, signals!(Signals::Writable), MOJO_INDEFINITE, HandlerExpectReady {}); + runloop.run(); + }); + } + + // Verifies that the handler's "on_timeout" function is called. + fn notify_timeout() { + let (_endpt0, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + run_loop::with_current(|runloop| { + let _ = runloop.register(&endpt1, signals!(Signals::Readable), 0, HandlerExpectTimeout {}); + runloop.run(); + }); + } + + // Verifies that the handler's "on_error" function is called. + fn notify_error() { + // Drop the first endpoint immediately + let (_, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + run_loop::with_current(|runloop| { + let _ = runloop.register(&endpt1, signals!(Signals::Readable), 0, HandlerExpectError {}); + runloop.run(); + }); + } + + // Verifies that the handler's "on_ready" function is called which only quits. + fn notify_ready_quit() { + let (_endpt0, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + run_loop::with_current(|runloop| { + let _ = runloop.register(&endpt1, signals!(Signals::Writable), MOJO_INDEFINITE, HandlerQuit {}); + runloop.run(); + }); + } + + // Tests more complex behavior, i.e. the interaction between two handlers. + fn register_deregister() { + let (_endpt0, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + run_loop::with_current(|runloop| { + let _ = runloop.register(&endpt1, signals!(Signals::Writable), MOJO_INDEFINITE, HandlerRegister {}); + runloop.run(); + }); + } + + // Tests reregistering. + fn reregister() { + let (_endpt0, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + run_loop::with_current(|runloop| { + let _ = runloop.register(&endpt1, signals!(Signals::Readable), 0, HandlerReregister { count: 0 }); + runloop.run(); + }); + } + + // Tests nesting run loops by having a handler create a new one. + fn nesting() { + let (_endpt0, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + run_loop::with_current(|runloop| { + let _ = runloop.register(&endpt1, signals!(Signals::Readable), 0, HandlerNesting { count: 0 }); + runloop.run(); + }); + } + + // Tests to make sure nesting with the SAME runloop fails. + #[should_panic] + fn bad_nesting() { + let (_endpt0, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + run_loop::with_current(|runloop| { + let _ = runloop.register(&endpt1, signals!(Signals::Readable), 0, HandlerBadNesting {}); + runloop.run(); + }); + } + + // Tests adding a simple task that adds a handler. + fn simple_task() { + run_loop::with_current(|runloop| { + let _ = runloop.post_task(|runloop| { + let (_, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + let _ = runloop.register(&endpt1, signals!(Signals::Readable), 0, HandlerExpectError {}); + }, 0); + runloop.run(); + }); + } + + // Tests using a handler that adds a bunch of tasks. + fn handler_tasks() { + let (_endpt0, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + let r = Rc::new(Cell::new(0)); + run_loop::with_current(|runloop| { + let _ = runloop.register(&endpt1, signals!(Signals::Writable), 0, HandlerTasks { count: r.clone() }); + runloop.run(); + assert!((*r).get() >= 11); + }); + } + + // Tests using a handler that adds a bunch of tasks. + fn nested_tasks() { + let (_endpt0, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + let r = Rc::new(Cell::new(0)); + run_loop::with_current(|runloop| { + let _ = runloop.register(&endpt1, signals!(Signals::Writable), 0, NestedTasks { count: r.clone(), quitter: false }); + runloop.run(); + assert!((*r).get() >= 10); + }); + } + + // Tests using a handler that adds a bunch of tasks. + fn nested_tasks_quit() { + let (_endpt0, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + let r = Rc::new(Cell::new(0)); + run_loop::with_current(|runloop| { + let _ = runloop.register(&endpt1, signals!(Signals::Writable), 0, NestedTasks { count: r.clone(), quitter: true }); + runloop.run(); + assert!((*r).get() >= 10); + }); + } +}
diff --git a/mojo/public/rust/tests/system.rs b/mojo/public/rust/tests/system.rs new file mode 100644 index 0000000..bf4f1f7c --- /dev/null +++ b/mojo/public/rust/tests/system.rs
@@ -0,0 +1,264 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! Tests all functionality in the system package +//! +//! Test failure is defined as the function returning via panicking +//! and the result being caught in the test! macro. If a test function +//! returns without panicking, it is assumed to pass. + +#[macro_use] +extern crate mojo; + +#[macro_use] +mod util; + +use mojo::system; +use mojo::system::core; +use mojo::system::data_pipe; +use mojo::system::message_pipe; +use mojo::system::shared_buffer; +use mojo::system::wait_set; +use mojo::system::{CastHandle, Handle}; + +use std::string::String; +use std::thread; +use std::vec::Vec; + +tests! { + fn get_time_ticks_now() { + let x = core::get_time_ticks_now(); + assert!(x >= 10); + } + + fn handle() { + let sb = shared_buffer::create(sbflags!(Create::None), 1).unwrap(); + let handle = sb.as_untyped(); + unsafe { + assert_eq!((handle.get_native_handle() != 0), handle.is_valid()); + assert!(handle.get_native_handle() != 0 && handle.is_valid()); + let mut h2 = system::acquire(handle.get_native_handle()); + assert!(h2.is_valid()); + h2.invalidate(); + assert!(!h2.is_valid()); + } + } + + fn shared_buffer() { + let bufsize = 100; + let sb1; + { + let mut buf; + { + let sb_c = shared_buffer::create(sbflags!(Create::None), bufsize).unwrap(); + // Extract original handle to check against + let sb_h = sb_c.get_native_handle(); + // Test casting of handle types + let sb_u = sb_c.as_untyped(); + assert_eq!(sb_u.get_native_handle(), sb_h); + let sb = unsafe { shared_buffer::SharedBuffer::from_untyped(sb_u) }; + assert_eq!(sb.get_native_handle(), sb_h); + // Test map + buf = sb.map(0, bufsize, sbflags!(Map::None)).unwrap(); + assert_eq!(buf.len(), bufsize as usize); + // Test get info + let (flags, size) = sb.get_info().unwrap(); + assert_eq!(flags, sbflags!(Info::None)); + assert_eq!(size, bufsize); + buf.write(50, 34); + // Test duplicate + sb1 = sb.duplicate(sbflags!(Duplicate::None)).unwrap(); + } + // sb gets closed + buf.write(51, 35); + } + // buf just got closed + // remap to buf1 from sb1 + let buf1 = sb1.map(50, 50, sbflags!(Map::None)).unwrap(); + assert_eq!(buf1.len(), 50); + // verify buffer contents + assert_eq!(buf1.read(0), 34); + assert_eq!(buf1.read(1), 35); + } + + fn message_pipe() { + let (endpt, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + // Extract original handle to check against + let endpt_h = endpt.get_native_handle(); + // Test casting of handle types + let endpt_u = endpt.as_untyped(); + assert_eq!(endpt_u.get_native_handle(), endpt_h); + { + let endpt0 = unsafe { message_pipe::MessageEndpoint::from_untyped(endpt_u) }; + assert_eq!(endpt0.get_native_handle(), endpt_h); + { + let (s, r) = endpt0.wait(signals!(Signals::Readable), 0); + assert_eq!(r, mojo::MojoResult::DeadlineExceeded); + assert!(s.satisfied().is_writable()); + assert!(s.satisfiable().is_readable()); + assert!(s.satisfiable().is_writable()); + assert!(s.satisfiable().is_peer_closed()); + } + { + let (s, r) = endpt0.wait(signals!(Signals::Writable), system::MOJO_INDEFINITE); + assert_eq!(r, mojo::MojoResult::Okay); + assert!(s.satisfied().is_writable()); + assert!(s.satisfiable().is_readable()); + assert!(s.satisfiable().is_writable()); + assert!(s.satisfiable().is_peer_closed()); + } + match endpt0.read(mpflags!(Read::None)) { + Ok((_msg, _handles)) => panic!("Read should not have succeeded."), + Err(r) => assert_eq!(r, mojo::MojoResult::ShouldWait), + } + let hello = "hello".to_string().into_bytes(); + let write_result = endpt1.write(&hello, Vec::new(), mpflags!(Write::None)); + assert_eq!(write_result, mojo::MojoResult::Okay); + { + let (s, r) = endpt0.wait(signals!(Signals::Readable), system::MOJO_INDEFINITE); + assert_eq!(r, mojo::MojoResult::Okay); + assert!(s.satisfied().is_readable(), s.satisfied().is_writable()); + assert!(s.satisfiable().is_readable()); + assert!(s.satisfiable().is_writable()); + assert!(s.satisfiable().is_peer_closed()); + } + let hello_data; + match endpt0.read(mpflags!(Read::None)) { + Ok((msg, _handles)) => hello_data = msg, + Err(r) => panic!("Failed to read message on endpt0, error: {}", r), + } + assert_eq!(String::from_utf8(hello_data).unwrap(), "hello".to_string()); + { + let handles: Vec<&Handle> = vec![&endpt0]; + let signals: Vec<system::HandleSignals> = vec![signals!(Signals::Readable)]; + let mut states: Vec<system::SignalsState> = vec![Default::default()]; + let (idx, r) = core::wait_many(&handles, &signals, &mut states, 10); + assert_eq!(r, mojo::MojoResult::DeadlineExceeded); + assert_eq!(idx, -1); + assert_eq!(states.len(), 1); + assert!(states[0].satisfied().is_writable()); + assert!(states[0].satisfiable().is_readable()); + assert!(states[0].satisfiable().is_writable()); + assert!(states[0].satisfiable().is_peer_closed()); + } + } + let (s, r) = endpt1.wait(signals!(Signals::Readable, Signals::Writable), + system::MOJO_INDEFINITE); + assert_eq!(r, mojo::MojoResult::FailedPrecondition); + assert!(s.satisfied().is_peer_closed()); + assert_eq!(s.satisfiable().get_bits(), system::Signals::PeerClosed as u32); + } + + fn data_pipe() { + let (cons0, prod0) = data_pipe::create_default().unwrap(); + // Extract original handle to check against + let cons_h = cons0.get_native_handle(); + let prod_h = prod0.get_native_handle(); + // Test casting of handle types + let cons_u = cons0.as_untyped(); + let prod_u = prod0.as_untyped(); + assert_eq!(cons_u.get_native_handle(), cons_h); + assert_eq!(prod_u.get_native_handle(), prod_h); + let cons = unsafe { data_pipe::Consumer::<u8>::from_untyped(cons_u) }; + let prod = unsafe { data_pipe::Producer::<u8>::from_untyped(prod_u) }; + assert_eq!(cons.get_native_handle(), cons_h); + assert_eq!(prod.get_native_handle(), prod_h); + // Test waiting on consumer + { + let (_s, r) = cons.wait(signals!(Signals::Readable), 0); + assert_eq!(r, mojo::MojoResult::DeadlineExceeded); + } + // Test waiting on producer + { + let (_s, r) = prod.wait(signals!(Signals::Writable), system::MOJO_INDEFINITE); + assert_eq!(r, mojo::MojoResult::Okay); + } + // Test one-phase read/write. + // Writing. + let hello = "hello".to_string().into_bytes(); + let bytes_written = prod.write(&hello, dpflags!(Write::None)).unwrap(); + assert_eq!(bytes_written, hello.len()); + // Reading. + { + let (_s, r) = cons.wait(signals!(Signals::Readable), system::MOJO_INDEFINITE); + assert_eq!(r, mojo::MojoResult::Okay); + } + let data_string = String::from_utf8(cons.read(dpflags!(Read::None)).unwrap()).unwrap(); + assert_eq!(data_string, "hello".to_string()); + { + // Test two-phase read/write. + // Writing. + let goodbye = "goodbye".to_string().into_bytes(); + let mut write_buf = match prod.begin(dpflags!(Write::None)) { + Ok(buf) => buf, + Err(err) => panic!("Error on write begin: {}", err), + }; + assert!(write_buf.len() >= goodbye.len()); + for i in 0..goodbye.len() { + write_buf[i] = goodbye[i]; + } + { + let (_s, r) = cons.wait(signals!(Signals::Readable), 0); + assert_eq!(r, mojo::MojoResult::DeadlineExceeded); + } + match write_buf.commit(goodbye.len()) { + Some((_buf, _err)) => assert!(false), + None => (), + } + // Reading. + { + let (_s, r) = cons.wait(signals!(Signals::Readable), system::MOJO_INDEFINITE); + assert_eq!(r, mojo::MojoResult::Okay); + } + let mut data_goodbye: Vec<u8> = Vec::with_capacity(goodbye.len()); + { + let read_buf = match cons.begin(dpflags!(Read::None)) { + Ok(buf) => buf, + Err(err) => panic!("Error on read begin: {}", err), + }; + for i in 0..read_buf.len() { + data_goodbye.push(read_buf[i]); + } + match cons.read(dpflags!(Read::None)) { + Ok(_bytes) => assert!(false), + Err(r) => assert_eq!(r, mojo::MojoResult::Busy), + } + match read_buf.commit(data_goodbye.len()) { + Some((_buf, _err)) => assert!(false), + None => (), + } + } + assert_eq!(data_goodbye.len(), goodbye.len()); + assert_eq!(String::from_utf8(data_goodbye).unwrap(), "goodbye".to_string()); + } + } + + fn wait_set() { + let set0 = wait_set::WaitSet::new(wsflags!(Create::None)).unwrap(); + let set_h = set0.get_native_handle(); + let set_u = set0.as_untyped(); + assert_eq!(set_u.get_native_handle(), set_h); + let mut set = unsafe { wait_set::WaitSet::from_untyped(set_u) }; + let (endpt0, endpt1) = message_pipe::create(mpflags!(Create::None)).unwrap(); + let signals = signals!(Signals::Readable); + let flags = wsflags!(Add::None); + assert_eq!(set.add(&endpt0, signals, 245, flags), mojo::MojoResult::Okay); + assert_eq!(set.add(&endpt0, signals, 245, flags), mojo::MojoResult::AlreadyExists); + assert_eq!(set.remove(245), mojo::MojoResult::Okay); + assert_eq!(set.remove(245), mojo::MojoResult::NotFound); + assert_eq!(set.add(&endpt0, signals, 123, flags), mojo::MojoResult::Okay); + thread::spawn(move || { + let hello = "hello".to_string().into_bytes(); + let write_result = endpt1.write(&hello, Vec::new(), mpflags!(Write::None)); + assert_eq!(write_result, mojo::MojoResult::Okay); + }); + let mut output = Vec::with_capacity(1); + let max = set.wait_on_set(system::MOJO_INDEFINITE, &mut output).unwrap(); + assert_eq!(output.len(), 1); + assert_eq!(output[0].cookie(), 123); + assert_eq!(output[0].result(), mojo::MojoResult::Okay); + assert!(output[0].state().satisfied().is_readable()); + assert_eq!(max, 1); + } +}
diff --git a/mojo/public/rust/tests/util/mod.rs b/mojo/public/rust/tests/util/mod.rs new file mode 100644 index 0000000..cb1b719 --- /dev/null +++ b/mojo/public/rust/tests/util/mod.rs
@@ -0,0 +1,98 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! This module contains useful functions and macros for testing. + +pub mod mojom_validation; + +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; +use std::ptr; +use std::slice; + +/// This macro sets up tests by adding in Mojo embedder +/// initialization. +macro_rules! tests { + ( $( $( #[ $attr:meta ] )* fn $i:ident() $b:block)* ) => { + use std::sync::{Once, ONCE_INIT}; + static START: Once = ONCE_INIT; + $( + #[test] + $( + #[ $attr ] + )* + fn $i() { + START.call_once(|| unsafe { + util::InitializeMojoEmbedder(); + }); + $b + } + )* + } +} + +#[link(name = "stdc++")] +extern "C" {} + +#[link(name = "c")] +extern "C" { + fn free(ptr: *mut u8); +} + +#[link(name = "rust_embedder")] +extern "C" { + pub fn InitializeMojoEmbedder(); +} + +#[link(name = "validation_parser")] +extern "C" { + #[allow(dead_code)] + fn ParseValidationTest( + input: *const c_char, + num_handles: *mut usize, + data: *mut *mut u8, + data_len: *mut usize, + ) -> *mut c_char; +} + +#[allow(dead_code)] +pub fn parse_validation_test(input: &str) -> Result<(Vec<u8>, usize), String> { + let input_c = CString::new(input.to_string()).unwrap(); + let mut num_handles: usize = 0; + let mut data: *mut u8 = ptr::null_mut(); + let mut data_len: usize = 0; + let error = unsafe { + ParseValidationTest( + input_c.as_ptr(), + &mut num_handles as *mut usize, + &mut data as *mut *mut u8, + &mut data_len as *mut usize, + ) + }; + if error == ptr::null_mut() { + if data == ptr::null_mut() || data_len == 0 { + // We assume we were just given an empty file + Ok((Vec::new(), 0)) + } else { + // Make a copy of the buffer + let buffer; + unsafe { + buffer = slice::from_raw_parts(data, data_len).to_vec(); + free(data); + } + Ok((buffer, num_handles)) + } + } else { + let err_str; + unsafe { + // Copy the error string + err_str = CStr::from_ptr(error) + .to_str() + .expect("Could not convert error message to UTF-8!") + .to_owned(); + free(error as *mut u8); + } + Err(err_str) + } +}
diff --git a/mojo/public/rust/tests/util/mojom_validation.rs b/mojo/public/rust/tests/util/mojom_validation.rs new file mode 100644 index 0000000..00d3f0c --- /dev/null +++ b/mojo/public/rust/tests/util/mojom_validation.rs
@@ -0,0 +1,3215 @@ +//! This file was auto-generated by the Rust bindings generator. +#![allow(bad_style)] +#![allow(unused_imports)] +#![allow(unused_variables)] +#![allow(dead_code)] + +use mojo::bindings::decoding; +use mojo::bindings::decoding::{Decoder, ValidationError}; +use mojo::bindings::encoding; +use mojo::bindings::encoding::{Context, DataHeaderValue, Encoder, DATA_HEADER_SIZE}; +use mojo::bindings::message; +use mojo::bindings::message::MessageHeader; +use mojo::bindings::mojom; +use mojo::bindings::mojom::{ + MojomEncodable, MojomInterface, MojomInterfaceRecv, MojomInterfaceSend, MojomMessage, + MojomMessageOption, MojomPointer, MojomStruct, MojomUnion, UNION_SIZE, +}; + +use mojo::system; +use mojo::system::message_pipe; +use mojo::system::{CastHandle, UntypedHandle}; + +use std::collections::HashMap; +use std::vec::Vec; + +// Top-level constants: + +// Structs: +// -- StructA -- + +// Constants +// Enums +// Struct version information +const StructAVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct StructA { + pub i: u64, +} + +impl MojomPointer for StructA { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.i, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&StructAVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let i = match <u64>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(StructA { i: i }) + } +} + +impl MojomEncodable for StructA { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.i.compute_size(context.clone()) + } +} + +impl MojomStruct for StructA {} +// -- StructB -- + +// Constants +// Enums +// Struct version information +const StructBVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct StructB { + pub struct_a: StructA, +} + +impl MojomPointer for StructB { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.struct_a, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&StructBVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let struct_a = match <StructA>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(StructB { struct_a: struct_a }) + } +} + +impl MojomEncodable for StructB { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.struct_a.compute_size(context.clone()) + } +} + +impl MojomStruct for StructB {} +// -- StructC -- + +// Constants +// Enums +// Struct version information +const StructCVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct StructC { + pub data: Vec<u8>, +} + +impl MojomPointer for StructC { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.data, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&StructCVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let data = match <Vec<u8>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(StructC { data: data }) + } +} + +impl MojomEncodable for StructC { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.data.compute_size(context.clone()) + } +} + +impl MojomStruct for StructC {} +// -- StructD -- + +// Constants +// Enums +// Struct version information +const StructDVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct StructD { + pub message_pipes: Vec<message_pipe::MessageEndpoint>, +} + +impl MojomPointer for StructD { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.message_pipes, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&StructDVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let message_pipes = + match <Vec<message_pipe::MessageEndpoint>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(StructD { message_pipes: message_pipes }) + } +} + +impl MojomEncodable for StructD { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.message_pipes.compute_size(context.clone()) + } +} + +impl MojomStruct for StructD {} +// -- StructE -- + +// Constants +// Enums +// Struct version information +const StructEVersions: [(u32, u32); 1] = [(0, 24)]; + +// Struct definition +pub struct StructE { + pub struct_d: StructD, + pub data_pipe_consumer: system::data_pipe::Consumer<u8>, +} + +impl MojomPointer for StructE { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 24 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.struct_d, encoder, context.clone()); + MojomEncodable::encode(self.data_pipe_consumer, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&StructEVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let struct_d = match <StructD>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + let data_pipe_consumer = + match <system::data_pipe::Consumer<u8>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(StructE { struct_d: struct_d, data_pipe_consumer: data_pipe_consumer }) + } +} + +impl MojomEncodable for StructE { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.struct_d.compute_size(context.clone()) + + self.data_pipe_consumer.compute_size(context.clone()) + } +} + +impl MojomStruct for StructE {} +// -- StructF -- + +// Constants +// Enums +// Struct version information +const StructFVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct StructF { + pub fixed_size_array: [u8; 3], +} + +impl MojomPointer for StructF { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.fixed_size_array, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&StructFVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let fixed_size_array = match <[u8; 3]>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(StructF { fixed_size_array: fixed_size_array }) + } +} + +impl MojomEncodable for StructF { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.fixed_size_array.compute_size(context.clone()) + } +} + +impl MojomStruct for StructF {} +// -- StructG -- + +// Constants +// Enums +// Struct version information +const StructGVersions: [(u32, u32); 3] = [(0, 16), (1, 24), (3, 32)]; + +// Struct definition +pub struct StructG { + pub i: i32, + pub b: bool, + pub struct_a: Option<StructA>, + pub str: Option<String>, +} + +impl MojomPointer for StructG { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(3) + } + fn serialized_size(&self, _context: &Context) -> usize { + 32 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.i, encoder, context.clone()); + MojomEncodable::encode(self.b, encoder, context.clone()); + MojomEncodable::encode(self.struct_a, encoder, context.clone()); + MojomEncodable::encode(self.str, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&StructGVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let i = match <i32>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + let b = if version >= 3 { + match <bool>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + } else { + Default::default() + }; + let struct_a = if version >= 1 { + match <Option<StructA>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + } else { + Default::default() + }; + let str = if version >= 3 { + match <Option<String>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + } else { + Default::default() + }; + Ok(StructG { i: i, b: b, struct_a: struct_a, str: str }) + } +} + +impl MojomEncodable for StructG { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.i.compute_size(context.clone()) + + self.b.compute_size(context.clone()) + + self.struct_a.compute_size(context.clone()) + + self.str.compute_size(context.clone()) + } +} + +impl MojomStruct for StructG {} +// -- StructH -- + +// Constants +// Enums +// Struct version information +const StructHVersions: [(u32, u32); 1] = [(0, 48)]; + +// Struct definition +pub struct StructH { + pub a: bool, + pub b: u8, + pub c: Option<UnionA>, + pub d: Option<Vec<UnionA>>, + pub e: Option<HashMap<u8, UnionA>>, +} + +impl MojomPointer for StructH { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 48 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.a, encoder, context.clone()); + MojomEncodable::encode(self.b, encoder, context.clone()); + MojomEncodable::encode(self.c, encoder, context.clone()); + MojomEncodable::encode(self.d, encoder, context.clone()); + MojomEncodable::encode(self.e, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&StructHVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let a = match <bool>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + let b = match <u8>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + let c = match <Option<UnionA>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + let d = match <Option<Vec<UnionA>>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + let e = match <Option<HashMap<u8, UnionA>>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(StructH { a: a, b: b, c: c, d: d, e: e }) + } +} + +impl MojomEncodable for StructH { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.a.compute_size(context.clone()) + + self.b.compute_size(context.clone()) + + self.c.compute_size(context.clone()) + + self.d.compute_size(context.clone()) + + self.e.compute_size(context.clone()) + } +} + +impl MojomStruct for StructH {} +// -- BasicStruct -- + +// Constants +// Enums +// Struct version information +const BasicStructVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct BasicStruct { + pub a: i32, +} + +impl MojomPointer for BasicStruct { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.a, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&BasicStructVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let a = match <i32>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(BasicStruct { a: a }) + } +} + +impl MojomEncodable for BasicStruct { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.a.compute_size(context.clone()) + } +} + +impl MojomStruct for BasicStruct {} +// -- StructWithEnum -- + +// Constants +// Enums +type StructWithEnumEnumWithin = i32; +const StructWithEnumEnumWithin_A: StructWithEnumEnumWithin = 0; +const StructWithEnumEnumWithin_B: StructWithEnumEnumWithin = 1; +const StructWithEnumEnumWithin_C: StructWithEnumEnumWithin = 2; +const StructWithEnumEnumWithin_D: StructWithEnumEnumWithin = 3; + +const StructWithEnumEnumWithin__UNKNOWN: StructWithEnumEnumWithin = 0x7FFFFFFF; + +// Struct version information +const StructWithEnumVersions: [(u32, u32); 1] = [(0, 8)]; + +// Struct definition +pub struct StructWithEnum {} + +impl MojomPointer for StructWithEnum { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 8 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) {} + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&StructWithEnumVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + Ok(StructWithEnum {}) + } +} + +impl MojomEncodable for StructWithEnum { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + } +} + +impl MojomStruct for StructWithEnum {} + +// Mojom Unions: +// -- UnionA -- + +type UnionATag = u32; +const UnionATag_a: UnionATag = 0; +const UnionATag_b: UnionATag = 1; +const UnionATag_c: UnionATag = 2; +const UnionATag_d: UnionATag = 3; +const UnionATag_e: UnionATag = 4; +const UnionATag_f: UnionATag = 5; +const UnionATag_g: UnionATag = 6; +const UnionATag_h: UnionATag = 7; +const UnionATag_i: UnionATag = 8; +const UnionATag_j: UnionATag = 9; + +const UnionATag__UNKNOWN: UnionATag = 0xFFFFFFFF; + +pub enum UnionA { + a(u16), + b(u32), + c(Option<StructA>), + d(Option<Vec<u8>>), + e(Option<HashMap<String, u8>>), + f(Option<UnionB>), + g(StructA), + h(Vec<u8>), + i(HashMap<String, u8>), + j(UnionB), + _Unknown(u64), +} + +impl MojomUnion for UnionA { + fn get_tag(&self) -> u32 { + match *self { + UnionA::a(_) => UnionATag_a, + UnionA::b(_) => UnionATag_b, + UnionA::c(_) => UnionATag_c, + UnionA::d(_) => UnionATag_d, + UnionA::e(_) => UnionATag_e, + UnionA::f(_) => UnionATag_f, + UnionA::g(_) => UnionATag_g, + UnionA::h(_) => UnionATag_h, + UnionA::i(_) => UnionATag_i, + UnionA::j(_) => UnionATag_j, + UnionA::_Unknown(_) => UnionATag__UNKNOWN, + } + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + match self { + UnionA::a(val) => MojomEncodable::encode(val, encoder, context.clone()), + UnionA::b(val) => MojomEncodable::encode(val, encoder, context.clone()), + UnionA::c(val) => MojomEncodable::encode(val, encoder, context.clone()), + UnionA::d(val) => MojomEncodable::encode(val, encoder, context.clone()), + UnionA::e(val) => MojomEncodable::encode(val, encoder, context.clone()), + UnionA::f(val) => MojomEncodable::encode(val, encoder, context.clone()), + UnionA::g(val) => MojomEncodable::encode(val, encoder, context.clone()), + UnionA::h(val) => MojomEncodable::encode(val, encoder, context.clone()), + UnionA::i(val) => MojomEncodable::encode(val, encoder, context.clone()), + UnionA::j(val) => MojomEncodable::encode(val, encoder, context.clone()), + UnionA::_Unknown(val) => MojomEncodable::encode(val, encoder, context.clone()), + } + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let tag = { + let mut state = decoder.get_mut(&context); + let bytes = state.decode::<u32>(); + if (bytes as usize) != UNION_SIZE { + return Err(ValidationError::UnexpectedNullUnion); + } + state.decode::<u32>() + }; + Ok(match tag { + UnionATag_a => UnionA::a({ + match <u16>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + }), + UnionATag_b => UnionA::b({ + match <u32>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + }), + UnionATag_c => UnionA::c({ + match <Option<StructA>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + }), + UnionATag_d => UnionA::d({ + match <Option<Vec<u8>>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + }), + UnionATag_e => UnionA::e({ + match <Option<HashMap<String, u8>>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + }), + UnionATag_f => UnionA::f({ + match <Option<UnionB>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + }), + UnionATag_g => UnionA::g({ + match <StructA>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + }), + UnionATag_h => UnionA::h({ + match <Vec<u8>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + }), + UnionATag_i => UnionA::i({ + match <HashMap<String, u8>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + }), + UnionATag_j => UnionA::j({ + match <UnionB>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + }), + _ => UnionA::_Unknown(u64::decode(decoder, context.clone()).unwrap()), + }) + } +} + +impl MojomEncodable for UnionA { + impl_encodable_for_union!(); + fn compute_size(&self, context: Context) -> usize { + UNION_SIZE + + match *self { + UnionA::a(ref val) => val.compute_size(context.clone()), + UnionA::b(ref val) => val.compute_size(context.clone()), + UnionA::c(ref val) => val.compute_size(context.clone()), + UnionA::d(ref val) => val.compute_size(context.clone()), + UnionA::e(ref val) => val.compute_size(context.clone()), + UnionA::f(ref val) => val.compute_size(context.clone()), + UnionA::g(ref val) => val.compute_size(context.clone()), + UnionA::h(ref val) => val.compute_size(context.clone()), + UnionA::i(ref val) => val.compute_size(context.clone()), + UnionA::j(ref val) => val.compute_size(context.clone()), + UnionA::_Unknown(ref val) => 0, + } + } +} + +// -- UnionB -- + +type UnionBTag = u32; +const UnionBTag_a: UnionBTag = 0; +const UnionBTag_b: UnionBTag = 1; +const UnionBTag_c: UnionBTag = 2; +const UnionBTag_d: UnionBTag = 3; + +const UnionBTag__UNKNOWN: UnionBTag = 0xFFFFFFFF; + +pub enum UnionB { + a(u16), + b(u32), + c(u64), + d(u32), + _Unknown(u64), +} + +impl MojomUnion for UnionB { + fn get_tag(&self) -> u32 { + match *self { + UnionB::a(_) => UnionBTag_a, + UnionB::b(_) => UnionBTag_b, + UnionB::c(_) => UnionBTag_c, + UnionB::d(_) => UnionBTag_d, + UnionB::_Unknown(_) => UnionBTag__UNKNOWN, + } + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + match self { + UnionB::a(val) => MojomEncodable::encode(val, encoder, context.clone()), + UnionB::b(val) => MojomEncodable::encode(val, encoder, context.clone()), + UnionB::c(val) => MojomEncodable::encode(val, encoder, context.clone()), + UnionB::d(val) => MojomEncodable::encode(val, encoder, context.clone()), + UnionB::_Unknown(val) => MojomEncodable::encode(val, encoder, context.clone()), + } + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let tag = { + let mut state = decoder.get_mut(&context); + let bytes = state.decode::<u32>(); + if (bytes as usize) != UNION_SIZE { + return Err(ValidationError::UnexpectedNullUnion); + } + state.decode::<u32>() + }; + Ok(match tag { + UnionBTag_a => UnionB::a({ + match <u16>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + }), + UnionBTag_b => UnionB::b({ + match <u32>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + }), + UnionBTag_c => UnionB::c({ + match <u64>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + }), + UnionBTag_d => UnionB::d({ + match <u32>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + } + }), + _ => UnionB::_Unknown(u64::decode(decoder, context.clone()).unwrap()), + }) + } +} + +impl MojomEncodable for UnionB { + impl_encodable_for_union!(); + fn compute_size(&self, context: Context) -> usize { + UNION_SIZE + + match *self { + UnionB::a(ref val) => val.compute_size(context.clone()), + UnionB::b(ref val) => val.compute_size(context.clone()), + UnionB::c(ref val) => val.compute_size(context.clone()), + UnionB::d(ref val) => val.compute_size(context.clone()), + UnionB::_Unknown(ref val) => 0, + } + } +} + +// Mojom Enums: +type BasicEnum = i32; +const BasicEnum_A: BasicEnum = 0; +const BasicEnum_B: BasicEnum = 1; +const BasicEnum_C: BasicEnum = 0; +const BasicEnum_D: BasicEnum = -3; +const BasicEnum_E: BasicEnum = 10; + +const BasicEnum__UNKNOWN: BasicEnum = 0x7FFFFFFF; + +// Interfaces: +// --- InterfaceA --- + +pub mod InterfaceA { + pub const SERVICE_NAME: &'static str = ""; + pub const VERSION: u32 = 0; +} + +pub struct InterfaceAClient { + pipe: message_pipe::MessageEndpoint, + version: u32, +} + +impl InterfaceAClient { + pub fn new(pipe: message_pipe::MessageEndpoint) -> InterfaceAClient { + InterfaceAClient { pipe: pipe, version: InterfaceA::VERSION } + } + pub fn with_version(pipe: message_pipe::MessageEndpoint, version: u32) -> InterfaceAClient { + InterfaceAClient { pipe: pipe, version: version } + } +} + +impl MojomInterface for InterfaceAClient { + fn service_name() -> &'static str { + InterfaceA::SERVICE_NAME + } + fn version(&self) -> u32 { + self.version + } + fn pipe(&self) -> &message_pipe::MessageEndpoint { + &self.pipe + } + fn unwrap(self) -> message_pipe::MessageEndpoint { + self.pipe + } +} + +impl CastHandle for InterfaceAClient { + unsafe fn from_untyped(handle: system::UntypedHandle) -> InterfaceAClient { + InterfaceAClient { + pipe: message_pipe::MessageEndpoint::from_untyped(handle), + version: 0, // Since we have no other information, assume its the base + } + } + fn as_untyped(self) -> system::UntypedHandle { + self.pipe.as_untyped() + } +} + +impl MojomEncodable for InterfaceAClient { + impl_encodable_for_interface!(); +} + +impl<R: InterfaceARequest> MojomInterfaceSend<R> for InterfaceAClient {} +impl MojomInterfaceRecv for InterfaceAClient { + type Container = InterfaceAResponseOption; +} + +pub struct InterfaceAServer { + pipe: message_pipe::MessageEndpoint, + version: u32, +} + +impl InterfaceAServer { + pub fn new(pipe: message_pipe::MessageEndpoint) -> InterfaceAServer { + InterfaceAServer { pipe: pipe, version: InterfaceA::VERSION } + } + pub fn with_version(pipe: message_pipe::MessageEndpoint, version: u32) -> InterfaceAServer { + InterfaceAServer { pipe: pipe, version: version } + } +} + +impl MojomInterface for InterfaceAServer { + fn service_name() -> &'static str { + InterfaceA::SERVICE_NAME + } + fn version(&self) -> u32 { + self.version + } + fn pipe(&self) -> &message_pipe::MessageEndpoint { + &self.pipe + } + fn unwrap(self) -> message_pipe::MessageEndpoint { + self.pipe + } +} + +impl CastHandle for InterfaceAServer { + unsafe fn from_untyped(handle: system::UntypedHandle) -> InterfaceAServer { + InterfaceAServer { + pipe: message_pipe::MessageEndpoint::from_untyped(handle), + version: 0, // Since we have no other information, assume its the base + } + } + fn as_untyped(self) -> system::UntypedHandle { + self.pipe.as_untyped() + } +} + +impl MojomEncodable for InterfaceAServer { + impl_encodable_for_interface!(); +} + +impl<R: InterfaceAResponse> MojomInterfaceSend<R> for InterfaceAServer {} +impl MojomInterfaceRecv for InterfaceAServer { + type Container = InterfaceARequestOption; +} + +// Enums + +// Constants + +pub trait InterfaceARequest: MojomMessage {} +pub trait InterfaceAResponse: MojomMessage {} + +pub enum InterfaceARequestOption {} + +impl MojomMessageOption for InterfaceARequestOption { + fn decode_payload( + header: MessageHeader, + buffer: &[u8], + handles: Vec<UntypedHandle>, + ) -> Result<Self, ValidationError> { + match header.name { + _ => Err(ValidationError::MessageHeaderUnknownMethod), + } + } +} + +pub enum InterfaceAResponseOption {} + +impl MojomMessageOption for InterfaceAResponseOption { + fn decode_payload( + header: MessageHeader, + buffer: &[u8], + handles: Vec<UntypedHandle>, + ) -> Result<Self, ValidationError> { + if header.flags != message::MESSAGE_HEADER_IS_RESPONSE { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match header.name { + _ => Err(ValidationError::MessageHeaderUnknownMethod), + } + } +} + +// --- BoundsCheckTestInterface --- + +pub mod BoundsCheckTestInterface { + pub const SERVICE_NAME: &'static str = "this.is.the.service.name.for.BoundsCheckTestInterface"; + pub const VERSION: u32 = 0; +} + +pub struct BoundsCheckTestInterfaceClient { + pipe: message_pipe::MessageEndpoint, + version: u32, +} + +impl BoundsCheckTestInterfaceClient { + pub fn new(pipe: message_pipe::MessageEndpoint) -> BoundsCheckTestInterfaceClient { + BoundsCheckTestInterfaceClient { pipe: pipe, version: BoundsCheckTestInterface::VERSION } + } + pub fn with_version( + pipe: message_pipe::MessageEndpoint, + version: u32, + ) -> BoundsCheckTestInterfaceClient { + BoundsCheckTestInterfaceClient { pipe: pipe, version: version } + } +} + +impl MojomInterface for BoundsCheckTestInterfaceClient { + fn service_name() -> &'static str { + BoundsCheckTestInterface::SERVICE_NAME + } + fn version(&self) -> u32 { + self.version + } + fn pipe(&self) -> &message_pipe::MessageEndpoint { + &self.pipe + } + fn unwrap(self) -> message_pipe::MessageEndpoint { + self.pipe + } +} + +impl CastHandle for BoundsCheckTestInterfaceClient { + unsafe fn from_untyped(handle: system::UntypedHandle) -> BoundsCheckTestInterfaceClient { + BoundsCheckTestInterfaceClient { + pipe: message_pipe::MessageEndpoint::from_untyped(handle), + version: 0, // Since we have no other information, assume its the base + } + } + fn as_untyped(self) -> system::UntypedHandle { + self.pipe.as_untyped() + } +} + +impl MojomEncodable for BoundsCheckTestInterfaceClient { + impl_encodable_for_interface!(); +} + +impl<R: BoundsCheckTestInterfaceRequest> MojomInterfaceSend<R> for BoundsCheckTestInterfaceClient {} +impl MojomInterfaceRecv for BoundsCheckTestInterfaceClient { + type Container = BoundsCheckTestInterfaceResponseOption; +} + +pub struct BoundsCheckTestInterfaceServer { + pipe: message_pipe::MessageEndpoint, + version: u32, +} + +impl BoundsCheckTestInterfaceServer { + pub fn new(pipe: message_pipe::MessageEndpoint) -> BoundsCheckTestInterfaceServer { + BoundsCheckTestInterfaceServer { pipe: pipe, version: BoundsCheckTestInterface::VERSION } + } + pub fn with_version( + pipe: message_pipe::MessageEndpoint, + version: u32, + ) -> BoundsCheckTestInterfaceServer { + BoundsCheckTestInterfaceServer { pipe: pipe, version: version } + } +} + +impl MojomInterface for BoundsCheckTestInterfaceServer { + fn service_name() -> &'static str { + BoundsCheckTestInterface::SERVICE_NAME + } + fn version(&self) -> u32 { + self.version + } + fn pipe(&self) -> &message_pipe::MessageEndpoint { + &self.pipe + } + fn unwrap(self) -> message_pipe::MessageEndpoint { + self.pipe + } +} + +impl CastHandle for BoundsCheckTestInterfaceServer { + unsafe fn from_untyped(handle: system::UntypedHandle) -> BoundsCheckTestInterfaceServer { + BoundsCheckTestInterfaceServer { + pipe: message_pipe::MessageEndpoint::from_untyped(handle), + version: 0, // Since we have no other information, assume its the base + } + } + fn as_untyped(self) -> system::UntypedHandle { + self.pipe.as_untyped() + } +} + +impl MojomEncodable for BoundsCheckTestInterfaceServer { + impl_encodable_for_interface!(); +} + +impl<R: BoundsCheckTestInterfaceResponse> MojomInterfaceSend<R> for BoundsCheckTestInterfaceServer {} +impl MojomInterfaceRecv for BoundsCheckTestInterfaceServer { + type Container = BoundsCheckTestInterfaceRequestOption; +} + +// Enums + +// Constants + +pub trait BoundsCheckTestInterfaceRequest: MojomMessage {} +pub trait BoundsCheckTestInterfaceResponse: MojomMessage {} + +pub enum BoundsCheckTestInterfaceRequestOption { + BoundsCheckTestInterfaceMethod0(BoundsCheckTestInterfaceMethod0Request), + BoundsCheckTestInterfaceMethod1(BoundsCheckTestInterfaceMethod1Request), +} + +impl MojomMessageOption for BoundsCheckTestInterfaceRequestOption { + fn decode_payload( + header: MessageHeader, + buffer: &[u8], + handles: Vec<UntypedHandle>, + ) -> Result<Self, ValidationError> { + match header.name { + BoundsCheckTestInterfaceMethod0::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_EXPECT_RESPONSE { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match BoundsCheckTestInterfaceMethod0Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(BoundsCheckTestInterfaceRequestOption::BoundsCheckTestInterfaceMethod0( + value, + )) + } + Err(err) => return Err(err), + } + } + BoundsCheckTestInterfaceMethod1::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match BoundsCheckTestInterfaceMethod1Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(BoundsCheckTestInterfaceRequestOption::BoundsCheckTestInterfaceMethod1( + value, + )) + } + Err(err) => return Err(err), + } + } + _ => Err(ValidationError::MessageHeaderUnknownMethod), + } + } +} + +pub enum BoundsCheckTestInterfaceResponseOption { + BoundsCheckTestInterfaceMethod0(BoundsCheckTestInterfaceMethod0Response), +} + +impl MojomMessageOption for BoundsCheckTestInterfaceResponseOption { + fn decode_payload( + header: MessageHeader, + buffer: &[u8], + handles: Vec<UntypedHandle>, + ) -> Result<Self, ValidationError> { + if header.flags != message::MESSAGE_HEADER_IS_RESPONSE { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match header.name { + BoundsCheckTestInterfaceMethod0::ORDINAL => { + match BoundsCheckTestInterfaceMethod0Response::deserialize(buffer, handles) { + Ok(value) => { + Ok(BoundsCheckTestInterfaceResponseOption::BoundsCheckTestInterfaceMethod0( + value, + )) + } + Err(err) => return Err(err), + } + } + _ => Err(ValidationError::MessageHeaderUnknownMethod), + } + } +} + +/// Message: BoundsCheckTestInterfaceMethod0 +pub mod BoundsCheckTestInterfaceMethod0 { + pub const ORDINAL: u32 = 0; + pub const MIN_VERSION: u32 = 0; +} +// -- BoundsCheckTestInterfaceMethod0Request -- + +// Constants +// Enums +// Struct version information +const BoundsCheckTestInterfaceMethod0RequestVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct BoundsCheckTestInterfaceMethod0Request { + pub param0: u8, +} + +impl MojomPointer for BoundsCheckTestInterfaceMethod0Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&BoundsCheckTestInterfaceMethod0RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <u8>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(BoundsCheckTestInterfaceMethod0Request { param0: param0 }) + } +} + +impl MojomEncodable for BoundsCheckTestInterfaceMethod0Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for BoundsCheckTestInterfaceMethod0Request {} +impl MojomMessage for BoundsCheckTestInterfaceMethod0Request { + fn min_version() -> u32 { + BoundsCheckTestInterfaceMethod0::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 1, + BoundsCheckTestInterfaceMethod0::ORDINAL, + message::MESSAGE_HEADER_EXPECT_RESPONSE, + ) + } +} +impl BoundsCheckTestInterfaceRequest for BoundsCheckTestInterfaceMethod0Request {} + +// -- BoundsCheckTestInterfaceMethod0Response -- + +// Constants +// Enums +// Struct version information +const BoundsCheckTestInterfaceMethod0ResponseVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct BoundsCheckTestInterfaceMethod0Response { + pub param0: u8, +} + +impl MojomPointer for BoundsCheckTestInterfaceMethod0Response { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&BoundsCheckTestInterfaceMethod0ResponseVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <u8>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(BoundsCheckTestInterfaceMethod0Response { param0: param0 }) + } +} + +impl MojomEncodable for BoundsCheckTestInterfaceMethod0Response { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for BoundsCheckTestInterfaceMethod0Response {} + +impl MojomMessage for BoundsCheckTestInterfaceMethod0Response { + fn min_version() -> u32 { + BoundsCheckTestInterfaceMethod0::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 1, + BoundsCheckTestInterfaceMethod0::ORDINAL, + message::MESSAGE_HEADER_IS_RESPONSE, + ) + } +} +impl BoundsCheckTestInterfaceResponse for BoundsCheckTestInterfaceMethod0Response {} +/// Message: BoundsCheckTestInterfaceMethod1 +pub mod BoundsCheckTestInterfaceMethod1 { + pub const ORDINAL: u32 = 1; + pub const MIN_VERSION: u32 = 0; +} +// -- BoundsCheckTestInterfaceMethod1Request -- + +// Constants +// Enums +// Struct version information +const BoundsCheckTestInterfaceMethod1RequestVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct BoundsCheckTestInterfaceMethod1Request { + pub param0: u8, +} + +impl MojomPointer for BoundsCheckTestInterfaceMethod1Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&BoundsCheckTestInterfaceMethod1RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <u8>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(BoundsCheckTestInterfaceMethod1Request { param0: param0 }) + } +} + +impl MojomEncodable for BoundsCheckTestInterfaceMethod1Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for BoundsCheckTestInterfaceMethod1Request {} +impl MojomMessage for BoundsCheckTestInterfaceMethod1Request { + fn min_version() -> u32 { + BoundsCheckTestInterfaceMethod1::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + BoundsCheckTestInterfaceMethod1::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl BoundsCheckTestInterfaceRequest for BoundsCheckTestInterfaceMethod1Request {} + +// --- ConformanceTestInterface --- + +pub mod ConformanceTestInterface { + pub const SERVICE_NAME: &'static str = ""; + pub const VERSION: u32 = 0; +} + +pub struct ConformanceTestInterfaceClient { + pipe: message_pipe::MessageEndpoint, + version: u32, +} + +impl ConformanceTestInterfaceClient { + pub fn new(pipe: message_pipe::MessageEndpoint) -> ConformanceTestInterfaceClient { + ConformanceTestInterfaceClient { pipe: pipe, version: ConformanceTestInterface::VERSION } + } + pub fn with_version( + pipe: message_pipe::MessageEndpoint, + version: u32, + ) -> ConformanceTestInterfaceClient { + ConformanceTestInterfaceClient { pipe: pipe, version: version } + } +} + +impl MojomInterface for ConformanceTestInterfaceClient { + fn service_name() -> &'static str { + ConformanceTestInterface::SERVICE_NAME + } + fn version(&self) -> u32 { + self.version + } + fn pipe(&self) -> &message_pipe::MessageEndpoint { + &self.pipe + } + fn unwrap(self) -> message_pipe::MessageEndpoint { + self.pipe + } +} + +impl CastHandle for ConformanceTestInterfaceClient { + unsafe fn from_untyped(handle: system::UntypedHandle) -> ConformanceTestInterfaceClient { + ConformanceTestInterfaceClient { + pipe: message_pipe::MessageEndpoint::from_untyped(handle), + version: 0, // Since we have no other information, assume its the base + } + } + fn as_untyped(self) -> system::UntypedHandle { + self.pipe.as_untyped() + } +} + +impl MojomEncodable for ConformanceTestInterfaceClient { + impl_encodable_for_interface!(); +} + +impl<R: ConformanceTestInterfaceRequest> MojomInterfaceSend<R> for ConformanceTestInterfaceClient {} +impl MojomInterfaceRecv for ConformanceTestInterfaceClient { + type Container = ConformanceTestInterfaceResponseOption; +} + +pub struct ConformanceTestInterfaceServer { + pipe: message_pipe::MessageEndpoint, + version: u32, +} + +impl ConformanceTestInterfaceServer { + pub fn new(pipe: message_pipe::MessageEndpoint) -> ConformanceTestInterfaceServer { + ConformanceTestInterfaceServer { pipe: pipe, version: ConformanceTestInterface::VERSION } + } + pub fn with_version( + pipe: message_pipe::MessageEndpoint, + version: u32, + ) -> ConformanceTestInterfaceServer { + ConformanceTestInterfaceServer { pipe: pipe, version: version } + } +} + +impl MojomInterface for ConformanceTestInterfaceServer { + fn service_name() -> &'static str { + ConformanceTestInterface::SERVICE_NAME + } + fn version(&self) -> u32 { + self.version + } + fn pipe(&self) -> &message_pipe::MessageEndpoint { + &self.pipe + } + fn unwrap(self) -> message_pipe::MessageEndpoint { + self.pipe + } +} + +impl CastHandle for ConformanceTestInterfaceServer { + unsafe fn from_untyped(handle: system::UntypedHandle) -> ConformanceTestInterfaceServer { + ConformanceTestInterfaceServer { + pipe: message_pipe::MessageEndpoint::from_untyped(handle), + version: 0, // Since we have no other information, assume its the base + } + } + fn as_untyped(self) -> system::UntypedHandle { + self.pipe.as_untyped() + } +} + +impl MojomEncodable for ConformanceTestInterfaceServer { + impl_encodable_for_interface!(); +} + +impl<R: ConformanceTestInterfaceResponse> MojomInterfaceSend<R> for ConformanceTestInterfaceServer {} +impl MojomInterfaceRecv for ConformanceTestInterfaceServer { + type Container = ConformanceTestInterfaceRequestOption; +} + +// Enums + +// Constants + +pub trait ConformanceTestInterfaceRequest: MojomMessage {} +pub trait ConformanceTestInterfaceResponse: MojomMessage {} + +pub enum ConformanceTestInterfaceRequestOption { + ConformanceTestInterfaceMethod3(ConformanceTestInterfaceMethod3Request), + ConformanceTestInterfaceMethod4(ConformanceTestInterfaceMethod4Request), + ConformanceTestInterfaceMethod5(ConformanceTestInterfaceMethod5Request), + ConformanceTestInterfaceMethod7(ConformanceTestInterfaceMethod7Request), + ConformanceTestInterfaceMethod12(ConformanceTestInterfaceMethod12Request), + ConformanceTestInterfaceMethod14(ConformanceTestInterfaceMethod14Request), + ConformanceTestInterfaceMethod15(ConformanceTestInterfaceMethod15Request), + ConformanceTestInterfaceMethod1(ConformanceTestInterfaceMethod1Request), + ConformanceTestInterfaceMethod2(ConformanceTestInterfaceMethod2Request), + ConformanceTestInterfaceMethod6(ConformanceTestInterfaceMethod6Request), + ConformanceTestInterfaceMethod8(ConformanceTestInterfaceMethod8Request), + ConformanceTestInterfaceMethod10(ConformanceTestInterfaceMethod10Request), + ConformanceTestInterfaceMethod11(ConformanceTestInterfaceMethod11Request), + ConformanceTestInterfaceMethod0(ConformanceTestInterfaceMethod0Request), + ConformanceTestInterfaceMethod9(ConformanceTestInterfaceMethod9Request), + ConformanceTestInterfaceMethod13(ConformanceTestInterfaceMethod13Request), +} + +impl MojomMessageOption for ConformanceTestInterfaceRequestOption { + fn decode_payload( + header: MessageHeader, + buffer: &[u8], + handles: Vec<UntypedHandle>, + ) -> Result<Self, ValidationError> { + match header.name { + ConformanceTestInterfaceMethod3::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod3Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod3( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod4::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod4Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod4( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod5::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod5Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod5( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod7::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod7Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod7( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod12::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_EXPECT_RESPONSE { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod12Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod12( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod14::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod14Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod14( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod15::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod15Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod15( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod1::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod1Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod1( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod2::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod2Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod2( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod6::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod6Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod6( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod8::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod8Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod8( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod10::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod10Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod10( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod11::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod11Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod11( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod0::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod0Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod0( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod9::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod9Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod9( + value, + )) + } + Err(err) => return Err(err), + } + } + ConformanceTestInterfaceMethod13::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_NO_FLAG { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match ConformanceTestInterfaceMethod13Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(ConformanceTestInterfaceRequestOption::ConformanceTestInterfaceMethod13( + value, + )) + } + Err(err) => return Err(err), + } + } + _ => Err(ValidationError::MessageHeaderUnknownMethod), + } + } +} + +pub enum ConformanceTestInterfaceResponseOption { + ConformanceTestInterfaceMethod12(ConformanceTestInterfaceMethod12Response), +} + +impl MojomMessageOption for ConformanceTestInterfaceResponseOption { + fn decode_payload( + header: MessageHeader, + buffer: &[u8], + handles: Vec<UntypedHandle>, + ) -> Result<Self, ValidationError> { + if header.flags != message::MESSAGE_HEADER_IS_RESPONSE { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match header.name { + ConformanceTestInterfaceMethod12::ORDINAL => { + match ConformanceTestInterfaceMethod12Response::deserialize(buffer, handles) { + Ok(value) => Ok( + ConformanceTestInterfaceResponseOption::ConformanceTestInterfaceMethod12( + value, + ), + ), + Err(err) => return Err(err), + } + } + _ => Err(ValidationError::MessageHeaderUnknownMethod), + } + } +} + +/// Message: ConformanceTestInterfaceMethod3 +pub mod ConformanceTestInterfaceMethod3 { + pub const ORDINAL: u32 = 3; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod3Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod3RequestVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod3Request { + pub param0: Vec<bool>, +} + +impl MojomPointer for ConformanceTestInterfaceMethod3Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod3RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <Vec<bool>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod3Request { param0: param0 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod3Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod3Request {} +impl MojomMessage for ConformanceTestInterfaceMethod3Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod3::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod3::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod3Request {} + +/// Message: ConformanceTestInterfaceMethod4 +pub mod ConformanceTestInterfaceMethod4 { + pub const ORDINAL: u32 = 4; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod4Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod4RequestVersions: [(u32, u32); 1] = [(0, 24)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod4Request { + pub param0: StructC, + pub param1: Vec<u8>, +} + +impl MojomPointer for ConformanceTestInterfaceMethod4Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 24 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + MojomEncodable::encode(self.param1, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod4RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <StructC>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + let param1 = match <Vec<u8>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod4Request { param0: param0, param1: param1 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod4Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + + self.param1.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod4Request {} +impl MojomMessage for ConformanceTestInterfaceMethod4Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod4::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod4::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod4Request {} + +/// Message: ConformanceTestInterfaceMethod5 +pub mod ConformanceTestInterfaceMethod5 { + pub const ORDINAL: u32 = 5; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod5Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod5RequestVersions: [(u32, u32); 1] = [(0, 24)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod5Request { + pub param0: StructE, + pub param1: system::data_pipe::Producer<u8>, +} + +impl MojomPointer for ConformanceTestInterfaceMethod5Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 24 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + MojomEncodable::encode(self.param1, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod5RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <StructE>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + let param1 = match <system::data_pipe::Producer<u8>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod5Request { param0: param0, param1: param1 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod5Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + + self.param1.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod5Request {} +impl MojomMessage for ConformanceTestInterfaceMethod5Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod5::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod5::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod5Request {} + +/// Message: ConformanceTestInterfaceMethod7 +pub mod ConformanceTestInterfaceMethod7 { + pub const ORDINAL: u32 = 7; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod7Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod7RequestVersions: [(u32, u32); 1] = [(0, 24)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod7Request { + pub param0: StructF, + pub param1: [Option<[u8; 3]>; 2], +} + +impl MojomPointer for ConformanceTestInterfaceMethod7Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 24 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + MojomEncodable::encode(self.param1, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod7RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <StructF>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + let param1 = match <[Option<[u8; 3]>; 2]>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod7Request { param0: param0, param1: param1 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod7Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + + self.param1.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod7Request {} +impl MojomMessage for ConformanceTestInterfaceMethod7Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod7::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod7::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod7Request {} + +/// Message: ConformanceTestInterfaceMethod12 +pub mod ConformanceTestInterfaceMethod12 { + pub const ORDINAL: u32 = 12; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod12Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod12RequestVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod12Request { + pub param0: f32, +} + +impl MojomPointer for ConformanceTestInterfaceMethod12Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod12RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <f32>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod12Request { param0: param0 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod12Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod12Request {} +impl MojomMessage for ConformanceTestInterfaceMethod12Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod12::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 1, + ConformanceTestInterfaceMethod12::ORDINAL, + message::MESSAGE_HEADER_EXPECT_RESPONSE, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod12Request {} + +// -- ConformanceTestInterfaceMethod12Response -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod12ResponseVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod12Response { + pub param0: f32, +} + +impl MojomPointer for ConformanceTestInterfaceMethod12Response { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod12ResponseVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <f32>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod12Response { param0: param0 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod12Response { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod12Response {} + +impl MojomMessage for ConformanceTestInterfaceMethod12Response { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod12::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 1, + ConformanceTestInterfaceMethod12::ORDINAL, + message::MESSAGE_HEADER_IS_RESPONSE, + ) + } +} +impl ConformanceTestInterfaceResponse for ConformanceTestInterfaceMethod12Response {} +/// Message: ConformanceTestInterfaceMethod14 +pub mod ConformanceTestInterfaceMethod14 { + pub const ORDINAL: u32 = 14; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod14Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod14RequestVersions: [(u32, u32); 1] = [(0, 24)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod14Request { + pub param0: UnionA, +} + +impl MojomPointer for ConformanceTestInterfaceMethod14Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 24 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod14RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <UnionA>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod14Request { param0: param0 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod14Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod14Request {} +impl MojomMessage for ConformanceTestInterfaceMethod14Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod14::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod14::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod14Request {} + +/// Message: ConformanceTestInterfaceMethod15 +pub mod ConformanceTestInterfaceMethod15 { + pub const ORDINAL: u32 = 15; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod15Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod15RequestVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod15Request { + pub param0: StructH, +} + +impl MojomPointer for ConformanceTestInterfaceMethod15Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod15RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <StructH>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod15Request { param0: param0 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod15Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod15Request {} +impl MojomMessage for ConformanceTestInterfaceMethod15Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod15::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod15::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod15Request {} + +/// Message: ConformanceTestInterfaceMethod1 +pub mod ConformanceTestInterfaceMethod1 { + pub const ORDINAL: u32 = 1; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod1Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod1RequestVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod1Request { + pub param0: StructA, +} + +impl MojomPointer for ConformanceTestInterfaceMethod1Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod1RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <StructA>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod1Request { param0: param0 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod1Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod1Request {} +impl MojomMessage for ConformanceTestInterfaceMethod1Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod1::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod1::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod1Request {} + +/// Message: ConformanceTestInterfaceMethod2 +pub mod ConformanceTestInterfaceMethod2 { + pub const ORDINAL: u32 = 2; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod2Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod2RequestVersions: [(u32, u32); 1] = [(0, 24)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod2Request { + pub param0: StructB, + pub param1: StructA, +} + +impl MojomPointer for ConformanceTestInterfaceMethod2Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 24 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + MojomEncodable::encode(self.param1, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod2RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <StructB>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + let param1 = match <StructA>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod2Request { param0: param0, param1: param1 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod2Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + + self.param1.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod2Request {} +impl MojomMessage for ConformanceTestInterfaceMethod2Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod2::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod2::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod2Request {} + +/// Message: ConformanceTestInterfaceMethod6 +pub mod ConformanceTestInterfaceMethod6 { + pub const ORDINAL: u32 = 6; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod6Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod6RequestVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod6Request { + pub param0: Vec<Vec<u8>>, +} + +impl MojomPointer for ConformanceTestInterfaceMethod6Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod6RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <Vec<Vec<u8>>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod6Request { param0: param0 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod6Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod6Request {} +impl MojomMessage for ConformanceTestInterfaceMethod6Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod6::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod6::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod6Request {} + +/// Message: ConformanceTestInterfaceMethod8 +pub mod ConformanceTestInterfaceMethod8 { + pub const ORDINAL: u32 = 8; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod8Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod8RequestVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod8Request { + pub param0: Vec<Option<Vec<String>>>, +} + +impl MojomPointer for ConformanceTestInterfaceMethod8Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod8RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <Vec<Option<Vec<String>>>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod8Request { param0: param0 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod8Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod8Request {} +impl MojomMessage for ConformanceTestInterfaceMethod8Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod8::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod8::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod8Request {} + +/// Message: ConformanceTestInterfaceMethod10 +pub mod ConformanceTestInterfaceMethod10 { + pub const ORDINAL: u32 = 10; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod10Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod10RequestVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod10Request { + pub param0: HashMap<String, u8>, +} + +impl MojomPointer for ConformanceTestInterfaceMethod10Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod10RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <HashMap<String, u8>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod10Request { param0: param0 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod10Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod10Request {} +impl MojomMessage for ConformanceTestInterfaceMethod10Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod10::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod10::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod10Request {} + +/// Message: ConformanceTestInterfaceMethod11 +pub mod ConformanceTestInterfaceMethod11 { + pub const ORDINAL: u32 = 11; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod11Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod11RequestVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod11Request { + pub param0: StructG, +} + +impl MojomPointer for ConformanceTestInterfaceMethod11Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod11RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <StructG>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod11Request { param0: param0 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod11Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod11Request {} +impl MojomMessage for ConformanceTestInterfaceMethod11Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod11::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod11::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod11Request {} + +/// Message: ConformanceTestInterfaceMethod0 +pub mod ConformanceTestInterfaceMethod0 { + pub const ORDINAL: u32 = 0; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod0Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod0RequestVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod0Request { + pub param0: f32, +} + +impl MojomPointer for ConformanceTestInterfaceMethod0Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod0RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <f32>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod0Request { param0: param0 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod0Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod0Request {} +impl MojomMessage for ConformanceTestInterfaceMethod0Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod0::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod0::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod0Request {} + +/// Message: ConformanceTestInterfaceMethod9 +pub mod ConformanceTestInterfaceMethod9 { + pub const ORDINAL: u32 = 9; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod9Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod9RequestVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod9Request { + pub param0: Option<Vec<Vec<Option<system::UntypedHandle>>>>, +} + +impl MojomPointer for ConformanceTestInterfaceMethod9Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod9RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <Option<Vec<Vec<Option<system::UntypedHandle>>>>>::decode( + decoder, + context.clone(), + ) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod9Request { param0: param0 }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod9Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod9Request {} +impl MojomMessage for ConformanceTestInterfaceMethod9Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod9::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod9::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod9Request {} + +/// Message: ConformanceTestInterfaceMethod13 +pub mod ConformanceTestInterfaceMethod13 { + pub const ORDINAL: u32 = 13; + pub const MIN_VERSION: u32 = 0; +} +// -- ConformanceTestInterfaceMethod13Request -- + +// Constants +// Enums +// Struct version information +const ConformanceTestInterfaceMethod13RequestVersions: [(u32, u32); 1] = [(0, 32)]; + +// Struct definition +pub struct ConformanceTestInterfaceMethod13Request { + pub param0: Option<InterfaceAClient>, + pub param1: u32, + pub param2: Option<InterfaceAClient>, +} + +impl MojomPointer for ConformanceTestInterfaceMethod13Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 32 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + MojomEncodable::encode(self.param1, encoder, context.clone()); + MojomEncodable::encode(self.param2, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&ConformanceTestInterfaceMethod13RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <Option<InterfaceAClient>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + let param1 = match <u32>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + let param2 = match <Option<InterfaceAClient>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(ConformanceTestInterfaceMethod13Request { + param0: param0, + param1: param1, + param2: param2, + }) + } +} + +impl MojomEncodable for ConformanceTestInterfaceMethod13Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + + self.param1.compute_size(context.clone()) + + self.param2.compute_size(context.clone()) + } +} + +impl MojomStruct for ConformanceTestInterfaceMethod13Request {} +impl MojomMessage for ConformanceTestInterfaceMethod13Request { + fn min_version() -> u32 { + ConformanceTestInterfaceMethod13::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 0, + ConformanceTestInterfaceMethod13::ORDINAL, + message::MESSAGE_HEADER_NO_FLAG, + ) + } +} +impl ConformanceTestInterfaceRequest for ConformanceTestInterfaceMethod13Request {} + +// --- IntegrationTestInterface --- + +pub mod IntegrationTestInterface { + pub const SERVICE_NAME: &'static str = ""; + pub const VERSION: u32 = 0; +} + +pub struct IntegrationTestInterfaceClient { + pipe: message_pipe::MessageEndpoint, + version: u32, +} + +impl IntegrationTestInterfaceClient { + pub fn new(pipe: message_pipe::MessageEndpoint) -> IntegrationTestInterfaceClient { + IntegrationTestInterfaceClient { pipe: pipe, version: IntegrationTestInterface::VERSION } + } + pub fn with_version( + pipe: message_pipe::MessageEndpoint, + version: u32, + ) -> IntegrationTestInterfaceClient { + IntegrationTestInterfaceClient { pipe: pipe, version: version } + } +} + +impl MojomInterface for IntegrationTestInterfaceClient { + fn service_name() -> &'static str { + IntegrationTestInterface::SERVICE_NAME + } + fn version(&self) -> u32 { + self.version + } + fn pipe(&self) -> &message_pipe::MessageEndpoint { + &self.pipe + } + fn unwrap(self) -> message_pipe::MessageEndpoint { + self.pipe + } +} + +impl CastHandle for IntegrationTestInterfaceClient { + unsafe fn from_untyped(handle: system::UntypedHandle) -> IntegrationTestInterfaceClient { + IntegrationTestInterfaceClient { + pipe: message_pipe::MessageEndpoint::from_untyped(handle), + version: 0, // Since we have no other information, assume its the base + } + } + fn as_untyped(self) -> system::UntypedHandle { + self.pipe.as_untyped() + } +} + +impl MojomEncodable for IntegrationTestInterfaceClient { + impl_encodable_for_interface!(); +} + +impl<R: IntegrationTestInterfaceRequest> MojomInterfaceSend<R> for IntegrationTestInterfaceClient {} +impl MojomInterfaceRecv for IntegrationTestInterfaceClient { + type Container = IntegrationTestInterfaceResponseOption; +} + +pub struct IntegrationTestInterfaceServer { + pipe: message_pipe::MessageEndpoint, + version: u32, +} + +impl IntegrationTestInterfaceServer { + pub fn new(pipe: message_pipe::MessageEndpoint) -> IntegrationTestInterfaceServer { + IntegrationTestInterfaceServer { pipe: pipe, version: IntegrationTestInterface::VERSION } + } + pub fn with_version( + pipe: message_pipe::MessageEndpoint, + version: u32, + ) -> IntegrationTestInterfaceServer { + IntegrationTestInterfaceServer { pipe: pipe, version: version } + } +} + +impl MojomInterface for IntegrationTestInterfaceServer { + fn service_name() -> &'static str { + IntegrationTestInterface::SERVICE_NAME + } + fn version(&self) -> u32 { + self.version + } + fn pipe(&self) -> &message_pipe::MessageEndpoint { + &self.pipe + } + fn unwrap(self) -> message_pipe::MessageEndpoint { + self.pipe + } +} + +impl CastHandle for IntegrationTestInterfaceServer { + unsafe fn from_untyped(handle: system::UntypedHandle) -> IntegrationTestInterfaceServer { + IntegrationTestInterfaceServer { + pipe: message_pipe::MessageEndpoint::from_untyped(handle), + version: 0, // Since we have no other information, assume its the base + } + } + fn as_untyped(self) -> system::UntypedHandle { + self.pipe.as_untyped() + } +} + +impl MojomEncodable for IntegrationTestInterfaceServer { + impl_encodable_for_interface!(); +} + +impl<R: IntegrationTestInterfaceResponse> MojomInterfaceSend<R> for IntegrationTestInterfaceServer {} +impl MojomInterfaceRecv for IntegrationTestInterfaceServer { + type Container = IntegrationTestInterfaceRequestOption; +} + +// Enums + +// Constants + +pub trait IntegrationTestInterfaceRequest: MojomMessage {} +pub trait IntegrationTestInterfaceResponse: MojomMessage {} + +pub enum IntegrationTestInterfaceRequestOption { + IntegrationTestInterfaceMethod0(IntegrationTestInterfaceMethod0Request), +} + +impl MojomMessageOption for IntegrationTestInterfaceRequestOption { + fn decode_payload( + header: MessageHeader, + buffer: &[u8], + handles: Vec<UntypedHandle>, + ) -> Result<Self, ValidationError> { + match header.name { + IntegrationTestInterfaceMethod0::ORDINAL => { + if header.flags != message::MESSAGE_HEADER_EXPECT_RESPONSE { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match IntegrationTestInterfaceMethod0Request::deserialize(buffer, handles) { + Ok(value) => { + Ok(IntegrationTestInterfaceRequestOption::IntegrationTestInterfaceMethod0( + value, + )) + } + Err(err) => return Err(err), + } + } + _ => Err(ValidationError::MessageHeaderUnknownMethod), + } + } +} + +pub enum IntegrationTestInterfaceResponseOption { + IntegrationTestInterfaceMethod0(IntegrationTestInterfaceMethod0Response), +} + +impl MojomMessageOption for IntegrationTestInterfaceResponseOption { + fn decode_payload( + header: MessageHeader, + buffer: &[u8], + handles: Vec<UntypedHandle>, + ) -> Result<Self, ValidationError> { + if header.flags != message::MESSAGE_HEADER_IS_RESPONSE { + return Err(ValidationError::MessageHeaderInvalidFlags); + } + match header.name { + IntegrationTestInterfaceMethod0::ORDINAL => { + match IntegrationTestInterfaceMethod0Response::deserialize(buffer, handles) { + Ok(value) => { + Ok(IntegrationTestInterfaceResponseOption::IntegrationTestInterfaceMethod0( + value, + )) + } + Err(err) => return Err(err), + } + } + _ => Err(ValidationError::MessageHeaderUnknownMethod), + } + } +} + +/// Message: IntegrationTestInterfaceMethod0 +pub mod IntegrationTestInterfaceMethod0 { + pub const ORDINAL: u32 = 0; + pub const MIN_VERSION: u32 = 0; +} +// -- IntegrationTestInterfaceMethod0Request -- + +// Constants +// Enums +// Struct version information +const IntegrationTestInterfaceMethod0RequestVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct IntegrationTestInterfaceMethod0Request { + pub param0: BasicStruct, +} + +impl MojomPointer for IntegrationTestInterfaceMethod0Request { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&IntegrationTestInterfaceMethod0RequestVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <BasicStruct>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(IntegrationTestInterfaceMethod0Request { param0: param0 }) + } +} + +impl MojomEncodable for IntegrationTestInterfaceMethod0Request { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for IntegrationTestInterfaceMethod0Request {} +impl MojomMessage for IntegrationTestInterfaceMethod0Request { + fn min_version() -> u32 { + IntegrationTestInterfaceMethod0::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 1, + IntegrationTestInterfaceMethod0::ORDINAL, + message::MESSAGE_HEADER_EXPECT_RESPONSE, + ) + } +} +impl IntegrationTestInterfaceRequest for IntegrationTestInterfaceMethod0Request {} + +// -- IntegrationTestInterfaceMethod0Response -- + +// Constants +// Enums +// Struct version information +const IntegrationTestInterfaceMethod0ResponseVersions: [(u32, u32); 1] = [(0, 16)]; + +// Struct definition +pub struct IntegrationTestInterfaceMethod0Response { + pub param0: Vec<u8>, +} + +impl MojomPointer for IntegrationTestInterfaceMethod0Response { + fn header_data(&self) -> DataHeaderValue { + DataHeaderValue::Version(0) + } + fn serialized_size(&self, _context: &Context) -> usize { + 16 + } + fn encode_value(self, encoder: &mut Encoder, context: Context) { + MojomEncodable::encode(self.param0, encoder, context.clone()); + } + fn decode_value(decoder: &mut Decoder, context: Context) -> Result<Self, ValidationError> { + let version = { + let mut state = decoder.get_mut(&context); + match state.decode_struct_header(&IntegrationTestInterfaceMethod0ResponseVersions) { + Ok(header) => header.data(), + Err(err) => return Err(err), + } + }; + let param0 = match <Vec<u8>>::decode(decoder, context.clone()) { + Ok(value) => value, + Err(err) => return Err(err), + }; + Ok(IntegrationTestInterfaceMethod0Response { param0: param0 }) + } +} + +impl MojomEncodable for IntegrationTestInterfaceMethod0Response { + impl_encodable_for_pointer!(); + fn compute_size(&self, context: Context) -> usize { + encoding::align_default(self.serialized_size(&context)) + + self.param0.compute_size(context.clone()) + } +} + +impl MojomStruct for IntegrationTestInterfaceMethod0Response {} + +impl MojomMessage for IntegrationTestInterfaceMethod0Response { + fn min_version() -> u32 { + IntegrationTestInterfaceMethod0::MIN_VERSION + } + fn create_header() -> MessageHeader { + MessageHeader::new( + 1, + IntegrationTestInterfaceMethod0::ORDINAL, + message::MESSAGE_HEADER_IS_RESPONSE, + ) + } +} +impl IntegrationTestInterfaceResponse for IntegrationTestInterfaceMethod0Response {}
diff --git a/mojo/public/rust/tests/validation.rs b/mojo/public/rust/tests/validation.rs new file mode 100644 index 0000000..5b85582 --- /dev/null +++ b/mojo/public/rust/tests/validation.rs
@@ -0,0 +1,131 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! Tests encoding and decoding functionality in the bindings package +//! +//! Test failure is defined as the function returning via panicking +//! and the result being caught in the test! macro. If a test function +//! returns without panicking, it is assumed to pass. + +#[macro_use] +extern crate mojo; + +use mojo::bindings::mojom::MojomMessageOption; +use mojo::system; + +#[macro_use] +mod util; + +use util::mojom_validation::*; + +/// This macro is a wrapper for the tests! macro as it takes advantage of the +/// shared code between tests. +/// +/// Given a test name, it will generate a test function. In this test function +/// we perform the following steps: +/// 1. Decode the header of the validation input. +/// 2. Decode the payload of the validation input, expecting a validation +/// error. +/// +macro_rules! validation_tests { + ($($name:ident => $req_type:ident;)*) => { + tests! { + $( + fn $name() { + let data = include_str!(concat!("../../interfaces/bindings/tests/data/validation/", + stringify!($name), + ".data")); + let expected = include_str!(concat!("../../interfaces/bindings/tests/data/validation/", + stringify!($name), + ".expected")).trim(); + match util::parse_validation_test(data) { + Ok((data, num_handles)) => { + let mut mock_handles = Vec::with_capacity(num_handles); + for _ in 0..num_handles { + mock_handles.push(unsafe { system::acquire(0) }); + } + match $req_type::decode_message(data, mock_handles) { + Ok(_) => panic!("Should not be valid!"), + Err(err) => assert_eq!(err.as_str(), expected), + } + }, + Err(msg) => panic!("Error: {}", msg), + } + } + )* + } + } +} + +validation_tests! { + conformance_empty => ConformanceTestInterfaceRequestOption; + conformance_mthd0_incomplete_struct => ConformanceTestInterfaceRequestOption; + conformance_mthd0_incomplete_struct_header => ConformanceTestInterfaceRequestOption; + conformance_mthd0_invalid_request_flags => ConformanceTestInterfaceRequestOption; + conformance_mthd0_invalid_request_flags2 => ConformanceTestInterfaceRequestOption; + conformance_mthd0_struct_num_bytes_huge => ConformanceTestInterfaceRequestOption; + conformance_mthd0_struct_num_bytes_less_than_min_requirement => ConformanceTestInterfaceRequestOption; + conformance_mthd0_struct_num_bytes_less_than_struct_header => ConformanceTestInterfaceRequestOption; + conformance_mthd10_null_keys => ConformanceTestInterfaceRequestOption; + conformance_mthd10_null_values => ConformanceTestInterfaceRequestOption; + conformance_mthd10_one_null_key => ConformanceTestInterfaceRequestOption; + conformance_mthd10_unequal_array_size => ConformanceTestInterfaceRequestOption; + conformance_mthd11_num_bytes_version_mismatch_1 => ConformanceTestInterfaceRequestOption; + conformance_mthd11_num_bytes_version_mismatch_2 => ConformanceTestInterfaceRequestOption; + conformance_mthd12_invalid_request_flags => ConformanceTestInterfaceRequestOption; + conformance_mthd14_unexpected_null_array_in_union => ConformanceTestInterfaceRequestOption; + conformance_mthd14_unexpected_null_map_in_union => ConformanceTestInterfaceRequestOption; + conformance_mthd14_unexpected_null_struct_in_union => ConformanceTestInterfaceRequestOption; + conformance_mthd14_unexpected_null_union_in_union => ConformanceTestInterfaceRequestOption; + conformance_mthd15_unexpected_null_union_in_array => ConformanceTestInterfaceRequestOption; + conformance_mthd1_misaligned_struct => ConformanceTestInterfaceRequestOption; + conformance_mthd1_struct_pointer_overflow => ConformanceTestInterfaceRequestOption; + conformance_mthd1_unexpected_null_struct => ConformanceTestInterfaceRequestOption; + conformance_mthd2_multiple_pointers_to_same_struct => ConformanceTestInterfaceRequestOption; + conformance_mthd2_overlapped_objects => ConformanceTestInterfaceRequestOption; + conformance_mthd2_wrong_layout_order => ConformanceTestInterfaceRequestOption; + conformance_mthd3_array_num_bytes_huge => ConformanceTestInterfaceRequestOption; + conformance_mthd3_array_num_bytes_less_than_array_header => ConformanceTestInterfaceRequestOption; + conformance_mthd3_array_num_bytes_less_than_necessary_size => ConformanceTestInterfaceRequestOption; + conformance_mthd3_array_pointer_overflow => ConformanceTestInterfaceRequestOption; + conformance_mthd3_incomplete_array => ConformanceTestInterfaceRequestOption; + conformance_mthd3_incomplete_array_header => ConformanceTestInterfaceRequestOption; + conformance_mthd3_misaligned_array => ConformanceTestInterfaceRequestOption; + conformance_mthd3_unexpected_null_array => ConformanceTestInterfaceRequestOption; + conformance_mthd4_multiple_pointers_to_same_array => ConformanceTestInterfaceRequestOption; + conformance_mthd4_overlapped_objects => ConformanceTestInterfaceRequestOption; + conformance_mthd4_wrong_layout_order => ConformanceTestInterfaceRequestOption; + conformance_mthd5_handle_out_of_range => ConformanceTestInterfaceRequestOption; + conformance_mthd5_multiple_handles_with_same_value_1 => ConformanceTestInterfaceRequestOption; + conformance_mthd5_multiple_handles_with_same_value_2 => ConformanceTestInterfaceRequestOption; + conformance_mthd5_unexpected_invalid_handle => ConformanceTestInterfaceRequestOption; + conformance_mthd5_wrong_handle_order => ConformanceTestInterfaceRequestOption; + conformance_mthd6_nested_array_num_bytes_less_than_necessary_size => ConformanceTestInterfaceRequestOption; + conformance_mthd7_unexpected_null_fixed_array => ConformanceTestInterfaceRequestOption; + conformance_mthd7_unmatched_array_elements => ConformanceTestInterfaceRequestOption; + conformance_mthd7_unmatched_array_elements_nested => ConformanceTestInterfaceRequestOption; + conformance_mthd8_array_num_bytes_overflow => ConformanceTestInterfaceRequestOption; + conformance_mthd8_unexpected_null_array => ConformanceTestInterfaceRequestOption; + conformance_mthd8_unexpected_null_string => ConformanceTestInterfaceRequestOption; + conformance_mthd9_unexpected_null_array => ConformanceTestInterfaceRequestOption; + boundscheck_msghdr_no_such_method => BoundsCheckTestInterfaceRequestOption; + conformance_msghdr_incomplete_struct => ConformanceTestInterfaceRequestOption; + conformance_msghdr_incomplete_struct_header => ConformanceTestInterfaceRequestOption; + conformance_msghdr_invalid_flag_combo => ConformanceTestInterfaceRequestOption; + conformance_msghdr_missing_request_id => ConformanceTestInterfaceRequestOption; + conformance_msghdr_no_such_method => ConformanceTestInterfaceRequestOption; + conformance_msghdr_num_bytes_huge => ConformanceTestInterfaceRequestOption; + conformance_msghdr_num_bytes_less_than_min_requirement => ConformanceTestInterfaceRequestOption; + conformance_msghdr_num_bytes_less_than_struct_header => ConformanceTestInterfaceRequestOption; + conformance_msghdr_num_bytes_version_mismatch_1 => ConformanceTestInterfaceRequestOption; + conformance_msghdr_num_bytes_version_mismatch_2 => ConformanceTestInterfaceRequestOption; + conformance_msghdr_num_bytes_version_mismatch_3 => ConformanceTestInterfaceRequestOption; + resp_boundscheck_msghdr_no_such_method => BoundsCheckTestInterfaceResponseOption; + resp_conformance_msghdr_invalid_response_flags1 => ConformanceTestInterfaceResponseOption; + resp_conformance_msghdr_invalid_response_flags2 => ConformanceTestInterfaceResponseOption; + resp_conformance_msghdr_no_such_method => ConformanceTestInterfaceResponseOption; + integration_intf_resp_mthd0_unexpected_array_header => IntegrationTestInterfaceResponseOption; + integration_intf_rqst_mthd0_unexpected_struct_header => IntegrationTestInterfaceRequestOption; + integration_msghdr_invalid_flags => IntegrationTestInterfaceRequestOption; +}
diff --git a/net/cert/crl_set.cc b/net/cert/crl_set.cc index 5216565..24ceee3 100644 --- a/net/cert/crl_set.cc +++ b/net/cert/crl_set.cc
@@ -223,10 +223,8 @@ if (!header_dict.get()) return false; - std::string contents; - if (!header_dict->GetString("ContentType", &contents)) - return false; - if (contents != "CRLSet") + std::string* contents = header_dict->FindStringKey("ContentType"); + if (!contents || (*contents != "CRLSet")) return false; if (header_dict->FindIntKey("Version") != kCurrentFileVersion)
diff --git a/net/log/net_log_event_type_list.h b/net/log/net_log_event_type_list.h index e66ebe9..509d788 100644 --- a/net/log/net_log_event_type_list.h +++ b/net/log/net_log_event_type_list.h
@@ -1489,6 +1489,14 @@ // On sending an HTTP/2 SETTINGS frame with ACK flag. EVENT_TYPE(HTTP2_SESSION_SEND_SETTINGS_ACK) +// Receipt of an HTTP/2 ACCEPT_CH frame. +// The following parameters are attached: +// { +// "origin": <The origin associated with the settings>, +// "accept_ch": <the raw ACCEPT_CH setting for that origin>, +// } +EVENT_TYPE(HTTP2_SESSION_RECV_ACCEPT_CH) + // Receipt of an HTTP/2 SETTINGS frame without ACK flag. EVENT_TYPE(HTTP2_SESSION_RECV_SETTINGS) @@ -2305,6 +2313,13 @@ // Session received a HANDSHAKE_DONE frame. EVENT_TYPE(QUIC_SESSION_HANDSHAKE_DONE_FRAME_RECEIVED) +// Session received an ACCEPT_CH frame +// { +// "origin": <the origin the accept_ch settings apply to> +// "accept_ch": <the raw ACCEPT_CH data> +// } +EVENT_TYPE(QUIC_ACCEPT_CH_FRAME_RECEIVED) + // Session sent a coalesced QUIC packet. // { // "info": <coalesced packet info>
diff --git a/net/quic/quic_chromium_client_session.cc b/net/quic/quic_chromium_client_session.cc index f39c94c..03ca363 100644 --- a/net/quic/quic_chromium_client_session.cc +++ b/net/quic/quic_chromium_client_session.cc
@@ -215,6 +215,14 @@ return dict; } +base::Value NetLogAcceptChFrameReceivedParams( + spdy::AcceptChOriginValuePair entry) { + base::Value dict(base::Value::Type::DICTIONARY); + dict.SetStringKey("origin", entry.origin); + dict.SetStringKey("accept_ch", entry.value); + return dict; +} + // Histogram for recording the different reasons that a QUIC session is unable // to complete the handshake. enum HandshakeFailureReason { @@ -1241,6 +1249,9 @@ has_valid_entry = true; accept_ch_entries_received_via_alps_.insert( std::make_pair(std::move(scheme_host_port), entry.value)); + + net_log_.AddEvent(NetLogEventType::QUIC_ACCEPT_CH_FRAME_RECEIVED, + [&] { return NetLogAcceptChFrameReceivedParams(entry); }); } LogAcceptChFrameReceivedHistogram(has_valid_entry, has_invalid_entry); }
diff --git a/net/spdy/spdy_session.cc b/net/spdy/spdy_session.cc index f2b859f..4e88846 100644 --- a/net/spdy/spdy_session.cc +++ b/net/spdy/spdy_session.cc
@@ -316,6 +316,13 @@ return dict; } +base::Value NetLogSpdyRecvAcceptChParams(spdy::AcceptChOriginValuePair entry) { + base::Value dict(base::Value::Type::DICTIONARY); + dict.SetStringKey("origin", entry.origin); + dict.SetStringKey("accept_ch", entry.value); + return dict; +} + base::Value NetLogSpdyRecvSettingParams(spdy::SpdySettingsId id, uint32_t value) { base::Value dict(base::Value::Type::DICTIONARY); @@ -1157,6 +1164,9 @@ has_valid_entry = true; accept_ch_entries_received_via_alps_.insert( std::make_pair(std::move(scheme_host_port), entry.value)); + + net_log_.AddEvent(NetLogEventType::HTTP2_SESSION_RECV_ACCEPT_CH, + [&] { return NetLogSpdyRecvAcceptChParams(entry); }); } SpdyAcceptChEntries value;
diff --git a/pdf/pdf_view_plugin_base.cc b/pdf/pdf_view_plugin_base.cc index fa2aa0c..aef2402 100644 --- a/pdf/pdf_view_plugin_base.cc +++ b/pdf/pdf_view_plugin_base.cc
@@ -567,8 +567,6 @@ NotifyLinkUnderCursor(); } -// TODO(crbug.com/1191817): Add tests for input events. Unit testing should be -// feasible now that the Pepper dependency is removed for input events. bool PdfViewPluginBase::HandleInputEvent(const blink::WebInputEvent& event) { // Ignore user input in read-only mode. if (engine()->IsReadOnly())
diff --git a/pdf/pdf_view_plugin_base_unittest.cc b/pdf/pdf_view_plugin_base_unittest.cc index 111a6a9..1788865 100644 --- a/pdf/pdf_view_plugin_base_unittest.cc +++ b/pdf/pdf_view_plugin_base_unittest.cc
@@ -31,7 +31,10 @@ #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" #include "third_party/abseil-cpp/absl/types/optional.h" +#include "third_party/blink/public/common/input/web_input_event.h" +#include "third_party/blink/public/common/input/web_mouse_event.h" #include "third_party/skia/include/core/SkColor.h" +#include "ui/gfx/geometry/point_f.h" #include "ui/gfx/geometry/size.h" namespace chrome_pdf { @@ -160,6 +163,7 @@ using PdfViewPluginBase::accessibility_state; using PdfViewPluginBase::engine; using PdfViewPluginBase::full_frame; + using PdfViewPluginBase::HandleInputEvent; using PdfViewPluginBase::HandleMessage; using PdfViewPluginBase::LoadUrl; using PdfViewPluginBase::SetZoom; @@ -780,6 +784,25 @@ fake_plugin_.GetNotifiedBrowserAboutUnsupportedFeatureForTesting()); } +TEST_F(PdfViewPluginBaseWithEngineTest, HandleInputEvent) { + auto* engine = static_cast<TestPDFiumEngine*>(fake_plugin_.engine()); + EXPECT_CALL(*engine, HandleInputEvent) + .WillRepeatedly([](const blink::WebInputEvent& event) { + const auto& mouse_event = + static_cast<const blink::WebMouseEvent&>(event); + EXPECT_EQ(blink::WebInputEvent::Type::kMouseDown, + mouse_event.GetType()); + EXPECT_EQ(gfx::PointF(10.0f, 20.0f), mouse_event.PositionInWidget()); + return true; + }); + + blink::WebMouseEvent mouse_event; + mouse_event.SetType(blink::WebInputEvent::Type::kMouseDown); + mouse_event.SetPositionInWidget(10.0f, 20.0f); + + EXPECT_TRUE(fake_plugin_.HandleInputEvent(mouse_event)); +} + TEST_F(PdfViewPluginBaseTest, EnteredEditMode) { EXPECT_CALL(fake_plugin_, SetPluginCanSave(true)); fake_plugin_.EnteredEditMode();
diff --git a/pdf/pdf_view_web_plugin_unittest.cc b/pdf/pdf_view_web_plugin_unittest.cc index d2aa1e9..2d728d2 100644 --- a/pdf/pdf_view_web_plugin_unittest.cc +++ b/pdf/pdf_view_web_plugin_unittest.cc
@@ -721,28 +721,8 @@ class PdfViewWebPluginImeTest : public PdfViewWebPluginTest { public: - class TestPDFiumEngineForIme : public TestPDFiumEngine { - public: - explicit TestPDFiumEngineForIme(PDFEngine::Client* client) - : TestPDFiumEngine(client) {} - - // TestPDFiumEngine: - MOCK_METHOD(bool, - HandleInputEvent, - (const blink::WebInputEvent&), - (override)); - }; - - std::unique_ptr<TestPDFiumEngine> CreateEngine() override { - return std::make_unique<NiceMock<TestPDFiumEngineForIme>>(plugin_.get()); - } - - TestPDFiumEngineForIme* engine() { - return static_cast<TestPDFiumEngineForIme*>(engine_ptr_); - } - void TestImeSetCompositionForPlugin(const blink::WebString& text) { - EXPECT_CALL(*engine(), HandleInputEvent).Times(0); + EXPECT_CALL(*engine_ptr_, HandleInputEvent).Times(0); plugin_->ImeSetCompositionForPlugin(text, std::vector<ui::ImeTextSpan>(), gfx::Range(), /*selection_start=*/0, @@ -756,12 +736,12 @@ if (expected_text16.size()) { for (const auto& c : expected_text16) { base::StringPiece16 expected_key(&c, 1); - EXPECT_CALL(*engine(), + EXPECT_CALL(*engine_ptr_, HandleInputEvent(IsExpectedImeKeyEvent(expected_key))) .WillOnce(Return(true)); } } else { - EXPECT_CALL(*engine(), HandleInputEvent).Times(0); + EXPECT_CALL(*engine_ptr_, HandleInputEvent).Times(0); } plugin_->ImeFinishComposingTextForPlugin(false); } @@ -772,11 +752,12 @@ if (expected_text16.size()) { for (const auto& c : expected_text16) { base::StringPiece16 event(&c, 1); - EXPECT_CALL(*engine(), HandleInputEvent(IsExpectedImeKeyEvent(event))) + EXPECT_CALL(*engine_ptr_, + HandleInputEvent(IsExpectedImeKeyEvent(event))) .WillOnce(Return(true)); } } else { - EXPECT_CALL(*engine(), HandleInputEvent).Times(0); + EXPECT_CALL(*engine_ptr_, HandleInputEvent).Times(0); } plugin_->ImeCommitTextForPlugin(text, std::vector<ui::ImeTextSpan>(), gfx::Range(),
diff --git a/pdf/test/test_pdfium_engine.h b/pdf/test/test_pdfium_engine.h index c1518ed..89afd9ea 100644 --- a/pdf/test/test_pdfium_engine.h +++ b/pdf/test/test_pdfium_engine.h
@@ -42,6 +42,11 @@ MOCK_METHOD(void, ScrolledToXPosition, (int), (override)); MOCK_METHOD(void, ScrolledToYPosition, (int), (override)); + MOCK_METHOD(bool, + HandleInputEvent, + (const blink::WebInputEvent&), + (override)); + MOCK_METHOD(void, ZoomUpdated, (double), (override)); MOCK_METHOD(gfx::Size,
diff --git a/remoting/host/mojom/webauthn_proxy.mojom b/remoting/host/mojom/webauthn_proxy.mojom index 8edb3f1..8ed55423a 100644 --- a/remoting/host/mojom/webauthn_proxy.mojom +++ b/remoting/host/mojom/webauthn_proxy.mojom
@@ -4,13 +4,37 @@ module remoting.mojom; +// The response object for WebAuthnProxy.Create(). +union WebAuthnCreateResponse { + // The `name` property of the `DOMException`, if any, yielded by the remote + // request. + string error_name; + + // A string-serialized representation of the `PublicKeyCredential` + // (https://w3c.github.io/webauthn/#publickeycredential), if any, yielded + // by the remote request. + // Note that it is opaque to chromoting host processes and will be passed + // verbatim to the Chrome webAuthenticationProxy extension API. + string response_data; +}; + // An interface for the host-side chromoting extension to pass intercepted Web // Authentication API requests to the client side chromoting security extension // through a chromoting host process. -// The remote of this interface is always owned by a low-trust native messaging -// host process; on Windows, the receiver is bound in a low-trust chromoting -// network process, while on Linux, the receiver is bound in a high-trust -// chromoting host process. +// +// The interface defined here generally matches the IDL of the Chrome +// webAuthenticationProxy extension API: +// chrome/common/extensions/api/web_authentication_proxy.idl +// +// The remote of this interface is always owned by a native messaging host +// process; on Windows, the receiver is bound in a chromoting network process, +// while on Linux, the receiver is bound in a chromoting host process. Both the +// remote and the receiver are generally trusted. +// +// There is an intrinsic risk of remote WebAuthn forwarding, but it's an +// accepted risk and the impact is limited given the limited scope of the +// feature. Please see the note here: go/crd-webauthn#heading=h.s445jjbbs1m2 +// // Note that both processes are chromoting-only and they don't make IPCs with // Chrome processes. interface WebAuthnProxy { @@ -18,4 +42,12 @@ // PublicKeyCredential.isUserVerifyingPlatformAuthenticatorAvailable() call // remotely. IsUserVerifyingPlatformAuthenticatorAvailable() => (bool is_available); + + // Handles a navigator.credentials.create() call remotely. + // |request_data| is the string-serialized representation of the parameters + // passed to the create() call. It is opaque to chromoting host processes and + // will be passed verbatim to the client. + // If |response| is null, it means that the remote create() call has yielded + // `null`, which is still a valid response according to the spec. + Create(string request_data) => (WebAuthnCreateResponse? response); };
diff --git a/remoting/host/webauthn/remote_webauthn_constants.cc b/remoting/host/webauthn/remote_webauthn_constants.cc index 134c56f..0fd4073 100644 --- a/remoting/host/webauthn/remote_webauthn_constants.cc +++ b/remoting/host/webauthn/remote_webauthn_constants.cc
@@ -10,8 +10,12 @@ const char kIsUvpaaMessageType[] = "isUvpaa"; const char kGetRemoteStateMessageType[] = "getRemoteState"; +const char kCreateMessageType[] = "create"; const char kIsUvpaaResponseIsAvailableKey[] = "isAvailable"; const char kGetRemoteStateResponseIsRemotedKey[] = "isRemoted"; +const char kCreateRequestDataKey[] = "requestData"; +const char kCreateResponseErrorNameKey[] = "errorName"; +const char kCreateResponseDataKey[] = "responseData"; } // namespace remoting
diff --git a/remoting/host/webauthn/remote_webauthn_constants.h b/remoting/host/webauthn/remote_webauthn_constants.h index 65ac883..2194881 100644 --- a/remoting/host/webauthn/remote_webauthn_constants.h +++ b/remoting/host/webauthn/remote_webauthn_constants.h
@@ -12,10 +12,14 @@ // NMH message types. extern const char kIsUvpaaMessageType[]; extern const char kGetRemoteStateMessageType[]; +extern const char kCreateMessageType[]; // NMH message keys. extern const char kIsUvpaaResponseIsAvailableKey[]; extern const char kGetRemoteStateResponseIsRemotedKey[]; +extern const char kCreateRequestDataKey[]; +extern const char kCreateResponseErrorNameKey[]; +extern const char kCreateResponseDataKey[]; } // namespace remoting
diff --git a/remoting/host/webauthn/remote_webauthn_message_handler.cc b/remoting/host/webauthn/remote_webauthn_message_handler.cc index 1baeaad..a22567f 100644 --- a/remoting/host/webauthn/remote_webauthn_message_handler.cc +++ b/remoting/host/webauthn/remote_webauthn_message_handler.cc
@@ -11,6 +11,23 @@ #include "remoting/proto/remote_webauthn.pb.h" #include "remoting/protocol/message_serialization.h" +namespace { + +template <typename CallbackType, typename ResponseType> +void FindAndRunCallback(base::flat_map<uint64_t, CallbackType>& callback_map, + uint64_t id, + ResponseType response) { + auto it = callback_map.find(id); + if (it == callback_map.end()) { + LOG(WARNING) << "No callback found associated with ID: " << id; + return; + } + std::move(it->second).Run(std::move(response)); + callback_map.erase(it); +} + +} // namespace + namespace remoting { RemoteWebAuthnMessageHandler::RemoteWebAuthnMessageHandler( @@ -48,6 +65,10 @@ OnIsUvpaaResponse(remote_webauthn->id(), remote_webauthn->is_uvpaa_response()); break; + case protocol::RemoteWebAuthn::kCreateResponse: + OnCreateResponse(remote_webauthn->id(), + remote_webauthn->create_response()); + break; default: LOG(ERROR) << "Unknown message case: " << remote_webauthn->message_case(); } @@ -83,6 +104,20 @@ Send(remote_webauthn, base::DoNothing()); } +void RemoteWebAuthnMessageHandler::Create(const std::string& request_data, + CreateCallback callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + uint64_t id = AssignNextMessageId(); + create_callbacks_[id] = std::move(callback); + + protocol::RemoteWebAuthn remote_webauthn; + remote_webauthn.set_id(id); + remote_webauthn.mutable_create_request()->set_request_details_json( + request_data); + Send(remote_webauthn, base::DoNothing()); +} + void RemoteWebAuthnMessageHandler::AddReceiver( mojo::PendingReceiver<mojom::WebAuthnProxy> receiver) { if (!connected()) { @@ -121,13 +156,34 @@ const protocol::RemoteWebAuthn_IsUvpaaResponse& response) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - auto it = is_uvpaa_callbacks_.find(id); - if (it == is_uvpaa_callbacks_.end()) { - LOG(WARNING) << "No IsUvpaa IPC callback associated with ID: " << id; - return; + FindAndRunCallback(is_uvpaa_callbacks_, id, response.is_available()); +} + +void RemoteWebAuthnMessageHandler::OnCreateResponse( + uint64_t id, + const protocol::RemoteWebAuthn_CreateResponse& response) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + mojom::WebAuthnCreateResponsePtr mojo_response; + switch (response.result_case()) { + case protocol::RemoteWebAuthn::CreateResponse::ResultCase::kErrorName: + mojo_response = + mojom::WebAuthnCreateResponse::NewErrorName(response.error_name()); + break; + case protocol::RemoteWebAuthn::CreateResponse::ResultCase::kResponseJson: + mojo_response = mojom::WebAuthnCreateResponse::NewResponseData( + response.response_json()); + break; + case protocol::RemoteWebAuthn::CreateResponse::ResultCase::RESULT_NOT_SET: + // Do nothing and send a nullptr to the mojo client. This means the remote + // create() call has yielded `null`, which is still a valid response + // according to the spec. + break; + default: + NOTREACHED() << "Unknown create result case: " << response.result_case(); } - std::move(it->second).Run(response.is_available()); - is_uvpaa_callbacks_.erase(it); + + FindAndRunCallback(create_callbacks_, id, std::move(mojo_response)); } uint64_t RemoteWebAuthnMessageHandler::AssignNextMessageId() {
diff --git a/remoting/host/webauthn/remote_webauthn_message_handler.h b/remoting/host/webauthn/remote_webauthn_message_handler.h index 3116a0f..ff405bfdc4 100644 --- a/remoting/host/webauthn/remote_webauthn_message_handler.h +++ b/remoting/host/webauthn/remote_webauthn_message_handler.h
@@ -22,6 +22,7 @@ namespace remoting { namespace protocol { +class RemoteWebAuthn_CreateResponse; class RemoteWebAuthn_IsUvpaaResponse; } // namespace protocol @@ -44,6 +45,8 @@ // mojom::WebAuthnProxy implementation. void IsUserVerifyingPlatformAuthenticatorAvailable( IsUserVerifyingPlatformAuthenticatorAvailableCallback callback) override; + void Create(const std::string& request_data, + CreateCallback callback) override; void AddReceiver(mojo::PendingReceiver<mojom::WebAuthnProxy> receiver); void ClearReceivers(); @@ -55,10 +58,16 @@ base::WeakPtr<RemoteWebAuthnMessageHandler> GetWeakPtr(); private: + template <typename CallbackType> + using CallbackMap = base::flat_map<uint64_t, CallbackType>; + void OnReceiverDisconnected(); void OnIsUvpaaResponse( uint64_t id, const protocol::RemoteWebAuthn_IsUvpaaResponse& response); + void OnCreateResponse( + uint64_t id, + const protocol::RemoteWebAuthn_CreateResponse& response); uint64_t AssignNextMessageId(); @@ -68,9 +77,10 @@ mojo::ReceiverSet<mojom::WebAuthnProxy> receiver_set_; // message ID => mojo callback mappings. - base::flat_map<uint64_t, - IsUserVerifyingPlatformAuthenticatorAvailableCallback> + CallbackMap<IsUserVerifyingPlatformAuthenticatorAvailableCallback> is_uvpaa_callbacks_ GUARDED_BY_CONTEXT(sequence_checker_); + CallbackMap<CreateCallback> create_callbacks_ + GUARDED_BY_CONTEXT(sequence_checker_); uint64_t current_message_id_ GUARDED_BY_CONTEXT(sequence_checker_) = 0u;
diff --git a/remoting/host/webauthn/remote_webauthn_native_messaging_host.cc b/remoting/host/webauthn/remote_webauthn_native_messaging_host.cc index 7faf2eb..483d6fe 100644 --- a/remoting/host/webauthn/remote_webauthn_native_messaging_host.cc +++ b/remoting/host/webauthn/remote_webauthn_native_messaging_host.cc
@@ -48,6 +48,8 @@ ProcessIsUvpaa(request, std::move(response)); } else if (type == kGetRemoteStateMessageType) { ProcessGetRemoteState(std::move(response)); + } else if (type == kCreateMessageType) { + ProcessCreate(request, std::move(response)); } else { LOG(ERROR) << "Unsupported request type: " << type; } @@ -94,6 +96,40 @@ base::Unretained(this), std::move(response))); } +void RemoteWebAuthnNativeMessagingHost::ProcessCreate( + const base::Value& request, + base::Value response) { + // Create request: {id: string, type: 'create', requestData: string} + // Create response: { + // id: string, type: 'createResponse', responseData?: string, + // errorName?: string} + + DCHECK(task_runner_->BelongsToCurrentThread()); + + if (!EnsureIpcConnection()) { + // TODO(yuweih): See if this is the right error to use here. + response.SetStringKey(kCreateResponseErrorNameKey, "InvalidStateError"); + SendMessageToClient(std::move(response)); + return; + } + + const std::string* request_data = + request.FindStringKey(kCreateRequestDataKey); + if (!request_data) { + LOG(ERROR) << "Request data not found in create request."; + // navigator.credentials.create() throws NotSupportedError if the parameter + // is unexpected. + response.SetStringKey(kCreateResponseErrorNameKey, "NotSupportedError"); + SendMessageToClient(std::move(response)); + return; + } + + remote_->Create( + *request_data, + base::BindOnce(&RemoteWebAuthnNativeMessagingHost::OnCreateResponse, + base::Unretained(this), std::move(response))); +} + void RemoteWebAuthnNativeMessagingHost::ProcessGetRemoteState( base::Value response) { // GetRemoteState request: {id: string, type: 'getRemoteState'} @@ -136,6 +172,33 @@ SendMessageToClient(std::move(response)); } +void RemoteWebAuthnNativeMessagingHost::OnCreateResponse( + base::Value response, + mojom::WebAuthnCreateResponsePtr remote_response) { + DCHECK(task_runner_->BelongsToCurrentThread()); + + // If |remote_response| is null, it means that the remote create() call has + // yielded `null`, which is still a valid response according to the spec. In + // this case we just send back an empty create response. + if (!remote_response.is_null()) { + switch (remote_response->which()) { + case mojom::WebAuthnCreateResponse::Tag::kErrorName: + response.SetStringKey(kCreateResponseErrorNameKey, + remote_response->get_error_name()); + break; + case mojom::WebAuthnCreateResponse::Tag::kResponseData: + response.SetStringKey(kCreateResponseDataKey, + remote_response->get_response_data()); + break; + default: + NOTREACHED() << "Unexpected create response tag: " + << static_cast<uint32_t>(remote_response->which()); + } + } + + SendMessageToClient(std::move(response)); +} + void RemoteWebAuthnNativeMessagingHost::QueryNextRemoteState() { DCHECK(task_runner_->BelongsToCurrentThread());
diff --git a/remoting/host/webauthn/remote_webauthn_native_messaging_host.h b/remoting/host/webauthn/remote_webauthn_native_messaging_host.h index 97daf70..1ddd0ff 100644 --- a/remoting/host/webauthn/remote_webauthn_native_messaging_host.h +++ b/remoting/host/webauthn/remote_webauthn_native_messaging_host.h
@@ -13,6 +13,7 @@ #include "extensions/browser/api/messaging/native_message_host.h" #include "mojo/public/cpp/bindings/remote.h" #include "remoting/host/chromoting_host_services_client.h" +#include "remoting/host/mojom/webauthn_proxy.mojom-forward.h" #include "remoting/host/mojom/webauthn_proxy.mojom.h" namespace remoting { @@ -38,10 +39,13 @@ void ProcessHello(base::Value response); void ProcessGetRemoteState(base::Value response); void ProcessIsUvpaa(const base::Value& request, base::Value response); + void ProcessCreate(const base::Value& request, base::Value response); void OnQueryVersionResult(uint32_t version); void OnIpcDisconnected(); void OnIsUvpaaResponse(base::Value response, bool is_available); + void OnCreateResponse(base::Value response, + mojom::WebAuthnCreateResponsePtr remote_response); void QueryNextRemoteState(); void SendNextRemoteState(bool is_remoted);
diff --git a/remoting/proto/remote_webauthn.proto b/remoting/proto/remote_webauthn.proto index ca5884d..0d56f74 100644 --- a/remoting/proto/remote_webauthn.proto +++ b/remoting/proto/remote_webauthn.proto
@@ -10,7 +10,7 @@ // Composite message type for messages sent over the remote-webauthn data // channel. -// Next ID: 4 +// Next ID: 6 message RemoteWebAuthn { // Requests the client to handle a call to // PublicKeyCredential.isUserVerifyingPlatformAuthenticatorAvailable(). @@ -20,12 +20,38 @@ // Next ID: 2 message IsUvpaaResponse { optional bool is_available = 1; } - // Unique ID to pair a response with the request. - // Required for all message types. + // Requests the client to handle a navigator.credentials.create() call. + // Next ID: 2 + message CreateRequest { + // A JSON serialized representation of PublicKeyCredentialCreationOptions + // passed to navigator.credentials.create(). + optional string request_details_json = 1; + } + + // Response for CreateRequest. + // Next ID: 3 + message CreateResponse { + // If neither of the fields is set, it means that the remote create() call + // has yielded `null`, which is still a valid response according to the + // spec. + oneof result { + // The `name` property of the `DOMException`, if any. + string error_name = 1; + + // A JSON serialized representation of the `PublicKeyCredential` + // (https://w3c.github.io/webauthn/#publickeycredential), if any. + string response_json = 2; + } + } + + // Unique ID used to multiplex requests. optional uint64 id = 1; oneof message { IsUvpaaRequest is_uvpaa_request = 2; IsUvpaaResponse is_uvpaa_response = 3; + + CreateRequest create_request = 4; + CreateResponse create_response = 5; } }
diff --git a/services/network/BUILD.gn b/services/network/BUILD.gn index bf42c74..45070bd 100644 --- a/services/network/BUILD.gn +++ b/services/network/BUILD.gn
@@ -456,6 +456,7 @@ "ct_log_list_distributor_unittest.cc", "expect_ct_reporter_unittest.cc", "sct_auditing/sct_auditing_cache_unittest.cc", + "sct_auditing/sct_auditing_handler_unittest.cc", ] deps += [ "//components/certificate_transparency" ] }
diff --git a/services/network/first_party_sets/first_party_sets.cc b/services/network/first_party_sets/first_party_sets.cc index e4112e8..ed94752a 100644 --- a/services/network/first_party_sets/first_party_sets.cc +++ b/services/network/first_party_sets/first_party_sets.cc
@@ -9,6 +9,7 @@ #include <utility> #include <vector> +#include "base/check.h" #include "base/containers/contains.h" #include "base/files/file_util.h" #include "base/logging.h" @@ -100,8 +101,17 @@ void FirstPartySets::ParseAndSet(base::File sets_file) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - if (!enabled_) + if (!enabled_ || component_sets_parse_progress_ != Progress::kNotStarted) { return; + } + + component_sets_parse_progress_ = Progress::kStarted; + + if (!sets_file.IsValid()) { + OnReadSetsFile(""); + return; + } + base::ThreadPool::PostTaskAndReplyWithResult( FROM_HERE, {base::MayBlock(), base::TaskPriority::BEST_EFFORT}, base::BindOnce(&ReadSetsFile, std::move(sets_file)), @@ -111,8 +121,8 @@ void FirstPartySets::OnReadSetsFile(const std::string& raw_sets) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - if (!enabled_) - return; + DCHECK_EQ(component_sets_parse_progress_, Progress::kStarted); + DCHECK(enabled_); bool is_v1_format = raw_sets.find('[') < raw_sets.find('{'); if (is_v1_format) { @@ -128,7 +138,7 @@ is_v1_format); ApplyManuallySpecifiedSet(); - component_sets_ready_ = true; + component_sets_parse_progress_ = Progress::kFinished; ClearSiteDataOnChangedSetsIfReady(); } @@ -345,8 +355,9 @@ void FirstPartySets::ClearSiteDataOnChangedSetsIfReady() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - if (!persisted_sets_ready_ || !component_sets_ready_ || !manual_sets_ready_ || - on_site_data_cleared_.is_null()) + if (!persisted_sets_ready_ || + component_sets_parse_progress_ != Progress::kFinished || + !manual_sets_ready_ || on_site_data_cleared_.is_null()) return; base::flat_set<net::SchemefulSite> diff = ComputeSetsDiff(
diff --git a/services/network/first_party_sets/first_party_sets.h b/services/network/first_party_sets/first_party_sets.h index 5afda58..af9ce173 100644 --- a/services/network/first_party_sets/first_party_sets.h +++ b/services/network/first_party_sets/first_party_sets.h
@@ -51,8 +51,15 @@ // record is a set declaration in the format specified here: // https://github.com/privacycg/first-party-sets. // - // In case of invalid input, the members-to-owners map is cleared, but keeps - // any manually-specified set (i.e. a set provided on the command line). + // Only the first call to ParseAndSet can have any effect; subsequent + // invocations are ignored. + // + // If `sets_file.IsValid()` is false, then the set of sets is considered + // empty. + // + // In case of invalid input, the set of sets provided by the file is + // considered empty. Note that the FirstPartySets instance may still have some + // sets, from the command line or enterprise policies. // // Has no effect if `kFirstPartySets` is disabled. void ParseAndSet(base::File sets_file); @@ -159,8 +166,15 @@ std::string raw_persisted_sets_ GUARDED_BY_CONTEXT(sequence_checker_); + enum Progress { + kNotStarted, + kStarted, + kFinished, + }; + bool persisted_sets_ready_ GUARDED_BY_CONTEXT(sequence_checker_) = false; - bool component_sets_ready_ GUARDED_BY_CONTEXT(sequence_checker_) = false; + Progress component_sets_parse_progress_ + GUARDED_BY_CONTEXT(sequence_checker_) = kNotStarted; bool manual_sets_ready_ GUARDED_BY_CONTEXT(sequence_checker_) = false; bool enabled_ GUARDED_BY_CONTEXT(sequence_checker_) = false; // The callback runs after the site state clearing is completed.
diff --git a/services/network/first_party_sets/first_party_sets_unittest.cc b/services/network/first_party_sets/first_party_sets_unittest.cc index 5448b9e..11c7a92 100644 --- a/services/network/first_party_sets/first_party_sets_unittest.cc +++ b/services/network/first_party_sets/first_party_sets_unittest.cc
@@ -4,6 +4,8 @@ #include "services/network/first_party_sets/first_party_sets.h" +#include <string> + #include "base/files/file_util.h" #include "base/files/scoped_temp_dir.h" #include "base/json/json_reader.h" @@ -62,6 +64,8 @@ FirstPartySets& sets() { return sets_; } + base::test::TaskEnvironment& env() { return env_; } + private: base::test::ScopedFeatureList feature_list_; base::test::TaskEnvironment env_; @@ -155,6 +159,21 @@ EXPECT_THAT(sets().Sets(), IsEmpty()); } +TEST_F(FirstPartySetsEnabledTest, IgnoresInvalidFile) { + sets().ParseAndSet(base::File()); + env().RunUntilIdle(); + EXPECT_THAT(sets().Sets(), IsEmpty()); + + const std::string input = + "{\"owner\": \"https://example.test\",\"members\": " + "[\"https://aaaa.test\",],}"; + + // Subsequent ParseAndSet calls should be ignored, because the instance has + // already received sets from component updater. + SetComponentSetsAndWait(input); + EXPECT_THAT(sets().Sets(), IsEmpty()); +} + TEST_F(FirstPartySetsEnabledTest, ParsesJSON) { SetComponentSetsAndWait("[]"); EXPECT_THAT(sets().Sets(), IsEmpty()); @@ -236,8 +255,8 @@ SerializesTo("https://member2.test"))))); } -TEST_F(FirstPartySetsEnabledTest, ClearsPreloadedOnError) { - const std::string input = R"( +TEST_F(FirstPartySetsEnabledTest, ParseAndSet_Idempotent) { + std::string input = R"( [ { "owner": "https://example.test", @@ -262,9 +281,31 @@ UnorderedElementsAre(SerializesTo("https://foo.test"), SerializesTo("https://member2.test"))))); - SetComponentSetsAndWait("{}"); + input = R"( + [ + { + "owner": "https://example2.test", + "members": ["https://member1.test"] + }, + { + "owner": "https://foo2.test", + "members": ["https://member2.test"] + } + ] + )"; + ASSERT_TRUE(base::JSONReader::Read(input)); + SetComponentSetsAndWait(input); - EXPECT_THAT(sets().Sets(), IsEmpty()); + // The second call to ParseAndSet should have had no effect. + EXPECT_THAT( + sets().Sets(), + UnorderedElementsAre( + Pair(SerializesTo("https://example.test"), + UnorderedElementsAre(SerializesTo("https://example.test"), + SerializesTo("https://member1.test"))), + Pair(SerializesTo("https://foo.test"), + UnorderedElementsAre(SerializesTo("https://foo.test"), + SerializesTo("https://member2.test"))))); } TEST_F(FirstPartySetsEnabledTest, OwnerIsOnlyMember) { @@ -568,42 +609,6 @@ } TEST_F(FirstPartySetsEnabledTest, - SetsManuallySpecified_ClearsPreloadedOnError) { - const std::string input = R"( - [ - { - "owner": "https://bar.test", - "members": ["https://member3.test"] - } - ] - )"; - ASSERT_TRUE(base::JSONReader::Read(input)); - SetComponentSetsAndWait(input); - - sets().SetManuallySpecifiedSet( - "https://example.test,https://member1.test,https://member2.test"); - EXPECT_THAT( - sets().Sets(), - UnorderedElementsAre( - Pair(SerializesTo("https://example.test"), - UnorderedElementsAre(SerializesTo("https://example.test"), - SerializesTo("https://member1.test"), - SerializesTo("https://member2.test"))), - Pair(SerializesTo("https://bar.test"), - UnorderedElementsAre(SerializesTo("https://bar.test"), - SerializesTo("https://member3.test"))))); - - SetComponentSetsAndWait("{}"); - - EXPECT_THAT(sets().Sets(), - UnorderedElementsAre(Pair( - SerializesTo("https://example.test"), - UnorderedElementsAre(SerializesTo("https://example.test"), - SerializesTo("https://member1.test"), - SerializesTo("https://member2.test"))))); -} - -TEST_F(FirstPartySetsEnabledTest, SetsManuallySpecified_PrunesInducedSingletons) { const std::string input = R"( [
diff --git a/services/network/network_context.cc b/services/network/network_context.cc index b4cc320..b5b0180 100644 --- a/services/network/network_context.cc +++ b/services/network/network_context.cc
@@ -516,6 +516,16 @@ if (params_->ct_policy) SetCTPolicy(std::move(params_->ct_policy)); + base::FilePath sct_auditing_path; + if (base::FeatureList::IsEnabled( + features::kSCTAuditingRetryAndPersistReports)) { + GetFullDataFilePath(params_->file_paths, + &network::mojom::NetworkContextFilePaths:: + sct_auditing_pending_reports_file_name, + sct_auditing_path); + } + sct_auditing_handler_ = + std::make_unique<SCTAuditingHandler>(this, sct_auditing_path); sct_auditing_handler()->SetEnabled(params_->enable_sct_auditing); #endif
diff --git a/services/network/network_context.h b/services/network/network_context.h index 40cd0e3d..08d7391 100644 --- a/services/network/network_context.h +++ b/services/network/network_context.h
@@ -72,10 +72,6 @@ #include "net/reporting/reporting_report.h" #endif // BUILDFLAG(ENABLE_REPORTING) -#if BUILDFLAG(IS_CT_SUPPORTED) -#include "services/network/sct_auditing/sct_auditing_handler.h" -#endif // BUILDFLAG(IS_CT_SUPPORTED) - namespace base { class UnguessableToken; } // namespace base @@ -115,6 +111,7 @@ class ProxyLookupRequest; class ResourceScheduler; class ResourceSchedulerClient; +class SCTAuditingHandler; class SessionCleanupCookieStore; class SQLiteTrustTokenPersister; class WebSocketFactory; @@ -303,7 +300,9 @@ void OnCTLogListUpdated( const std::vector<network::mojom::CTLogInfoPtr>& log_list, base::Time update_time); - SCTAuditingHandler* sct_auditing_handler() { return &sct_auditing_handler_; } + SCTAuditingHandler* sct_auditing_handler() { + return sct_auditing_handler_.get(); + } #endif // BUILDFLAG(IS_CT_SUPPORTED) void CreateUDPSocket( mojo::PendingReceiver<mojom::UDPSocket> receiver, @@ -751,7 +750,7 @@ raw_ptr<certificate_transparency::ChromeCTPolicyEnforcer> ct_policy_enforcer_ = nullptr; - SCTAuditingHandler sct_auditing_handler_{this}; + std::unique_ptr<SCTAuditingHandler> sct_auditing_handler_; #endif // BUILDFLAG(IS_CT_SUPPORTED) #if BUILDFLAG(IS_CHROMEOS_ASH)
diff --git a/services/network/public/cpp/cookie_manager_mojom_traits.cc b/services/network/public/cpp/cookie_manager_mojom_traits.cc index 9c7778ae..12c5c13 100644 --- a/services/network/public/cpp/cookie_manager_mojom_traits.cc +++ b/services/network/public/cpp/cookie_manager_mojom_traits.cc
@@ -8,6 +8,7 @@ #include "net/cookies/cookie_constants.h" #include "net/cookies/cookie_options.h" #include "net/cookies/same_party_context.h" +#include "services/network/public/mojom/cookie_manager.mojom.h" namespace mojo {
diff --git a/services/network/public/cpp/cookie_manager_mojom_traits.h b/services/network/public/cpp/cookie_manager_mojom_traits.h index c2dc387..8bcd88c 100644 --- a/services/network/public/cpp/cookie_manager_mojom_traits.h +++ b/services/network/public/cpp/cookie_manager_mojom_traits.h
@@ -17,7 +17,8 @@ #include "net/cookies/cookie_options.h" #include "net/cookies/cookie_partition_key_collection.h" #include "net/cookies/same_party_context.h" -#include "services/network/public/mojom/cookie_manager.mojom.h" +#include "services/network/public/mojom/cookie_manager.mojom-forward.h" +#include "services/network/public/mojom/cookie_partition_key.mojom.h" #include "third_party/abseil-cpp/absl/types/optional.h" namespace mojo {
diff --git a/services/network/public/mojom/network_context.mojom b/services/network/public/mojom/network_context.mojom index 8cdb76c..2fe9068 100644 --- a/services/network/public/mojom/network_context.mojom +++ b/services/network/public/mojom/network_context.mojom
@@ -253,6 +253,11 @@ // flag is false. mojo_base.mojom.FilePath? reporting_and_nel_store_database_name; + // The name of the file to store pending SCT auditing reports inside + // `data_path`. If empty, or `data_path` is empty, or the file can't be + // opened, pending reports will instead be stored in-memory only. + mojo_base.mojom.FilePath? sct_auditing_pending_reports_file_name; + // Specifies whether or not a migration of existing data should occur from // `unsandboxed_data_path` to `data_path`. It is not valid to set this to true // if an `unsandboxed_data_path` is not specified.
diff --git a/services/network/public/mojom/network_service.mojom b/services/network/public/mojom/network_service.mojom index 245313f..713f0d7 100644 --- a/services/network/public/mojom/network_service.mojom +++ b/services/network/public/mojom/network_service.mojom
@@ -340,11 +340,15 @@ // collection of set declarations according to the format specified in this // document: https://github.com/privacycg/first-party-sets. The collection may // either be a single JSON array of such records, or a sequence of - // newline-delimited JSON objects (one per line). Note that by setting the - // First-Party Sets, any previous notion of First-Party Sets is cleared. On - // any kind of error, all First-Party Sets are cleared (except for the - // manually-specified set, if one exists). - SetFirstPartySets(mojo_base.mojom.ReadOnlyFile sets_file); + // newline-delimited JSON objects (one per line). Note that only the first + // invocation will have any effect. On any kind of error, this method has no + // effect. + // + // |sets_file| may be invalid ("nullopt"), if no such file exists on the + // client's device (i.e. if the First-Party Sets component has not been + // installed yet). Such a case is considered the same as if the file's + // contents were empty. + SetFirstPartySets(mojo_base.mojom.ReadOnlyFile? sets_file); // Sets the First-Party Sets data that was persisted to compare it with the // current First-Party Sets data set by `SetFirstPartySets()`, which is
diff --git a/services/network/sct_auditing/sct_auditing_cache.h b/services/network/sct_auditing/sct_auditing_cache.h index aad251b..c5c3bc6 100644 --- a/services/network/sct_auditing/sct_auditing_cache.h +++ b/services/network/sct_auditing/sct_auditing_cache.h
@@ -43,6 +43,13 @@ // // The SCTAuditingCache allows the embedder to configure SCT auditing via the // network service's ConfigureSCTAuditing() API. +// +// Note: The SCTAuditingCache's deduplication cache is not persisted to disk. +// Pending reports that are persisted to disk by SCTAuditingHandler do not +// repopulate the deduplication cache when loaded. Not persisting the dedupe +// cache slightly increases the probability weight of sampling and sending SCTs +// from sites a user commonly visits (i.e., those they are likely to visit in +// every session). class COMPONENT_EXPORT(NETWORK_SERVICE) SCTAuditingCache { public: explicit SCTAuditingCache(size_t cache_size = 1024);
diff --git a/services/network/sct_auditing/sct_auditing_handler.cc b/services/network/sct_auditing/sct_auditing_handler.cc index aecb59e..846e7f5 100644 --- a/services/network/sct_auditing/sct_auditing_handler.cc +++ b/services/network/sct_auditing/sct_auditing_handler.cc
@@ -4,12 +4,27 @@ #include "services/network/sct_auditing/sct_auditing_handler.h" +#include "base/base64.h" +#include "base/feature_list.h" +#include "base/files/file_path.h" +#include "base/files/file_util.h" +#include "base/json/json_reader.h" +#include "base/json/json_writer.h" #include "base/metrics/histogram_functions.h" +#include "base/task/bind_post_task.h" +#include "base/task/post_task.h" +#include "base/task/sequenced_task_runner.h" +#include "base/task/task_runner_util.h" +#include "base/task/thread_pool.h" #include "base/time/time.h" +#include "base/values.h" +#include "net/base/backoff_entry.h" +#include "net/base/backoff_entry_serializer.h" #include "net/base/hash_value.h" #include "net/traffic_annotation/network_traffic_annotation.h" #include "services/network/network_context.h" #include "services/network/network_service.h" +#include "services/network/public/cpp/features.h" #include "services/network/public/mojom/network_context.mojom.h" #include "services/network/public/mojom/url_loader_factory.mojom.h" #include "services/network/public/proto/sct_audit_report.pb.h" @@ -18,15 +33,169 @@ namespace network { -SCTAuditingHandler::SCTAuditingHandler(NetworkContext* context, - size_t cache_size) - : owner_network_context_(context), pending_reporters_(cache_size) {} +namespace { -SCTAuditingHandler::~SCTAuditingHandler() = default; +std::string LoadReports(const base::FilePath& path) { + std::string result; + if (!base::ReadFileToString(path, &result)) { + return ""; + } + return result; +} + +// Keys in dictionary for each serialized pending reporter entry in the +// top-level list of serialized entries. +const char kReporterKeyKey[] = "reporter_key"; +const char kBackoffEntryKey[] = "backoff_entry"; +const char kReportKey[] = "report"; + +} // namespace + +SCTAuditingHandler::SCTAuditingHandler(NetworkContext* context, + const base::FilePath& persistence_path, + size_t cache_size) + : owner_network_context_(context), + pending_reporters_(cache_size), + persistence_path_(persistence_path), + foreground_runner_(base::SequencedTaskRunnerHandle::Get()) { + if (base::FeatureList::IsEnabled( + features::kSCTAuditingRetryAndPersistReports)) { + // If no persistence path is set, only store pending reporters in memory. + if (persistence_path_.empty()) { + return; + } + + // Persisting reports uses a low priority task runner as it should not block + // anything user-visible, but it should block shutdown to ensure updates are + // persisted to disk (particularly clearing entries or the entire + // persisted state). + background_runner_ = base::ThreadPool::CreateSequencedTaskRunner( + {base::MayBlock(), base::TaskPriority::BEST_EFFORT, + base::TaskShutdownBehavior::BLOCK_SHUTDOWN}); + writer_ = std::make_unique<base::ImportantFileWriter>(persistence_path_, + background_runner_); + + // Post a task to load persisted state after startup has finished. + foreground_runner_->PostTask( + FROM_HERE, base::BindOnce(&SCTAuditingHandler::OnStartupFinished, + weak_factory_.GetWeakPtr())); + } +} + +SCTAuditingHandler::~SCTAuditingHandler() { + DCHECK(foreground_runner_->RunsTasksInCurrentSequence()); + if (writer_ && writer_->HasPendingWrite()) { + writer_->DoScheduledWrite(); + } +} + +bool SCTAuditingHandler::SerializeData(std::string* output) { + DCHECK(foreground_runner_->RunsTasksInCurrentSequence()); + + base::Value reports(base::Value::Type::LIST); + for (const auto& kv : pending_reporters_) { + auto reporter_key = kv.first; + auto* reporter = kv.second.get(); + + base::Value report_entry(base::Value::Type::DICTIONARY); + + report_entry.SetStringKey(kReporterKeyKey, reporter_key.ToString()); + + base::Value backoff_entry_value = + net::BackoffEntrySerializer::SerializeToValue( + *reporter->backoff_entry(), base::Time::Now()); + report_entry.SetKey(kBackoffEntryKey, std::move(backoff_entry_value)); + + std::string serialized_report; + reporter->report()->SerializeToString(&serialized_report); + base::Base64Encode(serialized_report, &serialized_report); + report_entry.SetStringKey(kReportKey, serialized_report); + + reports.Append(std::move(report_entry)); + } + return base::JSONWriter::Write(reports, output); +} + +void SCTAuditingHandler::DeserializeData(const std::string& serialized) { + DCHECK(foreground_runner_->RunsTasksInCurrentSequence()); + + // Parse the serialized reports. + absl::optional<base::Value> value = base::JSONReader::Read(serialized); + if (!value || !value->is_list()) { + return; + } + + size_t num_reporters_deserialized = 0u; + for (const base::Value& sct_entry : value->GetList()) { + if (!sct_entry.is_dict()) { + continue; + } + + const std::string* reporter_key_string = + sct_entry.FindStringKey(kReporterKeyKey); + const std::string* report_string = sct_entry.FindStringKey(kReportKey); + const base::Value* backoff_entry_value = + sct_entry.FindKey(kBackoffEntryKey); + + if (!reporter_key_string || !report_string || !backoff_entry_value) { + continue; + } + + // Try to read the reporter_key from the entry and convert back to a + // HashValue. If it fails, continue to the next entry. + net::HashValue cache_key(net::HASH_VALUE_SHA256); + if (!cache_key.FromString(*reporter_key_string)) { + continue; + } + + // Check if cache_key already exists. If it's already in the pending set, + // skip re-adding it. + auto it = pending_reporters_.Get(cache_key); + if (it != pending_reporters_.end()) { + continue; + } + + // Try to recreate the BackoffEntry from the serialized value. + std::unique_ptr<net::BackoffEntry> backoff_entry = + net::BackoffEntrySerializer::DeserializeFromValue( + *backoff_entry_value, &SCTAuditingReporter::kDefaultBackoffPolicy, + nullptr, base::Time::Now()); + if (!backoff_entry) { + continue; + } + + // Try parsing the serialized protobuf. If it fails, continue to next entry. + std::string decoded_report_string; + if (!base::Base64Decode(*report_string, &decoded_report_string)) { + continue; + } + auto audit_report = std::make_unique<sct_auditing::SCTClientReport>(); + if (!audit_report->ParseFromString(decoded_report_string)) { + continue; + } + + AddReporter(cache_key, std::move(audit_report), std::move(backoff_entry)); + ++num_reporters_deserialized; + } + // TODO(crbug.com/1144205): Add metrics for number of reporters deserialized. +} + +void SCTAuditingHandler::OnStartupFinished() { + DCHECK(foreground_runner_->RunsTasksInCurrentSequence()); + is_after_startup_ = true; + // Load the persisted pending reports from disk on a background sequence, and + // then process them. + background_runner_->PostTaskAndReplyWithResult( + FROM_HERE, base::BindOnce(&LoadReports, persistence_path_), + base::BindOnce(&SCTAuditingHandler::OnReportsLoadedFromDisk, + weak_factory_.GetWeakPtr())); +} void SCTAuditingHandler::AddReporter( net::HashValue reporter_key, - std::unique_ptr<sct_auditing::SCTClientReport> report) { + std::unique_ptr<sct_auditing::SCTClientReport> report, + std::unique_ptr<net::BackoffEntry> backoff_entry) { + DCHECK(foreground_runner_->RunsTasksInCurrentSequence()); if (!enabled_) { return; } @@ -43,14 +212,37 @@ auto reporter = std::make_unique<SCTAuditingReporter>( reporter_key, std::move(report), GetURLLoaderFactory(), report_uri, traffic_annotation, - base::BindOnce(&SCTAuditingHandler::OnReporterFinished, GetWeakPtr())); + base::BindRepeating(&SCTAuditingHandler::OnReporterStateUpdated, + GetWeakPtr()), + base::BindOnce(&SCTAuditingHandler::OnReporterFinished, GetWeakPtr()), + std::move(backoff_entry)); reporter->Start(); pending_reporters_.Put(reporter->key(), std::move(reporter)); if (pending_reporters_.size() > pending_reporters_size_hwm_) pending_reporters_size_hwm_ = pending_reporters_.size(); + + // Trigger updating the persisted state. + if (writer_) { + writer_->ScheduleWrite(this); + } } +void SCTAuditingHandler::OnReportsLoadedFromDisk( + const std::string& serialized) { + DCHECK(foreground_runner_->RunsTasksInCurrentSequence()); + DCHECK(is_after_startup_); + DCHECK(!persisted_reports_read_); + + persisted_reports_read_ = true; + + DeserializeData(serialized); +} + +// TODOO(crbug.com/1144205): This method should take a completion callback (for +// callers like NetworkContext::ClearNetworkingHistoryBetween() that want to be +// able to wait for the write completing), and pass it through to the `writer_`, +// like TransportSecurityState does. void SCTAuditingHandler::ClearPendingReports() { // Delete any outstanding Reporters. This will delete any extant URLLoader // instances owned by the Reporters, which will cancel any outstanding @@ -58,7 +250,9 @@ // they trigger as they use a WeakPtr to the Reporter instance that posted the // task. pending_reporters_.Clear(); - // TODO(crbug.com/1144205): Clear any persisted state. + if (writer_) { + writer_->ScheduleWrite(this); + } } void SCTAuditingHandler::SetEnabled(bool enabled) { @@ -81,11 +275,27 @@ return weak_factory_.GetWeakPtr(); } +void SCTAuditingHandler::OnReporterStateUpdated() { + DCHECK(foreground_runner_->RunsTasksInCurrentSequence()); + + // Trigger updating the persisted state. + if (writer_) { + writer_->ScheduleWrite(this); + } +} + void SCTAuditingHandler::OnReporterFinished(net::HashValue reporter_key) { + DCHECK(foreground_runner_->RunsTasksInCurrentSequence()); + auto it = pending_reporters_.Get(reporter_key); if (it != pending_reporters_.end()) { pending_reporters_.Erase(it); } + + // Trigger updating the persisted state. + if (writer_) { + writer_->ScheduleWrite(this); + } } void SCTAuditingHandler::ReportHWMMetrics() {
diff --git a/services/network/sct_auditing/sct_auditing_handler.h b/services/network/sct_auditing/sct_auditing_handler.h index 47c6bd6..188eab1 100644 --- a/services/network/sct_auditing/sct_auditing_handler.h +++ b/services/network/sct_auditing/sct_auditing_handler.h
@@ -8,9 +8,12 @@ #include <memory> #include "base/containers/lru_cache.h" +#include "base/files/important_file_writer.h" #include "base/memory/weak_ptr.h" +#include "base/task/sequenced_task_runner.h" #include "base/timer/timer.h" #include "mojo/public/cpp/bindings/remote.h" +#include "net/base/backoff_entry.h" #include "net/base/hash_value.h" #include "services/network/public/mojom/url_loader_factory.mojom.h" #include "url/gurl.h" @@ -25,21 +28,64 @@ class SCTAuditingReporter; // SCTAuditingHandler owns SCT auditing reports for a specific NetworkContext. -// Each SCTAuditingHandler is owned by its matching NetworkContext. -class SCTAuditingHandler { +// Each SCTAuditingHandler is owned by its matching NetworkContext. The +// SCTAuditingHandler is also responsible for persisting pending auditing +// reports to disk and loading them back on browser startup. +// +// Note: Persisted reports only repopulate the SCTAuditingHandler's +// `pending_reports_` cache, and *do not* repopulate the SCTAuditingCache's +// deduplication cache. +class COMPONENT_EXPORT(NETWORK_SERVICE) SCTAuditingHandler + : public base::ImportantFileWriter::DataSerializer { public: - explicit SCTAuditingHandler(NetworkContext* context, - size_t cache_size = 1024); - ~SCTAuditingHandler(); + SCTAuditingHandler(NetworkContext* context, + const base::FilePath& persistence_path, + size_t cache_size = 1024); + ~SCTAuditingHandler() override; SCTAuditingHandler(const SCTAuditingHandler&) = delete; SCTAuditingHandler& operator=(const SCTAuditingHandler&) = delete; + // base::ImportantFileWriter::DataSerializer: + // + // Serializes `pending_reporters_` into `*output`. Returns true if all + // reporters were serialized correctly. + // + // The serialization format is a JSON list-of-dicts of the form: + // [ + // { + // "reporter_key": <serialized HashValue>, + // "report": <serialized SCTClientReport>, + // "backoff_entry": <serialized BackoffEntry> + // } + // ] + // + // Each entry in the dictionary includes sufficient information to deserialize + // and recreate the entries in the SCTAuditingHandler's pending reporters set. + bool SerializeData(std::string* output) override; + + void DeserializeData(const std::string& serialized); + + void OnStartupFinished(); + // Creates a new SCTAuditingReporter for the report and adds it to this // SCTAuditingHandler's pending reporters set. After creating the reporter, // this will call SCTAuditingReporter::Start() to initiate sending the report. + // Optionally takes in a BackoffEntry for recreating reporter state from + // persisted storage. void AddReporter(net::HashValue reporter_key, - std::unique_ptr<sct_auditing::SCTClientReport> report); + std::unique_ptr<sct_auditing::SCTClientReport> report, + std::unique_ptr<net::BackoffEntry> backoff_entry = nullptr); + + // Loads serialized reports from `serialized` and creates a new + // SCTAuditingReporter for each (if a reporter for that report does not yet + // exist). This results in the set of loaded reports being merged with any + // existing pending reports in the SCTAuditingCache. + // Returns true if all entries were parsed and deserialized correctly. + // If data does not deserialize correctly, this drops the entries rather than + // trying to recover. This means that reports currently (and intentionally, + // for simplicity) do no persist over format/version changes. + void OnReportsLoadedFromDisk(const std::string& serialized); // Clears the set of pending reporters for this SCTAuditingHandler. void ClearPendingReports(); @@ -58,9 +104,12 @@ url_loader_factory_.Bind(std::move(factory)); } + base::ImportantFileWriter* GetFileWriterForTesting() { return writer_.get(); } + base::WeakPtr<SCTAuditingHandler> GetWeakPtr(); private: + void OnReporterStateUpdated(); void OnReporterFinished(net::HashValue reporter_key); void ReportHWMMetrics(); network::mojom::URLLoaderFactory* GetURLLoaderFactory(); @@ -81,9 +130,18 @@ bool enabled_ = false; base::RepeatingTimer histogram_timer_; + // Helper for safely writing data to disk. + std::unique_ptr<base::ImportantFileWriter> writer_; + // Used to send reports. mojo::Remote<mojom::URLLoaderFactory> url_loader_factory_; + base::FilePath persistence_path_; + scoped_refptr<base::SequencedTaskRunner> foreground_runner_; + scoped_refptr<base::SequencedTaskRunner> background_runner_; + bool is_after_startup_ = false; + bool persisted_reports_read_ = false; + base::WeakPtrFactory<SCTAuditingHandler> weak_factory_{this}; };
diff --git a/services/network/sct_auditing/sct_auditing_handler_unittest.cc b/services/network/sct_auditing/sct_auditing_handler_unittest.cc new file mode 100644 index 0000000..43b4e65 --- /dev/null +++ b/services/network/sct_auditing/sct_auditing_handler_unittest.cc
@@ -0,0 +1,507 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/network/sct_auditing/sct_auditing_handler.h" + +#include <memory> + +#include "base/base64.h" +#include "base/feature_list.h" +#include "base/files/file_util.h" +#include "base/files/scoped_temp_dir.h" +#include "base/test/bind.h" +#include "base/test/scoped_feature_list.h" +#include "base/test/task_environment.h" +#include "mojo/public/cpp/bindings/remote.h" +#include "net/base/hash_value.h" +#include "net/traffic_annotation/network_traffic_annotation.h" +#include "net/traffic_annotation/network_traffic_annotation_test_helper.h" +#include "services/network/network_context.h" +#include "services/network/network_service.h" +#include "services/network/public/cpp/features.h" +#include "services/network/public/cpp/resource_request.h" +#include "services/network/public/proto/sct_audit_report.pb.h" +#include "services/network/sct_auditing/sct_auditing_cache.h" +#include "services/network/sct_auditing/sct_auditing_reporter.h" +#include "services/network/test/fake_test_cert_verifier_params_factory.h" +#include "services/network/test/test_url_loader_factory.h" +#include "services/network/test/test_utils.h" +#include "services/network/url_loader_factory.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace network { + +namespace { + +class SCTAuditingHandlerTest : public testing::Test { + public: + SCTAuditingHandlerTest() + : network_service_(NetworkService::CreateForTesting()) {} + ~SCTAuditingHandlerTest() override = default; + + SCTAuditingHandlerTest(const SCTAuditingHandlerTest&) = delete; + SCTAuditingHandlerTest& operator=(const SCTAuditingHandlerTest&) = delete; + + void SetUp() override { + ASSERT_TRUE(persistence_dir_.CreateUniqueTempDir()); + persistence_path_ = persistence_dir_.GetPath().AppendASCII("SCT Auditing"); + + // Set up a NetworkContext. + mojom::NetworkContextParamsPtr context_params = + CreateNetworkContextParamsForTesting(); + context_params->cert_verifier_params = + FakeTestCertVerifierParamsFactory::GetCertVerifierParams(); + context_params->enable_sct_auditing = true; + network_context_ = std::make_unique<NetworkContext>( + network_service_.get(), + network_context_remote_.BindNewPipeAndPassReceiver(), + std::move(context_params)); + + // Set up SCT auditing configuration. + auto* cache = network_service_->sct_auditing_cache(); + cache->set_enabled(true); + cache->set_sampling_rate(1.0); + cache->set_report_uri(GURL("https://example.test")); + cache->set_traffic_annotation( + net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS)); + } + + // Get the contents of `persistence_path_`. Pumps the message loop before + // returning the result. + std::string GetTestFileContents() { + task_environment_.RunUntilIdle(); + std::string file_contents; + base::ReadFileToString(persistence_path_, &file_contents); + return file_contents; + } + + // Check whether `substring` appears in the file contents at + // `persistence_path_`. + bool FileContentsHasString(const std::string& substring) { + auto contents = GetTestFileContents(); + auto position = contents.find(substring); + return position != std::string::npos; + } + + // Waits for `expected_requests` to be seen by the TestURLLoaderFactory. Note + // that this only counts HTTP requests, so network errors (e.g., cert errors) + // won't count. + void WaitForRequests(size_t expected_requests) { + // Initialize a new RunLoop, so that tests can call WaitForRequests() + // multiple times, if needed. + run_loop_ = std::make_unique<base::RunLoop>(); + + if (num_requests_seen_ >= expected_requests) { + return; + } + + // Add a TestURLLoaderFactory interceptor to count requests seen. + url_loader_factory_.SetInterceptor(base::BindLambdaForTesting( + [&](const network::ResourceRequest& request) { + ++num_requests_seen_; + if (run_loop_->running() && num_requests_seen_ >= expected_requests) { + run_loop_->QuitWhenIdle(); + } + })); + + run_loop_->Run(); + } + + protected: + base::test::TaskEnvironment task_environment_{ + base::test::TaskEnvironment::MainThreadType::IO, + base::test::TaskEnvironment::TimeSource::MOCK_TIME}; + base::ScopedTempDir persistence_dir_; + base::FilePath persistence_path_; + std::unique_ptr<NetworkService> network_service_; + std::unique_ptr<NetworkContext> network_context_; + + TestURLLoaderFactory url_loader_factory_; + + std::unique_ptr<base::RunLoop> run_loop_; + size_t num_requests_seen_ = 0; + + // Stores the mojo::Remote<mojom::NetworkContext> of the most recently created + // NetworkContext. + mojo::Remote<mojom::NetworkContext> network_context_remote_; +}; + +// Test that when the retry+persistence feature is disabled no reports will be +// persisted on disk. +TEST_F(SCTAuditingHandlerTest, PersistenceFeatureDisabled) { + base::test::ScopedFeatureList scoped_feature_list; + scoped_feature_list.InitAndDisableFeature( + features::kSCTAuditingRetryAndPersistReports); + + mojo::PendingRemote<network::mojom::URLLoaderFactory> factory_remote; + url_loader_factory_.Clone(factory_remote.InitWithNewPipeAndPassReceiver()); + + SCTAuditingHandler handler(network_context_.get(), persistence_path_); + handler.SetEnabled(true); + handler.SetURLLoaderFactoryForTesting(std::move(factory_remote)); + + // `file_writer` should not be created for this handler. + auto* file_writer = handler.GetFileWriterForTesting(); + EXPECT_EQ(file_writer, nullptr); +} + +// Test that when the SCTAuditingHandler is created without a persistence path +// (e.g., as happens for ephemeral profiles), no file writer is created. +TEST_F(SCTAuditingHandlerTest, HandlerWithoutPersistencePath) { + base::test::ScopedFeatureList scoped_feature_list; + scoped_feature_list.InitAndEnableFeature( + features::kSCTAuditingRetryAndPersistReports); + + mojo::PendingRemote<network::mojom::URLLoaderFactory> factory_remote; + url_loader_factory_.Clone(factory_remote.InitWithNewPipeAndPassReceiver()); + + // Set up a Handler with an empty `persistence_path`. + SCTAuditingHandler handler(network_context_.get(), base::FilePath()); + handler.SetEnabled(true); + handler.SetURLLoaderFactoryForTesting(std::move(factory_remote)); + + // `file_writer` should not be created for this handler. + auto* file_writer = handler.GetFileWriterForTesting(); + ASSERT_EQ(file_writer, nullptr); +} + +// Test that when the SCTAuditingHandler is created with a valid persistence +// path, then pending reports get stored to disk. +TEST_F(SCTAuditingHandlerTest, HandlerWithPersistencePath) { + base::test::ScopedFeatureList scoped_feature_list; + scoped_feature_list.InitAndEnableFeature( + features::kSCTAuditingRetryAndPersistReports); + + mojo::PendingRemote<network::mojom::URLLoaderFactory> factory_remote; + url_loader_factory_.Clone(factory_remote.InitWithNewPipeAndPassReceiver()); + + SCTAuditingHandler handler(network_context_.get(), persistence_path_); + handler.SetEnabled(true); + handler.SetURLLoaderFactoryForTesting(std::move(factory_remote)); + + auto* file_writer = handler.GetFileWriterForTesting(); + ASSERT_TRUE(file_writer); + + // Add a Reporter to the Handler and check that it gets scheduled to be + // persisted to disk. + auto report = std::make_unique<sct_auditing::SCTClientReport>(); + auto* tls_report = report->add_certificate_report(); + auto* connection_context = tls_report->mutable_context(); + auto* origin = connection_context->mutable_origin(); + origin->set_hostname("example.test"); + origin->set_port(443); + + // Fake a HashValue to use as the key. + net::HashValue reporter_key(net::HASH_VALUE_SHA256); + + handler.AddReporter(reporter_key, std::move(report)); + ASSERT_EQ(handler.GetPendingReportersForTesting()->size(), 1u); + ASSERT_TRUE(file_writer->HasPendingWrite()); + + // Check that file got written with the expected content. + file_writer->DoScheduledWrite(); + ASSERT_FALSE(file_writer->HasPendingWrite()); + EXPECT_TRUE(FileContentsHasString(reporter_key.ToString())); + + WaitForRequests(1u); + + EXPECT_EQ(1, url_loader_factory_.NumPending()); + + // Simulate the server returning 200 OK to the report request. + url_loader_factory_.SimulateResponseForPendingRequest( + "https://example.test", + /*content=*/"", + /*status=*/net::HTTP_OK); + + // Check that there are no pending requests anymore. + EXPECT_EQ(0, url_loader_factory_.NumPending()); + + // Check that the pending reporter was deleted on successful completion. + EXPECT_TRUE(handler.GetPendingReportersForTesting()->empty()); + + // Check that the Reporter is no longer in the file. + file_writer->DoScheduledWrite(); + ASSERT_FALSE(file_writer->HasPendingWrite()); + EXPECT_FALSE(FileContentsHasString(reporter_key.ToString())); +} + +// Tests that serializing reports and then deserializing them results in the +// same data. +TEST_F(SCTAuditingHandlerTest, DataRoundTrip) { + base::test::ScopedFeatureList scoped_feature_list; + scoped_feature_list.InitAndEnableFeature( + features::kSCTAuditingRetryAndPersistReports); + + // Create a Handler, add a reporter, and wait for it to get persisted. + { + SCTAuditingHandler handler(network_context_.get(), persistence_path_); + handler.SetEnabled(true); + mojo::PendingRemote<network::mojom::URLLoaderFactory> factory_remote; + url_loader_factory_.Clone(factory_remote.InitWithNewPipeAndPassReceiver()); + handler.SetURLLoaderFactoryForTesting(std::move(factory_remote)); + + auto* file_writer = handler.GetFileWriterForTesting(); + ASSERT_TRUE(file_writer); + + ASSERT_TRUE(handler.is_enabled()); + ASSERT_FALSE(file_writer->HasPendingWrite()); + + // Add a Reporter to the Handler and check that it gets scheduled to be + // persisted to disk. + auto report = std::make_unique<sct_auditing::SCTClientReport>(); + auto* tls_report = report->add_certificate_report(); + auto* connection_context = tls_report->mutable_context(); + auto* origin = connection_context->mutable_origin(); + origin->set_hostname("example.test"); + origin->set_port(443); + + // Fake a HashValue to use as the key. + net::HashValue reporter_key(net::HASH_VALUE_SHA256); + + handler.AddReporter(reporter_key, std::move(report)); + ASSERT_EQ(handler.GetPendingReportersForTesting()->size(), 1u); + ASSERT_TRUE(file_writer->HasPendingWrite()); + + // Check that file got written with the expected content. + file_writer->DoScheduledWrite(); + ASSERT_FALSE(file_writer->HasPendingWrite()); + EXPECT_TRUE(FileContentsHasString(reporter_key.ToString())); + } + + // Create a second Handler using the same persistence path. It should load + // the same data. + { + SCTAuditingHandler handler(network_context_.get(), persistence_path_); + handler.SetEnabled(true); + mojo::PendingRemote<network::mojom::URLLoaderFactory> factory_remote; + url_loader_factory_.Clone(factory_remote.InitWithNewPipeAndPassReceiver()); + handler.SetURLLoaderFactoryForTesting(std::move(factory_remote)); + + auto* file_writer = handler.GetFileWriterForTesting(); + ASSERT_TRUE(file_writer); + + WaitForRequests(1u); + + auto* pending_reporters = handler.GetPendingReportersForTesting(); + ASSERT_EQ(1u, pending_reporters->size()); + + // Reporter should be for "example.test:443" as added in the first Handler. + for (const auto& reporter : *pending_reporters) { + auto origin = + reporter.second->report()->certificate_report(0).context().origin(); + EXPECT_EQ(origin.hostname(), "example.test"); + EXPECT_EQ(origin.port(), 443); + } + } +} + +// Test that deserializing bad data shouldn't result in any reporters being +// created. +TEST_F(SCTAuditingHandlerTest, DeserializeBadData) { + base::test::ScopedFeatureList scoped_feature_list; + scoped_feature_list.InitAndEnableFeature( + features::kSCTAuditingRetryAndPersistReports); + + mojo::PendingRemote<network::mojom::URLLoaderFactory> factory_remote; + url_loader_factory_.Clone(factory_remote.InitWithNewPipeAndPassReceiver()); + + // Set an empty persistence path so no file IO is performed. + SCTAuditingHandler handler(network_context_.get(), base::FilePath()); + handler.SetURLLoaderFactoryForTesting(std::move(factory_remote)); + + // Non-JSON data. + handler.DeserializeData("Blorp"); + EXPECT_EQ(handler.GetPendingReportersForTesting()->size(), 0u); + + // JSON data but non-sensical. + handler.DeserializeData("[15]"); + EXPECT_EQ(handler.GetPendingReportersForTesting()->size(), 0u); + + // JSON data in the right format, but with invalid keys. + handler.DeserializeData(R"([{"blorp": "a", "bloop": "b", "bleep": "c"}])"); + EXPECT_EQ(handler.GetPendingReportersForTesting()->size(), 0u); + + // JSON data with the right format and keys, but data is invalid. + handler.DeserializeData( + R"([{"reporter_key": "a", "report": "b", "backoff_entry": ["c"]}])"); + EXPECT_EQ(handler.GetPendingReportersForTesting()->size(), 0u); + + // Check that no file got written to the persistence path. + EXPECT_EQ(GetTestFileContents(), std::string()); +} + +// Test that a handler loads valid persisted data from disk and creates pending +// reporters for each entry. +TEST_F(SCTAuditingHandlerTest, HandlerWithExistingPersistedData) { + base::test::ScopedFeatureList scoped_feature_list; + scoped_feature_list.InitAndEnableFeature( + features::kSCTAuditingRetryAndPersistReports); + + // Set up previously persisted data on disk: + // - Default-initialized net::HashValue(net::HASH_VALUE_SHA256) + // - Empty SCTClientReport for origin "example.test:443". + // - A simple BackoffEntry. + std::string persisted_report = + R"( + [{ + "reporter_key": + "sha256/qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqo=", + "report": "EhUKExIRCgxleGFtcGxlLnRlc3QQuwM=", + "backoff_entry": [2,0,"30000000","11644578625551798"] + }] + )"; + ASSERT_TRUE(base::WriteFile(persistence_path_, persisted_report)); + + mojo::PendingRemote<network::mojom::URLLoaderFactory> factory_remote; + url_loader_factory_.Clone(factory_remote.InitWithNewPipeAndPassReceiver()); + + SCTAuditingHandler handler(network_context_.get(), persistence_path_); + handler.SetEnabled(true); + handler.SetURLLoaderFactoryForTesting(std::move(factory_remote)); + + auto* file_writer = handler.GetFileWriterForTesting(); + ASSERT_TRUE(file_writer); + + WaitForRequests(1u); + + EXPECT_EQ(handler.GetPendingReportersForTesting()->size(), 1u); + EXPECT_EQ(1, url_loader_factory_.NumPending()); + + // Simulate the server returning 200 OK to the report request. + url_loader_factory_.SimulateResponseForPendingRequest( + "https://example.test", + /*content=*/"", + /*status=*/net::HTTP_OK); + + // Check that there is no pending requests anymore. + EXPECT_EQ(0, url_loader_factory_.NumPending()); + + // Check that the pending reporter was deleted on successful completion. + EXPECT_TRUE(handler.GetPendingReportersForTesting()->empty()); + + // Check that the Reporter is no longer in the file. + file_writer->DoScheduledWrite(); + ASSERT_FALSE(file_writer->HasPendingWrite()); + EXPECT_FALSE(FileContentsHasString( + "sha256/qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqo=")); +} + +// Test that scheduling a retry causes the failure count to increment in +// persisted storage. +TEST_F(SCTAuditingHandlerTest, RetryUpdatesPersistedBackoffEntry) { + base::test::ScopedFeatureList scoped_feature_list; + scoped_feature_list.InitAndEnableFeature( + features::kSCTAuditingRetryAndPersistReports); + + // Set up previously persisted data on disk: + // - Default-initialized net::HashValue(net::HASH_VALUE_SHA256) + // - Empty SCTClientReport for origin "example.test:443". + // - A simple BackoffEntry with a failure count of "1". + std::string persisted_report = + R"( + [{ + "reporter_key": + "sha256/qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqo=", + "report": "EhUKExIRCgxleGFtcGxlLnRlc3QQuwM=", + "backoff_entry": [2,1,"30000000","11644578625551798"] + }] + )"; + ASSERT_TRUE(base::WriteFile(persistence_path_, persisted_report)); + + mojo::PendingRemote<network::mojom::URLLoaderFactory> factory_remote; + url_loader_factory_.Clone(factory_remote.InitWithNewPipeAndPassReceiver()); + + SCTAuditingHandler handler(network_context_.get(), persistence_path_); + handler.SetEnabled(true); + handler.SetURLLoaderFactoryForTesting(std::move(factory_remote)); + + auto* file_writer = handler.GetFileWriterForTesting(); + ASSERT_TRUE(file_writer); + + WaitForRequests(1u); + + EXPECT_EQ(handler.GetPendingReportersForTesting()->size(), 1u); + EXPECT_EQ(url_loader_factory_.NumPending(), 1); + + // Simulate the server returning error to the report request. The Reporter + // should schedule a retry and trigger updating the persisted storage. + url_loader_factory_.SimulateResponseForPendingRequest( + "https://example.test", + /*content=*/"", + /*status=*/net::HTTP_TOO_MANY_REQUESTS); + EXPECT_EQ(url_loader_factory_.NumPending(), 0); + EXPECT_EQ(handler.GetPendingReportersForTesting()->size(), 1u); + ASSERT_TRUE(file_writer->HasPendingWrite()); + + // Check that the Reporter is updated in the persisted storage file. + file_writer->DoScheduledWrite(); + ASSERT_FALSE(file_writer->HasPendingWrite()); + EXPECT_TRUE(FileContentsHasString( + "sha256/qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqo=")); + // Persisted backoff entry should have the failure count incrememted to 2. + // (The first value of a serialized BackoffEntry is the version, the second is + // the failure count.) + EXPECT_TRUE(FileContentsHasString(R"("backoff_entry":[2,2,)")); +} + +// Test that retries carry over correctly. Specifically, a persisted entry with +// 14 retries already (one less than kMaxRetries), if after being loaded from +// persisted storage tries and fails once more, should get deleted. +TEST_F(SCTAuditingHandlerTest, RestoringMaxRetries) { + base::test::ScopedFeatureList scoped_feature_list; + scoped_feature_list.InitAndEnableFeature( + features::kSCTAuditingRetryAndPersistReports); + + // Set up previously persisted data on disk: + // - Default-initialized net::HashValue(net::HASH_VALUE_SHA256) + // - Empty SCTClientReport for origin "example.test:443". + // - A simple BackoffEntry with a failure count of "15" (so it is scheduled to + // retry for the 15th and final time). + std::string persisted_report = + R"( + [{ + "reporter_key": + "sha256/qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqo=", + "report": "EhUKExIRCgxleGFtcGxlLnRlc3QQuwM=", + "backoff_entry": [2,15,"30000000","11644578625551798"] + }] + )"; + ASSERT_TRUE(base::WriteFile(persistence_path_, persisted_report)); + + mojo::PendingRemote<network::mojom::URLLoaderFactory> factory_remote; + url_loader_factory_.Clone(factory_remote.InitWithNewPipeAndPassReceiver()); + + SCTAuditingHandler handler(network_context_.get(), persistence_path_); + handler.SetEnabled(true); + handler.SetURLLoaderFactoryForTesting(std::move(factory_remote)); + + auto* file_writer = handler.GetFileWriterForTesting(); + ASSERT_TRUE(file_writer); + + WaitForRequests(1u); + + EXPECT_EQ(handler.GetPendingReportersForTesting()->size(), 1u); + EXPECT_EQ(url_loader_factory_.NumPending(), 1); + + // Simulate the server returning error to the report request. The Reporter + // should schedule a retry and trigger updating the persisted storage. + url_loader_factory_.SimulateResponseForPendingRequest( + "https://example.test", + /*content=*/"", + /*status=*/net::HTTP_TOO_MANY_REQUESTS); + EXPECT_EQ(url_loader_factory_.NumPending(), 0); + + // Pending reporter should get deleted as it has reached max retries. + EXPECT_EQ(handler.GetPendingReportersForTesting()->size(), 0u); + + // Reporter state on disk should get deleted as well. + file_writer->DoScheduledWrite(); + ASSERT_FALSE(file_writer->HasPendingWrite()); + EXPECT_FALSE(FileContentsHasString( + "sha256/qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqo=")); +} + +} // namespace + +} // namespace network
diff --git a/services/network/sct_auditing/sct_auditing_reporter.cc b/services/network/sct_auditing/sct_auditing_reporter.cc index 0cca80d..d664401d 100644 --- a/services/network/sct_auditing/sct_auditing_reporter.cc +++ b/services/network/sct_auditing/sct_auditing_reporter.cc
@@ -73,7 +73,7 @@ // roughly the next five days. // See more discussion in the SCT Auditing Retry and Persistence design doc: // https://docs.google.com/document/d/1YTUzoG6BDF1QIxosaQDp2H5IzYY7_fwH8qNJXSVX8OQ/edit -constexpr size_t kMaxRetries = 15; +constexpr int kMaxRetries = 15; SCTAuditingReporter::SCTAuditingReporter( net::HashValue reporter_key, @@ -81,13 +81,15 @@ mojom::URLLoaderFactory* url_loader_factory, const GURL& report_uri, const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, - ReporterDoneCallback done_callback) + ReporterUpdatedCallback update_callback, + ReporterDoneCallback done_callback, + std::unique_ptr<net::BackoffEntry> persisted_backoff_entry) : reporter_key_(reporter_key), report_(std::move(report)), traffic_annotation_(traffic_annotation), report_uri_(report_uri), + update_callback_(std::move(update_callback)), done_callback_(std::move(done_callback)), - num_retries_(0), max_retries_(kMaxRetries) { // Clone the URLLoaderFactory to avoid any dependencies on its lifetime. The // Reporter instance can maintain its own copy. @@ -104,34 +106,31 @@ backoff_policy_.initial_delay_ms = g_retry_delay_for_testing->InMilliseconds(); } - backoff_entry_ = std::make_unique<net::BackoffEntry>(&backoff_policy_); + + // `persisted_backoff_entry` is only non-null when persistence is enabled and + // this SCTAuditingReporter is being created from a reporter that had been + // persisted to disk. + if (persisted_backoff_entry) { + backoff_entry_ = std::move(persisted_backoff_entry); + } else { + backoff_entry_ = std::make_unique<net::BackoffEntry>(&backoff_policy_); + // Informing the backoff entry of a success will force it to use the initial + // delay (and jitter) for the first attempt. Otherwise, + // ShouldRejectRequest() will return `true` despite the policy specifying + // `always_use_initial_delay = true`. + backoff_entry_->InformOfRequest(true); + } } SCTAuditingReporter::~SCTAuditingReporter() = default; void SCTAuditingReporter::Start() { - // Informing the backoff entry of a success will force it to use the initial - // delay (and jitter) for the first attempt. Otherwise, ShouldRejectRequest() - // will return `true` despite the policy specifying - // `always_use_initial_delay = true`. - backoff_entry_->InformOfRequest(true); - - // Start sending the report. - ScheduleReport(); -} - -void SCTAuditingReporter::SetRetryDelayForTesting( - absl::optional<base::TimeDelta> delay) { - g_retry_delay_for_testing = delay; -} - -void SCTAuditingReporter::ScheduleReport() { if (base::FeatureList::IsEnabled( features::kSCTAuditingRetryAndPersistReports) && backoff_entry_->ShouldRejectRequest()) { // TODO(crbug.com/1199827): Investigate if explicit task traits should be // used for these tasks (e.g., BEST_EFFORT and SKIP_ON_SHUTDOWN). - base::ThreadTaskRunnerHandle::Get()->PostDelayedTask( + base::SequencedTaskRunnerHandle::Get()->PostDelayedTask( FROM_HERE, base::BindOnce(&SCTAuditingReporter::SendReport, weak_factory_.GetWeakPtr()), @@ -141,6 +140,11 @@ } } +void SCTAuditingReporter::SetRetryDelayForTesting( + absl::optional<base::TimeDelta> delay) { + g_retry_delay_for_testing = delay; +} + void SCTAuditingReporter::SendReport() { DCHECK(url_loader_factory_remote_); @@ -191,7 +195,7 @@ features::kSCTAuditingRetryAndPersistReports)) { if (success) { // Report succeeded. - if (num_retries_ == 0) { + if (backoff_entry_->failure_count() == 0) { ReportSCTAuditingCompletionStatusMetrics( CompletionStatus::kSuccessFirstTry); } else { @@ -205,7 +209,7 @@ return; } // Sending the report failed. - if (num_retries_ >= max_retries_) { + if (backoff_entry_->failure_count() >= max_retries_) { // Retry limit reached. ReportSCTAuditingCompletionStatusMetrics( CompletionStatus::kRetriesExhausted); @@ -215,10 +219,11 @@ std::move(done_callback_).Run(reporter_key_); return; } else { - // Schedule a retry. - ++num_retries_; + // Schedule a retry and alert the SCTAuditingHandler to trigger a write so + // it can persist the updated backoff entry. backoff_entry_->InformOfRequest(false); - ScheduleReport(); + update_callback_.Run(); + Start(); } } else { // Retry is not enabled, so just notify the Cache that this Reporter is
diff --git a/services/network/sct_auditing/sct_auditing_reporter.h b/services/network/sct_auditing/sct_auditing_reporter.h index 9bcdd2120..7e89aee3 100644 --- a/services/network/sct_auditing/sct_auditing_reporter.h +++ b/services/network/sct_auditing/sct_auditing_reporter.h
@@ -45,6 +45,9 @@ // The SHA256HashValue `reporter_key` is passed to uniquely identify this // reporter instance. using ReporterDoneCallback = base::OnceCallback<void(net::HashValue)>; + // Callback to notify the SCTAuditingHandler that the reporter has updated + // (e.g., the retry counter has been incremented). + using ReporterUpdatedCallback = base::RepeatingCallback<void()>; SCTAuditingReporter( net::HashValue reporter_key, @@ -52,7 +55,9 @@ mojom::URLLoaderFactory* url_loader_factory, const GURL& report_uri, const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, - ReporterDoneCallback done_callback); + ReporterUpdatedCallback update_callback, + ReporterDoneCallback done_callback, + std::unique_ptr<net::BackoffEntry> backoff_entry = nullptr); ~SCTAuditingReporter(); SCTAuditingReporter(const SCTAuditingReporter&) = delete; @@ -64,6 +69,7 @@ net::HashValue key() { return reporter_key_; } sct_auditing::SCTClientReport* report() { return report_.get(); } + net::BackoffEntry* backoff_entry() { return backoff_entry_.get(); } // These values are persisted to logs. Entries should not be renumbered and // numeric values should never be reused. @@ -77,7 +83,6 @@ static void SetRetryDelayForTesting(absl::optional<base::TimeDelta> delay); private: - void ScheduleReport(); void SendReport(); void OnSendReportComplete(scoped_refptr<net::HttpResponseHeaders> headers); @@ -87,13 +92,13 @@ std::unique_ptr<SimpleURLLoader> url_loader_; net::NetworkTrafficAnnotationTag traffic_annotation_; GURL report_uri_; + ReporterUpdatedCallback update_callback_; ReporterDoneCallback done_callback_; net::BackoffEntry::Policy backoff_policy_; std::unique_ptr<net::BackoffEntry> backoff_entry_; - size_t num_retries_; - size_t max_retries_; + int max_retries_; base::WeakPtrFactory<SCTAuditingReporter> weak_factory_{this}; };
diff --git a/testing/buildbot/chromium.fyi.json b/testing/buildbot/chromium.fyi.json index bddb60c..484b328 100644 --- a/testing/buildbot/chromium.fyi.json +++ b/testing/buildbot/chromium.fyi.json
@@ -5006,6 +5006,2343 @@ } ] }, + "Win11 Tests x64": { + "gtest_tests": [ + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "absl_hardening_tests", + "test_id_prefix": "ninja://third_party/abseil-cpp:absl_hardening_tests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "accessibility_unittests", + "test_id_prefix": "ninja://ui/accessibility:accessibility_unittests/" + }, + { + "args": [ + "angle_unittests" + ], + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_isolated_script_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "angle_unittests", + "test_id_prefix": "ninja://third_party/angle/src/tests:angle_unittests/", + "use_isolated_scripts_api": true + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "app_shell_unittests", + "test_id_prefix": "ninja://extensions/shell:app_shell_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "aura_unittests", + "test_id_prefix": "ninja://ui/aura:aura_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "base_unittests", + "test_id_prefix": "ninja://base:base_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "blink_common_unittests", + "test_id_prefix": "ninja://third_party/blink/common:blink_common_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "blink_fuzzer_unittests", + "test_id_prefix": "ninja://third_party/blink/renderer/platform:blink_fuzzer_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "blink_heap_unittests", + "test_id_prefix": "ninja://third_party/blink/renderer/platform/heap:blink_heap_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "blink_platform_unittests", + "test_id_prefix": "ninja://third_party/blink/renderer/platform:blink_platform_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "name": "webkit_unit_tests", + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "blink_unittests", + "test_id_prefix": "ninja://third_party/blink/renderer/controller:blink_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "boringssl_crypto_tests", + "test_id_prefix": "ninja://third_party/boringssl:boringssl_crypto_tests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "boringssl_ssl_tests", + "test_id_prefix": "ninja://third_party/boringssl:boringssl_ssl_tests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "browser_switcher_bho_unittests", + "test_id_prefix": "ninja://chrome/browser/browser_switcher/bho:browser_switcher_bho_unittests/" + }, + { + "args": [ + "--disable-features=WebRTC-H264WithOpenH264FFmpeg" + ], + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "quickrun_shards": 30, + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com", + "shards": 15 + }, + "test": "browser_tests", + "test_id_prefix": "ninja://chrome/test:browser_tests/" + }, + { + "args": [ + "--browser-ui-tests-verify-pixels", + "--enable-pixel-output-in-tests", + "--test-launcher-filter-file=../../testing/buildbot/filters/pixel_browser_tests.filter", + "--git-revision=${got_revision}" + ], + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "name": "pixel_browser_tests", + "precommit_args": [ + "--gerrit-issue=${patch_issue}", + "--gerrit-patchset=${patch_set}", + "--buildbucket-id=${buildbucket_build_id}" + ], + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chrome-gold@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "browser_tests", + "test_id_prefix": "ninja://chrome/test:browser_tests/" + }, + { + "args": [ + "--gtest_filter=-*UsingRealWebcam*" + ], + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "capture_unittests", + "test_id_prefix": "ninja://media/capture:capture_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "cast_unittests", + "test_id_prefix": "ninja://media/cast:cast_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "cc_unittests", + "test_id_prefix": "ninja://cc:cc_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "chrome_app_unittests", + "test_id_prefix": "ninja://chrome/test:chrome_app_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "chrome_cleaner_unittests", + "test_id_prefix": "ninja://chrome/chrome_cleaner:chrome_cleaner_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "chrome_elf_unittests", + "test_id_prefix": "ninja://chrome/chrome_elf:chrome_elf_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "chromedriver_unittests", + "test_id_prefix": "ninja://chrome/test/chromedriver:chromedriver_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "color_unittests", + "test_id_prefix": "ninja://ui/color:color_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "components_browsertests", + "test_id_prefix": "ninja://components:components_browsertests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "components_unittests", + "test_id_prefix": "ninja://components:components_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "compositor_unittests", + "test_id_prefix": "ninja://ui/compositor:compositor_unittests/" + }, + { + "args": [ + "--disable-features=WebRTC-H264WithOpenH264FFmpeg" + ], + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com", + "shards": 6 + }, + "test": "content_browsertests", + "test_id_prefix": "ninja://content/test:content_browsertests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "content_unittests", + "test_id_prefix": "ninja://content/test:content_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "courgette_unittests", + "test_id_prefix": "ninja://courgette:courgette_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "crashpad_tests", + "test_id_prefix": "ninja://third_party/crashpad/crashpad:crashpad_tests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "cronet_tests", + "test_id_prefix": "ninja://components/cronet:cronet_tests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "cronet_unittests", + "test_id_prefix": "ninja://components/cronet:cronet_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "crypto_unittests", + "test_id_prefix": "ninja://crypto:crypto_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "delayloads_unittests", + "test_id_prefix": "ninja://chrome/test:delayloads_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "device_unittests", + "test_id_prefix": "ninja://device:device_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "display_unittests", + "test_id_prefix": "ninja://ui/display:display_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "elevation_service_unittests", + "test_id_prefix": "ninja://chrome/elevation_service:elevation_service_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "events_unittests", + "test_id_prefix": "ninja://ui/events:events_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "extensions_browsertests", + "test_id_prefix": "ninja://extensions:extensions_browsertests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "extensions_unittests", + "test_id_prefix": "ninja://extensions:extensions_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "filesystem_service_unittests", + "test_id_prefix": "ninja://components/services/filesystem:filesystem_service_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "gcm_unit_tests", + "test_id_prefix": "ninja://google_apis/gcm:gcm_unit_tests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "gcp_unittests", + "test_id_prefix": "ninja://chrome/credential_provider/test:gcp_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "gfx_unittests", + "test_id_prefix": "ninja://ui/gfx:gfx_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "gin_unittests", + "test_id_prefix": "ninja://gin:gin_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "google_apis_unittests", + "test_id_prefix": "ninja://google_apis:google_apis_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "gpu_unittests", + "test_id_prefix": "ninja://gpu:gpu_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "gwp_asan_unittests", + "test_id_prefix": "ninja://components/gwp_asan:gwp_asan_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "headless_browsertests", + "test_id_prefix": "ninja://headless:headless_browsertests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "headless_unittests", + "test_id_prefix": "ninja://headless:headless_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "install_static_unittests", + "test_id_prefix": "ninja://chrome/install_static:install_static_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "integrity": "high", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "installer_util_unittests", + "test_id_prefix": "ninja://chrome/installer/util:installer_util_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com", + "shards": 3 + }, + "test": "interactive_ui_tests", + "test_id_prefix": "ninja://chrome/test:interactive_ui_tests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "ipc_tests", + "test_id_prefix": "ninja://ipc:ipc_tests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "jingle_unittests", + "test_id_prefix": "ninja://jingle:jingle_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "latency_unittests", + "test_id_prefix": "ninja://ui/latency:latency_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "libjingle_xmpp_unittests", + "test_id_prefix": "ninja://third_party/libjingle_xmpp:libjingle_xmpp_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "liburlpattern_unittests", + "test_id_prefix": "ninja://third_party/liburlpattern:liburlpattern_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "media_unittests", + "test_id_prefix": "ninja://media:media_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "message_center_unittests", + "test_id_prefix": "ninja://ui/message_center:message_center_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "midi_unittests", + "test_id_prefix": "ninja://media/midi:midi_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "mojo_core_unittests", + "test_id_prefix": "ninja://mojo/core:mojo_core_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "mojo_unittests", + "test_id_prefix": "ninja://mojo:mojo_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "nacl_loader_unittests", + "test_id_prefix": "ninja://components/nacl/loader:nacl_loader_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "native_theme_unittests", + "test_id_prefix": "ninja://ui/native_theme:native_theme_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "net_unittests", + "test_id_prefix": "ninja://net:net_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "notification_helper_unittests", + "test_id_prefix": "ninja://chrome/notification_helper:notification_helper_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "pdf_unittests", + "test_id_prefix": "ninja://pdf:pdf_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "perfetto_unittests", + "test_id_prefix": "ninja://third_party/perfetto:perfetto_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "ppapi_unittests", + "test_id_prefix": "ninja://ppapi:ppapi_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "printing_unittests", + "test_id_prefix": "ninja://printing:printing_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "remoting_unittests", + "test_id_prefix": "ninja://remoting:remoting_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "integrity": "high", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "sbox_integration_tests", + "test_id_prefix": "ninja://sandbox/win:sbox_integration_tests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "sbox_unittests", + "test_id_prefix": "ninja://sandbox/win:sbox_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "sbox_validation_tests", + "test_id_prefix": "ninja://sandbox/win:sbox_validation_tests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "service_manager_unittests", + "test_id_prefix": "ninja://services/service_manager/tests:service_manager_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "services_unittests", + "test_id_prefix": "ninja://services:services_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "integrity": "high", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "setup_unittests", + "test_id_prefix": "ninja://chrome/installer/setup:setup_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "shell_dialogs_unittests", + "test_id_prefix": "ninja://ui/shell_dialogs:shell_dialogs_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "skia_unittests", + "test_id_prefix": "ninja://skia:skia_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "snapshot_unittests", + "test_id_prefix": "ninja://ui/snapshot:snapshot_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "sql_unittests", + "test_id_prefix": "ninja://sql:sql_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "storage_unittests", + "test_id_prefix": "ninja://storage:storage_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "sync_integration_tests", + "test_id_prefix": "ninja://chrome/test:sync_integration_tests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "ui_base_unittests", + "test_id_prefix": "ninja://ui/base:ui_base_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "ui_touch_selection_unittests", + "test_id_prefix": "ninja://ui/touch_selection:ui_touch_selection_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "unit_tests", + "test_id_prefix": "ninja://chrome/test:unit_tests/" + }, + { + "args": [ + "--test-launcher-timeout=90000", + "--ui-test-action-max-timeout=45000", + "--ui-test-action-timeout=40000" + ], + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "updater_tests", + "test_id_prefix": "ninja://chrome/updater:updater_tests/" + }, + { + "args": [ + "--test-launcher-print-test-stdio=always", + "--test-launcher-timeout=90000", + "--ui-test-action-max-timeout=45000", + "--ui-test-action-timeout=40000", + "--ui-test-action-timeout=40000" + ], + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "updater_tests_system", + "test_id_prefix": "ninja://chrome/updater:updater_tests_system/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "url_unittests", + "test_id_prefix": "ninja://url:url_unittests/" + }, + { + "args": [ + "--git-revision=${got_revision}" + ], + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "precommit_args": [ + "--gerrit-issue=${patch_issue}", + "--gerrit-patchset=${patch_set}", + "--buildbucket-id=${buildbucket_build_id}" + ], + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chrome-gold@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "views_examples_unittests", + "test_id_prefix": "ninja://ui/views/examples:views_examples_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "views_unittests", + "test_id_prefix": "ninja://ui/views:views_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "viz_unittests", + "test_id_prefix": "ninja://components/viz:viz_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "vr_common_unittests", + "test_id_prefix": "ninja://chrome/browser/vr:vr_common_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "vr_pixeltests", + "test_id_prefix": "ninja://chrome/browser/vr:vr_pixeltests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "weblayer_browsertests", + "test_id_prefix": "ninja://weblayer/test:weblayer_browsertests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "weblayer_unittests", + "test_id_prefix": "ninja://weblayer/test:weblayer_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "wm_unittests", + "test_id_prefix": "ninja://ui/wm:wm_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "wtf_unittests", + "test_id_prefix": "ninja://third_party/blink/renderer/platform/wtf:wtf_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "zlib_unittests", + "test_id_prefix": "ninja://third_party/zlib:zlib_unittests/" + }, + { + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_gtest_merge.py" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test": "zucchini_unittests", + "test_id_prefix": "ninja://components/zucchini:zucchini_unittests/" + } + ], + "isolated_scripts": [ + { + "isolate_name": "blink_python_tests", + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_isolated_script_merge.py" + }, + "name": "blink_python_tests", + "resultdb": { + "enable": true + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test_id_prefix": "ninja://:blink_python_tests/" + }, + { + "args": [ + "--num-retries=3", + "--target", + "Release_x64" + ], + "isolate_name": "blink_web_tests", + "isolate_profile_data": true, + "merge": { + "args": [ + "--verbose" + ], + "script": "//third_party/blink/tools/merge_web_test_results.py" + }, + "name": "blink_web_tests", + "resultdb": { + "enable": true + }, + "results_handler": "layout tests", + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com", + "shards": 28 + }, + "test_id_prefix": "ninja://:blink_web_tests/" + }, + { + "args": [ + "--test-type=integration" + ], + "isolate_name": "chromedriver_py_tests", + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_isolated_script_merge.py" + }, + "name": "chromedriver_py_tests", + "resultdb": { + "enable": true + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test_id_prefix": "ninja://chrome/test/chromedriver:chromedriver_py_tests/" + }, + { + "isolate_name": "chromedriver_replay_unittests", + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_isolated_script_merge.py" + }, + "name": "chromedriver_replay_unittests", + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test_id_prefix": "ninja://chrome/test/chromedriver:chromedriver_replay_unittests/" + }, + { + "args": [ + "--gtest-benchmark-name=components_perftests" + ], + "isolate_name": "components_perftests", + "isolate_profile_data": true, + "merge": { + "args": [ + "--smoke-test-mode" + ], + "script": "//tools/perf/process_perf_results.py" + }, + "name": "components_perftests", + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test_id_prefix": "ninja://components:components_perftests/" + }, + { + "isolate_name": "content_shell_crash_test", + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_isolated_script_merge.py" + }, + "name": "content_shell_crash_test", + "resultdb": { + "enable": true, + "result_format": "single" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test_id_prefix": "ninja://content/shell:content_shell_crash_test/" + }, + { + "isolate_name": "flatbuffers_unittests", + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_isolated_script_merge.py" + }, + "name": "flatbuffers_unittests", + "resultdb": { + "enable": true, + "result_format": "single" + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test_id_prefix": "ninja://third_party/flatbuffers:flatbuffers_unittests/" + }, + { + "isolate_name": "grit_python_unittests", + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_isolated_script_merge.py" + }, + "name": "grit_python_unittests", + "resultdb": { + "enable": true + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test_id_prefix": "ninja://tools/grit:grit_python_unittests/" + }, + { + "isolate_name": "mini_installer_tests", + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_isolated_script_merge.py" + }, + "name": "mini_installer_tests", + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "integrity": "high", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test_id_prefix": "ninja://chrome/test/mini_installer:mini_installer_tests/" + }, + { + "isolate_name": "mojo_python_unittests", + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_isolated_script_merge.py" + }, + "name": "mojo_python_unittests", + "resultdb": { + "enable": true + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test_id_prefix": "ninja://mojo/public/tools:mojo_python_unittests/" + }, + { + "experiment_percentage": 100, + "isolate_name": "polymer_tools_python_unittests", + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_isolated_script_merge.py" + }, + "name": "polymer_tools_python_unittests", + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test_id_prefix": "ninja://tools/polymer:polymer_tools_python_unittests/" + }, + { + "args": [ + "BrowserMinidumpTest", + "-v", + "--passthrough", + "--retry-limit=2" + ], + "isolate_name": "telemetry_perf_unittests", + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_isolated_script_merge.py" + }, + "name": "telemetry_desktop_minidump_unittests", + "resultdb": { + "enable": true + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test_id_prefix": "ninja://chrome/test:telemetry_perf_unittests/" + }, + { + "isolate_name": "telemetry_gpu_unittests", + "isolate_profile_data": true, + "merge": { + "args": [], + "script": "//testing/merge_scripts/standard_isolated_script_merge.py" + }, + "name": "telemetry_gpu_unittests", + "resultdb": { + "enable": true + }, + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "idempotent": false, + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test_id_prefix": "ninja://chrome/test:telemetry_gpu_unittests/" + }, + { + "args": [ + "--gtest-benchmark-name=views_perftests" + ], + "isolate_name": "views_perftests", + "isolate_profile_data": true, + "merge": { + "args": [ + "--smoke-test-mode" + ], + "script": "//tools/perf/process_perf_results.py" + }, + "name": "views_perftests", + "swarming": { + "can_use_on_swarming_builders": true, + "dimension_sets": [ + { + "cpu": "x86-64", + "os": "Windows-11-22000" + } + ], + "service_account": "chromium-tester@chops-service-accounts.iam.gserviceaccount.com" + }, + "test_id_prefix": "ninja://ui/views:views_perftests/" + } + ] + }, "android-backuprefptr-arm-fyi-rel": { "gtest_tests": [ {
diff --git a/testing/buildbot/filters/fuchsia.components_unittests.filter b/testing/buildbot/filters/fuchsia.components_unittests.filter index a29c809..91efe36 100644 --- a/testing/buildbot/filters/fuchsia.components_unittests.filter +++ b/testing/buildbot/filters/fuchsia.components_unittests.filter
@@ -9,7 +9,6 @@ -DownloadFile/DownloadFileTestWithRename.RenameError* -DownloadFile/DownloadFileTestWithRename.RenameWith* -DownloadPathReservationTrackerTest.UnwriteableDirectory --EncryptedReportingJobConfigurationTest.Validate* -FilesystemProxyTest.DeleteFile* -FilesystemProxyTest.OpenFileAppend* -FilesystemProxyTest.OpenFileWrite* @@ -19,6 +18,5 @@ -PaintPreviewSerialUtils.TestSerialTypeface -PaintPreviewSubsetFontTest.TestBasicSubset -ProfilingJsonExporterTest.MemoryMaps --RealtimeReportingJobConfigurationTest.ValidatePayload -SslCastSocketTest.TestConnectEndToEndWithRealSSL -UIDevToolsServerTest.ConnectionToViewsServer
diff --git a/testing/buildbot/filters/linux-lacros.interactive_ui_tests.filter b/testing/buildbot/filters/linux-lacros.interactive_ui_tests.filter index d52a20a8..eb7b804 100644 --- a/testing/buildbot/filters/linux-lacros.interactive_ui_tests.filter +++ b/testing/buildbot/filters/linux-lacros.interactive_ui_tests.filter
@@ -15,6 +15,7 @@ -ExtensionApiTest.WindowOpenFocus -MenuViewDragAndDropTestNestedDrag.MenuViewDragAndDropNestedDrag -MenuViewDragAndDropTestTestInMenuDrag.TestInMenuDrag +-OmniboxViewViewsTest.DefaultTypedNavigationsToHttps_ZeroSuggest_NoUpgrade -OmniboxViewViewsTest.SelectionClipboard -SameSiteSubframe* -SitePerProcessInteractiveBrowserTest.TabAndMouseFocusNavigation
diff --git a/testing/buildbot/filters/ozone-linux.interactive_ui_tests_wayland.filter b/testing/buildbot/filters/ozone-linux.interactive_ui_tests_wayland.filter index 73af361..58a7b0c 100644 --- a/testing/buildbot/filters/ozone-linux.interactive_ui_tests_wayland.filter +++ b/testing/buildbot/filters/ozone-linux.interactive_ui_tests_wayland.filter
@@ -28,6 +28,7 @@ -KeyboardLockInteractiveBrowserTest.ActiveWithSomeKeysLocked -MediaDialogViewBrowserTest.PictureInPicture -NotificationsTestWithFakeMediaStream.ShouldQueueDuringScreenPresent +-OmniboxViewViewsTest.DefaultTypedNavigationsToHttps_ZeroSuggest_NoUpgrade -OmniboxViewViewsTest.SelectionClipboard -PasswordBubbleInteractiveUiTest.AutoSigninNoFocus -PopupBlockerBrowserTest.ModalPopUnder
diff --git a/testing/buildbot/mixins.pyl b/testing/buildbot/mixins.pyl index 40a714f..ad452e5 100644 --- a/testing/buildbot/mixins.pyl +++ b/testing/buildbot/mixins.pyl
@@ -1106,6 +1106,13 @@ ], }, }, + 'win11': { + 'swarming': { + 'dimensions': { + 'os': 'Windows-11-22000', + }, + }, + }, 'win7': { 'swarming': { 'dimensions': {
diff --git a/testing/buildbot/test_suite_exceptions.pyl b/testing/buildbot/test_suite_exceptions.pyl index 631efff..ca15da0 100644 --- a/testing/buildbot/test_suite_exceptions.pyl +++ b/testing/buildbot/test_suite_exceptions.pyl
@@ -376,6 +376,15 @@ 'Debug_x64', ], }, + 'Win11 Tests x64': { + 'args': [ + '--target', + 'Release_x64', + ], + 'swarming': { + "shards": 28 + }, + }, 'Win7 Tests (dbg)(1)': { 'args': [ '--debug', @@ -673,6 +682,18 @@ 'quickrun_shards': 30, } }, + 'Win11 Tests x64': { + # crbug.com/868082 + 'args': [ + '--disable-features=WebRTC-H264WithOpenH264FFmpeg', + ], + 'swarming': { + # This is for slow test execution that often becomes a critical path of + # swarming jobs. crbug.com/868114 + 'shards': 15, + 'quickrun_shards': 30, + } + }, 'Win7 Tests (1)': { # This is for slow test execution that often becomes a critical path of # swarming jobs. crbug.com/868114 @@ -1237,6 +1258,12 @@ '--disable-features=WebRTC-H264WithOpenH264FFmpeg', ], }, + 'Win11 Tests x64': { + # crbug.com/868082 + 'args': [ + '--disable-features=WebRTC-H264WithOpenH264FFmpeg', + ], + }, 'android-11-x86-rel': { # TODO(crbug.com/1137474): Remove after the test suite is green. 'experiment_percentage': 100, @@ -2963,6 +2990,7 @@ 'Linux - Future (dbg)', # client.v8.chromium 'Win10 Tests x64', 'Win10 Tests x64 (dbg)', + 'Win11 Tests x64', ], }, 'telemetry_unittests': { @@ -2980,6 +3008,7 @@ 'Mac10.11 Tests', 'Win10 Tests x64', + 'Win11 Tests x64', # TODO(https://crbug.com/1267161): Re-enable when platform is supported. 'mac11-arm64-rel-tests',
diff --git a/testing/buildbot/waterfalls.pyl b/testing/buildbot/waterfalls.pyl index d7a91ae..cc4d1a2 100644 --- a/testing/buildbot/waterfalls.pyl +++ b/testing/buildbot/waterfalls.pyl
@@ -2678,6 +2678,17 @@ 'scripts': 'chromium_win_scripts', }, }, + 'Win11 Tests x64': { + 'mixins': [ + 'x86-64', + 'win11', + 'isolate_profile_data', + ], + 'test_suites': { + 'gtest_tests': 'chromium_win10_gtests', + 'isolated_scripts': 'chromium_win_rel_isolated_scripts', + }, + }, 'android-backuprefptr-arm-fyi-rel': { 'test_suites': { 'gtest_tests': 'backuprefptr_gtests',
diff --git a/testing/variations/fieldtrial_testing_config.json b/testing/variations/fieldtrial_testing_config.json index 16f5e83..b063a85 100644 --- a/testing/variations/fieldtrial_testing_config.json +++ b/testing/variations/fieldtrial_testing_config.json
@@ -526,21 +526,6 @@ ] } ], - "AndroidMessagesNearOomReduction": [ - { - "platforms": [ - "android" - ], - "experiments": [ - { - "name": "Enabled_2021-10-06", - "enable_features": [ - "MessagesForAndroidNearOomReduction" - ] - } - ] - } - ], "AndroidMessagesPWAInstall": [ { "platforms": [ @@ -2837,6 +2822,21 @@ ] } ], + "DWriteFontProxyOnIO": [ + { + "platforms": [ + "windows" + ], + "experiments": [ + { + "name": "Enabled", + "enable_features": [ + "DWriteFontProxyOnIO" + ] + } + ] + } + ], "DataReductionProxyFREPromo": [ { "platforms": [
diff --git a/third_party/android_platform/development/scripts/stack.py b/third_party/android_platform/development/scripts/stack.py index bcd1685..6420406 100755 --- a/third_party/android_platform/development/scripts/stack.py +++ b/third_party/android_platform/development/scripts/stack.py
@@ -50,12 +50,11 @@ print(" usage: " + sys.argv[0] + " [options] [FILE]") print() print(" --symbols-dir=path") - print(" the path to a symbols dir, such as") + print(" path to the Android OS symbols dir, such as") print(" =/tmp/out/target/product/dream/symbols") print() print(" --chrome-symbols-dir=path") - print(" the path to a Chrome symbols dir (can be absolute or relative") - print(" to src), such as =out/Debug/lib.unstripped") + print(" path to a Chrome symbols dir. E.g.: out/Debug/lib.unstripped") print() print(" --output-directory=path") print(" the path to the build output directory, such as out/Debug.") @@ -169,8 +168,7 @@ symbol.ARCH = value arch_defined = True elif option == "--chrome-symbols-dir": - symbol.CHROME_SYMBOLS_DIR = os.path.join(constants.DIR_SOURCE_ROOT, - value) + symbol.CHROME_SYMBOLS_DIR = value elif option == "--output-directory": constants.SetOutputDirectory(os.path.abspath(value)) elif option == "--apks-directory":
diff --git a/third_party/blink/public/mojom/input/input_handler.mojom b/third_party/blink/public/mojom/input/input_handler.mojom index 0216057b..773ac2bf 100644 --- a/third_party/blink/public/mojom/input/input_handler.mojom +++ b/third_party/blink/public/mojom/input/input_handler.mojom
@@ -199,6 +199,11 @@ // selection (which is a caret). int32 extended_start_adjust; int32 extended_end_adjust; + // The offset differences between the word selection (regardless of the + // extended selection granularity) and the initial selection (which is a + // caret). + int32 word_start_adjust; + int32 word_end_adjust; }; // GENERATED_JAVA_ENUM_PACKAGE: org.chromium.blink_public.input
diff --git a/third_party/blink/public/mojom/payments/payment_request.mojom b/third_party/blink/public/mojom/payments/payment_request.mojom index bcc1ab1..af662edd4 100644 --- a/third_party/blink/public/mojom/payments/payment_request.mojom +++ b/third_party/blink/public/mojom/payments/payment_request.mojom
@@ -36,8 +36,7 @@ struct SecurePaymentConfirmationResponse { blink.mojom.CommonCredentialInfo credential_info; array<uint8> signature; - bool has_transport; - blink.mojom.AuthenticatorTransport transport; + blink.mojom.AuthenticatorAttachment authenticator_attachment; array<uint8>? user_handle; };
diff --git a/third_party/blink/public/mojom/web_feature/web_feature.mojom b/third_party/blink/public/mojom/web_feature/web_feature.mojom index 2899546..f89b116 100644 --- a/third_party/blink/public/mojom/web_feature/web_feature.mojom +++ b/third_party/blink/public/mojom/web_feature/web_feature.mojom
@@ -3197,7 +3197,7 @@ kWebAppManifestProtocolHandlers = 3884, kRTCPeerConnectionOfferAllowExtmapMixedFalse = 3885, kNewCanvas2DAPI = 3886, - kServiceWorkerSubresourceFilterBypassedRequest = 3887, + kOBSOLETE_ServiceWorkerSubresourceFilterBypassedRequest = 3887, kWebGPU = 3888, kCSSFilterColorMatrix = 3889, kHTMLFencedFrameElement = 3890,
diff --git a/third_party/blink/public/mojom/webauthn/authenticator.mojom b/third_party/blink/public/mojom/webauthn/authenticator.mojom index ed9f764de..4156827 100644 --- a/third_party/blink/public/mojom/webauthn/authenticator.mojom +++ b/third_party/blink/public/mojom/webauthn/authenticator.mojom
@@ -88,11 +88,8 @@ struct MakeCredentialAuthenticatorResponse { CommonCredentialInfo info; - // True if transport exists. False if transport does not exist, e.g. on old - // Windows. - bool has_transport; - // The transport method used to authenticate. - AuthenticatorTransport transport; + // The attachment of the authenticator that created the credential. + AuthenticatorAttachment authenticator_attachment; // A blob of data returned by the authenticator after creating a credential. array<uint8> attestation_object; @@ -156,11 +153,8 @@ struct GetAssertionAuthenticatorResponse { CommonCredentialInfo info; - // True if transport exists. False if transport does not exist, e.g. on old - // Windows. - bool has_transport; - // The transport method used to authenticate. - AuthenticatorTransport transport; + // The attachment of the authenticator that created the credential. + AuthenticatorAttachment authenticator_attachment; // Cryptographic signature proving possession of the credential private key. array<uint8> signature;
diff --git a/third_party/blink/public/web/web_local_frame.h b/third_party/blink/public/web/web_local_frame.h index a0e3492..562af77 100644 --- a/third_party/blink/public/web/web_local_frame.h +++ b/third_party/blink/public/web/web_local_frame.h
@@ -875,8 +875,6 @@ // Reset TextFinder state for the web test runner in between two tests. virtual void ClearActiveFindMatchForTesting() = 0; - virtual bool ServiceWorkerSubresourceFilterEnabled() = 0; - // Sets a local storage area which can be used for this frame. This storage // area is ignored if a cached storage area already exists for the storage // key.
diff --git a/third_party/blink/renderer/bindings/scripts/bind_gen/callback_function.py b/third_party/blink/renderer/bindings/scripts/bind_gen/callback_function.py index cae8eb9..b36fced5 100644 --- a/third_party/blink/renderer/bindings/scripts/bind_gen/callback_function.py +++ b/third_party/blink/renderer/bindings/scripts/bind_gen/callback_function.py
@@ -412,7 +412,7 @@ T("v8::TryCatch try_catch(${isolate});"), T("try_catch.SetVerbose(true);"), EmptyNode(), - F("ignore_result({api_func_name}({arg_names}));", + F("std::ignore = {api_func_name}({arg_names});", api_func_name=api_func_name, arg_names=", ".join(arg_names)), ]) @@ -594,8 +594,10 @@ "third_party/blink/renderer/platform/bindings/callback_function_base.h", "third_party/blink/renderer/platform/bindings/v8_value_or_script_wrappable_adapter.h", ]) + source_node.accumulator.add_stdcpp_include_headers([ + "tuple", + ]) source_node.accumulator.add_include_headers([ - "base/ignore_result.h", "third_party/blink/renderer/bindings/core/v8/callback_invoke_helper.h", "third_party/blink/renderer/bindings/core/v8/generated_code_helper.h", "third_party/blink/renderer/bindings/core/v8/to_v8_traits.h",
diff --git a/third_party/blink/renderer/bindings/scripts/bind_gen/callback_interface.py b/third_party/blink/renderer/bindings/scripts/bind_gen/callback_interface.py index 4bb0545..024ca13 100644 --- a/third_party/blink/renderer/bindings/scripts/bind_gen/callback_interface.py +++ b/third_party/blink/renderer/bindings/scripts/bind_gen/callback_interface.py
@@ -283,8 +283,10 @@ "third_party/blink/renderer/platform/bindings/callback_interface_base.h", "third_party/blink/renderer/platform/bindings/v8_value_or_script_wrappable_adapter.h", ]) + source_node.accumulator.add_stdcpp_include_headers([ + "tuple", + ]) source_node.accumulator.add_include_headers([ - "base/ignore_result.h", "third_party/blink/renderer/bindings/core/v8/callback_invoke_helper.h", "third_party/blink/renderer/bindings/core/v8/generated_code_helper.h", "third_party/blink/renderer/bindings/core/v8/to_v8_traits.h",
diff --git a/third_party/blink/renderer/bindings/scripts/bind_gen/codegen_accumulator.py b/third_party/blink/renderer/bindings/scripts/bind_gen/codegen_accumulator.py index 2c822edf..c661802 100644 --- a/third_party/blink/renderer/bindings/scripts/bind_gen/codegen_accumulator.py +++ b/third_party/blink/renderer/bindings/scripts/bind_gen/codegen_accumulator.py
@@ -10,16 +10,18 @@ """ def __init__(self): - # Headers to be included + # Headers of non-standard library to be included self._include_headers = set() + # Headers of C++ standard library to be included + self._stdcpp_include_headers = set() # Forward declarations of C++ class self._class_decls = set() # Forward declarations of C++ struct self._struct_decls = set() def total_size(self): - return (len(self.include_headers) + len(self.class_decls) + len( - self.struct_decls)) + return (len(self.include_headers) + len(self.class_decls) + + len(self.struct_decls) + len(self.stdcpp_include_headers)) @property def include_headers(self): @@ -33,6 +35,18 @@ return lambda accumulator: accumulator.add_include_headers(headers) @property + def stdcpp_include_headers(self): + return self._stdcpp_include_headers + + def add_stdcpp_include_headers(self, headers): + self._stdcpp_include_headers.update(filter(None, headers)) + + @staticmethod + def require_stdcpp_include_headers(headers): + return lambda accumulator: accumulator.add_stdcpp_include_headers( + headers) + + @property def class_decls(self): return self._class_decls
diff --git a/third_party/blink/renderer/bindings/scripts/bind_gen/codegen_utils.py b/third_party/blink/renderer/bindings/scripts/bind_gen/codegen_utils.py index b21e325..0d5dfab 100644 --- a/third_party/blink/renderer/bindings/scripts/bind_gen/codegen_utils.py +++ b/third_party/blink/renderer/bindings/scripts/bind_gen/codegen_utils.py
@@ -55,11 +55,22 @@ self._accumulator = accumulator def __str__(self): - return "\n".join([ + lines = [] + + if self._accumulator.stdcpp_include_headers: + lines.extend([ + "#include <{}>".format(header) for header in sorted( + self._accumulator.stdcpp_include_headers) + ]) + lines.append("") + + lines.extend([ "#include \"{}\"".format(header) for header in sorted(self._accumulator.include_headers) ]) + return "\n".join(lines) + return LiteralNode(HeaderIncludeDirectives(accumulator))
diff --git a/third_party/blink/renderer/core/BUILD.gn b/third_party/blink/renderer/core/BUILD.gn index b2a43ce..62eee2d8 100644 --- a/third_party/blink/renderer/core/BUILD.gn +++ b/third_party/blink/renderer/core/BUILD.gn
@@ -304,7 +304,6 @@ deps = [ ":generated_settings_macros", "//build:chromeos_buildflags", - "//build:os_buildflags", "//components/paint_preview/common", "//components/performance_manager/public/mojom:mojom_blink", "//components/power_scheduler",
diff --git a/third_party/blink/renderer/core/display_lock/display_lock_utilities.h b/third_party/blink/renderer/core/display_lock/display_lock_utilities.h index e620b94..a0992f2 100644 --- a/third_party/blink/renderer/core/display_lock/display_lock_utilities.h +++ b/third_party/blink/renderer/core/display_lock/display_lock_utilities.h
@@ -9,6 +9,7 @@ #include "third_party/blink/renderer/core/core_export.h" #include "third_party/blink/renderer/core/display_lock/display_lock_context.h" #include "third_party/blink/renderer/core/dom/range.h" +#include "third_party/blink/renderer/core/editing/commands/apply_style_command.h" #include "third_party/blink/renderer/core/editing/ephemeral_range.h" #include "third_party/blink/renderer/core/editing/frame_selection.h" #include "third_party/blink/renderer/core/paint/paint_layer.h" @@ -48,6 +49,9 @@ friend void Document::UpdateStyleAndLayoutForNode( const Node* node, DocumentUpdateReason reason); + friend void Document::UpdateStyleAndLayoutForRange( + const Range* range, + DocumentUpdateReason reason); friend void Document::UpdateStyleAndLayoutTreeForNode(const Node*); friend void Document::UpdateStyleAndLayoutTreeForSubtree(const Node* node); friend void Document::EnsurePaintLocationDataValidForNode(
diff --git a/third_party/blink/renderer/core/dom/document.cc b/third_party/blink/renderer/core/dom/document.cc index 10227b4..297b0b9 100644 --- a/third_party/blink/renderer/core/dom/document.cc +++ b/third_party/blink/renderer/core/dom/document.cc
@@ -2322,6 +2322,13 @@ } } +void Document::UpdateStyleAndLayoutForRange(const Range* range, + DocumentUpdateReason reason) { + DisplayLockUtilities::ScopedForcedUpdate scoped_update_forced( + range, DisplayLockContext::ForcedPhase::kLayout); + UpdateStyleAndLayout(reason); +} + void Document::UpdateStyleAndLayoutForNode(const Node* node, DocumentUpdateReason reason) { DCHECK(node);
diff --git a/third_party/blink/renderer/core/dom/document.h b/third_party/blink/renderer/core/dom/document.h index 796f548..26f606f6 100644 --- a/third_party/blink/renderer/core/dom/document.h +++ b/third_party/blink/renderer/core/dom/document.h
@@ -639,6 +639,7 @@ kRunPostLayoutTasksSynchronously, }; void UpdateStyleAndLayoutForNode(const Node*, DocumentUpdateReason); + void UpdateStyleAndLayoutForRange(const Range*, DocumentUpdateReason); void IncLayoutCallsCounter() { ++layout_calls_counter_; } void IncLayoutCallsCounterNG() { ++layout_calls_counter_ng_; } void IncLayoutBlockCounter() { ++layout_blocks_counter_; }
diff --git a/third_party/blink/renderer/core/editing/commands/apply_style_command.cc b/third_party/blink/renderer/core/editing/commands/apply_style_command.cc index 433b090..97cc32b 100644 --- a/third_party/blink/renderer/core/editing/commands/apply_style_command.cc +++ b/third_party/blink/renderer/core/editing/commands/apply_style_command.cc
@@ -32,6 +32,7 @@ #include "third_party/blink/renderer/core/css/css_property_names.h" #include "third_party/blink/renderer/core/css/css_property_value_set.h" #include "third_party/blink/renderer/core/css_value_keywords.h" +#include "third_party/blink/renderer/core/display_lock/display_lock_utilities.h" #include "third_party/blink/renderer/core/dom/document.h" #include "third_party/blink/renderer/core/dom/node_list.h" #include "third_party/blink/renderer/core/dom/node_traversal.h" @@ -965,7 +966,10 @@ if (remove_only_) return; - GetDocument().UpdateStyleAndLayout(DocumentUpdateReason::kEditing); + Range* range = MakeGarbageCollected<Range>(GetDocument(), StartPosition(), + EndPosition()); + GetDocument().UpdateStyleAndLayoutForRange(range, + DocumentUpdateReason::kEditing); HeapVector<InlineRunToApplyStyle> runs; Node* node = start_node; @@ -1041,7 +1045,8 @@ } } - GetDocument().UpdateStyleAndLayout(DocumentUpdateReason::kEditing); + GetDocument().UpdateStyleAndLayoutForRange(range, + DocumentUpdateReason::kEditing); for (auto& run : runs) { if (run.position_for_style_computation.IsNotNull())
diff --git a/third_party/blink/renderer/core/editing/frame_selection.cc b/third_party/blink/renderer/core/editing/frame_selection.cc index 99a2f5b..9386c18 100644 --- a/third_party/blink/renderer/core/editing/frame_selection.cc +++ b/third_party/blink/renderer/core/editing/frame_selection.cc
@@ -1148,69 +1148,44 @@ TextGranularity text_granularity, HandleVisibility handle_visibility, ContextMenuVisibility context_menu_visibility) { - // Only supports word and sentence granularities for now. - DCHECK(text_granularity == TextGranularity::kWord || - text_granularity == TextGranularity::kSentence); + CHECK(text_granularity == TextGranularity::kWord || + text_granularity == TextGranularity::kSentence) + << "Only word and sentence granularities are supported for now"; - const VisibleSelection& selection = ComputeVisibleSelectionInDOMTree(); - // TODO(editing-dev): The use of VisibleSelection needs to be audited. See - // http://crbug.com/657237 for more details. - if (!selection.IsCaret()) + EphemeralRange selection_range = + GetSelectionRangeAroundCaret(text_granularity); + if (selection_range.IsNull()) { return false; - const Position position = selection.Start(); - static const WordSide kWordSideList[2] = {kNextWordIfOnBoundary, - kPreviousWordIfOnBoundary}; - for (WordSide word_side : kWordSideList) { - Position start; - Position end; - // Use word granularity by default unless sentence granularity is explicitly - // requested. - if (text_granularity == TextGranularity::kSentence) { - start = StartOfSentencePosition(position); - end = EndOfSentence(position, SentenceTrailingSpaceBehavior::kOmitSpace) - .GetPosition(); - } else { - start = StartOfWordPosition(position, word_side); - end = EndOfWordPosition(position, word_side); - } - - // TODO(editing-dev): |StartOfWord()| and |EndOfWord()| should not make null - // for non-null parameter. - // See http://crbug.com/872443 - if (start.IsNull() || end.IsNull()) - continue; - - if (start > end) { - // Since word boundaries are computed on flat tree, they can be reversed - // when mapped back to DOM. - std::swap(start, end); - } - - String text = PlainText(EphemeralRange(start, end)); - if (!text.IsEmpty() && !IsSeparator(text.CharacterStartingAt(0))) { - SetSelection( - SelectionInDOMTree::Builder().Collapse(start).Extend(end).Build(), - SetSelectionOptions::Builder() - .SetShouldCloseTyping(true) - .SetShouldClearTypingStyle(true) - .SetGranularity(text_granularity == TextGranularity::kSentence - ? TextGranularity::kSentence - : TextGranularity::kWord) - .SetShouldShowHandle(handle_visibility == - HandleVisibility::kVisible) - .Build()); - - if (context_menu_visibility == ContextMenuVisibility::kVisible) { - ContextMenuAllowedScope scope; - frame_->GetEventHandler().ShowNonLocatedContextMenu( - /*override_target_element=*/nullptr, kMenuSourceTouch); - } - - return true; - } } - return false; + SetSelection( + SelectionInDOMTree::Builder() + .Collapse(selection_range.StartPosition()) + .Extend(selection_range.EndPosition()) + .Build(), + SetSelectionOptions::Builder() + .SetShouldCloseTyping(true) + .SetShouldClearTypingStyle(true) + .SetGranularity(text_granularity) + .SetShouldShowHandle(handle_visibility == HandleVisibility::kVisible) + .Build()); + + if (context_menu_visibility == ContextMenuVisibility::kVisible) { + ContextMenuAllowedScope scope; + frame_->GetEventHandler().ShowNonLocatedContextMenu( + /*override_target_element=*/nullptr, kMenuSourceTouch); + } + + return true; +} + +EphemeralRange FrameSelection::GetWordSelectionRangeAroundCaret() const { + return GetSelectionRangeAroundCaret(TextGranularity::kWord); +} + +EphemeralRange FrameSelection::GetSelectionRangeAroundCaretForTesting( + TextGranularity text_granularity) const { + return GetSelectionRangeAroundCaret(text_granularity); } GranularityStrategy* FrameSelection::GetGranularityStrategy() { @@ -1345,6 +1320,59 @@ selection_editor_->MarkCacheDirty(); } +EphemeralRange FrameSelection::GetSelectionRangeAroundCaret( + TextGranularity text_granularity) const { + DCHECK(text_granularity == TextGranularity::kWord || + text_granularity == TextGranularity::kSentence) + << "Only word and sentence granularities are supported for now"; + + const VisibleSelection& selection = ComputeVisibleSelectionInDOMTree(); + // TODO(editing-dev): The use of VisibleSelection needs to be audited. See + // http://crbug.com/657237 for more details. + if (!selection.IsCaret()) { + return EphemeralRange(); + } + const Position position = selection.Start(); + static const WordSide kWordSideList[2] = {kNextWordIfOnBoundary, + kPreviousWordIfOnBoundary}; + for (WordSide word_side : kWordSideList) { + Position start; + Position end; + // Use word granularity by default unless sentence granularity is explicitly + // requested. + if (text_granularity == TextGranularity::kSentence) { + start = StartOfSentencePosition(position); + end = EndOfSentence(position, SentenceTrailingSpaceBehavior::kOmitSpace) + .GetPosition(); + } else { + start = StartOfWordPosition(position, word_side); + end = EndOfWordPosition(position, word_side); + } + + // TODO(editing-dev): |StartOfWord()| and |EndOfWord()| should not make null + // for non-null parameter. + // See http://crbug.com/872443 + if (start.IsNull() || end.IsNull()) { + continue; + } + + if (start > end) { + // Since word boundaries are computed on flat tree, they can be reversed + // when mapped back to DOM. + std::swap(start, end); + } + + String text = PlainText(EphemeralRange(start, end)); + if (text.IsEmpty() || IsSeparator(text.CharacterStartingAt(0))) { + continue; + } + + return EphemeralRange(start, end); + } + + return EphemeralRange(); +} + } // namespace blink #if DCHECK_IS_ON()
diff --git a/third_party/blink/renderer/core/editing/frame_selection.h b/third_party/blink/renderer/core/editing/frame_selection.h index d97ed56..f614f06 100644 --- a/third_party/blink/renderer/core/editing/frame_selection.h +++ b/third_party/blink/renderer/core/editing/frame_selection.h
@@ -32,6 +32,7 @@ #include "base/dcheck_is_on.h" #include "third_party/blink/renderer/core/core_export.h" #include "third_party/blink/renderer/core/dom/synchronous_mutation_observer.h" +#include "third_party/blink/renderer/core/editing/ephemeral_range.h" #include "third_party/blink/renderer/core/editing/forward.h" #include "third_party/blink/renderer/core/editing/set_selection_options.h" #include "third_party/blink/renderer/core/scroll/scroll_alignment.h" @@ -258,6 +259,18 @@ HandleVisibility handle_visibility, ContextMenuVisibility context_menu_visibility); + // Returns the range corresponding to a word selection around the caret. + // Returns a null range if the selection failed, either because the current + // selection was not a caret or if a word selection could not be made. + EphemeralRange GetWordSelectionRangeAroundCaret() const; + + // Returns the range corresponding to a |text_granularity| selection around + // the caret. Returns a null range if the selection failed, either because + // the current selection was not a caret or if a |text_granularity| selection + // could not be made. + EphemeralRange GetSelectionRangeAroundCaretForTesting( + TextGranularity text_granularity) const; + #if DCHECK_IS_ON() void ShowTreeForThis() const; #endif @@ -327,6 +340,13 @@ void NodeChildrenWillBeRemoved(ContainerNode&) final; void NodeWillBeRemoved(Node&) final; + // Returns the range corresponding to a |text_granularity| selection around + // the caret. Returns a null range if the selection failed, either because + // the current selection was not a caret or if a |text_granularity| selection + // could not be made. + EphemeralRange GetSelectionRangeAroundCaret( + TextGranularity text_granularity) const; + Member<LocalFrame> frame_; const Member<LayoutSelection> layout_selection_; const Member<SelectionEditor> selection_editor_;
diff --git a/third_party/blink/renderer/core/editing/frame_selection_test.cc b/third_party/blink/renderer/core/editing/frame_selection_test.cc index 8e60a15..9da8b02 100644 --- a/third_party/blink/renderer/core/editing/frame_selection_test.cc +++ b/third_party/blink/renderer/core/editing/frame_selection_test.cc
@@ -7,11 +7,13 @@ #include <memory> #include "base/memory/scoped_refptr.h" #include "testing/gtest/include/gtest/gtest.h" +#include "third_party/blink/public/web/web_range.h" #include "third_party/blink/renderer/core/dom/document.h" #include "third_party/blink/renderer/core/dom/element.h" #include "third_party/blink/renderer/core/dom/text.h" #include "third_party/blink/renderer/core/editing/ephemeral_range.h" #include "third_party/blink/renderer/core/editing/frame_caret.h" +#include "third_party/blink/renderer/core/editing/iterators/text_iterator.h" #include "third_party/blink/renderer/core/editing/selection_controller.h" #include "third_party/blink/renderer/core/editing/selection_modifier.h" #include "third_party/blink/renderer/core/editing/selection_template.h" @@ -67,11 +69,19 @@ // Returns if a word is is selected. bool SelectWordAroundPosition(const Position&); + // Returns whether the selection was accomplished. + bool SelectWordAroundCaret(); + + // Returns whether the selection was accomplished. + bool SelectSentenceAroundCaret(); + // Places the caret on the |text| at |selection_index|. - void ResetAndPlaceCaret(Text* text, int selection_index) { + void ResetAndPlaceCaret(Text* text, size_t selection_index) { + ASSERT_LE(selection_index, + static_cast<size_t>(std::numeric_limits<int>::max())); Selection().SetSelectionAndEndTyping( SelectionInDOMTree::Builder() - .Collapse(Position(text, selection_index)) + .Collapse(Position(text, static_cast<int>(selection_index))) .Build()); } @@ -107,6 +117,18 @@ return Selection().SelectWordAroundCaret(); } +bool FrameSelectionTest::SelectWordAroundCaret() { + return Selection().SelectAroundCaret(TextGranularity::kWord, + HandleVisibility::kNotVisible, + ContextMenuVisibility::kNotVisible); +} + +bool FrameSelectionTest::SelectSentenceAroundCaret() { + return Selection().SelectAroundCaret(TextGranularity::kSentence, + HandleVisibility::kNotVisible, + ContextMenuVisibility::kNotVisible); +} + TEST_F(FrameSelectionTest, FirstEphemeralRangeOf) { SetBodyContent("<div id=sample>0123456789</div>abc"); Element* const sample = GetDocument().getElementById("sample"); @@ -206,6 +228,41 @@ EXPECT_EQ_SELECTED_TEXT("baz"); } +TEST_F(FrameSelectionTest, SelectAroundCaret_Word) { + Text* text = AppendTextNode("This is a sentence."); + UpdateAllLifecyclePhasesForTest(); + + // Beginning of text: |This is a sentence. + ResetAndPlaceCaret(text, strlen("")); + EXPECT_TRUE(SelectWordAroundCaret()); + EXPECT_EQ_SELECTED_TEXT("This"); + + // Beginning of a word: This |is a sentence. + ResetAndPlaceCaret(text, strlen("This ")); + EXPECT_TRUE(SelectWordAroundCaret()); + EXPECT_EQ_SELECTED_TEXT("is"); + + // Somewhere in a word: This is a s|entence. + ResetAndPlaceCaret(text, strlen("This is a s")); + EXPECT_TRUE(SelectWordAroundCaret()); + EXPECT_EQ_SELECTED_TEXT("sentence"); + + // End a word: This| is a sentence. + ResetAndPlaceCaret(text, strlen("This")); + EXPECT_TRUE(SelectWordAroundCaret()); + EXPECT_EQ_SELECTED_TEXT("This"); + + // End a word with punctuation: This is a sentence|. + ResetAndPlaceCaret(text, strlen("This is a sentence")); + EXPECT_TRUE(SelectWordAroundCaret()); + EXPECT_EQ_SELECTED_TEXT("sentence"); + + // End a word after punctuation: This is a sentence.| + ResetAndPlaceCaret(text, strlen("This is a sentence.")); + EXPECT_FALSE(SelectWordAroundCaret()); + EXPECT_EQ_SELECTED_TEXT(""); +} + TEST_F(FrameSelectionTest, SelectAroundCaret_Sentence) { Text* text = AppendTextNode( "This is the first sentence. This is the second sentence. This is the " @@ -214,44 +271,34 @@ // This is the first sentence. Th|is is the second sentence. This is the last // sentence. - ResetAndPlaceCaret(text, 30); - EXPECT_TRUE(Selection().SelectAroundCaret( - TextGranularity::kSentence, HandleVisibility::kNotVisible, - ContextMenuVisibility::kNotVisible)); + ResetAndPlaceCaret(text, strlen("This is the first sentence. Th")); + EXPECT_TRUE(SelectSentenceAroundCaret()); EXPECT_EQ_SELECTED_TEXT("This is the second sentence."); // This is the first sentence|. This is the second sentence. This is the last // sentence. - ResetAndPlaceCaret(text, 26); - EXPECT_TRUE(Selection().SelectAroundCaret( - TextGranularity::kSentence, HandleVisibility::kNotVisible, - ContextMenuVisibility::kNotVisible)); + ResetAndPlaceCaret(text, strlen("This is the first sentence")); + EXPECT_TRUE(SelectSentenceAroundCaret()); EXPECT_EQ_SELECTED_TEXT("This is the first sentence."); // This is the first sentence.| This is the second sentence. This is the last // sentence. - ResetAndPlaceCaret(text, 27); - EXPECT_TRUE(Selection().SelectAroundCaret( - TextGranularity::kSentence, HandleVisibility::kNotVisible, - ContextMenuVisibility::kNotVisible)); + ResetAndPlaceCaret(text, strlen("This is the first sentence.")); + EXPECT_TRUE(SelectSentenceAroundCaret()); EXPECT_EQ_SELECTED_TEXT( "This is the first sentence. This is the second sentence."); // This is the first sentence. |This is the second sentence. This is the last // sentence. - ResetAndPlaceCaret(text, 28); - EXPECT_TRUE(Selection().SelectAroundCaret( - TextGranularity::kSentence, HandleVisibility::kNotVisible, - ContextMenuVisibility::kNotVisible)); + ResetAndPlaceCaret(text, strlen("This is the first sentence. ")); + EXPECT_TRUE(SelectSentenceAroundCaret()); EXPECT_EQ_SELECTED_TEXT( "This is the first sentence. This is the second sentence."); // This is the first sentence. T|his is the second sentence. This is the last // sentence. - ResetAndPlaceCaret(text, 29); - EXPECT_TRUE(Selection().SelectAroundCaret( - TextGranularity::kSentence, HandleVisibility::kNotVisible, - ContextMenuVisibility::kNotVisible)); + ResetAndPlaceCaret(text, strlen("This is the first sentence. T")); + EXPECT_TRUE(SelectSentenceAroundCaret()); EXPECT_EQ_SELECTED_TEXT("This is the second sentence."); } @@ -360,6 +407,117 @@ EXPECT_TRUE(HasContextMenu()); } +TEST_F(FrameSelectionTest, GetSelectionRangeAroundCaret_Word) { + Text* text = AppendTextNode("This is a sentence."); + UpdateAllLifecyclePhasesForTest(); + + // Beginning of a text: |This is a sentence. + ResetAndPlaceCaret(text, strlen("")); + EphemeralRange range = Selection().GetWordSelectionRangeAroundCaret(); + EXPECT_EQ("This", PlainText(range)); + + // Beginning of a word: This |is a sentence. + ResetAndPlaceCaret(text, strlen("This ")); + range = Selection().GetWordSelectionRangeAroundCaret(); + EXPECT_EQ("is", PlainText(range)); + + // Somewhere in a word: This is a s|entence. + ResetAndPlaceCaret(text, strlen("This is a s")); + range = Selection().GetWordSelectionRangeAroundCaret(); + EXPECT_EQ("sentence", PlainText(range)); + + // End a word: This| is a sentence. + ResetAndPlaceCaret(text, strlen("This")); + range = Selection().GetWordSelectionRangeAroundCaret(); + EXPECT_EQ("This", PlainText(range)); + + // End a word before punctuation: This is a sentence|. + ResetAndPlaceCaret(text, strlen("This is a sentence")); + range = Selection().GetWordSelectionRangeAroundCaret(); + EXPECT_EQ("sentence", PlainText(range)); + + // End of text after punctuation (no selection): This is a sentence.| + ResetAndPlaceCaret(text, strlen("This is a sentence.")); + range = Selection().GetWordSelectionRangeAroundCaret(); + EXPECT_EQ("", PlainText(range)); + + // End of text without punctuation: This is a sentence| + ResetAndPlaceCaret(text, strlen("This is a sentence")); + range = Selection().GetWordSelectionRangeAroundCaret(); + EXPECT_EQ("sentence", PlainText(range)); + + // After punctuation before whitespace (no selection): A word.| Another. + text = AppendTextNode("A word. Another."); + UpdateAllLifecyclePhasesForTest(); + ResetAndPlaceCaret(text, strlen("A word.")); + range = Selection().GetWordSelectionRangeAroundCaret(); + EXPECT_EQ("", PlainText(range)); +} + +TEST_F(FrameSelectionTest, GetSelectionRangeAroundCaret_Sentence) { + Text* text = AppendTextNode( + "This is the first sentence. This is the second sentence. This is the " + "last sentence."); + UpdateAllLifecyclePhasesForTest(); + + // |This is the first sentence. This is the second sentence. This is the last + // sentence. + ResetAndPlaceCaret(text, strlen("")); + EphemeralRange range = Selection().GetSelectionRangeAroundCaretForTesting( + TextGranularity::kSentence); + EXPECT_EQ("This is the first sentence.", PlainText(range)); + + // This is the first sentence|. This is the second sentence. This is the last + // sentence. + ResetAndPlaceCaret(text, strlen("This is the first sentence")); + range = Selection().GetSelectionRangeAroundCaretForTesting( + TextGranularity::kSentence); + EXPECT_EQ("This is the first sentence.", PlainText(range)); + + // TODO(crbug.com/1273856): This should only select one sentence. + // This is the first sentence.| This is the second sentence. This is the last + // sentence. + ResetAndPlaceCaret(text, strlen("This is the first sentence.")); + range = Selection().GetSelectionRangeAroundCaretForTesting( + TextGranularity::kSentence); + EXPECT_EQ("This is the first sentence. This is the second sentence.", + PlainText(range)); + + // TODO(crbug.com/1273856): This should only select one sentence. + // This is the first sentence. |This is the second sentence. This is the last + // sentence. + ResetAndPlaceCaret(text, strlen("This is the first sentence. ")); + range = Selection().GetSelectionRangeAroundCaretForTesting( + TextGranularity::kSentence); + EXPECT_EQ("This is the first sentence. This is the second sentence.", + PlainText(range)); + + // This is the first sentence. Th|is is the second sentence. This is the last + // sentence. + ResetAndPlaceCaret(text, strlen("This is the first sentence. Th")); + range = Selection().GetSelectionRangeAroundCaretForTesting( + TextGranularity::kSentence); + EXPECT_EQ("This is the second sentence.", PlainText(range)); + + // This is the first sentence. This is the second sentence. This is the last + // sentence|. + ResetAndPlaceCaret(text, + strlen("This is the first sentence. This is the second " + "sentence. This is the last sentence")); + range = Selection().GetSelectionRangeAroundCaretForTesting( + TextGranularity::kSentence); + EXPECT_EQ("This is the last sentence.", PlainText(range)); + + // This is the first sentence. This is the second sentence. This is the last + // sentence.| + ResetAndPlaceCaret(text, + strlen("This is the first sentence. This is the second " + "sentence. This is the last sentence.")); + range = Selection().GetSelectionRangeAroundCaretForTesting( + TextGranularity::kSentence); + EXPECT_EQ("This is the last sentence.", PlainText(range)); +} + TEST_F(FrameSelectionTest, ModifyExtendWithFlatTree) { SetBodyContent("<span id=host></span>one"); SetShadowContent("two<slot></slot>", "host");
diff --git a/third_party/blink/renderer/core/frame/find_in_page.cc b/third_party/blink/renderer/core/frame/find_in_page.cc index 11bc63d..22c97f56 100644 --- a/third_party/blink/renderer/core/frame/find_in_page.cc +++ b/third_party/blink/renderer/core/frame/find_in_page.cc
@@ -40,7 +40,6 @@ #include "third_party/blink/renderer/core/editing/finder/text_finder.h" #include "third_party/blink/renderer/core/frame/web_local_frame_impl.h" #include "third_party/blink/renderer/core/layout/layout_view.h" -#include "third_party/blink/renderer/core/page/chrome_client.h" #include "third_party/blink/renderer/core/page/focus_controller.h" #include "third_party/blink/renderer/core/page/page.h" @@ -343,23 +342,13 @@ : mojom::blink::FindMatchUpdateType::kMoreUpdatesComing); } -void FindInPage::ReportFindInPageSelection( - int request_id, - int active_match_ordinal, - const gfx::Rect& local_selection_rect, - bool final_update) { +void FindInPage::ReportFindInPageSelection(int request_id, + int active_match_ordinal, + const gfx::Rect& selection_rect, + bool final_update) { // In tests, |client_| might not be set. if (!client_) return; - - float device_scale_factor = 1.f; - if (LocalFrame* local_frame = frame_->GetFrame()) { - device_scale_factor = - local_frame->GetPage()->GetChromeClient().WindowToViewportScalar( - local_frame, 1.0f); - } - auto selection_rect = gfx::ScaleToEnclosingRect(local_selection_rect, - 1.f / device_scale_factor); client_->SetActiveMatch( request_id, selection_rect, active_match_ordinal, final_update ? mojom::blink::FindMatchUpdateType::kFinalUpdate
diff --git a/third_party/blink/renderer/core/frame/web_frame_widget_impl.cc b/third_party/blink/renderer/core/frame/web_frame_widget_impl.cc index d55d1e3..25e256a 100644 --- a/third_party/blink/renderer/core/frame/web_frame_widget_impl.cc +++ b/third_party/blink/renderer/core/frame/web_frame_widget_impl.cc
@@ -3691,32 +3691,57 @@ return; } - // TODO(crbug.com/1278134): Calculate extended adjustments. - bool did_select = false; int extended_start_adjust = 0; int extended_end_adjust = 0; + int word_start_adjust = 0; + int word_end_adjust = 0; blink::WebRange initial_range = focused_frame->SelectionRange(); SetHandlingInputEvent(true); - if (!initial_range.IsNull()) { - did_select = focused_frame->SelectAroundCaret( - granularity, should_show_handle, should_show_context_menu); - } - if (!did_select) { + if (initial_range.IsNull()) { std::move(callback).Run(std::move(nullptr)); return; } - blink::WebRange adjusted_range = focused_frame->SelectionRange(); - DCHECK(!adjusted_range.IsNull()); + // If the requested granularity is not word, still calculate the hypothetical + // word selection offsets. This is needed for contextual search to support + // legacy semantics for the word that was tapped. + blink::WebRange word_range; + if (granularity != mojom::blink::SelectionGranularity::kWord) { + word_range = focused_frame->GetWordSelectionRangeAroundCaret(); + } + + // Select around the caret at the specified |granularity|. + if (!focused_frame->SelectAroundCaret(granularity, should_show_handle, + should_show_context_menu)) { + std::move(callback).Run(std::move(nullptr)); + return; + } + + blink::WebRange extended_range = focused_frame->SelectionRange(); + DCHECK(!extended_range.IsNull()); extended_start_adjust = - adjusted_range.StartOffset() - initial_range.StartOffset(); - extended_end_adjust = adjusted_range.EndOffset() - initial_range.EndOffset(); + extended_range.StartOffset() - initial_range.StartOffset(); + extended_end_adjust = extended_range.EndOffset() - initial_range.EndOffset(); + + if (granularity == mojom::blink::SelectionGranularity::kWord) { + // Since the requested granularity was word, simply set the word offset + // to be the same as the extended offset values. + word_start_adjust = extended_start_adjust; + word_end_adjust = extended_end_adjust; + } else { + // Calculate the word offset compared to the initial selection (caret). + DCHECK(!word_range.IsNull()); + word_start_adjust = word_range.StartOffset() - initial_range.StartOffset(); + word_end_adjust = word_range.EndOffset() - initial_range.EndOffset(); + } SetHandlingInputEvent(false); auto result = mojom::blink::SelectAroundCaretResult::New(); result->extended_start_adjust = extended_start_adjust; result->extended_end_adjust = extended_end_adjust; + result->word_start_adjust = word_start_adjust; + result->word_end_adjust = word_end_adjust; std::move(callback).Run(std::move(result)); } #endif
diff --git a/third_party/blink/renderer/core/frame/web_local_frame_impl.cc b/third_party/blink/renderer/core/frame/web_local_frame_impl.cc index 79ec0d0d..a63f132 100644 --- a/third_party/blink/renderer/core/frame/web_local_frame_impl.cc +++ b/third_party/blink/renderer/core/frame/web_local_frame_impl.cc
@@ -1429,6 +1429,11 @@ : ContextMenuVisibility ::kNotVisible); } +EphemeralRange WebLocalFrameImpl::GetWordSelectionRangeAroundCaret() const { + TRACE_EVENT0("blink", "WebLocalFrameImpl::getWordSelectionRangeAroundCaret"); + return GetFrame()->Selection().GetWordSelectionRangeAroundCaret(); +} + void WebLocalFrameImpl::SelectRange(const gfx::Point& base_in_viewport, const gfx::Point& extent_in_viewport) { MoveRangeSelection(base_in_viewport, extent_in_viewport); @@ -2894,12 +2899,4 @@ has_scrolled_focused_editable_node_into_rect_ = false; } -bool WebLocalFrameImpl::ServiceWorkerSubresourceFilterEnabled() { - if (GetFrame() && GetFrame()->GetDocument()) { - return RuntimeEnabledFeatures::ServiceWorkerSubresourceFilterEnabled( - GetFrame()->GetDocument()->GetExecutionContext()); - } - return false; -} - } // namespace blink
diff --git a/third_party/blink/renderer/core/frame/web_local_frame_impl.h b/third_party/blink/renderer/core/frame/web_local_frame_impl.h index 45c6211..f22f5a2 100644 --- a/third_party/blink/renderer/core/frame/web_local_frame_impl.h +++ b/third_party/blink/renderer/core/frame/web_local_frame_impl.h
@@ -218,6 +218,7 @@ bool SelectAroundCaret(mojom::blink::SelectionGranularity granularity, bool should_show_handle, bool should_show_context_menu); + EphemeralRange GetWordSelectionRangeAroundCaret() const; void SelectRange(const gfx::Point& base, const gfx::Point& extent) override; void SelectRange(const WebRange&, HandleVisibilityBehavior, @@ -335,7 +336,6 @@ void UpdateCurrentHistoryItem() override; PageState CurrentHistoryItemToPageState() override; const WebHistoryItem& GetCurrentHistoryItem() const override; - bool ServiceWorkerSubresourceFilterEnabled() override; void SetLocalStorageArea( CrossVariantMojoRemote<mojom::StorageAreaInterfaceBase> local_storage_area) override;
diff --git a/third_party/blink/renderer/core/html/canvas/canvas_rendering_context.h b/third_party/blink/renderer/core/html/canvas/canvas_rendering_context.h index d8bbec9..e212237 100644 --- a/third_party/blink/renderer/core/html/canvas/canvas_rendering_context.h +++ b/third_party/blink/renderer/core/html/canvas/canvas_rendering_context.h
@@ -287,9 +287,9 @@ return CanvasColorParams(); } - private: - void Dispose(); + virtual void Dispose(); + private: Member<CanvasRenderingContextHost> host_; CanvasColorParams color_params_; CanvasContextCreationAttributesCore creation_attributes_;
diff --git a/third_party/blink/renderer/core/html/canvas/html_canvas_element.cc b/third_party/blink/renderer/core/html/canvas/html_canvas_element.cc index 2e45c9dc..b12205b 100644 --- a/third_party/blink/renderer/core/html/canvas/html_canvas_element.cc +++ b/third_party/blink/renderer/core/html/canvas/html_canvas_element.cc
@@ -40,7 +40,6 @@ #include "base/numerics/checked_math.h" #include "base/numerics/safe_conversions.h" #include "build/build_config.h" -#include "build/os_buildflags.h" #include "services/metrics/public/cpp/ukm_recorder.h" #include "services/metrics/public/cpp/ukm_source_id.h" #include "third_party/blink/public/common/features.h"
diff --git a/third_party/blink/renderer/core/layout/ng/layout_ng_text_control_multi_line.cc b/third_party/blink/renderer/core/layout/ng/layout_ng_text_control_multi_line.cc index 4469ef6..19fb387 100644 --- a/third_party/blink/renderer/core/layout/ng/layout_ng_text_control_multi_line.cc +++ b/third_party/blink/renderer/core/layout/ng/layout_ng_text_control_multi_line.cc
@@ -34,20 +34,19 @@ const HitTestLocation& hit_test_location, const PhysicalOffset& accumulated_offset, HitTestAction hit_test_action) { - if (!LayoutNGBlockFlow::NodeAtPoint(result, hit_test_location, - accumulated_offset, hit_test_action)) - return false; + bool stop_hit_testing = LayoutNGBlockFlow::NodeAtPoint( + result, hit_test_location, accumulated_offset, hit_test_action); const LayoutObject* stop_node = result.GetHitTestRequest().GetStopNode(); if (stop_node && stop_node->NodeForHitTest() == result.InnerNode()) - return true; + return stop_hit_testing; HTMLElement* inner_editor = InnerEditorElement(); if (result.InnerNode() == GetNode() || result.InnerNode() == inner_editor) { LayoutTextControl::HitInnerEditorElement( *this, *inner_editor, result, hit_test_location, accumulated_offset); } - return true; + return stop_hit_testing; } } // namespace blink
diff --git a/third_party/blink/renderer/core/layout/ng/layout_ng_text_control_single_line.cc b/third_party/blink/renderer/core/layout/ng/layout_ng_text_control_single_line.cc index b425de0..f3157e6 100644 --- a/third_party/blink/renderer/core/layout/ng/layout_ng_text_control_single_line.cc +++ b/third_party/blink/renderer/core/layout/ng/layout_ng_text_control_single_line.cc
@@ -49,13 +49,14 @@ const PhysicalOffset& accumulated_offset, HitTestAction hit_test_action) { NOT_DESTROYED(); - if (!LayoutNGBlockFlow::NodeAtPoint(result, hit_test_location, - accumulated_offset, hit_test_action)) - return false; + bool stop_hit_testing = LayoutNGBlockFlow::NodeAtPoint( + result, hit_test_location, accumulated_offset, hit_test_action); const LayoutObject* stop_node = result.GetHitTestRequest().GetStopNode(); - if (stop_node && stop_node->NodeForHitTest() == result.InnerNode()) - return true; + if (!result.InnerNode() || + (stop_node && stop_node->NodeForHitTest() == result.InnerNode())) { + return stop_hit_testing; + } // Say that we hit the inner text element if // - we hit a node inside the inner editor element, @@ -77,7 +78,7 @@ LayoutTextControl::HitInnerEditorElement( *this, *inner_editor, result, hit_test_location, accumulated_offset); } - return true; + return stop_hit_testing; } bool LayoutNGTextControlSingleLine::AllowsNonVisibleOverflow() const {
diff --git a/third_party/blink/renderer/modules/canvas/imagebitmap/image_bitmap_rendering_context_base.cc b/third_party/blink/renderer/modules/canvas/imagebitmap/image_bitmap_rendering_context_base.cc index 9deba9e3..ebe6556 100644 --- a/third_party/blink/renderer/modules/canvas/imagebitmap/image_bitmap_rendering_context_base.cc +++ b/third_party/blink/renderer/modules/canvas/imagebitmap/image_bitmap_rendering_context_base.cc
@@ -40,6 +40,11 @@ image_layer_bridge_->Dispose(); } +void ImageBitmapRenderingContextBase::Dispose() { + Stop(); + CanvasRenderingContext::Dispose(); +} + void ImageBitmapRenderingContextBase::ResetInternalBitmapToBlackTransparent( int width, int height) {
diff --git a/third_party/blink/renderer/modules/canvas/imagebitmap/image_bitmap_rendering_context_base.h b/third_party/blink/renderer/modules/canvas/imagebitmap/image_bitmap_rendering_context_base.h index efbb8b58..7e38191 100644 --- a/third_party/blink/renderer/modules/canvas/imagebitmap/image_bitmap_rendering_context_base.h +++ b/third_party/blink/renderer/modules/canvas/imagebitmap/image_bitmap_rendering_context_base.h
@@ -64,6 +64,8 @@ bool IsPaintable() const final; protected: + void Dispose() override; + Member<ImageLayerBridge> image_layer_bridge_; // This function resets the internal image resource to a image of the same
diff --git a/third_party/blink/renderer/modules/credentialmanager/credentials_container.cc b/third_party/blink/renderer/modules/credentialmanager/credentials_container.cc index 93abfc8d..0be83f1 100644 --- a/third_party/blink/renderer/modules/credentialmanager/credentials_container.cc +++ b/third_party/blink/renderer/modules/credentialmanager/credentials_container.cc
@@ -609,7 +609,7 @@ } resolver->Resolve(MakeGarbageCollected<PublicKeyCredential>( credential->info->id, raw_id, authenticator_response, - credential->has_transport, credential->transport, extension_outputs)); + credential->authenticator_attachment, extension_outputs)); } bool IsForPayment(const CredentialCreationOptions* options, ExecutionContext* context) { @@ -725,8 +725,8 @@ resolver->Resolve(MakeGarbageCollected<PublicKeyCredential>( credential->info->id, VectorToDOMArrayBuffer(std::move(credential->info->raw_id)), - authenticator_response, credential->has_transport, - credential->transport, extension_outputs)); + authenticator_response, credential->authenticator_attachment, + extension_outputs)); return; } DCHECK(!credential);
diff --git a/third_party/blink/renderer/modules/credentialmanager/public_key_credential.cc b/third_party/blink/renderer/modules/credentialmanager/public_key_credential.cc index 44378a4..e99f340 100644 --- a/third_party/blink/renderer/modules/credentialmanager/public_key_credential.cc +++ b/third_party/blink/renderer/modules/credentialmanager/public_key_credential.cc
@@ -6,6 +6,7 @@ #include <utility> +#include "third_party/blink/public/mojom/webauthn/authenticator.mojom-shared.h" #include "third_party/blink/renderer/bindings/core/v8/script_promise.h" #include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h" #include "third_party/blink/renderer/core/dom/dom_exception.h" @@ -27,24 +28,32 @@ bool available) { scoped_resolver->Release()->Resolve(available); } + +absl::optional<std::string> AuthenticatorAttachmentToString( + mojom::blink::AuthenticatorAttachment authenticator_attachment) { + switch (authenticator_attachment) { + case mojom::blink::AuthenticatorAttachment::PLATFORM: + return "platform"; + case mojom::blink::AuthenticatorAttachment::CROSS_PLATFORM: + return "cross-platform"; + case mojom::blink::AuthenticatorAttachment::NO_PREFERENCE: + return absl::nullopt; + } +} } // namespace PublicKeyCredential::PublicKeyCredential( const String& id, DOMArrayBuffer* raw_id, AuthenticatorResponse* response, - bool has_transport, - mojom::AuthenticatorTransport transport, + mojom::blink::AuthenticatorAttachment authenticator_attachment, const AuthenticationExtensionsClientOutputs* extension_outputs, const String& type) : Credential(id, type.IsEmpty() ? kPublicKeyCredentialType : type), raw_id_(raw_id), response_(response), authenticator_attachment_( - has_transport ? (transport == mojom::AuthenticatorTransport::INTERNAL - ? absl::make_optional("platform") - : absl::make_optional("cross-platform")) - : absl::nullopt), + AuthenticatorAttachmentToString(authenticator_attachment)), extension_outputs_(extension_outputs) {} ScriptPromise
diff --git a/third_party/blink/renderer/modules/credentialmanager/public_key_credential.h b/third_party/blink/renderer/modules/credentialmanager/public_key_credential.h index 55b44cd..fb49a8f 100644 --- a/third_party/blink/renderer/modules/credentialmanager/public_key_credential.h +++ b/third_party/blink/renderer/modules/credentialmanager/public_key_credential.h
@@ -6,7 +6,6 @@ #define THIRD_PARTY_BLINK_RENDERER_MODULES_CREDENTIALMANAGER_PUBLIC_KEY_CREDENTIAL_H_ #include "third_party/abseil-cpp/absl/types/optional.h" -#include "third_party/blink/public/mojom/webauthn/authenticator.mojom-shared.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_authentication_extensions_client_outputs.h" #include "third_party/blink/renderer/core/typed_arrays/dom_array_buffer.h" #include "third_party/blink/renderer/modules/credentialmanager/authenticator_response.h" @@ -17,8 +16,11 @@ namespace blink { +namespace mojom { +enum class AuthenticatorAttachment; +} + class AuthenticatorResponse; -class AuthenticatorTransport; class ScriptPromise; class ScriptState; @@ -30,8 +32,7 @@ const String& id, DOMArrayBuffer* raw_id, AuthenticatorResponse*, - bool has_transport, - mojom::AuthenticatorTransport transport, + mojom::AuthenticatorAttachment authenticator_attachment, const AuthenticationExtensionsClientOutputs* extension_outputs, const String& type = "");
diff --git a/third_party/blink/renderer/modules/mediastream/BUILD.gn b/third_party/blink/renderer/modules/mediastream/BUILD.gn index 736cecf..5b24736 100644 --- a/third_party/blink/renderer/modules/mediastream/BUILD.gn +++ b/third_party/blink/renderer/modules/mediastream/BUILD.gn
@@ -107,7 +107,6 @@ deps = [ "//build:chromecast_buildflags", "//build:chromeos_buildflags", - "//build:os_buildflags", "//media/capture/mojom:image_capture_blink", "//media/webrtc", "//services/viz/public/cpp/gpu:gpu",
diff --git a/third_party/blink/renderer/modules/mediastream/media_constraints_impl.cc b/third_party/blink/renderer/modules/mediastream/media_constraints_impl.cc index d56667d..31c1e51 100644 --- a/third_party/blink/renderer/modules/mediastream/media_constraints_impl.cc +++ b/third_party/blink/renderer/modules/mediastream/media_constraints_impl.cc
@@ -30,7 +30,7 @@ #include "third_party/blink/renderer/modules/mediastream/media_constraints_impl.h" -#include "build/os_buildflags.h" +#include "build/build_config.h" #include "third_party/blink/public/platform/web_string.h" #include "third_party/blink/renderer/bindings/core/v8/array_value.h" #include "third_party/blink/renderer/bindings/core/v8/dictionary.h"
diff --git a/third_party/blink/renderer/modules/payments/payment_response.cc b/third_party/blink/renderer/modules/payments/payment_response.cc index 5a7a284d..1389620 100644 --- a/third_party/blink/renderer/modules/payments/payment_response.cc +++ b/third_party/blink/renderer/modules/payments/payment_response.cc
@@ -42,8 +42,8 @@ secure_payment_confirmation->credential_info->id, DOMArrayBuffer::Create(static_cast<const void*>(info->raw_id.data()), info->raw_id.size()), - authenticator_response, secure_payment_confirmation->has_transport, - secure_payment_confirmation->transport, + authenticator_response, + secure_payment_confirmation->authenticator_attachment, AuthenticationExtensionsClientOutputs::Create()); return result->Wrap(script_state).ToLocalChecked(); }
diff --git a/third_party/blink/renderer/modules/webcodecs/BUILD.gn b/third_party/blink/renderer/modules/webcodecs/BUILD.gn index 4053a4a..85d2cf0 100644 --- a/third_party/blink/renderer/modules/webcodecs/BUILD.gn +++ b/third_party/blink/renderer/modules/webcodecs/BUILD.gn
@@ -81,7 +81,6 @@ "webcodecs_logger.h", ] deps = [ - "//build:os_buildflags", "//media", "//media/mojo:buildflags", "//media/mojo/clients",
diff --git a/third_party/blink/renderer/modules/webcodecs/video_encoder.cc b/third_party/blink/renderer/modules/webcodecs/video_encoder.cc index ff1a8bd..fd4d78a4 100644 --- a/third_party/blink/renderer/modules/webcodecs/video_encoder.cc +++ b/third_party/blink/renderer/modules/webcodecs/video_encoder.cc
@@ -17,7 +17,6 @@ #include "base/trace_event/common/trace_event_common.h" #include "base/trace_event/trace_event.h" #include "build/build_config.h" -#include "build/os_buildflags.h" #include "components/viz/common/gpu/raster_context_provider.h" #include "gpu/GLES2/gl2extchromium.h" #include "gpu/command_buffer/client/raster_interface.h"
diff --git a/third_party/blink/renderer/platform/BUILD.gn b/third_party/blink/renderer/platform/BUILD.gn index 72bd14c..2896e16 100644 --- a/third_party/blink/renderer/platform/BUILD.gn +++ b/third_party/blink/renderer/platform/BUILD.gn
@@ -1662,7 +1662,6 @@ "//base/allocator:buildflags", "//build:chromecast_buildflags", "//build:chromeos_buildflags", - "//build:os_buildflags", "//cc/ipc", "//cc/mojo_embedder", "//components/paint_preview/common",
diff --git a/third_party/blink/renderer/platform/network/encoded_form_data.cc b/third_party/blink/renderer/platform/network/encoded_form_data.cc index 665d291..d17dfda 100644 --- a/third_party/blink/renderer/platform/network/encoded_form_data.cc +++ b/third_party/blink/renderer/platform/network/encoded_form_data.cc
@@ -121,13 +121,6 @@ return result; } -scoped_refptr<EncodedFormData> EncodedFormData::Create( - const Vector<char>& vector) { - scoped_refptr<EncodedFormData> result = Create(); - result->AppendData(vector.data(), vector.size()); - return result; -} - scoped_refptr<EncodedFormData> EncodedFormData::Copy() const { return base::AdoptRef(new EncodedFormData(*this)); }
diff --git a/third_party/blink/renderer/platform/network/encoded_form_data.h b/third_party/blink/renderer/platform/network/encoded_form_data.h index 79653a1..9b8e00f 100644 --- a/third_party/blink/renderer/platform/network/encoded_form_data.h +++ b/third_party/blink/renderer/platform/network/encoded_form_data.h
@@ -103,7 +103,6 @@ static scoped_refptr<EncodedFormData> Create(); static scoped_refptr<EncodedFormData> Create(const void*, wtf_size_t); static scoped_refptr<EncodedFormData> Create(base::span<const char>); - static scoped_refptr<EncodedFormData> Create(const Vector<char>&); scoped_refptr<EncodedFormData> Copy() const; scoped_refptr<EncodedFormData> DeepCopy() const; ~EncodedFormData();
diff --git a/third_party/blink/renderer/platform/runtime_enabled_features.json5 b/third_party/blink/renderer/platform/runtime_enabled_features.json5 index fd09b455..c0cad641 100644 --- a/third_party/blink/renderer/platform/runtime_enabled_features.json5 +++ b/third_party/blink/renderer/platform/runtime_enabled_features.json5
@@ -2098,11 +2098,6 @@ status: "experimental", }, { - name: "ServiceWorkerSubresourceFilter", - origin_trial_feature_name: "ServiceWorkerSubresourceFilter", - status: "experimental", - }, - { name: "SharedArrayBuffer", }, {
diff --git a/third_party/blink/renderer/platform/storage/blink_storage_key_hash.h b/third_party/blink/renderer/platform/storage/blink_storage_key_hash.h index 1fcf115..5c187b8 100644 --- a/third_party/blink/renderer/platform/storage/blink_storage_key_hash.h +++ b/third_party/blink/renderer/platform/storage/blink_storage_key_hash.h
@@ -13,8 +13,8 @@ namespace blink { // TODO(https://crbug.com/1199077): This needs to be re-implemented for the -// actual StorageKey content once it's stable. Right now it's just a shim for -// `SecurityOriginHash`. +// actual StorageKey content once it's stable. Right now it's just (almost) a +// shim for `SecurityOriginHash`. struct BlinkStorageKeyHash { STATIC_ONLY(BlinkStorageKeyHash); @@ -28,8 +28,7 @@ } static bool Equal(const BlinkStorageKey* a, const BlinkStorageKey* b) { - return SecurityOriginHash::Equal(a->GetSecurityOrigin(), - b->GetSecurityOrigin()); + return *a == *b; } static bool Equal(const std::unique_ptr<const BlinkStorageKey>& a,
diff --git a/third_party/blink/renderer/platform/webrtc/webrtc_video_frame_adapter.cc b/third_party/blink/renderer/platform/webrtc/webrtc_video_frame_adapter.cc index 80771ee..b019cb4a 100644 --- a/third_party/blink/renderer/platform/webrtc/webrtc_video_frame_adapter.cc +++ b/third_party/blink/renderer/platform/webrtc/webrtc_video_frame_adapter.cc
@@ -13,7 +13,6 @@ #include "base/synchronization/waitable_event.h" #include "base/threading/thread_restrictions.h" #include "build/build_config.h" -#include "build/os_buildflags.h" #include "gpu/command_buffer/client/gpu_memory_buffer_manager.h" #include "gpu/command_buffer/client/raster_interface.h" #include "gpu/command_buffer/client/shared_image_interface.h"
diff --git a/third_party/blink/tools/blinkpy/web_tests/port/base.py b/third_party/blink/tools/blinkpy/web_tests/port/base.py index 5095962..d6bd435 100644 --- a/third_party/blink/tools/blinkpy/web_tests/port/base.py +++ b/third_party/blink/tools/blinkpy/web_tests/port/base.py
@@ -1459,7 +1459,6 @@ ] clean_env['DISPLAY'] = self.host.environ.get('DISPLAY', ':1') if self.host.platform.is_mac(): - clean_env['DYLD_LIBRARY_PATH'] = self._build_path() variables_to_copy += [ 'HOME', ]
diff --git a/third_party/blink/web_tests/TestExpectations b/third_party/blink/web_tests/TestExpectations index 4005bf61..a8febfb 100644 --- a/third_party/blink/web_tests/TestExpectations +++ b/third_party/blink/web_tests/TestExpectations
@@ -3015,6 +3015,16 @@ crbug.com/626703 [ Mac11 ] external/wpt/css/css-flexbox/abspos/flex-abspos-staticpos-align-self-006.html [ Failure ] # ====== New tests from wpt-importer added here ====== +crbug.com/626703 [ Linux ] external/wpt/css/css-color/tagged-images-003.html [ Failure ] +crbug.com/626703 [ Mac10.12 ] external/wpt/css/css-color/tagged-images-003.html [ Failure ] +crbug.com/626703 [ Mac10.13 ] external/wpt/css/css-color/tagged-images-003.html [ Failure ] +crbug.com/626703 [ Mac10.14 ] external/wpt/css/css-color/tagged-images-003.html [ Failure ] +crbug.com/626703 [ Mac10.15 ] external/wpt/css/css-color/tagged-images-003.html [ Failure ] +crbug.com/626703 [ Mac11 ] external/wpt/css/css-color/tagged-images-003.html [ Failure ] +crbug.com/626703 [ Win ] external/wpt/css/css-color/tagged-images-003.html [ Failure ] +crbug.com/626703 external/wpt/css/css-color/tagged-images-004.html [ Failure ] +crbug.com/626703 virtual/system-color-compute/external/wpt/css/css-color/tagged-images-003.html [ Failure ] +crbug.com/626703 virtual/system-color-compute/external/wpt/css/css-color/tagged-images-004.html [ Failure ] crbug.com/626703 [ Mac10.13 ] virtual/plz-dedicated-worker/external/wpt/resource-timing/object-not-found-adds-entry.html [ Timeout ] crbug.com/626703 [ Mac11 ] external/wpt/css/css-color-adjust/inheritance.html [ Crash Failure ] crbug.com/626703 [ Mac11-arm64 ] external/wpt/fetch/private-network-access/fetch.window.html [ Timeout ] @@ -6519,9 +6529,6 @@ # DevTools roll crbug.com/1050549 http/tests/devtools/console/console-correct-suggestions.js [ Failure Pass ] -crbug.com/1016266 http/tests/devtools/security/interstitial-sidebar.js [ Failure Pass ] -crbug.com/1016266 http/tests/devtools/security/mixed-content-sidebar.js [ Failure Pass ] -crbug.com/1016266 http/tests/devtools/console/console-format-classes.js [ Failure Pass ] # Flaky test crbug.com/1173439 http/tests/devtools/service-workers/service-worker-manager.js [ Pass Skip Timeout ] @@ -7090,7 +7097,7 @@ crbug.com/1201406 [ Mac11 ] fast/events/touch/gesture/touch-gesture-scroll-listbox.html [ Failure ] crbug.com/1201406 [ Mac11 ] http/tests/credentialmanager/credentialscontainer-create-from-nested-frame.html [ Crash Timeout ] crbug.com/1201406 [ Mac11 ] http/tests/credentialmanager/credentialscontainer-create-origins.html [ Crash Timeout ] -crbug.com/1201406 [ Mac11 ] http/tests/credentialmanager/publickeycredential-same-origin-with-ancestors.html [ Crash Timeout ] +crbug.com/1201406 http/tests/credentialmanager/publickeycredential-same-origin-with-ancestors.html [ Crash Pass Timeout ] crbug.com/1201406 [ Mac11 ] http/tests/credentialmanager/register-then-sign.html [ Crash Skip Timeout ] crbug.com/1201406 [ Mac11 ] http/tests/inspector-protocol/webauthn/webauthn-add-virtual-authenticator.js [ Crash Timeout ] crbug.com/1201406 [ Mac11 ] http/tests/inspector-protocol/webauthn/webauthn-clear-credentials.js [ Crash ] @@ -7883,3 +7890,13 @@ # Sheriff 2022-01-05 crbug.com/1284572 [ Linux ] virtual/threaded/http/tests/devtools/isolated-code-cache/same-origin-module-test.js [ Failure Pass ] + +# Sheriff 2022-01-07 +crbug.com/1285348 [ Mac ] virtual/off-main-thread-css-paint/external/wpt/css/css-paint-api/hidpi/canvas-transform.https.html [ Failure Pass ] +crbug.com/1285350 virtual/off-main-thread-css-paint/external/wpt/css/css-paint-api/parse-input-arguments-011.https.html [ Failure Pass ] +crbug.com/1285373 [ Linux ] fast/multicol/dynamic/insert-block-before-spanner-before-content.html [ Crash Failure Pass ] +crbug.com/1285426 [ Mac ] virtual/controls-refresh-hc/fast/forms/color-scheme/week-picker/week-picker-appearance-highlight-es.html [ Crash Pass ] +crbug.com/1285438 fast/layers/clip-rects-transformed.html [ Failure Pass ] +crbug.com/1285431 [ Mac ] virtual/backface-visibility-interop/compositing/overflow/transform-should-update-absolute-clip-rects.html [ Crash Failure Pass ] +crbug.com/1285437 virtual/scroll-unification-prefer_compositing_to_lcd_text/fast/scroll-behavior/smooth-scroll/ongoing-smooth-scroll-vertical-rl-anchors.html [ Pass Timeout ] +crbug.com/1285436 virtual/partitioned-cookies/http/tests/inspector-protocol/network/blocked-setcookie-same-site-lax.js [ Crash Failure Pass Timeout ]
diff --git a/third_party/blink/web_tests/external/WPT_BASE_MANIFEST_8.json b/third_party/blink/web_tests/external/WPT_BASE_MANIFEST_8.json index f44264c..2bff58a9 100644 --- a/third_party/blink/web_tests/external/WPT_BASE_MANIFEST_8.json +++ b/third_party/blink/web_tests/external/WPT_BASE_MANIFEST_8.json
@@ -601,6 +601,13 @@ {} ] ], + "hidden-execcommand-crash.html": [ + "67e297e6520b49325dd6091ca474c0151af5e0ea", + [ + null, + {} + ] + ], "hidden-pseudo-element-removed-crash.html": [ "62e38f214be715299f59bfc5f85697d28dfe4e92", [ @@ -85208,6 +85215,71 @@ {} ] ], + "tagged-images-001.html": [ + "0af9c0ce51e109588772a291e7b0250549c7fbe8", + [ + null, + [ + [ + "/css/css-color/009900-image-ref.html", + "==" + ] + ], + {} + ] + ], + "tagged-images-002.html": [ + "b7c76b35502e13e65c4e8204ea1f05d15940f0ea", + [ + null, + [ + [ + "/css/css-color/009900-image-ref.html", + "==" + ] + ], + {} + ] + ], + "tagged-images-003.html": [ + "09ce9e9418f4e3edfb0977b52b24906f45e5a715", + [ + null, + [ + [ + "/css/css-color/greensquare-090-ref.html", + "==" + ] + ], + {} + ] + ], + "tagged-images-004.html": [ + "9e67776f680901d05c6c89df044f8b201f33d600", + [ + null, + [ + [ + "/css/css-color/greensquare-090-ref.html", + "==" + ] + ], + {} + ] + ], + "untagged-images-001.html": [ + "a1aafb759c0d5994cc45ddd9422a1f7e39fa66e8", + [ + null, + [ + [ + "/css/css-color/009900-tagged-image-ref.html", + "==" + ] + ], + {} + ] + ], "xyz-001.html": [ "6fdf98e2dfc9f6d5341d12e5367b3159127e5df3", [ @@ -136080,6 +136152,45 @@ {} ] ], + "overflow-auto-scrollbar-gutter-intrinsic-001.html": [ + "061339b49cd29b0f68d53c5ecce6d2eb03fdaa6c", + [ + null, + [ + [ + "/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-001-ref.html", + "==" + ] + ], + {} + ] + ], + "overflow-auto-scrollbar-gutter-intrinsic-002.html": [ + "535f2c4d5298df70e2eb417531093bb7fa753cbd", + [ + null, + [ + [ + "/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-002-ref.html", + "==" + ] + ], + {} + ] + ], + "overflow-auto-scrollbar-gutter-intrinsic-003.html": [ + "ab247d9ca5a7cac660200cb475feb6b69093e520", + [ + null, + [ + [ + "/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-003-ref.html", + "==" + ] + ], + {} + ] + ], "overflow-body-propagation-001.html": [ "0998fe68e007d5a46ff11d5ff87fdbad12d1dfe2", [ @@ -136496,6 +136607,19 @@ {} ] ], + "overflow-scroll-intrinsic-001.html": [ + "093fd283c77b71e7f73914761cc086e10c12a0d6", + [ + null, + [ + [ + "/css/css-overflow/overflow-scroll-intrinsic-001-ref.html", + "==" + ] + ], + {} + ] + ], "overflow-scroll-resize-visibility-hidden.html": [ "e8d0bc91440c7ed810079cc577d1e02689d3b220", [ @@ -242009,6 +242133,14 @@ } }, "css-color": { + "009900-image-ref.html": [ + "cf1ba5c48f70bba18ac44e3acf35378aa4f62fd5", + [] + ], + "009900-tagged-image-ref.html": [ + "21086fc62185a65da89513daf0c8bc1b66688064", + [] + ], "LICENSE": [ "d47f50cca8a2d9dc40dee384ae256f8aecf44e0a", [] @@ -242275,6 +242407,20 @@ "a1574caa30328d9b053283515f9d2b54d6326a54", [] ], + "support": { + "009900-sRGB.png": [ + "d6b0c96642f2552a75fb2be54b3bb3a3d49730dd", + [] + ], + "009900.png": [ + "7ddaa10b44e83d5a64a66370da1b1dcaa1ea7ec4", + [] + ], + "swap-990000-iCCP.png": [ + "766d0f76797a31c3d1c041680b8d04dfcb72ab1a", + [] + ] + }, "t31-color-currentColor-b-ref.html": [ "3013c7050c3c6f057e295923d43c87da6c09751f", [] @@ -259441,6 +259587,18 @@ "59864d0f4d185ece259879a299f597b80f9babdc", [] ], + "overflow-auto-scrollbar-gutter-intrinsic-001-ref.html": [ + "8c092bd0c876bc328919f4b39358110dd902c726", + [] + ], + "overflow-auto-scrollbar-gutter-intrinsic-002-ref.html": [ + "bb0742bfd0f130021ed4234920365aa83a2efed5", + [] + ], + "overflow-auto-scrollbar-gutter-intrinsic-003-ref.html": [ + "2fb2eb4479a876fb2c990b5a16d1cbbd44aade37", + [] + ], "overflow-body-propagation-007-ref.html": [ "66f9b1c3b0098c0cc448775cde0d94b4325a94a8", [] @@ -259517,6 +259675,10 @@ "c7ea1807443ef1b2d454edd547e65c89a59cea16", [] ], + "overflow-scroll-intrinsic-001-ref.html": [ + "8870d339196d8ccc14db575437190c728fa7e49c", + [] + ], "overflow-scroll-resize-visibility-hidden-ref.html": [ "571ba348df4b2dfdd05b31d074496c95c340635d", [] @@ -261011,7 +261173,7 @@ ], "parsing": { "marker-supported-properties-expected.txt": [ - "b599ad933fc63e2535c7740ac8cca8e933e17883", + "3b1189ccdf31f07452875606ebf38a736363409b", [] ], "marker-supported-properties-in-animation-expected.txt": [ @@ -278108,7 +278270,7 @@ [] ], "inserthtml.js": [ - "376d988bec125d1272a8137b34220facf69b006d", + "1f0a8f588dd630e3bcdcccc37d6ba3d021cab084", [] ], "insertimage.js": [ @@ -278494,7 +278656,7 @@ [] ], "inserthtml-expected.txt": [ - "2cae0dffb68410e3daf9c4e94b92bbe9dcee1605", + "06c23700443abaff3eb0e9e872ee308d5befa517", [] ], "insertimage-expected.txt": [ @@ -282137,11 +282299,15 @@ [] ], "preflight.py": [ - "8accc689f3bffe38dca735361ea8f7f582ccf682", + "bc2250456e04dbb35edb0496685d0af698fded6b", [] ], - "service-worker-fetcher.html": [ - "d9c0df30fd4bfcc2eb780a839037d0c2ab88af4b", + "service-worker-bridge.html": [ + "9bee90d210580b371dfadbff3061b666b0b83dbe", + [] + ], + "service-worker.js": [ + "bca71ad910cb189c2de6298b4ea59b5594aba637", [] ], "shared-worker-fetcher.html": [ @@ -282165,8 +282331,16 @@ [] ] }, + "service-worker-fetch.https.window-expected.txt": [ + "716cc67bdef69bfb4a99cafcc8e0af6c16e94787", + [] + ], + "service-worker-update.https.window-expected.txt": [ + "2884c271eecfd46ee3fee89707cbd5c81c624939", + [] + ], "service-worker.https.window-expected.txt": [ - "862f41ae97d255c84fe160a70cb5d2eb14e72f39", + "6c834a89c8d481231d211249ff5451aaa9164f5b", [] ], "shared-worker.https.window-expected.txt": [ @@ -284059,6 +284233,12 @@ "5f10361d5cffd939d1cc12c73cf49218098d7351", [] ], + "eligibility": { + "broadcast-channel-expected.txt": [ + "307bd14234d626c558b4070b1f9225ae6b722981", + [] + ] + }, "events-expected.txt": [ "6dad05534728dc6a05db46b5c112a2fd9546e8ce", [] @@ -287948,10 +288128,6 @@ "377c7296a781adb8ce5792cda50a4ded28472339", [] ], - "local-storage.tentative.https.window-expected.txt": [ - "0f6267c7c83b6d4183dc8d273e43e12b80fe07fc", - [] - ], "require-corp-embed-anonymous-iframe.tentative.https.window.js.headers": [ "6604450991a122e3e241e40b1b9e0516c525389d", [] @@ -309168,10 +309344,6 @@ "3f8456354bd08a3cb87e767ca0eb508603f45f1b", [] ], - "basic-dedicated-worker-expected.txt": [ - "12e06f3594177c46c8e01fcd3873c83d8d03cc0b", - [] - ], "basic-popup-and-iframe-tests.https-expected.txt": [ "712c701f4f7230566278fd0df2cf7d350dab8707", [] @@ -315781,7 +315953,7 @@ [] ], "accumulation-per-property-002-expected.txt": [ - "b477cd5cfb940ffef79b82973abbec1c049b66e8", + "357a3c326e931046ef5cf8a67898fc67f44c2804", [] ], "addition-per-property-001-expected.txt": [ @@ -315789,7 +315961,7 @@ [] ], "addition-per-property-002-expected.txt": [ - "9a5c6364dc42aa14461a5be0b67cc8eb49593b09", + "0e7f2b0b8eee56f51b02421646784b1ce4ea540c", [] ], "interpolation-per-property-001-expected.txt": [ @@ -315797,11 +315969,11 @@ [] ], "interpolation-per-property-002-expected.txt": [ - "c4a07ae6b232dd902c94032bc86803073cfc67ad", + "74ac012b78c356fe0337d24f7d54e680834e2d7a", [] ], "property-list.js": [ - "60b0b595150cd92486d3ed41abba6188fd4d4011", + "e9b7c524a749f2a8372868a1245224cf0f7d01d5", [] ], "property-types.js": [ @@ -369216,14 +369388,14 @@ ] ], "marker-supported-properties-in-animation.html": [ - "259a5b84eb5122c309294973b5ad745da11216e5", + "e581a786367ac13a90f95cb043b95c629babd9ab", [ null, {} ] ], "marker-supported-properties.html": [ - "ddcf98bfcf283fbb72d08396a6676a686bf02da5", + "ab03b9825deee02142bd0a03a21e9ff95feaa82f", [ null, {} @@ -375228,7 +375400,7 @@ }, "css-text-decor": { "inheritance.html": [ - "b106343742e03aa305a3017610272c8692b2a428", + "9ee65b4e5926f043b4841bc1f5a7cf84e214ce49", [ null, {} @@ -375348,7 +375520,7 @@ ] ], "text-emphasis-computed.html": [ - "460035a55179ada38a78be2dd7a280c8a0bd3518", + "7765f87d5b1eef59e9ae05f00de11daebeda038c", [ null, {} @@ -375362,7 +375534,7 @@ ] ], "text-emphasis-style-computed.html": [ - "5aa84ab08b615568b96812670674caade7b663ab", + "9153e82c30a43b88596ef239c30bf6b1fe095edf", [ null, {} @@ -415327,8 +415499,44 @@ } ] ], + "service-worker-fetch.https.window.js": [ + "edb20c04940e3fa76a2dd99e71785f97285c31c6", + [ + "fetch/private-network-access/service-worker-fetch.https.window.html", + { + "script_metadata": [ + [ + "script", + "/common/utils.js" + ], + [ + "script", + "resources/support.sub.js" + ] + ] + } + ] + ], + "service-worker-update.https.window.js": [ + "703c7650b0dbeeb4559a74e3557a91bf6b484e19", + [ + "fetch/private-network-access/service-worker-update.https.window.html", + { + "script_metadata": [ + [ + "script", + "/common/utils.js" + ], + [ + "script", + "resources/support.sub.js" + ] + ] + } + ] + ], "service-worker.https.window.js": [ - "969f359700e4c1c145743eb4c826973e7400a3af", + "3d3845b4d7cb832dcc4ad7238c54d655c84656fc", [ "fetch/private-network-access/service-worker.https.window.html", { @@ -418160,6 +418368,15 @@ "browsers": { "browsing-the-web": { "back-forward-cache": { + "eligibility": { + "broadcast-channel.html": [ + "bc04a5ed7f1c1feea66a6db74d6537ea3f48b577", + [ + null, + {} + ] + ] + }, "events.html": [ "4b1d3e408ebc25a4bf5f18087fc9bac67b2a1889", [ @@ -495833,7 +496050,7 @@ }, "secure-contexts": { "basic-dedicated-worker.html": [ - "581d761f5e795a2734f306c239d4a96220d4219a", + "ecd6138ac25d575a76a8625323513b9ec42646d6", [ null, {}
diff --git a/third_party/blink/web_tests/external/wpt/css/css-color/009900-image-ref.html b/third_party/blink/web_tests/external/wpt/css/css-color/009900-image-ref.html new file mode 100644 index 0000000..cf1ba5c --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-color/009900-image-ref.html
@@ -0,0 +1,11 @@ +<!DOCTYPE html> +<meta charset="utf-8"> +<title>CSS Color 4: Color Space of Tagged Images. HTML img</title> +<link rel="author" title="Chris Lilley" href="mailto:chris@w3.org"> + + +<body> + <p>Test passes if you see a green square, and no red.</p> + <!-- solid color #009900 PNG image, untagged --> + <p><img src="./support/009900.png" alt="should be green"/></p> +</body> \ No newline at end of file
diff --git a/third_party/blink/web_tests/external/wpt/css/css-color/009900-tagged-image-ref.html b/third_party/blink/web_tests/external/wpt/css/css-color/009900-tagged-image-ref.html new file mode 100644 index 0000000..21086fc --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-color/009900-tagged-image-ref.html
@@ -0,0 +1,14 @@ +<!DOCTYPE html> +<meta charset="utf-8"> +<title>CSS Color 4: Color Space of Tagged Images. HTML img</title> +<link rel="author" title="Chris Lilley" href="mailto:chris@w3.org"> + + +<body> + <p>Test passes if you see a green square, and no red.</p> + <!-- + solid color #990000 PNG image, iCCP with v2 profile + red-green swapped to be sure the profile is actually applied + --> + <p><img src="./support/swap-990000-iCCP.png" alt="should be green"/></p> +</body> \ No newline at end of file
diff --git a/third_party/blink/web_tests/external/wpt/css/css-color/support/009900-sRGB.png b/third_party/blink/web_tests/external/wpt/css/css-color/support/009900-sRGB.png new file mode 100644 index 0000000..d6b0c96 --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-color/support/009900-sRGB.png Binary files differ
diff --git a/third_party/blink/web_tests/external/wpt/css/css-color/support/009900.png b/third_party/blink/web_tests/external/wpt/css/css-color/support/009900.png new file mode 100644 index 0000000..7ddaa10 --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-color/support/009900.png Binary files differ
diff --git a/third_party/blink/web_tests/external/wpt/css/css-color/support/swap-990000-iCCP.png b/third_party/blink/web_tests/external/wpt/css/css-color/support/swap-990000-iCCP.png new file mode 100644 index 0000000..766d0f767 --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-color/support/swap-990000-iCCP.png Binary files differ
diff --git a/third_party/blink/web_tests/external/wpt/css/css-color/tagged-images-001.html b/third_party/blink/web_tests/external/wpt/css/css-color/tagged-images-001.html new file mode 100644 index 0000000..0af9c0c --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-color/tagged-images-001.html
@@ -0,0 +1,14 @@ +<!DOCTYPE html> +<meta charset="utf-8"> +<title>CSS Color 4: Color Space of Tagged Images. HTML img</title> +<link rel="author" title="Chris Lilley" href="mailto:chris@w3.org"> +<link rel="help" href="https://drafts.csswg.org/css-color-4/#tagged-images"> +<link rel="help" href="https://w3c.github.io/PNG-spec/#11iCCP"> +<meta name="assert" content="Tagged RGB images... if the color profile or other identifying information is valid, must be treated as being in the specified color space."> +<link rel="match" href="009900-image-ref.html"> + +<body> + <p>Test passes if you see a green square, and no red.</p> + <!-- solid color #990000 PNG image, iCCP with v2 ICC swapped red-green sRGB profile --> + <p><img src="./support/swap-990000-iCCP.png" alt="should be green"/></p> +</body> \ No newline at end of file
diff --git a/third_party/blink/web_tests/external/wpt/css/css-color/tagged-images-002.html b/third_party/blink/web_tests/external/wpt/css/css-color/tagged-images-002.html new file mode 100644 index 0000000..b7c76b3 --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-color/tagged-images-002.html
@@ -0,0 +1,14 @@ +<!DOCTYPE html> +<meta charset="utf-8"> +<title>CSS Color 4: Color Space of Tagged Images. HTML img</title> +<link rel="author" title="Chris Lilley" href="mailto:chris@w3.org"> +<link rel="help" href="https://drafts.csswg.org/css-color-4/#tagged-images"> +<link rel="help" href="https://w3c.github.io/PNG-spec/#11sRGB"> +<meta name="assert" content="Tagged RGB images... if the color profile or other identifying information is valid, must be treated as being in the specified color space."> +<link rel="match" href="009900-image-ref.html"> + +<body> + <p>Test passes if you see a green square, and no red.</p> + <!-- solid color #009900 PNG image, sRGB chunk, relative colorimetric --> + <p><img src="./support/009900-sRGB.png" alt="should be green"/></p> +</body> \ No newline at end of file
diff --git a/third_party/blink/web_tests/external/wpt/css/css-color/tagged-images-003.html b/third_party/blink/web_tests/external/wpt/css/css-color/tagged-images-003.html new file mode 100644 index 0000000..09ce9e9 --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-color/tagged-images-003.html
@@ -0,0 +1,22 @@ +<!DOCTYPE html> +<meta charset="utf-8"> +<title>CSS Color 4: Color Space of Tagged Images. CSS background property</title> +<link rel="author" title="Chris Lilley" href="mailto:chris@w3.org"> +<link rel="help" href="https://drafts.csswg.org/css-color-4/#tagged-images"> +<link rel="help" href="https://w3c.github.io/PNG-spec/#11iCCP"> +<meta name="assert" content="Tagged RGB images... if the color profile or other identifying information is valid, must be treated as being in the specified color space."> +<link rel="match" href="./greensquare-090-ref.html"> +<style> + .test { background-color: red; width: 12em; height: 6em; margin-top: 0; } + .ref { background-color: #090; width: 12em; height: 6em; margin-bottom: 0; } /* red-green swap of #900 sRGB */ + /* solid color #990000 PNG image, iCCP with v2 ICC swapped red-green sRGB profile */ + .test { background: url(./support/swap-990000-iCCP.png); } +</style> +<body> + <p>Test passes if you see a single green square, + not two rectangles of different colors, + and no red.</p> + <div class="ref"></div> + <div class="test"></div> + +</body> \ No newline at end of file
diff --git a/third_party/blink/web_tests/external/wpt/css/css-color/tagged-images-004.html b/third_party/blink/web_tests/external/wpt/css/css-color/tagged-images-004.html new file mode 100644 index 0000000..9e67776 --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-color/tagged-images-004.html
@@ -0,0 +1,28 @@ +<!DOCTYPE html> +<meta charset="utf-8"> +<title>CSS Color 4: Color Space of Tagged Images. CSS background property</title> +<link rel="author" title="Chris Lilley" href="mailto:chris@w3.org"> +<link rel="help" href="https://drafts.csswg.org/css-color-4/#tagged-images"> +<link rel="help" href="https://w3c.github.io/PNG-spec/#11iCCP"> +<link rel="help" href="https://w3c.github.io/PNG-spec/#11sRGB"> +<meta name="assert" content="Image formats may also use other, equivalent methods, often for brevity. +For example, PNG specifies a means (the sRGB chunk) "> +<link rel="match" href="./greensquare-090-ref.html"> +<style> + div { width: 12em; height: 6em; margin-top: 0; margin-bottom: 0} + .test { background-color: red; } + + /* solid color #990000 PNG image, iCCP with v2 ICC swapped red-green sRGB profile */ + .ref { background: url(./support/swap-990000-iCCP.png); } + + /* solid color #009900 PNG image, sRGB chunk */ + .test { background: url(./support/009900-sRGB.png); } +</style> +<body> + <p>Test passes if you see a single green square, + not two rectangles of different colors, + and no red.</p> + <div class="ref"></div> + <div class="test"></div> + +</body> \ No newline at end of file
diff --git a/third_party/blink/web_tests/external/wpt/css/css-color/untagged-images-001.html b/third_party/blink/web_tests/external/wpt/css/css-color/untagged-images-001.html new file mode 100644 index 0000000..a1aafb7 --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-color/untagged-images-001.html
@@ -0,0 +1,13 @@ +<!DOCTYPE html> +<meta charset="utf-8"> +<title>CSS Color 4: Color Spaces of Untagged Colors</title> +<link rel="author" title="Chris Lilley" href="mailto:chris@w3.org"> +<link rel="help" href="https://drafts.csswg.org/css-color-4/#untagged"> +<meta name="assert" content=" untagged images must be treated as being in the sRGB color space"> +<link rel="match" href="009900-tagged-image-ref.html"> + +<body> + <p>Test passes if you see a green square, and no red.</p> + <!-- solid color #009900 PNG image, no color information --> + <p><img src="./support/009900.png" alt="should be green"/></p> +</body> \ No newline at end of file
diff --git a/third_party/blink/web_tests/external/wpt/css/css-contain/content-visibility/hidden-execcommand-crash.html b/third_party/blink/web_tests/external/wpt/css/css-contain/content-visibility/hidden-execcommand-crash.html new file mode 100644 index 0000000..67e297e6 --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-contain/content-visibility/hidden-execcommand-crash.html
@@ -0,0 +1,10 @@ +<!DOCTYPE html> +<link rel=author href="mailto:jarhar@chromium.org"> +<link rel=help href="https://bugs.chromium.org/p/chromium/issues/detail?id=1280134"> + +<html contenteditable=true> + X<div style="content-visibility:hidden">Y</div>Z +<script> +document.execCommand("selectall"); +document.execCommand("fontSize", false, 6); +</script>
diff --git a/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-001-ref.html b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-001-ref.html new file mode 100644 index 0000000..8c092bd --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-001-ref.html
@@ -0,0 +1,60 @@ +<!DOCTYPE html> +<html> + <meta charset="utf-8"> + <title>CSS Overflow Reference: scrollbar-gutter size contributes to the scroll container's intrinsic size with "overflow:auto"</title> + <link rel="author" title="Ting-Yu Lin" href="mailto:tlin@mozilla.com"> + <link rel="help" href="https://drafts.csswg.org/css-overflow-4/#scrollbar-gutter-property"> + + <style> + .line { + display: flex; + } + .container { + block-size: 50px; + border: 5px solid black; + scrollbar-gutter: stable; + margin: 10px; + } + .hidden { + overflow: hidden; + } + .scroll-x { + overflow-x: scroll; + } + .scroll-y { + overflow-y: scroll; + } + .tall { + /* trigger overflow */ + block-size: 5000px; + } + </style> + + <div class="line"> + <div class="container hidden"> + <div>I should not wrap</div> + </div> + + <div class="container scroll-y"> + <div class="tall">I should not wrap</div> + </div> + </div> + + <div class="line"> + <div class="container hidden" style="writing-mode: vertical-rl"> + <div>I should not wrap</div> + </div> + + <div class="container scroll-x" style="writing-mode: vertical-rl"> + <div class="tall">I should not wrap</div> + </div> + + <div class="container hidden" style="writing-mode: vertical-lr"> + <div>I should not wrap</div> + </div> + + <div class="container scroll-x" style="writing-mode: vertical-lr"> + <div class="tall">I should not wrap</div> + </div> + </div> +</html>
diff --git a/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-001.html b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-001.html new file mode 100644 index 0000000..061339b --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-001.html
@@ -0,0 +1,53 @@ +<!DOCTYPE html> +<html> + <meta charset="utf-8"> + <title>CSS Overflow Test: scrollbar-gutter size contributes to the scroll container's intrinsic size with "overflow:auto"</title> + <link rel="author" title="Ting-Yu Lin" href="mailto:tlin@mozilla.com"> + <link rel="help" href="https://drafts.csswg.org/css-overflow-4/#scrollbar-gutter-property"> + <link rel="match" href="overflow-auto-scrollbar-gutter-intrinsic-001-ref.html"> + + <style> + .line { + display: flex; + } + .container { + block-size: 50px; + border: 5px solid black; + overflow: auto; + scrollbar-gutter: stable; + margin: 10px; + } + .tall { + /* trigger overflow */ + block-size: 5000px; + } + </style> + + <div class="line"> + <div class="container"> + <div>I should not wrap</div> + </div> + + <div class="container"> + <div class="tall">I should not wrap</div> + </div> + </div> + + <div class="line"> + <div class="container" style="writing-mode: vertical-rl"> + <div>I should not wrap</div> + </div> + + <div class="container" style="writing-mode: vertical-rl"> + <div class="tall">I should not wrap</div> + </div> + + <div class="container" style="writing-mode: vertical-lr"> + <div>I should not wrap</div> + </div> + + <div class="container" style="writing-mode: vertical-lr"> + <div class="tall">I should not wrap</div> + </div> + </div> +</html>
diff --git a/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-002-ref.html b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-002-ref.html new file mode 100644 index 0000000..bb0742bf --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-002-ref.html
@@ -0,0 +1,60 @@ +<!DOCTYPE html> +<html> + <meta charset="utf-8"> + <title>CSS Overflow Reference: scrollbar-gutter size contributes to the scroll container's intrinsic size with "overflow:auto"</title> + <link rel="author" title="Ting-Yu Lin" href="mailto:tlin@mozilla.com"> + <link rel="help" href="https://drafts.csswg.org/css-overflow-4/#scrollbar-gutter-property"> + + <style> + .line { + display: flex; + } + .container { + block-size: 50px; + border: 5px solid black; + scrollbar-gutter: stable both-edges; + margin: 10px; + } + .hidden { + overflow: hidden; + } + .scroll-x { + overflow-x: scroll; + } + .scroll-y { + overflow-y: scroll; + } + .tall { + /* trigger overflow */ + block-size: 5000px; + } + </style> + + <div class="line"> + <div class="container hidden"> + <div>I should not wrap</div> + </div> + + <div class="container scroll-y"> + <div class="tall">I should not wrap</div> + </div> + </div> + + <div class="line"> + <div class="container hidden" style="writing-mode: vertical-rl"> + <div>I should not wrap</div> + </div> + + <div class="container scroll-x" style="writing-mode: vertical-rl"> + <div class="tall">I should not wrap</div> + </div> + + <div class="container hidden" style="writing-mode: vertical-lr"> + <div>I should not wrap</div> + </div> + + <div class="container scroll-x" style="writing-mode: vertical-lr"> + <div class="tall">I should not wrap</div> + </div> + </div> +</html>
diff --git a/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-002.html b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-002.html new file mode 100644 index 0000000..535f2c4 --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-002.html
@@ -0,0 +1,53 @@ +<!DOCTYPE html> +<html> + <meta charset="utf-8"> + <title>CSS Overflow Test: scrollbar-gutter size contributes to the scroll container's intrinsic size with "overflow:auto"</title> + <link rel="author" title="Ting-Yu Lin" href="mailto:tlin@mozilla.com"> + <link rel="help" href="https://drafts.csswg.org/css-overflow-4/#scrollbar-gutter-property"> + <link rel="match" href="overflow-auto-scrollbar-gutter-intrinsic-002-ref.html"> + + <style> + .line { + display: flex; + } + .container { + block-size: 50px; + border: 5px solid black; + overflow: auto; + scrollbar-gutter: stable both-edges; + margin: 10px; + } + .tall { + /* trigger overflow */ + block-size: 5000px; + } + </style> + + <div class="line"> + <div class="container"> + <div>I should not wrap</div> + </div> + + <div class="container"> + <div class="tall">I should not wrap</div> + </div> + </div> + + <div class="line"> + <div class="container" style="writing-mode: vertical-rl"> + <div>I should not wrap</div> + </div> + + <div class="container" style="writing-mode: vertical-rl"> + <div class="tall">I should not wrap</div> + </div> + + <div class="container" style="writing-mode: vertical-lr"> + <div>I should not wrap</div> + </div> + + <div class="container" style="writing-mode: vertical-lr"> + <div class="tall">I should not wrap</div> + </div> + </div> +</html>
diff --git a/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-003-ref.html b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-003-ref.html new file mode 100644 index 0000000..2fb2eb4 --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-003-ref.html
@@ -0,0 +1,47 @@ +<!DOCTYPE html> +<html> + <meta charset="utf-8"> + <title>CSS Overflow Test: scrollbar-gutter size doesn't contribute to the scroll container's intrinsic size with "overflow:auto" and "scrollbar-width: none"</title> + <link rel="author" title="Ting-Yu Lin" href="mailto:tlin@mozilla.com"> + <link rel="help" href="https://drafts.csswg.org/css-overflow-4/#scrollbar-gutter-property"> + <link rel="help" href="https://drafts.csswg.org/css-scrollbars/#scrollbar-width"> + + <style> + .line { + display: flex; + } + .container { + block-size: 50px; + border: 5px solid black; + margin: 10px; + } + </style> + + <div class="line"> + <div class="container"> + <div>I should not wrap</div> + </div> + + <div class="container"> + <div>I should not wrap</div> + </div> + </div> + + <div class="line"> + <div class="container" style="writing-mode: vertical-rl"> + <div>I should not wrap</div> + </div> + + <div class="container" style="writing-mode: vertical-rl"> + <div>I should not wrap</div> + </div> + + <div class="container" style="writing-mode: vertical-lr"> + <div>I should not wrap</div> + </div> + + <div class="container" style="writing-mode: vertical-lr"> + <div>I should not wrap</div> + </div> + </div> +</html>
diff --git a/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-003.html b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-003.html new file mode 100644 index 0000000..ab247d9c --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-auto-scrollbar-gutter-intrinsic-003.html
@@ -0,0 +1,55 @@ +<!DOCTYPE html> +<html> + <meta charset="utf-8"> + <title>CSS Overflow Test: scrollbar-gutter size doesn't contribute to the scroll container's intrinsic size with "overflow:auto" and "scrollbar-width: none"</title> + <link rel="author" title="Ting-Yu Lin" href="mailto:tlin@mozilla.com"> + <link rel="help" href="https://drafts.csswg.org/css-overflow-4/#scrollbar-gutter-property"> + <link rel="help" href="https://drafts.csswg.org/css-scrollbars/#scrollbar-width"> + <link rel="match" href="overflow-auto-scrollbar-gutter-intrinsic-003-ref.html"> + + <style> + .line { + display: flex; + } + .container { + block-size: 50px; + border: 5px solid black; + overflow: auto; + scrollbar-gutter: stable; + scrollbar-width: none; + margin: 10px; + } + .tall { + /* trigger overflow */ + block-size: 5000px; + } + </style> + + <div class="line"> + <div class="container"> + <div>I should not wrap</div> + </div> + + <div class="container"> + <div class="tall">I should not wrap</div> + </div> + </div> + + <div class="line"> + <div class="container" style="writing-mode: vertical-rl"> + <div>I should not wrap</div> + </div> + + <div class="container" style="writing-mode: vertical-rl"> + <div class="tall">I should not wrap</div> + </div> + + <div class="container" style="writing-mode: vertical-lr"> + <div>I should not wrap</div> + </div> + + <div class="container" style="writing-mode: vertical-lr"> + <div class="tall">I should not wrap</div> + </div> + </div> +</html>
diff --git a/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-scroll-intrinsic-001-ref.html b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-scroll-intrinsic-001-ref.html new file mode 100644 index 0000000..8870d33 --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-scroll-intrinsic-001-ref.html
@@ -0,0 +1,20 @@ +<!DOCTYPE html> +<html> + <meta charset="utf-8"> + <title>CSS Overflow Reference: Intrinsic size of a "overflow:auto" vertical scroll container</title> + <link rel="author" title="Daniel Holbert" href="mailto:dholbert@mozilla.com"> + <link rel="author" title="Ting-Yu Lin" href="mailto:tlin@mozilla.com"> + + <style> + .container { + border: 1px solid black; + width: 100px; + display: inline-block; + } + </style> + + <div class="container" style="overflow-x: scroll;"></div> + <div class="container" style="overflow-y: scroll;"></div> + <div class="container" style="overflow-x: scroll;"></div> + <div class="container" style="overflow-y: scroll;"></div> +</html>
diff --git a/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-scroll-intrinsic-001.html b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-scroll-intrinsic-001.html new file mode 100644 index 0000000..093fd283 --- /dev/null +++ b/third_party/blink/web_tests/external/wpt/css/css-overflow/overflow-scroll-intrinsic-001.html
@@ -0,0 +1,23 @@ +<!DOCTYPE html> +<html> + <meta charset="utf-8"> + <title>CSS Overflow Test: Intrinsic size of a "overflow:auto" vertical scroll container</title> + <link rel="author" title="Daniel Holbert" href="mailto:dholbert@mozilla.com"> + <link rel="author" title="Ting-Yu Lin" href="mailto:tlin@mozilla.com"> + <link rel="help" href="https://drafts.csswg.org/css-overflow-3/#overflow-properties"> + <link rel="match" href="overflow-scroll-intrinsic-001-ref.html"> + + <style> + .container { + border: 1px solid black; + width: 100px; + display: inline-block; + } + </style> + + <div class="container" style="writing-mode: vertical-rl; overflow-x: scroll;"></div> + <div class="container" style="writing-mode: vertical-rl; overflow-y: scroll;"></div> + <div class="container" style="writing-mode: vertical-lr; overflow-x: scroll;"></div> + <div class="container" style="writing-mode: vertical-lr; overflow-y: scroll;"></div> + +</html>
diff --git a/third_party/blink/web_tests/external/wpt/editing/data/inserthtml.js b/third_party/blink/web_tests/external/wpt/editing/data/inserthtml.js index 376d988b..1f0a8f5 100644 --- a/third_party/blink/web_tests/external/wpt/editing/data/inserthtml.js +++ b/third_party/blink/web_tests/external/wpt/editing/data/inserthtml.js
@@ -571,4 +571,18 @@ "<pre contenteditable=\"false\"><span contenteditable=\"\">abc<br></span></pre>"], [true], {"inserthtml":[false,false,"",false,false,""]}], + +// Empty inline elements shouldn't be deleted if they are inserted intentionally +["<div>a[]b</div>", + [["inserthtml","<span></span>"]], + ["<div>a<span></span>b</div>", + "<div>a<span></span>b<br></div>"], + [true], + {"inserthtml":[false,false,"",false,false,""]}], +["<div>a[]c</div>", + [["inserthtml","<span class=\"s1\"></span>b<span class=\"s2\"></span>"]], + ["<div>a<span class=\"s1\"></span>b<span class=\"s2\"></span>c</div>", + "<div>a<span class=\"s1\"></span>b<span class=\"s2\"></span>c<br></div>"], + [true], + {"inserthtml":[false,false,"",false,false,""]}], ]
diff --git a/third_party/blink/web_tests/external/wpt/editing/run/inserthtml-expected.txt b/third_party/blink/web_tests/external/wpt/editing/run/inserthtml-expected.txt index 2cae0df..06c23700 100644 --- a/third_party/blink/web_tests/external/wpt/editing/run/inserthtml-expected.txt +++ b/third_party/blink/web_tests/external/wpt/editing/run/inserthtml-expected.txt
@@ -1,5 +1,5 @@ This is a testharness.js-based test. -Found 1498 tests; 1444 PASS, 54 FAIL, 0 TIMEOUT, 0 NOTRUN. +Found 1516 tests; 1460 PASS, 56 FAIL, 0 TIMEOUT, 0 NOTRUN. PASS [["stylewithcss","true"],["inserthtml","ab<b>c</b>d"]] "foo[]bar": execCommand("stylewithcss", false, "true") return value PASS [["stylewithcss","true"],["inserthtml","ab<b>c</b>d"]] "foo[]bar": execCommand("inserthtml", false, "ab<b>c</b>d") return value PASS [["stylewithcss","true"],["inserthtml","ab<b>c</b>d"]] "foo[]bar" checks for modifications to non-editable content @@ -1498,5 +1498,23 @@ PASS [["inserthtml","<pre>abc</pre>"]] "<pre contenteditable=\"false\"><span contenteditable>[1234]</span></pre>" queryCommandIndeterm("inserthtml") after PASS [["inserthtml","<pre>abc</pre>"]] "<pre contenteditable=\"false\"><span contenteditable>[1234]</span></pre>" queryCommandState("inserthtml") after PASS [["inserthtml","<pre>abc</pre>"]] "<pre contenteditable=\"false\"><span contenteditable>[1234]</span></pre>" queryCommandValue("inserthtml") after +PASS [["inserthtml","<span></span>"]] "<div>a[]b</div>": execCommand("inserthtml", false, "<span></span>") return value +PASS [["inserthtml","<span></span>"]] "<div>a[]b</div>" checks for modifications to non-editable content +FAIL [["inserthtml","<span></span>"]] "<div>a[]b</div>" compare innerHTML assert_in_array: Unexpected innerHTML (after normalizing inline style) value "<div>ab</div>" not in array ["<div>a<span></span>b</div>", "<div>a<span></span>b<br></div>"] +PASS [["inserthtml","<span></span>"]] "<div>a[]b</div>" queryCommandIndeterm("inserthtml") before +PASS [["inserthtml","<span></span>"]] "<div>a[]b</div>" queryCommandState("inserthtml") before +PASS [["inserthtml","<span></span>"]] "<div>a[]b</div>" queryCommandValue("inserthtml") before +PASS [["inserthtml","<span></span>"]] "<div>a[]b</div>" queryCommandIndeterm("inserthtml") after +PASS [["inserthtml","<span></span>"]] "<div>a[]b</div>" queryCommandState("inserthtml") after +PASS [["inserthtml","<span></span>"]] "<div>a[]b</div>" queryCommandValue("inserthtml") after +PASS [["inserthtml","<span class=\"s1\"></span>b<span class=\"s2\"></span>"]] "<div>a[]c</div>": execCommand("inserthtml", false, "<span class=\"s1\"></span>b<span class=\"s2\"></span>") return value +PASS [["inserthtml","<span class=\"s1\"></span>b<span class=\"s2\"></span>"]] "<div>a[]c</div>" checks for modifications to non-editable content +FAIL [["inserthtml","<span class=\"s1\"></span>b<span class=\"s2\"></span>"]] "<div>a[]c</div>" compare innerHTML assert_in_array: Unexpected innerHTML (after normalizing inline style) value "<div>ab</div><span class=\"s1\"></span><span class=\"s2\"></span><div>c<br></div>" not in array ["<div>a<span class=\"s1\"></span>b<span class=\"s2\"></span>c</div>", "<div>a<span class=\"s1\"></span>b<span class=\"s2\"></span>c<br></div>"] +PASS [["inserthtml","<span class=\"s1\"></span>b<span class=\"s2\"></span>"]] "<div>a[]c</div>" queryCommandIndeterm("inserthtml") before +PASS [["inserthtml","<span class=\"s1\"></span>b<span class=\"s2\"></span>"]] "<div>a[]c</div>" queryCommandState("inserthtml") before +PASS [["inserthtml","<span class=\"s1\"></span>b<span class=\"s2\"></span>"]] "<div>a[]c</div>" queryCommandValue("inserthtml") before +PASS [["inserthtml","<span class=\"s1\"></span>b<span class=\"s2\"></span>"]] "<div>a[]c</div>" queryCommandIndeterm("inserthtml") after +PASS [["inserthtml","<span class=\"s1\"></span>b<span class=\"s2\"></span>"]] "<div>a[]c</div>" queryCommandState("inserthtml") after +PASS [["inserthtml","<span class=\"s1\"></span>b<span class=\"s2\"></span>"]] "<div>a[]c</div>" queryCommandValue("inserthtml") after Harness: the test ran to completion.
diff --git a/third_party/blink/web_tests/external/wpt/html/cross-origin-embedder-policy/anonymous-iframe/local-storage.tentative.https.window-expected.txt b/third_party/blink/web_tests/external/wpt/html/cross-origin-embedder-policy/anonymous-iframe/local-storage.tentative.https.window-expected.txt deleted file mode 100644 index 0f6267c..0000000 --- a/third_party/blink/web_tests/external/wpt/html/cross-origin-embedder-policy/anonymous-iframe/local-storage.tentative.https.window-expected.txt +++ /dev/null
@@ -1,6 +0,0 @@ -This is a testharness.js-based test. -PASS Setup -FAIL same_origin anonymous iframe can't access the localStorage assert_equals: expected "" but got "same_origin" -FAIL cross_origin anonymous iframe can't access the localStorage assert_equals: expected "" but got "cross_origin" -Harness: the test ran to completion. -
diff --git a/third_party/blink/web_tests/platform/mac-mac10.13/virtual/plz-dedicated-worker/external/wpt/referrer-policy/gen/srcdoc.meta/strict-origin/iframe-tag.http-expected.txt b/third_party/blink/web_tests/platform/mac-mac10.13/virtual/plz-dedicated-worker/external/wpt/referrer-policy/gen/srcdoc.meta/strict-origin/iframe-tag.http-expected.txt new file mode 100644 index 0000000..04a9370 --- /dev/null +++ b/third_party/blink/web_tests/platform/mac-mac10.13/virtual/plz-dedicated-worker/external/wpt/referrer-policy/gen/srcdoc.meta/strict-origin/iframe-tag.http-expected.txt
@@ -0,0 +1,15 @@ +This is a testharness.js-based test. +PASS Referrer Policy: Expects origin for iframe-tag to cross-http origin and keep-origin redirection from http context. +PASS Referrer Policy: Expects origin for iframe-tag to cross-http origin and no-redirect redirection from http context. +PASS Referrer Policy: Expects origin for iframe-tag to cross-http origin and swap-origin redirection from http context. +PASS Referrer Policy: Expects origin for iframe-tag to cross-https origin and keep-origin redirection from http context. +PASS Referrer Policy: Expects origin for iframe-tag to cross-https origin and no-redirect redirection from http context. +PASS Referrer Policy: Expects origin for iframe-tag to cross-https origin and swap-origin redirection from http context. +PASS Referrer Policy: Expects origin for iframe-tag to same-http origin and keep-origin redirection from http context. +PASS Referrer Policy: Expects origin for iframe-tag to same-http origin and no-redirect redirection from http context. +PASS Referrer Policy: Expects origin for iframe-tag to same-http origin and swap-origin redirection from http context. +PASS Referrer Policy: Expects origin for iframe-tag to same-https origin and keep-origin redirection from http context. +PASS Referrer Policy: Expects origin for iframe-tag to same-https origin and no-redirect redirection from http context. +FAIL Referrer Policy: Expects origin for iframe-tag to same-https origin and swap-origin redirection from http context. assert_in_array: document.referrer value "http://web-platform.test:8001/referrer-policy/gen/srcdoc.meta/strict-origin/iframe-tag.http.html" not in array ["http://web-platform.test:8001/", undefined] +Harness: the test ran to completion. +
diff --git a/third_party/blink/web_tests/virtual/third-party-storage-partitioning/external/wpt/webstorage/localstorage-basic-partitioned.tentative.sub-expected.txt b/third_party/blink/web_tests/virtual/third-party-storage-partitioning/external/wpt/webstorage/localstorage-basic-partitioned.tentative.sub-expected.txt index 6cdbef6..9d229db 100644 --- a/third_party/blink/web_tests/virtual/third-party-storage-partitioning/external/wpt/webstorage/localstorage-basic-partitioned.tentative.sub-expected.txt +++ b/third_party/blink/web_tests/virtual/third-party-storage-partitioning/external/wpt/webstorage/localstorage-basic-partitioned.tentative.sub-expected.txt
@@ -1,4 +1,4 @@ This is a testharness.js-based test. -FAIL Simple test for partitioned localStorage assert_true: IDs pulled from two partitioned iframes are different. expected true got false +PASS Simple test for partitioned localStorage Harness: the test ran to completion.
diff --git a/third_party/boringssl/BUILD.generated.gni b/third_party/boringssl/BUILD.generated.gni index fed3f02..5e57473 100644 --- a/third_party/boringssl/BUILD.generated.gni +++ b/third_party/boringssl/BUILD.generated.gni
@@ -78,14 +78,15 @@ "src/crypto/conf/conf.c", "src/crypto/conf/conf_def.h", "src/crypto/conf/internal.h", - "src/crypto/cpu-aarch64-fuchsia.c", - "src/crypto/cpu-aarch64-linux.c", - "src/crypto/cpu-aarch64-win.c", - "src/crypto/cpu-arm-linux.c", - "src/crypto/cpu-arm-linux.h", - "src/crypto/cpu-arm.c", - "src/crypto/cpu-intel.c", - "src/crypto/cpu-ppc64le.c", + "src/crypto/cpu_aarch64_apple.c", + "src/crypto/cpu_aarch64_fuchsia.c", + "src/crypto/cpu_aarch64_linux.c", + "src/crypto/cpu_aarch64_win.c", + "src/crypto/cpu_arm.c", + "src/crypto/cpu_arm_linux.c", + "src/crypto/cpu_arm_linux.h", + "src/crypto/cpu_intel.c", + "src/crypto/cpu_ppc64le.c", "src/crypto/crypto.c", "src/crypto/curve25519/curve25519.c", "src/crypto/curve25519/curve25519_tables.h",
diff --git a/third_party/boringssl/BUILD.generated_tests.gni b/third_party/boringssl/BUILD.generated_tests.gni index 20a4212d..9e612df8 100644 --- a/third_party/boringssl/BUILD.generated_tests.gni +++ b/third_party/boringssl/BUILD.generated_tests.gni
@@ -40,7 +40,7 @@ "src/crypto/compiler_test.cc", "src/crypto/conf/conf_test.cc", "src/crypto/constant_time_test.cc", - "src/crypto/cpu-arm-linux_test.cc", + "src/crypto/cpu_arm_linux_test.cc", "src/crypto/crypto_test.cc", "src/crypto/curve25519/ed25519_test.cc", "src/crypto/curve25519/spake25519_test.cc", @@ -91,7 +91,6 @@ "src/crypto/x509/x509_test.cc", "src/crypto/x509/x509_time_test.cc", "src/crypto/x509v3/tab_test.cc", - "src/crypto/x509v3/v3name_test.cc", ] crypto_test_data = [
diff --git a/third_party/boringssl/BUILD.gn b/third_party/boringssl/BUILD.gn index 91ce539..e31d0921 100644 --- a/third_party/boringssl/BUILD.gn +++ b/third_party/boringssl/BUILD.gn
@@ -35,9 +35,6 @@ "BORINGSSL_NO_STATIC_INITIALIZER", "OPENSSL_SMALL", ] - if (is_posix || is_fuchsia) { - defines += [ "_XOPEN_SOURCE=700" ] - } } config("no_asm_config") {
diff --git a/third_party/boringssl/ios-aarch64/crypto/fipsmodule/sha256-armv8.S b/third_party/boringssl/ios-aarch64/crypto/fipsmodule/sha256-armv8.S index c9b7991..b40b260 100644 --- a/third_party/boringssl/ios-aarch64/crypto/fipsmodule/sha256-armv8.S +++ b/third_party/boringssl/ios-aarch64/crypto/fipsmodule/sha256-armv8.S
@@ -12,7 +12,7 @@ #if defined(BORINGSSL_PREFIX) #include <boringssl_prefix_symbols_asm.h> #endif -// Copyright 2014-2016 The OpenSSL Project Authors. All Rights Reserved. +// Copyright 2014-2020 The OpenSSL Project Authors. All Rights Reserved. // // Licensed under the OpenSSL license (the "License"). You may not use // this file except in compliance with the License. You can obtain a copy @@ -40,6 +40,7 @@ // Denver 2.01 10.5 (+26%) 6.70 (+8%) // X-Gene 20.0 (+100%) 12.8 (+300%(***)) // Mongoose 2.36 13.0 (+50%) 8.36 (+33%) +// Kryo 1.92 17.4 (+30%) 11.2 (+8%) // // (*) Software SHA256 results are of lesser relevance, presented // mostly for informational purposes. @@ -48,7 +49,7 @@ // on Cortex-A53 (or by 4 cycles per round). // (***) Super-impressive coefficients over gcc-generated code are // indication of some compiler "pathology", most notably code -// generated with -mgeneral-regs-only is significanty faster +// generated with -mgeneral-regs-only is significantly faster // and the gap is only 40-90%. #ifndef __KERNEL__ @@ -100,7 +101,7 @@ ldr w19,[x30],#4 // *K++ eor w28,w21,w22 // magic seed str x1,[x29,#112] -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w3,w3 // 0 #endif ror w16,w24,#6 @@ -123,7 +124,7 @@ add w27,w27,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w27,w27,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w4,w4 // 1 #endif ldp w5,w6,[x1],#2*4 @@ -148,7 +149,7 @@ add w26,w26,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w26,w26,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w5,w5 // 2 #endif add w26,w26,w17 // h+=Sigma0(a) @@ -172,7 +173,7 @@ add w25,w25,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w25,w25,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w6,w6 // 3 #endif ldp w7,w8,[x1],#2*4 @@ -197,7 +198,7 @@ add w24,w24,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w24,w24,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w7,w7 // 4 #endif add w24,w24,w17 // h+=Sigma0(a) @@ -221,7 +222,7 @@ add w23,w23,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w23,w23,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w8,w8 // 5 #endif ldp w9,w10,[x1],#2*4 @@ -246,7 +247,7 @@ add w22,w22,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w22,w22,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w9,w9 // 6 #endif add w22,w22,w17 // h+=Sigma0(a) @@ -270,7 +271,7 @@ add w21,w21,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w21,w21,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w10,w10 // 7 #endif ldp w11,w12,[x1],#2*4 @@ -295,7 +296,7 @@ add w20,w20,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w20,w20,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w11,w11 // 8 #endif add w20,w20,w17 // h+=Sigma0(a) @@ -319,7 +320,7 @@ add w27,w27,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w27,w27,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w12,w12 // 9 #endif ldp w13,w14,[x1],#2*4 @@ -344,7 +345,7 @@ add w26,w26,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w26,w26,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w13,w13 // 10 #endif add w26,w26,w17 // h+=Sigma0(a) @@ -368,7 +369,7 @@ add w25,w25,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w25,w25,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w14,w14 // 11 #endif ldp w15,w0,[x1],#2*4 @@ -394,7 +395,7 @@ add w24,w24,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w24,w24,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w15,w15 // 12 #endif add w24,w24,w17 // h+=Sigma0(a) @@ -419,7 +420,7 @@ add w23,w23,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w23,w23,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w0,w0 // 13 #endif ldp w1,w2,[x1] @@ -445,7 +446,7 @@ add w22,w22,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w22,w22,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w1,w1 // 14 #endif ldr w6,[sp,#12] @@ -471,7 +472,7 @@ add w21,w21,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w21,w21,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w2,w2 // 15 #endif ldr w7,[sp,#0]
diff --git a/third_party/boringssl/ios-aarch64/crypto/fipsmodule/sha512-armv8.S b/third_party/boringssl/ios-aarch64/crypto/fipsmodule/sha512-armv8.S index 97b3230..b2d366d 100644 --- a/third_party/boringssl/ios-aarch64/crypto/fipsmodule/sha512-armv8.S +++ b/third_party/boringssl/ios-aarch64/crypto/fipsmodule/sha512-armv8.S
@@ -12,7 +12,7 @@ #if defined(BORINGSSL_PREFIX) #include <boringssl_prefix_symbols_asm.h> #endif -// Copyright 2014-2016 The OpenSSL Project Authors. All Rights Reserved. +// Copyright 2014-2020 The OpenSSL Project Authors. All Rights Reserved. // // Licensed under the OpenSSL license (the "License"). You may not use // this file except in compliance with the License. You can obtain a copy @@ -40,6 +40,7 @@ // Denver 2.01 10.5 (+26%) 6.70 (+8%) // X-Gene 20.0 (+100%) 12.8 (+300%(***)) // Mongoose 2.36 13.0 (+50%) 8.36 (+33%) +// Kryo 1.92 17.4 (+30%) 11.2 (+8%) // // (*) Software SHA256 results are of lesser relevance, presented // mostly for informational purposes. @@ -48,7 +49,7 @@ // on Cortex-A53 (or by 4 cycles per round). // (***) Super-impressive coefficients over gcc-generated code are // indication of some compiler "pathology", most notably code -// generated with -mgeneral-regs-only is significanty faster +// generated with -mgeneral-regs-only is significantly faster // and the gap is only 40-90%. #ifndef __KERNEL__ @@ -64,6 +65,17 @@ .align 6 _sha512_block_data_order: + AARCH64_VALID_CALL_TARGET +#ifndef __KERNEL__ +#if __has_feature(hwaddress_sanitizer) && __clang_major__ >= 10 + adrp x16,:pg_hi21_nc:_OPENSSL_armcap_P +#else + adrp x16,_OPENSSL_armcap_P@PAGE +#endif + ldr w16,[x16,_OPENSSL_armcap_P@PAGEOFF] + tst w16,#ARMV8_SHA512 + b.ne Lv8_entry +#endif AARCH64_SIGN_LINK_REGISTER stp x29,x30,[sp,#-128]! add x29,sp,#0 @@ -89,7 +101,7 @@ ldr x19,[x30],#8 // *K++ eor x28,x21,x22 // magic seed str x1,[x29,#112] -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x3,x3 // 0 #endif ror x16,x24,#14 @@ -112,7 +124,7 @@ add x27,x27,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x27,x27,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x4,x4 // 1 #endif ldp x5,x6,[x1],#2*8 @@ -137,7 +149,7 @@ add x26,x26,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x26,x26,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x5,x5 // 2 #endif add x26,x26,x17 // h+=Sigma0(a) @@ -161,7 +173,7 @@ add x25,x25,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x25,x25,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x6,x6 // 3 #endif ldp x7,x8,[x1],#2*8 @@ -186,7 +198,7 @@ add x24,x24,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x24,x24,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x7,x7 // 4 #endif add x24,x24,x17 // h+=Sigma0(a) @@ -210,7 +222,7 @@ add x23,x23,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x23,x23,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x8,x8 // 5 #endif ldp x9,x10,[x1],#2*8 @@ -235,7 +247,7 @@ add x22,x22,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x22,x22,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x9,x9 // 6 #endif add x22,x22,x17 // h+=Sigma0(a) @@ -259,7 +271,7 @@ add x21,x21,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x21,x21,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x10,x10 // 7 #endif ldp x11,x12,[x1],#2*8 @@ -284,7 +296,7 @@ add x20,x20,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x20,x20,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x11,x11 // 8 #endif add x20,x20,x17 // h+=Sigma0(a) @@ -308,7 +320,7 @@ add x27,x27,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x27,x27,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x12,x12 // 9 #endif ldp x13,x14,[x1],#2*8 @@ -333,7 +345,7 @@ add x26,x26,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x26,x26,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x13,x13 // 10 #endif add x26,x26,x17 // h+=Sigma0(a) @@ -357,7 +369,7 @@ add x25,x25,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x25,x25,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x14,x14 // 11 #endif ldp x15,x0,[x1],#2*8 @@ -383,7 +395,7 @@ add x24,x24,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x24,x24,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x15,x15 // 12 #endif add x24,x24,x17 // h+=Sigma0(a) @@ -408,7 +420,7 @@ add x23,x23,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x23,x23,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x0,x0 // 13 #endif ldp x1,x2,[x1] @@ -434,7 +446,7 @@ add x22,x22,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x22,x22,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x1,x1 // 14 #endif ldr x6,[sp,#24] @@ -460,7 +472,7 @@ add x21,x21,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x21,x21,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x2,x2 // 15 #endif ldr x7,[sp,#0] @@ -1078,4 +1090,525 @@ .byte 83,72,65,53,49,50,32,98,108,111,99,107,32,116,114,97,110,115,102,111,114,109,32,102,111,114,32,65,82,77,118,56,44,32,67,82,89,80,84,79,71,65,77,83,32,98,121,32,60,97,112,112,114,111,64,111,112,101,110,115,115,108,46,111,114,103,62,0 .align 2 .align 2 +.text +#ifndef __KERNEL__ + +.align 6 +sha512_block_armv8: +Lv8_entry: + stp x29,x30,[sp,#-16]! + add x29,sp,#0 + + ld1 {v16.16b,v17.16b,v18.16b,v19.16b},[x1],#64 // load input + ld1 {v20.16b,v21.16b,v22.16b,v23.16b},[x1],#64 + + ld1 {v0.2d,v1.2d,v2.2d,v3.2d},[x0] // load context + adrp x3,LK512@PAGE + add x3,x3,LK512@PAGEOFF + + rev64 v16.16b,v16.16b + rev64 v17.16b,v17.16b + rev64 v18.16b,v18.16b + rev64 v19.16b,v19.16b + rev64 v20.16b,v20.16b + rev64 v21.16b,v21.16b + rev64 v22.16b,v22.16b + rev64 v23.16b,v23.16b + b Loop_hw + +.align 4 +Loop_hw: + ld1 {v24.2d},[x3],#16 + subs x2,x2,#1 + sub x4,x1,#128 + orr v26.16b,v0.16b,v0.16b // offload + orr v27.16b,v1.16b,v1.16b + orr v28.16b,v2.16b,v2.16b + orr v29.16b,v3.16b,v3.16b + csel x1,x1,x4,ne // conditional rewind + add v24.2d,v24.2d,v16.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08230 //sha512su0 v16.16b,v17.16b + ext v7.16b,v20.16b,v21.16b,#8 +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.long 0xce678af0 //sha512su1 v16.16b,v23.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v25.2d,v25.2d,v17.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08251 //sha512su0 v17.16b,v18.16b + ext v7.16b,v21.16b,v22.16b,#8 +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.long 0xce678a11 //sha512su1 v17.16b,v16.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v24.2d,v24.2d,v18.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08272 //sha512su0 v18.16b,v19.16b + ext v7.16b,v22.16b,v23.16b,#8 +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.long 0xce678a32 //sha512su1 v18.16b,v17.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v25.2d,v25.2d,v19.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08293 //sha512su0 v19.16b,v20.16b + ext v7.16b,v23.16b,v16.16b,#8 +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.long 0xce678a53 //sha512su1 v19.16b,v18.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v24.2d,v24.2d,v20.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082b4 //sha512su0 v20.16b,v21.16b + ext v7.16b,v16.16b,v17.16b,#8 +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.long 0xce678a74 //sha512su1 v20.16b,v19.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v25.2d,v25.2d,v21.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec082d5 //sha512su0 v21.16b,v22.16b + ext v7.16b,v17.16b,v18.16b,#8 +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.long 0xce678a95 //sha512su1 v21.16b,v20.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v24.2d,v24.2d,v22.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082f6 //sha512su0 v22.16b,v23.16b + ext v7.16b,v18.16b,v19.16b,#8 +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.long 0xce678ab6 //sha512su1 v22.16b,v21.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v25.2d,v25.2d,v23.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08217 //sha512su0 v23.16b,v16.16b + ext v7.16b,v19.16b,v20.16b,#8 +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.long 0xce678ad7 //sha512su1 v23.16b,v22.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v24.2d,v24.2d,v16.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08230 //sha512su0 v16.16b,v17.16b + ext v7.16b,v20.16b,v21.16b,#8 +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.long 0xce678af0 //sha512su1 v16.16b,v23.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v25.2d,v25.2d,v17.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08251 //sha512su0 v17.16b,v18.16b + ext v7.16b,v21.16b,v22.16b,#8 +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.long 0xce678a11 //sha512su1 v17.16b,v16.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v24.2d,v24.2d,v18.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08272 //sha512su0 v18.16b,v19.16b + ext v7.16b,v22.16b,v23.16b,#8 +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.long 0xce678a32 //sha512su1 v18.16b,v17.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v25.2d,v25.2d,v19.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08293 //sha512su0 v19.16b,v20.16b + ext v7.16b,v23.16b,v16.16b,#8 +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.long 0xce678a53 //sha512su1 v19.16b,v18.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v24.2d,v24.2d,v20.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082b4 //sha512su0 v20.16b,v21.16b + ext v7.16b,v16.16b,v17.16b,#8 +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.long 0xce678a74 //sha512su1 v20.16b,v19.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v25.2d,v25.2d,v21.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec082d5 //sha512su0 v21.16b,v22.16b + ext v7.16b,v17.16b,v18.16b,#8 +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.long 0xce678a95 //sha512su1 v21.16b,v20.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v24.2d,v24.2d,v22.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082f6 //sha512su0 v22.16b,v23.16b + ext v7.16b,v18.16b,v19.16b,#8 +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.long 0xce678ab6 //sha512su1 v22.16b,v21.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v25.2d,v25.2d,v23.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08217 //sha512su0 v23.16b,v16.16b + ext v7.16b,v19.16b,v20.16b,#8 +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.long 0xce678ad7 //sha512su1 v23.16b,v22.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v24.2d,v24.2d,v16.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08230 //sha512su0 v16.16b,v17.16b + ext v7.16b,v20.16b,v21.16b,#8 +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.long 0xce678af0 //sha512su1 v16.16b,v23.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v25.2d,v25.2d,v17.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08251 //sha512su0 v17.16b,v18.16b + ext v7.16b,v21.16b,v22.16b,#8 +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.long 0xce678a11 //sha512su1 v17.16b,v16.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v24.2d,v24.2d,v18.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08272 //sha512su0 v18.16b,v19.16b + ext v7.16b,v22.16b,v23.16b,#8 +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.long 0xce678a32 //sha512su1 v18.16b,v17.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v25.2d,v25.2d,v19.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08293 //sha512su0 v19.16b,v20.16b + ext v7.16b,v23.16b,v16.16b,#8 +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.long 0xce678a53 //sha512su1 v19.16b,v18.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v24.2d,v24.2d,v20.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082b4 //sha512su0 v20.16b,v21.16b + ext v7.16b,v16.16b,v17.16b,#8 +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.long 0xce678a74 //sha512su1 v20.16b,v19.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v25.2d,v25.2d,v21.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec082d5 //sha512su0 v21.16b,v22.16b + ext v7.16b,v17.16b,v18.16b,#8 +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.long 0xce678a95 //sha512su1 v21.16b,v20.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v24.2d,v24.2d,v22.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082f6 //sha512su0 v22.16b,v23.16b + ext v7.16b,v18.16b,v19.16b,#8 +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.long 0xce678ab6 //sha512su1 v22.16b,v21.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v25.2d,v25.2d,v23.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08217 //sha512su0 v23.16b,v16.16b + ext v7.16b,v19.16b,v20.16b,#8 +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.long 0xce678ad7 //sha512su1 v23.16b,v22.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v24.2d,v24.2d,v16.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08230 //sha512su0 v16.16b,v17.16b + ext v7.16b,v20.16b,v21.16b,#8 +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.long 0xce678af0 //sha512su1 v16.16b,v23.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v25.2d,v25.2d,v17.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08251 //sha512su0 v17.16b,v18.16b + ext v7.16b,v21.16b,v22.16b,#8 +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.long 0xce678a11 //sha512su1 v17.16b,v16.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v24.2d,v24.2d,v18.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08272 //sha512su0 v18.16b,v19.16b + ext v7.16b,v22.16b,v23.16b,#8 +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.long 0xce678a32 //sha512su1 v18.16b,v17.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v25.2d,v25.2d,v19.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08293 //sha512su0 v19.16b,v20.16b + ext v7.16b,v23.16b,v16.16b,#8 +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.long 0xce678a53 //sha512su1 v19.16b,v18.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v24.2d,v24.2d,v20.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082b4 //sha512su0 v20.16b,v21.16b + ext v7.16b,v16.16b,v17.16b,#8 +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.long 0xce678a74 //sha512su1 v20.16b,v19.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v25.2d,v25.2d,v21.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec082d5 //sha512su0 v21.16b,v22.16b + ext v7.16b,v17.16b,v18.16b,#8 +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.long 0xce678a95 //sha512su1 v21.16b,v20.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v24.2d,v24.2d,v22.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082f6 //sha512su0 v22.16b,v23.16b + ext v7.16b,v18.16b,v19.16b,#8 +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.long 0xce678ab6 //sha512su1 v22.16b,v21.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v25.2d,v25.2d,v23.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08217 //sha512su0 v23.16b,v16.16b + ext v7.16b,v19.16b,v20.16b,#8 +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.long 0xce678ad7 //sha512su1 v23.16b,v22.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + ld1 {v25.2d},[x3],#16 + add v24.2d,v24.2d,v16.2d + ld1 {v16.16b},[x1],#16 // load next input + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v24.2d // "T1 + H + K512[i]" +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b + rev64 v16.16b,v16.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + ld1 {v24.2d},[x3],#16 + add v25.2d,v25.2d,v17.2d + ld1 {v17.16b},[x1],#16 // load next input + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v25.2d // "T1 + H + K512[i]" +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b + rev64 v17.16b,v17.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + ld1 {v25.2d},[x3],#16 + add v24.2d,v24.2d,v18.2d + ld1 {v18.16b},[x1],#16 // load next input + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v24.2d // "T1 + H + K512[i]" +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b + rev64 v18.16b,v18.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + ld1 {v24.2d},[x3],#16 + add v25.2d,v25.2d,v19.2d + ld1 {v19.16b},[x1],#16 // load next input + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v25.2d // "T1 + H + K512[i]" +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b + rev64 v19.16b,v19.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + ld1 {v25.2d},[x3],#16 + add v24.2d,v24.2d,v20.2d + ld1 {v20.16b},[x1],#16 // load next input + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v24.2d // "T1 + H + K512[i]" +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b + rev64 v20.16b,v20.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + ld1 {v24.2d},[x3],#16 + add v25.2d,v25.2d,v21.2d + ld1 {v21.16b},[x1],#16 // load next input + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v25.2d // "T1 + H + K512[i]" +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b + rev64 v21.16b,v21.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + ld1 {v25.2d},[x3],#16 + add v24.2d,v24.2d,v22.2d + ld1 {v22.16b},[x1],#16 // load next input + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v24.2d // "T1 + H + K512[i]" +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b + rev64 v22.16b,v22.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + sub x3,x3,#80*8 // rewind + add v25.2d,v25.2d,v23.2d + ld1 {v23.16b},[x1],#16 // load next input + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v25.2d // "T1 + H + K512[i]" +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b + rev64 v23.16b,v23.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v0.2d,v0.2d,v26.2d // accumulate + add v1.2d,v1.2d,v27.2d + add v2.2d,v2.2d,v28.2d + add v3.2d,v3.2d,v29.2d + + cbnz x2,Loop_hw + + st1 {v0.2d,v1.2d,v2.2d,v3.2d},[x0] // store context + + ldr x29,[sp],#16 + ret + +#endif #endif // !OPENSSL_NO_ASM
diff --git a/third_party/boringssl/linux-aarch64/crypto/fipsmodule/sha256-armv8.S b/third_party/boringssl/linux-aarch64/crypto/fipsmodule/sha256-armv8.S index a4f170e..c777ec8 100644 --- a/third_party/boringssl/linux-aarch64/crypto/fipsmodule/sha256-armv8.S +++ b/third_party/boringssl/linux-aarch64/crypto/fipsmodule/sha256-armv8.S
@@ -13,7 +13,7 @@ #if defined(BORINGSSL_PREFIX) #include <boringssl_prefix_symbols_asm.h> #endif -// Copyright 2014-2016 The OpenSSL Project Authors. All Rights Reserved. +// Copyright 2014-2020 The OpenSSL Project Authors. All Rights Reserved. // // Licensed under the OpenSSL license (the "License"). You may not use // this file except in compliance with the License. You can obtain a copy @@ -41,6 +41,7 @@ // Denver 2.01 10.5 (+26%) 6.70 (+8%) // X-Gene 20.0 (+100%) 12.8 (+300%(***)) // Mongoose 2.36 13.0 (+50%) 8.36 (+33%) +// Kryo 1.92 17.4 (+30%) 11.2 (+8%) // // (*) Software SHA256 results are of lesser relevance, presented // mostly for informational purposes. @@ -49,7 +50,7 @@ // on Cortex-A53 (or by 4 cycles per round). // (***) Super-impressive coefficients over gcc-generated code are // indication of some compiler "pathology", most notably code -// generated with -mgeneral-regs-only is significanty faster +// generated with -mgeneral-regs-only is significantly faster // and the gap is only 40-90%. #ifndef __KERNEL__ @@ -101,7 +102,7 @@ ldr w19,[x30],#4 // *K++ eor w28,w21,w22 // magic seed str x1,[x29,#112] -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w3,w3 // 0 #endif ror w16,w24,#6 @@ -124,7 +125,7 @@ add w27,w27,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w27,w27,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w4,w4 // 1 #endif ldp w5,w6,[x1],#2*4 @@ -149,7 +150,7 @@ add w26,w26,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w26,w26,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w5,w5 // 2 #endif add w26,w26,w17 // h+=Sigma0(a) @@ -173,7 +174,7 @@ add w25,w25,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w25,w25,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w6,w6 // 3 #endif ldp w7,w8,[x1],#2*4 @@ -198,7 +199,7 @@ add w24,w24,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w24,w24,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w7,w7 // 4 #endif add w24,w24,w17 // h+=Sigma0(a) @@ -222,7 +223,7 @@ add w23,w23,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w23,w23,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w8,w8 // 5 #endif ldp w9,w10,[x1],#2*4 @@ -247,7 +248,7 @@ add w22,w22,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w22,w22,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w9,w9 // 6 #endif add w22,w22,w17 // h+=Sigma0(a) @@ -271,7 +272,7 @@ add w21,w21,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w21,w21,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w10,w10 // 7 #endif ldp w11,w12,[x1],#2*4 @@ -296,7 +297,7 @@ add w20,w20,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w20,w20,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w11,w11 // 8 #endif add w20,w20,w17 // h+=Sigma0(a) @@ -320,7 +321,7 @@ add w27,w27,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w27,w27,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w12,w12 // 9 #endif ldp w13,w14,[x1],#2*4 @@ -345,7 +346,7 @@ add w26,w26,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w26,w26,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w13,w13 // 10 #endif add w26,w26,w17 // h+=Sigma0(a) @@ -369,7 +370,7 @@ add w25,w25,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w25,w25,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w14,w14 // 11 #endif ldp w15,w0,[x1],#2*4 @@ -395,7 +396,7 @@ add w24,w24,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w24,w24,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w15,w15 // 12 #endif add w24,w24,w17 // h+=Sigma0(a) @@ -420,7 +421,7 @@ add w23,w23,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w23,w23,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w0,w0 // 13 #endif ldp w1,w2,[x1] @@ -446,7 +447,7 @@ add w22,w22,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w22,w22,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w1,w1 // 14 #endif ldr w6,[sp,#12] @@ -472,7 +473,7 @@ add w21,w21,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w21,w21,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w2,w2 // 15 #endif ldr w7,[sp,#0]
diff --git a/third_party/boringssl/linux-aarch64/crypto/fipsmodule/sha512-armv8.S b/third_party/boringssl/linux-aarch64/crypto/fipsmodule/sha512-armv8.S index 98b7a7e2..a3b458a 100644 --- a/third_party/boringssl/linux-aarch64/crypto/fipsmodule/sha512-armv8.S +++ b/third_party/boringssl/linux-aarch64/crypto/fipsmodule/sha512-armv8.S
@@ -13,7 +13,7 @@ #if defined(BORINGSSL_PREFIX) #include <boringssl_prefix_symbols_asm.h> #endif -// Copyright 2014-2016 The OpenSSL Project Authors. All Rights Reserved. +// Copyright 2014-2020 The OpenSSL Project Authors. All Rights Reserved. // // Licensed under the OpenSSL license (the "License"). You may not use // this file except in compliance with the License. You can obtain a copy @@ -41,6 +41,7 @@ // Denver 2.01 10.5 (+26%) 6.70 (+8%) // X-Gene 20.0 (+100%) 12.8 (+300%(***)) // Mongoose 2.36 13.0 (+50%) 8.36 (+33%) +// Kryo 1.92 17.4 (+30%) 11.2 (+8%) // // (*) Software SHA256 results are of lesser relevance, presented // mostly for informational purposes. @@ -49,7 +50,7 @@ // on Cortex-A53 (or by 4 cycles per round). // (***) Super-impressive coefficients over gcc-generated code are // indication of some compiler "pathology", most notably code -// generated with -mgeneral-regs-only is significanty faster +// generated with -mgeneral-regs-only is significantly faster // and the gap is only 40-90%. #ifndef __KERNEL__ @@ -65,6 +66,17 @@ .type sha512_block_data_order,%function .align 6 sha512_block_data_order: + AARCH64_VALID_CALL_TARGET +#ifndef __KERNEL__ +#if __has_feature(hwaddress_sanitizer) && __clang_major__ >= 10 + adrp x16,:pg_hi21_nc:OPENSSL_armcap_P +#else + adrp x16,OPENSSL_armcap_P +#endif + ldr w16,[x16,:lo12:OPENSSL_armcap_P] + tst w16,#ARMV8_SHA512 + b.ne .Lv8_entry +#endif AARCH64_SIGN_LINK_REGISTER stp x29,x30,[sp,#-128]! add x29,sp,#0 @@ -90,7 +102,7 @@ ldr x19,[x30],#8 // *K++ eor x28,x21,x22 // magic seed str x1,[x29,#112] -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x3,x3 // 0 #endif ror x16,x24,#14 @@ -113,7 +125,7 @@ add x27,x27,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x27,x27,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x4,x4 // 1 #endif ldp x5,x6,[x1],#2*8 @@ -138,7 +150,7 @@ add x26,x26,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x26,x26,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x5,x5 // 2 #endif add x26,x26,x17 // h+=Sigma0(a) @@ -162,7 +174,7 @@ add x25,x25,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x25,x25,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x6,x6 // 3 #endif ldp x7,x8,[x1],#2*8 @@ -187,7 +199,7 @@ add x24,x24,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x24,x24,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x7,x7 // 4 #endif add x24,x24,x17 // h+=Sigma0(a) @@ -211,7 +223,7 @@ add x23,x23,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x23,x23,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x8,x8 // 5 #endif ldp x9,x10,[x1],#2*8 @@ -236,7 +248,7 @@ add x22,x22,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x22,x22,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x9,x9 // 6 #endif add x22,x22,x17 // h+=Sigma0(a) @@ -260,7 +272,7 @@ add x21,x21,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x21,x21,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x10,x10 // 7 #endif ldp x11,x12,[x1],#2*8 @@ -285,7 +297,7 @@ add x20,x20,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x20,x20,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x11,x11 // 8 #endif add x20,x20,x17 // h+=Sigma0(a) @@ -309,7 +321,7 @@ add x27,x27,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x27,x27,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x12,x12 // 9 #endif ldp x13,x14,[x1],#2*8 @@ -334,7 +346,7 @@ add x26,x26,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x26,x26,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x13,x13 // 10 #endif add x26,x26,x17 // h+=Sigma0(a) @@ -358,7 +370,7 @@ add x25,x25,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x25,x25,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x14,x14 // 11 #endif ldp x15,x0,[x1],#2*8 @@ -384,7 +396,7 @@ add x24,x24,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x24,x24,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x15,x15 // 12 #endif add x24,x24,x17 // h+=Sigma0(a) @@ -409,7 +421,7 @@ add x23,x23,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x23,x23,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x0,x0 // 13 #endif ldp x1,x2,[x1] @@ -435,7 +447,7 @@ add x22,x22,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x22,x22,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x1,x1 // 14 #endif ldr x6,[sp,#24] @@ -461,7 +473,7 @@ add x21,x21,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x21,x21,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x2,x2 // 15 #endif ldr x7,[sp,#0] @@ -1079,6 +1091,527 @@ .byte 83,72,65,53,49,50,32,98,108,111,99,107,32,116,114,97,110,115,102,111,114,109,32,102,111,114,32,65,82,77,118,56,44,32,67,82,89,80,84,79,71,65,77,83,32,98,121,32,60,97,112,112,114,111,64,111,112,101,110,115,115,108,46,111,114,103,62,0 .align 2 .align 2 +.text +#ifndef __KERNEL__ +.type sha512_block_armv8,%function +.align 6 +sha512_block_armv8: +.Lv8_entry: + stp x29,x30,[sp,#-16]! + add x29,sp,#0 + + ld1 {v16.16b,v17.16b,v18.16b,v19.16b},[x1],#64 // load input + ld1 {v20.16b,v21.16b,v22.16b,v23.16b},[x1],#64 + + ld1 {v0.2d,v1.2d,v2.2d,v3.2d},[x0] // load context + adrp x3,.LK512 + add x3,x3,:lo12:.LK512 + + rev64 v16.16b,v16.16b + rev64 v17.16b,v17.16b + rev64 v18.16b,v18.16b + rev64 v19.16b,v19.16b + rev64 v20.16b,v20.16b + rev64 v21.16b,v21.16b + rev64 v22.16b,v22.16b + rev64 v23.16b,v23.16b + b .Loop_hw + +.align 4 +.Loop_hw: + ld1 {v24.2d},[x3],#16 + subs x2,x2,#1 + sub x4,x1,#128 + orr v26.16b,v0.16b,v0.16b // offload + orr v27.16b,v1.16b,v1.16b + orr v28.16b,v2.16b,v2.16b + orr v29.16b,v3.16b,v3.16b + csel x1,x1,x4,ne // conditional rewind + add v24.2d,v24.2d,v16.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec08230 //sha512su0 v16.16b,v17.16b + ext v7.16b,v20.16b,v21.16b,#8 +.inst 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.inst 0xce678af0 //sha512su1 v16.16b,v23.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.inst 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v25.2d,v25.2d,v17.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec08251 //sha512su0 v17.16b,v18.16b + ext v7.16b,v21.16b,v22.16b,#8 +.inst 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.inst 0xce678a11 //sha512su1 v17.16b,v16.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.inst 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v24.2d,v24.2d,v18.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec08272 //sha512su0 v18.16b,v19.16b + ext v7.16b,v22.16b,v23.16b,#8 +.inst 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.inst 0xce678a32 //sha512su1 v18.16b,v17.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.inst 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v25.2d,v25.2d,v19.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec08293 //sha512su0 v19.16b,v20.16b + ext v7.16b,v23.16b,v16.16b,#8 +.inst 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.inst 0xce678a53 //sha512su1 v19.16b,v18.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.inst 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v24.2d,v24.2d,v20.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec082b4 //sha512su0 v20.16b,v21.16b + ext v7.16b,v16.16b,v17.16b,#8 +.inst 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.inst 0xce678a74 //sha512su1 v20.16b,v19.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.inst 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v25.2d,v25.2d,v21.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec082d5 //sha512su0 v21.16b,v22.16b + ext v7.16b,v17.16b,v18.16b,#8 +.inst 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.inst 0xce678a95 //sha512su1 v21.16b,v20.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.inst 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v24.2d,v24.2d,v22.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec082f6 //sha512su0 v22.16b,v23.16b + ext v7.16b,v18.16b,v19.16b,#8 +.inst 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.inst 0xce678ab6 //sha512su1 v22.16b,v21.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.inst 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v25.2d,v25.2d,v23.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec08217 //sha512su0 v23.16b,v16.16b + ext v7.16b,v19.16b,v20.16b,#8 +.inst 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.inst 0xce678ad7 //sha512su1 v23.16b,v22.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.inst 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v24.2d,v24.2d,v16.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec08230 //sha512su0 v16.16b,v17.16b + ext v7.16b,v20.16b,v21.16b,#8 +.inst 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.inst 0xce678af0 //sha512su1 v16.16b,v23.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.inst 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v25.2d,v25.2d,v17.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec08251 //sha512su0 v17.16b,v18.16b + ext v7.16b,v21.16b,v22.16b,#8 +.inst 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.inst 0xce678a11 //sha512su1 v17.16b,v16.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.inst 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v24.2d,v24.2d,v18.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec08272 //sha512su0 v18.16b,v19.16b + ext v7.16b,v22.16b,v23.16b,#8 +.inst 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.inst 0xce678a32 //sha512su1 v18.16b,v17.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.inst 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v25.2d,v25.2d,v19.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec08293 //sha512su0 v19.16b,v20.16b + ext v7.16b,v23.16b,v16.16b,#8 +.inst 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.inst 0xce678a53 //sha512su1 v19.16b,v18.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.inst 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v24.2d,v24.2d,v20.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec082b4 //sha512su0 v20.16b,v21.16b + ext v7.16b,v16.16b,v17.16b,#8 +.inst 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.inst 0xce678a74 //sha512su1 v20.16b,v19.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.inst 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v25.2d,v25.2d,v21.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec082d5 //sha512su0 v21.16b,v22.16b + ext v7.16b,v17.16b,v18.16b,#8 +.inst 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.inst 0xce678a95 //sha512su1 v21.16b,v20.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.inst 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v24.2d,v24.2d,v22.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec082f6 //sha512su0 v22.16b,v23.16b + ext v7.16b,v18.16b,v19.16b,#8 +.inst 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.inst 0xce678ab6 //sha512su1 v22.16b,v21.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.inst 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v25.2d,v25.2d,v23.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec08217 //sha512su0 v23.16b,v16.16b + ext v7.16b,v19.16b,v20.16b,#8 +.inst 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.inst 0xce678ad7 //sha512su1 v23.16b,v22.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.inst 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v24.2d,v24.2d,v16.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec08230 //sha512su0 v16.16b,v17.16b + ext v7.16b,v20.16b,v21.16b,#8 +.inst 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.inst 0xce678af0 //sha512su1 v16.16b,v23.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.inst 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v25.2d,v25.2d,v17.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec08251 //sha512su0 v17.16b,v18.16b + ext v7.16b,v21.16b,v22.16b,#8 +.inst 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.inst 0xce678a11 //sha512su1 v17.16b,v16.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.inst 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v24.2d,v24.2d,v18.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec08272 //sha512su0 v18.16b,v19.16b + ext v7.16b,v22.16b,v23.16b,#8 +.inst 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.inst 0xce678a32 //sha512su1 v18.16b,v17.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.inst 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v25.2d,v25.2d,v19.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec08293 //sha512su0 v19.16b,v20.16b + ext v7.16b,v23.16b,v16.16b,#8 +.inst 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.inst 0xce678a53 //sha512su1 v19.16b,v18.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.inst 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v24.2d,v24.2d,v20.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec082b4 //sha512su0 v20.16b,v21.16b + ext v7.16b,v16.16b,v17.16b,#8 +.inst 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.inst 0xce678a74 //sha512su1 v20.16b,v19.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.inst 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v25.2d,v25.2d,v21.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec082d5 //sha512su0 v21.16b,v22.16b + ext v7.16b,v17.16b,v18.16b,#8 +.inst 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.inst 0xce678a95 //sha512su1 v21.16b,v20.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.inst 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v24.2d,v24.2d,v22.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec082f6 //sha512su0 v22.16b,v23.16b + ext v7.16b,v18.16b,v19.16b,#8 +.inst 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.inst 0xce678ab6 //sha512su1 v22.16b,v21.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.inst 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v25.2d,v25.2d,v23.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec08217 //sha512su0 v23.16b,v16.16b + ext v7.16b,v19.16b,v20.16b,#8 +.inst 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.inst 0xce678ad7 //sha512su1 v23.16b,v22.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.inst 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v24.2d,v24.2d,v16.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec08230 //sha512su0 v16.16b,v17.16b + ext v7.16b,v20.16b,v21.16b,#8 +.inst 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.inst 0xce678af0 //sha512su1 v16.16b,v23.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.inst 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v25.2d,v25.2d,v17.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec08251 //sha512su0 v17.16b,v18.16b + ext v7.16b,v21.16b,v22.16b,#8 +.inst 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.inst 0xce678a11 //sha512su1 v17.16b,v16.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.inst 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v24.2d,v24.2d,v18.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec08272 //sha512su0 v18.16b,v19.16b + ext v7.16b,v22.16b,v23.16b,#8 +.inst 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.inst 0xce678a32 //sha512su1 v18.16b,v17.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.inst 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v25.2d,v25.2d,v19.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec08293 //sha512su0 v19.16b,v20.16b + ext v7.16b,v23.16b,v16.16b,#8 +.inst 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.inst 0xce678a53 //sha512su1 v19.16b,v18.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.inst 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v24.2d,v24.2d,v20.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec082b4 //sha512su0 v20.16b,v21.16b + ext v7.16b,v16.16b,v17.16b,#8 +.inst 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.inst 0xce678a74 //sha512su1 v20.16b,v19.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.inst 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v25.2d,v25.2d,v21.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec082d5 //sha512su0 v21.16b,v22.16b + ext v7.16b,v17.16b,v18.16b,#8 +.inst 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.inst 0xce678a95 //sha512su1 v21.16b,v20.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.inst 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v24.2d,v24.2d,v22.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xcec082f6 //sha512su0 v22.16b,v23.16b + ext v7.16b,v18.16b,v19.16b,#8 +.inst 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.inst 0xce678ab6 //sha512su1 v22.16b,v21.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.inst 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v25.2d,v25.2d,v23.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xcec08217 //sha512su0 v23.16b,v16.16b + ext v7.16b,v19.16b,v20.16b,#8 +.inst 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.inst 0xce678ad7 //sha512su1 v23.16b,v22.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.inst 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + ld1 {v25.2d},[x3],#16 + add v24.2d,v24.2d,v16.2d + ld1 {v16.16b},[x1],#16 // load next input + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b + rev64 v16.16b,v16.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.inst 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + ld1 {v24.2d},[x3],#16 + add v25.2d,v25.2d,v17.2d + ld1 {v17.16b},[x1],#16 // load next input + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b + rev64 v17.16b,v17.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.inst 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + ld1 {v25.2d},[x3],#16 + add v24.2d,v24.2d,v18.2d + ld1 {v18.16b},[x1],#16 // load next input + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b + rev64 v18.16b,v18.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.inst 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + ld1 {v24.2d},[x3],#16 + add v25.2d,v25.2d,v19.2d + ld1 {v19.16b},[x1],#16 // load next input + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b + rev64 v19.16b,v19.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.inst 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + ld1 {v25.2d},[x3],#16 + add v24.2d,v24.2d,v20.2d + ld1 {v20.16b},[x1],#16 // load next input + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b + rev64 v20.16b,v20.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.inst 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + ld1 {v24.2d},[x3],#16 + add v25.2d,v25.2d,v21.2d + ld1 {v21.16b},[x1],#16 // load next input + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b + rev64 v21.16b,v21.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.inst 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + ld1 {v25.2d},[x3],#16 + add v24.2d,v24.2d,v22.2d + ld1 {v22.16b},[x1],#16 // load next input + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v24.2d // "T1 + H + K512[i]" +.inst 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b + rev64 v22.16b,v22.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.inst 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + sub x3,x3,#80*8 // rewind + add v25.2d,v25.2d,v23.2d + ld1 {v23.16b},[x1],#16 // load next input + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v25.2d // "T1 + H + K512[i]" +.inst 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b + rev64 v23.16b,v23.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.inst 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v0.2d,v0.2d,v26.2d // accumulate + add v1.2d,v1.2d,v27.2d + add v2.2d,v2.2d,v28.2d + add v3.2d,v3.2d,v29.2d + + cbnz x2,.Loop_hw + + st1 {v0.2d,v1.2d,v2.2d,v3.2d},[x0] // store context + + ldr x29,[sp],#16 + ret +.size sha512_block_armv8,.-sha512_block_armv8 +#endif #endif #endif // !OPENSSL_NO_ASM .section .note.GNU-stack,"",%progbits
diff --git a/third_party/boringssl/win-aarch64/crypto/fipsmodule/sha256-armv8.S b/third_party/boringssl/win-aarch64/crypto/fipsmodule/sha256-armv8.S index 5f233f9..15970af 100644 --- a/third_party/boringssl/win-aarch64/crypto/fipsmodule/sha256-armv8.S +++ b/third_party/boringssl/win-aarch64/crypto/fipsmodule/sha256-armv8.S
@@ -13,7 +13,7 @@ #if defined(BORINGSSL_PREFIX) #include <boringssl_prefix_symbols_asm.h> #endif -// Copyright 2014-2016 The OpenSSL Project Authors. All Rights Reserved. +// Copyright 2014-2020 The OpenSSL Project Authors. All Rights Reserved. // // Licensed under the OpenSSL license (the "License"). You may not use // this file except in compliance with the License. You can obtain a copy @@ -41,6 +41,7 @@ // Denver 2.01 10.5 (+26%) 6.70 (+8%) // X-Gene 20.0 (+100%) 12.8 (+300%(***)) // Mongoose 2.36 13.0 (+50%) 8.36 (+33%) +// Kryo 1.92 17.4 (+30%) 11.2 (+8%) // // (*) Software SHA256 results are of lesser relevance, presented // mostly for informational purposes. @@ -49,7 +50,7 @@ // on Cortex-A53 (or by 4 cycles per round). // (***) Super-impressive coefficients over gcc-generated code are // indication of some compiler "pathology", most notably code -// generated with -mgeneral-regs-only is significanty faster +// generated with -mgeneral-regs-only is significantly faster // and the gap is only 40-90%. #ifndef __KERNEL__ @@ -103,7 +104,7 @@ ldr w19,[x30],#4 // *K++ eor w28,w21,w22 // magic seed str x1,[x29,#112] -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w3,w3 // 0 #endif ror w16,w24,#6 @@ -126,7 +127,7 @@ add w27,w27,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w27,w27,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w4,w4 // 1 #endif ldp w5,w6,[x1],#2*4 @@ -151,7 +152,7 @@ add w26,w26,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w26,w26,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w5,w5 // 2 #endif add w26,w26,w17 // h+=Sigma0(a) @@ -175,7 +176,7 @@ add w25,w25,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w25,w25,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w6,w6 // 3 #endif ldp w7,w8,[x1],#2*4 @@ -200,7 +201,7 @@ add w24,w24,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w24,w24,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w7,w7 // 4 #endif add w24,w24,w17 // h+=Sigma0(a) @@ -224,7 +225,7 @@ add w23,w23,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w23,w23,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w8,w8 // 5 #endif ldp w9,w10,[x1],#2*4 @@ -249,7 +250,7 @@ add w22,w22,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w22,w22,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w9,w9 // 6 #endif add w22,w22,w17 // h+=Sigma0(a) @@ -273,7 +274,7 @@ add w21,w21,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w21,w21,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w10,w10 // 7 #endif ldp w11,w12,[x1],#2*4 @@ -298,7 +299,7 @@ add w20,w20,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w20,w20,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w11,w11 // 8 #endif add w20,w20,w17 // h+=Sigma0(a) @@ -322,7 +323,7 @@ add w27,w27,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w27,w27,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w12,w12 // 9 #endif ldp w13,w14,[x1],#2*4 @@ -347,7 +348,7 @@ add w26,w26,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w26,w26,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w13,w13 // 10 #endif add w26,w26,w17 // h+=Sigma0(a) @@ -371,7 +372,7 @@ add w25,w25,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w25,w25,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w14,w14 // 11 #endif ldp w15,w0,[x1],#2*4 @@ -397,7 +398,7 @@ add w24,w24,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w24,w24,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w15,w15 // 12 #endif add w24,w24,w17 // h+=Sigma0(a) @@ -422,7 +423,7 @@ add w23,w23,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w23,w23,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w0,w0 // 13 #endif ldp w1,w2,[x1] @@ -448,7 +449,7 @@ add w22,w22,w19 // h+=Maj(a,b,c) ldr w19,[x30],#4 // *K++, w28 in next round //add w22,w22,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w1,w1 // 14 #endif ldr w6,[sp,#12] @@ -474,7 +475,7 @@ add w21,w21,w28 // h+=Maj(a,b,c) ldr w28,[x30],#4 // *K++, w19 in next round //add w21,w21,w17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev w2,w2 // 15 #endif ldr w7,[sp,#0]
diff --git a/third_party/boringssl/win-aarch64/crypto/fipsmodule/sha512-armv8.S b/third_party/boringssl/win-aarch64/crypto/fipsmodule/sha512-armv8.S index d01304f..b2b5a7e 100644 --- a/third_party/boringssl/win-aarch64/crypto/fipsmodule/sha512-armv8.S +++ b/third_party/boringssl/win-aarch64/crypto/fipsmodule/sha512-armv8.S
@@ -13,7 +13,7 @@ #if defined(BORINGSSL_PREFIX) #include <boringssl_prefix_symbols_asm.h> #endif -// Copyright 2014-2016 The OpenSSL Project Authors. All Rights Reserved. +// Copyright 2014-2020 The OpenSSL Project Authors. All Rights Reserved. // // Licensed under the OpenSSL license (the "License"). You may not use // this file except in compliance with the License. You can obtain a copy @@ -41,6 +41,7 @@ // Denver 2.01 10.5 (+26%) 6.70 (+8%) // X-Gene 20.0 (+100%) 12.8 (+300%(***)) // Mongoose 2.36 13.0 (+50%) 8.36 (+33%) +// Kryo 1.92 17.4 (+30%) 11.2 (+8%) // // (*) Software SHA256 results are of lesser relevance, presented // mostly for informational purposes. @@ -49,7 +50,7 @@ // on Cortex-A53 (or by 4 cycles per round). // (***) Super-impressive coefficients over gcc-generated code are // indication of some compiler "pathology", most notably code -// generated with -mgeneral-regs-only is significanty faster +// generated with -mgeneral-regs-only is significantly faster // and the gap is only 40-90%. #ifndef __KERNEL__ @@ -67,6 +68,17 @@ .endef .align 6 sha512_block_data_order: + AARCH64_VALID_CALL_TARGET +#ifndef __KERNEL__ +#if __has_feature(hwaddress_sanitizer) && __clang_major__ >= 10 + adrp x16,:pg_hi21_nc:OPENSSL_armcap_P +#else + adrp x16,OPENSSL_armcap_P +#endif + ldr w16,[x16,:lo12:OPENSSL_armcap_P] + tst w16,#ARMV8_SHA512 + b.ne Lv8_entry +#endif AARCH64_SIGN_LINK_REGISTER stp x29,x30,[sp,#-128]! add x29,sp,#0 @@ -92,7 +104,7 @@ ldr x19,[x30],#8 // *K++ eor x28,x21,x22 // magic seed str x1,[x29,#112] -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x3,x3 // 0 #endif ror x16,x24,#14 @@ -115,7 +127,7 @@ add x27,x27,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x27,x27,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x4,x4 // 1 #endif ldp x5,x6,[x1],#2*8 @@ -140,7 +152,7 @@ add x26,x26,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x26,x26,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x5,x5 // 2 #endif add x26,x26,x17 // h+=Sigma0(a) @@ -164,7 +176,7 @@ add x25,x25,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x25,x25,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x6,x6 // 3 #endif ldp x7,x8,[x1],#2*8 @@ -189,7 +201,7 @@ add x24,x24,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x24,x24,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x7,x7 // 4 #endif add x24,x24,x17 // h+=Sigma0(a) @@ -213,7 +225,7 @@ add x23,x23,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x23,x23,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x8,x8 // 5 #endif ldp x9,x10,[x1],#2*8 @@ -238,7 +250,7 @@ add x22,x22,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x22,x22,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x9,x9 // 6 #endif add x22,x22,x17 // h+=Sigma0(a) @@ -262,7 +274,7 @@ add x21,x21,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x21,x21,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x10,x10 // 7 #endif ldp x11,x12,[x1],#2*8 @@ -287,7 +299,7 @@ add x20,x20,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x20,x20,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x11,x11 // 8 #endif add x20,x20,x17 // h+=Sigma0(a) @@ -311,7 +323,7 @@ add x27,x27,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x27,x27,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x12,x12 // 9 #endif ldp x13,x14,[x1],#2*8 @@ -336,7 +348,7 @@ add x26,x26,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x26,x26,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x13,x13 // 10 #endif add x26,x26,x17 // h+=Sigma0(a) @@ -360,7 +372,7 @@ add x25,x25,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x25,x25,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x14,x14 // 11 #endif ldp x15,x0,[x1],#2*8 @@ -386,7 +398,7 @@ add x24,x24,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x24,x24,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x15,x15 // 12 #endif add x24,x24,x17 // h+=Sigma0(a) @@ -411,7 +423,7 @@ add x23,x23,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x23,x23,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x0,x0 // 13 #endif ldp x1,x2,[x1] @@ -437,7 +449,7 @@ add x22,x22,x19 // h+=Maj(a,b,c) ldr x19,[x30],#8 // *K++, x28 in next round //add x22,x22,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x1,x1 // 14 #endif ldr x6,[sp,#24] @@ -463,7 +475,7 @@ add x21,x21,x28 // h+=Maj(a,b,c) ldr x28,[x30],#8 // *K++, x19 in next round //add x21,x21,x17 // h+=Sigma0(a) -#ifndef __ARMEB__ +#ifndef __AARCH64EB__ rev x2,x2 // 15 #endif ldr x7,[sp,#0] @@ -1081,5 +1093,528 @@ .byte 83,72,65,53,49,50,32,98,108,111,99,107,32,116,114,97,110,115,102,111,114,109,32,102,111,114,32,65,82,77,118,56,44,32,67,82,89,80,84,79,71,65,77,83,32,98,121,32,60,97,112,112,114,111,64,111,112,101,110,115,115,108,46,111,114,103,62,0 .align 2 .align 2 +.text +#ifndef __KERNEL__ +.def sha512_block_armv8 + .type 32 +.endef +.align 6 +sha512_block_armv8: +Lv8_entry: + stp x29,x30,[sp,#-16]! + add x29,sp,#0 + + ld1 {v16.16b,v17.16b,v18.16b,v19.16b},[x1],#64 // load input + ld1 {v20.16b,v21.16b,v22.16b,v23.16b},[x1],#64 + + ld1 {v0.2d,v1.2d,v2.2d,v3.2d},[x0] // load context + adrp x3,LK512 + add x3,x3,:lo12:LK512 + + rev64 v16.16b,v16.16b + rev64 v17.16b,v17.16b + rev64 v18.16b,v18.16b + rev64 v19.16b,v19.16b + rev64 v20.16b,v20.16b + rev64 v21.16b,v21.16b + rev64 v22.16b,v22.16b + rev64 v23.16b,v23.16b + b Loop_hw + +.align 4 +Loop_hw: + ld1 {v24.2d},[x3],#16 + subs x2,x2,#1 + sub x4,x1,#128 + orr v26.16b,v0.16b,v0.16b // offload + orr v27.16b,v1.16b,v1.16b + orr v28.16b,v2.16b,v2.16b + orr v29.16b,v3.16b,v3.16b + csel x1,x1,x4,ne // conditional rewind + add v24.2d,v24.2d,v16.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08230 //sha512su0 v16.16b,v17.16b + ext v7.16b,v20.16b,v21.16b,#8 +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.long 0xce678af0 //sha512su1 v16.16b,v23.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v25.2d,v25.2d,v17.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08251 //sha512su0 v17.16b,v18.16b + ext v7.16b,v21.16b,v22.16b,#8 +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.long 0xce678a11 //sha512su1 v17.16b,v16.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v24.2d,v24.2d,v18.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08272 //sha512su0 v18.16b,v19.16b + ext v7.16b,v22.16b,v23.16b,#8 +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.long 0xce678a32 //sha512su1 v18.16b,v17.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v25.2d,v25.2d,v19.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08293 //sha512su0 v19.16b,v20.16b + ext v7.16b,v23.16b,v16.16b,#8 +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.long 0xce678a53 //sha512su1 v19.16b,v18.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v24.2d,v24.2d,v20.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082b4 //sha512su0 v20.16b,v21.16b + ext v7.16b,v16.16b,v17.16b,#8 +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.long 0xce678a74 //sha512su1 v20.16b,v19.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v25.2d,v25.2d,v21.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec082d5 //sha512su0 v21.16b,v22.16b + ext v7.16b,v17.16b,v18.16b,#8 +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.long 0xce678a95 //sha512su1 v21.16b,v20.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v24.2d,v24.2d,v22.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082f6 //sha512su0 v22.16b,v23.16b + ext v7.16b,v18.16b,v19.16b,#8 +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.long 0xce678ab6 //sha512su1 v22.16b,v21.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v25.2d,v25.2d,v23.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08217 //sha512su0 v23.16b,v16.16b + ext v7.16b,v19.16b,v20.16b,#8 +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.long 0xce678ad7 //sha512su1 v23.16b,v22.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v24.2d,v24.2d,v16.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08230 //sha512su0 v16.16b,v17.16b + ext v7.16b,v20.16b,v21.16b,#8 +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.long 0xce678af0 //sha512su1 v16.16b,v23.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v25.2d,v25.2d,v17.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08251 //sha512su0 v17.16b,v18.16b + ext v7.16b,v21.16b,v22.16b,#8 +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.long 0xce678a11 //sha512su1 v17.16b,v16.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v24.2d,v24.2d,v18.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08272 //sha512su0 v18.16b,v19.16b + ext v7.16b,v22.16b,v23.16b,#8 +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.long 0xce678a32 //sha512su1 v18.16b,v17.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v25.2d,v25.2d,v19.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08293 //sha512su0 v19.16b,v20.16b + ext v7.16b,v23.16b,v16.16b,#8 +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.long 0xce678a53 //sha512su1 v19.16b,v18.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v24.2d,v24.2d,v20.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082b4 //sha512su0 v20.16b,v21.16b + ext v7.16b,v16.16b,v17.16b,#8 +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.long 0xce678a74 //sha512su1 v20.16b,v19.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v25.2d,v25.2d,v21.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec082d5 //sha512su0 v21.16b,v22.16b + ext v7.16b,v17.16b,v18.16b,#8 +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.long 0xce678a95 //sha512su1 v21.16b,v20.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v24.2d,v24.2d,v22.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082f6 //sha512su0 v22.16b,v23.16b + ext v7.16b,v18.16b,v19.16b,#8 +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.long 0xce678ab6 //sha512su1 v22.16b,v21.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v25.2d,v25.2d,v23.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08217 //sha512su0 v23.16b,v16.16b + ext v7.16b,v19.16b,v20.16b,#8 +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.long 0xce678ad7 //sha512su1 v23.16b,v22.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v24.2d,v24.2d,v16.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08230 //sha512su0 v16.16b,v17.16b + ext v7.16b,v20.16b,v21.16b,#8 +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.long 0xce678af0 //sha512su1 v16.16b,v23.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v25.2d,v25.2d,v17.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08251 //sha512su0 v17.16b,v18.16b + ext v7.16b,v21.16b,v22.16b,#8 +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.long 0xce678a11 //sha512su1 v17.16b,v16.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v24.2d,v24.2d,v18.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08272 //sha512su0 v18.16b,v19.16b + ext v7.16b,v22.16b,v23.16b,#8 +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.long 0xce678a32 //sha512su1 v18.16b,v17.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v25.2d,v25.2d,v19.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08293 //sha512su0 v19.16b,v20.16b + ext v7.16b,v23.16b,v16.16b,#8 +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.long 0xce678a53 //sha512su1 v19.16b,v18.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v24.2d,v24.2d,v20.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082b4 //sha512su0 v20.16b,v21.16b + ext v7.16b,v16.16b,v17.16b,#8 +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.long 0xce678a74 //sha512su1 v20.16b,v19.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v25.2d,v25.2d,v21.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec082d5 //sha512su0 v21.16b,v22.16b + ext v7.16b,v17.16b,v18.16b,#8 +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.long 0xce678a95 //sha512su1 v21.16b,v20.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v24.2d,v24.2d,v22.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082f6 //sha512su0 v22.16b,v23.16b + ext v7.16b,v18.16b,v19.16b,#8 +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.long 0xce678ab6 //sha512su1 v22.16b,v21.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v25.2d,v25.2d,v23.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08217 //sha512su0 v23.16b,v16.16b + ext v7.16b,v19.16b,v20.16b,#8 +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.long 0xce678ad7 //sha512su1 v23.16b,v22.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v24.2d,v24.2d,v16.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08230 //sha512su0 v16.16b,v17.16b + ext v7.16b,v20.16b,v21.16b,#8 +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.long 0xce678af0 //sha512su1 v16.16b,v23.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v25.2d,v25.2d,v17.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08251 //sha512su0 v17.16b,v18.16b + ext v7.16b,v21.16b,v22.16b,#8 +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.long 0xce678a11 //sha512su1 v17.16b,v16.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v24.2d,v24.2d,v18.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec08272 //sha512su0 v18.16b,v19.16b + ext v7.16b,v22.16b,v23.16b,#8 +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.long 0xce678a32 //sha512su1 v18.16b,v17.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + add v25.2d,v25.2d,v19.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08293 //sha512su0 v19.16b,v20.16b + ext v7.16b,v23.16b,v16.16b,#8 +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b +.long 0xce678a53 //sha512su1 v19.16b,v18.16b,v7.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + add v24.2d,v24.2d,v20.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082b4 //sha512su0 v20.16b,v21.16b + ext v7.16b,v16.16b,v17.16b,#8 +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b +.long 0xce678a74 //sha512su1 v20.16b,v19.16b,v7.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + add v25.2d,v25.2d,v21.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec082d5 //sha512su0 v21.16b,v22.16b + ext v7.16b,v17.16b,v18.16b,#8 +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b +.long 0xce678a95 //sha512su1 v21.16b,v20.16b,v7.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v24.2d,v24.2d,v22.2d + ld1 {v25.2d},[x3],#16 + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v24.2d // "T1 + H + K512[i]" +.long 0xcec082f6 //sha512su0 v22.16b,v23.16b + ext v7.16b,v18.16b,v19.16b,#8 +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b +.long 0xce678ab6 //sha512su1 v22.16b,v21.16b,v7.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + add v25.2d,v25.2d,v23.2d + ld1 {v24.2d},[x3],#16 + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v25.2d // "T1 + H + K512[i]" +.long 0xcec08217 //sha512su0 v23.16b,v16.16b + ext v7.16b,v19.16b,v20.16b,#8 +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b +.long 0xce678ad7 //sha512su1 v23.16b,v22.16b,v7.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + ld1 {v25.2d},[x3],#16 + add v24.2d,v24.2d,v16.2d + ld1 {v16.16b},[x1],#16 // load next input + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v24.2d // "T1 + H + K512[i]" +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b + rev64 v16.16b,v16.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + ld1 {v24.2d},[x3],#16 + add v25.2d,v25.2d,v17.2d + ld1 {v17.16b},[x1],#16 // load next input + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v25.2d // "T1 + H + K512[i]" +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b + rev64 v17.16b,v17.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + ld1 {v25.2d},[x3],#16 + add v24.2d,v24.2d,v18.2d + ld1 {v18.16b},[x1],#16 // load next input + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v24.2d // "T1 + H + K512[i]" +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b + rev64 v18.16b,v18.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + ld1 {v24.2d},[x3],#16 + add v25.2d,v25.2d,v19.2d + ld1 {v19.16b},[x1],#16 // load next input + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v2.16b,v3.16b,#8 + ext v6.16b,v1.16b,v2.16b,#8 + add v3.2d,v3.2d,v25.2d // "T1 + H + K512[i]" +.long 0xce6680a3 //sha512h v3.16b,v5.16b,v6.16b + rev64 v19.16b,v19.16b + add v4.2d,v1.2d,v3.2d // "D + T1" +.long 0xce608423 //sha512h2 v3.16b,v1.16b,v0.16b + ld1 {v25.2d},[x3],#16 + add v24.2d,v24.2d,v20.2d + ld1 {v20.16b},[x1],#16 // load next input + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v4.16b,v2.16b,#8 + ext v6.16b,v0.16b,v4.16b,#8 + add v2.2d,v2.2d,v24.2d // "T1 + H + K512[i]" +.long 0xce6680a2 //sha512h v2.16b,v5.16b,v6.16b + rev64 v20.16b,v20.16b + add v1.2d,v0.2d,v2.2d // "D + T1" +.long 0xce638402 //sha512h2 v2.16b,v0.16b,v3.16b + ld1 {v24.2d},[x3],#16 + add v25.2d,v25.2d,v21.2d + ld1 {v21.16b},[x1],#16 // load next input + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v1.16b,v4.16b,#8 + ext v6.16b,v3.16b,v1.16b,#8 + add v4.2d,v4.2d,v25.2d // "T1 + H + K512[i]" +.long 0xce6680a4 //sha512h v4.16b,v5.16b,v6.16b + rev64 v21.16b,v21.16b + add v0.2d,v3.2d,v4.2d // "D + T1" +.long 0xce628464 //sha512h2 v4.16b,v3.16b,v2.16b + ld1 {v25.2d},[x3],#16 + add v24.2d,v24.2d,v22.2d + ld1 {v22.16b},[x1],#16 // load next input + ext v24.16b,v24.16b,v24.16b,#8 + ext v5.16b,v0.16b,v1.16b,#8 + ext v6.16b,v2.16b,v0.16b,#8 + add v1.2d,v1.2d,v24.2d // "T1 + H + K512[i]" +.long 0xce6680a1 //sha512h v1.16b,v5.16b,v6.16b + rev64 v22.16b,v22.16b + add v3.2d,v2.2d,v1.2d // "D + T1" +.long 0xce648441 //sha512h2 v1.16b,v2.16b,v4.16b + sub x3,x3,#80*8 // rewind + add v25.2d,v25.2d,v23.2d + ld1 {v23.16b},[x1],#16 // load next input + ext v25.16b,v25.16b,v25.16b,#8 + ext v5.16b,v3.16b,v0.16b,#8 + ext v6.16b,v4.16b,v3.16b,#8 + add v0.2d,v0.2d,v25.2d // "T1 + H + K512[i]" +.long 0xce6680a0 //sha512h v0.16b,v5.16b,v6.16b + rev64 v23.16b,v23.16b + add v2.2d,v4.2d,v0.2d // "D + T1" +.long 0xce618480 //sha512h2 v0.16b,v4.16b,v1.16b + add v0.2d,v0.2d,v26.2d // accumulate + add v1.2d,v1.2d,v27.2d + add v2.2d,v2.2d,v28.2d + add v3.2d,v3.2d,v29.2d + + cbnz x2,Loop_hw + + st1 {v0.2d,v1.2d,v2.2d,v3.2d},[x0] // store context + + ldr x29,[sp],#16 + ret + +#endif #endif #endif // !OPENSSL_NO_ASM
diff --git a/third_party/crashpad/README.chromium b/third_party/crashpad/README.chromium index ff83c0db..226d2dc 100644 --- a/third_party/crashpad/README.chromium +++ b/third_party/crashpad/README.chromium
@@ -2,7 +2,7 @@ Short Name: crashpad URL: https://crashpad.chromium.org/ Version: unknown -Revision: 0ea32e0c7bc2569e597e412a20151ad8a7e64623 +Revision: 23375ab37c3dee6f442bddbf13b84a15116bcd6e License: Apache 2.0 License File: crashpad/LICENSE Security Critical: yes
diff --git a/third_party/crashpad/crashpad/DEPS b/third_party/crashpad/crashpad/DEPS index e13bea8..73adbd44 100644 --- a/third_party/crashpad/crashpad/DEPS +++ b/third_party/crashpad/crashpad/DEPS
@@ -36,10 +36,10 @@ '5bcd8e3bb929714e031a542d303f818e5a5af45d', 'crashpad/third_party/lss/lss': Var('chromium_git') + '/linux-syscall-support.git@' + - '7bde79cc274d06451bf65ae82c012a5d3e476b5a', + 'e1e7b0ad8ee99a875b272c8e33e308472e897660', 'crashpad/third_party/mini_chromium/mini_chromium': Var('chromium_git') + '/chromium/mini_chromium@' + - '502930381b23c5fa3911c8b82ec3e4ba6ceb3658', + 'bbb68fcec19ff7c268fadeebd2ef79f7203fa2f2', 'crashpad/third_party/libfuzzer/src': Var('chromium_git') + '/chromium/llvm-project/compiler-rt/lib/fuzzer.git@' + 'fda403cf93ecb8792cb1d061564d89a6553ca020',
diff --git a/third_party/crashpad/crashpad/client/crash_report_database_generic.cc b/third_party/crashpad/crashpad/client/crash_report_database_generic.cc index f5dfa89..754f578 100644 --- a/third_party/crashpad/crashpad/client/crash_report_database_generic.cc +++ b/third_party/crashpad/crashpad/client/crash_report_database_generic.cc
@@ -18,9 +18,9 @@ #include <sys/stat.h> #include <sys/types.h> +#include <tuple> #include <utility> -#include "base/ignore_result.h" #include "base/logging.h" #include "build/build_config.h" #include "client/settings.h" @@ -356,14 +356,14 @@ return kFileSystemError; } // We've moved the report to pending, so it no longer needs to be removed. - ignore_result(report->file_remover_.release()); + std::ignore = report->file_remover_.release(); // Close all the attachments and disarm their removers too. for (auto& writer : report->attachment_writers_) { writer->Close(); } for (auto& remover : report->attachment_removers_) { - ignore_result(remover.release()); + std::ignore = remover.release(); } *uuid = report->ReportID();
diff --git a/third_party/crashpad/crashpad/client/crash_report_database_mac.mm b/third_party/crashpad/crashpad/client/crash_report_database_mac.mm index d1c0a3ed..52319be 100644 --- a/third_party/crashpad/crashpad/client/crash_report_database_mac.mm +++ b/third_party/crashpad/crashpad/client/crash_report_database_mac.mm
@@ -25,8 +25,9 @@ #include <unistd.h> #include <uuid/uuid.h> +#include <tuple> + #include "base/cxx17_backports.h" -#include "base/ignore_result.h" #include "base/logging.h" #include "base/mac/scoped_nsautorelease_pool.h" #include "base/posix/eintr_wrapper.h" @@ -385,14 +386,14 @@ PLOG(ERROR) << "rename " << path.value() << " to " << new_path.value(); return kFileSystemError; } - ignore_result(report->file_remover_.release()); + std::ignore = report->file_remover_.release(); // Close all the attachments and disarm their removers too. for (auto& writer : report->attachment_writers_) { writer->Close(); } for (auto& remover : report->attachment_removers_) { - ignore_result(remover.release()); + std::ignore = remover.release(); } Metrics::CrashReportPending(Metrics::PendingReportReason::kNewlyCreated);
diff --git a/third_party/crashpad/crashpad/client/crash_report_database_win.cc b/third_party/crashpad/crashpad/client/crash_report_database_win.cc index 65feaa4..6331f650 100644 --- a/third_party/crashpad/crashpad/client/crash_report_database_win.cc +++ b/third_party/crashpad/crashpad/client/crash_report_database_win.cc
@@ -21,9 +21,9 @@ #include <time.h> #include <wchar.h> +#include <tuple> #include <utility> -#include "base/ignore_result.h" #include "base/logging.h" #include "base/numerics/safe_math.h" #include "base/strings/utf_string_conversions.h" @@ -735,14 +735,14 @@ time(nullptr), ReportState::kPending)); - ignore_result(report->file_remover_.release()); + std::ignore = report->file_remover_.release(); // Close all the attachments and disarm their removers too. for (auto& writer : report->attachment_writers_) { writer->Close(); } for (auto& remover : report->attachment_removers_) { - ignore_result(remover.release()); + std::ignore = remover.release(); } *uuid = report->ReportID();
diff --git a/third_party/crashpad/crashpad/client/crashpad_client_mac.cc b/third_party/crashpad/crashpad/client/crashpad_client_mac.cc index b4a365d..d25bfb7 100644 --- a/third_party/crashpad/crashpad/client/crashpad_client_mac.cc +++ b/third_party/crashpad/crashpad/client/crashpad_client_mac.cc
@@ -21,9 +21,9 @@ #include <stdint.h> #include <memory> +#include <tuple> #include <utility> -#include "base/ignore_result.h" #include "base/logging.h" #include "base/mac/mach_logging.h" #include "base/strings/stringprintf.h" @@ -178,7 +178,7 @@ handler_restarter->StartRestartThread( handler, database, metrics_dir, url, annotations, arguments)) { // The thread owns the object now. - ignore_result(handler_restarter.release()); + std::ignore = handler_restarter.release(); } // If StartRestartThread() failed, proceed without the ability to restart. @@ -362,7 +362,7 @@ return false; } - ignore_result(receive_right.release()); + std::ignore = receive_right.release(); return true; }
diff --git a/third_party/crashpad/crashpad/test/mac/mach_multiprocess.cc b/third_party/crashpad/crashpad/test/mac/mach_multiprocess.cc index 91a06e3..5f1b3ac 100644 --- a/third_party/crashpad/crashpad/test/mac/mach_multiprocess.cc +++ b/third_party/crashpad/crashpad/test/mac/mach_multiprocess.cc
@@ -19,9 +19,9 @@ #include <memory> #include <string> +#include <tuple> #include "base/auto_reset.h" -#include "base/ignore_result.h" #include "base/mac/scoped_mach_port.h" #include "gtest/gtest.h" #include "test/errors.h" @@ -212,7 +212,7 @@ ScopedForbidReturn forbid_return; // local_port is not valid in the forked child process. - ignore_result(info_->local_port.release()); + std::ignore = info_->local_port.release(); info_->local_port.reset(NewMachPort(MACH_PORT_RIGHT_RECEIVE)); ASSERT_NE(info_->local_port, kMachPortNull);
diff --git a/third_party/crashpad/crashpad/third_party/lss/lss.h b/third_party/crashpad/crashpad/third_party/lss/lss.h index 2646b6c..0bdd381 100644 --- a/third_party/crashpad/crashpad/third_party/lss/lss.h +++ b/third_party/crashpad/crashpad/third_party/lss/lss.h
@@ -20,7 +20,7 @@ #elif defined(CRASHPAD_LSS_SOURCE_EMBEDDED) #include "third_party/lss/lss/linux_syscall_support.h" #elif defined(CRASHPAD_LSS_SOURCE_FUCHSIA) -#include "../../../../third_party/lss/linux_syscall_support.h" +#include "../../../../third_party/linux-syscall-support/linux_syscall_support.h" #else #error Unknown lss source #endif
diff --git a/third_party/crashpad/crashpad/util/BUILD.gn b/third_party/crashpad/crashpad/util/BUILD.gn index 9f178d1..8251312 100644 --- a/third_party/crashpad/crashpad/util/BUILD.gn +++ b/third_party/crashpad/crashpad/util/BUILD.gn
@@ -663,7 +663,7 @@ if (crashpad_http_transport_impl == "socket") { sources += [ "net/http_transport_socket.cc" ] if (crashpad_use_boringssl_for_http_transport_socket) { - defines += [ "CRASHPAD_USE_BORINGSSL" ] + defines = [ "CRASHPAD_USE_BORINGSSL" ] if (crashpad_is_in_chromium || crashpad_is_in_fuchsia) { deps += [ "//third_party/boringssl" ]
diff --git a/third_party/crashpad/crashpad/util/mach/child_port_handshake.h b/third_party/crashpad/crashpad/util/mach/child_port_handshake.h index c61f734..54168691 100644 --- a/third_party/crashpad/crashpad/util/mach/child_port_handshake.h +++ b/third_party/crashpad/crashpad/util/mach/child_port_handshake.h
@@ -18,9 +18,9 @@ #include <mach/mach.h> #include <string> +#include <tuple> #include "base/files/scoped_file.h" -#include "base/ignore_result.h" #include "util/mach/child_port_types.h" namespace crashpad { @@ -124,7 +124,7 @@ //! // for use in the parent process. //! if (child_port_handshake.RunClient(receive_right.get(), //! MACH_MSG_TYPE_MOVE_RECEIVE)) { -//! ignore_result(receive_right.release()); +//! std::ignore = receive_right.release(); //! } //! \endcode //!
diff --git a/third_party/crashpad/crashpad/util/mach/mach_message_test.cc b/third_party/crashpad/crashpad/util/mach/mach_message_test.cc index 0bee04b..0c578b9 100644 --- a/third_party/crashpad/crashpad/util/mach/mach_message_test.cc +++ b/third_party/crashpad/crashpad/util/mach/mach_message_test.cc
@@ -16,7 +16,8 @@ #include <unistd.h> -#include "base/ignore_result.h" +#include <tuple> + #include "base/mac/scoped_mach_port.h" #include "gtest/gtest.h" #include "test/mac/mach_errors.h" @@ -154,7 +155,7 @@ ASSERT_EQ(right_type, implicit_cast<mach_msg_type_name_t>(MACH_MSG_TYPE_PORT_SEND)); EXPECT_TRUE(MachMessageDestroyReceivedPort(port, MACH_MSG_TYPE_PORT_RECEIVE)); - ignore_result(receive.release()); + std::ignore = receive.release(); EXPECT_TRUE(MachMessageDestroyReceivedPort(port, MACH_MSG_TYPE_PORT_SEND)); }
diff --git a/third_party/crashpad/crashpad/util/posix/close_stdio.cc b/third_party/crashpad/crashpad/util/posix/close_stdio.cc index a8efc81..37c40a4 100644 --- a/third_party/crashpad/crashpad/util/posix/close_stdio.cc +++ b/third_party/crashpad/crashpad/util/posix/close_stdio.cc
@@ -18,9 +18,10 @@ #include <paths.h> #include <unistd.h> +#include <tuple> + #include "base/check.h" #include "base/files/scoped_file.h" -#include "base/ignore_result.h" #include "base/posix/eintr_wrapper.h" namespace crashpad { @@ -32,7 +33,7 @@ HANDLE_EINTR(open(_PATH_DEVNULL, oflag | O_NOCTTY | O_CLOEXEC))); if (fd == desired_fd) { // Weird, but play along. - ignore_result(fd.release()); + std::ignore = fd.release(); } else { PCHECK(fd.get() >= 0) << "open"; PCHECK(HANDLE_EINTR(dup2(fd.get(), desired_fd)) != -1) << "dup2";
diff --git a/third_party/crashpad/crashpad/util/stream/file_encoder.cc b/third_party/crashpad/crashpad/util/stream/file_encoder.cc index 5cebbfcc..d39b5b09 100644 --- a/third_party/crashpad/crashpad/util/stream/file_encoder.cc +++ b/third_party/crashpad/crashpad/util/stream/file_encoder.cc
@@ -15,8 +15,8 @@ #include "util/stream/file_encoder.h" #include <memory> +#include <tuple> -#include "base/ignore_result.h" #include "util/file/file_io.h" #include "util/file/file_reader.h" #include "util/file/scoped_remove_file.h" @@ -77,7 +77,7 @@ if (!output->Flush()) return false; - ignore_result(file_remover.release()); + std::ignore = file_remover.release(); return true; }
diff --git a/third_party/tflite/BUILD.gn b/third_party/tflite/BUILD.gn index dd01620..8880e6b 100644 --- a/third_party/tflite/BUILD.gn +++ b/third_party/tflite/BUILD.gn
@@ -414,8 +414,18 @@ "src/tensorflow/lite/core/api/tensor_utils.cc", "src/tensorflow/lite/core/api/tensor_utils.h", "src/tensorflow/lite/core/subgraph.cc", + "src/tensorflow/lite/delegates/interpreter_utils.cc", + "src/tensorflow/lite/delegates/interpreter_utils.h", "src/tensorflow/lite/delegates/nnapi/nnapi_delegate.h", "src/tensorflow/lite/delegates/nnapi/nnapi_delegate_disabled.cc", + "src/tensorflow/lite/experimental/acceleration/configuration/delegate_registry.cc", + "src/tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h", + "src/tensorflow/lite/experimental/acceleration/configuration/flatbuffer_to_proto.cc", + "src/tensorflow/lite/experimental/acceleration/configuration/flatbuffer_to_proto.h", + "src/tensorflow/lite/experimental/acceleration/configuration/proto_to_flatbuffer.cc", + "src/tensorflow/lite/experimental/acceleration/configuration/proto_to_flatbuffer.h", + "src/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.cc", + "src/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.h", "src/tensorflow/lite/experimental/resource/initialization_status.cc", "src/tensorflow/lite/experimental/resource/initialization_status.h", "src/tensorflow/lite/experimental/resource/lookup_interfaces.h", @@ -474,7 +484,11 @@ } if (is_ios) { - sources += [ "src/tensorflow/lite/minimal_logging_ios.cc" ] + sources += [ + "src/tensorflow/lite/minimal_logging_ios.cc", + "src/tensorflow/lite/profiling/signpost_profiler.h", + "src/tensorflow/lite/profiling/signpost_profiler.mm", + ] } else if (is_android) { sources += [ "src/tensorflow/lite/minimal_logging_android.cc",
diff --git a/third_party/tflite_support/BUILD.gn b/third_party/tflite_support/BUILD.gn index cc651378..30f80c0b 100644 --- a/third_party/tflite_support/BUILD.gn +++ b/third_party/tflite_support/BUILD.gn
@@ -16,13 +16,19 @@ proto_library("tflite_support_proto") { proto_in_dir = "src" sources = [ + "src/tensorflow_lite_support/cc/task/core/proto/base_options.proto", "src/tensorflow_lite_support/cc/task/core/proto/external_file.proto", + "src/tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options.proto", + "src/tensorflow_lite_support/cc/task/text/proto/nl_classifier_options.proto", "src/tensorflow_lite_support/cc/task/vision/proto/bounding_box.proto", "src/tensorflow_lite_support/cc/task/vision/proto/class.proto", "src/tensorflow_lite_support/cc/task/vision/proto/classifications.proto", "src/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto", ] cc_generator_options = "lite=true:" + + import_dirs = [ "//third_party/tflite/src" ] + proto_deps = [ "//third_party/tflite:tflite-config-proto" ] } config("tflite_support_flags") { @@ -31,6 +37,9 @@ "-Wno-extern-c-compat", "-Wno-implicit-function-declaration", "-Wno-sign-compare", + "-Wno-ignored-attributes", + "-Wno-deprecated-declarations", + "-Wno-unused-variable", ] if (!is_win) { cflags_cc = [ "-frtti" ] @@ -51,15 +60,14 @@ sources = [ "src/tensorflow_lite_support/cc/common.cc", "src/tensorflow_lite_support/cc/common.h", - "src/tensorflow_lite_support/cc/port/default/statusor.cc", - "src/tensorflow_lite_support/cc/port/default/statusor.h", - "src/tensorflow_lite_support/cc/port/default/statusor_internals.h", "src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc", "src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h", "src/tensorflow_lite_support/cc/port/status_macros.h", "src/tensorflow_lite_support/cc/port/statusor.h", "src/tensorflow_lite_support/cc/task/core/base_task_api.h", "src/tensorflow_lite_support/cc/task/core/category.h", + "src/tensorflow_lite_support/cc/task/core/error_reporter.cc", + "src/tensorflow_lite_support/cc/task/core/error_reporter.h", "src/tensorflow_lite_support/cc/task/core/external_file_handler.cc", "src/tensorflow_lite_support/cc/task/core/external_file_handler.h", "src/tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h", @@ -68,10 +76,15 @@ "src/tensorflow_lite_support/cc/task/core/task_utils.h", "src/tensorflow_lite_support/cc/task/core/tflite_engine.cc", "src/tensorflow_lite_support/cc/task/core/tflite_engine.h", - "src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc", - "src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h", + "src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc", + "src/tensorflow_lite_support/cc/task/processor/image_preprocessor.h", + "src/tensorflow_lite_support/cc/task/processor/processor.cc", + "src/tensorflow_lite_support/cc/task/processor/processor.h", + "src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc", + "src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h", "src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc", "src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h", + "src/tensorflow_lite_support/cc/task/text/proto/nl_classifier_options_proto_inc.h", "src/tensorflow_lite_support/cc/task/vision/core/classification_head.cc", "src/tensorflow_lite_support/cc/task/vision/core/classification_head.h", "src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.cc", @@ -126,8 +139,10 @@ configs -= [ "//build/config/compiler:chromium_code" ] configs += [ - ":tflite_support_flags", "//build/config/compiler:no_chromium_code", + + # Must be after no_chromium_code for warning flags to be ordered correctly. + ":tflite_support_flags", ] public_configs = [ ":tflite_support_config" ]
diff --git a/third_party/tflite_support/README.chromium b/third_party/tflite_support/README.chromium index d60ea916..4995893e 100644 --- a/third_party/tflite_support/README.chromium +++ b/third_party/tflite_support/README.chromium
@@ -1,8 +1,8 @@ Name: TensorFlow Lite Support Short Name: tflite-support URL: https://github.com/tensorflow/tflite-support -Version: 3faaca9c6a3b22dec4d636b6b092431c9ac409e8 -Date: 2021/01/05 +Version: v0.3.1 +Date: 2021/12/16 License: Apache 2.0 License File: LICENSE Security Critical: Yes @@ -13,27 +13,32 @@ models onto mobile devices. It works cross-Platform and is supported on Java, C++ (WIP), and Swift (WIP). -Modifications: -- Use chromium's logging utility in place of glog (patches/0001-use-base-logging.patch) -- Use size_t rather than int for loops according to chromium style (0001-use-size_t.patch) -- Rely on re::StringPiece instead of absl::string_view (0001-use_StringPiece-for_string_view.patch) -- Remove unsafe use of conversions between NSString and string by using SysNSStringToUTF8. Note, this - is unused code but required for presubmit checks. (0001-use-SysNSStringToUTF8.patch) -- Remove usage of absl::Cord in tflite::support::CreateStatusWithPayload (0001-no-absl-cord.patch) -- Use _Exit instead of _exit to work on all platforms (0001-use-exit.patch) -- Remove external file handlers support for memory mapping files to support Windows - (0001-remove-unsupported-memory-map-from-file-handler.patch) -- Fixes sign compare issues in tflite-support (0001-task-utils-sign-compare.patch) -- Remove support for sentencepiece tokenizers (0001-no-sentencepiece-tokenizer.patch) -- Allows for the max sequence used by BERT models to be 512 instead of 128 (0001-bert-max-seq-len.patch) -- Ensure name field in metadata exists before checking for tflite metadata (0001-add-metadata-name-check.patch) - Third party dependencies: - tflite - libzip - utf - tensorflow-text +Modifications: +01) Use re2::StringPiece instead of absl::string_view in regex_tokenizer.cc +02) Remove support for sentencepiece tokenization because the required overhead +isn't worth adding this functionality, esp since no feature team needs it. +03) [To Be Upstreamed] Remove unused functions. +04) Remove the ABSL_DEPRECATED annotation from a deprecated struct since this +is a no-op in chromium builds and upsets clang. +05) [To Be Upstreamed] Use size_t in for loop in nl_classifier.h +06) [To Be Upstreamed] Remove unused variable in task_utils.h +07) Do not use absl::any since it is not supported in chromium +08) [To Be Upstreamed] Remove unused stl include in tokenizer_jni_lib.h +09) Remove unbuilt files that triggered checkdeps warnings, and fix file perms. +10) Remove memory mapped file support in external_file_handler.cc since it is +only available on POSIX systems. +11) Run clang-format. +12) Remove an unneeded static initializer. + Update Process: 1) Clone the tflite-support github repo at the desired commit into src/ 2) Apply each patch listed above residing in patches/ using `git apply patches/$PATCHFILE` +3) Get the build working. +4) Record the patches made with `git format-patches HEAD -<number of changes>` +
diff --git a/third_party/tflite_support/patches/0001-Fix-signed-comparison-in-base_vision_task_api.h.patch b/third_party/tflite_support/patches/0001-Fix-signed-comparison-in-base_vision_task_api.h.patch deleted file mode 100644 index 6a56310a..0000000 --- a/third_party/tflite_support/patches/0001-Fix-signed-comparison-in-base_vision_task_api.h.patch +++ /dev/null
@@ -1,34 +0,0 @@ -From c8bdfe3f6b3ce087c36b551d668b97101f620bdc Mon Sep 17 00:00:00 2001 -From: Daniel Rubery <drubery@chromium.org> -Date: Thu, 6 May 2021 11:45:48 -0700 -Subject: [PATCH] Fix signed comparison in base_vision_task_api.h - ---- - .../cc/task/vision/core/base_vision_task_api.h | 4 ++-- - 1 file changed, 2 insertions(+), 2 deletions(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h -index 3d1359685f3f..c787876bec33 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h -@@ -204,7 +204,7 @@ class BaseVisionTaskApi - if (normalization_options.num_values == 1) { - float mean_value = normalization_options.mean_values[0]; - float inv_std_value = (1.0f / normalization_options.std_values[0]); -- for (int i = 0; i < input_data_byte_size / sizeof(uint8); -+ for (size_t i = 0; i < input_data_byte_size / sizeof(uint8); - i++, input_data++, normalized_input_data++) { - *normalized_input_data = - inv_std_value * (static_cast<float>(*input_data) - mean_value); -@@ -214,7 +214,7 @@ class BaseVisionTaskApi - 1.0f / normalization_options.std_values[0], - 1.0f / normalization_options.std_values[1], - 1.0f / normalization_options.std_values[2]}; -- for (int i = 0; i < input_data_byte_size / sizeof(uint8); -+ for (size_t i = 0; i < input_data_byte_size / sizeof(uint8); - i++, input_data++, normalized_input_data++) { - *normalized_input_data = inv_std_values[i % 3] * - (static_cast<float>(*input_data) - --- -2.31.1.607.g51e8a6a459-goog -
diff --git a/third_party/tflite_support/patches/0001-Remove-signed-comparison-in-frame_buffer.h.patch b/third_party/tflite_support/patches/0001-Remove-signed-comparison-in-frame_buffer.h.patch deleted file mode 100644 index 757ec2e6..0000000 --- a/third_party/tflite_support/patches/0001-Remove-signed-comparison-in-frame_buffer.h.patch +++ /dev/null
@@ -1,25 +0,0 @@ -From 368b317061ba7deb1f42c52c5443c261bb6c03ea Mon Sep 17 00:00:00 2001 -From: Daniel Rubery <drubery@chromium.org> -Date: Thu, 6 May 2021 11:40:37 -0700 -Subject: [PATCH] Remove signed comparison in frame_buffer.h - ---- - .../tensorflow_lite_support/cc/task/vision/core/frame_buffer.h | 2 +- - 1 file changed, 1 insertion(+), 1 deletion(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h -index 22f63fc34d36..42ac080c4749 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h -@@ -246,7 +246,7 @@ class FrameBuffer { - - // Returns plane indexed by the input `index`. - const Plane plane(int index) const { -- if (index > -1 && index < planes_.size()) { -+ if (index > -1 && static_cast<size_t>(index) < planes_.size()) { - return planes_[index]; - } - return {}; --- -2.31.1.607.g51e8a6a459-goog -
diff --git a/third_party/tflite_support/patches/0001-Remove-unused-qualifiers-in-frame_buffer.h.patch b/third_party/tflite_support/patches/0001-Remove-unused-qualifiers-in-frame_buffer.h.patch deleted file mode 100644 index d80bdd1..0000000 --- a/third_party/tflite_support/patches/0001-Remove-unused-qualifiers-in-frame_buffer.h.patch +++ /dev/null
@@ -1,38 +0,0 @@ -From b23fcde4753dbf5e4adc325e9ded16800f1d1bc5 Mon Sep 17 00:00:00 2001 -From: Daniel Rubery <drubery@chromium.org> -Date: Thu, 6 May 2021 11:38:06 -0700 -Subject: [PATCH] Remove unused qualifiers in frame_buffer.h - ---- - .../cc/task/vision/core/frame_buffer.h | 6 +++--- - 1 file changed, 3 insertions(+), 3 deletions(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h -index 1556b7dfabef..22f63fc34d36 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h -@@ -242,7 +242,7 @@ class FrameBuffer { - timestamp_(timestamp) {} - - // Returns number of planes. -- const int plane_count() const { return planes_.size(); } -+ int plane_count() const { return planes_.size(); } - - // Returns plane indexed by the input `index`. - const Plane plane(int index) const { -@@ -256,10 +256,10 @@ class FrameBuffer { - const Dimension dimension() const { return dimension_; } - - // Returns FrameBuffer format. -- const Format format() const { return format_; } -+ Format format() const { return format_; } - - // Returns FrameBuffer orientation. -- const Orientation orientation() const { return orientation_; } -+ Orientation orientation() const { return orientation_; } - - // Returns FrameBuffer timestamp. - const absl::Time timestamp() const { return timestamp_; } --- -2.31.1.607.g51e8a6a459-goog -
diff --git a/third_party/tflite_support/patches/0001-Remove-use-of-banned-absl-any.patch b/third_party/tflite_support/patches/0001-Remove-use-of-banned-absl-any.patch deleted file mode 100644 index fdd6997b..0000000 --- a/third_party/tflite_support/patches/0001-Remove-use-of-banned-absl-any.patch +++ /dev/null
@@ -1,64 +0,0 @@ -From 670dfffa386fd0ff28e66cfe1238af43b4e587ce Mon Sep 17 00:00:00 2001 -From: Daniel Rubery <drubery@chromium.org> -Date: Thu, 6 May 2021 11:22:13 -0700 -Subject: [PATCH] Remove use of banned absl::any - ---- - .../cc/task/vision/core/frame_buffer.h | 27 ------------------- - 1 file changed, 27 deletions(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h -index 2bea92883c4d..1556b7dfabef 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h -@@ -27,7 +27,6 @@ limitations under the License. - #include "absl/strings/str_cat.h" - #include "absl/time/clock.h" - #include "absl/time/time.h" --#include "absl/types/any.h" - #include "absl/types/optional.h" - #include "tensorflow_lite_support/cc/port/integral_types.h" - #include "tensorflow_lite_support/cc/port/statusor.h" -@@ -253,31 +252,6 @@ class FrameBuffer { - return {}; - } - -- // Returns the tag associated to the tag_key. -- absl::any GetTag(const std::string& tag_key) const { -- auto iter = tags_.find(tag_key); -- if (iter != tags_.end()) { -- return iter->second; -- } -- return absl::any(); -- } -- -- // Inserts or updates the tags map with key value pair (tag_key, tag_value). -- void InsertOrUpdateTag(const std::string& tag_key, absl::any tag_value) { -- tags_[tag_key] = std::move(tag_value); -- } -- -- // Inserts the key value pair (tag_key, tag_value) into tags map. If the -- // tag_key already exists, an internal error will return. -- absl::Status InsertTag(const std::string& tag_key, absl::any tag_value) { -- auto iter = tags_.emplace(tag_key, tag_value); -- if (iter.second) { -- return absl::OkStatus(); -- } -- return absl::InternalError(absl::StrCat( -- "tag_key already exists in tags.tag_key was not inserted: ", tag_key)); -- } -- - // Returns FrameBuffer dimension. - const Dimension dimension() const { return dimension_; } - -@@ -292,7 +266,6 @@ class FrameBuffer { - - private: - std::vector<Plane> planes_; -- std::map<std::string, absl::any> tags_; - Dimension dimension_; - Format format_; - Orientation orientation_; --- -2.31.1.607.g51e8a6a459-goog -
diff --git a/third_party/tflite_support/patches/0001-Use-third_party-libyuv.patch b/third_party/tflite_support/patches/0001-Use-third_party-libyuv.patch deleted file mode 100644 index 1210ba8..0000000 --- a/third_party/tflite_support/patches/0001-Use-third_party-libyuv.patch +++ /dev/null
@@ -1,30 +0,0 @@ -From 226d36a5d12ca3080b1d0d9b450be949e418e318 Mon Sep 17 00:00:00 2001 -From: Daniel Rubery <drubery@chromium.org> -Date: Thu, 6 May 2021 11:26:23 -0700 -Subject: [PATCH] Use third_party/libyuv - ---- - .../cc/task/vision/utils/libyuv_frame_buffer_utils.cc | 2 +- - 1 file changed, 1 insertion(+), 1 deletion(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc -index f3cd0c70fe1b..b50b500bb5a4 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc -@@ -23,12 +23,12 @@ limitations under the License. - #include "absl/status/status.h" - #include "absl/strings/str_cat.h" - #include "absl/strings/str_format.h" --#include "include/libyuv.h" - #include "tensorflow_lite_support/cc/common.h" - #include "tensorflow_lite_support/cc/port/integral_types.h" - #include "tensorflow_lite_support/cc/port/status_macros.h" - #include "tensorflow_lite_support/cc/port/statusor.h" - #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" -+#include "third_party/libyuv/include/libyuv.h" - - namespace tflite { - namespace task { --- -2.31.1.607.g51e8a6a459-goog -
diff --git a/third_party/tflite_support/patches/0001-add-metadata-name-check.patch b/third_party/tflite_support/patches/0001-add-metadata-name-check.patch deleted file mode 100644 index fd2b258..0000000 --- a/third_party/tflite_support/patches/0001-add-metadata-name-check.patch +++ /dev/null
@@ -1,28 +0,0 @@ -From 5e3e4b63a6bfd871afa16f8d27f2daa8b99d84e9 Mon Sep 17 00:00:00 2001 -From: mcrouse <mcrouse@google.com> -Date: Thu, 19 Aug 2021 11:31:28 -0700 -Subject: [PATCH] add metadata name check - ---- - .../metadata/cc/metadata_extractor.cc | 5 ++++- - 1 file changed, 4 insertions(+), 1 deletion(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/third_party/tflite-support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc -index ad5df76f1c27b..42f2a7c13a516 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc -+++ b/third_party/tflite-support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc -@@ -159,7 +159,10 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer( - // Look for the "TFLITE_METADATA" field, if any. - for (int i = 0; i < model_->metadata()->size(); ++i) { - const auto metadata = model_->metadata()->Get(i); -- if (metadata->name() && metadata->name()->str() != kMetadataBufferName) { -+ if (!metadata->name()) { -+ continue; -+ } -+ if (metadata->name()->str() != kMetadataBufferName) { - continue; - } - const auto buffer_index = metadata->buffer(); --- -2.33.0.rc2.250.ged5fa647cd-goog -
diff --git a/third_party/tflite_support/patches/0001-bert-max-seq-len.patch b/third_party/tflite_support/patches/0001-bert-max-seq-len.patch deleted file mode 100644 index d984907b..0000000 --- a/third_party/tflite_support/patches/0001-bert-max-seq-len.patch +++ /dev/null
@@ -1,25 +0,0 @@ -From 49cd597b3c1fbfef2e3772682aa98575654131ba Mon Sep 17 00:00:00 2001 -From: Sophie Chang <sophiechang@chromium.org> -Date: Mon, 1 Mar 2021 19:33:21 +0000 -Subject: [PATCH] allow for more tokens - ---- - .../cc/task/text/nlclassifier/bert_nl_classifier.h | 2 +- - 1 file changed, 1 insertion(+), 1 deletion(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h -index cd5c5a3ade03..e78085d98761 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h -@@ -52,7 +52,7 @@ class BertNLClassifier : public NLClassifier { - public: - using NLClassifier::NLClassifier; - // Max number of tokens to pass to the model. -- static constexpr int kMaxSeqLen = 128; -+ static constexpr int kMaxSeqLen = 512; - - // Factory function to create a BertNLClassifier from TFLite model with - // metadata. --- -2.30.1.766.gb4fecdf3b7-goog -
diff --git a/third_party/tflite_support/patches/0001-no-absl-cord.patch b/third_party/tflite_support/patches/0001-no-absl-cord.patch deleted file mode 100644 index 5f18f38..0000000 --- a/third_party/tflite_support/patches/0001-no-absl-cord.patch +++ /dev/null
@@ -1,33 +0,0 @@ -From 61fb20a08d2325d03759a5b9394c033901fc0a7f Mon Sep 17 00:00:00 2001 -From: Sophie Chang <sophiechang@chromium.org> -Date: Wed, 3 Feb 2021 04:21:19 +0000 -Subject: [PATCH] do not use cord in tflite status payload - ---- - .../tflite-support/src/tensorflow_lite_support/cc/common.cc | 3 --- - 1 file changed, 3 deletions(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/common.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/common.cc -index 47dd3bcc6581..ed373e96d555 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/common.cc -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/common.cc -@@ -15,7 +15,6 @@ limitations under the License. - - #include "tensorflow_lite_support/cc/common.h" - --#include "absl/strings/cord.h" - #include "absl/strings/str_cat.h" - - namespace tflite { -@@ -26,8 +25,6 @@ absl::Status CreateStatusWithPayload(absl::StatusCode canonical_code, - TfLiteSupportStatus tfls_code) { - // NOTE: Ignores `message` if the canonical code is ok. - absl::Status status = absl::Status(canonical_code, message); -- // NOTE: Does nothing if the canonical code is ok. -- status.SetPayload(kTfLiteSupportPayload, absl::Cord(absl::StrCat(tfls_code))); - return status; - } - --- -2.30.0.365.g02bc693789-goog -
diff --git a/third_party/tflite_support/patches/0001-no-sentencepiece-tokenizer.patch b/third_party/tflite_support/patches/0001-no-sentencepiece-tokenizer.patch deleted file mode 100644 index df91f5c..0000000 --- a/third_party/tflite_support/patches/0001-no-sentencepiece-tokenizer.patch +++ /dev/null
@@ -1,40 +0,0 @@ -From 7faac3ddcbc05275d797dda64a9b9d7f2279ae1c Mon Sep 17 00:00:00 2001 -From: Sophie Chang <sophiechang@chromium.org> -Date: Thu, 11 Feb 2021 00:53:47 +0000 -Subject: [PATCH] no sentencepiece tokenizer - ---- - .../cc/text/tokenizers/tokenizer_utils.cc | 11 ----------- - 1 file changed, 11 deletions(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc -index 352c4a8c5e4f..46786fd7faf8 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc -@@ -20,7 +20,6 @@ limitations under the License. - #include "tensorflow_lite_support/cc/port/status_macros.h" - #include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" - #include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h" --#include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h" - #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" - - namespace tflite { -@@ -73,16 +72,6 @@ StatusOr<std::unique_ptr<Tokenizer>> CreateTokenizerFromProcessUnit( - return absl::make_unique<BertTokenizer>(vocab_buffer.data(), - vocab_buffer.size()); - } -- case ProcessUnitOptions_SentencePieceTokenizerOptions: { -- const tflite::SentencePieceTokenizerOptions* options = -- tokenizer_process_unit->options_as<SentencePieceTokenizerOptions>(); -- ASSIGN_OR_RETURN(absl::string_view model_buffer, -- CheckAndLoadFirstAssociatedFile( -- options->sentencePiece_model(), metadata_extractor)); -- // TODO(b/160647204): Extract sentence piece model vocabulary -- return absl::make_unique<SentencePieceTokenizer>(model_buffer.data(), -- model_buffer.size()); -- } - case ProcessUnitOptions_RegexTokenizerOptions: { - const tflite::RegexTokenizerOptions* options = - tokenizer_process_unit->options_as<RegexTokenizerOptions>(); --- -2.30.0.478.g8a0d178c01-goog
diff --git a/third_party/tflite_support/patches/0001-remove-unsupported-memory-map-from-file-handler.patch b/third_party/tflite_support/patches/0001-remove-unsupported-memory-map-from-file-handler.patch deleted file mode 100644 index 55c2e7b..0000000 --- a/third_party/tflite_support/patches/0001-remove-unsupported-memory-map-from-file-handler.patch +++ /dev/null
@@ -1,224 +0,0 @@ -From 8c5a37f7324b4a03f123c6faecedb4abc8eb0066 Mon Sep 17 00:00:00 2001 -From: mcrouse <mcrouse@google.com> -Date: Fri, 5 Feb 2021 15:30:25 +0000 -Subject: [PATCH] remove unsupported memory map from file handler - ---- - .../cc/task/core/external_file_handler.cc | 126 +----------------- - .../cc/task/core/external_file_handler.h | 7 - - .../cc/task/core/tflite_engine.cc | 2 - - .../cc/task/core/tflite_engine.h | 2 - - 4 files changed, 6 insertions(+), 131 deletions(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc -index ee689e41c6e5..55b662f0926f 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc -@@ -18,9 +18,6 @@ limitations under the License. - #include <errno.h> - #include <fcntl.h> - #include <stddef.h> --#include <sys/mman.h> --#include <unistd.h> -- - #include <memory> - #include <string> - -@@ -40,18 +37,6 @@ using ::tflite::support::CreateStatusWithPayload; - using ::tflite::support::StatusOr; - using ::tflite::support::TfLiteSupportStatus; - --// Gets the offset aligned to page size for mapping given files into memory by --// file descriptor correctly, as according to mmap(2), the offset used in mmap --// must be a multiple of sysconf(_SC_PAGE_SIZE). --int64 GetPageSizeAlignedOffset(int64 offset) { -- int64 aligned_offset = offset; -- int64 page_size = sysconf(_SC_PAGE_SIZE); -- if (offset % page_size != 0) { -- aligned_offset = offset / page_size * page_size; -- } -- return aligned_offset; --} -- - } // namespace - - /* static */ -@@ -71,103 +56,11 @@ absl::Status ExternalFileHandler::MapExternalFile() { - if (!external_file_.file_content().empty()) { - return absl::OkStatus(); - } -- if (external_file_.file_name().empty() && -- !external_file_.has_file_descriptor_meta()) { -- return CreateStatusWithPayload( -- StatusCode::kInvalidArgument, -- "ExternalFile must specify at least one of 'file_content', file_name' " -- "or 'file_descriptor_meta'.", -- TfLiteSupportStatus::kInvalidArgumentError); -- } -- // Obtain file descriptor, offset and size. -- int fd = -1; -- if (!external_file_.file_name().empty()) { -- owned_fd_ = open(external_file_.file_name().c_str(), O_RDONLY); -- if (owned_fd_ < 0) { -- const std::string error_message = absl::StrFormat( -- "Unable to open file at %s", external_file_.file_name()); -- switch (errno) { -- case ENOENT: -- return CreateStatusWithPayload( -- StatusCode::kNotFound, error_message, -- TfLiteSupportStatus::kFileNotFoundError); -- case EACCES: -- case EPERM: -- return CreateStatusWithPayload( -- StatusCode::kPermissionDenied, error_message, -- TfLiteSupportStatus::kFilePermissionDeniedError); -- case EINTR: -- return CreateStatusWithPayload(StatusCode::kUnavailable, -- error_message, -- TfLiteSupportStatus::kFileReadError); -- case EBADF: -- return CreateStatusWithPayload(StatusCode::kFailedPrecondition, -- error_message, -- TfLiteSupportStatus::kFileReadError); -- default: -- return CreateStatusWithPayload( -- StatusCode::kUnknown, -- absl::StrFormat("%s, errno=%d", error_message, errno), -- TfLiteSupportStatus::kFileReadError); -- } -- } -- fd = owned_fd_; -- } else { -- fd = external_file_.file_descriptor_meta().fd(); -- if (fd < 0) { -- return CreateStatusWithPayload( -- StatusCode::kInvalidArgument, -- absl::StrFormat("Provided file descriptor is invalid: %d < 0", fd), -- TfLiteSupportStatus::kInvalidArgumentError); -- } -- buffer_offset_ = external_file_.file_descriptor_meta().offset(); -- buffer_size_ = external_file_.file_descriptor_meta().length(); -- } -- // Get actual file size. Always use 0 as offset to lseek(2) to get the actual -- // file size, as SEEK_END returns the size of the file *plus* offset. -- size_t file_size = lseek(fd, /*offset=*/0, SEEK_END); -- if (file_size <= 0) { -- return CreateStatusWithPayload( -- StatusCode::kUnknown, -- absl::StrFormat("Unable to get file size, errno=%d", errno), -- TfLiteSupportStatus::kFileReadError); -- } -- // Deduce buffer size if not explicitly provided through file descriptor. -- if (buffer_size_ <= 0) { -- buffer_size_ = file_size - buffer_offset_; -- } -- // Check for out of range issues. -- if (file_size <= buffer_offset_) { -- return CreateStatusWithPayload( -- StatusCode::kInvalidArgument, -- absl::StrFormat("Provided file offset (%d) exceeds or matches actual " -- "file length (%d)", -- buffer_offset_, file_size), -- TfLiteSupportStatus::kInvalidArgumentError); -- } -- if (file_size < buffer_size_ + buffer_offset_) { -- return CreateStatusWithPayload( -- StatusCode::kInvalidArgument, -- absl::StrFormat("Provided file length + offset (%d) exceeds actual " -- "file length (%d)", -- buffer_size_ + buffer_offset_, file_size), -- TfLiteSupportStatus::kInvalidArgumentError); -- } -- // If buffer_offset_ is not multiple of sysconf(_SC_PAGE_SIZE), align with -- // extra leading bytes and adjust buffer_size_ to account for the extra -- // leading bytes. -- buffer_aligned_offset_ = GetPageSizeAlignedOffset(buffer_offset_); -- buffer_aligned_size_ = buffer_size_ + buffer_offset_ - buffer_aligned_offset_; -- // Map into memory. -- buffer_ = mmap(/*addr=*/nullptr, buffer_aligned_size_, PROT_READ, MAP_SHARED, -- fd, buffer_aligned_offset_); -- if (buffer_ == MAP_FAILED) { -- return CreateStatusWithPayload( -- StatusCode::kUnknown, -- absl::StrFormat("Unable to map file to memory buffer, errno=%d", errno), -- TfLiteSupportStatus::kFileMmapError); -- } -- return absl::OkStatus(); -+ return CreateStatusWithPayload( -+ StatusCode::kInvalidArgument, -+ "ExternalFile must have 'file_content' set, loading from" -+ "'file_name' is not supported.", -+ TfLiteSupportStatus::kInvalidArgumentError); - } - - absl::string_view ExternalFileHandler::GetFileContent() { -@@ -180,14 +73,7 @@ absl::string_view ExternalFileHandler::GetFileContent() { - } - } - --ExternalFileHandler::~ExternalFileHandler() { -- if (buffer_ != MAP_FAILED) { -- munmap(buffer_, buffer_aligned_size_); -- } -- if (owned_fd_ >= 0) { -- close(owned_fd_); -- } --} -+ExternalFileHandler::~ExternalFileHandler() = default; - - } // namespace core - } // namespace task -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h -index 236d90347698..ad292dcc3702 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h -@@ -65,10 +65,6 @@ class ExternalFileHandler { - // Reference to the input ExternalFile. - const ExternalFile& external_file_; - -- // The file descriptor of the ExternalFile if provided by path, as it is -- // opened and owned by this class. Set to -1 otherwise. -- int owned_fd_{-1}; -- - // Points to the memory buffer mapped from the file descriptor of the - // ExternalFile, if provided by path or file descriptor. - void* buffer_{}; -@@ -82,9 +78,6 @@ class ExternalFileHandler { - - // The aligned mapped memory buffer offset, if any. - int64 buffer_aligned_offset_{}; -- // The aligned mapped memory buffer size in bytes taking into account the -- // offset shift introduced by buffer_aligned_memory_offset_, if any. -- int64 buffer_aligned_size_{}; - }; - - } // namespace core -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc -index 0317d5f8ea34..6230e5c645c0 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc -@@ -15,8 +15,6 @@ limitations under the License. - - #include "tensorflow_lite_support/cc/task/core/tflite_engine.h" - --#include <unistd.h> -- - #include "absl/strings/match.h" - #include "absl/strings/str_cat.h" - #include "tensorflow/lite/builtin_ops.h" -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h -index 6a7e97dd264e..bc55f6b0fe72 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h -@@ -16,8 +16,6 @@ limitations under the License. - #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ - #define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ - --#include <sys/mman.h> -- - #include <memory> - - #include "absl/memory/memory.h" --- -2.30.0.365.g02bc693789-goog -
diff --git a/third_party/tflite_support/patches/0001-task-utils-sign-compare.patch b/third_party/tflite_support/patches/0001-task-utils-sign-compare.patch deleted file mode 100644 index 85172c99..0000000 --- a/third_party/tflite_support/patches/0001-task-utils-sign-compare.patch +++ /dev/null
@@ -1,34 +0,0 @@ -From f84b50f175efff54ee6a6ef795703907245260cd Mon Sep 17 00:00:00 2001 -From: Sophie Chang <sophiechang@chromium.org> -Date: Wed, 10 Feb 2021 17:55:30 +0000 -Subject: [PATCH] fix sign issues - ---- - .../src/tensorflow_lite_support/cc/task/core/task_utils.h | 4 ++-- - 1 file changed, 2 insertions(+), 2 deletions(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/task_utils.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/task_utils.h -index 744dbbfb0f80..ced3dbcae9e4 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/task_utils.h -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/task_utils.h -@@ -119,7 +119,7 @@ inline void PopulateVector(const TfLiteTensor* tensor, std::vector<T>* data) { - const T* results = GetTensorData<T>(tensor); - size_t num = tensor->bytes / sizeof(tensor->type); - data->reserve(num); -- for (int i = 0; i < num; i++) { -+ for (size_t i = 0; i < num; i++) { - data->emplace_back(results[i]); - } - } -@@ -169,7 +169,7 @@ static TensorType* FindTensorByName( - tensor_metadatas->size() != tensors.size()) { - return nullptr; - } -- for (int i = 0; i < tensor_metadatas->size(); i++) { -+ for (flatbuffers::uoffset_t i = 0; i < tensor_metadatas->size(); i++) { - if (strcmp(name.data(), tensor_metadatas->Get(i)->name()->c_str()) == 0) { - return tensors[i]; - } --- -2.30.0.478.g8a0d178c01-goog -
diff --git a/third_party/tflite_support/patches/0001-use-StringPiece-for-string_view.patch b/third_party/tflite_support/patches/0001-use-StringPiece-for-string_view.patch deleted file mode 100644 index 28f4740..0000000 --- a/third_party/tflite_support/patches/0001-use-StringPiece-for-string_view.patch +++ /dev/null
@@ -1,38 +0,0 @@ -From 81287a62d65139f29c512fed88ed734bef2c33f5 Mon Sep 17 00:00:00 2001 -From: Michael Crouse <mcrouse@chromium.org> -Date: Tue, 22 Dec 2020 14:25:39 -0800 -Subject: [PATCH] use StringPiece for string_view - ---- - .../cc/text/tokenizers/regex_tokenizer.cc | 10 +++++----- - 1 file changed, 5 insertions(+), 5 deletions(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc -index 38aff8805b30..44c43b2d5086 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc -@@ -61,16 +61,16 @@ RegexTokenizer::RegexTokenizer(const std::string& regex_pattern, - } - - TokenizerResult RegexTokenizer::Tokenize(const std::string& input) { -- absl::string_view leftover(input.data()); -- absl::string_view last_end = leftover; -+ re2::StringPiece leftover(input.data()); -+ re2::StringPiece last_end = leftover; - - TokenizerResult result; - - // Keep looking for split points until we have reached the end of the input. -- absl::string_view extracted_delim_token; -+ re2::StringPiece extracted_delim_token; - while (RE2::FindAndConsume(&leftover, delim_re_, &extracted_delim_token)) { -- absl::string_view token(last_end.data(), -- extracted_delim_token.data() - last_end.data()); -+ re2::StringPiece token(last_end.data(), -+ extracted_delim_token.data() - last_end.data()); - bool has_non_empty_token = token.length() > 0; - - last_end = leftover; --- -2.29.2.729.g45daf8777d-goog -
diff --git a/third_party/tflite_support/patches/0001-use-SysNSStringToUTF8.patch b/third_party/tflite_support/patches/0001-use-SysNSStringToUTF8.patch deleted file mode 100644 index 4ca0e41..0000000 --- a/third_party/tflite_support/patches/0001-use-SysNSStringToUTF8.patch +++ /dev/null
@@ -1,29 +0,0 @@ -From e4b8790a56487279b084fb59a2186a8bfd24b838 Mon Sep 17 00:00:00 2001 -From: Michael Crouse <mcrouse@chromium.org> -Date: Thu, 7 Jan 2021 08:20:06 -0800 -Subject: [PATCH] use SysNSStringToUTF8 - ---- - .../tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm | 3 ++- - 1 file changed, 2 insertions(+), 1 deletion(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm b/third_party/tflite-support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm -index 2a11bb673047..b82be34a9ab9 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm -+++ b/third_party/tflite-support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm -@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ -+#inmport "base/strings/sys_string_conversions.h" - #import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h" - - std::string MakeString(NSString* str) { -- return std::string([str UTF8String]); -+ return SysNSStringToUTF8(str); - } - - NSString* MakeNSString(const std::string& str) { --- -2.29.2.729.g45daf8777d-goog -
diff --git a/third_party/tflite_support/patches/0001-use-base-logging.patch b/third_party/tflite_support/patches/0001-use-base-logging.patch deleted file mode 100644 index 0cdb9d6..0000000 --- a/third_party/tflite_support/patches/0001-use-base-logging.patch +++ /dev/null
@@ -1,26 +0,0 @@ -From 5307d81798215dae084b5079e797fd4408040340 Mon Sep 17 00:00:00 2001 -From: Michael Crouse <mcrouse@chromium.org> -Date: Tue, 22 Dec 2020 14:18:09 -0800 -Subject: [PATCH] use base logging - ---- - .../src/tensorflow_lite_support/cc/port/default/statusor.cc | 2 +- - 1 file changed, 1 insertion(+), 1 deletion(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc -index 547a79192324..182c37e4aaf6 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc -@@ -18,8 +18,8 @@ limitations under the License. - - #include <utility> - --#include <glog/logging.h> - #include "absl/strings/str_cat.h" -+#include "base/logging.h" - - namespace tflite { - namespace support { --- -2.29.2.729.g45daf8777d-goog -
diff --git a/third_party/tflite_support/patches/0001-use-exit.patch b/third_party/tflite_support/patches/0001-use-exit.patch deleted file mode 100644 index 986543f..0000000 --- a/third_party/tflite_support/patches/0001-use-exit.patch +++ /dev/null
@@ -1,25 +0,0 @@ -From 49de34d7489ba5218a822461a42786844a1e344b Mon Sep 17 00:00:00 2001 -From: Sophie Chang <sophiechang@chromium.org> -Date: Wed, 3 Feb 2021 04:30:56 +0000 -Subject: [PATCH] use _Exit - ---- - .../src/tensorflow_lite_support/cc/port/default/statusor.cc | 2 +- - 1 file changed, 1 insertion(+), 1 deletion(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc -index 182c37e4aaf6..058c0070f0da 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc -@@ -50,7 +50,7 @@ void Helper::HandleInvalidStatusCtorArg(absl::Status* status) { - void Helper::Crash(const absl::Status& status) { - LOG(FATAL) << "Attempting to fetch value instead of handling error " - << status; -- _exit(1); -+ _Exit(1); - } - - void ThrowBadStatusOrAccess(absl::Status status) { --- -2.30.0.365.g02bc693789-goog -
diff --git a/third_party/tflite_support/patches/0001-use-re2-StringPiece-for-RegexTokenizer-Tokenize.patch b/third_party/tflite_support/patches/0001-use-re2-StringPiece-for-RegexTokenizer-Tokenize.patch new file mode 100644 index 0000000..950a6db --- /dev/null +++ b/third_party/tflite_support/patches/0001-use-re2-StringPiece-for-RegexTokenizer-Tokenize.patch
@@ -0,0 +1,36 @@ +From b16b7af6f58ede0718fabf9c0da7495c79400c90 Mon Sep 17 00:00:00 2001 +From: Robert Ogden <robertogden@chromium.org> +Date: Wed, 15 Dec 2021 14:39:48 -0800 +Subject: [PATCH 01/11] use re2 StringPiece for RegexTokenizer::Tokenize + +--- + .../cc/text/tokenizers/regex_tokenizer.cc | 8 ++++---- + 1 file changed, 4 insertions(+), 4 deletions(-) + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc +index 564f5f63a0584..832f9df42f824 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc +@@ -61,15 +61,15 @@ RegexTokenizer::RegexTokenizer(const std::string& regex_pattern, + } + + TokenizerResult RegexTokenizer::Tokenize(const std::string& input) { +- absl::string_view leftover(input.data()); +- absl::string_view last_end = leftover; ++ re2::StringPiece leftover(input.data()); ++ re2::StringPiece last_end = leftover; + + TokenizerResult result; + + // Keep looking for split points until we have reached the end of the input. +- absl::string_view extracted_delim_token; ++ re2::StringPiece extracted_delim_token; + while (RE2::FindAndConsume(&leftover, delim_re_, &extracted_delim_token)) { +- absl::string_view token(last_end.data(), ++ re2::StringPiece token(last_end.data(), + extracted_delim_token.data() - last_end.data()); + bool has_non_empty_token = token.length() > 0; + +-- +2.34.1.307.g9b7440fafd-goog +
diff --git a/third_party/tflite_support/patches/0001-use-size_t.patch b/third_party/tflite_support/patches/0001-use-size_t.patch deleted file mode 100644 index 447bdb81..0000000 --- a/third_party/tflite_support/patches/0001-use-size_t.patch +++ /dev/null
@@ -1,25 +0,0 @@ -From ecb535154168358a72de6b51099a9549b970bce5 Mon Sep 17 00:00:00 2001 -From: Michael Crouse <mcrouse@chromium.org> -Date: Tue, 22 Dec 2020 14:34:12 -0800 -Subject: [PATCH] use size_t - ---- - .../cc/task/text/nlclassifier/nl_classifier.h | 2 +- - 1 file changed, 1 insertion(+), 1 deletion(-) - -diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h -index d3c3bed2083d..4055a93467d4 100644 ---- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h -+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h -@@ -151,7 +151,7 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>, - const std::string& name, - int index) { - if (metadata_array != nullptr && metadata_array->size() == tensors.size()) { -- for (int i = 0; i < metadata_array->size(); i++) { -+ for (size_t i = 0; i < metadata_array->size(); i++) { - if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) { - return tensors[i]; - } --- -2.29.2.729.g45daf8777d-goog -
diff --git a/third_party/tflite_support/patches/0002-sentencepiece-tokenization-not-supported.patch b/third_party/tflite_support/patches/0002-sentencepiece-tokenization-not-supported.patch new file mode 100644 index 0000000..95774a5 --- /dev/null +++ b/third_party/tflite_support/patches/0002-sentencepiece-tokenization-not-supported.patch
@@ -0,0 +1,51 @@ +From 9aa45ea43f8d84db1e20674c294f9ab958b12d7e Mon Sep 17 00:00:00 2001 +From: Robert Ogden <robertogden@chromium.org> +Date: Wed, 15 Dec 2021 14:57:16 -0800 +Subject: [PATCH 02/11] sentencepiece tokenization not supported + +--- + .../cc/text/tokenizers/tokenizer_utils.cc | 14 ++++---------- + 1 file changed, 4 insertions(+), 10 deletions(-) + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc +index 9abca9691f058..28f0137f54278 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc +@@ -20,7 +20,6 @@ limitations under the License. + #include "tensorflow_lite_support/cc/port/status_macros.h" + #include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" + #include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h" +-#include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h" + #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + + namespace tflite { +@@ -29,7 +28,6 @@ namespace text { + namespace tokenizer { + + using ::tflite::ProcessUnit; +-using ::tflite::SentencePieceTokenizerOptions; + using ::tflite::support::CreateStatusWithPayload; + using ::tflite::support::StatusOr; + using ::tflite::support::TfLiteSupportStatus; +@@ -74,14 +72,10 @@ StatusOr<std::unique_ptr<Tokenizer>> CreateTokenizerFromProcessUnit( + vocab_buffer.size()); + } + case ProcessUnitOptions_SentencePieceTokenizerOptions: { +- const tflite::SentencePieceTokenizerOptions* options = +- tokenizer_process_unit->options_as<SentencePieceTokenizerOptions>(); +- ASSIGN_OR_RETURN(absl::string_view model_buffer, +- CheckAndLoadFirstAssociatedFile( +- options->sentencePiece_model(), metadata_extractor)); +- // TODO(b/160647204): Extract sentence piece model vocabulary +- return absl::make_unique<SentencePieceTokenizer>(model_buffer.data(), +- model_buffer.size()); ++ return CreateStatusWithPayload( ++ absl::StatusCode::kInvalidArgument, ++ "Chromium does not support sentencepiece tokenization", ++ TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + case ProcessUnitOptions_RegexTokenizerOptions: { + const tflite::RegexTokenizerOptions* options = +-- +2.34.1.307.g9b7440fafd-goog +
diff --git a/third_party/tflite_support/patches/0003-rm-unused-func.patch b/third_party/tflite_support/patches/0003-rm-unused-func.patch new file mode 100644 index 0000000..5441faf --- /dev/null +++ b/third_party/tflite_support/patches/0003-rm-unused-func.patch
@@ -0,0 +1,224 @@ +From 7bd2e5f0e2bb560e55efc3dd86249ff42a10d08c Mon Sep 17 00:00:00 2001 +From: Robert Ogden <robertogden@chromium.org> +Date: Wed, 15 Dec 2021 15:08:59 -0800 +Subject: [PATCH 03/11] rm unused func + +--- + .../vision/utils/libyuv_frame_buffer_utils.cc | 201 +----------------- + 1 file changed, 1 insertion(+), 200 deletions(-) + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc +index 0ece48636504e..6fd3ca81c984c 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc +@@ -1326,206 +1326,7 @@ absl::Status CropResize(const FrameBuffer& buffer, int x0, int y0, int x1, + } + } + +-// Returns the scaled dimension of the input_size best fit within the +-// output_size bound while respecting the aspect ratio. +-FrameBuffer::Dimension GetScaledDimension(FrameBuffer::Dimension input_size, +- FrameBuffer::Dimension output_size) { +- int original_width = input_size.width; +- int original_height = input_size.height; +- int bound_width = output_size.width; +- int bound_height = output_size.height; +- int new_width = original_width; +- int new_height = original_height; +- +- // Try to fit the width first. +- new_width = bound_width; +- new_height = (new_width * original_height) / original_width; +- +- // Try to fit the height if needed. +- if (new_height > bound_height) { +- new_height = bound_height; +- new_width = (new_height * original_width) / original_height; +- } +- return FrameBuffer::Dimension{.width = new_width, .height = new_height}; +-} +- +-// This method only supports kGRAY, kRGBA, and kRGB formats. +-absl::Status UniformCropResizePlane(const FrameBuffer& buffer, +- std::vector<int> crop_coordinates, +- FrameBuffer* output_buffer) { +- int x0 = 0, y0 = 0; +- FrameBuffer::Dimension input_dimension = buffer.dimension(); +- if (!crop_coordinates.empty()) { +- x0 = crop_coordinates[0]; +- y0 = crop_coordinates[1]; +- input_dimension = +- GetCropDimension(x0, crop_coordinates[2], y0, crop_coordinates[3]); +- } +- if (input_dimension == output_buffer->dimension()) { +- // Cropping only case. +- return CropPlane(buffer, x0, y0, crop_coordinates[2], crop_coordinates[3], +- output_buffer); +- } +- +- // Cropping is achieved by adjusting origin to (x0, y0). +- ASSIGN_OR_RETURN(int pixel_stride, GetPixelStrides(buffer.format())); +- int adjusted_offset = +- buffer.plane(0).stride.row_stride_bytes * y0 + x0 * pixel_stride; +- FrameBuffer::Plane plane = { +- /*buffer=*/buffer.plane(0).buffer + adjusted_offset, +- /*stride=*/{buffer.plane(0).stride.row_stride_bytes, pixel_stride}}; +- auto adjusted_buffer = +- FrameBuffer::Create({plane}, input_dimension, buffer.format(), +- buffer.orientation(), buffer.timestamp()); +- +- // Uniform resize is achieved by adjusting the resize dimension to fit the +- // output_buffer and respect the input aspect ratio at the same time. We +- // create an intermediate output buffer with adjusted dimension and point its +- // backing buffer to the output_buffer. Note the stride information on the +- // adjusted_output_buffer is not used in the Resize* methods. +- FrameBuffer::Dimension adjusted_dimension = +- GetScaledDimension(input_dimension, output_buffer->dimension()); +- FrameBuffer::Plane output_plane = {/*buffer=*/output_buffer->plane(0).buffer, +- /*stride=*/output_buffer->plane(0).stride}; +- auto adjusted_output_buffer = FrameBuffer::Create( +- {output_plane}, adjusted_dimension, output_buffer->format(), +- output_buffer->orientation(), output_buffer->timestamp()); +- +- switch (buffer.format()) { +- case FrameBuffer::Format::kRGB: +- return ResizeRgb(*adjusted_buffer, adjusted_output_buffer.get()); +- case FrameBuffer::Format::kRGBA: +- return ResizeRgba(*adjusted_buffer, adjusted_output_buffer.get()); +- case FrameBuffer::Format::kGRAY: +- return ResizeGray(*adjusted_buffer, adjusted_output_buffer.get()); +- default: +- return CreateStatusWithPayload( +- StatusCode::kInternal, +- absl::StrFormat("Format %i is not supported.", buffer.format()), +- TfLiteSupportStatus::kImageProcessingError); +- } +-} +- +-absl::Status UniformCropResizeYuv(const FrameBuffer& buffer, +- std::vector<int> crop_coordinates, +- FrameBuffer* output_buffer) { +- int x0 = 0, y0 = 0; +- FrameBuffer::Dimension input_dimension = buffer.dimension(); +- if (!crop_coordinates.empty()) { +- x0 = crop_coordinates[0]; +- y0 = crop_coordinates[1]; +- input_dimension = +- GetCropDimension(x0, crop_coordinates[2], y0, crop_coordinates[3]); +- } +- if (input_dimension == output_buffer->dimension()) { +- // Cropping only case. +- int x1 = crop_coordinates[2]; +- int y1 = crop_coordinates[3]; +- switch (buffer.format()) { +- case FrameBuffer::Format::kNV12: +- case FrameBuffer::Format::kNV21: +- return CropNv(buffer, x0, y0, x1, y1, output_buffer); +- case FrameBuffer::Format::kYV12: +- case FrameBuffer::Format::kYV21: +- return CropYv(buffer, x0, y0, x1, y1, output_buffer); +- default: +- return CreateStatusWithPayload( +- StatusCode::kInternal, +- absl::StrFormat("Format %i is not supported.", buffer.format()), +- TfLiteSupportStatus::kImageProcessingError); +- } +- } +- +- // Cropping is achieved by adjusting origin to (x0, y0). +- ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, +- FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); +- // Cropping YUV planes by offsetting the origins of each plane. +- // TODO(b/152629712): Investigate the impact of color shifting caused by the +- // bounding box with odd X or Y starting positions. +- const int plane_y_offset = input_data.y_row_stride * y0 + x0; +- const int plane_uv_offset = input_data.uv_row_stride * (y0 / 2) + +- input_data.uv_pixel_stride * (x0 / 2); +- FrameBuffer::Plane adjusted_plane_y = { +- /*buffer=*/input_data.y_buffer + plane_y_offset, +- /*stride=*/{input_data.y_row_stride, /*pixel_stride_bytes=*/1}}; +- FrameBuffer::Plane adjusted_plane_u = { +- /*buffer=*/input_data.u_buffer + plane_uv_offset, +- /*stride=*/{input_data.uv_row_stride, input_data.uv_pixel_stride}}; +- FrameBuffer::Plane adjusted_plane_v = { +- /*buffer=*/input_data.v_buffer + plane_uv_offset, +- /*stride=*/{input_data.uv_row_stride, input_data.uv_pixel_stride}}; +- +- // Uniform resize is achieved by adjusting the resize dimension to fit the +- // output_buffer and respect the input aspect ratio at the same time. For +- // YUV formats, we need access to the actual output dimension to get the +- // correct address of each plane. For this, we are not calling ResizeNv or +- // ResizeYv but the libyuv scale methods directly. +- FrameBuffer::Dimension adjusted_dimension = +- GetScaledDimension(input_dimension, output_buffer->dimension()); +- ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, +- FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); +- +- switch (buffer.format()) { +- case FrameBuffer::Format::kNV12: { +- int ret = libyuv::NV12Scale( +- adjusted_plane_y.buffer, adjusted_plane_y.stride.row_stride_bytes, +- adjusted_plane_u.buffer, adjusted_plane_u.stride.row_stride_bytes, +- input_dimension.width, input_dimension.height, +- const_cast<uint8_t*>(output_data.y_buffer), output_data.y_row_stride, +- const_cast<uint8_t*>(output_data.u_buffer), output_data.uv_row_stride, +- adjusted_dimension.width, adjusted_dimension.height, +- libyuv::FilterMode::kFilterBilinear); +- if (ret != 0) { +- return CreateStatusWithPayload( +- StatusCode::kUnknown, "Libyuv NV12Scale operation failed.", +- TfLiteSupportStatus::kImageProcessingBackendError); +- } +- return absl::OkStatus(); +- } +- case FrameBuffer::Format::kNV21: { +- int ret = libyuv::NV12Scale( +- adjusted_plane_y.buffer, adjusted_plane_y.stride.row_stride_bytes, +- adjusted_plane_v.buffer, adjusted_plane_v.stride.row_stride_bytes, +- input_dimension.width, input_dimension.height, +- const_cast<uint8_t*>(output_data.y_buffer), output_data.y_row_stride, +- const_cast<uint8_t*>(output_data.v_buffer), output_data.uv_row_stride, +- adjusted_dimension.width, adjusted_dimension.height, +- libyuv::FilterMode::kFilterBilinear); +- if (ret != 0) { +- return CreateStatusWithPayload( +- StatusCode::kUnknown, "Libyuv NV12Scale operation failed.", +- TfLiteSupportStatus::kImageProcessingBackendError); +- } +- return absl::OkStatus(); +- } +- case FrameBuffer::Format::kYV12: +- case FrameBuffer::Format::kYV21: { +- int ret = libyuv::I420Scale( +- adjusted_plane_y.buffer, adjusted_plane_y.stride.row_stride_bytes, +- adjusted_plane_u.buffer, adjusted_plane_u.stride.row_stride_bytes, +- adjusted_plane_v.buffer, adjusted_plane_v.stride.row_stride_bytes, +- input_dimension.width, input_dimension.height, +- const_cast<uint8_t*>(output_data.y_buffer), output_data.y_row_stride, +- const_cast<uint8_t*>(output_data.u_buffer), output_data.uv_row_stride, +- const_cast<uint8_t*>(output_data.v_buffer), output_data.uv_row_stride, +- adjusted_dimension.width, adjusted_dimension.height, +- libyuv::FilterMode::kFilterBilinear); +- if (ret != 0) { +- return CreateStatusWithPayload( +- StatusCode::kUnknown, "Libyuv I420Scale operation failed.", +- TfLiteSupportStatus::kImageProcessingBackendError); +- } +- return absl::OkStatus(); +- } +- default: +- return CreateStatusWithPayload( +- StatusCode::kInternal, +- absl::StrFormat("Format %i is not supported.", buffer.format()), +- TfLiteSupportStatus::kImageProcessingError); +- } +- return absl::OkStatus(); +-} +-} // namespace ++} // namespace + + absl::Status LibyuvFrameBufferUtils::Crop(const FrameBuffer& buffer, int x0, + int y0, int x1, int y1, +-- +2.34.1.307.g9b7440fafd-goog +
diff --git a/third_party/tflite_support/patches/0004-rm-noop-deprecated-attribute.patch b/third_party/tflite_support/patches/0004-rm-noop-deprecated-attribute.patch new file mode 100644 index 0000000..d59a5ad --- /dev/null +++ b/third_party/tflite_support/patches/0004-rm-noop-deprecated-attribute.patch
@@ -0,0 +1,26 @@ +From 243aadd7dcea9be980aa89d183bf2dea7cba202b Mon Sep 17 00:00:00 2001 +From: Robert Ogden <robertogden@chromium.org> +Date: Wed, 15 Dec 2021 15:49:40 -0800 +Subject: [PATCH 04/11] rm noop deprecated attribute + +--- + .../cc/task/text/nlclassifier/nl_classifier.h | 3 --- + 1 file changed, 3 deletions(-) + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h +index d5b49dfd75277..ac12536355db4 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h +@@ -43,9 +43,6 @@ namespace text { + namespace nlclassifier { + + // Options to identify input and output tensors of the model +-ABSL_DEPRECATED( +- "Prefer using `tflite::task::text::NLClassifierOptions` and " +- "`CreateFromOptions`") + struct NLClassifierOptions { + int input_tensor_index = 0; + int output_score_tensor_index = 0; +-- +2.34.1.307.g9b7440fafd-goog +
diff --git a/third_party/tflite_support/patches/0005-use-size_t-in-for-loop.patch b/third_party/tflite_support/patches/0005-use-size_t-in-for-loop.patch new file mode 100644 index 0000000..015edf6f --- /dev/null +++ b/third_party/tflite_support/patches/0005-use-size_t-in-for-loop.patch
@@ -0,0 +1,25 @@ +From 5d67933d8d4440816a02f7a319d7323041c3f7bf Mon Sep 17 00:00:00 2001 +From: Robert Ogden <robertogden@chromium.org> +Date: Wed, 15 Dec 2021 15:51:22 -0800 +Subject: [PATCH 05/11] use size_t in for loop + +--- + .../cc/task/text/nlclassifier/nl_classifier.h | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h +index ac12536355db4..2adafba8f2fa9 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h +@@ -179,7 +179,7 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>, + metadata_array, + const std::string& name, int index) { + if (metadata_array != nullptr && metadata_array->size() == tensors.size()) { +- for (int i = 0; i < metadata_array->size(); i++) { ++ for (size_t i = 0; i < metadata_array->size(); i++) { + if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) { + return tensors[i]; + } +-- +2.34.1.307.g9b7440fafd-goog +
diff --git a/third_party/tflite_support/patches/0006-unused-variable.patch b/third_party/tflite_support/patches/0006-unused-variable.patch new file mode 100644 index 0000000..fa433fa --- /dev/null +++ b/third_party/tflite_support/patches/0006-unused-variable.patch
@@ -0,0 +1,29 @@ +From b2d06daf8ab5cff8748489407b6ad10ea600948d Mon Sep 17 00:00:00 2001 +From: Robert Ogden <robertogden@chromium.org> +Date: Thu, 16 Dec 2021 08:35:07 -0800 +Subject: [PATCH 06/11] unused variable + +--- + .../src/tensorflow_lite_support/cc/task/core/task_utils.h | 6 ++++-- + 1 file changed, 4 insertions(+), 2 deletions(-) + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h +index 03ef9ade9af41..e95ea73a4a812 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h +@@ -144,8 +144,10 @@ inline absl::Status PopulateVector(const TfLiteTensor* tensor, + template <> + inline absl::Status PopulateVector<std::string>( + const TfLiteTensor* tensor, std::vector<std::string>* data) { +- std::string* v; +- ASSIGN_OR_RETURN(v, AssertAndReturnTypedTensor<std::string>(tensor)); ++ if (tensor->type != typeToTfLiteType<std::string>()) { ++ return absl::InvalidArgumentError("not of type string"); ++ } ++ + int num = GetStringCount(tensor); + data->reserve(num); + for (int i = 0; i < num; i++) { +-- +2.34.1.307.g9b7440fafd-goog +
diff --git a/third_party/tflite_support/patches/0007-do-not-use-absl-any.patch b/third_party/tflite_support/patches/0007-do-not-use-absl-any.patch new file mode 100644 index 0000000..0fbe740 --- /dev/null +++ b/third_party/tflite_support/patches/0007-do-not-use-absl-any.patch
@@ -0,0 +1,64 @@ +From d3d4385132632282fc91c735875ebfc90697b067 Mon Sep 17 00:00:00 2001 +From: Robert Ogden <robertogden@chromium.org> +Date: Thu, 16 Dec 2021 13:28:16 -0800 +Subject: [PATCH 07/11] do not use absl any + +--- + .../cc/task/vision/core/frame_buffer.h | 27 ------------------- + 1 file changed, 27 deletions(-) + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h +index c1289673cb82b..1668447393e9e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h +@@ -27,7 +27,6 @@ limitations under the License. + #include "absl/strings/str_cat.h" // from @com_google_absl + #include "absl/time/clock.h" // from @com_google_absl + #include "absl/time/time.h" // from @com_google_absl +-#include "absl/types/any.h" // from @com_google_absl + #include "absl/types/optional.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/integral_types.h" + #include "tensorflow_lite_support/cc/port/statusor.h" +@@ -250,31 +249,6 @@ class FrameBuffer { + return {}; + } + +- // Returns the tag associated to the tag_key. +- absl::any GetTag(const std::string& tag_key) const { +- auto iter = tags_.find(tag_key); +- if (iter != tags_.end()) { +- return iter->second; +- } +- return absl::any(); +- } +- +- // Inserts or updates the tags map with key value pair (tag_key, tag_value). +- void InsertOrUpdateTag(const std::string& tag_key, absl::any tag_value) { +- tags_[tag_key] = std::move(tag_value); +- } +- +- // Inserts the key value pair (tag_key, tag_value) into tags map. If the +- // tag_key already exists, an internal error will return. +- absl::Status InsertTag(const std::string& tag_key, absl::any tag_value) { +- auto iter = tags_.emplace(tag_key, tag_value); +- if (iter.second) { +- return absl::OkStatus(); +- } +- return absl::InternalError(absl::StrCat( +- "tag_key already exists in tags.tag_key was not inserted: ", tag_key)); +- } +- + // Returns FrameBuffer dimension. + const Dimension dimension() const { return dimension_; } + +@@ -289,7 +263,6 @@ class FrameBuffer { + + private: + std::vector<Plane> planes_; +- std::map<std::string, absl::any> tags_; + Dimension dimension_; + Format format_; + Orientation orientation_; +-- +2.34.1.307.g9b7440fafd-goog +
diff --git a/third_party/tflite_support/patches/0008-unused-string-include.patch b/third_party/tflite_support/patches/0008-unused-string-include.patch new file mode 100644 index 0000000..2f5c20aeb --- /dev/null +++ b/third_party/tflite_support/patches/0008-unused-string-include.patch
@@ -0,0 +1,25 @@ +From 5feffc2cdd8c970490fadd812401be4eb57174d5 Mon Sep 17 00:00:00 2001 +From: Robert Ogden <robertogden@chromium.org> +Date: Thu, 16 Dec 2021 13:43:57 -0800 +Subject: [PATCH 08/11] unused string include + +--- + .../cc/text/tokenizers/tokenizer_jni_lib.h | 2 -- + 1 file changed, 2 deletions(-) + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h +index fc7285c6807b0..33677d305a853 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h +@@ -17,8 +17,6 @@ limitations under the License. + + #include <jni.h> + +-#include <string> +- + #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" + #include "tensorflow_lite_support/cc/utils/jni_utils.h" + +-- +2.34.1.307.g9b7440fafd-goog +
diff --git a/third_party/tflite_support/patches/0009-remove-unbuilt-files-and-change-exec-bit-where-neede.patch b/third_party/tflite_support/patches/0009-remove-unbuilt-files-and-change-exec-bit-where-neede.patch new file mode 100644 index 0000000..4ea8bdce --- /dev/null +++ b/third_party/tflite_support/patches/0009-remove-unbuilt-files-and-change-exec-bit-where-neede.patch
@@ -0,0 +1,666 @@ +From 515f1ef8496e5c73318aa41f6295bbfbefb6bbae Mon Sep 17 00:00:00 2001 +From: Robert Ogden <robertogden@chromium.org> +Date: Thu, 16 Dec 2021 13:56:14 -0800 +Subject: [PATCH 09/11] remove unbuilt files and change exec bit where needed + +--- + .../cc/port/benchmark.h | 21 --- + .../cc/port/default/status_matchers.h | 55 ------- + .../tensorflow_lite_support/cc/port/gmock.h | 21 --- + .../tensorflow_lite_support/cc/port/gtest.h | 21 --- + .../tensorflow_lite_support/cc/port/proto2.h | 32 ---- + .../examples/task/audio/desktop/python/BUILD | 0 + .../task/audio/desktop/python/README.md | 0 + .../desktop/python/audio_classifier_demo.py | 0 + .../examples/task/vision/desktop/python/BUILD | 0 + .../desktop/python/image_classifier_demo.py | 0 + .../desktop/python/image_segmenter_demo.py | 0 + .../desktop/python/object_detector_demo.py | 0 + .../ios/utils/Sources/TFLStringUtil.mm | 23 --- + .../metadata/cc/metadata_populator.cc | 150 ------------------ + .../metadata/cc/utils/zip_mem_file.cc | 124 --------------- + .../metadata/cc/utils/zip_mem_file.h | 71 --------- + .../odml/ios/image/resources/grace_hopper.jpg | Bin + .../tools/ci_build/build_all.sh | 0 + .../ci_build/builds/build_ios_framework.sh | 0 + .../tools/ci_build/builds/pip_smoke_test.sh | 0 + .../tools/ci_build/common.sh | 0 + .../tools/ci_build/common_win.bat | 0 + 22 files changed, 518 deletions(-) + delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h + delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h + delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h + delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h + delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h + mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/BUILD + mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/README.md + mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/audio_classifier_demo.py + mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/BUILD + mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_classifier_demo.py + mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_segmenter_demo.py + mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/object_detector_demo.py + delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm + delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc + delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc + delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h + mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/resources/grace_hopper.jpg + mode change 100644 => 100755 third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/build_all.sh + mode change 100644 => 100755 third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/build_ios_framework.sh + mode change 100644 => 100755 third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh + mode change 100644 => 100755 third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common.sh + mode change 100644 => 100755 third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h +deleted file mode 100644 +index 74bc1a6857664..0000000000000 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h ++++ /dev/null +@@ -1,21 +0,0 @@ +-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +- +-Licensed under the Apache License, Version 2.0 (the "License"); +-you may not use this file except in compliance with the License. +-You may obtain a copy of the License at +- +- http://www.apache.org/licenses/LICENSE-2.0 +- +-Unless required by applicable law or agreed to in writing, software +-distributed under the License is distributed on an "AS IS" BASIS, +-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-See the License for the specific language governing permissions and +-limitations under the License. +-==============================================================================*/ +- +-#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_ +-#define TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_ +- +-#include "gtest/benchmark.h" +- +-#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_ +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h +deleted file mode 100644 +index 6d9668043c183..0000000000000 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h ++++ /dev/null +@@ -1,55 +0,0 @@ +-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +- +-Licensed under the Apache License, Version 2.0 (the "License"); +-you may not use this file except in compliance with the License. +-You may obtain a copy of the License at +- +- http://www.apache.org/licenses/LICENSE-2.0 +- +-Unless required by applicable law or agreed to in writing, software +-distributed under the License is distributed on an "AS IS" BASIS, +-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-See the License for the specific language governing permissions and +-limitations under the License. +-==============================================================================*/ +- +-#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MATCHERS_H_ +-#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MATCHERS_H_ +- +-#include "gmock/gmock.h" +-#include "gtest/gtest.h" +- +-#define SUPPORT_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y +-#define SUPPORT_STATUS_MACROS_IMPL_CONCAT_(x, y) \ +- SUPPORT_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) +- +-#undef SUPPORT_ASSERT_OK +-#define SUPPORT_ASSERT_OK(expr) \ +- SUPPORT_ASSERT_OK_IMPL_( \ +- SUPPORT_STATUS_MACROS_IMPL_CONCAT_(_status, __LINE__), expr) +- +-#define SUPPORT_ASSERT_OK_IMPL_(status, expr) \ +- auto status = (expr); \ +- ASSERT_TRUE(status.ok()); +- +-#undef SUPPORT_EXPECT_OK +-#define SUPPORT_EXPECT_OK(expr) \ +- SUPPORT_EXPECT_OK_IMPL_( \ +- SUPPORT_STATUS_MACROS_IMPL_CONCAT_(_status, __LINE__), expr) +- +-#define SUPPORT_EXPECT_OK_IMPL_(status, expr) \ +- auto status = (expr); \ +- EXPECT_TRUE(status.ok()); +- +-#undef SUPPORT_ASSERT_OK_AND_ASSIGN +-#define SUPPORT_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ +- SUPPORT_ASSERT_OK_AND_ASSIGN_IMPL_( \ +- SUPPORT_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, \ +- rexpr) +- +-#define SUPPORT_ASSERT_OK_AND_ASSIGN_IMPL_(statusor, lhs, rexpr) \ +- auto statusor = (rexpr); \ +- ASSERT_TRUE(statusor.ok()); \ +- lhs = std::move(statusor.value()) +- +-#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MATCHERS_H_ +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h +deleted file mode 100644 +index 5e4334db323d6..0000000000000 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h ++++ /dev/null +@@ -1,21 +0,0 @@ +-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +- +-Licensed under the Apache License, Version 2.0 (the "License"); +-you may not use this file except in compliance with the License. +-You may obtain a copy of the License at +- +- http://www.apache.org/licenses/LICENSE-2.0 +- +-Unless required by applicable law or agreed to in writing, software +-distributed under the License is distributed on an "AS IS" BASIS, +-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-See the License for the specific language governing permissions and +-limitations under the License. +-==============================================================================*/ +- +-#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_ +-#define TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_ +- +-#include "gmock/gmock.h" +- +-#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_ +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h +deleted file mode 100644 +index dbe2e5e6f9d7c..0000000000000 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h ++++ /dev/null +@@ -1,21 +0,0 @@ +-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +- +-Licensed under the Apache License, Version 2.0 (the "License"); +-you may not use this file except in compliance with the License. +-You may obtain a copy of the License at +- +- http://www.apache.org/licenses/LICENSE-2.0 +- +-Unless required by applicable law or agreed to in writing, software +-distributed under the License is distributed on an "AS IS" BASIS, +-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-See the License for the specific language governing permissions and +-limitations under the License. +-==============================================================================*/ +- +-#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_ +-#define TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_ +- +-#include "gtest/gtest.h" +- +-#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_ +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h +deleted file mode 100644 +index 3cde2ab81d6ee..0000000000000 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h ++++ /dev/null +@@ -1,32 +0,0 @@ +-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +- +-Licensed under the Apache License, Version 2.0 (the "License"); +-you may not use this file except in compliance with the License. +-You may obtain a copy of the License at +- +- http://www.apache.org/licenses/LICENSE-2.0 +- +-Unless required by applicable law or agreed to in writing, software +-distributed under the License is distributed on an "AS IS" BASIS, +-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-See the License for the specific language governing permissions and +-limitations under the License. +-==============================================================================*/ +-#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_PROTO_NS_H_ +-#define TENSORFLOW_LITE_SUPPORT_CC_PORT_PROTO_NS_H_ +- +-#include "google/protobuf/message_lite.h" +-#include "google/protobuf/text_format.h" +- +-namespace tflite { +-namespace support { +-namespace proto { +- +-using TextFormat = ::google::protobuf::TextFormat; +-using MessageLite = ::google::protobuf::MessageLite; +- +-} // namespace proto +-} // namespace support +-} // namespace tflite +- +-#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_PROTO_NS_H_ +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/BUILD +old mode 100755 +new mode 100644 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/README.md b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/README.md +old mode 100755 +new mode 100644 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/audio_classifier_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/audio_classifier_demo.py +old mode 100755 +new mode 100644 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/BUILD +old mode 100755 +new mode 100644 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_classifier_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_classifier_demo.py +old mode 100755 +new mode 100644 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_segmenter_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_segmenter_demo.py +old mode 100755 +new mode 100644 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/object_detector_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/object_detector_demo.py +old mode 100755 +new mode 100644 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm b/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm +deleted file mode 100644 +index 6e9cf23802427..0000000000000 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm ++++ /dev/null +@@ -1,23 +0,0 @@ +-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +- +-Licensed under the Apache License, Version 2.0 (the "License"); +-you may not use this file except in compliance with the License. +-You may obtain a copy of the License at +- +- http://www.apache.org/licenses/LICENSE-2.0 +- +-Unless required by applicable law or agreed to in writing, software +-distributed under the License is distributed on an "AS IS" BASIS, +-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-See the License for the specific language governing permissions and +-limitations under the License. +-==============================================================================*/ +-#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h" +- +-std::string MakeString(NSString* str) { return std::string([str UTF8String]); } +- +-NSString* MakeNSString(const std::string& str) { +- return [[NSString alloc] initWithBytes:const_cast<void*>(static_cast<const void*>(str.data())) +- length:str.length() +- encoding:NSUTF8StringEncoding]; +-} +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc +deleted file mode 100644 +index e21d426369e2e..0000000000000 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc ++++ /dev/null +@@ -1,150 +0,0 @@ +-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +- +-Licensed under the Apache License, Version 2.0 (the "License"); +-you may not use this file except in compliance with the License. +-You may obtain a copy of the License at +- +- http://www.apache.org/licenses/LICENSE-2.0 +- +-Unless required by applicable law or agreed to in writing, software +-distributed under the License is distributed on an "AS IS" BASIS, +-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-See the License for the specific language governing permissions and +-limitations under the License. +-==============================================================================*/ +- +-#include "tensorflow_lite_support/metadata/cc/metadata_populator.h" +- +-#include <cstdlib> +-#include <cstring> +-#include <functional> +- +-#include "flatbuffers/flatbuffers.h" // from @flatbuffers +-#include "contrib/minizip/ioapi.h" +-#include "contrib/minizip/zip.h" +-#include "tensorflow/lite/schema/schema_generated.h" +-#include "tensorflow_lite_support/cc/common.h" +-#include "tensorflow_lite_support/cc/port/status_macros.h" +-#include "tensorflow_lite_support/cc/port/statusor.h" +-#include "tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h" +-#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" +- +-namespace tflite { +-namespace metadata { +- +-namespace { +-constexpr char kMetadataBufferName[] = "TFLITE_METADATA"; +- +-using ::absl::StatusCode; +-using ::tflite::support::CreateStatusWithPayload; +-using ::tflite::support::TfLiteSupportStatus; +- +-} // namespace +- +-ModelMetadataPopulator::ModelMetadataPopulator(const tflite::Model& model) { +- model.UnPackTo(&model_t_); +-} +- +-/* static */ +-tflite::support::StatusOr<std::unique_ptr<ModelMetadataPopulator>> +-ModelMetadataPopulator::CreateFromModelBuffer(const char* buffer_data, +- size_t buffer_size) { +- // Rely on the simplest, base flatbuffers verifier. Here is not the place to +- // e.g. use an OpResolver: we just want to make sure the buffer is valid to +- // access the metadata. +- flatbuffers::Verifier verifier = flatbuffers::Verifier( +- reinterpret_cast<const uint8_t*>(buffer_data), buffer_size); +- if (!tflite::VerifyModelBuffer(verifier)) { +- return CreateStatusWithPayload( +- StatusCode::kInvalidArgument, +- "The model is not a valid FlatBuffer buffer.", +- TfLiteSupportStatus::kInvalidFlatBufferError); +- } +- // Use absl::WrapUnique() to call private constructor: +- // https://abseil.io/tips/126. +- return absl::WrapUnique( +- new ModelMetadataPopulator(*tflite::GetModel(buffer_data))); +-} +- +-void ModelMetadataPopulator::LoadMetadata(const char* metadata_buffer_data, +- size_t metadata_buffer_size) { +- // Pack the model metadata in a buffer. +- auto model_metadata_buffer = std::make_unique<tflite::BufferT>(); +- model_metadata_buffer->data = {metadata_buffer_data, +- metadata_buffer_data + metadata_buffer_size}; +- // Check if the model already has metadata. If so, just override the buffer +- // and exit. +- for (const auto& metadata_t : model_t_.metadata) { +- if (metadata_t->name == kMetadataBufferName) { +- model_t_.buffers[metadata_t->buffer] = std::move(model_metadata_buffer); +- return; +- } +- } +- // Model doesn't already have metadata: add metadata buffer and pointer to the +- // buffer in the model metadata section. +- model_t_.buffers.push_back(std::move(model_metadata_buffer)); +- auto metadata_t = std::make_unique<tflite::MetadataT>(); +- metadata_t->name = kMetadataBufferName; +- metadata_t->buffer = model_t_.buffers.size() - 1; +- model_t_.metadata.push_back(std::move(metadata_t)); +-} +- +-void ModelMetadataPopulator::LoadAssociatedFiles( +- const absl::flat_hash_map<std::string, std::string>& associated_files) { +- associated_files_ = associated_files; +-} +- +-tflite::support::StatusOr<std::string> +-ModelMetadataPopulator::AppendAssociatedFiles(const char* model_buffer_data, +- size_t model_buffer_size) { +- // Create in-memory zip file. +- ZipMemFile mem_file = ZipMemFile(model_buffer_data, model_buffer_size); +- // Open zip. +- zipFile zf = zipOpen2(/*pathname=*/nullptr, APPEND_STATUS_CREATEAFTER, +- /*globalcomment=*/nullptr, &mem_file.GetFileFuncDef()); +- if (zf == nullptr) { +- return CreateStatusWithPayload( +- StatusCode::kUnknown, "Unable to open zip archive", +- TfLiteSupportStatus::kMetadataAssociatedFileZipError); +- } +- // Write associated files. +- for (const auto& [name, contents] : associated_files_) { +- if ((zipOpenNewFileInZip(zf, name.c_str(), +- /*zipfi=*/nullptr, +- /*extrafield_local=*/nullptr, +- /*size_extrafield_local=*/0, +- /*extrafield_global=*/nullptr, +- /*size_extrafield_global=*/0, +- /*comment=*/nullptr, +- /*method=*/0, +- /*level=*/Z_DEFAULT_COMPRESSION) != ZIP_OK) || +- (zipWriteInFileInZip(zf, contents.data(), contents.length()) != +- ZIP_OK) || +- (zipCloseFileInZip(zf) != ZIP_OK)) { +- return CreateStatusWithPayload( +- StatusCode::kUnknown, "Unable to write file to zip archive", +- TfLiteSupportStatus::kMetadataAssociatedFileZipError); +- } +- } +- // Close zip. +- if (zipClose(zf, /*global_comment=*/nullptr) != ZIP_OK) { +- return CreateStatusWithPayload( +- StatusCode::kUnknown, "Unable to close zip archive", +- TfLiteSupportStatus::kMetadataAssociatedFileZipError); +- } +- // Return as a string. +- return std::string(mem_file.GetFileContent()); +-} +- +-tflite::support::StatusOr<std::string> ModelMetadataPopulator::Populate() { +- // Build model. +- flatbuffers::FlatBufferBuilder model_fbb; +- model_fbb.Finish(tflite::Model::Pack(model_fbb, &model_t_), +- tflite::ModelIdentifier()); +- return AppendAssociatedFiles( +- reinterpret_cast<char*>(model_fbb.GetBufferPointer()), +- model_fbb.GetSize()); +-} +- +-} // namespace metadata +-} // namespace tflite +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc +deleted file mode 100644 +index 2e4d9107c8c31..0000000000000 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc ++++ /dev/null +@@ -1,124 +0,0 @@ +-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +- +-Licensed under the Apache License, Version 2.0 (the "License"); +-you may not use this file except in compliance with the License. +-You may obtain a copy of the License at +- +- http://www.apache.org/licenses/LICENSE-2.0 +- +-Unless required by applicable law or agreed to in writing, software +-distributed under the License is distributed on an "AS IS" BASIS, +-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-See the License for the specific language governing permissions and +-limitations under the License. +-==============================================================================*/ +- +-#include "tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h" +- +-#include <algorithm> +-#include <cstdio> +- +-#include "absl/strings/string_view.h" // from @com_google_absl +-#include "contrib/minizip/ioapi.h" +- +-namespace tflite { +-namespace metadata { +- +-ZipMemFile::ZipMemFile(const char* buffer, size_t size) +- : data_(buffer, size), offset_(0) { +- zlib_filefunc_def_.zopen_file = OpenFile; +- zlib_filefunc_def_.zread_file = ReadFile; +- zlib_filefunc_def_.zwrite_file = WriteFile; +- zlib_filefunc_def_.ztell_file = TellFile; +- zlib_filefunc_def_.zseek_file = SeekFile; +- zlib_filefunc_def_.zclose_file = CloseFile; +- zlib_filefunc_def_.zerror_file = ErrorFile; +- zlib_filefunc_def_.opaque = this; +-} +- +-zlib_filefunc_def& ZipMemFile::GetFileFuncDef() { return zlib_filefunc_def_; } +- +-absl::string_view ZipMemFile::GetFileContent() const { return data_; } +- +-/* static */ +-voidpf ZipMemFile::OpenFile(voidpf opaque, const char* filename, int mode) { +- // Result is never used, but needs to be non-null for `zipOpen2` not to fail. +- return opaque; +-} +- +-/* static */ +-size_t ZipMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf, +- size_t size) { +- auto* mem_file = static_cast<ZipMemFile*>(opaque); +- if (mem_file->offset_ < 0 || mem_file->Size() < mem_file->offset_) { +- return 0; +- } +- if (mem_file->offset_ + size > mem_file->Size()) { +- size = mem_file->Size() - mem_file->offset_; +- } +- memcpy(buf, +- static_cast<const char*>(mem_file->data_.c_str()) + mem_file->offset_, +- size); +- mem_file->offset_ += size; +- return size; +-} +- +-/* static */ +-size_t ZipMemFile::WriteFile(voidpf opaque, voidpf stream, const void* buf, +- size_t size) { +- auto* mem_file = static_cast<ZipMemFile*>(opaque); +- if (mem_file->offset_ + size > mem_file->Size()) { +- mem_file->data_.resize(mem_file->offset_ + size); +- } +- mem_file->data_.replace(mem_file->offset_, size, +- static_cast<const char*>(buf), size); +- mem_file->offset_ += size; +- return size; +-} +- +-/* static */ +-ptrdiff_t ZipMemFile::TellFile(voidpf opaque, voidpf stream) { +- return static_cast<ZipMemFile*>(opaque)->offset_; +-} +- +-/* static */ +-ptrdiff_t ZipMemFile::SeekFile(voidpf opaque, voidpf stream, size_t offset, +- int origin) { +- auto* mem_file = static_cast<ZipMemFile*>(opaque); +- switch (origin) { +- case SEEK_SET: +- mem_file->offset_ = offset; +- return 0; +- case SEEK_CUR: +- if (mem_file->offset_ + offset < 0 || +- mem_file->offset_ + offset > mem_file->Size()) { +- return -1; +- } +- mem_file->offset_ += offset; +- return 0; +- case SEEK_END: +- if (mem_file->Size() - offset < 0 || +- mem_file->Size() - offset > mem_file->Size()) { +- return -1; +- } +- mem_file->offset_ = offset + mem_file->Size(); +- return 0; +- default: +- return -1; +- } +-} +- +-/* static */ +-int ZipMemFile::CloseFile(voidpf opaque, voidpf stream) { +- // Nothing to do. +- return 0; +-} +- +-/* static */ +-int ZipMemFile::ErrorFile(voidpf opaque, voidpf stream) { +- // Unused. +- return 0; +-} +- +-} // namespace metadata +-} // namespace tflite +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h +deleted file mode 100644 +index ef7843d70cff6..0000000000000 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h ++++ /dev/null +@@ -1,71 +0,0 @@ +-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +- +-Licensed under the Apache License, Version 2.0 (the "License"); +-you may not use this file except in compliance with the License. +-You may obtain a copy of the License at +- +- http://www.apache.org/licenses/LICENSE-2.0 +- +-Unless required by applicable law or agreed to in writing, software +-distributed under the License is distributed on an "AS IS" BASIS, +-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-See the License for the specific language governing permissions and +-limitations under the License. +-==============================================================================*/ +- +-#ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_ +-#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_ +- +-#include <cstdlib> +- +-#include "absl/strings/string_view.h" // from @com_google_absl +-#include "contrib/minizip/ioapi.h" +- +-namespace tflite { +-namespace metadata { +- +-// In-memory zip file implementation. +-// +-// Adapted from [1], with a few key differences: +-// * backed by an `std::string` instead of malloc-ed C buffers, +-// * supports opening the file for writing through `zipOpen2`. +-// +-// [1]: +-// https://github.com/google/libkml/blob/master/third_party/zlib-1.2.3/contrib/minizip/iomem_simple.c +-class ZipMemFile { +- public: +- // Constructs an in-memory zip file from a buffer. +- ZipMemFile(const char* buffer, size_t size); +- // Provides access to the `zlib_filefunc_def` implementation for the in-memory +- // zip file. +- zlib_filefunc_def& GetFileFuncDef(); +- // Provides access to the file contents. +- absl::string_view GetFileContent() const; +- +- private: +- // The string backing the in-memory file. +- std::string data_; +- // The current offset in the file. +- size_t offset_; +- // The `zlib_filefunc_def` implementation for this in-memory zip file. +- zlib_filefunc_def zlib_filefunc_def_; +- +- // Convenience function to access the current data size. +- size_t Size() const { return data_.size(); } +- +- // The file function implementations used in the `zlib_filefunc_def`. +- static voidpf OpenFile(voidpf opaque, const char* filename, int mode); +- static size_t ReadFile(voidpf opaque, voidpf stream, void* buf, size_t size); +- static size_t WriteFile(voidpf opaque, voidpf stream, const void* buf, +- size_t size); +- static ptrdiff_t TellFile(voidpf opaque, voidpf stream); +- static ptrdiff_t SeekFile(voidpf opaque, voidpf stream, size_t offset, +- int origin); +- static int CloseFile(voidpf opaque, voidpf stream); +- static int ErrorFile(voidpf opaque, voidpf stream); +-}; +- +-} // namespace metadata +-} // namespace tflite +- +-#endif // TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_ +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/resources/grace_hopper.jpg b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/resources/grace_hopper.jpg +old mode 100755 +new mode 100644 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/build_all.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/build_all.sh +old mode 100644 +new mode 100755 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/build_ios_framework.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/build_ios_framework.sh +old mode 100644 +new mode 100755 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh +old mode 100644 +new mode 100755 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common.sh +old mode 100644 +new mode 100755 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat +old mode 100644 +new mode 100755 +-- +2.34.1.307.g9b7440fafd-goog +
diff --git a/third_party/tflite_support/patches/0010-only-support-model-file-passed-in-from-mem.patch b/third_party/tflite_support/patches/0010-only-support-model-file-passed-in-from-mem.patch new file mode 100644 index 0000000..d97f8ed --- /dev/null +++ b/third_party/tflite_support/patches/0010-only-support-model-file-passed-in-from-mem.patch
@@ -0,0 +1,265 @@ +From 4b7e971e2f2f6ef3fd394858d975b64479047872 Mon Sep 17 00:00:00 2001 +From: Robert Ogden <robertogden@chromium.org> +Date: Mon, 20 Dec 2021 08:50:35 -0800 +Subject: [PATCH 10/11] only support model file passed in from mem + +--- + .../cc/task/core/external_file_handler.cc | 143 ++---------------- + .../cc/task/core/external_file_handler.h | 23 +-- + .../cc/task/core/tflite_engine.cc | 2 - + .../cc/task/core/tflite_engine.h | 2 - + 4 files changed, 10 insertions(+), 160 deletions(-) + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc +index dcde0c926c653..e91a54fb7d11a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc +@@ -15,45 +15,25 @@ limitations under the License. + + #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" + +-#include <errno.h> +-#include <fcntl.h> + #include <stddef.h> +-#include <sys/mman.h> +-#include <unistd.h> +- + #include <memory> + #include <string> + +-#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/common.h" +-#include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/port/status_macros.h" ++#include "tensorflow_lite_support/cc/port/statusor.h" + + namespace tflite { + namespace task { + namespace core { +-namespace { + + using ::absl::StatusCode; + using ::tflite::support::CreateStatusWithPayload; + using ::tflite::support::StatusOr; + using ::tflite::support::TfLiteSupportStatus; + +-// Gets the offset aligned to page size for mapping given files into memory by +-// file descriptor correctly, as according to mmap(2), the offset used in mmap +-// must be a multiple of sysconf(_SC_PAGE_SIZE). +-int64 GetPageSizeAlignedOffset(int64 offset) { +- int64 aligned_offset = offset; +- int64 page_size = sysconf(_SC_PAGE_SIZE); +- if (offset % page_size != 0) { +- aligned_offset = offset / page_size * page_size; +- } +- return aligned_offset; +-} +- +-} // namespace +- + /* static */ + StatusOr<std::unique_ptr<ExternalFileHandler>> + ExternalFileHandler::CreateFromExternalFile(const ExternalFile* external_file) { +@@ -71,123 +51,18 @@ absl::Status ExternalFileHandler::MapExternalFile() { + if (!external_file_.file_content().empty()) { + return absl::OkStatus(); + } +- if (external_file_.file_name().empty() && +- !external_file_.has_file_descriptor_meta()) { +- return CreateStatusWithPayload( +- StatusCode::kInvalidArgument, +- "ExternalFile must specify at least one of 'file_content', file_name' " +- "or 'file_descriptor_meta'.", +- TfLiteSupportStatus::kInvalidArgumentError); +- } +- // Obtain file descriptor, offset and size. +- int fd = -1; +- if (!external_file_.file_name().empty()) { +- owned_fd_ = open(external_file_.file_name().c_str(), O_RDONLY); +- if (owned_fd_ < 0) { +- const std::string error_message = absl::StrFormat( +- "Unable to open file at %s", external_file_.file_name()); +- switch (errno) { +- case ENOENT: +- return CreateStatusWithPayload( +- StatusCode::kNotFound, error_message, +- TfLiteSupportStatus::kFileNotFoundError); +- case EACCES: +- case EPERM: +- return CreateStatusWithPayload( +- StatusCode::kPermissionDenied, error_message, +- TfLiteSupportStatus::kFilePermissionDeniedError); +- case EINTR: +- return CreateStatusWithPayload(StatusCode::kUnavailable, +- error_message, +- TfLiteSupportStatus::kFileReadError); +- case EBADF: +- return CreateStatusWithPayload(StatusCode::kFailedPrecondition, +- error_message, +- TfLiteSupportStatus::kFileReadError); +- default: +- return CreateStatusWithPayload( +- StatusCode::kUnknown, +- absl::StrFormat("%s, errno=%d", error_message, errno), +- TfLiteSupportStatus::kFileReadError); +- } +- } +- fd = owned_fd_; +- } else { +- fd = external_file_.file_descriptor_meta().fd(); +- if (fd < 0) { +- return CreateStatusWithPayload( +- StatusCode::kInvalidArgument, +- absl::StrFormat("Provided file descriptor is invalid: %d < 0", fd), +- TfLiteSupportStatus::kInvalidArgumentError); +- } +- buffer_offset_ = external_file_.file_descriptor_meta().offset(); +- buffer_size_ = external_file_.file_descriptor_meta().length(); +- } +- // Get actual file size. Always use 0 as offset to lseek(2) to get the actual +- // file size, as SEEK_END returns the size of the file *plus* offset. +- size_t file_size = lseek(fd, /*offset=*/0, SEEK_END); +- if (file_size <= 0) { +- return CreateStatusWithPayload( +- StatusCode::kUnknown, +- absl::StrFormat("Unable to get file size, errno=%d", errno), +- TfLiteSupportStatus::kFileReadError); +- } +- // Deduce buffer size if not explicitly provided through file descriptor. +- if (buffer_size_ <= 0) { +- buffer_size_ = file_size - buffer_offset_; +- } +- // Check for out of range issues. +- if (file_size <= buffer_offset_) { +- return CreateStatusWithPayload( +- StatusCode::kInvalidArgument, +- absl::StrFormat("Provided file offset (%d) exceeds or matches actual " +- "file length (%d)", +- buffer_offset_, file_size), +- TfLiteSupportStatus::kInvalidArgumentError); +- } +- if (file_size < buffer_size_ + buffer_offset_) { +- return CreateStatusWithPayload( +- StatusCode::kInvalidArgument, +- absl::StrFormat("Provided file length + offset (%d) exceeds actual " +- "file length (%d)", +- buffer_size_ + buffer_offset_, file_size), +- TfLiteSupportStatus::kInvalidArgumentError); +- } +- // If buffer_offset_ is not multiple of sysconf(_SC_PAGE_SIZE), align with +- // extra leading bytes and adjust buffer_size_ to account for the extra +- // leading bytes. +- buffer_aligned_offset_ = GetPageSizeAlignedOffset(buffer_offset_); +- buffer_aligned_size_ = buffer_size_ + buffer_offset_ - buffer_aligned_offset_; +- // Map into memory. +- buffer_ = mmap(/*addr=*/nullptr, buffer_aligned_size_, PROT_READ, MAP_SHARED, +- fd, buffer_aligned_offset_); +- if (buffer_ == MAP_FAILED) { +- return CreateStatusWithPayload( +- StatusCode::kUnknown, +- absl::StrFormat("Unable to map file to memory buffer, errno=%d", errno), +- TfLiteSupportStatus::kFileMmapError); +- } +- return absl::OkStatus(); ++ ++ return CreateStatusWithPayload(StatusCode::kInvalidArgument, ++ "ExternalFile must specify 'file_content' " ++ "to be compatible with Chromium.", ++ TfLiteSupportStatus::kInvalidArgumentError); + } + + absl::string_view ExternalFileHandler::GetFileContent() { +- if (!external_file_.file_content().empty()) { +- return external_file_.file_content(); +- } else { +- return absl::string_view(static_cast<const char*>(buffer_) + +- buffer_offset_ - buffer_aligned_offset_, +- buffer_size_); +- } ++ return external_file_.file_content(); + } + +-ExternalFileHandler::~ExternalFileHandler() { +- if (buffer_ != MAP_FAILED) { +- munmap(buffer_, buffer_aligned_size_); +- } +- if (owned_fd_ >= 0) { +- close(owned_fd_); +- } +-} ++ExternalFileHandler::~ExternalFileHandler() = default; + + } // namespace core + } // namespace task +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h +index cf0bdf0b48037..48c62813e212e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h +@@ -18,7 +18,7 @@ limitations under the License. + + #include <memory> + +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/integral_types.h" + #include "tensorflow_lite_support/cc/port/statusor.h" +@@ -64,27 +64,6 @@ class ExternalFileHandler { + + // Reference to the input ExternalFile. + const ExternalFile& external_file_; +- +- // The file descriptor of the ExternalFile if provided by path, as it is +- // opened and owned by this class. Set to -1 otherwise. +- int owned_fd_{-1}; +- +- // Points to the memory buffer mapped from the file descriptor of the +- // ExternalFile, if provided by path or file descriptor. +- void* buffer_{}; +- +- // The mapped memory buffer offset, if any. +- int64 buffer_offset_{}; +- // The size in bytes of the mapped memory buffer, if any. +- int64 buffer_size_{}; +- +- // As mmap(2) requires the offset to be a multiple of sysconf(_SC_PAGE_SIZE): +- +- // The aligned mapped memory buffer offset, if any. +- int64 buffer_aligned_offset_{}; +- // The aligned mapped memory buffer size in bytes taking into account the +- // offset shift introduced by buffer_aligned_memory_offset_, if any. +- int64 buffer_aligned_size_{}; + }; + + } // namespace core +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc +index 8cd4585161df7..484b9a099ecdc 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc +@@ -15,8 +15,6 @@ limitations under the License. + + #include "tensorflow_lite_support/cc/task/core/tflite_engine.h" + +-#include <unistd.h> +- + #include <memory> + + #include "absl/strings/match.h" // from @com_google_absl +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h +index 9b44c6e5c022a..53dabdc4841d7 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h +@@ -16,8 +16,6 @@ limitations under the License. + #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ + #define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ + +-#include <sys/mman.h> +- + #include <memory> + + #include "absl/memory/memory.h" // from @com_google_absl +-- +2.34.1.307.g9b7440fafd-goog +
diff --git a/third_party/tflite_support/patches/0011-run-clang-format.patch b/third_party/tflite_support/patches/0011-run-clang-format.patch new file mode 100644 index 0000000..6f390e38 --- /dev/null +++ b/third_party/tflite_support/patches/0011-run-clang-format.patch
@@ -0,0 +1,43843 @@ +From 66d958947df0a4366b2a808e2a74e5ba412a2c38 Mon Sep 17 00:00:00 2001 +From: Robert Ogden <robertogden@chromium.org> +Date: Mon, 20 Dec 2021 11:40:47 -0800 +Subject: [PATCH 11/11] run clang format + +--- + .../configuration/edgetpu_coral_plugin.cc | 20 +- + .../edgetpu_coral_plugin_test.cc | 3 +- + .../src/tensorflow_lite_support/c/common.cc | 2 +- + .../src/tensorflow_lite_support/c/common.h | 4 +- + .../tensorflow_lite_support/c/common_utils.cc | 11 +- + .../tensorflow_lite_support/c/common_utils.h | 3 +- + .../c/task/text/bert_nl_classifier.cc | 6 +- + .../c/task/text/bert_nl_classifier.h | 6 +- + .../c/task/text/bert_question_answerer.cc | 3 +- + .../c/task/text/bert_question_answerer.h | 3 +- + .../c/task/text/nl_classifier.cc | 3 +- + .../c/task/text/nl_classifier.h | 3 +- + .../c/task/vision/image_classifier.cc | 9 +- + .../c/task/vision/image_classifier.h | 9 +- + .../c/task/vision/object_detector.h | 6 +- + .../test/task/vision/image_classifier_test.cc | 84 +- + .../src/tensorflow_lite_support/cc/common.cc | 2 +- + .../src/tensorflow_lite_support/cc/common.h | 5 +- + .../cc/port/default/status_macros.h | 2 +- + .../cc/port/default/statusor_internals.h | 38 +- + .../cc/port/default/tflite_wrapper.cc | 12 +- + .../cc/port/default/tflite_wrapper.h | 2 +- + .../cc/port/integral_types.h | 2 +- + .../cc/task/audio/audio_classifier.cc | 2 +- + .../cc/task/audio/audio_embedder.h | 6 +- + .../cc/task/audio/core/audio_buffer.h | 10 +- + .../cc/task/audio/utils/wav_io.cc | 19 +- + .../cc/task/audio/utils/wav_io.h | 6 +- + .../cc/task/core/base_task_api.h | 2 +- + .../cc/task/core/classification_head.h | 2 +- + .../cc/task/core/error_reporter.cc | 8 +- + .../cc/task/core/label_map_item.cc | 5 +- + .../cc/task/core/label_map_item.h | 7 +- + .../cc/task/core/proto/external_file.proto | 2 - + .../cc/task/core/score_calibration.cc | 8 +- + .../cc/task/core/score_calibration.h | 11 +- + .../cc/task/core/task_api_factory.h | 8 +- + .../cc/task/core/task_utils.h | 23 +- + .../cc/task/core/tflite_engine.cc | 14 +- + .../cc/task/core/tflite_engine.h | 13 +- + .../cc/task/processor/audio_preprocessor.cc | 5 +- + .../processor/classification_postprocessor.cc | 5 +- + .../task/processor/embedding_postprocessor.h | 10 +- + .../cc/task/processor/image_preprocessor.cc | 6 +- + .../cc/task/processor/processor.h | 5 +- + .../cc/task/processor/regex_preprocessor.cc | 3 +- + .../cc/task/processor/regex_preprocessor.h | 3 +- + .../cc/task/text/bert_nl_classifier.cc | 7 +- + .../cc/task/text/bert_nl_classifier.h | 2 +- + .../cc/task/text/bert_question_answerer.cc | 34 +- + .../cc/task/text/bert_question_answerer.h | 7 +- + .../task/text/nlclassifier/nl_classifier.cc | 16 +- + .../cc/task/text/nlclassifier/nl_classifier.h | 16 +- + .../cc/task/text/question_answerer.h | 6 +- + .../text/universal_sentence_encoder_qa.cc | 14 +- + .../task/text/universal_sentence_encoder_qa.h | 9 +- + .../task/vision/core/base_vision_task_api.h | 9 +- + .../cc/task/vision/core/classification_head.h | 2 +- + .../cc/task/vision/core/frame_buffer.h | 47 +- + .../cc/task/vision/core/label_map_item.cc | 5 +- + .../cc/task/vision/core/label_map_item.h | 7 +- + .../cc/task/vision/image_classifier.cc | 14 +- + .../cc/task/vision/image_classifier.h | 8 +- + .../cc/task/vision/image_embedder.cc | 17 +- + .../cc/task/vision/image_embedder.h | 9 +- + .../cc/task/vision/image_segmenter.cc | 15 +- + .../cc/task/vision/image_segmenter.h | 8 +- + .../cc/task/vision/object_detector.cc | 16 +- + .../cc/task/vision/object_detector.h | 5 +- + .../cc/task/vision/proto/segmentations.proto | 8 +- + .../vision/utils/frame_buffer_common_utils.cc | 59 +- + .../vision/utils/frame_buffer_common_utils.h | 37 +- + .../task/vision/utils/frame_buffer_utils.cc | 50 +- + .../cc/task/vision/utils/frame_buffer_utils.h | 40 +- + .../utils/frame_buffer_utils_interface.h | 11 +- + .../vision/utils/libyuv_frame_buffer_utils.cc | 81 +- + .../vision/utils/libyuv_frame_buffer_utils.h | 9 +- + .../cc/task/vision/utils/score_calibration.cc | 8 +- + .../cc/task/vision/utils/score_calibration.h | 11 +- + .../cc/test/common_test.cc | 2 +- + .../task/processor/image_preprocessor_test.cc | 13 +- + .../test/task/text/bert_nl_classifier_test.cc | 36 +- + .../task/text/bert_question_answerer_test.cc | 7 +- + .../text/nlclassifier/nl_classifier_test.cc | 83 +- + .../test/task/vision/image_classifier_test.cc | 149 +- + .../test/task/vision/image_embedder_test.cc | 95 +- + .../test/task/vision/image_segmenter_test.cc | 117 +- + .../test/task/vision/object_detector_test.cc | 157 +- + .../cc/test/test_utils.cc | 18 +- + .../cc/test/test_utils.h | 6 +- + .../cc/text/tokenizers/bert_tokenizer.cc | 3 +- + .../cc/text/tokenizers/bert_tokenizer.h | 3 +- + .../cc/text/tokenizers/bert_tokenizer_jni.cc | 25 +- + .../cc/text/tokenizers/regex_tokenizer.cc | 4 +- + .../cc/text/tokenizers/sentencepiece_jni.cc | 20 +- + .../cc/text/tokenizers/tokenizer_jni_lib.cc | 3 +- + .../cc/text/tokenizers/tokenizer_jni_lib.h | 3 +- + .../cc/text/tokenizers/tokenizer_utils.cc | 6 +- + .../cc/text/tokenizers/tokenizer_utils.h | 1 - + .../cc/utils/common_utils.cc | 3 +- + .../cc/utils/common_utils.h | 3 +- + .../cc/utils/jni_utils.cc | 7 +- + .../cc/utils/jni_utils.h | 8 +- + .../codegen/android_java_generator.cc | 37 +- + .../codegen/android_java_generator.h | 5 +- + .../codegen/code_generator.cc | 3 +- + .../codegen/code_generator.h | 3 +- + .../codegen/code_generator_test.cc | 3 +- + .../codegen/metadata_helper.h | 2 +- + .../codegen/python/codegen_lib.cc | 9 +- + .../tensorflow_lite_support/codegen/utils.cc | 36 +- + .../custom_ops/kernel/ngrams.cc | 7 +- + .../custom_ops/kernel/ngrams_op_resolver.cc | 2 +- + .../custom_ops/kernel/ngrams_test.cc | 9 +- + .../kernel/ragged/py_tflite_registerer.h | 2 +- + .../kernel/ragged/ragged_range_tflite.cc | 9 +- + .../kernel/ragged/ragged_range_tflite_test.cc | 3 +- + .../ragged/ragged_tensor_to_tensor_tflite.cc | 47 +- + .../ragged_tensor_to_tensor_tflite_test.cc | 6 +- + .../kernel/sentencepiece/model_converter.cc | 10 +- + .../kernel/sentencepiece/model_converter.h | 6 +- + .../sentencepiece/optimized_decoder_test.cc | 6 +- + .../kernel/sentencepiece/optimized_encoder.cc | 23 +- + .../kernel/sentencepiece/optimized_encoder.h | 10 +- + .../sentencepiece/optimized_encoder_test.cc | 8 +- + .../sentencepiece/py_tflite_registerer.h | 2 +- + .../sentencepiece_detokenizer_tflite.cc | 3 +- + .../sentencepiece_tokenizer_op.cc | 6 +- + .../sentencepiece_tokenizer_tflite.cc | 7 +- + .../custom_ops/kernel/whitespace_tokenizer.cc | 13 +- + .../whitespace_tokenizer_op_resolver.cc | 2 +- + .../audio/desktop/audio_classifier_demo.cc | 18 +- + .../audio/desktop/audio_classifier_lib.cc | 11 +- + .../task/audio/desktop/audio_classifier_lib.h | 3 +- + .../text/desktop/bert_nl_classifier_demo.cc | 14 +- + .../desktop/bert_question_answerer_demo.cc | 18 +- + .../task/text/desktop/nl_classifier_demo.cc | 14 +- + .../universal_sentence_encoder_qa_demo.cc | 32 +- + .../vision/desktop/image_classifier_demo.cc | 34 +- + .../vision/desktop/image_embedder_demo.cc | 30 +- + .../vision/desktop/image_segmenter_demo.cc | 24 +- + .../vision/desktop/object_detector_demo.cc | 40 +- + .../task/vision/desktop/utils/image_utils.cc | 12 +- + .../task/vision/desktop/utils/image_utils.h | 2 +- + .../ios/sources/TFLCommon.h | 11 +- + .../ios/sources/TFLCommonUtils.h | 41 +- + .../ios/sources/TFLCommonUtils.m | 42 +- + .../core/sources/TFLBaseOptions+Helpers.h | 2 +- + .../core/sources/TFLBaseOptions+Helpers.m | 2 +- + .../ios/task/core/sources/TFLBaseOptions.h | 32 +- + .../ios/task/core/sources/TFLBaseOptions.m | 16 +- + .../TFLClassificationOptions+Helpers.h | 6 +- + .../TFLClassificationOptions+Helpers.m | 66 +- + .../sources/TFLClassificationOptions.h | 9 +- + .../sources/TFLClassificationOptions.m | 5 +- + .../sources/TFLClassificationResult.h | 23 +- + .../utils/sources/TFLClassificationUtils.h | 21 +- + .../utils/sources/TFLClassificationUtils.m | 31 +- + .../Sources/TFLBertNLClassifier.h | 21 +- + .../Sources/TFLBertNLClassifier.m | 28 +- + .../nlclassifier/Sources/TFLNLClassifier.h | 47 +- + .../nlclassifier/Sources/TFLNLClassifier.m | 18 +- + .../Tests/TFLBertNLClassifierTest.m | 29 +- + .../nlclassifier/Tests/TFLNLClassifierTest.m | 28 +- + .../text/qa/Sources/TFLBertQuestionAnswerer.h | 4 +- + .../text/qa/Sources/TFLBertQuestionAnswerer.m | 23 +- + .../qa/Tests/TFLBertQuestionAnswererTest.m | 33 +- + .../task/vision/sources/TFLImageClassifier.h | 37 +- + .../task/vision/sources/TFLImageClassifier.m | 52 +- + .../task/vision/utils/sources/GMLImageUtils.h | 5 +- + .../task/vision/utils/sources/GMLImageUtils.m | 146 +- + .../TFLImageClassifierTests.m | 109 +- + .../tokenizers/Sources/TFLBertTokenizer.h | 6 +- + .../tokenizers/Sources/TFLBertTokenizer.mm | 10 +- + .../Sources/TFLSentencepieceTokenizer.h | 2 +- + .../Sources/TFLSentencepieceTokenizer.mm | 12 +- + .../text/tokenizers/Sources/TFLTokenizer.h | 4 +- + .../tokenizers/Sources/TFLTokenizerUtil.h | 11 +- + .../tokenizers/Sources/TFLTokenizerUtil.mm | 15 +- + .../lite/support/audio/TensorAudio.java | 524 ++--- + .../lite/support/common/FileUtil.java | 301 +-- + .../lite/support/common/Operator.java | 15 +- + .../lite/support/common/Processor.java | 2 +- + .../support/common/SequentialProcessor.java | 83 +- + .../lite/support/common/TensorOperator.java | 6 +- + .../lite/support/common/TensorProcessor.java | 57 +- + .../common/internal/SupportPreconditions.java | 302 +-- + .../lite/support/common/ops/CastOp.java | 55 +- + .../lite/support/common/ops/DequantizeOp.java | 9 +- + .../lite/support/common/ops/NormalizeOp.java | 245 ++- + .../lite/support/common/ops/QuantizeOp.java | 9 +- + .../lite/support/image/BitmapContainer.java | 116 +- + .../lite/support/image/BoundingBoxUtil.java | 369 ++-- + .../lite/support/image/ColorSpaceType.java | 623 +++--- + .../lite/support/image/ImageContainer.java | 36 +- + .../lite/support/image/ImageConversions.java | 217 +- + .../lite/support/image/ImageOperator.java | 41 +- + .../lite/support/image/ImageProcessor.java | 285 +-- + .../lite/support/image/ImageProperties.java | 91 +- + .../support/image/MediaImageContainer.java | 112 +- + .../lite/support/image/MlImageAdapter.java | 160 +- + .../support/image/TensorBufferContainer.java | 202 +- + .../lite/support/image/TensorImage.java | 677 +++--- + .../lite/support/image/ops/ResizeOp.java | 105 +- + .../image/ops/ResizeWithCropOrPadOp.java | 170 +- + .../lite/support/image/ops/Rot90Op.java | 141 +- + .../image/ops/TensorOperatorWrapper.java | 78 +- + .../image/ops/TransformToGrayscaleOp.java | 127 +- + .../lite/support/label/Category.java | 192 +- + .../lite/support/label/LabelUtil.java | 77 +- + .../lite/support/label/TensorLabel.java | 331 +-- + .../lite/support/label/ops/LabelAxisOp.java | 70 +- + .../lite/support/model/GpuDelegateProxy.java | 71 +- + .../tensorflow/lite/support/model/Model.java | 449 ++-- + .../support/tensorbuffer/TensorBuffer.java | 899 ++++---- + .../tensorbuffer/TensorBufferFloat.java | 181 +- + .../tensorbuffer/TensorBufferUint8.java | 188 +- + .../audio/classifier/AudioClassifier.java | 857 ++++---- + .../audio/classifier/Classifications.java | 28 +- + .../lite/task/core/BaseOptions.java | 105 +- + .../lite/task/core/BaseTaskApi.java | 122 +- + .../lite/task/core/ComputeSettings.java | 48 +- + .../lite/task/core/TaskJniUtils.java | 275 ++- + .../core/vision/ImageProcessingOptions.java | 125 +- + .../text/nlclassifier/BertNLClassifier.java | 391 ++-- + .../task/text/nlclassifier/NLClassifier.java | 568 ++--- + .../task/text/qa/BertQuestionAnswerer.java | 394 ++-- + .../lite/task/text/qa/QaAnswer.java | 60 +- + .../lite/task/text/qa/QuestionAnswerer.java | 19 +- + .../vision/classifier/Classifications.java | 25 +- + .../vision/classifier/ImageClassifier.java | 882 ++++---- + .../task/vision/core/BaseVisionTaskApi.java | 349 ++-- + .../lite/task/vision/detector/Detection.java | 26 +- + .../task/vision/detector/ObjectDetector.java | 873 ++++---- + .../task/vision/segmenter/ColoredLabel.java | 112 +- + .../task/vision/segmenter/ImageSegmenter.java | 752 ++++--- + .../task/vision/segmenter/OutputType.java | 202 +- + .../task/vision/segmenter/Segmentation.java | 106 +- + .../lite/support/audio/TensorAudioTest.java | 486 ++--- + .../lite/support/common/FileUtilTest.java | 129 +- + .../support/common/TensorProcessorTest.java | 91 +- + .../lite/support/common/ops/CastOpTest.java | 91 +- + .../support/common/ops/DequantizeOpTest.java | 23 +- + .../support/common/ops/NormalizeOpTest.java | 217 +- + .../support/common/ops/QuantizeOpTest.java | 21 +- + .../support/image/BoundingBoxUtilTest.java | 343 ++-- + .../image/ColorSpaceTypeInstrumentedTest.java | 37 +- + .../support/image/ColorSpaceTypeTest.java | 703 +++---- + .../ImageConversionsInstrumentedTest.java | 338 +-- + .../support/image/ImageConversionsTest.java | 164 +- + .../image/ImageProcessorInstrumentedTest.java | 221 +- + .../support/image/ImageProcessorTest.java | 209 +- + .../support/image/MlImageAdapterTest.java | 259 +-- + .../image/TensorImageInstrumentedTest.java | 208 +- + .../lite/support/image/TensorImageTest.java | 1391 ++++++------- + .../lite/support/image/TestImageCreator.java | 183 +- + .../image/ops/ResizeOpInstrumentedTest.java | 103 +- + ...ResizeWithCropOrPadOpInstrumentedTest.java | 239 ++- + .../image/ops/Rot90OpInstrumentedTest.java | 122 +- + ...ransformToGrayScaleOpInstrumentedTest.java | 104 +- + .../lite/support/label/CategoryTest.java | 204 +- + .../lite/support/label/LabelUtilTest.java | 47 +- + .../lite/support/label/TensorLabelTest.java | 327 +-- + .../support/label/ops/LabelAxisOpTest.java | 160 +- + .../GpuDelegateProxyInstrumentedTest.java | 18 +- + .../support/model/GpuDelegateProxyTest.java | 11 +- + .../lite/support/model/ModelTest.java | 244 +-- + .../tensorbuffer/TensorBufferFloatTest.java | 82 +- + .../tensorbuffer/TensorBufferTest.java | 1707 +++++++-------- + .../tensorbuffer/TensorBufferUint8Test.java | 82 +- + .../audio/classifier/audio_classifier_jni.cc | 42 +- + .../src/native/task/core/task_jni_utils.cc | 5 +- + .../bert/bert_nl_classifier_jni.cc | 23 +- + .../text/nlclassifier/nl_classifier_jni.cc | 21 +- + .../text/qa/bert_question_answerer_jni.cc | 24 +- + .../vision/classifier/image_classifier_jni.cc | 27 +- + .../vision/core/base_vision_task_api_jni.cc | 40 +- + .../vision/detector/object_detector_jni.cc | 27 +- + .../java/src/native/task/vision/jni_utils.cc | 30 +- + .../java/src/native/task/vision/jni_utils.h | 28 +- + .../vision/segmenter/image_segmenter_jni.cc | 32 +- + .../metadata/cc/metadata_extractor.cc | 21 +- + .../metadata/cc/metadata_extractor.h | 4 +- + .../metadata/cc/metadata_populator.h | 7 +- + .../metadata/cc/metadata_version.cc | 35 +- + .../flatbuffers_lib/flatbuffers_lib.cc | 2 +- + .../support/metadata/BoundedInputStream.java | 138 +- + .../support/metadata/ByteBufferChannel.java | 188 +- + .../support/metadata/MetadataExtractor.java | 622 +++--- + .../lite/support/metadata/MetadataParser.java | 12 +- + .../lite/support/metadata/ModelInfo.java | 448 ++-- + .../support/metadata/ModelMetadataInfo.java | 243 ++- + .../lite/support/metadata/Preconditions.java | 306 +-- + .../metadata/SeekableByteChannelCompat.java | 140 +- + .../lite/support/metadata/ZipFile.java | 686 +++---- + .../metadata/BoundedInputStreamTest.java | 429 ++-- + .../metadata/ByteBufferChannelTest.java | 480 +++-- + .../metadata/MetadataExtractorTest.java | 1828 ++++++++--------- + .../support/metadata/MetadataParserTest.java | 18 +- + .../lite/support/metadata/ZipFileTest.java | 206 +- + .../odml/ios/image/apis/GMLImage.h | 47 +- + .../odml/ios/image/sources/GMLImage.m | 2 +- + .../odml/ios/image/tests/GMLImageTests.m | 73 +- + .../android/odml/image/BitmapExtractor.java | 43 +- + .../odml/image/BitmapImageContainer.java | 70 +- + .../odml/image/BitmapMlImageBuilder.java | 137 +- + .../odml/image/ByteBufferExtractor.java | 421 ++-- + .../odml/image/ByteBufferImageContainer.java | 68 +- + .../odml/image/ByteBufferMlImageBuilder.java | 135 +- + .../android/odml/image/ImageContainer.java | 12 +- + .../android/odml/image/ImageProperties.java | 92 +- + .../odml/image/MediaImageContainer.java | 81 +- + .../odml/image/MediaImageExtractor.java | 42 +- + .../odml/image/MediaMlImageBuilder.java | 105 +- + .../google/android/odml/image/MlImage.java | 423 ++-- + .../odml/image/BitmapExtractorTest.java | 46 +- + .../odml/image/BitmapMlImageBuilderTest.java | 116 +- + .../odml/image/ByteBufferExtractorTest.java | 264 ++- + .../image/ByteBufferMlImageBuilderTest.java | 93 +- + .../odml/image/MediaImageExtractorTest.java | 48 +- + .../odml/image/MediaMlImageBuilderTest.java | 109 +- + .../android/odml/image/TestImageCreator.java | 211 +- + .../src/third_party/fft2d/fft.h | 12 +- + .../src/third_party/fft2d/fft2d.h | 12 +- + 324 files changed, 17479 insertions(+), 17052 deletions(-) + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc +index 9f27f3baae82f..6a16d12856258 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc +@@ -17,12 +17,12 @@ limitations under the License. + + #include <glog/logging.h> + #include "absl/container/node_hash_map.h" // from @com_google_absl +-#include "absl/memory/memory.h" // from @com_google_absl +-#include "absl/strings/match.h" // from @com_google_absl +-#include "absl/strings/numbers.h" // from @com_google_absl +-#include "tflite/public/edgetpu_c.h" ++#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/strings/match.h" // from @com_google_absl ++#include "absl/strings/numbers.h" // from @com_google_absl + #include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h" + #include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h" ++#include "tflite/public/edgetpu_c.h" + + namespace tflite { + namespace delegates { +@@ -50,12 +50,16 @@ inline std::string ConvertBool(bool from_bool) { + return from_bool ? "True" : "False"; + } + +-bool MatchDevice(const std::string& device, const std::string& type, ++bool MatchDevice(const std::string& device, ++ const std::string& type, + int* index) { + const auto prefix(type + ":"); +- if (!absl::StartsWith(device, prefix)) return false; +- if (!absl::SimpleAtoi(device.substr(prefix.size()), index)) return false; +- if (*index < 0) return false; ++ if (!absl::StartsWith(device, prefix)) ++ return false; ++ if (!absl::SimpleAtoi(device.substr(prefix.size()), index)) ++ return false; ++ if (*index < 0) ++ return false; + return true; + } + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc +index 83cb6f24b1277..cc183a65a9e5f 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc +@@ -43,7 +43,8 @@ using ::tflite::task::vision::ImageDataFree; + + using EdgeTpuCoralPluginTest = testing::TestWithParam<std::string>; + +-INSTANTIATE_TEST_SUITE_P(CoralPluginTests, EdgeTpuCoralPluginTest, ++INSTANTIATE_TEST_SUITE_P(CoralPluginTests, ++ EdgeTpuCoralPluginTest, + testing::Values(kRegularModelFilePath, + kEdgeTpuModelFilePath)); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc +index 2a182bbd6535a..f0974ed26b826 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc +@@ -17,7 +17,7 @@ limitations under the License. + + #include <cstdlib> + +-void TfLiteSupportErrorDelete(TfLiteSupportError *error) { ++void TfLiteSupportErrorDelete(TfLiteSupportError* error) { + // `strdup` obtains memory using `malloc` and the memory needs to be + // released using `free`. + free(error->message); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common.h b/third_party/tflite_support/src/tensorflow_lite_support/c/common.h +index 1e21f1dcb31dc..3ced64226987f 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/common.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common.h +@@ -190,10 +190,10 @@ typedef struct TfLiteSupportError { + // Holds the error code. + enum TfLiteSupportErrorCode code; + // Detailed description of the error. +- char *message; ++ char* message; + } TfLiteSupportError; + +-void TfLiteSupportErrorDelete(TfLiteSupportError *error); ++void TfLiteSupportErrorDelete(TfLiteSupportError* error); + + #ifdef __cplusplus + } // extern "C" +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc +index 39287377c4b36..39afb9c8cbdf3 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc +@@ -18,15 +18,17 @@ limitations under the License. + #include <string> + + #include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/cord.h" // from @com_google_absl ++#include "absl/strings/cord.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/common.h" + + namespace tflite { + namespace support { + + void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code, +- const char* message, TfLiteSupportError** error) { +- if (error == nullptr) return; ++ const char* message, ++ TfLiteSupportError** error) { ++ if (error == nullptr) ++ return; + + *error = new TfLiteSupportError; + (*error)->code = code; +@@ -35,7 +37,8 @@ void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code, + + void CreateTfLiteSupportErrorWithStatus(const absl::Status& status, + TfLiteSupportError** error) { +- if (status.ok() || error == nullptr) return; ++ if (status.ok() || error == nullptr) ++ return; + + // Payload of absl::Status created by the tflite task library stores an + // appropriate value of the enum TfLiteSupportStatus. The integer value +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h +index 6959029575663..551f64a598970 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h +@@ -27,7 +27,8 @@ namespace support { + + // Creates a TfLiteSupportError with a TfLiteSupportErrorCode and message. + void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code, +- const char* message, TfLiteSupportError** error); ++ const char* message, ++ TfLiteSupportError** error); + + // Creates a TfLiteSupportError from absl::Status and passes it back as a + // parameter which is a pointer to the error pointer. +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc +index 26888a832fc34..52907f4fe7d35 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc +@@ -40,7 +40,8 @@ struct TfLiteBertNLClassifier { + }; + + TfLiteBertNLClassifier* TfLiteBertNLClassifierCreateFromOptions( +- const char* model_path, const TfLiteBertNLClassifierOptions* options) { ++ const char* model_path, ++ const TfLiteBertNLClassifierOptions* options) { + BertNLClassifierOptionsCpp cc_options; + + cc_options.mutable_base_options()->mutable_model_file()->set_file_name( +@@ -64,7 +65,8 @@ TfLiteBertNLClassifier* TfLiteBertNLClassifierCreate(const char* model_path) { + } + + Categories* TfLiteBertNLClassifierClassify( +- const TfLiteBertNLClassifier* classifier, const char* text) { ++ const TfLiteBertNLClassifier* classifier, ++ const char* text) { + std::vector<CategoryCpp> results = + + classifier->impl->Classify(absl::string_view(text).data()); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h +index 430f5735c6bd2..94138a291233b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h +@@ -48,7 +48,8 @@ typedef struct TfLiteBertNLClassifierOptions { + // Creates TfLiteBertNLClassifier from model path and options, returns nullptr + // if the file doesn't exist or is not a well formatted TFLite model path. + TfLiteBertNLClassifier* TfLiteBertNLClassifierCreateFromOptions( +- const char* model_path, const TfLiteBertNLClassifierOptions* options); ++ const char* model_path, ++ const TfLiteBertNLClassifierOptions* options); + + // Creates TfLiteBertNLClassifier from model path and default options, returns + // nullptr if the file doesn't exist or is not a well formatted TFLite model +@@ -57,7 +58,8 @@ TfLiteBertNLClassifier* TfLiteBertNLClassifierCreate(const char* model_path); + + // Invokes the encapsulated TFLite model and classifies the input text. + Categories* TfLiteBertNLClassifierClassify( +- const TfLiteBertNLClassifier* classifier, const char* text); ++ const TfLiteBertNLClassifier* classifier, ++ const char* text); + + void TfLiteBertNLClassifierDelete(TfLiteBertNLClassifier* classifier); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc +index d0d1639357348..1887d5234d180 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc +@@ -48,7 +48,8 @@ TfLiteBertQuestionAnswerer* TfLiteBertQuestionAnswererCreate( + } + + TfLiteQaAnswers* TfLiteBertQuestionAnswererAnswer( +- const TfLiteBertQuestionAnswerer* question_answerer, const char* context, ++ const TfLiteBertQuestionAnswerer* question_answerer, ++ const char* context, + const char* question) { + std::vector<QaAnswerCpp> answers = question_answerer->impl->Answer( + absl::string_view(context).data(), absl::string_view(question).data()); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h +index 7bc6e6ed385db..e9a1190356914 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h +@@ -58,7 +58,8 @@ TfLiteBertQuestionAnswerer* TfLiteBertQuestionAnswererCreate( + // Invokes the encapsulated TFLite model and answers a question based on + // context. + TfLiteQaAnswers* TfLiteBertQuestionAnswererAnswer( +- const TfLiteBertQuestionAnswerer* question_answerer, const char* context, ++ const TfLiteBertQuestionAnswerer* question_answerer, ++ const char* context, + const char* question); + + void TfLiteBertQuestionAnswererDelete( +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc +index d6d86f67a620a..1e6805c1d1cd6 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc +@@ -37,7 +37,8 @@ struct TfLiteNLClassifier { + }; + + TfLiteNLClassifier* TfLiteNLClassifierCreateFromOptions( +- const char* model_path, const TfLiteNLClassifierOptions* options) { ++ const char* model_path, ++ const TfLiteNLClassifierOptions* options) { + auto classifier_status = NLClassifierCpp::CreateFromFileAndOptions( + std::string(model_path), + { +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h +index c47dd59b13eb4..389ca5d686df0 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h +@@ -48,7 +48,8 @@ typedef struct TfLiteNLClassifierOptions { + // Creates TfLiteNLClassifier from model path and options, returns nullptr if + // the file doesn't exist or is not a well formatted TFLite model path. + TfLiteNLClassifier* TfLiteNLClassifierCreateFromOptions( +- const char* model_path, const TfLiteNLClassifierOptions* options); ++ const char* model_path, ++ const TfLiteNLClassifierOptions* options); + + // Invokes the encapsulated TFLite model and classifies the input text. + Categories* TfLiteNLClassifierClassify(const TfLiteNLClassifier* classifier, +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc +index edf3889059b27..8981e66b41d0c 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc +@@ -108,7 +108,8 @@ TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate() { + } + + TfLiteImageClassifier* TfLiteImageClassifierFromOptions( +- const TfLiteImageClassifierOptions* options, TfLiteSupportError** error) { ++ const TfLiteImageClassifierOptions* options, ++ TfLiteSupportError** error) { + StatusOr<ImageClassifierOptionsCpp> cpp_option_status = + CreateImageClassifierCppOptionsFromCOptions(options); + +@@ -175,7 +176,8 @@ TfLiteClassificationResult* GetClassificationResultCStruct( + + TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi( + const TfLiteImageClassifier* classifier, +- const TfLiteFrameBuffer* frame_buffer, const TfLiteBoundingBox* roi, ++ const TfLiteFrameBuffer* frame_buffer, ++ const TfLiteBoundingBox* roi, + TfLiteSupportError** error) { + if (classifier == nullptr) { + tflite::support::CreateTfLiteSupportError( +@@ -219,7 +221,8 @@ TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi( + + TfLiteClassificationResult* TfLiteImageClassifierClassify( + const TfLiteImageClassifier* classifier, +- const TfLiteFrameBuffer* frame_buffer, TfLiteSupportError** error) { ++ const TfLiteFrameBuffer* frame_buffer, ++ TfLiteSupportError** error) { + return TfLiteImageClassifierClassifyWithRoi(classifier, frame_buffer, nullptr, + error); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h +index 290e57d56f5a1..8a53e5e2a079e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h +@@ -158,7 +158,8 @@ TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate(); + // TfLiteSupportErrorDelete(error) + // + TfLiteImageClassifier* TfLiteImageClassifierFromOptions( +- const TfLiteImageClassifierOptions* options, TfLiteSupportError** error); ++ const TfLiteImageClassifierOptions* options, ++ TfLiteSupportError** error); + + // Invokes the encapsulated TFLite model and classifies the frame_buffer. + // Returns a pointer to the created classification result in case of success or +@@ -186,7 +187,8 @@ TfLiteImageClassifier* TfLiteImageClassifierFromOptions( + // + TfLiteClassificationResult* TfLiteImageClassifierClassify( + const TfLiteImageClassifier* classifier, +- const TfLiteFrameBuffer* frame_buffer, TfLiteSupportError** error); ++ const TfLiteFrameBuffer* frame_buffer, ++ TfLiteSupportError** error); + + // Invokes the encapsulated TFLite model and classifies the region of the + // frame_buffer specified by the bounding box. Same as TfLiteImageClassifier* +@@ -198,7 +200,8 @@ TfLiteClassificationResult* TfLiteImageClassifierClassify( + // operations. + TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi( + const TfLiteImageClassifier* classifier, +- const TfLiteFrameBuffer* frame_buffer, const TfLiteBoundingBox* roi, ++ const TfLiteFrameBuffer* frame_buffer, ++ const TfLiteBoundingBox* roi, + TfLiteSupportError** error); + + // Disposes off the image classifier. +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h +index a46cf043aeb24..5a2d3e1d1e4d2 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h +@@ -157,7 +157,8 @@ TfLiteObjectDetectorOptions TfLiteObjectDetectorOptionsCreate(); + // TfLiteSupportErrorDelete(error) + // + TfLiteObjectDetector* TfLiteObjectDetectorFromOptions( +- const TfLiteObjectDetectorOptions* options, TfLiteSupportError** error); ++ const TfLiteObjectDetectorOptions* options, ++ TfLiteSupportError** error); + + // Invokes the encapsulated TFLite model and performs object detection on the + // frame_buffer. Returns a pointer to the created object detection result result +@@ -185,7 +186,8 @@ TfLiteObjectDetector* TfLiteObjectDetectorFromOptions( + // TfLiteSupportErrorDelete(error) + // + TfLiteDetectionResult* TfLiteObjectDetectorDetect( +- const TfLiteObjectDetector* detector, const TfLiteFrameBuffer* frame_buffer, ++ const TfLiteObjectDetector* detector, ++ const TfLiteFrameBuffer* frame_buffer, + TfLiteSupportError** error); + + // Disposes off the object detector. +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc +index 688af14580ab3..b398b7adafe5c 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc +@@ -44,8 +44,8 @@ constexpr char kMobileNetQuantizedWithMetadata[] = + "mobilenet_v1_0.25_224_quant.tflite"; + + StatusOr<ImageData> LoadImage(const char* image_name) { +- return DecodeImageFromFile(JoinPath("./" /*test src dir*/, +- kTestDataDirectory, image_name)); ++ return DecodeImageFromFile( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name)); + } + + class ImageClassifierFromOptionsTest : public tflite_shims::testing::Test {}; +@@ -56,7 +56,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithNullOptionsAndError) { + TfLiteImageClassifierFromOptions(nullptr, &error); + + EXPECT_EQ(image_classifier, nullptr); +- if (image_classifier) TfLiteImageClassifierDelete(image_classifier); ++ if (image_classifier) ++ TfLiteImageClassifierDelete(image_classifier); + + ASSERT_NE(error, nullptr); + EXPECT_EQ(error->code, kInvalidArgumentError); +@@ -71,7 +72,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPath) { + TfLiteImageClassifier* image_classifier = + TfLiteImageClassifierFromOptions(&options, nullptr); + EXPECT_EQ(image_classifier, nullptr); +- if (image_classifier) TfLiteImageClassifierDelete(image_classifier); ++ if (image_classifier) ++ TfLiteImageClassifierDelete(image_classifier); + } + + TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPathAndError) { +@@ -82,7 +84,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPathAndError) { + TfLiteImageClassifierFromOptions(&options, &error); + + EXPECT_EQ(image_classifier, nullptr); +- if (image_classifier) TfLiteImageClassifierDelete(image_classifier); ++ if (image_classifier) ++ TfLiteImageClassifierDelete(image_classifier); + + ASSERT_NE(error, nullptr); + EXPECT_EQ(error->code, kInvalidArgumentError); +@@ -93,9 +96,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPathAndError) { + } + + TEST_F(ImageClassifierFromOptionsTest, SucceedsWithModelPath) { +- std::string model_path = +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetQuantizedWithMetadata); ++ std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, ++ kMobileNetQuantizedWithMetadata); + TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); + TfLiteImageClassifier* image_classifier = +@@ -106,9 +108,8 @@ TEST_F(ImageClassifierFromOptionsTest, SucceedsWithModelPath) { + } + + TEST_F(ImageClassifierFromOptionsTest, SucceedsWithNumberOfThreadsAndError) { +- std::string model_path = +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetQuantizedWithMetadata); ++ std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, ++ kMobileNetQuantizedWithMetadata); + TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); + options.base_options.compute_settings.cpu_settings.num_threads = 3; +@@ -120,15 +121,16 @@ TEST_F(ImageClassifierFromOptionsTest, SucceedsWithNumberOfThreadsAndError) { + EXPECT_NE(image_classifier, nullptr); + EXPECT_EQ(error, nullptr); + +- if (image_classifier) TfLiteImageClassifierDelete(image_classifier); +- if (error) TfLiteSupportErrorDelete(error); ++ if (image_classifier) ++ TfLiteImageClassifierDelete(image_classifier); ++ if (error) ++ TfLiteSupportErrorDelete(error); + } + + TEST_F(ImageClassifierFromOptionsTest, + FailsWithClassNameDenyListAndClassNameAllowListAndError) { +- std::string model_path = +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetQuantizedWithMetadata); ++ std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, ++ kMobileNetQuantizedWithMetadata); + + TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); +@@ -146,7 +148,8 @@ TEST_F(ImageClassifierFromOptionsTest, + TfLiteImageClassifierFromOptions(&options, &error); + + EXPECT_EQ(image_classifier, nullptr); +- if (image_classifier) TfLiteImageClassifierDelete(image_classifier); ++ if (image_classifier) ++ TfLiteImageClassifierDelete(image_classifier); + + ASSERT_NE(error, nullptr); + EXPECT_EQ(error->code, kInvalidArgumentError); +@@ -158,7 +161,8 @@ TEST_F(ImageClassifierFromOptionsTest, + + TEST(ImageClassifierNullClassifierClassifyTest, + FailsWithNullImageClassifierAndError) { +- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, ++ LoadImage("burger-224.png")); + + TfLiteSupportError* error = nullptr; + TfLiteClassificationResult* classification_result = +@@ -181,9 +185,8 @@ TEST(ImageClassifierNullClassifierClassifyTest, + class ImageClassifierClassifyTest : public tflite_shims::testing::Test { + protected: + void SetUp() override { +- std::string model_path = +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetQuantizedWithMetadata); ++ std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, ++ kMobileNetQuantizedWithMetadata); + + TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); +@@ -196,7 +199,8 @@ class ImageClassifierClassifyTest : public tflite_shims::testing::Test { + }; + + TEST_F(ImageClassifierClassifyTest, SucceedsWithImageData) { +- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, ++ LoadImage("burger-224.png")); + + TfLiteFrameBuffer frame_buffer = { + .format = kRGB, +@@ -223,7 +227,8 @@ TEST_F(ImageClassifierClassifyTest, SucceedsWithImageData) { + } + + TEST_F(ImageClassifierClassifyTest, FailsWithNullFrameBufferAndError) { +- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, ++ LoadImage("burger-224.png")); + + TfLiteSupportError* error = nullptr; + TfLiteClassificationResult* classification_result = +@@ -244,7 +249,8 @@ TEST_F(ImageClassifierClassifyTest, FailsWithNullFrameBufferAndError) { + } + + TEST_F(ImageClassifierClassifyTest, FailsWithNullImageDataAndError) { +- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, ++ LoadImage("burger-224.png")); + + TfLiteFrameBuffer frame_buffer = {.format = kRGB, .orientation = kTopLeft}; + +@@ -267,7 +273,8 @@ TEST_F(ImageClassifierClassifyTest, FailsWithNullImageDataAndError) { + } + + TEST_F(ImageClassifierClassifyTest, SucceedsWithRoiWithinImageBounds) { +- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, ++ LoadImage("burger-224.png")); + + TfLiteFrameBuffer frame_buffer = { + .format = kRGB, +@@ -298,7 +305,8 @@ TEST_F(ImageClassifierClassifyTest, SucceedsWithRoiWithinImageBounds) { + } + + TEST_F(ImageClassifierClassifyTest, FailsWithRoiOutsideImageBoundsAndError) { +- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, ++ LoadImage("burger-224.png")); + + TfLiteFrameBuffer frame_buffer = { + .format = kRGB, +@@ -330,9 +338,8 @@ TEST_F(ImageClassifierClassifyTest, FailsWithRoiOutsideImageBoundsAndError) { + TEST(ImageClassifierWithUserDefinedOptionsClassifyTest, + SucceedsWithClassNameDenyList) { + char* denylisted_label_name = (char*)"cheeseburger"; +- std::string model_path = +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetQuantizedWithMetadata); ++ std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, ++ kMobileNetQuantizedWithMetadata); + + TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); +@@ -345,7 +352,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest, + TfLiteImageClassifierFromOptions(&options, nullptr); + ASSERT_NE(image_classifier, nullptr); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, ++ LoadImage("burger-224.png")); + + TfLiteFrameBuffer frame_buffer = { + .format = kRGB, +@@ -357,7 +365,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest, + TfLiteImageClassifierClassify(image_classifier, &frame_buffer, nullptr); + + ImageDataFree(&image_data); +- if (image_classifier) TfLiteImageClassifierDelete(image_classifier); ++ if (image_classifier) ++ TfLiteImageClassifierDelete(image_classifier); + + ASSERT_NE(classification_result, nullptr); + EXPECT_GE(classification_result->size, 1); +@@ -374,10 +383,9 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest, + TEST(ImageClassifierWithUserDefinedOptionsClassifyTest, + SucceedsWithClassNameAllowList) { + char* allowlisted_label_name = (char*)"cheeseburger"; +- std::string model_path = +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetQuantizedWithMetadata) +- .data(); ++ std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, ++ kMobileNetQuantizedWithMetadata) ++ .data(); + + TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); +@@ -390,7 +398,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest, + TfLiteImageClassifierFromOptions(&options, nullptr); + ASSERT_NE(image_classifier, nullptr); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, ++ LoadImage("burger-224.png")); + + TfLiteFrameBuffer frame_buffer = { + .format = kRGB, +@@ -402,7 +411,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest, + TfLiteImageClassifierClassify(image_classifier, &frame_buffer, nullptr); + + ImageDataFree(&image_data); +- if (image_classifier) TfLiteImageClassifierDelete(image_classifier); ++ if (image_classifier) ++ TfLiteImageClassifierDelete(image_classifier); + + ASSERT_NE(classification_result, nullptr); + EXPECT_GE(classification_result->size, 1); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc +index abfef722d6659..09e9a83e07bef 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc +@@ -15,7 +15,7 @@ limitations under the License. + + #include "tensorflow_lite_support/cc/common.h" + +-#include "absl/strings/cord.h" // from @com_google_absl ++#include "absl/strings/cord.h" // from @com_google_absl + #include "absl/strings/str_cat.h" // from @com_google_absl + + namespace tflite { +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h +index b06e9f58459af..71dd920b86bed 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h +@@ -16,7 +16,7 @@ limitations under the License. + #ifndef TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_ + #define TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_ + +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl + + namespace tflite { +@@ -164,7 +164,8 @@ enum class TfLiteSupportStatus { + // more than returning an object identical to an OK status. See `absl::Status` + // for more details. + absl::Status CreateStatusWithPayload( +- absl::StatusCode canonical_code, absl::string_view message, ++ absl::StatusCode canonical_code, ++ absl::string_view message, + tflite::support::TfLiteSupportStatus tfls_code = + tflite::support::TfLiteSupportStatus::kError); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h +index 14999ca37b7ac..cb145dbd232c8 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h +@@ -18,7 +18,7 @@ limitations under the License. + #define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_ + + #include "absl/base/optimization.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + + // Evaluates an expression that produces a `absl::Status`. If the status is not + // ok, returns it from the current function. +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h +index dc04c293c6ffd..81ec3c1ab5f86 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h +@@ -21,8 +21,8 @@ limitations under the License. + #include <utility> + + #include "absl/meta/type_traits.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/utility/utility.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/utility/utility.h" // from @com_google_absl + + namespace tflite { + namespace support { +@@ -63,7 +63,8 @@ struct IsDirectInitializationAmbiguous + U>::value, + std::false_type, + IsDirectInitializationAmbiguous< +- T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {}; ++ T, ++ absl::remove_cv_t<absl::remove_reference_t<U>>>> {}; + + template <typename T, typename V> + struct IsDirectInitializationAmbiguous<T, tflite::support::StatusOr<V>> +@@ -101,7 +102,8 @@ struct IsForwardingAssignmentAmbiguous + U>::value, + std::false_type, + IsForwardingAssignmentAmbiguous< +- T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {}; ++ T, ++ absl::remove_cv_t<absl::remove_reference_t<U>>>> {}; + + template <typename T, typename U> + struct IsForwardingAssignmentAmbiguous<T, tflite::support::StatusOr<U>> +@@ -136,7 +138,8 @@ template <typename T, typename... Args> + void PlacementNew(void* p, Args&&... args) { + #if defined(__GNUC__) && !defined(__clang__) + // Teach gcc that 'p' cannot be null, fixing code size issues. +- if (p == nullptr) __builtin_unreachable(); ++ if (p == nullptr) ++ __builtin_unreachable(); + #endif + new (p) T(std::forward<Args>(args)...); + } +@@ -207,7 +210,8 @@ class StatusOrData { + } + + StatusOrData& operator=(const StatusOrData& other) { +- if (this == &other) return *this; ++ if (this == &other) ++ return *this; + if (other.ok()) + Assign(other.data_); + else +@@ -216,7 +220,8 @@ class StatusOrData { + } + + StatusOrData& operator=(StatusOrData&& other) { +- if (this == &other) return *this; ++ if (this == &other) ++ return *this; + if (other.ok()) + Assign(std::move(other.data_)); + else +@@ -295,15 +300,18 @@ class StatusOrData { + }; + + void Clear() { +- if (ok()) data_.~T(); ++ if (ok()) ++ data_.~T(); + } + + void EnsureOk() const { +- if (ABSL_PREDICT_FALSE(!ok())) Helper::Crash(status_); ++ if (ABSL_PREDICT_FALSE(!ok())) ++ Helper::Crash(status_); + } + + void EnsureNotOk() { +- if (ABSL_PREDICT_FALSE(ok())) Helper::HandleInvalidStatusCtorArg(&status_); ++ if (ABSL_PREDICT_FALSE(ok())) ++ Helper::HandleInvalidStatusCtorArg(&status_); + } + + // Construct the value (ie. data_) through placement new with the passed +@@ -362,8 +370,9 @@ struct MoveCtorBase<T, false> { + MoveCtorBase& operator=(MoveCtorBase&&) = default; + }; + +-template <typename T, bool = std::is_copy_constructible<T>::value&& +- std::is_copy_assignable<T>::value> ++template <typename T, ++ bool = std::is_copy_constructible<T>::value&& ++ std::is_copy_assignable<T>::value> + struct CopyAssignBase { + CopyAssignBase() = default; + CopyAssignBase(const CopyAssignBase&) = default; +@@ -381,8 +390,9 @@ struct CopyAssignBase<T, false> { + CopyAssignBase& operator=(CopyAssignBase&&) = default; + }; + +-template <typename T, bool = std::is_move_constructible<T>::value&& +- std::is_move_assignable<T>::value> ++template <typename T, ++ bool = std::is_move_constructible<T>::value&& ++ std::is_move_assignable<T>::value> + struct MoveAssignBase { + MoveAssignBase() = default; + MoveAssignBase(const MoveAssignBase&) = default; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc +index 6334c02d738a6..0b3e5d6a2269a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc +@@ -15,7 +15,7 @@ limitations under the License. + + #include "tensorflow_lite_support/cc/port/default/tflite_wrapper.h" + +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow/lite/c/common.h" + #include "tensorflow/lite/delegates/interpreter_utils.h" +@@ -53,7 +53,8 @@ TfLiteInterpreterWrapper::TfLiteInterpreterWrapper( + : delegate_(nullptr, nullptr), + got_error_do_not_delegate_anymore_(false), + default_model_namespace_(default_model_namespace), +- default_model_id_(default_model_id), mini_benchmark_(nullptr) {} ++ default_model_id_(default_model_id), ++ mini_benchmark_(nullptr) {} + + std::string TfLiteInterpreterWrapper::ModelNamespace() { + const auto& ns_from_acceleration = +@@ -299,7 +300,9 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithoutFallback() { + return absl::OkStatus(); + } + +-void TfLiteInterpreterWrapper::Cancel() { cancel_flag_.Set(true); } ++void TfLiteInterpreterWrapper::Cancel() { ++ cancel_flag_.Set(true); ++} + + void TfLiteInterpreterWrapper::SetTfLiteCancellation() { + // Create a cancellation check function and set to the TFLite interpreter. +@@ -312,7 +315,8 @@ void TfLiteInterpreterWrapper::SetTfLiteCancellation() { + } + + absl::Status TfLiteInterpreterWrapper::LoadDelegatePlugin( +- const std::string& name, const tflite::TFLiteSettings& tflite_settings) { ++ const std::string& name, ++ const tflite::TFLiteSettings& tflite_settings) { + delegate_plugin_ = DelegatePluginRegistry::CreateByName( + absl::StrFormat("%sPlugin", name), tflite_settings); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h +index 278ae7643264e..9f32fa8735ccf 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h +@@ -19,7 +19,7 @@ limitations under the License. + #include <string> + #include <utility> + +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "flatbuffers/flatbuffers.h" // from @flatbuffers + #include "tensorflow/lite/c/common.h" + #include "tensorflow/lite/experimental/acceleration/configuration/configuration.pb.h" +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h +index 0d808ab24d6cc..dc6183bee693c 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h +@@ -37,7 +37,7 @@ typedef unsigned long uword_t; + #define GG_LL_FORMAT "ll" // As in "%lld". Note that "q" is poor form also. + #define GG_LL_FORMAT_W L"ll" + +-const uint8 kuint8max{0xFF}; ++const uint8 kuint8max{0xFF}; + const uint16 kuint16max{0xFFFF}; + const uint32 kuint32max{0xFFFFFFFF}; + const uint64 kuint64max{GG_ULONGLONG(0xFFFFFFFFFFFFFFFF)}; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc +index 4b1439dcc0719..4be3e53c11972 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc +@@ -17,7 +17,7 @@ limitations under the License. + + #include <initializer_list> + +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow/lite/c/c_api_types.h" + #include "tensorflow_lite_support/cc/common.h" +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h +index 28b379996cb42..a3d4c5717f239 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h +@@ -27,9 +27,9 @@ limitations under the License. + namespace tflite { + namespace task { + namespace audio { +-class AudioEmbedder +- : public tflite::task::core::BaseTaskApi< +- tflite::task::processor::EmbeddingResult, const AudioBuffer&> { ++class AudioEmbedder : public tflite::task::core::BaseTaskApi< ++ tflite::task::processor::EmbeddingResult, ++ const AudioBuffer&> { + public: + // Use base class constructor. + using BaseTaskApi::BaseTaskApi; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h +index 39110ed8d0b15..d922e48af25bc 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h +@@ -17,8 +17,8 @@ limitations under the License. + + #include <memory> + +-#include "absl/memory/memory.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/common.h" + #include "tensorflow_lite_support/cc/port/statusor.h" +@@ -41,7 +41,8 @@ class AudioBuffer { + // Factory method for creating an AudioBuffer object. The internal buffer does + // not take the ownership of the input backing buffer. + static tflite::support::StatusOr<std::unique_ptr<AudioBuffer>> Create( +- const float* audio_buffer, int buffer_size, ++ const float* audio_buffer, ++ int buffer_size, + const AudioFormat& audio_format) { + return absl::make_unique<AudioBuffer>(audio_buffer, buffer_size, + audio_format); +@@ -50,7 +51,8 @@ class AudioBuffer { + // AudioBuffer for internal use only. Uses the factory method to construct + // AudioBuffer instance. The internal buffer does not take the ownership of + // the input backing buffer. +- AudioBuffer(const float* audio_buffer, int buffer_size, ++ AudioBuffer(const float* audio_buffer, ++ int buffer_size, + const AudioFormat& audio_format) + : audio_buffer_(audio_buffer), + buffer_size_(buffer_size), +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc +index 3c0ad996a9919..9ae3fbec70543 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc +@@ -27,9 +27,9 @@ limitations under the License. + #include <fstream> + #include <limits> + +-#include "absl/base/casts.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/str_cat.h" // from @com_google_absl ++#include "absl/base/casts.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/str_cat.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/status_macros.h" + +@@ -62,7 +62,9 @@ std::string ReadFile(const std::string filepath) { + + // Handles moving the data index forward, validating the arguments, and avoiding + // overflow or underflow. +-absl::Status IncrementOffset(int old_offset, size_t increment, size_t max_size, ++absl::Status IncrementOffset(int old_offset, ++ size_t increment, ++ size_t max_size, + int* new_offset) { + if (old_offset < 0) { + return absl::InvalidArgumentError( +@@ -87,7 +89,8 @@ absl::Status IncrementOffset(int old_offset, size_t increment, size_t max_size, + } + + absl::Status ExpectText(const std::string& data, +- const std::string& expected_text, int* offset) { ++ const std::string& expected_text, ++ int* offset) { + int new_offset; + RETURN_IF_ERROR( + IncrementOffset(*offset, expected_text.size(), data.size(), &new_offset)); +@@ -101,8 +104,10 @@ absl::Status ExpectText(const std::string& data, + return absl::OkStatus(); + } + +-absl::Status ReadString(const std::string& data, int expected_length, +- std::string* value, int* offset) { ++absl::Status ReadString(const std::string& data, ++ int expected_length, ++ std::string* value, ++ int* offset) { + int new_offset; + RETURN_IF_ERROR( + IncrementOffset(*offset, expected_length, data.size(), &new_offset)); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h +index 51271fc065c83..9aca5d06f7985 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h +@@ -20,9 +20,9 @@ limitations under the License. + + #define TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_UTILS_WAV_IO_H_ + ++#include <cstdint> + #include <string> + #include <vector> +-#include <cstdint> + + #include "absl/status/status.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/status_macros.h" +@@ -64,7 +64,9 @@ absl::Status DecodeLin16WaveAsFloatVector(const std::string& wav_string, + + // Handles moving the data index forward, validating the arguments, and avoiding + // overflow or underflow. +-absl::Status IncrementOffset(int old_offset, size_t increment, size_t max_size, ++absl::Status IncrementOffset(int old_offset, ++ size_t increment, ++ size_t max_size, + int* new_offset); + + // This function is only exposed in the header for testing purposes, as a +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h +index d743383734b42..effd42f0f0336 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h +@@ -18,7 +18,7 @@ limitations under the License. + + #include <utility> + +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl + #include "tensorflow/lite/c/common.h" + #include "tensorflow_lite_support/cc/common.h" +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h +index c868060f9894a..c91552f7ec82e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h +@@ -18,7 +18,7 @@ limitations under the License. + #include <string> + #include <vector> + +-#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/core/label_map_item.h" +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc +index 80dea95cce24b..a626ce6030b96 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc +@@ -35,9 +35,13 @@ int ErrorReporter::Report(const char* format, va_list args) { + return num_characters; + } + +-std::string ErrorReporter::message() { return last_message_; } ++std::string ErrorReporter::message() { ++ return last_message_; ++} + +-std::string ErrorReporter::previous_message() { return second_last_message_; } ++std::string ErrorReporter::previous_message() { ++ return second_last_message_; ++} + + } // namespace core + } // namespace task +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc +index 694c55ab34e78..72e4b670cb172 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc +@@ -15,7 +15,7 @@ limitations under the License. + #include "tensorflow_lite_support/cc/task/core/label_map_item.h" + + #include "absl/strings/str_format.h" // from @com_google_absl +-#include "absl/strings/str_split.h" // from @com_google_absl ++#include "absl/strings/str_split.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/common.h" + + namespace tflite { +@@ -28,7 +28,8 @@ using ::tflite::support::StatusOr; + using ::tflite::support::TfLiteSupportStatus; + + StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles( +- absl::string_view labels_file, absl::string_view display_names_file) { ++ absl::string_view labels_file, ++ absl::string_view display_names_file) { + if (labels_file.empty()) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Expected non-empty labels file.", +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h +index 4d8422a2a572d..d8e1f70d8fab1 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h +@@ -20,8 +20,8 @@ limitations under the License. + + #include "absl/container/flat_hash_map.h" // from @com_google_absl + #include "absl/container/flat_hash_set.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/string_view.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/string_view.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + + namespace tflite { +@@ -49,7 +49,8 @@ struct LabelMapItem { + // Returns an error e.g. if there's a mismatch between the number of labels and + // display names. + tflite::support::StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles( +- absl::string_view labels_file, absl::string_view display_names_file); ++ absl::string_view labels_file, ++ absl::string_view display_names_file); + + // A class that represents a hierarchy of labels as specified in a label map. + // +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/external_file.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/external_file.proto +index c0a42124e1b50..91b6a214b1253 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/external_file.proto ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/external_file.proto +@@ -17,7 +17,6 @@ syntax = "proto2"; + + package tflite.task.core; + +- + // Represents external files used by the Task APIs (e.g. TF Lite FlatBuffer or + // plain-text labels file). The files can be specified by one of the following + // three ways: +@@ -64,4 +63,3 @@ message FileDescriptorMeta { + // offset of a given asset obtained from AssetFileDescriptor#getStartOffset(). + optional int64 offset = 3; + } +- +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc +index 818839a77e43d..e7faebad487b9 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc +@@ -19,11 +19,11 @@ limitations under the License. + #include <utility> + #include <vector> + +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/str_format.h" // from @com_google_absl +-#include "absl/strings/str_split.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/str_format.h" // from @com_google_absl ++#include "absl/strings/str_split.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl +-#include "absl/types/optional.h" // from @com_google_absl ++#include "absl/types/optional.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/common.h" + #include "tensorflow_lite_support/cc/port/status_macros.h" + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h +index c1b945f76ab48..6e2b308bef101 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h +@@ -23,9 +23,9 @@ limitations under the License. + #include <vector> + + #include "absl/container/flat_hash_map.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/string_view.h" // from @com_google_absl +-#include "absl/types/optional.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/string_view.h" // from @com_google_absl ++#include "absl/types/optional.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/core/label_map_item.h" + #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" +@@ -37,7 +37,10 @@ namespace core { + // Sigmoid structure. + struct Sigmoid { + Sigmoid() : scale(1.0) {} +- Sigmoid(std::string label, float slope, float offset, float scale = 1.0, ++ Sigmoid(std::string label, ++ float slope, ++ float offset, ++ float scale = 1.0, + absl::optional<float> min_uncalibrated_score = absl::nullopt) + : label(label), + slope(slope), +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h +index 4e4b42cceaff7..f42d703fd1ae8 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h +@@ -18,7 +18,7 @@ limitations under the License. + + #include <memory> + +-#include "absl/base/macros.h" // from @com_google_absl ++#include "absl/base/macros.h" // from @com_google_absl + #include "absl/status/status.h" // from @com_google_absl + #include "tensorflow/lite/core/api/op_resolver.h" + #include "tensorflow/lite/kernels/op_macros.h" +@@ -48,7 +48,8 @@ class TaskAPIFactory { + "Use CreateFromBaseOptions and configure model input from " + "tensorflow_lite_support/cc/task/core/proto/base_options.proto") + static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromBuffer( +- const char* buffer_data, size_t buffer_size, ++ const char* buffer_data, ++ size_t buffer_size, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>(), + int num_threads = 1, +@@ -151,7 +152,8 @@ class TaskAPIFactory { + private: + template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr> + static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromTfLiteEngine( +- std::unique_ptr<TfLiteEngine> engine, int num_threads, ++ std::unique_ptr<TfLiteEngine> engine, ++ int num_threads, + const tflite::proto::ComputeSettings& compute_settings = + tflite::proto::ComputeSettings()) { + tflite::proto::ComputeSettings settings_copy = +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h +index e95ea73a4a812..7cde474dcd8f6 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h +@@ -21,9 +21,9 @@ limitations under the License. + #include <numeric> + #include <vector> + +-#include "absl/memory/memory.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/str_cat.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/str_cat.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "flatbuffers/flatbuffers.h" // from @flatbuffers + #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +@@ -65,9 +65,11 @@ tflite::support::StatusOr<T*> AssertAndReturnTypedTensor( + // type or has not the same number of elements. + // Note: std::negation is not used because it is from C++17, where the code will + // be compiled using C++14 in OSS. +-template <typename T, typename = std::enable_if_t< +- std::is_same<T, std::string>::value == false>> +-inline absl::Status PopulateTensor(const T* data, int num_elements, ++template < ++ typename T, ++ typename = std::enable_if_t<std::is_same<T, std::string>::value == false>> ++inline absl::Status PopulateTensor(const T* data, ++ int num_elements, + TfLiteTensor* tensor) { + T* v; + ASSIGN_OR_RETURN(v, AssertAndReturnTypedTensor<T>(tensor)); +@@ -92,7 +94,8 @@ inline absl::Status PopulateTensor(const std::vector<T>& data, + + template <> + inline absl::Status PopulateTensor<std::string>( +- const std::vector<std::string>& data, TfLiteTensor* tensor) { ++ const std::vector<std::string>& data, ++ TfLiteTensor* tensor) { + if (tensor->type != kTfLiteString) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, +@@ -143,7 +146,8 @@ inline absl::Status PopulateVector(const TfLiteTensor* tensor, + + template <> + inline absl::Status PopulateVector<std::string>( +- const TfLiteTensor* tensor, std::vector<std::string>* data) { ++ const TfLiteTensor* tensor, ++ std::vector<std::string>* data) { + if (tensor->type != typeToTfLiteType<std::string>()) { + return absl::InvalidArgumentError("not of type string"); + } +@@ -161,7 +165,8 @@ inline absl::Status PopulateVector<std::string>( + // Note: std::negation is not used because it is from C++17, where the code will + // be compiled using C++14 in OSS. + template < +- class TRepeatedField, class T = float, ++ class TRepeatedField, ++ class T = float, + typename = std::enable_if_t<std::is_same<T, std::string>::value == false>> + inline absl::Status PopulateVectorToRepeated(const TfLiteTensor* tensor, + TRepeatedField* data) { +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc +index 484b9a099ecdc..0b34bad4f18f7 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc +@@ -17,7 +17,7 @@ limitations under the License. + + #include <memory> + +-#include "absl/strings/match.h" // from @com_google_absl ++#include "absl/strings/match.h" // from @com_google_absl + #include "absl/strings/str_cat.h" // from @com_google_absl + #include "tensorflow/lite/builtin_ops.h" + #include "tensorflow/lite/core/shims/cc/kernels/register.h" +@@ -53,7 +53,8 @@ using ::tflite::support::CreateStatusWithPayload; + using ::tflite::support::InterpreterCreationResources; + using ::tflite::support::TfLiteSupportStatus; + +-bool TfLiteEngine::Verifier::Verify(const char* data, int length, ++bool TfLiteEngine::Verifier::Verify(const char* data, ++ int length, + tflite::ErrorReporter* reporter) { + return tflite_shims::Verify(data, length, reporter); + } +@@ -84,7 +85,8 @@ std::vector<const TfLiteTensor*> TfLiteEngine::GetOutputs() { + } + + void TfLiteEngine::VerifyAndBuildModelFromBuffer( +- const char* buffer_data, size_t buffer_size, ++ const char* buffer_data, ++ size_t buffer_size, + TfLiteVerifier* extra_verifier) { + model_ = tflite_shims::FlatBufferModel::VerifyAndBuildFromBuffer( + buffer_data, buffer_size, extra_verifier, &error_reporter_); +@@ -131,7 +133,8 @@ absl::Status TfLiteEngine::InitializeFromModelFileHandler( + } + + absl::Status TfLiteEngine::BuildModelFromFlatBuffer( +- const char* buffer_data, size_t buffer_size, ++ const char* buffer_data, ++ size_t buffer_size, + const tflite::proto::ComputeSettings& compute_settings) { + if (model_) { + return CreateStatusWithPayload(StatusCode::kInternal, +@@ -220,7 +223,8 @@ absl::Status TfLiteEngine::InitInterpreter(int num_threads) { + // absl::Status TfLiteEngine::InitInterpreter( + // const tflite::proto::ComputeSettings& compute_settings) + absl::Status TfLiteEngine::InitInterpreter( +- const tflite::proto::ComputeSettings& compute_settings, int num_threads) { ++ const tflite::proto::ComputeSettings& compute_settings, ++ int num_threads) { + ComputeSettings settings_copy = ComputeSettings(compute_settings); + settings_copy.mutable_tflite_settings() + ->mutable_cpu_settings() +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h +index 53dabdc4841d7..0cbaa738e6db6 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h +@@ -18,8 +18,8 @@ limitations under the License. + + #include <memory> + +-#include "absl/memory/memory.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl + #include "tensorflow/lite/core/api/op_resolver.h" + #include "tensorflow/lite/core/shims/c/common.h" +@@ -96,7 +96,8 @@ class TfLiteEngine { + // object. This performs extra verification on the input data using + // tflite::Verify. + absl::Status BuildModelFromFlatBuffer( +- const char* buffer_data, size_t buffer_size, ++ const char* buffer_data, ++ size_t buffer_size, + const tflite::proto::ComputeSettings& compute_settings = + tflite::proto::ComputeSettings()); + +@@ -138,7 +139,8 @@ class TfLiteEngine { + // absl::Status TfLiteEngine::InitInterpreter( + // const tflite::proto::ComputeSettings& compute_settings) + absl::Status InitInterpreter( +- const tflite::proto::ComputeSettings& compute_settings, int num_threads); ++ const tflite::proto::ComputeSettings& compute_settings, ++ int num_threads); + + // Cancels the on-going `Invoke()` call if any and if possible. This method + // can be called from a different thread than the one where `Invoke()` is +@@ -155,7 +157,8 @@ class TfLiteEngine { + // the FlatBuffer data provided as input. + class Verifier : public tflite::TfLiteVerifier { + public: +- bool Verify(const char* data, int length, ++ bool Verify(const char* data, ++ int length, + tflite::ErrorReporter* reporter) override; + }; + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc +index e3ea2b134e3f4..254d0689e5ecc 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc +@@ -14,7 +14,7 @@ limitations under the License. + ==============================================================================*/ + #include "tensorflow_lite_support/cc/task/processor/audio_preprocessor.h" + +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/common.h" + #include "tensorflow_lite_support/cc/port/statusor.h" +@@ -29,7 +29,8 @@ namespace { + // Looks up AudioProperty from metadata. If no error occurs, the returned value + // is guaranteed to be valid (not null). + tflite::support::StatusOr<const AudioProperties*> GetAudioPropertiesSafe( +- const TensorMetadata* tensor_metadata, int input_index) { ++ const TensorMetadata* tensor_metadata, ++ int input_index) { + if (tensor_metadata->content() == nullptr || + tensor_metadata->content()->content_properties() == nullptr) { + return CreateStatusWithPayload( +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc +index 9c11083c4f839..63962003f5e77 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc +@@ -17,7 +17,7 @@ limitations under the License. + + #include <memory> + +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow/lite/c/c_api_types.h" + #include "tensorflow_lite_support/cc/port/status_macros.h" +@@ -42,7 +42,8 @@ using ::tflite::task::core::ScoreCalibration; + /* static */ + tflite::support::StatusOr<std::unique_ptr<ClassificationPostprocessor>> + ClassificationPostprocessor::Create( +- core::TfLiteEngine* engine, const std::initializer_list<int> output_indices, ++ core::TfLiteEngine* engine, ++ const std::initializer_list<int> output_indices, + std::unique_ptr<ClassificationOptions> options) { + ASSIGN_OR_RETURN(auto processor, + Processor::Create<ClassificationPostprocessor>( +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h +index fdc872a23d3d4..78cef8ab57e3d 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h +@@ -66,8 +66,8 @@ class EmbeddingPostprocessor : public Postprocessor { + + // Performs actual cosine similarity computation. + template <typename T> +- static tflite::support::StatusOr<double> ComputeCosineSimilarity( +- const T* u, const T* v, int num_elements); ++ static tflite::support::StatusOr<double> ++ ComputeCosineSimilarity(const T* u, const T* v, int num_elements); + + template <typename T> + void NormalizeFeatureVector(T* feature_vector) const; +@@ -143,7 +143,8 @@ void EmbeddingPostprocessor::QuantizeFeatureVector(T* feature_vector) const { + /* static */ + template <typename T> + tflite::support::StatusOr<double> +-EmbeddingPostprocessor::ComputeCosineSimilarity(const T* u, const T* v, ++EmbeddingPostprocessor::ComputeCosineSimilarity(const T* u, ++ const T* v, + int num_elements) { + if (num_elements <= 0) { + return CreateStatusWithPayload( +@@ -171,7 +172,8 @@ EmbeddingPostprocessor::ComputeCosineSimilarity(const T* u, const T* v, + /* static */ + template <typename T> + tflite::support::StatusOr<double> EmbeddingPostprocessor::CosineSimilarity( +- const T& u, const T& v) { ++ const T& u, ++ const T& v) { + if (u.has_value_string() && v.has_value_string()) { + if (u.value_string().size() != v.value_string().size()) { + return CreateStatusWithPayload( +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc +index 7ad4ad4703789..310a1f5eba724 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc +@@ -36,7 +36,8 @@ using ::tflite::task::vision::FrameBuffer; + /* static */ + tflite::support::StatusOr<std::unique_ptr<ImagePreprocessor>> + ImagePreprocessor::Create( +- core::TfLiteEngine* engine, const std::initializer_list<int> input_indices, ++ core::TfLiteEngine* engine, ++ const std::initializer_list<int> input_indices, + const vision::FrameBufferUtils::ProcessEngine& process_engine) { + ASSIGN_OR_RETURN(auto processor, + Processor::Create<ImagePreprocessor>( +@@ -49,7 +50,8 @@ ImagePreprocessor::Create( + + // Returns false if image preprocessing could be skipped, true otherwise. + bool ImagePreprocessor::IsImagePreprocessingNeeded( +- const FrameBuffer& frame_buffer, const BoundingBox& roi) { ++ const FrameBuffer& frame_buffer, ++ const BoundingBox& roi) { + // Is crop required? + if (roi.origin_x() != 0 || roi.origin_y() != 0 || + roi.width() != frame_buffer.dimension().width || +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h +index 4aad40b2afd97..b3c43605ac82e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h +@@ -18,7 +18,7 @@ limitations under the License. + #include <initializer_list> + #include <vector> + +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow/lite/core/shims/c/common.h" + #include "tensorflow_lite_support/cc/common.h" +@@ -52,7 +52,8 @@ class Processor { + // num_expected_tensors, engine, tensor_indices); + template <typename T, EnableIfProcessorSubclass<T> = nullptr> + static tflite::support::StatusOr<std::unique_ptr<T>> Create( +- int num_expected_tensors, tflite::task::core::TfLiteEngine* engine, ++ int num_expected_tensors, ++ tflite::task::core::TfLiteEngine* engine, + const std::initializer_list<int> tensor_indices, + bool requires_metadata = true) { + auto processor = absl::make_unique<T>(engine, tensor_indices); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc +index af923b4d6f2c1..58b77b6952de1 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc +@@ -55,7 +55,8 @@ StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile( + + /* static */ + StatusOr<std::unique_ptr<RegexPreprocessor>> RegexPreprocessor::Create( +- tflite::task::core::TfLiteEngine* engine, int input_tensor_index) { ++ tflite::task::core::TfLiteEngine* engine, ++ int input_tensor_index) { + ASSIGN_OR_RETURN(auto processor, Processor::Create<RegexPreprocessor>( + /* num_expected_tensors = */ 1, engine, + {input_tensor_index}, +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h +index 1f92bcc18e524..bdd4e5e207a12 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h +@@ -34,7 +34,8 @@ namespace processor { + class RegexPreprocessor : public TextPreprocessor { + public: + static tflite::support::StatusOr<std::unique_ptr<RegexPreprocessor>> Create( +- tflite::task::core::TfLiteEngine* engine, int input_tensor_index); ++ tflite::task::core::TfLiteEngine* engine, ++ int input_tensor_index); + + absl::Status Preprocess(const std::string& text); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc +index c52f73be8b7a8..ac8fa548c669d 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc +@@ -23,8 +23,8 @@ limitations under the License. + #include <utility> + #include <vector> + +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/ascii.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/ascii.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow/lite/c/common.h" + #include "tensorflow/lite/core/api/op_resolver.h" +@@ -76,7 +76,8 @@ int GetLastDimSize(const TfLiteTensor* tensor) { + } // namespace + + absl::Status BertNLClassifier::Preprocess( +- const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) { ++ const std::vector<TfLiteTensor*>& input_tensors, ++ const std::string& input) { + auto* input_tensor_metadatas = + GetMetadataExtractor()->GetInputTensorMetadata(); + auto* ids_tensor = +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h +index 541b5561d5c6d..91bcfe50712d0 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h +@@ -22,7 +22,7 @@ limitations under the License. + #include <string> + #include <vector> + +-#include "absl/base/macros.h" // from @com_google_absl ++#include "absl/base/macros.h" // from @com_google_absl + #include "absl/status/status.h" // from @com_google_absl + #include "tensorflow/lite/c/common.h" + #include "tensorflow/lite/core/api/op_resolver.h" +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc +index 6b37649d4fbfd..591b70e84eb22 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc +@@ -15,8 +15,8 @@ limitations under the License. + + #include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h" + +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/str_join.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/str_join.h" // from @com_google_absl + #include "absl/strings/str_split.h" // from @com_google_absl + #include "tensorflow/lite/core/shims/cc/kernels/register.h" + #include "tensorflow_lite_support/cc/port/status_macros.h" +@@ -111,7 +111,8 @@ StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateFromFd( + + StatusOr<std::unique_ptr<QuestionAnswerer>> + BertQuestionAnswerer::CreateBertQuestionAnswererFromFile( +- const std::string& path_to_model, const std::string& path_to_vocab) { ++ const std::string& path_to_model, ++ const std::string& path_to_vocab) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, +@@ -125,8 +126,10 @@ BertQuestionAnswerer::CreateBertQuestionAnswererFromFile( + + StatusOr<std::unique_ptr<QuestionAnswerer>> + BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer( +- const char* model_buffer_data, size_t model_buffer_size, +- const char* vocab_buffer_data, size_t vocab_buffer_size) { ++ const char* model_buffer_data, ++ size_t model_buffer_size, ++ const char* vocab_buffer_data, ++ size_t vocab_buffer_size) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, +@@ -141,7 +144,8 @@ BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer( + + StatusOr<std::unique_ptr<QuestionAnswerer>> + BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile( +- const std::string& path_to_model, const std::string& path_to_spmodel) { ++ const std::string& path_to_model, ++ const std::string& path_to_spmodel) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, +@@ -155,8 +159,10 @@ BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile( + + StatusOr<std::unique_ptr<QuestionAnswerer>> + BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer( +- const char* model_buffer_data, size_t model_buffer_size, +- const char* spmodel_buffer_data, size_t spmodel_buffer_size) { ++ const char* model_buffer_data, ++ size_t model_buffer_size, ++ const char* spmodel_buffer_data, ++ size_t spmodel_buffer_size) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, +@@ -170,14 +176,16 @@ BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer( + } + + std::vector<QaAnswer> BertQuestionAnswerer::Answer( +- const std::string& context, const std::string& question) { ++ const std::string& context, ++ const std::string& question) { + // The BertQuestionAnswererer implementation for Preprocess() and + // Postprocess() never returns errors: just call value(). + return Infer(context, question).value(); + } + + absl::Status BertQuestionAnswerer::Preprocess( +- const std::vector<TfLiteTensor*>& input_tensors, const std::string& context, ++ const std::vector<TfLiteTensor*>& input_tensors, ++ const std::string& context, + const std::string& query) { + auto* input_tensor_metadatas = + GetMetadataExtractor()->GetInputTensorMetadata(); +@@ -392,7 +400,8 @@ void BertQuestionAnswerer::InitializeBertTokenizer( + } + + void BertQuestionAnswerer::InitializeBertTokenizerFromBinary( +- const char* vocab_buffer_data, size_t vocab_buffer_size) { ++ const char* vocab_buffer_data, ++ size_t vocab_buffer_size) { + tokenizer_ = + absl::make_unique<BertTokenizer>(vocab_buffer_data, vocab_buffer_size); + } +@@ -403,7 +412,8 @@ void BertQuestionAnswerer::InitializeSentencepieceTokenizer( + } + + void BertQuestionAnswerer::InitializeSentencepieceTokenizerFromBinary( +- const char* spmodel_buffer_data, size_t spmodel_buffer_size) { ++ const char* spmodel_buffer_data, ++ size_t spmodel_buffer_size) { + tokenizer_ = absl::make_unique<SentencePieceTokenizer>(spmodel_buffer_data, + spmodel_buffer_size); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h +index f041cc8e51637..52ec835371386 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h +@@ -16,9 +16,9 @@ limitations under the License. + #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_QA_BERT_QUESTION_ANSWERER_H_ + #define TENSORFLOW_LITE_SUPPORT_CC_TASK_QA_BERT_QUESTION_ANSWERER_H_ + +-#include "absl/base/macros.h" // from @com_google_absl ++#include "absl/base/macros.h" // from @com_google_absl + #include "absl/container/flat_hash_map.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/core/base_task_api.h" + #include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +@@ -136,7 +136,8 @@ class BertQuestionAnswerer : public QuestionAnswerer { + void InitializeSentencepieceTokenizer(const std::string& path_to_spmodel); + // Initialize API with a SentencepieceTokenizer from the model buffer. + void InitializeSentencepieceTokenizerFromBinary( +- const char* spmodel_buffer_data, size_t spmodel_buffer_size); ++ const char* spmodel_buffer_data, ++ size_t spmodel_buffer_size); + + // Initialize the API with the tokenizer set in the metadata. + absl::Status InitializeFromMetadata( +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc +index d3697f326db1b..6986bcc665733 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc +@@ -22,8 +22,8 @@ limitations under the License. + #include <vector> + + #include "absl/algorithm/container.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/str_cat.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/str_cat.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl + #include "flatbuffers/flatbuffers.h" // from @flatbuffers + #include "tensorflow/lite/c/common.h" +@@ -200,7 +200,8 @@ std::vector<Category> NLClassifier::Classify(const std::string& text) { + } + + absl::Status NLClassifier::Preprocess( +- const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) { ++ const std::vector<TfLiteTensor*>& input_tensors, ++ const std::string& input) { + TfLiteTensor* input_tensor = FindTensorWithNameOrIndex( + input_tensors, GetMetadataExtractor()->GetInputTensorMetadata(), + struct_options_.input_tensor_name, struct_options_.input_tensor_index); +@@ -446,7 +447,8 @@ StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromOptions( + + StatusOr<std::unique_ptr<NLClassifier>> + NLClassifier::CreateFromBufferAndOptions( +- const char* model_buffer_data, size_t model_buffer_size, ++ const char* model_buffer_data, ++ size_t model_buffer_size, + const NLClassifierOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + std::unique_ptr<NLClassifier> nl_classifier; +@@ -459,7 +461,8 @@ NLClassifier::CreateFromBufferAndOptions( + } + + StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFileAndOptions( +- const std::string& path_to_model, const NLClassifierOptions& options, ++ const std::string& path_to_model, ++ const NLClassifierOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + std::unique_ptr<NLClassifier> nl_classifier; + ASSIGN_OR_RETURN(nl_classifier, +@@ -470,7 +473,8 @@ StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFileAndOptions( + } + + StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFdAndOptions( +- int fd, const NLClassifierOptions& options, ++ int fd, ++ const NLClassifierOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + std::unique_ptr<NLClassifier> nl_classifier; + ASSIGN_OR_RETURN(nl_classifier, +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h +index 2adafba8f2fa9..331a6e4274342 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h +@@ -23,8 +23,8 @@ limitations under the License. + #include <string> + #include <vector> + +-#include "absl/base/macros.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/base/macros.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "flatbuffers/flatbuffers.h" // from @flatbuffers + #include "tensorflow/lite/c/common.h" + #include "tensorflow/lite/core/api/op_resolver.h" +@@ -109,7 +109,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>, + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") + static tflite::support::StatusOr<std::unique_ptr<NLClassifier>> + CreateFromBufferAndOptions( +- const char* model_buffer_data, size_t model_buffer_size, ++ const char* model_buffer_data, ++ size_t model_buffer_size, + const NLClassifierOptions& options = {}, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); +@@ -118,7 +119,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>, + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") + static tflite::support::StatusOr<std::unique_ptr<NLClassifier>> + CreateFromFileAndOptions( +- const std::string& path_to_model, const NLClassifierOptions& options = {}, ++ const std::string& path_to_model, ++ const NLClassifierOptions& options = {}, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); + +@@ -126,7 +128,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>, + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") + static tflite::support::StatusOr<std::unique_ptr<NLClassifier>> + CreateFromFdAndOptions( +- int fd, const NLClassifierOptions& options = {}, ++ int fd, ++ const NLClassifierOptions& options = {}, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); + +@@ -177,7 +180,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>, + const std::vector<TensorType*>& tensors, + const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>* + metadata_array, +- const std::string& name, int index) { ++ const std::string& name, ++ int index) { + if (metadata_array != nullptr && metadata_array->size() == tensors.size()) { + for (size_t i = 0; i < metadata_array->size(); i++) { + if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) { +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h +index 4cde4329a716b..df21662a40e3a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h +@@ -45,9 +45,9 @@ struct QaAnswer { + }; + + // Interface for an Question-Answer API. +-class QuestionAnswerer +- : public core::BaseTaskApi<std::vector<QaAnswer>, const std::string&, +- const std::string&> { ++class QuestionAnswerer : public core::BaseTaskApi<std::vector<QaAnswer>, ++ const std::string&, ++ const std::string&> { + public: + explicit QuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine) + : BaseTaskApi(std::move(engine)) {} +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc +index 069491f6e47c9..2937a175c5e3c 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc +@@ -21,7 +21,7 @@ limitations under the License. + #include <vector> + + #include "absl/container/flat_hash_map.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/core/base_task_api.h" + #include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +@@ -197,7 +197,8 @@ StatusOr<FeatureVector> UniversalSentenceEncoderQA::EncodeQuery( + } + + StatusOr<FeatureVector> UniversalSentenceEncoderQA::EncodeResponse( +- absl::string_view response_text, absl::string_view response_context) { ++ absl::string_view response_text, ++ absl::string_view response_context) { + if (response_text.empty() && response_context.empty()) { + return Status( + StatusCode::kInvalidArgument, +@@ -218,7 +219,8 @@ StatusOr<float> UniversalSentenceEncoderQA::Similarity(const FeatureVector& a, + } + + std::vector<size_t> UniversalSentenceEncoderQA::Top( +- const RetrievalOutput& output, size_t k) { ++ const RetrievalOutput& output, ++ size_t k) { + // Ensure k in [0, total_size). + // If k == 0, it means that all outputs are ranked. + if (k == 0) { +@@ -242,7 +244,8 @@ std::vector<size_t> UniversalSentenceEncoderQA::Top( + } + + Status UniversalSentenceEncoderQA::Preprocess( +- const std::vector<TfLiteTensor*>& input_tensors, const QAInput& input) { ++ const std::vector<TfLiteTensor*>& input_tensors, ++ const QAInput& input) { + auto* input_tensor_metadatas = + GetMetadataExtractor()->GetInputTensorMetadata(); + TfLiteTensor* query_text_tensor = +@@ -293,7 +296,8 @@ StatusOr<QAOutput> UniversalSentenceEncoderQA::Postprocess( + } + + internal::QAOutput UniversalSentenceEncoderQA::Run( +- absl::string_view query_text, absl::string_view response_text, ++ absl::string_view query_text, ++ absl::string_view response_text, + absl::string_view response_context) { + QAInput input; + input.query_text = query_text; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h +index fae2f29721722..0269033918cc9 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h +@@ -20,14 +20,14 @@ limitations under the License. + #include <vector> + + #include "absl/container/flat_hash_map.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/str_format.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/core/base_task_api.h" + #include "tensorflow_lite_support/cc/task/core/task_api_factory.h" + #include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +-#include "tensorflow_lite_support/cc/task/text/proto/retrieval.pb.h" + #include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h" ++#include "tensorflow_lite_support/cc/task/text/proto/retrieval.pb.h" + + namespace tflite { + namespace task { +@@ -73,7 +73,8 @@ class UniversalSentenceEncoderQA + // Encodes response from the text and/or context. + // Returns an error, if both text and context are empty. + tflite::support::StatusOr<FeatureVector> EncodeResponse( +- absl::string_view response_text, absl::string_view response_context); ++ absl::string_view response_text, ++ absl::string_view response_context); + + // Calculates similarity between two encoded vectors (require same size). + static tflite::support::StatusOr<float> Similarity(const FeatureVector& a, +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h +index 76a03671b54af..d3557fc508c61 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h +@@ -23,7 +23,7 @@ limitations under the License. + + #include "absl/memory/memory.h" // from @com_google_absl + #include "absl/status/status.h" // from @com_google_absl +-#include "absl/time/clock.h" // from @com_google_absl ++#include "absl/time/clock.h" // from @com_google_absl + #include "tensorflow/lite/c/common.h" + #include "tensorflow_lite_support/cc/common.h" + #include "tensorflow_lite_support/cc/port/integral_types.h" +@@ -45,11 +45,12 @@ namespace vision { + // Base class providing common logic for vision models. + template <class OutputType> + class BaseVisionTaskApi +- : public tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&, +- const BoundingBox&> { ++ : public tflite::task::core:: ++ BaseTaskApi<OutputType, const FrameBuffer&, const BoundingBox&> { + public: + explicit BaseVisionTaskApi(std::unique_ptr<core::TfLiteEngine> engine) +- : tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&, ++ : tflite::task::core::BaseTaskApi<OutputType, ++ const FrameBuffer&, + const BoundingBox&>(std::move(engine)) { + } + // BaseVisionTaskApi is neither copyable nor movable. +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h +index 47db0d121d43b..2e1aa6d652967 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h +@@ -18,7 +18,7 @@ limitations under the License. + #include <string> + #include <vector> + +-#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h +index 1668447393e9e..2936f5acbb921 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h +@@ -22,12 +22,12 @@ limitations under the License. + #include <utility> + #include <vector> + +-#include "absl/memory/memory.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_cat.h" // from @com_google_absl +-#include "absl/time/clock.h" // from @com_google_absl +-#include "absl/time/time.h" // from @com_google_absl +-#include "absl/types/optional.h" // from @com_google_absl ++#include "absl/time/clock.h" // from @com_google_absl ++#include "absl/time/time.h" // from @com_google_absl ++#include "absl/types/optional.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/integral_types.h" + #include "tensorflow_lite_support/cc/port/statusor.h" + +@@ -74,7 +74,16 @@ namespace vision { + class FrameBuffer { + public: + // Colorspace formats. +- enum class Format { kRGBA, kRGB, kNV12, kNV21, kYV12, kYV21, kGRAY, kUNKNOWN}; ++ enum class Format { ++ kRGBA, ++ kRGB, ++ kNV12, ++ kNV21, ++ kYV12, ++ kYV21, ++ kGRAY, ++ kUNKNOWN ++ }; + + // Stride information. + struct Stride { +@@ -166,7 +175,8 @@ class FrameBuffer { + // buffers. In a streaming use case (e.g continuous camera stream), the + // timestamp can be used as an ID to identify a frame. + static std::unique_ptr<FrameBuffer> Create(const std::vector<Plane>& planes, +- Dimension dimension, Format format, ++ Dimension dimension, ++ Format format, + Orientation orientation, + absl::Time timestamp) { + return absl::make_unique<FrameBuffer>(planes, dimension, format, +@@ -177,7 +187,8 @@ class FrameBuffer { + // backing buffers. In a streaming use case (e.g continuous camera stream), + // the timestamp can be used as an ID to identify a frame. + static std::unique_ptr<FrameBuffer> Create(std::vector<Plane>&& planes, +- Dimension dimension, Format format, ++ Dimension dimension, ++ Format format, + Orientation orientation, + absl::Time timestamp) { + return absl::make_unique<FrameBuffer>(std::move(planes), dimension, format, +@@ -189,7 +200,8 @@ class FrameBuffer { + // more suitable for processing use case that does not need to re-identify + // this buffer. + static std::unique_ptr<FrameBuffer> Create(const std::vector<Plane>& planes, +- Dimension dimension, Format format, ++ Dimension dimension, ++ Format format, + Orientation orientation) { + return absl::make_unique<FrameBuffer>(planes, dimension, format, + orientation, absl::Now()); +@@ -200,7 +212,8 @@ class FrameBuffer { + // method is more suitable for processing use case that does not need to + // re-identify this buffer. + static std::unique_ptr<FrameBuffer> Create(std::vector<Plane>&& planes, +- Dimension dimension, Format format, ++ Dimension dimension, ++ Format format, + Orientation orientation) { + return absl::make_unique<FrameBuffer>(std::move(planes), dimension, format, + orientation, absl::Now()); +@@ -217,8 +230,11 @@ class FrameBuffer { + // The FrameBuffer does not take ownership of the backing buffer. The backing + // buffer is read-only and the caller is responsible for maintaining the + // backing buffer lifecycle for the lifetime of FrameBuffer. +- FrameBuffer(const std::vector<Plane>& planes, Dimension dimension, +- Format format, Orientation orientation, absl::Time timestamp) ++ FrameBuffer(const std::vector<Plane>& planes, ++ Dimension dimension, ++ Format format, ++ Orientation orientation, ++ absl::Time timestamp) + : planes_(planes), + dimension_(dimension), + format_(format), +@@ -230,8 +246,11 @@ class FrameBuffer { + // The FrameBuffer does not take ownership of the backing buffer. The backing + // buffer is read-only and the caller is responsible for maintaining the + // backing buffer lifecycle for the lifetime of FrameBuffer. +- FrameBuffer(std::vector<Plane>&& planes, Dimension dimension, Format format, +- Orientation orientation, absl::Time timestamp) ++ FrameBuffer(std::vector<Plane>&& planes, ++ Dimension dimension, ++ Format format, ++ Orientation orientation, ++ absl::Time timestamp) + : planes_(std::move(planes)), + dimension_(dimension), + format_(format), +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc +index 9c82b63a10359..67fe07534b52a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc +@@ -16,7 +16,7 @@ limitations under the License. + #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" + + #include "absl/strings/str_format.h" // from @com_google_absl +-#include "absl/strings/str_split.h" // from @com_google_absl ++#include "absl/strings/str_split.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/common.h" + + namespace tflite { +@@ -29,7 +29,8 @@ using ::tflite::support::StatusOr; + using ::tflite::support::TfLiteSupportStatus; + + StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles( +- absl::string_view labels_file, absl::string_view display_names_file) { ++ absl::string_view labels_file, ++ absl::string_view display_names_file) { + if (labels_file.empty()) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Expected non-empty labels file.", +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h +index 0fb66f2639806..20c316ba4a992 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h +@@ -20,8 +20,8 @@ limitations under the License. + + #include "absl/container/flat_hash_map.h" // from @com_google_absl + #include "absl/container/flat_hash_set.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/string_view.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/string_view.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + + namespace tflite { +@@ -49,7 +49,8 @@ struct LabelMapItem { + // Returns an error e.g. if there's a mismatch between the number of labels and + // display names. + tflite::support::StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles( +- absl::string_view labels_file, absl::string_view display_names_file); ++ absl::string_view labels_file, ++ absl::string_view display_names_file); + + // A class that represents a hierarchy of labels as specified in a label map. + // +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc +index aa1e7707dd99b..36ab3c3ca1903 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc +@@ -16,9 +16,9 @@ limitations under the License. + #include "tensorflow_lite_support/cc/task/vision/image_classifier.h" + + #include "absl/algorithm/container.h" // from @com_google_absl +-#include "absl/strings/str_format.h" // from @com_google_absl ++#include "absl/strings/str_format.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl +-#include "flatbuffers/flatbuffers.h" // from @flatbuffers ++#include "flatbuffers/flatbuffers.h" // from @flatbuffers + #include "tensorflow_lite_support/cc/common.h" + #include "tensorflow_lite_support/cc/port/integral_types.h" + #include "tensorflow_lite_support/cc/port/status_macros.h" +@@ -146,7 +146,9 @@ absl::Status ImageClassifier::PreInit() { + return absl::OkStatus(); + } + +-absl::Status ImageClassifier::PostInit() { return InitScoreCalibrations(); } ++absl::Status ImageClassifier::PostInit() { ++ return InitScoreCalibrations(); ++} + + absl::Status ImageClassifier::CheckAndSetOutputs() { + num_outputs_ = TfLiteEngine::OutputCount(GetTfLiteEngine()->interpreter()); +@@ -380,13 +382,15 @@ StatusOr<ClassificationResult> ImageClassifier::Classify( + } + + StatusOr<ClassificationResult> ImageClassifier::Classify( +- const FrameBuffer& frame_buffer, const BoundingBox& roi) { ++ const FrameBuffer& frame_buffer, ++ const BoundingBox& roi) { + return InferWithFallback(frame_buffer, roi); + } + + StatusOr<ClassificationResult> ImageClassifier::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, +- const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) { ++ const FrameBuffer& /*frame_buffer*/, ++ const BoundingBox& /*roi*/) { + if (output_tensors.size() != num_outputs_) { + return CreateStatusWithPayload( + StatusCode::kInternal, +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h +index b2f595715e9da..eb0c13ec55c5b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h +@@ -20,7 +20,7 @@ limitations under the License. + #include <vector> + + #include "absl/container/flat_hash_set.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "tensorflow/lite/c/common.h" + #include "tensorflow/lite/core/api/op_resolver.h" + #include "tensorflow/lite/core/shims/cc/kernels/register.h" +@@ -109,7 +109,8 @@ class ImageClassifier : public BaseVisionTaskApi<ClassificationResult> { + // region of interest is not clamped, so this method will return a non-ok + // status if the region is out of these bounds. + tflite::support::StatusOr<ClassificationResult> Classify( +- const FrameBuffer& frame_buffer, const BoundingBox& roi); ++ const FrameBuffer& frame_buffer, ++ const BoundingBox& roi); + + protected: + // The options used to build this ImageClassifier. +@@ -123,7 +124,8 @@ class ImageClassifier : public BaseVisionTaskApi<ClassificationResult> { + // results. + tflite::support::StatusOr<ClassificationResult> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, +- const FrameBuffer& frame_buffer, const BoundingBox& roi) override; ++ const FrameBuffer& frame_buffer, ++ const BoundingBox& roi) override; + + // Performs sanity checks on the provided ImageClassifierOptions. + static absl::Status SanityCheckOptions(const ImageClassifierOptions& options); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc +index 0ce46fb9f9806..943a39b1f762e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc +@@ -18,10 +18,10 @@ limitations under the License. + #include <algorithm> + + #include "absl/container/node_hash_set.h" // from @com_google_absl +-#include "absl/memory/memory.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/str_format.h" // from @com_google_absl +-#include "absl/strings/string_view.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/str_format.h" // from @com_google_absl ++#include "absl/strings/string_view.h" // from @com_google_absl + #include "tensorflow/lite/c/common.h" + #include "tensorflow_lite_support/cc/common.h" + #include "tensorflow_lite_support/cc/port/status_macros.h" +@@ -51,7 +51,8 @@ CreatePostprocessor(core::TfLiteEngine* engine, + + /* static */ + tflite::support::StatusOr<double> ImageEmbedder::CosineSimilarity( +- const FeatureVector& u, const FeatureVector& v) { ++ const FeatureVector& u, ++ const FeatureVector& v) { + return processor::EmbeddingPostprocessor::CosineSimilarity(u, v); + } + +@@ -118,13 +119,15 @@ tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Embed( + } + + tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Embed( +- const FrameBuffer& frame_buffer, const BoundingBox& roi) { ++ const FrameBuffer& frame_buffer, ++ const BoundingBox& roi) { + return InferWithFallback(frame_buffer, roi); + } + + tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, +- const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) { ++ const FrameBuffer& /*frame_buffer*/, ++ const BoundingBox& /*roi*/) { + EmbeddingResult result; + for (int i = 0; i < postprocessors_.size(); ++i) { + RETURN_IF_ERROR( +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h +index bc321c83d3774..93e2455eebd19 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h +@@ -90,7 +90,8 @@ class ImageEmbedder + // region of interest. Note that the region of interest is not clamped, so + // this method will fail if the region is out of bounds of the input image. + tflite::support::StatusOr<EmbeddingResult> Embed( +- const FrameBuffer& frame_buffer, const BoundingBox& roi); ++ const FrameBuffer& frame_buffer, ++ const BoundingBox& roi); + + // Returns the Embedding output by the output_index'th layer. In (the most + // common) case where a single embedding is produced, you can just call +@@ -113,7 +114,8 @@ class ImageEmbedder + // + // [1]: https://en.wikipedia.org/wiki/Cosine_similarity + static tflite::support::StatusOr<double> CosineSimilarity( +- const FeatureVector& u, const FeatureVector& v); ++ const FeatureVector& u, ++ const FeatureVector& v); + + protected: + // The options used to build this ImageEmbedder. +@@ -122,7 +124,8 @@ class ImageEmbedder + // Post-processing to transform the raw model outputs into embedding results. + tflite::support::StatusOr<EmbeddingResult> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, +- const FrameBuffer& frame_buffer, const BoundingBox& roi) override; ++ const FrameBuffer& frame_buffer, ++ const BoundingBox& roi) override; + + // Performs pre-initialization actions. + virtual absl::Status PreInit(); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc +index f87c6b078eddc..20a34a956200b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc +@@ -17,8 +17,8 @@ limitations under the License. + + #include <algorithm> + +-#include "absl/memory/memory.h" // from @com_google_absl +-#include "absl/strings/str_format.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/strings/str_format.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl + #include "flatbuffers/flatbuffers.h" // from @flatbuffers + #include "tensorflow/lite/c/common.h" +@@ -110,7 +110,8 @@ constexpr uint8 kColorMap[768] = { + + StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny( + const ModelMetadataExtractor& metadata_extractor, +- const TensorMetadata& tensor_metadata, absl::string_view locale) { ++ const TensorMetadata& tensor_metadata, ++ absl::string_view locale) { + const std::string labels_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + tensor_metadata, tflite::AssociatedFileType_TENSOR_AXIS_LABELS); +@@ -332,7 +333,8 @@ StatusOr<SegmentationResult> ImageSegmenter::Segment( + + StatusOr<SegmentationResult> ImageSegmenter::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, +- const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) { ++ const FrameBuffer& frame_buffer, ++ const BoundingBox& /*roi*/) { + if (output_tensors.size() != 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, +@@ -432,7 +434,10 @@ StatusOr<SegmentationResult> ImageSegmenter::Postprocess( + } + + StatusOr<float> ImageSegmenter::GetOutputConfidence( +- const TfLiteTensor& output_tensor, int x, int y, int depth) { ++ const TfLiteTensor& output_tensor, ++ int x, ++ int y, ++ int depth) { + int index = output_width_ * output_depth_ * y + output_depth_ * x + depth; + if (has_uint8_outputs_) { + ASSIGN_OR_RETURN(const uint8* data, +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h +index 3f51f4962738e..e255110d9dc66 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h +@@ -119,7 +119,8 @@ class ImageSegmenter : public BaseVisionTaskApi<SegmentationResult> { + // results. + tflite::support::StatusOr<SegmentationResult> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, +- const FrameBuffer& frame_buffer, const BoundingBox& roi) override; ++ const FrameBuffer& frame_buffer, ++ const BoundingBox& roi) override; + + // Performs sanity checks on the provided ImageSegmenterOptions. + static absl::Status SanityCheckOptions(const ImageSegmenterOptions& options); +@@ -148,7 +149,10 @@ class ImageSegmenter : public BaseVisionTaskApi<SegmentationResult> { + // Returns the output confidence at coordinates {x, y, depth}, dequantizing + // on-the-fly if needed (i.e. if `has_uint8_outputs_` is true). + tflite::support::StatusOr<float> GetOutputConfidence( +- const TfLiteTensor& output_tensor, int x, int y, int depth); ++ const TfLiteTensor& output_tensor, ++ int x, ++ int y, ++ int depth); + + // Prebuilt list of ColoredLabel attached to each Segmentation result. The + // i-th item in this list corresponds to the i-th label map item. +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc +index 872bd8d5876a4..3eb512699bbda 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc +@@ -20,9 +20,9 @@ limitations under the License. + #include <vector> + + #include <glog/logging.h> +-#include "absl/memory/memory.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/str_format.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/str_format.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl + #include "tensorflow/lite/c/common.h" + #include "tensorflow_lite_support/cc/common.h" +@@ -141,7 +141,8 @@ StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties( + + StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny( + const ModelMetadataExtractor& metadata_extractor, +- const TensorMetadata& tensor_metadata, absl::string_view locale) { ++ const TensorMetadata& tensor_metadata, ++ absl::string_view locale) { + const std::string labels_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS); +@@ -370,7 +371,9 @@ absl::Status ObjectDetector::PreInit() { + return absl::OkStatus(); + } + +-absl::Status ObjectDetector::PostInit() { return InitScoreCalibrations(); } ++absl::Status ObjectDetector::PostInit() { ++ return InitScoreCalibrations(); ++} + + StatusOr<SigmoidCalibrationParameters> BuildCalibrationParametersIfAny( + const tflite::metadata::ModelMetadataExtractor& metadata_extractor, +@@ -599,7 +602,8 @@ StatusOr<DetectionResult> ObjectDetector::Detect( + + StatusOr<DetectionResult> ObjectDetector::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, +- const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) { ++ const FrameBuffer& frame_buffer, ++ const BoundingBox& /*roi*/) { + // Most of the checks here should never happen, as outputs have been validated + // at construction time. Checking nonetheless and returning internal errors if + // something bad happens. +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h +index eaa6b5371ba52..c37fa8771081e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h +@@ -19,7 +19,7 @@ limitations under the License. + #include <memory> + + #include "absl/container/flat_hash_set.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "tensorflow/lite/core/api/op_resolver.h" + #include "tensorflow/lite/core/shims/cc/kernels/register.h" + #include "tensorflow_lite_support/cc/port/statusor.h" +@@ -123,7 +123,8 @@ class ObjectDetector : public BaseVisionTaskApi<DetectionResult> { + // Post-processing to transform the raw model outputs into detection results. + tflite::support::StatusOr<DetectionResult> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, +- const FrameBuffer& frame_buffer, const BoundingBox& roi) override; ++ const FrameBuffer& frame_buffer, ++ const BoundingBox& roi) override; + + // Performs sanity checks on the provided ObjectDetectorOptions. + static absl::Status SanityCheckOptions(const ObjectDetectorOptions& options); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/segmentations.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/segmentations.proto +index 259bee8194735..f6df558cc1a1a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/segmentations.proto ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/segmentations.proto +@@ -31,17 +31,13 @@ message Segmentation { + // pixel, the value indicates the prediction confidence usually in the [0, 1] + // range where higher values represent a stronger confidence. Ultimately this + // is model specific, and other range of values might be used. +- message ConfidenceMask { +- repeated float value = 1 [packed = true]; +- } ++ message ConfidenceMask { repeated float value = 1 [packed = true]; } + + // List of confidence masks with respect to the model output depth (this depth + // represents how many classes are supported). Note: some models have a single + // class (e.g. a sky segmentation model) which turns into a single confidence + // mask in this list. +- message ConfidenceMasks { +- repeated ConfidenceMask confidence_mask = 1; +- } ++ message ConfidenceMasks { repeated ConfidenceMask confidence_mask = 1; } + + // IMPORTANT: segmentation masks are not direcly suited for display, in + // particular: +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc +index f30af1e7d27d8..cea7ef3fb1f23 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc +@@ -18,7 +18,7 @@ limitations under the License. + #include <string> + #include <vector> + +-#include "absl/strings/str_cat.h" // from @com_google_absl ++#include "absl/strings/str_cat.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/status_macros.h" + +@@ -36,8 +36,10 @@ constexpr int kGrayChannel = 1; + // Creates a FrameBuffer from one plane raw NV21/NV12 buffer and passing + // arguments. + StatusOr<std::unique_ptr<FrameBuffer>> CreateFromOnePlaneNVRawBuffer( +- const uint8* input, FrameBuffer::Dimension dimension, +- FrameBuffer::Format format, FrameBuffer::Orientation orientation, ++ const uint8* input, ++ FrameBuffer::Dimension dimension, ++ FrameBuffer::Format format, ++ FrameBuffer::Orientation orientation, + const absl::Time timestamp) { + FrameBuffer::Plane input_plane = {/*buffer=*/input, + /*stride=*/{dimension.width, kGrayChannel}}; +@@ -129,7 +131,8 @@ StatusOr<const uint8*> GetUvRawBuffer(const FrameBuffer& buffer) { + } + + StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension( +- FrameBuffer::Dimension dimension, FrameBuffer::Format format) { ++ FrameBuffer::Dimension dimension, ++ FrameBuffer::Format format) { + if (dimension.width <= 0 || dimension.height <= 0) { + return absl::InvalidArgumentError( + absl::StrFormat("Invalid input dimension: {%d, %d}.", dimension.width, +@@ -176,7 +179,8 @@ absl::Status ValidateBufferFormat(const FrameBuffer& buffer) { + case FrameBuffer::Format::kGRAY: + case FrameBuffer::Format::kRGB: + case FrameBuffer::Format::kRGBA: +- if (buffer.plane_count() == 1) return absl::OkStatus(); ++ if (buffer.plane_count() == 1) ++ return absl::OkStatus(); + return absl::InvalidArgumentError( + "Plane count must be 1 for grayscale and RGB[a] buffers."); + case FrameBuffer::Format::kNV21: +@@ -252,8 +256,11 @@ absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer, + } + + absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer, +- const FrameBuffer& output_buffer, int x0, +- int y0, int x1, int y1) { ++ const FrameBuffer& output_buffer, ++ int x0, ++ int y0, ++ int x1, ++ int y1) { + if (!AreBufferFormatsCompatible(buffer, output_buffer)) { + return absl::InvalidArgumentError( + "Input and output buffer formats must match."); +@@ -314,8 +321,10 @@ absl::Status ValidateConvertFormats(FrameBuffer::Format from_format, + + // Creates a FrameBuffer from raw RGBA buffer and passing arguments. + std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer( +- const uint8* input, FrameBuffer::Dimension dimension, +- FrameBuffer::Orientation orientation, const absl::Time timestamp, ++ const uint8* input, ++ FrameBuffer::Dimension dimension, ++ FrameBuffer::Orientation orientation, ++ const absl::Time timestamp, + FrameBuffer::Stride stride) { + if (stride == kDefaultStride) { + stride.row_stride_bytes = dimension.width * kRgbaChannels; +@@ -330,8 +339,10 @@ std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer( + + // Creates a FrameBuffer from raw RGB buffer and passing arguments. + std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer( +- const uint8* input, FrameBuffer::Dimension dimension, +- FrameBuffer::Orientation orientation, const absl::Time timestamp, ++ const uint8* input, ++ FrameBuffer::Dimension dimension, ++ FrameBuffer::Orientation orientation, ++ const absl::Time timestamp, + FrameBuffer::Stride stride) { + if (stride == kDefaultStride) { + stride.row_stride_bytes = dimension.width * kRgbChannels; +@@ -345,8 +356,10 @@ std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer( + + // Creates a FrameBuffer from raw grayscale buffer and passing arguments. + std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer( +- const uint8* input, FrameBuffer::Dimension dimension, +- FrameBuffer::Orientation orientation, const absl::Time timestamp, ++ const uint8* input, ++ FrameBuffer::Dimension dimension, ++ FrameBuffer::Orientation orientation, ++ const absl::Time timestamp, + FrameBuffer::Stride stride) { + if (stride == kDefaultStride) { + stride.row_stride_bytes = dimension.width * kGrayChannel; +@@ -361,10 +374,16 @@ std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer( + + // Creates a FrameBuffer from raw YUV buffer and passing arguments. + StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer( +- const uint8* y_plane, const uint8* u_plane, const uint8* v_plane, +- FrameBuffer::Format format, FrameBuffer::Dimension dimension, +- int row_stride_y, int row_stride_uv, int pixel_stride_uv, +- FrameBuffer::Orientation orientation, const absl::Time timestamp) { ++ const uint8* y_plane, ++ const uint8* u_plane, ++ const uint8* v_plane, ++ FrameBuffer::Format format, ++ FrameBuffer::Dimension dimension, ++ int row_stride_y, ++ int row_stride_uv, ++ int pixel_stride_uv, ++ FrameBuffer::Orientation orientation, ++ const absl::Time timestamp) { + const int pixel_stride_y = 1; + std::vector<FrameBuffer::Plane> planes; + if (format == FrameBuffer::Format::kNV21 || +@@ -385,9 +404,11 @@ StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer( + } + + StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer( +- const uint8* buffer, FrameBuffer::Dimension dimension, ++ const uint8* buffer, ++ FrameBuffer::Dimension dimension, + const FrameBuffer::Format target_format, +- FrameBuffer::Orientation orientation, absl::Time timestamp) { ++ FrameBuffer::Orientation orientation, ++ absl::Time timestamp) { + switch (target_format) { + case FrameBuffer::Format::kNV12: + return CreateFromOnePlaneNVRawBuffer(buffer, dimension, target_format, +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h +index 470e76b9037a1..7ebf69fadc3de 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h +@@ -18,8 +18,8 @@ limitations under the License. + #include <memory> + + #include "absl/status/status.h" // from @com_google_absl +-#include "absl/time/clock.h" // from @com_google_absl +-#include "absl/time/time.h" // from @com_google_absl ++#include "absl/time/clock.h" // from @com_google_absl ++#include "absl/time/time.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/integral_types.h" + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +@@ -58,7 +58,8 @@ tflite::support::StatusOr<const uint8*> GetUvRawBuffer( + // supported formats. This method assums the UV plane share the same dimension, + // especially for the YV12 / YV21 formats. + tflite::support::StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension( +- FrameBuffer::Dimension dimension, FrameBuffer::Format format); ++ FrameBuffer::Dimension dimension, ++ FrameBuffer::Format format); + + // Returns crop dimension based on crop start and end points. + FrameBuffer::Dimension GetCropDimension(int x0, int x1, int y0, int y1); +@@ -92,8 +93,11 @@ absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer, + // (x0, y0) represents the top-left point of the buffer. + // (x1, y1) represents the bottom-right point of the buffer. + absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer, +- const FrameBuffer& output_buffer, int x0, +- int y0, int x1, int y1); ++ const FrameBuffer& output_buffer, ++ int x0, ++ int y0, ++ int x1, ++ int y1); + + // Validates the given inputs for flipping `buffer` horizontally or vertically. + absl::Status ValidateFlipBufferInputs(const FrameBuffer& buffer, +@@ -110,36 +114,45 @@ absl::Status ValidateConvertFormats(FrameBuffer::Format from_format, + + // Creates a FrameBuffer from raw RGBA buffer and passing arguments. + std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer( +- const uint8* input, FrameBuffer::Dimension dimension, ++ const uint8* input, ++ FrameBuffer::Dimension dimension, + FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft, + absl::Time timestamp = absl::Now(), + FrameBuffer::Stride stride = kDefaultStride); + + // Creates a FrameBuffer from raw RGB buffer and passing arguments. + std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer( +- const uint8* input, FrameBuffer::Dimension dimension, ++ const uint8* input, ++ FrameBuffer::Dimension dimension, + FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft, + absl::Time timestamp = absl::Now(), + FrameBuffer::Stride stride = kDefaultStride); + + // Creates a FrameBuffer from raw grayscale buffer and passing arguments. + std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer( +- const uint8* input, FrameBuffer::Dimension dimension, ++ const uint8* input, ++ FrameBuffer::Dimension dimension, + FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft, + absl::Time timestamp = absl::Now(), + FrameBuffer::Stride stride = kDefaultStride); + + // Creates a FrameBuffer from raw YUV buffer and passing arguments. + tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer( +- const uint8* y_plane, const uint8* u_plane, const uint8* v_plane, +- FrameBuffer::Format format, FrameBuffer::Dimension dimension, +- int row_stride_y, int row_stride_uv, int pixel_stride_uv, ++ const uint8* y_plane, ++ const uint8* u_plane, ++ const uint8* v_plane, ++ FrameBuffer::Format format, ++ FrameBuffer::Dimension dimension, ++ int row_stride_y, ++ int row_stride_uv, ++ int pixel_stride_uv, + FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft, + absl::Time timestamp = absl::Now()); + + // Creates an instance of FrameBuffer from raw buffer and passing arguments. + tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer( +- const uint8* buffer, FrameBuffer::Dimension dimension, ++ const uint8* buffer, ++ FrameBuffer::Dimension dimension, + FrameBuffer::Format target_format, + FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft, + absl::Time timestamp = absl::Now()); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc +index 4d767fc3e48b2..4728c30cb60dc 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc +@@ -22,8 +22,8 @@ limitations under the License. + #include <utility> + #include <vector> + +-#include "absl/memory/memory.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow/lite/kernels/internal/compatibility.h" + #include "tensorflow/lite/kernels/op_macros.h" +@@ -91,7 +91,8 @@ static int GetOrientationIndex(FrameBuffer::Orientation orientation) { + // The new box origin is (x:box.origin_y, y:width - (box.origin_x + box.width). + // The new box dimension is (w: box.height, h: box.width). + // +-static BoundingBox RotateBoundingBox(const BoundingBox& box, int angle, ++static BoundingBox RotateBoundingBox(const BoundingBox& box, ++ int angle, + FrameBuffer::Dimension frame_dimension) { + int rx = box.origin_x(), ry = box.origin_y(), rw = box.width(), + rh = box.height(); +@@ -130,9 +131,12 @@ static BoundingBox RotateBoundingBox(const BoundingBox& box, int angle, + // in counterclockwise degree in one of the values [0, 90, 180, 270]. + // + // See `RotateBoundingBox` above for more details. +-static void RotateCoordinates(int from_x, int from_y, int angle, ++static void RotateCoordinates(int from_x, ++ int from_y, ++ int angle, + const FrameBuffer::Dimension& frame_dimension, +- int* to_x, int* to_y) { ++ int* to_x, ++ int* to_y) { + switch (angle) { + case 0: + *to_x = from_x; +@@ -199,7 +203,10 @@ BoundingBox OrientBoundingBox(const BoundingBox& from_box, + } + + BoundingBox OrientAndDenormalizeBoundingBox( +- float from_left, float from_top, float from_right, float from_bottom, ++ float from_left, ++ float from_top, ++ float from_right, ++ float from_bottom, + FrameBuffer::Orientation from_orientation, + FrameBuffer::Orientation to_orientation, + FrameBuffer::Dimension from_dimension) { +@@ -214,10 +221,12 @@ BoundingBox OrientAndDenormalizeBoundingBox( + return to_box; + } + +-void OrientCoordinates(int from_x, int from_y, ++void OrientCoordinates(int from_x, ++ int from_y, + FrameBuffer::Orientation from_orientation, + FrameBuffer::Orientation to_orientation, +- FrameBuffer::Dimension from_dimension, int* to_x, ++ FrameBuffer::Dimension from_dimension, ++ int* to_x, + int* to_y) { + *to_x = from_x; + *to_y = from_y; +@@ -298,15 +307,19 @@ bool RequireDimensionSwap(FrameBuffer::Orientation from_orientation, + return params.rotation_angle_deg == 90 || params.rotation_angle_deg == 270; + } + +-absl::Status FrameBufferUtils::Crop(const FrameBuffer& buffer, int x0, int y0, +- int x1, int y1, ++absl::Status FrameBufferUtils::Crop(const FrameBuffer& buffer, ++ int x0, ++ int y0, ++ int x1, ++ int y1, + FrameBuffer* output_buffer) { + TFLITE_DCHECK(utils_ != nullptr); + return utils_->Crop(buffer, x0, y0, x1, y1, output_buffer); + } + + FrameBuffer::Dimension FrameBufferUtils::GetSize( +- const FrameBuffer& buffer, const FrameBufferOperation& operation) { ++ const FrameBuffer& buffer, ++ const FrameBufferOperation& operation) { + FrameBuffer::Dimension dimension = buffer.dimension(); + if (absl::holds_alternative<OrientOperation>(operation)) { + OrientParams params = +@@ -327,7 +340,8 @@ FrameBuffer::Dimension FrameBufferUtils::GetSize( + } + + std::vector<FrameBuffer::Plane> FrameBufferUtils::GetPlanes( +- const uint8* buffer, FrameBuffer::Dimension dimension, ++ const uint8* buffer, ++ FrameBuffer::Dimension dimension, + FrameBuffer::Format format) { + std::vector<FrameBuffer::Plane> planes; + switch (format) { +@@ -378,7 +392,8 @@ std::vector<FrameBuffer::Plane> FrameBufferUtils::GetPlanes( + } + + FrameBuffer::Orientation FrameBufferUtils::GetOrientation( +- const FrameBuffer& buffer, const FrameBufferOperation& operation) { ++ const FrameBuffer& buffer, ++ const FrameBufferOperation& operation) { + if (absl::holds_alternative<OrientOperation>(operation)) { + return absl::get<OrientOperation>(operation).to_orientation; + } +@@ -386,7 +401,8 @@ FrameBuffer::Orientation FrameBufferUtils::GetOrientation( + } + + FrameBuffer::Format FrameBufferUtils::GetFormat( +- const FrameBuffer& buffer, const FrameBufferOperation& operation) { ++ const FrameBuffer& buffer, ++ const FrameBufferOperation& operation) { + if (absl::holds_alternative<ConvertOperation>(operation)) { + return absl::get<ConvertOperation>(operation).to_format; + } +@@ -578,8 +594,10 @@ absl::Status FrameBufferUtils::Execute( + } + + absl::Status FrameBufferUtils::Preprocess( +- const FrameBuffer& buffer, absl::optional<BoundingBox> bounding_box, +- FrameBuffer* output_buffer, bool uniform_resizing) { ++ const FrameBuffer& buffer, ++ absl::optional<BoundingBox> bounding_box, ++ FrameBuffer* output_buffer, ++ bool uniform_resizing) { + std::vector<FrameBufferOperation> frame_buffer_operations; + // Handle cropping and resizing. + bool needs_dimension_swap = +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h +index 59e80e5765bb0..48549461159cb 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h +@@ -19,9 +19,9 @@ limitations under the License. + #include <memory> + #include <vector> + +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/types/optional.h" // from @com_google_absl +-#include "absl/types/variant.h" // from @com_google_absl ++#include "absl/types/variant.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/integral_types.h" + #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" + #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +@@ -45,7 +45,10 @@ BoundingBox OrientBoundingBox(const BoundingBox& from_box, + + // Same as OrientBoundingBox but from normalized coordinates. + BoundingBox OrientAndDenormalizeBoundingBox( +- float from_left, float from_top, float from_right, float from_bottom, ++ float from_left, ++ float from_top, ++ float from_right, ++ float from_bottom, + FrameBuffer::Orientation from_orientation, + FrameBuffer::Orientation to_orientation, + FrameBuffer::Dimension from_dimension); +@@ -53,10 +56,12 @@ BoundingBox OrientAndDenormalizeBoundingBox( + // Rotates `(from_x, from_y)` coordinates from an image of dimension + // `from_dimension` and orientation `from_orientation` into `(to_x, to_y)` + // coordinates with orientation `to_orientation`. +-void OrientCoordinates(int from_x, int from_y, ++void OrientCoordinates(int from_x, ++ int from_y, + FrameBuffer::Orientation from_orientation, + FrameBuffer::Orientation to_orientation, +- FrameBuffer::Dimension from_dimension, int* to_x, ++ FrameBuffer::Dimension from_dimension, ++ int* to_x, + int* to_y); + + // Returns whether the conversion from from_orientation to to_orientation +@@ -92,7 +97,8 @@ OrientParams GetOrientParams(FrameBuffer::Orientation from_orientation, + // To perform just cropping, the `crop_width` and `crop_height` should be the + // same as `resize_width` `and resize_height`. + struct CropResizeOperation { +- CropResizeOperation(int crop_origin_x, int crop_origin_y, ++ CropResizeOperation(int crop_origin_x, ++ int crop_origin_y, + FrameBuffer::Dimension crop_dimension, + FrameBuffer::Dimension resize_dimension) + : crop_origin_x(crop_origin_x), +@@ -124,7 +130,8 @@ struct CropResizeOperation { + // The resized region is aligned to the upper left pixel of the output buffer. + // The unfilled area of the output buffer remains untouched. + struct UniformCropResizeOperation { +- UniformCropResizeOperation(int crop_origin_x, int crop_origin_y, ++ UniformCropResizeOperation(int crop_origin_x, ++ int crop_origin_y, + FrameBuffer::Dimension crop_dimension, + FrameBuffer::Dimension output_dimension) + : crop_origin_x(crop_origin_x), +@@ -154,9 +161,10 @@ struct OrientOperation { + + // A variant of the supported operations on FrameBuffers. Alias for user + // convenience. +-using FrameBufferOperation = +- absl::variant<CropResizeOperation, ConvertOperation, OrientOperation, +- UniformCropResizeOperation>; ++using FrameBufferOperation = absl::variant<CropResizeOperation, ++ ConvertOperation, ++ OrientOperation, ++ UniformCropResizeOperation>; + + // Image processing utility. This utility provides both basic image buffer + // manipulations (e.g. rotation, format conversion, resizing, etc) as well as +@@ -212,7 +220,11 @@ class FrameBufferUtils { + // should be big enough to store the operation result. If the `output_buffer` + // size dimension does not match with crop dimension, then a resize is + // automatically performed. +- absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, ++ absl::Status Crop(const FrameBuffer& buffer, ++ int x0, ++ int y0, ++ int x1, ++ int y1, + FrameBuffer* output_buffer); + + // Performs resizing operation. +@@ -229,7 +241,8 @@ class FrameBufferUtils { + // + // The output_buffer should have metadata populated and its backing buffer + // should be big enough to store the operation result. +- absl::Status Rotate(const FrameBuffer& buffer, RotationDegree rotation, ++ absl::Status Rotate(const FrameBuffer& buffer, ++ RotationDegree rotation, + FrameBuffer* output_buffer); + + // Performs horizontal flip operation. +@@ -305,7 +318,8 @@ class FrameBufferUtils { + + // Returns the new FrameBuffer orientation after command is processed. + FrameBuffer::Orientation GetOrientation( +- const FrameBuffer& buffer, const FrameBufferOperation& operation); ++ const FrameBuffer& buffer, ++ const FrameBufferOperation& operation); + + // Returns the new FrameBuffer format after command is processed. + FrameBuffer::Format GetFormat(const FrameBuffer& buffer, +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h +index ec0c3119ea4e8..59da2206bb06f 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h +@@ -37,8 +37,12 @@ class FrameBufferUtilsInterface { + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. +- virtual absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, +- int y1, FrameBuffer* output_buffer) = 0; ++ virtual absl::Status Crop(const FrameBuffer& buffer, ++ int x0, ++ int y0, ++ int x1, ++ int y1, ++ FrameBuffer* output_buffer) = 0; + + // Resizes `buffer` to the size of the given `output_buffer`. + // +@@ -57,7 +61,8 @@ class FrameBufferUtilsInterface { + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. +- virtual absl::Status Rotate(const FrameBuffer& buffer, int angle_deg, ++ virtual absl::Status Rotate(const FrameBuffer& buffer, ++ int angle_deg, + FrameBuffer* output_buffer) = 0; + + // Flips `buffer` horizontally. +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc +index 6fd3ca81c984c..a00c8223fac99 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc +@@ -20,10 +20,10 @@ limitations under the License. + #include <memory> + #include <string> + +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/str_cat.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/str_cat.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl +-#include "libyuv.h" // from @libyuv ++#include "libyuv.h" // from @libyuv + #include "tensorflow_lite_support/cc/common.h" + #include "tensorflow_lite_support/cc/port/integral_types.h" + #include "tensorflow_lite_support/cc/port/status_macros.h" +@@ -383,7 +383,8 @@ absl::Status ResizeNv(const FrameBuffer& buffer, FrameBuffer* output_buffer) { + + // Converts `buffer` to libyuv ARGB format and stores the conversion result + // in `dest_argb`. +-absl::Status ConvertRgbToArgb(const FrameBuffer& buffer, uint8* dest_argb, ++absl::Status ConvertRgbToArgb(const FrameBuffer& buffer, ++ uint8* dest_argb, + int dest_stride_argb) { + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); + if (buffer.format() != FrameBuffer::Format::kRGB) { +@@ -420,7 +421,8 @@ absl::Status ConvertRgbToArgb(const FrameBuffer& buffer, uint8* dest_argb, + + // Converts `src_argb` in libyuv ARGB format to FrameBuffer::kRGB format and + // stores the conversion result in `output_buffer`. +-absl::Status ConvertArgbToRgb(uint8* src_argb, int src_stride_argb, ++absl::Status ConvertArgbToRgb(uint8* src_argb, ++ int src_stride_argb, + FrameBuffer* output_buffer) { + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer)); + if (output_buffer->format() != FrameBuffer::Format::kRGB) { +@@ -456,7 +458,8 @@ absl::Status ConvertArgbToRgb(uint8* src_argb, int src_stride_argb, + + // Converts `buffer` in FrameBuffer::kRGBA format to libyuv ARGB (BGRA in + // memory) format and stores the conversion result in `dest_argb`. +-absl::Status ConvertRgbaToArgb(const FrameBuffer& buffer, uint8* dest_argb, ++absl::Status ConvertRgbaToArgb(const FrameBuffer& buffer, ++ uint8* dest_argb, + int dest_stride_argb) { + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); + if (buffer.format() != FrameBuffer::Format::kRGBA) { +@@ -674,7 +677,8 @@ libyuv::RotationMode GetLibyuvRotationMode(int angle_deg) { + } + } + +-absl::Status RotateRgba(const FrameBuffer& buffer, int angle_deg, ++absl::Status RotateRgba(const FrameBuffer& buffer, ++ int angle_deg, + FrameBuffer* output_buffer) { + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( +@@ -698,7 +702,8 @@ absl::Status RotateRgba(const FrameBuffer& buffer, int angle_deg, + return absl::OkStatus(); + } + +-absl::Status RotateRgb(const FrameBuffer& buffer, int angle_deg, ++absl::Status RotateRgb(const FrameBuffer& buffer, ++ int angle_deg, + FrameBuffer* output_buffer) { + // libyuv does not support rotate kRGB (RGB24) foramat. In this method, the + // implementation converts kRGB format to ARGB and use ARGB buffer for +@@ -731,7 +736,8 @@ absl::Status RotateRgb(const FrameBuffer& buffer, int angle_deg, + output_buffer); + } + +-absl::Status RotateGray(const FrameBuffer& buffer, int angle_deg, ++absl::Status RotateGray(const FrameBuffer& buffer, ++ int angle_deg, + FrameBuffer* output_buffer) { + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( +@@ -754,7 +760,8 @@ absl::Status RotateGray(const FrameBuffer& buffer, int angle_deg, + } + + // Rotates YV12/YV21 frame buffer. +-absl::Status RotateYv(const FrameBuffer& buffer, int angle_deg, ++absl::Status RotateYv(const FrameBuffer& buffer, ++ int angle_deg, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); +@@ -779,7 +786,8 @@ absl::Status RotateYv(const FrameBuffer& buffer, int angle_deg, + // Rotates NV12/NV21 frame buffer. + // TODO(b/152097364): Refactor NV12/NV21 rotation after libyuv explicitly + // support that. +-absl::Status RotateNv(const FrameBuffer& buffer, int angle_deg, ++absl::Status RotateNv(const FrameBuffer& buffer, ++ int angle_deg, + FrameBuffer* output_buffer) { + if (buffer.format() != FrameBuffer::Format::kNV12 && + buffer.format() != FrameBuffer::Format::kNV21) { +@@ -869,8 +877,12 @@ absl::Status FlipPlaneVertically(const FrameBuffer& buffer, + } + + // This method only supports kGRAY, kRGBA, and kRGB formats. +-absl::Status CropPlane(const FrameBuffer& buffer, int x0, int y0, int x1, +- int y1, FrameBuffer* output_buffer) { ++absl::Status CropPlane(const FrameBuffer& buffer, ++ int x0, ++ int y0, ++ int x1, ++ int y1, ++ FrameBuffer* output_buffer) { + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, +@@ -897,7 +909,11 @@ absl::Status CropPlane(const FrameBuffer& buffer, int x0, int y0, int x1, + + // Crops NV12/NV21 FrameBuffer to the subregion defined by the top left pixel + // position (x0, y0) and the bottom right pixel position (x1, y1). +-absl::Status CropNv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, ++absl::Status CropNv(const FrameBuffer& buffer, ++ int x0, ++ int y0, ++ int x1, ++ int y1, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); +@@ -929,7 +945,11 @@ absl::Status CropNv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, + + // Crops YV12/YV21 FrameBuffer to the subregion defined by the top left pixel + // position (x0, y0) and the bottom right pixel position (x1, y1). +-absl::Status CropYv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, ++absl::Status CropYv(const FrameBuffer& buffer, ++ int x0, ++ int y0, ++ int x1, ++ int y1, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); +@@ -964,8 +984,12 @@ absl::Status CropYv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, + return absl::OkStatus(); + } + +-absl::Status CropResizeYuv(const FrameBuffer& buffer, int x0, int y0, int x1, +- int y1, FrameBuffer* output_buffer) { ++absl::Status CropResizeYuv(const FrameBuffer& buffer, ++ int x0, ++ int y0, ++ int x1, ++ int y1, ++ FrameBuffer* output_buffer) { + FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1); + if (crop_dimension == output_buffer->dimension()) { + switch (buffer.format()) { +@@ -1293,8 +1317,12 @@ absl::Status ResizeGray(const FrameBuffer& buffer, FrameBuffer* output_buffer) { + } + + // This method only supports kGRAY, kRGBA, and kRGB formats. +-absl::Status CropResize(const FrameBuffer& buffer, int x0, int y0, int x1, +- int y1, FrameBuffer* output_buffer) { ++absl::Status CropResize(const FrameBuffer& buffer, ++ int x0, ++ int y0, ++ int x1, ++ int y1, ++ FrameBuffer* output_buffer) { + FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1); + if (crop_dimension == output_buffer->dimension()) { + return CropPlane(buffer, x0, y0, x1, y1, output_buffer); +@@ -1326,10 +1354,13 @@ absl::Status CropResize(const FrameBuffer& buffer, int x0, int y0, int x1, + } + } + +-} // namespace ++} // namespace + +-absl::Status LibyuvFrameBufferUtils::Crop(const FrameBuffer& buffer, int x0, +- int y0, int x1, int y1, ++absl::Status LibyuvFrameBufferUtils::Crop(const FrameBuffer& buffer, ++ int x0, ++ int y0, ++ int x1, ++ int y1, + FrameBuffer* output_buffer) { + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer)); +@@ -1410,7 +1441,8 @@ absl::Status LibyuvFrameBufferUtils::Rotate(const FrameBuffer& buffer, + } + + absl::Status LibyuvFrameBufferUtils::FlipHorizontally( +- const FrameBuffer& buffer, FrameBuffer* output_buffer) { ++ const FrameBuffer& buffer, ++ FrameBuffer* output_buffer) { + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer)); + RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer)); +@@ -1438,7 +1470,8 @@ absl::Status LibyuvFrameBufferUtils::FlipHorizontally( + } + + absl::Status LibyuvFrameBufferUtils::FlipVertically( +- const FrameBuffer& buffer, FrameBuffer* output_buffer) { ++ const FrameBuffer& buffer, ++ FrameBuffer* output_buffer) { + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer)); + RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer)); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h +index 5da898bc058a4..6f83559139130 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h +@@ -41,7 +41,11 @@ class LibyuvFrameBufferUtils : public FrameBufferUtilsInterface { + // + // Crop region dimensions must be equal or smaller than input `buffer` + // dimensions. +- absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, ++ absl::Status Crop(const FrameBuffer& buffer, ++ int x0, ++ int y0, ++ int x1, ++ int y1, + FrameBuffer* output_buffer) override; + + // Resizes `buffer` to the size of the given `output_buffer`. +@@ -51,7 +55,8 @@ class LibyuvFrameBufferUtils : public FrameBufferUtilsInterface { + // Rotates `buffer` counter-clockwise by the given `angle_deg` (in degrees). + // + // The given angle must be a multiple of 90 degrees. +- absl::Status Rotate(const FrameBuffer& buffer, int angle_deg, ++ absl::Status Rotate(const FrameBuffer& buffer, ++ int angle_deg, + FrameBuffer* output_buffer) override; + + // Flips `buffer` horizontally. +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc +index bc57c0b904534..d58969d96827e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc +@@ -20,11 +20,11 @@ limitations under the License. + #include <utility> + #include <vector> + +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/str_format.h" // from @com_google_absl +-#include "absl/strings/str_split.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/str_format.h" // from @com_google_absl ++#include "absl/strings/str_split.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl +-#include "absl/types/optional.h" // from @com_google_absl ++#include "absl/types/optional.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/common.h" + #include "tensorflow_lite_support/cc/port/status_macros.h" + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h +index 95cbecf54bd1d..e2b403d9b35b9 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h +@@ -23,9 +23,9 @@ limitations under the License. + #include <vector> + + #include "absl/container/flat_hash_map.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/string_view.h" // from @com_google_absl +-#include "absl/types/optional.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/string_view.h" // from @com_google_absl ++#include "absl/types/optional.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" + #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" +@@ -37,7 +37,10 @@ namespace vision { + // Sigmoid structure. + struct Sigmoid { + Sigmoid() : scale(1.0) {} +- Sigmoid(std::string label, float slope, float offset, float scale = 1.0, ++ Sigmoid(std::string label, ++ float slope, ++ float offset, ++ float scale = 1.0, + absl::optional<float> min_uncalibrated_score = absl::nullopt) + : label(label), + slope(slope), +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc +index 311994c1abbf9..bc2f9dfd53a96 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc +@@ -16,7 +16,7 @@ limitations under the License. + #include "tensorflow_lite_support/cc/common.h" + + #include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/cord.h" // from @com_google_absl ++#include "absl/strings/cord.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/gmock.h" + #include "tensorflow_lite_support/cc/port/gtest.h" + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc +index d0a7e33129e7e..9ae943548dc63 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc +@@ -46,8 +46,8 @@ constexpr char kTestDataDirectory[] = + constexpr char kDilatedConvolutionModelWithMetaData[] = "dilated_conv.tflite"; + + StatusOr<ImageData> LoadImage(std::string image_name) { +- return DecodeImageFromFile(JoinPath("./" /*test src dir*/, +- kTestDataDirectory, image_name)); ++ return DecodeImageFromFile( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name)); + } + + class DynamicInputTest : public tflite_shims::testing::Test { +@@ -60,7 +60,7 @@ class DynamicInputTest : public tflite_shims::testing::Test { + SUPPORT_ASSERT_OK(engine_->InitInterpreter()); + + SUPPORT_ASSERT_OK_AND_ASSIGN(auto preprocessor, +- ImagePreprocessor::Create(engine_.get(), {0})); ++ ImagePreprocessor::Create(engine_.get(), {0})); + + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( +@@ -94,9 +94,10 @@ TEST_F(DynamicInputTest, GoldenImageComparison) { + PreprocessImage(); + + // Get the processed input image. +- SUPPORT_ASSERT_OK_AND_ASSIGN(float* processed_input_data, +- tflite::task::core::AssertAndReturnTypedTensor<float>( +- engine_->GetInputs()[0])); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ float* processed_input_data, ++ tflite::task::core::AssertAndReturnTypedTensor<float>( ++ engine_->GetInputs()[0])); + + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + const uint8* image_data = image.pixel_data; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc +index 629f069e7b8d1..c4a8cea0d53b9 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc +@@ -49,8 +49,7 @@ constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite"; + constexpr int kMaxSeqLen = 128; + + std::string GetFullPath(absl::string_view file_name) { +- return JoinPath("./" /*test src dir*/, kTestDataDirectory, +- file_name); ++ return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name); + } + + class BertNLClassifierTest : public tflite_shims::testing::Test {}; +@@ -77,14 +76,15 @@ TEST_F(BertNLClassifierTest, CreateFromOptionsFailsWithMissingBaseOptions) { + } + + TEST_F(BertNLClassifierTest, TestNLClassifierCreationFilePath) { +- SUPPORT_ASSERT_OK(BertNLClassifier::CreateFromFile(GetFullPath(kTestModelPath))); ++ SUPPORT_ASSERT_OK( ++ BertNLClassifier::CreateFromFile(GetFullPath(kTestModelPath))); + } + + TEST_F(BertNLClassifierTest, TestNLClassifierCreationBinary) { + std::string model_buffer = + LoadBinaryContent(GetFullPath(kTestModelPath).c_str()); + SUPPORT_ASSERT_OK(BertNLClassifier::CreateFromBuffer(model_buffer.data(), +- model_buffer.size())); ++ model_buffer.size())); + } + + TEST_F(BertNLClassifierTest, TestNLClassifierCreationFailure) { +@@ -136,7 +136,7 @@ TEST_F(BertNLClassifierTest, ClassifySucceedsWithBaseOptions) { + contents); + + SUPPORT_ASSERT_OK_AND_ASSIGN(classifier, +- BertNLClassifier::CreateFromOptions(options)); ++ BertNLClassifier::CreateFromOptions(options)); + } + + verify_classifier(std::move(classifier), /*verify_positive=*/false); +@@ -146,8 +146,8 @@ TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyNegative) { + std::string model_buffer = + LoadBinaryContent(GetFullPath(kTestModelPath).c_str()); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier, +- BertNLClassifier::CreateFromBuffer(model_buffer.data(), +- model_buffer.size())); ++ BertNLClassifier::CreateFromBuffer( ++ model_buffer.data(), model_buffer.size())); + + verify_classifier(std::move(classifier), false); + } +@@ -156,24 +156,26 @@ TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyPositive) { + std::string model_buffer = + LoadBinaryContent(GetFullPath(kTestModelPath).c_str()); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier, +- BertNLClassifier::CreateFromBuffer(model_buffer.data(), +- model_buffer.size())); ++ BertNLClassifier::CreateFromBuffer( ++ model_buffer.data(), model_buffer.size())); + + verify_classifier(std::move(classifier), true); + } + + TEST_F(BertNLClassifierTest, TestNLClassifierFd_ClassifyPositive) { +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier, +- BertNLClassifier::CreateFromFd(open( +- GetFullPath(kTestModelPath).c_str(), O_RDONLY))); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::unique_ptr<BertNLClassifier> classifier, ++ BertNLClassifier::CreateFromFd( ++ open(GetFullPath(kTestModelPath).c_str(), O_RDONLY))); + + verify_classifier(std::move(classifier), false); + } + + TEST_F(BertNLClassifierTest, TestNLClassifierFd_ClassifyNegative) { +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier, +- BertNLClassifier::CreateFromFd(open( +- GetFullPath(kTestModelPath).c_str(), O_RDONLY))); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::unique_ptr<BertNLClassifier> classifier, ++ BertNLClassifier::CreateFromFd( ++ open(GetFullPath(kTestModelPath).c_str(), O_RDONLY))); + + verify_classifier(std::move(classifier), true); + } +@@ -191,8 +193,8 @@ TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) { + } + ss_for_positive_review << " movie review"; + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier, +- BertNLClassifier::CreateFromBuffer(model_buffer.data(), +- model_buffer.size())); ++ BertNLClassifier::CreateFromBuffer( ++ model_buffer.data(), model_buffer.size())); + + std::vector<core::Category> results = + classifier->Classify(ss_for_positive_review.str()); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc +index 252441df1cb59..a70dab7782044 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc +@@ -69,8 +69,7 @@ constexpr int kPredictAnsNum = 5; + class BertQuestionAnswererTest : public tflite_shims::testing::Test {}; + + std::string GetFullPath(absl::string_view file_name) { +- return JoinPath("./" /*test src dir*/, kTestDataDirectory, +- file_name); ++ return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name); + } + + TEST_F(BertQuestionAnswererTest, +@@ -108,8 +107,8 @@ TEST_F(BertQuestionAnswererTest, AnswerSucceedsWithModelWithMetadata) { + options.mutable_base_options()->mutable_model_file()->set_file_content( + contents); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(question_answerer, +- BertQuestionAnswerer::CreateFromOptions(options)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ question_answerer, BertQuestionAnswerer::CreateFromOptions(options)); + } + + std::vector<QaAnswer> answer = question_answerer->Answer(kContext, kQuestion); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc +index 67b03c3a45323..81198cfca30fc 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc +@@ -121,8 +121,7 @@ struct ProtoOptionsTestParam { + }; + + std::string GetFullPath(absl::string_view file_name) { +- return JoinPath("./" /*test src dir*/, kTestDataDirectory, +- file_name); ++ return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name); + } + + class ProtoOptionsTest : public TestWithParam<ProtoOptionsTestParam> { +@@ -163,7 +162,8 @@ TEST_F(ProtoOptionsTest, ClassifySucceedsWithBaseOptions) { + options.mutable_base_options()->mutable_model_file()->set_file_content( + contents); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(classifier, NLClassifier::CreateFromOptions(options)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN(classifier, ++ NLClassifier::CreateFromOptions(options)); + } + + std::vector<core::Category> positive_results = +@@ -180,8 +180,8 @@ TEST_F(ProtoOptionsTest, ClassifySucceedsWithBaseOptions) { + + TEST_F(ProtoOptionsTest, CreationFromIncorrectInputTensor) { + NLClassifierProtoOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kTestModelPath)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath)); + options.set_input_tensor_name("invalid_tensor_name"); + options.set_input_tensor_index(-1); + +@@ -200,8 +200,8 @@ TEST_F(ProtoOptionsTest, CreationFromIncorrectInputTensor) { + + TEST_F(ProtoOptionsTest, CreationFromIncorrectOutputScoreTensor) { + NLClassifierProtoOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kTestModelPath)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath)); + options.set_output_score_tensor_name("invalid_tensor_name"); + options.set_output_score_tensor_index(-1); + +@@ -224,7 +224,7 @@ TEST_F(ProtoOptionsTest, TestInferenceWithRegexTokenizer) { + options.mutable_base_options()->mutable_model_file()->set_file_name( + GetFullPath(kTestModelWithRegexTokenizer)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier, +- NLClassifier::CreateFromOptions(options)); ++ NLClassifier::CreateFromOptions(options)); + + std::vector<core::Category> positive_results = + classifier->Classify(kPositiveInput); +@@ -277,7 +277,7 @@ TEST_F(ProtoOptionsTest, TestInferenceWithAssociatedLabelBuiltinOps) { + options.mutable_base_options()->mutable_model_file()->set_file_name( + GetFullPath(kTestModelWithLabelBuiltInOpsPath)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier, +- NLClassifier::CreateFromOptions(options)); ++ NLClassifier::CreateFromOptions(options)); + std::vector<core::Category> results = classifier->Classify(kInputStr); + std::vector<core::Category> expected_class = { + {"Negative", 0.49332118034362793}, +@@ -296,8 +296,10 @@ struct ProtoOptionsTestParamToString { + }; + + NLClassifierProtoOptions CreateProtoOptionsFromTensorName( +- const char* input_tensor_name, const char* output_score_tensor_name, +- const char* output_label_tensor_name, const char* model_path) { ++ const char* input_tensor_name, ++ const char* output_score_tensor_name, ++ const char* output_label_tensor_name, ++ const char* model_path) { + NLClassifierProtoOptions options; + options.set_input_tensor_name(input_tensor_name); + options.set_output_score_tensor_name(output_score_tensor_name); +@@ -310,8 +312,10 @@ NLClassifierProtoOptions CreateProtoOptionsFromTensorName( + } + + NLClassifierProtoOptions CreateProtoOptionsFromTensorIndex( +- const int input_tensor_index, const int output_score_tensor_index, +- const int output_label_tensor_index, const char* model_path) { ++ const int input_tensor_index, ++ const int output_score_tensor_index, ++ const int output_label_tensor_index, ++ const char* model_path) { + NLClassifierProtoOptions options; + options.set_input_tensor_index(input_tensor_index); + options.set_output_score_tensor_index(output_score_tensor_index); +@@ -439,14 +443,16 @@ TEST_P(ProtoOptionsTest, TestClassify) { + EXPECT_THAT(results, UnorderedElementsAreArray(expected_class)); + } + +-INSTANTIATE_TEST_SUITE_P(TestClassify, ProtoOptionsTest, ++INSTANTIATE_TEST_SUITE_P(TestClassify, ++ ProtoOptionsTest, + ValuesIn(ClassifyParams()), + ProtoOptionsTestParamToString()); + + // Tests for struct sNLClassifierOptions. + class StructOptionsTest : public tflite_shims::testing::Test {}; + +-void AssertStatus(absl::Status status, absl::StatusCode status_code, ++void AssertStatus(absl::Status status, ++ absl::StatusCode status_code, + TfLiteSupportStatus tfls_code) { + ASSERT_EQ(status.code(), status_code); + EXPECT_THAT(status.GetPayload(kTfLiteSupportPayload), +@@ -454,30 +460,29 @@ void AssertStatus(absl::Status status, absl::StatusCode status_code, + } + + TEST_F(StructOptionsTest, TestApiCreationFromBuffer) { +- std::string model_buffer = +- LoadBinaryContent(JoinPath("./" /*test src dir*/, +- kTestDataDirectory, kTestModelPath) +- .c_str()); ++ std::string model_buffer = LoadBinaryContent( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath) ++ .c_str()); + SUPPORT_ASSERT_OK(NLClassifier::CreateFromBufferAndOptions( + model_buffer.data(), model_buffer.size(), {}, CreateCustomResolver())); + } + + TEST_F(StructOptionsTest, TestApiCreationFromFile) { +- SUPPORT_ASSERT_OK(NLClassifier::CreateFromFileAndOptions(GetFullPath(kTestModelPath), +- {}, CreateCustomResolver())); ++ SUPPORT_ASSERT_OK(NLClassifier::CreateFromFileAndOptions( ++ GetFullPath(kTestModelPath), {}, CreateCustomResolver())); + } + + TEST_F(StructOptionsTest, TestApiCreationFromIncorrectInputTensor) { + NLClassifierOptions options; + options.input_tensor_index = -1; + options.input_tensor_name = "I do not exist"; +- AssertStatus(NLClassifier::CreateFromFileAndOptions( +- JoinPath("./" /*test src dir*/, +- kTestDataDirectory, kTestModelPath), +- options, CreateCustomResolver()) +- .status(), +- absl::StatusCode::kInvalidArgument, +- TfLiteSupportStatus::kInputTensorNotFoundError); ++ AssertStatus( ++ NLClassifier::CreateFromFileAndOptions( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath), ++ options, CreateCustomResolver()) ++ .status(), ++ absl::StatusCode::kInvalidArgument, ++ TfLiteSupportStatus::kInputTensorNotFoundError); + } + + TEST_F(StructOptionsTest, TestApiCreationFromIncorrectOutputScoreTensor) { +@@ -497,9 +502,10 @@ TEST_F(StructOptionsTest, TestInferenceWithRegexTokenizer) { + options.output_score_tensor_name = "probability"; + + // The model with regex tokenizer doesn't need any custom ops. +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier, +- NLClassifier::CreateFromFileAndOptions( +- GetFullPath(kTestModelWithRegexTokenizer), options)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::unique_ptr<NLClassifier> classifier, ++ NLClassifier::CreateFromFileAndOptions( ++ GetFullPath(kTestModelWithRegexTokenizer), options)); + + std::vector<core::Category> positive_results = + classifier->Classify(kPositiveInput); +@@ -519,9 +525,9 @@ TEST_F(StructOptionsTest, TestInferenceWithBoolOutput) { + options.output_score_tensor_index = 0; + + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier, +- NLClassifier::CreateFromFileAndOptions( +- GetFullPath(kTestModelBoolOutputPath), options, +- CreateCustomResolver())); ++ NLClassifier::CreateFromFileAndOptions( ++ GetFullPath(kTestModelBoolOutputPath), ++ options, CreateCustomResolver())); + std::vector<core::Category> results = classifier->Classify(kInputStr); + std::vector<core::Category> expected_class = { + {"0", 1}, +@@ -535,10 +541,11 @@ TEST_F(StructOptionsTest, TestInferenceWithBoolOutput) { + TEST_F(StructOptionsTest, TestInferenceWithAssociatedLabelCustomOps) { + NLClassifierOptions options; + options.output_score_tensor_name = kMetadataOutputScoreTensorName; +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier, +- NLClassifier::CreateFromFileAndOptions( +- GetFullPath(kTestModelWithLabelCustomOpsPath), +- options, CreateCustomResolver())); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::unique_ptr<NLClassifier> classifier, ++ NLClassifier::CreateFromFileAndOptions( ++ GetFullPath(kTestModelWithLabelCustomOpsPath), options, ++ CreateCustomResolver())); + std::vector<core::Category> results = classifier->Classify(kInputStr); + std::vector<core::Category> expected_class = { + {"label0", 255}, +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc +index 6da6b6f7a2da3..ae4e48cac2410 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc +@@ -17,9 +17,9 @@ limitations under the License. + + #include <memory> + +-#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl + #include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/cord.h" // from @com_google_absl ++#include "absl/strings/cord.h" // from @com_google_absl + #include "tensorflow/lite/c/common.h" + #include "tensorflow/lite/core/shims/cc/shims_test_util.h" + #include "tensorflow/lite/kernels/builtin_op_kernels.h" +@@ -70,8 +70,8 @@ constexpr char kMobileNetQuantizedWithMetadata[] = + constexpr char kAutoMLModelWithMetadata[] = "automl_labeler_model.tflite"; + + StatusOr<ImageData> LoadImage(std::string image_name) { +- return DecodeImageFromFile(JoinPath("./" /*test src dir*/, +- kTestDataDirectory, image_name)); ++ return DecodeImageFromFile( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name)); + } + + // If the proto definition changes, please also change this function. +@@ -159,9 +159,8 @@ TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) { + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileNetQuantizedWithMetadata)); +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetFloatWithMetadata)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + + StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or = + ImageClassifier::CreateFromOptions(options); +@@ -234,9 +233,8 @@ TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) { + TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) { + ImageClassifierOptions options; + options.set_num_threads(4); +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetFloatWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + + SUPPORT_ASSERT_OK(ImageClassifier::CreateFromOptions(options)); + } +@@ -248,9 +246,8 @@ INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2)); + TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) { + ImageClassifierOptions options; + options.set_num_threads(GetParam()); +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetFloatWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + + StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or = + ImageClassifier::CreateFromOptions(options); +@@ -273,12 +270,12 @@ TEST(ClassifyTest, SucceedsWithFloatModel) { + + ImageClassifierOptions options; + options.set_max_results(3); +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetFloatWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, +- ImageClassifier::CreateFromOptions(options)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::unique_ptr<ImageClassifier> image_classifier, ++ ImageClassifier::CreateFromOptions(options)); + + StatusOr<ClassificationResult> result_or = + image_classifier->Classify(*frame_buffer); +@@ -307,19 +304,20 @@ TEST(ClassifyTest, SucceedsWithFloatModel) { + } + + TEST(ClassifyTest, SucceedsWithRegionOfInterest) { +- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("multi_objects.jpg")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, ++ LoadImage("multi_objects.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); + + ImageClassifierOptions options; + options.set_max_results(1); +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetFloatWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, +- ImageClassifier::CreateFromOptions(options)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::unique_ptr<ImageClassifier> image_classifier, ++ ImageClassifier::CreateFromOptions(options)); + + // Crop around the soccer ball. + BoundingBox roi; +@@ -358,8 +356,9 @@ TEST(ClassifyTest, SucceedsWithQuantizedModel) { + JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileNetQuantizedWithMetadata)); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, +- ImageClassifier::CreateFromOptions(options)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::unique_ptr<ImageClassifier> image_classifier, ++ ImageClassifier::CreateFromOptions(options)); + + StatusOr<ClassificationResult> result_or = + image_classifier->Classify(*frame_buffer); +@@ -391,12 +390,12 @@ TEST(ClassifyTest, SucceedsWithBaseOptions) { + + ImageClassifierOptions options; + options.set_max_results(3); +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetFloatWithMetadata)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, +- ImageClassifier::CreateFromOptions(options)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::unique_ptr<ImageClassifier> image_classifier, ++ ImageClassifier::CreateFromOptions(options)); + + StatusOr<ClassificationResult> result_or = + image_classifier->Classify(*frame_buffer); +@@ -426,11 +425,11 @@ TEST(ClassifyTest, SucceedsWithBaseOptions) { + + TEST(ClassifyTest, GetInputCountSucceeds) { + ImageClassifierOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetFloatWithMetadata)); +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, +- ImageClassifier::CreateFromOptions(options)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::unique_ptr<ImageClassifier> image_classifier, ++ ImageClassifier::CreateFromOptions(options)); + + int32_t input_count = image_classifier->GetInputCount(); + EXPECT_THAT(input_count, 1); +@@ -438,11 +437,11 @@ TEST(ClassifyTest, GetInputCountSucceeds) { + + TEST(ClassifyTest, GetInputShapeSucceeds) { + ImageClassifierOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetFloatWithMetadata)); +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, +- ImageClassifier::CreateFromOptions(options)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::unique_ptr<ImageClassifier> image_classifier, ++ ImageClassifier::CreateFromOptions(options)); + + // Verify the shape array size. + const TfLiteIntArray* input_shape_0 = image_classifier->GetInputShape(0); +@@ -456,11 +455,11 @@ TEST(ClassifyTest, GetInputShapeSucceeds) { + + TEST(ClassifyTest, GetOutputCountSucceeds) { + ImageClassifierOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetFloatWithMetadata)); +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, +- ImageClassifier::CreateFromOptions(options)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::unique_ptr<ImageClassifier> image_classifier, ++ ImageClassifier::CreateFromOptions(options)); + + int32_t output_count = image_classifier->GetOutputCount(); + EXPECT_THAT(output_count, 1); +@@ -468,11 +467,11 @@ TEST(ClassifyTest, GetOutputCountSucceeds) { + + TEST(ClassifyTest, GetOutputShapeSucceeds) { + ImageClassifierOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetFloatWithMetadata)); +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, +- ImageClassifier::CreateFromOptions(options)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::unique_ptr<ImageClassifier> image_classifier, ++ ImageClassifier::CreateFromOptions(options)); + + // Verify the shape array size. + const TfLiteIntArray* output_shape_0 = image_classifier->GetOutputShape(0); +@@ -537,9 +536,8 @@ class PostprocessTest : public tflite_shims::testing::Test { + + TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) { + ImageClassifierOptions options; +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kAutoMLModelWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata)); + options.set_max_results(3); + + SetUp(options); +@@ -551,9 +549,10 @@ TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) { + std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255, + /*sunflowers*/ 32, /*tulips*/ 128}; + SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor)); +- SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result, +- test_image_classifier_->Postprocess( +- {output_tensor}, *dummy_frame_buffer_, /*roi=*/{})); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ ClassificationResult result, ++ test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_, ++ /*roi=*/{})); + ExpectApproximatelyEqual( + result, + ParseTextProtoOrDie<ClassificationResult>( +@@ -568,9 +567,8 @@ TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) { + + TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) { + ImageClassifierOptions options; +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kAutoMLModelWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata)); + options.set_score_threshold(0.4); + + SetUp(options); +@@ -582,9 +580,10 @@ TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) { + std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255, + /*sunflowers*/ 32, /*tulips*/ 128}; + SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor)); +- SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result, +- test_image_classifier_->Postprocess( +- {output_tensor}, *dummy_frame_buffer_, /*roi=*/{})); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ ClassificationResult result, ++ test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_, ++ /*roi=*/{})); + + ExpectApproximatelyEqual( + result, +@@ -599,9 +598,8 @@ TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) { + + TEST_F(PostprocessTest, SucceedsWithWhitelistOption) { + ImageClassifierOptions options; +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kAutoMLModelWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata)); + options.add_class_name_whitelist("dandelion"); + options.add_class_name_whitelist("daisy"); + +@@ -614,9 +612,10 @@ TEST_F(PostprocessTest, SucceedsWithWhitelistOption) { + std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255, + /*sunflowers*/ 32, /*tulips*/ 128}; + SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor)); +- SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result, +- test_image_classifier_->Postprocess( +- {output_tensor}, *dummy_frame_buffer_, /*roi=*/{})); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ ClassificationResult result, ++ test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_, ++ /*roi=*/{})); + ExpectApproximatelyEqual( + result, + ParseTextProtoOrDie<ClassificationResult>( +@@ -630,9 +629,8 @@ TEST_F(PostprocessTest, SucceedsWithWhitelistOption) { + + TEST_F(PostprocessTest, SucceedsWithBlacklistOption) { + ImageClassifierOptions options; +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kAutoMLModelWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata)); + options.add_class_name_blacklist("dandelion"); + options.add_class_name_blacklist("daisy"); + +@@ -645,9 +643,10 @@ TEST_F(PostprocessTest, SucceedsWithBlacklistOption) { + std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255, + /*sunflowers*/ 32, /*tulips*/ 128}; + SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor)); +- SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result, +- test_image_classifier_->Postprocess( +- {output_tensor}, *dummy_frame_buffer_, /*roi=*/{})); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ ClassificationResult result, ++ test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_, ++ /*roi=*/{})); + + ExpectApproximatelyEqual( + result, +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc +index d5606d12440b0..8877f28b98beb 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc +@@ -17,7 +17,7 @@ limitations under the License. + + #include <memory> + +-#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl + #include "absl/status/status.h" // from @com_google_absl + #include "tensorflow/lite/c/common.h" + #include "tensorflow/lite/core/shims/cc/shims_test_util.h" +@@ -59,8 +59,8 @@ constexpr char kMobileNetV3[] = "mobilenet_v3_small_100_224_embedder.tflite"; + constexpr double kSimilarityTolerancy = 1e-6; + + StatusOr<ImageData> LoadImage(std::string image_name) { +- return DecodeImageFromFile(JoinPath("./" /*test src dir*/, +- kTestDataDirectory, image_name)); ++ return DecodeImageFromFile( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name)); + } + + class MobileNetV3OpResolver : public ::tflite::MutableOpResolver { +@@ -93,8 +93,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {}; + + TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) { + ImageEmbedderOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + + SUPPORT_ASSERT_OK(ImageEmbedder::CreateFromOptions( + options, absl::make_unique<MobileNetV3OpResolver>())); +@@ -113,8 +113,8 @@ class MobileNetV3OpResolverMissingOps : public ::tflite::MutableOpResolver { + + TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { + ImageEmbedderOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + + auto image_embedder_or = ImageEmbedder::CreateFromOptions( + options, absl::make_unique<MobileNetV3OpResolverMissingOps>()); +@@ -231,8 +231,9 @@ TEST(CosineSimilarityTest, Succeeds) { + // Prevent literal from being interpreted as null-terminated C-style string. + *v_quantized.mutable_value_string() = std::string("\x80\x00\x00\x00", 4); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(double float_similarity, +- ImageEmbedder::CosineSimilarity(u_float, v_float)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ double float_similarity, ++ ImageEmbedder::CosineSimilarity(u_float, v_float)); + SUPPORT_ASSERT_OK_AND_ASSIGN( + double quantized_similarity, + ImageEmbedder::CosineSimilarity(u_quantized, v_quantized)); +@@ -246,10 +247,10 @@ TEST(CosineSimilarityTest, Succeeds) { + TEST(EmbedTest, SucceedsWithoutL2Normalization) { + // Create embedder. + ImageEmbedderOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder, +- ImageEmbedder::CreateFromOptions(options)); ++ ImageEmbedder::CreateFromOptions(options)); + // Load images: one is a crop of the other. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer( +@@ -260,10 +261,10 @@ TEST(EmbedTest, SucceedsWithoutL2Normalization) { + + // Extract both embeddings. + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, +- embedder->Embed(*image_frame_buffer)); ++ embedder->Embed(*image_frame_buffer)); + ImageDataFree(&image); + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, +- embedder->Embed(*crop_frame_buffer)); ++ embedder->Embed(*crop_frame_buffer)); + ImageDataFree(&crop); + + // Check results sizes +@@ -276,9 +277,9 @@ TEST(EmbedTest, SucceedsWithoutL2Normalization) { + crop_result.embeddings(0).feature_vector(); + EXPECT_EQ(crop_feature_vector.value_float_size(), 1024); + // Check cosine similarity. +- SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity, +- ImageEmbedder::CosineSimilarity(image_feature_vector, +- crop_feature_vector)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector, ++ crop_feature_vector)); + double expected_similarity = 0.932738; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); + } +@@ -287,11 +288,11 @@ TEST(EmbedTest, SucceedsWithoutL2Normalization) { + TEST(EmbedTest, SucceedsWithL2Normalization) { + // Create embedder. + ImageEmbedderOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + options.set_l2_normalize(true); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder, +- ImageEmbedder::CreateFromOptions(options)); ++ ImageEmbedder::CreateFromOptions(options)); + // Load images: one is a crop of the other. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer( +@@ -302,10 +303,10 @@ TEST(EmbedTest, SucceedsWithL2Normalization) { + + // Extract both embeddings. + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, +- embedder->Embed(*image_frame_buffer)); ++ embedder->Embed(*image_frame_buffer)); + ImageDataFree(&image); + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, +- embedder->Embed(*crop_frame_buffer)); ++ embedder->Embed(*crop_frame_buffer)); + ImageDataFree(&crop); + + // Check results sizes +@@ -318,9 +319,9 @@ TEST(EmbedTest, SucceedsWithL2Normalization) { + crop_result.embeddings(0).feature_vector(); + EXPECT_EQ(crop_feature_vector.value_float_size(), 1024); + // Check cosine similarity. +- SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity, +- ImageEmbedder::CosineSimilarity(image_feature_vector, +- crop_feature_vector)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector, ++ crop_feature_vector)); + double expected_similarity = 0.932738; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); + } +@@ -331,12 +332,12 @@ TEST(EmbedTest, SucceedsWithL2Normalization) { + TEST(EmbedTest, SucceedsWithQuantization) { + // Create embedder. + ImageEmbedderOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + options.set_l2_normalize(true); + options.set_quantize(true); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder, +- ImageEmbedder::CreateFromOptions(options)); ++ ImageEmbedder::CreateFromOptions(options)); + // Load images: one is a crop of the other. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer( +@@ -347,10 +348,10 @@ TEST(EmbedTest, SucceedsWithQuantization) { + + // Extract both embeddings. + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, +- embedder->Embed(*image_frame_buffer)); ++ embedder->Embed(*image_frame_buffer)); + ImageDataFree(&image); + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, +- embedder->Embed(*crop_frame_buffer)); ++ embedder->Embed(*crop_frame_buffer)); + ImageDataFree(&crop); + + // Check results sizes +@@ -363,9 +364,9 @@ TEST(EmbedTest, SucceedsWithQuantization) { + crop_result.embeddings(0).feature_vector(); + EXPECT_EQ(crop_feature_vector.value_string().size(), 1024); + // Check cosine similarity. +- SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity, +- ImageEmbedder::CosineSimilarity(image_feature_vector, +- crop_feature_vector)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector, ++ crop_feature_vector)); + // Close to but expectedly different from the above tests due to slight loss + // of precision during quantization: + double expected_similarity = 0.929717; +@@ -378,10 +379,10 @@ TEST(EmbedTest, SucceedsWithQuantization) { + TEST(EmbedTest, SucceedsWithRegionOfInterest) { + // Create embedder. + ImageEmbedderOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder, +- ImageEmbedder::CreateFromOptions(options)); ++ ImageEmbedder::CreateFromOptions(options)); + // Load images: one is a crop of the other. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer( +@@ -398,10 +399,10 @@ TEST(EmbedTest, SucceedsWithRegionOfInterest) { + + // Extract both embeddings. + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, +- embedder->Embed(*image_frame_buffer, roi)); ++ embedder->Embed(*image_frame_buffer, roi)); + ImageDataFree(&image); + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, +- embedder->Embed(*crop_frame_buffer)); ++ embedder->Embed(*crop_frame_buffer)); + ImageDataFree(&crop); + + // Check results sizes +@@ -414,9 +415,9 @@ TEST(EmbedTest, SucceedsWithRegionOfInterest) { + crop_result.embeddings(0).feature_vector(); + EXPECT_EQ(crop_feature_vector.value_float_size(), 1024); + // Check cosine similarity. +- SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity, +- ImageEmbedder::CosineSimilarity(image_feature_vector, +- crop_feature_vector)); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector, ++ crop_feature_vector)); + double expected_similarity = 0.999914; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); + } +@@ -424,10 +425,10 @@ TEST(EmbedTest, SucceedsWithRegionOfInterest) { + TEST(GetEmbeddingDimension, Succeeds) { + // Create embedder. + ImageEmbedderOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder, +- ImageEmbedder::CreateFromOptions(options)); ++ ImageEmbedder::CreateFromOptions(options)); + + EXPECT_EQ(embedder->GetEmbeddingDimension(0), 1024); + EXPECT_EQ(embedder->GetEmbeddingDimension(1), -1); +@@ -436,10 +437,10 @@ TEST(GetEmbeddingDimension, Succeeds) { + TEST(GetNumberOfOutputLayers, Succeeds) { + // Create embedder. + ImageEmbedderOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder, +- ImageEmbedder::CreateFromOptions(options)); ++ ImageEmbedder::CreateFromOptions(options)); + + EXPECT_EQ(embedder->GetNumberOfOutputLayers(), 1); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc +index 3aab0bbee48ef..dc768a43a8726 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc +@@ -17,9 +17,9 @@ limitations under the License. + + #include <memory> + +-#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl + #include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/cord.h" // from @com_google_absl ++#include "absl/strings/cord.h" // from @com_google_absl + #include "tensorflow/lite/c/common.h" + #include "tensorflow/lite/core/shims/cc/shims_test_util.h" + #include "tensorflow/lite/kernels/builtin_op_kernels.h" +@@ -99,8 +99,8 @@ constexpr float kGoldenMaskTolerance = 1e-2; + constexpr int kGoldenMaskMagnificationFactor = 10; + + StatusOr<ImageData> LoadImage(std::string image_name) { +- return DecodeImageFromFile(JoinPath("./" /*test src dir*/, +- kTestDataDirectory, image_name)); ++ return DecodeImageFromFile( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name)); + } + + // Checks that the two provided `Segmentation` protos are equal. +@@ -141,8 +141,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {}; + + TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) { + ImageSegmenterOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + + SUPPORT_ASSERT_OK(ImageSegmenter::CreateFromOptions( + options, absl::make_unique<DeepLabOpResolver>())); +@@ -160,8 +160,8 @@ class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver { + + TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { + ImageSegmenterOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + + auto image_segmenter_or = ImageSegmenter::CreateFromOptions( + options, absl::make_unique<DeepLabOpResolverMissingOps>()); +@@ -177,10 +177,10 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { + + TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) { + ImageSegmenterOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); +- options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + + StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or = + ImageSegmenter::CreateFromOptions(options); +@@ -212,8 +212,8 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { + + TEST_F(CreateFromOptionsTest, FailsWithUnspecifiedOutputType) { + ImageSegmenterOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + options.set_output_type(ImageSegmenterOptions::UNSPECIFIED); + + auto image_segmenter_or = ImageSegmenter::CreateFromOptions(options); +@@ -230,8 +230,8 @@ TEST_F(CreateFromOptionsTest, FailsWithUnspecifiedOutputType) { + TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) { + ImageSegmenterOptions options; + options.set_num_threads(4); +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + + SUPPORT_ASSERT_OK(ImageSegmenter::CreateFromOptions(options)); + } +@@ -243,8 +243,8 @@ INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2)); + TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) { + ImageSegmenterOptions options; + options.set_num_threads(GetParam()); +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + + StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or = + ImageSegmenter::CreateFromOptions(options); +@@ -263,21 +263,21 @@ TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) { + TEST(SegmentTest, SucceedsWithCategoryMask) { + // Load input and build frame buffer. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, +- LoadImage("segmentation_input_rotation0.jpg")); ++ LoadImage("segmentation_input_rotation0.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); + // Load golden mask output. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask, +- LoadImage("segmentation_golden_rotation0.png")); ++ LoadImage("segmentation_golden_rotation0.png")); + + ImageSegmenterOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter, +- ImageSegmenter::CreateFromOptions(options)); ++ ImageSegmenter::CreateFromOptions(options)); + SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result, +- image_segmenter->Segment(*frame_buffer)); ++ image_segmenter->Segment(*frame_buffer)); + + EXPECT_EQ(result.segmentation_size(), 1); + const Segmentation& segmentation = result.segmentation(0); +@@ -301,23 +301,24 @@ TEST(SegmentTest, SucceedsWithCategoryMask) { + + TEST(SegmentTest, SucceedsWithOrientation) { + // Load input and build frame buffer with kRightBottom orientation. +- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, +- LoadImage("segmentation_input_rotation90_flop.jpg")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ ImageData rgb_image, LoadImage("segmentation_input_rotation90_flop.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}, + FrameBuffer::Orientation::kRightBottom); + // Load golden mask output. +- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask, +- LoadImage("segmentation_golden_rotation90_flop.png")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ ImageData golden_mask, ++ LoadImage("segmentation_golden_rotation90_flop.png")); + + ImageSegmenterOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter, +- ImageSegmenter::CreateFromOptions(options)); ++ ImageSegmenter::CreateFromOptions(options)); + SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result, +- image_segmenter->Segment(*frame_buffer)); ++ image_segmenter->Segment(*frame_buffer)); + + EXPECT_EQ(result.segmentation_size(), 1); + const Segmentation& segmentation = result.segmentation(0); +@@ -341,21 +342,21 @@ TEST(SegmentTest, SucceedsWithOrientation) { + TEST(SegmentTest, SucceedsWithBaseOptions) { + // Load input and build frame buffer. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, +- LoadImage("segmentation_input_rotation0.jpg")); ++ LoadImage("segmentation_input_rotation0.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); + // Load golden mask output. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask, +- LoadImage("segmentation_golden_rotation0.png")); ++ LoadImage("segmentation_golden_rotation0.png")); + + ImageSegmenterOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter, +- ImageSegmenter::CreateFromOptions(options)); ++ ImageSegmenter::CreateFromOptions(options)); + SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result, +- image_segmenter->Segment(*frame_buffer)); ++ image_segmenter->Segment(*frame_buffer)); + + EXPECT_EQ(result.segmentation_size(), 1); + const Segmentation& segmentation = result.segmentation(0); +@@ -461,18 +462,18 @@ class PostprocessTest : public tflite_shims::testing::Test { + + TEST_F(PostprocessTest, SucceedsWithCategoryMask) { + ImageSegmenterOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + std::unique_ptr<FrameBuffer> frame_buffer = + CreateFromRgbaRawBuffer(/*input=*/nullptr, {}); + + SetUp(options); + ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_; + SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor, +- FillAndGetOutputTensor()); ++ FillAndGetOutputTensor()); + SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result, +- test_image_segmenter_->Postprocess( +- {output_tensor}, *frame_buffer, /*roi=*/{})); ++ test_image_segmenter_->Postprocess( ++ {output_tensor}, *frame_buffer, /*roi=*/{})); + + EXPECT_EQ(result.segmentation_size(), 1); + const Segmentation& segmentation = result.segmentation(0); +@@ -487,8 +488,8 @@ TEST_F(PostprocessTest, SucceedsWithCategoryMask) { + + TEST_F(PostprocessTest, SucceedsWithCategoryMaskAndOrientation) { + ImageSegmenterOptions options; +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + // Frame buffer with kRightBottom orientation. + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbaRawBuffer( + /*input=*/nullptr, {}, FrameBuffer::Orientation::kRightBottom); +@@ -496,10 +497,10 @@ TEST_F(PostprocessTest, SucceedsWithCategoryMaskAndOrientation) { + SetUp(options); + ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_; + SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor, +- FillAndGetOutputTensor()); ++ FillAndGetOutputTensor()); + SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result, +- test_image_segmenter_->Postprocess( +- {output_tensor}, *frame_buffer, /*roi=*/{})); ++ test_image_segmenter_->Postprocess( ++ {output_tensor}, *frame_buffer, /*roi=*/{})); + + EXPECT_EQ(result.segmentation_size(), 1); + const Segmentation& segmentation = result.segmentation(0); +@@ -515,18 +516,18 @@ TEST_F(PostprocessTest, SucceedsWithCategoryMaskAndOrientation) { + TEST_F(PostprocessTest, SucceedsWithConfidenceMask) { + ImageSegmenterOptions options; + options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK); +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + std::unique_ptr<FrameBuffer> frame_buffer = + CreateFromRgbaRawBuffer(/*input=*/nullptr, {}); + + SetUp(options); + ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_; + SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor, +- FillAndGetOutputTensor()); ++ FillAndGetOutputTensor()); + SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result, +- test_image_segmenter_->Postprocess( +- {output_tensor}, *frame_buffer, /*roi=*/{})); ++ test_image_segmenter_->Postprocess( ++ {output_tensor}, *frame_buffer, /*roi=*/{})); + + EXPECT_EQ(result.segmentation_size(), 1); + const Segmentation& segmentation = result.segmentation(0); +@@ -547,8 +548,8 @@ TEST_F(PostprocessTest, SucceedsWithConfidenceMask) { + TEST_F(PostprocessTest, SucceedsWithConfidenceMaskAndOrientation) { + ImageSegmenterOptions options; + options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK); +- options.mutable_model_file_with_metadata()->set_file_name(JoinPath( +- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); ++ options.mutable_model_file_with_metadata()->set_file_name( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + // Frame buffer with kRightBottom orientation. + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbaRawBuffer( + /*input=*/nullptr, {}, FrameBuffer::Orientation::kRightBottom); +@@ -556,10 +557,10 @@ TEST_F(PostprocessTest, SucceedsWithConfidenceMaskAndOrientation) { + SetUp(options); + ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_; + SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor, +- FillAndGetOutputTensor()); ++ FillAndGetOutputTensor()); + SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result, +- test_image_segmenter_->Postprocess( +- {output_tensor}, *frame_buffer, /*roi=*/{})); ++ test_image_segmenter_->Postprocess( ++ {output_tensor}, *frame_buffer, /*roi=*/{})); + + EXPECT_EQ(result.segmentation_size(), 1); + const Segmentation& segmentation = result.segmentation(0); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc +index ef1f6509080ed..4a33e4b479354 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc +@@ -17,9 +17,9 @@ limitations under the License. + + #include <memory> + +-#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl + #include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/cord.h" // from @com_google_absl ++#include "absl/strings/cord.h" // from @com_google_absl + #include "tensorflow/lite/c/common.h" + #include "tensorflow/lite/core/shims/cc/shims_test_util.h" + #include "tensorflow/lite/kernels/builtin_op_kernels.h" +@@ -103,8 +103,8 @@ constexpr char kEfficientDetWithMetadata[] = + "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite"; + + StatusOr<ImageData> LoadImage(std::string image_name) { +- return DecodeImageFromFile(JoinPath("./" /*test src dir*/, +- kTestDataDirectory, image_name)); ++ return DecodeImageFromFile( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name)); + } + + // Checks that the two provided `DetectionResult` protos are equal, with a +@@ -153,9 +153,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {}; + + TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) { + ObjectDetectorOptions options; +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + + SUPPORT_ASSERT_OK(ObjectDetector::CreateFromOptions( + options, absl::make_unique<MobileSsdQuantizedOpResolver>())); +@@ -186,9 +185,8 @@ class MobileSsdQuantizedOpResolverMissingOps + + TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { + ObjectDetectorOptions options; +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + + auto object_detector_or = ObjectDetector::CreateFromOptions( + options, absl::make_unique<MobileSsdQuantizedOpResolverMissingOps>()); +@@ -203,12 +201,10 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { + + TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) { + ObjectDetectorOptions options; +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + + StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or = + ObjectDetector::CreateFromOptions(options); +@@ -241,9 +237,8 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { + + TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { + ObjectDetectorOptions options; +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.set_max_results(0); + + StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or = +@@ -260,9 +255,8 @@ TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { + + TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) { + ObjectDetectorOptions options; +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.add_class_name_whitelist("foo"); + options.add_class_name_blacklist("bar"); + +@@ -281,9 +275,8 @@ TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) { + TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) { + ObjectDetectorOptions options; + options.set_num_threads(4); +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + + SUPPORT_ASSERT_OK(ObjectDetector::CreateFromOptions(options)); + } +@@ -295,9 +288,8 @@ INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2)); + TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) { + ObjectDetectorOptions options; + options.set_num_threads(GetParam()); +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + + StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or = + ObjectDetector::CreateFromOptions(options); +@@ -315,51 +307,52 @@ TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) { + class DetectTest : public tflite_shims::testing::Test {}; + + TEST_F(DetectTest, Succeeds) { +- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("cats_and_dogs.jpg")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, ++ LoadImage("cats_and_dogs.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); + + ObjectDetectorOptions options; + options.set_max_results(4); +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector, +- ObjectDetector::CreateFromOptions(options)); ++ ObjectDetector::CreateFromOptions(options)); + + SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result, +- object_detector->Detect(*frame_buffer)); ++ object_detector->Detect(*frame_buffer)); + ImageDataFree(&rgb_image); + ExpectApproximatelyEqual( + result, ParseTextProtoOrDie<DetectionResult>(kExpectResults)); + } + + TEST_F(DetectTest, SucceedswithBaseOptions) { +- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("cats_and_dogs.jpg")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, ++ LoadImage("cats_and_dogs.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); + + ObjectDetectorOptions options; + options.set_max_results(4); +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector, +- ObjectDetector::CreateFromOptions(options)); ++ ObjectDetector::CreateFromOptions(options)); + + SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result, +- object_detector->Detect(*frame_buffer)); ++ object_detector->Detect(*frame_buffer)); + ImageDataFree(&rgb_image); + ExpectApproximatelyEqual( + result, ParseTextProtoOrDie<DetectionResult>(kExpectResults)); + } + + TEST_F(DetectTest, SucceedswithScoreCalibrations) { +- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("cats_and_dogs.jpg")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, ++ LoadImage("cats_and_dogs.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); +@@ -371,10 +364,10 @@ TEST_F(DetectTest, SucceedswithScoreCalibrations) { + kMobileSsdWithMetadataDummyScoreCalibration)); + + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector, +- ObjectDetector::CreateFromOptions(options)); ++ ObjectDetector::CreateFromOptions(options)); + + SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result, +- object_detector->Detect(*frame_buffer)); ++ object_detector->Detect(*frame_buffer)); + ImageDataFree(&rgb_image); + ExpectApproximatelyEqual( + result, ParseTextProtoOrDie<DetectionResult>(kExpectResults)); +@@ -482,20 +475,21 @@ class PostprocessTest : public tflite_shims::testing::Test { + + TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) { + ObjectDetectorOptions options; +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.set_score_threshold(0.5); + + SetUp(options); + ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_; + +- SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors, +- FillAndGetOutputTensors()); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ const std::vector<const TfLiteTensor*> output_tensors, ++ FillAndGetOutputTensors()); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result, +- test_object_detector_->Postprocess( +- output_tensors, *dummy_frame_buffer_, /*roi=*/{})); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ DetectionResult result, ++ test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_, ++ /*roi=*/{})); + + ExpectApproximatelyEqual( + result, +@@ -517,16 +511,16 @@ TEST_F(PostprocessTest, SucceedsWithFrameBufferOrientation) { + FrameBuffer::Orientation::kBottomRight); + + ObjectDetectorOptions options; +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.set_score_threshold(0.5); + + SetUp(options); + ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_; + +- SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors, +- FillAndGetOutputTensors()); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ const std::vector<const TfLiteTensor*> output_tensors, ++ FillAndGetOutputTensors()); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + DetectionResult result, +@@ -549,20 +543,21 @@ TEST_F(PostprocessTest, SucceedsWithFrameBufferOrientation) { + + TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) { + ObjectDetectorOptions options; +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.set_max_results(1); + + SetUp(options); + ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_; + +- SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors, +- FillAndGetOutputTensors()); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ const std::vector<const TfLiteTensor*> output_tensors, ++ FillAndGetOutputTensors()); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result, +- test_object_detector_->Postprocess( +- output_tensors, *dummy_frame_buffer_, /*roi=*/{})); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ DetectionResult result, ++ test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_, ++ /*roi=*/{})); + + ExpectApproximatelyEqual( + result, +@@ -576,21 +571,22 @@ TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) { + + TEST_F(PostprocessTest, SucceedsWithWhitelistOption) { + ObjectDetectorOptions options; +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.add_class_name_whitelist("car"); + options.add_class_name_whitelist("motorcycle"); + + SetUp(options); + ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_; + +- SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors, +- FillAndGetOutputTensors()); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ const std::vector<const TfLiteTensor*> output_tensors, ++ FillAndGetOutputTensors()); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result, +- test_object_detector_->Postprocess( +- output_tensors, *dummy_frame_buffer_, /*roi=*/{})); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ DetectionResult result, ++ test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_, ++ /*roi=*/{})); + + ExpectApproximatelyEqual( + result, +@@ -608,9 +604,8 @@ TEST_F(PostprocessTest, SucceedsWithWhitelistOption) { + + TEST_F(PostprocessTest, SucceedsWithBlacklistOption) { + ObjectDetectorOptions options; +- options.mutable_model_file_with_metadata()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileSsdWithMetadata)); ++ options.mutable_model_file_with_metadata()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.add_class_name_blacklist("car"); + // Setting score threshold to discard the 7 padded-with-zeros results. + options.set_score_threshold(0.1); +@@ -618,12 +613,14 @@ TEST_F(PostprocessTest, SucceedsWithBlacklistOption) { + SetUp(options); + ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_; + +- SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors, +- FillAndGetOutputTensors()); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ const std::vector<const TfLiteTensor*> output_tensors, ++ FillAndGetOutputTensors()); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result, +- test_object_detector_->Postprocess( +- output_tensors, *dummy_frame_buffer_, /*roi=*/{})); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ DetectionResult result, ++ test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_, ++ /*roi=*/{})); + + ExpectApproximatelyEqual( + result, +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc +index 7937dbafb090b..c16815cb38061 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc +@@ -21,13 +21,16 @@ namespace tflite { + namespace task { + + std::string JoinPath(absl::string_view path1, absl::string_view path2) { +- if (path1.empty()) return std::string(path2); +- if (path2.empty()) return std::string(path1); ++ if (path1.empty()) ++ return std::string(path2); ++ if (path2.empty()) ++ return std::string(path1); + if (path1.back() == '/') { + if (path2.front() == '/') + return absl::StrCat(path1, absl::ClippedSubstr(path2, 1)); + } else { +- if (path2.front() != '/') return absl::StrCat(path1, "/", path2); ++ if (path2.front() != '/') ++ return absl::StrCat(path1, "/", path2); + } + return absl::StrCat(path1, path2); + } +@@ -44,14 +47,16 @@ std::string JoinPathImpl(bool honor_abs, + // This size calculation is worst-case: it assumes one extra "/" for every + // path other than the first. + size_t total_size = paths.size() - 1; +- for (const absl::string_view path : paths) total_size += path.size(); ++ for (const absl::string_view path : paths) ++ total_size += path.size(); + result.resize(total_size); + + auto begin = result.begin(); + auto out = begin; + bool trailing_slash = false; + for (absl::string_view path : paths) { +- if (path.empty()) continue; ++ if (path.empty()) ++ continue; + if (path.front() == '/') { + if (honor_abs) { + out = begin; // wipe out whatever we've built up so far. +@@ -59,7 +64,8 @@ std::string JoinPathImpl(bool honor_abs, + path.remove_prefix(1); + } + } else { +- if (!trailing_slash && out != begin) *out++ = '/'; ++ if (!trailing_slash && out != begin) ++ *out++ = '/'; + } + const size_t this_size = path.size(); + memcpy(&*out, path.data(), this_size); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h +index db72bc5d5ae98..1d730d5a6d981 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h +@@ -33,8 +33,10 @@ std::string JoinPathImpl(bool honor_abs, + std::string JoinPath(absl::string_view path1, absl::string_view path2); + + template <typename... T> +-inline std::string JoinPath(absl::string_view path1, absl::string_view path2, +- absl::string_view path3, const T&... args) { ++inline std::string JoinPath(absl::string_view path1, ++ absl::string_view path2, ++ absl::string_view path3, ++ const T&... args) { + return internal::JoinPathImpl(false, {path1, path2, path3, args...}); + } + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc +index 6a050668edcbe..53c88310dde43 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc +@@ -31,7 +31,8 @@ FlatHashMapBackedWordpiece::FlatHashMapBackedWordpiece( + } + + tensorflow::text::LookupStatus FlatHashMapBackedWordpiece::Contains( +- absl::string_view key, bool* value) const { ++ absl::string_view key, ++ bool* value) const { + *value = index_map_.contains(key); + return tensorflow::text::LookupStatus(); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h +index aec178daf3cc5..1de54fa8f651c 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h +@@ -103,7 +103,8 @@ class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer { + + // Initialize the tokenizer from buffer and size of vocab and tokenizer + // configs. +- BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size, ++ BertTokenizer(const char* vocab_buffer_data, ++ size_t vocab_buffer_size, + const BertTokenizerOptions& options = {}) + : BertTokenizer( + utils::LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size), +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc +index 151161777863f..249bc2d1b6bc2 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc +@@ -31,9 +31,14 @@ using ::tflite::support::utils::StringListToVector; + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeLoadResource( // NOLINT +- JNIEnv* env, jobject thiz, jobject vocab_list, jint max_bytes_per_token, +- jint max_chars_per_sub_token, jstring jsuffix_indicator, +- jboolean use_unknown_token, jstring junknown_token, ++ JNIEnv* env, ++ jobject thiz, ++ jobject vocab_list, ++ jint max_bytes_per_token, ++ jint max_chars_per_sub_token, ++ jstring jsuffix_indicator, ++ jboolean use_unknown_token, ++ jstring junknown_token, + jboolean split_unknown_chars) { + // Convert java.util.List<String> into std::vector<string> + std::vector<std::string> vocab = StringListToVector(env, vocab_list); +@@ -66,20 +71,28 @@ Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeLoadResourc + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeUnloadResource( // NOLINT +- JNIEnv* env, jobject thiz, jlong handle) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong handle) { + delete reinterpret_cast<BertTokenizer*>(handle); + return 0; + } + + extern "C" JNIEXPORT jobjectArray JNICALL + Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeTokenize( +- JNIEnv* env, jobject thiz, jlong handle, jstring jtext) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong handle, ++ jstring jtext) { + return nativeTokenize(env, handle, jtext); + } + + extern "C" JNIEXPORT jintArray JNICALL + Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeConvertTokensToIds( // NOLINT +- JNIEnv* env, jobject thiz, jlong handle, jobjectArray jtokens) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong handle, ++ jobjectArray jtokens) { + return nativeConvertTokensToIds(env, handle, jtokens); + } + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc +index 832f9df42f824..ded6fbd13ea4a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc +@@ -17,7 +17,7 @@ limitations under the License. + + #include <iostream> + +-#include "absl/strings/str_cat.h" // from @com_google_absl ++#include "absl/strings/str_cat.h" // from @com_google_absl + #include "absl/strings/substitute.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/utils/common_utils.h" + namespace tflite { +@@ -70,7 +70,7 @@ TokenizerResult RegexTokenizer::Tokenize(const std::string& input) { + re2::StringPiece extracted_delim_token; + while (RE2::FindAndConsume(&leftover, delim_re_, &extracted_delim_token)) { + re2::StringPiece token(last_end.data(), +- extracted_delim_token.data() - last_end.data()); ++ extracted_delim_token.data() - last_end.data()); + bool has_non_empty_token = token.length() > 0; + + last_end = leftover; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc +index 6ecfff0d2baa1..8ca14c52eb262 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc +@@ -20,7 +20,7 @@ limitations under the License. + #include <utility> + #include <vector> + +-#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl + #include "absl/strings/str_split.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h" + #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h" +@@ -34,7 +34,9 @@ using ::tflite::support::utils::GetMappedFileBuffer; + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeLoadResource( // NOLINT +- JNIEnv* env, jobject obj, jobject model_buffer) { ++ JNIEnv* env, ++ jobject obj, ++ jobject model_buffer) { + auto model = GetMappedFileBuffer(env, model_buffer); + auto handle = + absl::make_unique<SentencePieceTokenizer>(model.data(), model.size()); +@@ -43,20 +45,28 @@ Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeLo + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeUnloadResource( // NOLINT +- JNIEnv* env, jobject obj, jlong handle) { ++ JNIEnv* env, ++ jobject obj, ++ jlong handle) { + delete reinterpret_cast<SentencePieceTokenizer*>(handle); + return 0; + } + + extern "C" JNIEXPORT jobjectArray JNICALL + Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeTokenize( // NOLINT +- JNIEnv* env, jobject thiz, jlong handle, jstring jtext) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong handle, ++ jstring jtext) { + return nativeTokenize(env, handle, jtext); + } + + extern "C" JNIEXPORT jintArray JNICALL + Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeConvertTokensToIds( // NOLINT +- JNIEnv* env, jobject thiz, jlong handle, jobjectArray jtokens) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong handle, ++ jobjectArray jtokens) { + return nativeConvertTokensToIds(env, handle, jtokens); + } + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc +index a72523be5984e..4e32bc5581a48 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc +@@ -54,7 +54,8 @@ jobjectArray nativeTokenize(JNIEnv* env, jlong handle, jstring jtext) { + return result; + } + +-jintArray nativeConvertTokensToIds(JNIEnv* env, jlong handle, ++jintArray nativeConvertTokensToIds(JNIEnv* env, ++ jlong handle, + jobjectArray jtokens) { + if (handle == 0) { + env->ThrowNew(env->FindClass(kIllegalStateException), +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h +index 33677d305a853..fd76f3aa553e4 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h +@@ -25,7 +25,8 @@ namespace support { + + jobjectArray nativeTokenize(JNIEnv* env, jlong handle, jstring jtext); + +-jintArray nativeConvertTokensToIds(JNIEnv* env, jlong handle, ++jintArray nativeConvertTokensToIds(JNIEnv* env, ++ jlong handle, + jobjectArray jtokens); + + } // namespace support +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc +index 28f0137f54278..32957d155dce6 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc +@@ -73,9 +73,9 @@ StatusOr<std::unique_ptr<Tokenizer>> CreateTokenizerFromProcessUnit( + } + case ProcessUnitOptions_SentencePieceTokenizerOptions: { + return CreateStatusWithPayload( +- absl::StatusCode::kInvalidArgument, +- "Chromium does not support sentencepiece tokenization", +- TfLiteSupportStatus::kMetadataInvalidTokenizerError); ++ absl::StatusCode::kInvalidArgument, ++ "Chromium does not support sentencepiece tokenization", ++ TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + case ProcessUnitOptions_RegexTokenizerOptions: { + const tflite::RegexTokenizerOptions* options = +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h +index 2e50a79963f82..696c5d4e27db7 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h +@@ -26,7 +26,6 @@ namespace support { + namespace text { + namespace tokenizer { + +- + // Create a Tokenizer from model metadata by extracting + tflite::support::StatusOr<std::unique_ptr<Tokenizer>> + CreateTokenizerFromProcessUnit( +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc +index 84cc0ef6ae52e..3ea6b147fcdd6 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc +@@ -83,7 +83,8 @@ absl::node_hash_map<std::string, int> LoadVocabAndIndexFromFile( + } + + absl::node_hash_map<std::string, int> LoadVocabAndIndexFromBuffer( +- const char* vocab_buffer_data, const size_t vocab_buffer_size) { ++ const char* vocab_buffer_data, ++ const size_t vocab_buffer_size) { + membuf sbuf(const_cast<char*>(vocab_buffer_data), + const_cast<char*>(vocab_buffer_data + vocab_buffer_size)); + absl::node_hash_map<std::string, int> vocab_index_map; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h +index 6921d2f5ac01b..275c4932f8ec0 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h +@@ -41,7 +41,8 @@ absl::node_hash_map<std::string, int> LoadVocabAndIndexFromFile( + // Read a vocab buffer with one vocabulary and its corresponding index on each + // line separated by space, create a map of <vocab, index>. + absl::node_hash_map<std::string, int> LoadVocabAndIndexFromBuffer( +- const char* vocab_buffer_data, const size_t vocab_buffer_size); ++ const char* vocab_buffer_data, ++ const size_t vocab_buffer_size); + } // namespace utils + } // namespace support + } // namespace tflite +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc +index bf9e93f9aa24a..35ce822951ad8 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc +@@ -18,8 +18,8 @@ limitations under the License. + #include <dlfcn.h> + #include <string.h> + +-#include "absl/memory/memory.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow/lite/core/shims/c/experimental/acceleration/configuration/delegate_plugin.h" + #include "tensorflow/lite/core/shims/cc/experimental/acceleration/configuration/delegate_registry.h" +@@ -168,7 +168,8 @@ void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...) { + va_end(args); + } + +-void ThrowExceptionWithMessage(JNIEnv* env, const char* clazz, ++void ThrowExceptionWithMessage(JNIEnv* env, ++ const char* clazz, + const char* message) { + jclass e_class = env->FindClass(clazz); + if (strcmp(clazz, kAssertionError) == 0) { +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h +index 7f0674d3c9187..7caf49e479859 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h +@@ -22,7 +22,7 @@ limitations under the License. + #include <string> + #include <vector> + +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/configuration_proto_inc.h" + #include "tensorflow_lite_support/cc/port/statusor.h" +@@ -57,7 +57,8 @@ T CheckNotNull(JNIEnv* env, T&& t) { + // Converts a std::vector<T> into a Java ArrayList using a converter, which + // processes a single element in the vector before adding it to the ArrayList. + template <typename T> +-jobject ConvertVectorToArrayList(JNIEnv* env, const std::vector<T>& results, ++jobject ConvertVectorToArrayList(JNIEnv* env, ++ const std::vector<T>& results, + std::function<jobject(T)> converter) { + jclass array_list_class = env->FindClass("java/util/ArrayList"); + jmethodID array_list_ctor = +@@ -91,7 +92,8 @@ jbyteArray CreateByteArray(JNIEnv* env, const jbyte* data, int num_bytes); + + void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...); + +-void ThrowExceptionWithMessage(JNIEnv* env, const char* clazz, ++void ThrowExceptionWithMessage(JNIEnv* env, ++ const char* clazz, + const char* message); + + const char* GetExceptionClassNameForStatusCode(absl::StatusCode status_code); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc +index eb94cb7020475..bb8f1f4d40655 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc +@@ -63,7 +63,8 @@ using details_android_java::TensorInfo; + // Using ctor and dtor to simulate an enter/exit schema like `with` in Python. + class AsBlock { + public: +- AsBlock(CodeWriter* code_writer, const std::string& before, ++ AsBlock(CodeWriter* code_writer, ++ const std::string& before, + bool trailing_blank_line = false) + : code_writer_(code_writer), trailing_blank_line_(trailing_blank_line) { + code_writer_->AppendNoNewLine(before); +@@ -105,7 +106,9 @@ std::string GetModelVersionedName(const ModelMetadata* metadata) { + } + + TensorInfo CreateTensorInfo(const TensorMetadata* metadata, +- const std::string& name, bool is_input, int index, ++ const std::string& name, ++ bool is_input, ++ int index, + ErrorReporter* err) { + TensorInfo tensor_info; + std::string tensor_identifier = is_input ? "input" : "output"; +@@ -273,7 +276,8 @@ bool IsImageUsed(const ModelInfo& model) { + + // The following functions generates the wrapper Java code for a model. + +-bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model, ++bool GenerateWrapperFileContent(CodeWriter* code_writer, ++ const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append("// Generated by TFLite Support."); + code_writer->Append("package {{PACKAGE}};"); +@@ -291,7 +295,8 @@ bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model, + return true; + } + +-bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model, ++bool GenerateWrapperImports(CodeWriter* code_writer, ++ const ModelInfo& model, + ErrorReporter* err) { + const std::string support_pkg = "org.tensorflow.lite.support."; + std::vector<std::string> imports{ +@@ -336,7 +341,8 @@ bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model, + return true; + } + +-bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model, ++bool GenerateWrapperClass(CodeWriter* code_writer, ++ const ModelInfo& model, + ErrorReporter* err) { + code_writer->SetTokenValue("MODEL_VERSIONED_NAME", + model.model_versioned_name); +@@ -373,7 +379,8 @@ private static final String MODEL_NAME = "{{MODEL_PATH}}";)"); + return true; + } + +-bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model, ++bool GenerateWrapperOutputs(CodeWriter* code_writer, ++ const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */"); + auto class_block = AsBlock(code_writer, "public static class Outputs"); +@@ -459,7 +466,8 @@ bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model, + return true; + } + +-bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model, ++bool GenerateWrapperMetadata(CodeWriter* code_writer, ++ const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append( + "/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */"); +@@ -605,7 +613,8 @@ public List<String> get{{NAME_U}}Labels() { + return true; + } + +-bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model, ++bool GenerateWrapperAPI(CodeWriter* code_writer, ++ const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append(R"(public Metadata getMetadata() { + return metadata; +@@ -980,8 +989,10 @@ AndroidJavaGenerator::AndroidJavaGenerator(const std::string& module_root) + : CodeGenerator(), module_root_(module_root) {} + + GenerationResult AndroidJavaGenerator::Generate( +- const Model* model, const std::string& package_name, +- const std::string& model_class_name, const std::string& model_asset_path) { ++ const Model* model, ++ const std::string& package_name, ++ const std::string& model_class_name, ++ const std::string& model_asset_path) { + GenerationResult result; + if (model == nullptr) { + err_.Error( +@@ -1006,8 +1017,10 @@ GenerationResult AndroidJavaGenerator::Generate( + } + + GenerationResult AndroidJavaGenerator::Generate( +- const char* model_storage, const std::string& package_name, +- const std::string& model_class_name, const std::string& model_asset_path) { ++ const char* model_storage, ++ const std::string& package_name, ++ const std::string& model_class_name, ++ const std::string& model_asset_path) { + const Model* model = GetModel(model_storage); + return Generate(model, package_name, model_class_name, model_asset_path); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h +index 634ccf69f6c1a..1ea8bb2182a67 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h +@@ -20,10 +20,10 @@ limitations under the License. + #include <string> + #include <vector> + ++#include "tensorflow/lite/schema/schema_generated.h" + #include "tensorflow_lite_support/codegen/code_generator.h" + #include "tensorflow_lite_support/codegen/utils.h" + #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" +-#include "tensorflow/lite/schema/schema_generated.h" + + namespace tflite { + namespace support { +@@ -90,7 +90,8 @@ class AndroidJavaGenerator : public CodeGenerator { + /// as "ImageClassifier", "MobileNetV2" or "MyModel". + /// - model_asset_path: The relevant path to the model file in the asset. + // TODO(b/141225157): Automatically generate model_class_name. +- GenerationResult Generate(const Model* model, const std::string& package_name, ++ GenerationResult Generate(const Model* model, ++ const std::string& package_name, + const std::string& model_class_name, + const std::string& model_asset_path); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc +index 1337708d4ac66..b6ec55cbc5e8b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc +@@ -144,7 +144,8 @@ std::string CodeGenerator::NameTensor(const TensorMetadata& tensor, + } + + void CodeGenerator::ResolveConflictedInputAndOutputNames( +- std::vector<std::string>* inputs, std::vector<std::string>* outputs) { ++ std::vector<std::string>* inputs, ++ std::vector<std::string>* outputs) { + std::unordered_set<std::string> io_conflict; + auto& input_names = *inputs; + auto& output_names = *outputs; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h +index b557773ddcc7a..fe67327986bd7 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h +@@ -70,7 +70,8 @@ class CodeGenerator { + static std::string NameTensor(const TensorMetadata& tensor, + const std::string& default_name); + static void ResolveConflictedInputAndOutputNames( +- std::vector<std::string>* input, std::vector<std::string>* output); ++ std::vector<std::string>* input, ++ std::vector<std::string>* output); + }; + + } // namespace codegen +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc +index 5e9d64a0d8f98..ccc87668ed3cb 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc +@@ -36,7 +36,8 @@ class CodeGeneratorTest : public ::testing::Test { + return CodeGenerator::ConvertToValidName(name); + } + static void ResolveConflictedInputAndOutputNames( +- std::vector<std::string>* input, std::vector<std::string>* output) { ++ std::vector<std::string>* input, ++ std::vector<std::string>* output) { + CodeGenerator::ResolveConflictedInputAndOutputNames(input, output); + } + }; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h b/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h +index 8e3dc6abaed66..193dfb2fb23f3 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h +@@ -18,9 +18,9 @@ limitations under the License. + + #include <string> + ++#include "tensorflow/lite/schema/schema_generated.h" + #include "tensorflow_lite_support/codegen/utils.h" + #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" +-#include "tensorflow/lite/schema/schema_generated.h" + + namespace tflite { + namespace support { +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc +index 6b2cd5ea9a778..a9da2403afc4f 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc +@@ -29,11 +29,10 @@ using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>; + + PYBIND11_MODULE(_pywrap_codegen, m) { + pybind11::class_<AndroidJavaGenerator>(m, "AndroidJavaGenerator") +- .def(pybind11::init<const std::string &>()) +- .def("generate", +- overload_cast_<const char *, const std::string &, +- const std::string &, const std::string &>()( +- &AndroidJavaGenerator::Generate)) ++ .def(pybind11::init<const std::string&>()) ++ .def("generate", overload_cast_<const char*, const std::string&, ++ const std::string&, const std::string&>()( ++ &AndroidJavaGenerator::Generate)) + .def("get_error_message", &AndroidJavaGenerator::GetErrorMessage); + pybind11::class_<GenerationResult>(m, "GenerationResult") + .def(pybind11::init<>()) +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc +index c75fc5fae631d..e89d09629dda1 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc +@@ -32,7 +32,8 @@ int ErrorReporter::Error(const char* format, ...) { + return Report("[ERROR] ", format, args); + } + +-int ErrorReporter::Report(const char* prefix, const char* format, ++int ErrorReporter::Report(const char* prefix, ++ const char* format, + va_list args) { + char buf[1024]; + int formatted = vsnprintf(buf, sizeof(buf), format, args); +@@ -69,9 +70,13 @@ void CodeWriter::SetIndentString(const std::string& indent_str) { + indent_str_ = indent_str; + } + +-void CodeWriter::Indent() { indent_++; } ++void CodeWriter::Indent() { ++ indent_++; ++} + +-void CodeWriter::Outdent() { indent_--; } ++void CodeWriter::Outdent() { ++ indent_--; ++} + + std::string CodeWriter::GenerateIndent() const { + std::string res; +@@ -82,7 +87,9 @@ std::string CodeWriter::GenerateIndent() const { + return res; + } + +-void CodeWriter::Append(const std::string& text) { AppendInternal(text, true); } ++void CodeWriter::Append(const std::string& text) { ++ AppendInternal(text, true); ++} + + void CodeWriter::AppendNoNewLine(const std::string& text) { + AppendInternal(text, false); +@@ -144,15 +151,21 @@ void CodeWriter::AppendInternal(const std::string& text, bool newline) { + } + } + +-void CodeWriter::NewLine() { Append(""); } ++void CodeWriter::NewLine() { ++ Append(""); ++} + + void CodeWriter::Backspace(int n) { + buffer_.resize(buffer_.size() > n ? buffer_.size() - n : 0); + } + +-std::string CodeWriter::ToString() const { return buffer_; } ++std::string CodeWriter::ToString() const { ++ return buffer_; ++} + +-bool CodeWriter::IsStreamEmpty() const { return buffer_.empty(); } ++bool CodeWriter::IsStreamEmpty() const { ++ return buffer_.empty(); ++} + + void CodeWriter::Clear() { + buffer_.clear(); +@@ -181,11 +194,14 @@ std::string SnakeCaseToCamelCase(const std::string& s) { + } + + std::string JoinPath(const std::string& a, const std::string& b) { +- if (a.empty()) return b; ++ if (a.empty()) ++ return b; + std::string a_fixed = a; +- if (!a_fixed.empty() && a_fixed.back() == '/') a_fixed.pop_back(); ++ if (!a_fixed.empty() && a_fixed.back() == '/') ++ a_fixed.pop_back(); + std::string b_fixed = b; +- if (!b_fixed.empty() && b_fixed.front() == '/') b_fixed.erase(0, 1); ++ if (!b_fixed.empty() && b_fixed.front() == '/') ++ b_fixed.erase(0, 1); + return a_fixed + "/" + b_fixed; + } + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc +index 3831c63ca17cc..f55ffb907f133 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc +@@ -66,7 +66,9 @@ struct NgramsAttributes { + string_separator(m["string_separator"].ToString()) {} + }; + +-inline bool OutputIsTensor(TfLiteNode* node) { return NumOutputs(node) == 1; } ++inline bool OutputIsTensor(TfLiteNode* node) { ++ return NumOutputs(node) == 1; ++} + inline int NumRowSplits(TfLiteNode* node) { + return NumInputs(node) - kRowSplitsStart; + } +@@ -176,7 +178,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + std::vector<StringRef> tokens; + for (int j = input_row_splits[i]; j < input_row_splits[i + 1]; ++j) { + tokens.emplace_back(GetString(input_values, j)); +- if (tokens.size() < attributes.width) continue; ++ if (tokens.size() < attributes.width) ++ continue; + tokens.erase(tokens.begin(), + tokens.begin() + tokens.size() - attributes.width); + buffer.AddJoinedString(tokens, separator); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc +index b87fcac328623..dc21f37beb3bf 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc +@@ -15,8 +15,8 @@ limitations under the License. + + #include "tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h" + +-#include "tensorflow_lite_support/custom_ops/kernel/ngrams.h" + #include "tensorflow/lite/mutable_op_resolver.h" ++#include "tensorflow_lite_support/custom_ops/kernel/ngrams.h" + + namespace tflite { + namespace ops { +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc +index 91ef47af6fd0f..4a5e671fa0987 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc +@@ -40,7 +40,8 @@ using ::testing::ElementsAreArray; + class NgramsModel : public SingleOpModel { + public: + // Constructor for testing the op with a tf.Tensor +- NgramsModel(int width, const std::string& string_separator, ++ NgramsModel(int width, ++ const std::string& string_separator, + const std::vector<std::string>& input_values, + const std::vector<int>& input_shape) { + input_values_ = AddInput(TensorType_STRING); +@@ -56,7 +57,8 @@ class NgramsModel : public SingleOpModel { + // Constructor for the op with a tf.RaggedTensor + // Note: This interface uses row_lengths, as they're closer to the + // dimensions in a TensorShape, but internally everything is row_splits. +- NgramsModel(int width, const std::string& string_separator, ++ NgramsModel(int width, ++ const std::string& string_separator, + const std::vector<std::string>& input_values, + const std::vector<std::vector<int64_t>> nested_row_lengths) { + std::vector<std::vector<int>> input_shapes; +@@ -203,8 +205,7 @@ TEST(NgramsTest, TensorMultidimensionalInputWidthTwo) { + TEST(NgramsTest, RaggedTensorSingleSequenceWidthTwo) { + std::vector<std::vector<int64_t>> nested_row_lengths; + nested_row_lengths.push_back({4}); +- NgramsModel m(2, " ", {"this", "is", "a", "test"}, +- nested_row_lengths); ++ NgramsModel m(2, " ", {"this", "is", "a", "test"}, nested_row_lengths); + EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3)); + EXPECT_THAT(m.ExtractValuesTensorVector(), + ElementsAre("this is", "is a", "a test")); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h +index ade3c5c178920..811be781d27fe 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h +@@ -20,6 +20,6 @@ limitations under the License. + // C-function that is called from the Python Wrapper. + + extern "C" void TFLite_RaggedTensorToTensorRegisterer( +- tflite::MutableOpResolver *resolver); ++ tflite::MutableOpResolver* resolver); + + #endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_RAGGED_PY_TFLITE_REGISTERER_H_ +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc +index a35a6db9ad48f..9fc73dd0f9778 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc +@@ -71,9 +71,12 @@ TfLiteStatus EvalT(TfLiteContext* context, TfLiteNode* node) { + // nrows (number of output rows) is the size of the non-broadcast inputs, + // or 1 if all inputs are scalars. + std::vector<int> in_sizes; +- if (!broadcast_starts) in_sizes.push_back(input_starts.dims->data[0]); +- if (!broadcast_limits) in_sizes.push_back(input_limits.dims->data[0]); +- if (!broadcast_deltas) in_sizes.push_back(input_deltas.dims->data[0]); ++ if (!broadcast_starts) ++ in_sizes.push_back(input_starts.dims->data[0]); ++ if (!broadcast_limits) ++ in_sizes.push_back(input_limits.dims->data[0]); ++ if (!broadcast_deltas) ++ in_sizes.push_back(input_deltas.dims->data[0]); + if (std::adjacent_find(std::begin(in_sizes), std::end(in_sizes), + std::not_equal_to<>()) != std::end(in_sizes)) { + context->ReportError( +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc +index 54cf4459a27ed..87a047c512ea7 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc +@@ -39,7 +39,8 @@ class RaggedRangeOpModel : public SingleOpModel { + public: + static TensorType GetType(); + +- RaggedRangeOpModel(const std::vector<T>& start, const std::vector<T>& limits, ++ RaggedRangeOpModel(const std::vector<T>& start, ++ const std::vector<T>& limits, + const std::vector<T>& deltas) { + const TensorType value_type = GetType(); + std::vector<std::vector<int>> shapes; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc +index 09ac76c71b26c..ff5c14b8e5e08 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc +@@ -140,8 +140,10 @@ RuntimeShape TensorShapeFromTensor(const TfLiteTensor& tensor) { + } + + const TfLiteTensor* GetRowPartitionTensor( +- const ConversionAttributes& conversion_attributes, TfLiteContext* context, +- TfLiteNode* node, int dimension) { ++ const ConversionAttributes& conversion_attributes, ++ TfLiteContext* context, ++ TfLiteNode* node, ++ int dimension) { + if (conversion_attributes.partition_types.front() == + tensorflow::RowPartitionType::FIRST_DIM_SIZE) { + return &context->tensors[node->inputs->data[kFirstPartitionInputIndex + 1 + +@@ -211,7 +213,9 @@ int GetMaxWidthRowSplit(const TfLiteTensor* tensor) { + } + + int GetMaxWidth(const ConversionAttributes& conversion_attributes, +- TfLiteContext* context, TfLiteNode* node, int dimension) { ++ TfLiteContext* context, ++ TfLiteNode* node, ++ int dimension) { + const TfLiteTensor* tensor = GetRowPartitionTensor( + conversion_attributes, context, node, dimension - 1); + switch (conversion_attributes.GetRowPartitionTypeByDimension(dimension - 1)) { +@@ -226,7 +230,8 @@ int GetMaxWidth(const ConversionAttributes& conversion_attributes, + } + + RuntimeShape CombineRaggedTensorToTensorShapes( +- int ragged_rank, const RuntimeShape& output_shape, ++ int ragged_rank, ++ const RuntimeShape& output_shape, + const RuntimeShape& value_shape) { + // TODO(mgubin): No checks, see + // third_party/tensorflow/core/ops/ragged_to_dense_util.cc +@@ -247,9 +252,13 @@ RuntimeShape CombineRaggedTensorToTensorShapes( + } + + RuntimeShape CalculateOutputSize( +- const ConversionAttributes& conversion_attributes, TfLiteContext* context, +- TfLiteNode* node, int first_dimension, int ragged_rank, +- const TfLiteTensor& values, const TfLiteTensor& default_value, ++ const ConversionAttributes& conversion_attributes, ++ TfLiteContext* context, ++ TfLiteNode* node, ++ int first_dimension, ++ int ragged_rank, ++ const TfLiteTensor& values, ++ const TfLiteTensor& default_value, + const TfLiteTensor& output_shape) { + RuntimeShape values_shape(values.dims->size, values.dims->data); + RuntimeShape default_value_shape(default_value.dims->size, +@@ -331,7 +340,8 @@ void CalculateFirstParentOutputIndex(int first_dimension, + void CalculateOutputIndexValueRowID(const TfLiteTensor& value_rowids, + const std::vector<int>& parent_output_index, + int output_index_multiplier, +- int output_size, std::vector<int>* result) { ++ int output_size, ++ std::vector<int>* result) { + const RuntimeShape tensor_shape(value_rowids.dims->size, + value_rowids.dims->data); + const int index_size = tensor_shape.FlatSize(); +@@ -380,7 +390,8 @@ void CalculateOutputIndexValueRowID(const TfLiteTensor& value_rowids, + + void CalculateOutputIndexRowSplit(const TfLiteTensor& row_split, + const std::vector<int>& parent_output_index, +- int output_index_multiplier, int output_size, ++ int output_index_multiplier, ++ int output_size, + std::vector<int>* result) { + const RuntimeShape row_split_shape(row_split.dims->size, + row_split.dims->data); +@@ -421,10 +432,14 @@ void CalculateOutputIndexRowSplit(const TfLiteTensor& row_split, + } + + TfLiteStatus CalculateOutputIndex( +- const ConversionAttributes& conversion_attributes, TfLiteContext* context, +- TfLiteNode* node, int dimension, +- const std::vector<int>& parent_output_index, int output_index_multiplier, +- int output_size, std::vector<int>* result) { ++ const ConversionAttributes& conversion_attributes, ++ TfLiteContext* context, ++ TfLiteNode* node, ++ int dimension, ++ const std::vector<int>& parent_output_index, ++ int output_index_multiplier, ++ int output_size, ++ std::vector<int>* result) { + const TfLiteTensor* row_partition_tensor = + GetRowPartitionTensor(conversion_attributes, context, node, dimension); + auto partition_type = +@@ -447,7 +462,8 @@ TfLiteStatus CalculateOutputIndex( + } + + template <typename VALUE_TYPE> +-void SetOutputT(TfLiteContext* context, int ragged_rank, ++void SetOutputT(TfLiteContext* context, ++ int ragged_rank, + const std::vector<int>& output_index, + const TfLiteTensor& values_tensor, + const TfLiteTensor& default_value_tensor, +@@ -522,7 +538,8 @@ void SetOutputT(TfLiteContext* context, int ragged_rank, + } + } + +-void SetOutput(TfLiteContext* context, int ragged_rank, ++void SetOutput(TfLiteContext* context, ++ int ragged_rank, + const std::vector<int>& output_index, + const TfLiteTensor& values_tensor, + const TfLiteTensor& default_value_tensor, +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc +index b1cde57c47c68..2f7a2a95b8478 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc +@@ -82,7 +82,8 @@ class RaggedTensorToTensorOpModel : public SingleOpModel { + std::vector<int32> GetOutputInt() { return ExtractVector<int32>(output_); } + + void InvokeFloat(const std::vector<int>& shape, +- const std::vector<float>& values, float default_value, ++ const std::vector<float>& values, ++ float default_value, + const std::vector<std::vector<int>>& partition_values) { + PopulateTensor(input_shape_, shape); + PopulateTensor(input_values_, values); +@@ -93,7 +94,8 @@ class RaggedTensorToTensorOpModel : public SingleOpModel { + SingleOpModel::Invoke(); + } + void InvokeInt(const std::vector<int>& shape, +- const std::vector<int32>& values, int32 default_value, ++ const std::vector<int32>& values, ++ int32 default_value, + const std::vector<std::vector<int>>& partition_values) { + PopulateTensor(input_shape_, shape); + PopulateTensor(input_values_, values); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc +index 4e2b87de37327..47ba9fdfebcae 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc +@@ -15,8 +15,8 @@ limitations under the License. + + #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h" + +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/str_replace.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/str_replace.h" // from @com_google_absl + #include "src/sentencepiece_model.pb.h" // from @com_google_sentencepiece + #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config_generated.h" + #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h" +@@ -48,7 +48,8 @@ DecodePrecompiledCharsmap( + } + + tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer( +- const std::string& model_config_str, int encoding_offset) { ++ const std::string& model_config_str, ++ int encoding_offset) { + ::sentencepiece::ModelProto model_config; + if (!model_config.ParseFromString(model_config_str)) { + return absl::InvalidArgumentError( +@@ -128,7 +129,8 @@ tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer( + + tflite::support::StatusOr<std::string> + ConvertSentencepieceModelToFlatBufferForDecoder( +- const std::string& model_config_str, int encoding_offset) { ++ const std::string& model_config_str, ++ int encoding_offset) { + ::sentencepiece::ModelProto model_config; + if (!model_config.ParseFromString(model_config_str)) { + return absl::InvalidArgumentError( +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h +index 5687b6287d140..03b3596820886 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h +@@ -27,13 +27,15 @@ namespace sentencepiece { + // Converts Sentencepiece configuration to flatbuffer format. + // encoding_offset is used by some encoders that combine different encodings. + tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer( +- const std::string& model_config_str, int encoding_offset = 0); ++ const std::string& model_config_str, ++ int encoding_offset = 0); + + // Converts Sentencepiece configuration to flatbuffer format for encoder. + // encoding_offset is used by some encoders that combine different encodings. + tflite::support::StatusOr<std::string> + ConvertSentencepieceModelToFlatBufferForDecoder( +- const std::string& model_config_str, int encoding_offset = 0); ++ const std::string& model_config_str, ++ int encoding_offset = 0); + + // The functions that are provided for the Python wrapper. + std::string ConvertSentencepieceModel(const std::string& model_string); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc +index 8e130ef73b9b6..94161c2ac4c4e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc +@@ -19,9 +19,9 @@ limitations under the License. + + #include <gmock/gmock.h> + #include <gtest/gtest.h> +-#include "absl/flags/flag.h" // from @com_google_absl +-#include "absl/strings/str_format.h" // from @com_google_absl +-#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece ++#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/strings/str_format.h" // from @com_google_absl ++#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece + #include "src/sentencepiece_processor.h" // from @com_google_sentencepiece + #include "tensorflow/core/platform/env.h" + #include "tensorflow_lite_support/cc/test/test_utils.h" +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc +index 45fde32237c65..4148f8e96627a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc +@@ -31,7 +31,8 @@ const char kSpaceSymbol[] = "\xe2\x96\x81"; + + template <typename processing_callback> + std::tuple<std::string, std::vector<int>> process_string( +- const std::string& input, const std::vector<int>& offsets, ++ const std::string& input, ++ const std::vector<int>& offsets, + const processing_callback& pc) { + std::string result_string; + result_string.reserve(input.size()); +@@ -78,7 +79,9 @@ std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data, + } + + std::tuple<int, utils::string_view> find_replacement( +- const char* data, int len, const DoubleArrayTrie& dat, ++ const char* data, ++ int len, ++ const DoubleArrayTrie& dat, + const flatbuffers::Vector<int8_t>& replacements) { + const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len)); + if (!max_match.empty()) { +@@ -94,7 +97,8 @@ std::tuple<int, utils::string_view> find_replacement( + } // namespace + + std::tuple<std::string, std::vector<int>> NormalizeString( +- const std::string& in_string, const EncoderConfig& config) { ++ const std::string& in_string, ++ const EncoderConfig& config) { + std::vector<int> output_offsets; + std::string result = in_string; + output_offsets.reserve(in_string.length()); +@@ -145,8 +149,10 @@ std::tuple<std::string, std::vector<int>> NormalizeString( + + EncoderResult EncodeNormalizedString(const std::string& str, + const std::vector<int>& offsets, +- const EncoderConfig& config, bool add_bos, +- bool add_eos, bool reverse) { ++ const EncoderConfig& config, ++ bool add_bos, ++ bool add_eos, ++ bool reverse) { + const DoubleArrayTrie piece_matcher(config.pieces()->nodes()); + const flatbuffers::Vector<float>* piece_scores = config.pieces_scores(); + const int unknown_code = config.unknown_code(); +@@ -219,8 +225,11 @@ EncoderResult EncodeNormalizedString(const std::string& str, + return result; + } + +-EncoderResult EncodeString(const std::string& string, const void* config_buffer, +- bool add_bos, bool add_eos, bool reverse) { ++EncoderResult EncodeString(const std::string& string, ++ const void* config_buffer, ++ bool add_bos, ++ bool add_eos, ++ bool reverse) { + // Get the config from the buffer. + const EncoderConfig* config = GetEncoderConfig(config_buffer); + if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) { +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h +index 44d6e88f2531c..b89154cbfa396 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h +@@ -37,12 +37,16 @@ struct EncoderResult { + std::vector<int> offsets; + }; + std::tuple<std::string, std::vector<int>> NormalizeString( +- const std::string& in_string, const EncoderConfig& config); ++ const std::string& in_string, ++ const EncoderConfig& config); + + // Encodes one string and returns ids and offsets. Takes the configuration as a + // type-erased buffer. +-EncoderResult EncodeString(const std::string& string, const void* config_buffer, +- bool add_bos, bool add_eos, bool reverse); ++EncoderResult EncodeString(const std::string& string, ++ const void* config_buffer, ++ bool add_bos, ++ bool add_eos, ++ bool reverse); + + } // namespace sentencepiece + } // namespace custom +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc +index e2787c785e8c4..dd956a22b26c1 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc +@@ -19,10 +19,10 @@ limitations under the License. + + #include <gmock/gmock.h> + #include <gtest/gtest.h> +-#include "absl/flags/flag.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/str_format.h" // from @com_google_absl +-#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece ++#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/str_format.h" // from @com_google_absl ++#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece + #include "src/sentencepiece_processor.h" // from @com_google_sentencepiece + #include "tensorflow/core/platform/env.h" + #include "tensorflow_lite_support/cc/test/test_utils.h" +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h +index deb4e4ee08dc2..3efcfefc6438d 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h +@@ -20,6 +20,6 @@ limitations under the License. + // C-function that is called from the Python Wrapper. + + extern "C" void TFLite_SentencepieceTokenizerRegisterer( +- tflite::MutableOpResolver *resolver); ++ tflite::MutableOpResolver* resolver); + + #endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_ +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc +index 54b34e4e33196..f5be376b45e12 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc +@@ -35,7 +35,8 @@ namespace detokenizer { + + constexpr int kOutputValuesInd = 0; + // Initializes text encoder object from serialized parameters. +-void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/, ++void* Initialize(TfLiteContext* /*context*/, ++ const char* /*buffer*/, + size_t /*length*/) { + return nullptr; + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc +index 41fc5aa28bf30..68f8e64492394 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc +@@ -16,16 +16,16 @@ limitations under the License. + #include <iterator> + #include <vector> + +-#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h" +-#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h" + #include "tensorflow/core/framework/op.h" + #include "tensorflow/core/framework/op_kernel.h" + #include "tensorflow/core/framework/shape_inference.h" + #include "tensorflow/core/framework/tensor.h" + #include "tensorflow/core/protobuf/error_codes.pb.h" ++#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h" ++#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h" + + namespace tensorflow { +-namespace ops{ ++namespace ops { + + // copied from third_party/tensorflow_text/core/ops/sentencepiece_ops.cc + REGISTER_OP("TFSentencepieceTokenizeOp") +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc +index 8309a6a2616fd..edb0160b508a3 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc +@@ -16,8 +16,6 @@ limitations under the License. + /** + * Sentencepiece tflite tokenizer implementation. + */ +-#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h" +-#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h" + #include "flatbuffers/flexbuffers.h" // from @flatbuffers + #include "tensorflow/lite/c/common.h" + #include "tensorflow/lite/context.h" +@@ -25,6 +23,8 @@ limitations under the License. + #include "tensorflow/lite/kernels/kernel_util.h" + #include "tensorflow/lite/model.h" + #include "tensorflow/lite/string_util.h" ++#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h" ++#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h" + + namespace tflite { + namespace ops { +@@ -47,7 +47,8 @@ TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) { + } // namespace + + // Initializes text encoder object from serialized parameters. +-void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/, ++void* Initialize(TfLiteContext* /*context*/, ++ const char* /*buffer*/, + size_t /*length*/) { + return nullptr; + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc +index dad2f0004be06..8096a5008bd12 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc +@@ -19,10 +19,10 @@ limitations under the License. + #include <utility> + #include <vector> + ++#include "libutf/utf.h" + #include "tensorflow/lite/context.h" + #include "tensorflow/lite/kernels/kernel_util.h" + #include "tensorflow/lite/string_util.h" +-#include "libutf/utf.h" + + constexpr int kInput = 0; + constexpr int kOutputValues = 0; +@@ -49,7 +49,7 @@ inline bool OutputIsPaddedTensor(TfLiteNode* node) { + } + + inline int charntorune(Rune* r, const char* s, int n) { +- const int bytes_read = chartorune(r, const_cast<char *>(s)); ++ const int bytes_read = chartorune(r, const_cast<char*>(s)); + if (bytes_read > n) { + *r = Runeerror; + return 0; +@@ -66,7 +66,8 @@ std::vector<std::pair<const char*, int>> Tokenize(StringRef str) { + while (n > 0) { + Rune r; + int c = charntorune(&r, p, n); +- if (r == Runeerror) break; ++ if (r == Runeerror) ++ break; + + if (isspacerune(r)) { + if (start != nullptr) { +@@ -91,7 +92,8 @@ std::vector<std::pair<const char*, int>> Tokenize(StringRef str) { + + TfLiteStatus WritePaddedOutput( + const std::vector<std::vector<std::pair<const char*, int>>>& list_of_tokens, +- const TfLiteTensor* input, TfLiteTensor* output_values) { ++ const TfLiteTensor* input, ++ TfLiteTensor* output_values) { + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) + 1); + for (int i = 0; i < NumDimensions(input); ++i) { + output_shape->data[i] = SizeOfDimension(input, i); +@@ -118,7 +120,8 @@ TfLiteStatus WritePaddedOutput( + + TfLiteStatus WriteRaggedOutput( + const std::vector<std::vector<std::pair<const char*, int>>>& list_of_tokens, +- const TfLiteTensor* input, TfLiteTensor* output_values, ++ const TfLiteTensor* input, ++ TfLiteTensor* output_values, + std::vector<TfLiteTensor*> nested_row_splits) { + // The outer dimensions of the ragged tensor are all non-ragged. + for (int i = 0; i < nested_row_splits.size() - 1; ++i) { +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc +index 534fbef4aff2d..6166bc149bc00 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc +@@ -15,8 +15,8 @@ limitations under the License. + + #include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h" + +-#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h" + #include "tensorflow/lite/mutable_op_resolver.h" ++#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h" + + namespace tflite { + namespace ops { +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc +index 7447870046f48..6339ed705bcb9 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc +@@ -24,22 +24,30 @@ limitations under the License. + #include <iostream> + #include <limits> + +-#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl + #include "absl/flags/parse.h" // from @com_google_absl + #include "tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h" + +-ABSL_FLAG(std::string, model_path, "", ++ABSL_FLAG(std::string, ++ model_path, ++ "", + "Absolute path to the '.tflite' audio classification model."); +-ABSL_FLAG(std::string, audio_wav_path, "", ++ABSL_FLAG(std::string, ++ audio_wav_path, ++ "", + "Absolute path to the 16-bit PCM WAV file to classify. The WAV " + "file must be monochannel and has a sampling rate matches the model " + "expected sampling rate (as in the Metadata). If the WAV file is " + "longer than what the model requires, only the beginning section is " + "used for inference."); +-ABSL_FLAG(float, score_threshold, 0.001f, ++ABSL_FLAG(float, ++ score_threshold, ++ 0.001f, + "Apply a filter on the results. Only display classes with score " + "higher than the threshold."); +-ABSL_FLAG(bool, use_coral, false, ++ABSL_FLAG(bool, ++ use_coral, ++ false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc +index 36d6633d902e3..a843501ec3d75 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc +@@ -19,7 +19,7 @@ limitations under the License. + #include <string> + #include <vector> + +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/status_macros.h" + #include "tensorflow_lite_support/cc/port/statusor.h" +@@ -34,7 +34,8 @@ namespace task { + namespace audio { + + tflite::support::StatusOr<AudioBuffer> LoadAudioBufferFromFile( +- const std::string& wav_file, int buffer_size, ++ const std::string& wav_file, ++ int buffer_size, + std::vector<float>* wav_data) { + std::string contents = ReadFile(wav_file); + +@@ -55,7 +56,8 @@ tflite::support::StatusOr<AudioBuffer> LoadAudioBufferFromFile( + } + + tflite::support::StatusOr<ClassificationResult> Classify( +- const std::string& model_path, const std::string& wav_file, ++ const std::string& model_path, ++ const std::string& wav_file, + bool use_coral) { + AudioClassifierOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( +@@ -97,7 +99,8 @@ void Display(const ClassificationResult& result, float score_threshold) { + std::cout << absl::StrFormat("\nHead[%d]: %s\n", i, head.head_name()); + for (int j = 0; j < head.classes_size(); j++) { + const auto& category = head.classes(j); +- if (category.score() < score_threshold) continue; ++ if (category.score() < score_threshold) ++ continue; + std::cout << absl::StrFormat("\tcategory[%s]: %.5f\t", + category.class_name(), category.score()); + if (!category.display_name().empty()) { +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h +index 6d23078ba3e19..13b2d7792e025 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h +@@ -28,7 +28,8 @@ namespace audio { + // than what the model requires, only the beginning section is used for + // inference. + tflite::support::StatusOr<ClassificationResult> Classify( +- const std::string& model_path, const std::string& wav_file, ++ const std::string& model_path, ++ const std::string& wav_file, + bool use_coral = false); + + // Prints the output classification result in the standard output. It only +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc +index 02eed2332b2e4..5203200808d60 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc +@@ -15,18 +15,22 @@ limitations under the License. + #include <iostream> + #include <limits> + +-#include "absl/flags/flag.h" // from @com_google_absl +-#include "absl/flags/parse.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/parse.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/core/category.h" + #include "tensorflow_lite_support/cc/task/text/bert_nl_classifier.h" + +-ABSL_FLAG(std::string, model_path, "", ++ABSL_FLAG(std::string, ++ model_path, ++ "", + "Absolute path to the '.tflite' bert classification model."); + ABSL_FLAG(std::string, text, "", "Text to classify."); +-ABSL_FLAG(bool, use_coral, false, ++ABSL_FLAG(bool, ++ use_coral, ++ false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc +index 4eaa2bbbdd9f5..f2577cfad54c2 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc +@@ -15,19 +15,25 @@ limitations under the License. + #include <iostream> + #include <limits> + +-#include "absl/flags/flag.h" // from @com_google_absl +-#include "absl/flags/parse.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/parse.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h" + +-ABSL_FLAG(std::string, model_path, "", ++ABSL_FLAG(std::string, ++ model_path, ++ "", + "Absolute path to the '.tflite' bert question answerer model."); + ABSL_FLAG(std::string, question, "", "Question to ask."); +-ABSL_FLAG(std::string, context, "", ++ABSL_FLAG(std::string, ++ context, ++ "", + "Context the asked question is based upon."); +-ABSL_FLAG(bool, use_coral, false, ++ABSL_FLAG(bool, ++ use_coral, ++ false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc +index 49f233ce1e74c..613744ffdb20b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc +@@ -15,18 +15,22 @@ limitations under the License. + #include <iostream> + #include <limits> + +-#include "absl/flags/flag.h" // from @com_google_absl +-#include "absl/flags/parse.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/parse.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/core/category.h" + #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" + +-ABSL_FLAG(std::string, model_path, "", ++ABSL_FLAG(std::string, ++ model_path, ++ "", + "Absolute path to the '.tflite' classification model."); + ABSL_FLAG(std::string, text, "", "Text to classify."); +-ABSL_FLAG(bool, use_coral, false, ++ABSL_FLAG(bool, ++ use_coral, ++ false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc +index 8c4a36c31674f..8ba00cb5d50bd 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc +@@ -14,26 +14,30 @@ limitations under the License. + ==============================================================================*/ + + // Demostration the usage of UniversalSentenceEncoderQA. +-#include "tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h" ++#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/parse.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_split.h" // from @com_google_absl +-#include "absl/flags/flag.h" // from @com_google_absl +-#include "absl/flags/parse.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-using tflite::task::text::RetrievalOptions; ++#include "tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h" + using tflite::task::text::RetrievalInput; ++using tflite::task::text::RetrievalOptions; + using tflite::task::text::RetrievalOutput; + using tflite::task::text::retrieval::UniversalSentenceEncoderQA; + +-ABSL_FLAG(std::string, model_path, "", ++ABSL_FLAG(std::string, ++ model_path, ++ "", + "Absolute path to the '.tflite' UniversalSentenceEncoderQA model."); +-ABSL_FLAG(std::string, question, "How are you feeling today?", ++ABSL_FLAG(std::string, ++ question, ++ "How are you feeling today?", + "Question to ask."); + ABSL_FLAG( +- std::string, answers, ++ std::string, ++ answers, + "I'm not feeling very well.:Paris is the capital of France.:He looks good.", + "Candidate answers seperated by `:`."); + +- + int main(int argc, char** argv) { + // Parse command line arguments and perform sanity checks. + absl::ParseCommandLine(argc, argv); +@@ -55,8 +59,8 @@ int main(int argc, char** argv) { + absl::GetFlag(FLAGS_model_path)); + auto status = UniversalSentenceEncoderQA::CreateFromOption(options); + CHECK_OK(status); +- std::unique_ptr<UniversalSentenceEncoderQA> client +- = std::move(status.value()); ++ std::unique_ptr<UniversalSentenceEncoderQA> client = ++ std::move(status.value()); + + // Create RetrievalInput with a query and responses. + RetrievalInput input; +@@ -80,8 +84,8 @@ int main(int argc, char** argv) { + // Consume the results according to the ranking. Here we just print them out. + std::cout << input.query_text() << std::endl; + for (size_t k : top) { +- std::cout << input.responses(k).raw_text().text() << ", " +- << input.responses(k).raw_text().context() << ", " +- << output.response_results(k).score() << std::endl; ++ std::cout << input.responses(k).raw_text().text() << ", " ++ << input.responses(k).raw_text().context() << ", " ++ << output.response_results(k).score() << std::endl; + } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc +index 8b2ed939686b3..bd2aaaf188726 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc +@@ -22,9 +22,9 @@ limitations under the License. + + #include <iostream> + +-#include "absl/flags/flag.h" // from @com_google_absl +-#include "absl/flags/parse.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/parse.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +@@ -36,29 +36,43 @@ limitations under the License. + #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" + #include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +-ABSL_FLAG(std::string, model_path, "", ++ABSL_FLAG(std::string, ++ model_path, ++ "", + "Absolute path to the '.tflite' image classifier model."); +-ABSL_FLAG(std::string, image_path, "", ++ABSL_FLAG(std::string, ++ image_path, ++ "", + "Absolute path to the image to classify. The image must be RGB or " + "RGBA (grayscale is not supported). The image EXIF orientation " + "flag, if any, is NOT taken into account."); +-ABSL_FLAG(int32, max_results, 5, ++ABSL_FLAG(int32, ++ max_results, ++ 5, + "Maximum number of classification results to display."); +-ABSL_FLAG(float, score_threshold, 0, ++ABSL_FLAG(float, ++ score_threshold, ++ 0, + "Classification results with a confidence score below this value are " + "rejected. If >= 0, overrides the score threshold(s) provided in the " + "TFLite Model Metadata. Ignored otherwise."); + ABSL_FLAG( +- std::vector<std::string>, class_name_whitelist, {}, ++ std::vector<std::string>, ++ class_name_whitelist, ++ {}, + "Comma-separated list of class names that acts as a whitelist. If " + "non-empty, classification results whose 'class_name' is not in this list " + "are filtered out. Mutually exclusive with 'class_name_blacklist'."); + ABSL_FLAG( +- std::vector<std::string>, class_name_blacklist, {}, ++ std::vector<std::string>, ++ class_name_blacklist, ++ {}, + "Comma-separated list of class names that acts as a blacklist. If " + "non-empty, classification results whose 'class_name' is in this list " + "are filtered out. Mutually exclusive with 'class_name_whitelist'."); +-ABSL_FLAG(bool, use_coral, false, ++ABSL_FLAG(bool, ++ use_coral, ++ false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc +index 722194f34ee5e..040878aa37841 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc +@@ -26,9 +26,9 @@ limitations under the License. + + #include <iostream> + +-#include "absl/flags/flag.h" // from @com_google_absl +-#include "absl/flags/parse.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/parse.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +@@ -39,28 +39,40 @@ limitations under the License. + #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" + #include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +-ABSL_FLAG(std::string, model_path, "", ++ABSL_FLAG(std::string, ++ model_path, ++ "", + "Absolute path to the '.tflite' image embedder model."); +-ABSL_FLAG(std::string, first_image_path, "", ++ABSL_FLAG(std::string, ++ first_image_path, ++ "", + "Absolute path to the first image, whose feature vector will be " + "extracted and compared to the second image using cosine similarity. " + "The image must be RGB or RGBA (grayscale is not supported). The " + "image EXIF orientation flag, if any, is NOT taken into account."); +-ABSL_FLAG(std::string, second_image_path, "", ++ABSL_FLAG(std::string, ++ second_image_path, ++ "", + "Absolute path to the second image, whose feature vector will be " + "extracted and compared to the first image using cosine similarity. " + "The image must be RGB or RGBA (grayscale is not supported). The " + "image EXIF orientation flag, if any, is NOT taken into account."); +-ABSL_FLAG(bool, l2_normalize, false, ++ABSL_FLAG(bool, ++ l2_normalize, ++ false, + "If true, the raw feature vectors returned by the image embedder " + "will be normalized with L2-norm. Generally only needed if the model " + "doesn't already contain a L2_NORMALIZATION TFLite Op."); + ABSL_FLAG( +- bool, quantize, false, ++ bool, ++ quantize, ++ false, + "If true, the raw feature vectors returned by the image embedder will " + "be quantized to 8 bit integers (uniform quantization) via post-processing " + "before cosine similarity is computed."); +-ABSL_FLAG(bool, use_coral, false, ++ABSL_FLAG(bool, ++ use_coral, ++ false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc +index 2cb606e011aca..6487fe92166cd 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc +@@ -23,10 +23,10 @@ limitations under the License. + + #include <iostream> + +-#include "absl/flags/flag.h" // from @com_google_absl +-#include "absl/flags/parse.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/match.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/parse.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/match.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +@@ -37,16 +37,24 @@ limitations under the License. + #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" + #include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +-ABSL_FLAG(std::string, model_path, "", ++ABSL_FLAG(std::string, ++ model_path, ++ "", + "Absolute path to the '.tflite' image segmenter model."); +-ABSL_FLAG(std::string, image_path, "", ++ABSL_FLAG(std::string, ++ image_path, ++ "", + "Absolute path to the image to segment. The image must be RGB or " + "RGBA (grayscale is not supported). The image EXIF orientation " + "flag, if any, is NOT taken into account."); +-ABSL_FLAG(std::string, output_mask_png, "", ++ABSL_FLAG(std::string, ++ output_mask_png, ++ "", + "Absolute path to the output category mask (confidence masks outputs " + "are not supported by this tool). Must have a '.png' extension."); +-ABSL_FLAG(bool, use_coral, false, ++ABSL_FLAG(bool, ++ use_coral, ++ false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc +index 0130b4550b9d9..9208439df6263 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc +@@ -24,10 +24,10 @@ limitations under the License. + #include <iostream> + #include <limits> + +-#include "absl/flags/flag.h" // from @com_google_absl +-#include "absl/flags/parse.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/match.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/parse.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/match.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +@@ -40,32 +40,48 @@ limitations under the License. + #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" + #include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +-ABSL_FLAG(std::string, model_path, "", ++ABSL_FLAG(std::string, ++ model_path, ++ "", + "Absolute path to the '.tflite' object detector model."); +-ABSL_FLAG(std::string, image_path, "", ++ABSL_FLAG(std::string, ++ image_path, ++ "", + "Absolute path to the image to run detection on. The image must be " + "RGB or RGBA (grayscale is not supported). The image EXIF " + "orientation flag, if any, is NOT taken into account."); +-ABSL_FLAG(std::string, output_png, "", ++ABSL_FLAG(std::string, ++ output_png, ++ "", + "Absolute path to a file where to draw the detection results on top " + "of the input image. Must have a '.png' extension."); +-ABSL_FLAG(int32, max_results, 5, ++ABSL_FLAG(int32, ++ max_results, ++ 5, + "Maximum number of detection results to display."); + ABSL_FLAG( +- float, score_threshold, std::numeric_limits<float>::lowest(), ++ float, ++ score_threshold, ++ std::numeric_limits<float>::lowest(), + "Detection results with a confidence score below this value are " + "rejected. If specified, overrides the score threshold(s) provided in the " + "TFLite Model Metadata. Ignored otherwise."); + ABSL_FLAG( +- std::vector<std::string>, class_name_whitelist, {}, ++ std::vector<std::string>, ++ class_name_whitelist, ++ {}, + "Comma-separated list of class names that acts as a whitelist. If " + "non-empty, detections results whose 'class_name' is not in this list " + "are filtered out. Mutually exclusive with 'class_name_blacklist'."); +-ABSL_FLAG(std::vector<std::string>, class_name_blacklist, {}, ++ABSL_FLAG(std::vector<std::string>, ++ class_name_blacklist, ++ {}, + "Comma-separated list of class names that acts as a blacklist. If " + "non-empty, detections results whose 'class_name' is in this list " + "are filtered out. Mutually exclusive with 'class_name_whitelist'."); +-ABSL_FLAG(bool, use_coral, false, ++ABSL_FLAG(bool, ++ use_coral, ++ false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc +index 6f3aa737bd090..efdcda993f5e8 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc +@@ -23,11 +23,11 @@ limitations under the License. + #define STB_IMAGE_IMPLEMENTATION + #define STB_IMAGE_WRITE_IMPLEMENTATION + +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/match.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/match.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl +-#include "stb_image.h" // from @stblib +-#include "stb_image_write.h" // from @stblib ++#include "stb_image.h" // from @stblib ++#include "stb_image_write.h" // from @stblib + #include "tensorflow_lite_support/cc/port/status_macros.h" + #include "tensorflow_lite_support/cc/port/statusor.h" + +@@ -87,7 +87,9 @@ absl::Status EncodeImageToPngFile(const ImageData& image_data, + return absl::OkStatus(); + } + +-void ImageDataFree(ImageData* image) { stbi_image_free(image->pixel_data); } ++void ImageDataFree(ImageData* image) { ++ stbi_image_free(image->pixel_data); ++} + + } // namespace vision + } // namespace task +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h +index a0b0c6bbad191..9e7e3ba500f2d 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h +@@ -15,7 +15,7 @@ limitations under the License. + #ifndef TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_ + #define TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_ + +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/integral_types.h" + #include "tensorflow_lite_support/cc/port/statusor.h" +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h +index a4fee55abe158..2ca42fb7f3fbe 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h +@@ -56,7 +56,8 @@ typedef NS_ENUM(NSUInteger, TFLSupportErrorCode) { + + /** TensorFlow Lite metadata error codes. */ + +- /** Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer. */ ++ /** Unexpected schema version (aka file_identifier) in the Metadata ++ FlatBuffer. */ + TFLSupportErrorCodeMetadataInvalidSchemaVersionError = 200, + + /** No such associated file within metadata, or file has not been packed. */ +@@ -198,11 +199,13 @@ typedef NS_ENUM(NSUInteger, TFLSupportErrorCode) { + */ + TFLSupportErrorCodeImageProcessingBackendError, + +- /** kNotFound indicates some requested entity (such as a file or directory) was not found. */ ++ /** kNotFound indicates some requested entity (such as a file or directory) ++ was not found. */ + TFLSupportErrorCodeNotFoundError = 900, + +- /** kInternal indicates an internal error has occurred and some invariants expected by the +- * underlying system have not been satisfied. This error code is reserved for serious errors. ++ /** kInternal indicates an internal error has occurred and some invariants ++ * expected by the underlying system have not been satisfied. This error code ++ * is reserved for serious errors. + */ + TFLSupportErrorCodeInternalError, + } NS_SWIFT_NAME(SupportErrorCode); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h +index 8ef21659a4a1a..a194b2834323a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h +@@ -21,42 +21,43 @@ NS_ASSUME_NONNULL_BEGIN + @interface TFLCommonUtils : NSObject + + /** +- * Creates and saves an error originating from the task library with the given error code and +- * description. ++ * Creates and saves an error originating from the task library with the given ++ * error code and description. + * + * @param code Error code. + * @param description Error description. +- * @param error Pointer to the memory location where the created error should be saved. If `nil`, no +- * error will be saved. ++ * @param error Pointer to the memory location where the created error should be ++ * saved. If `nil`, no error will be saved. + */ + + (void)customErrorWithCode:(NSInteger)code +- description:(NSString *)description +- error:(NSError **)error; ++ description:(NSString*)description ++ error:(NSError**)error; + + /** +- * Creates and saves an error originating from the task library from a C library error, +- * TfLiteSupportError . ++ * Creates and saves an error originating from the task library from a C library ++ * error, TfLiteSupportError . + * + * @param supportError C library error. +- * @param error Pointer to the memory location where the created error should be saved. If `nil`, no +- * error will be saved. ++ * @param error Pointer to the memory location where the created error should be ++ * saved. If `nil`, no error will be saved. + */ +-+ (void)errorFromTfLiteSupportError:(TfLiteSupportError *)supportError error:(NSError **)error; +++ (void)errorFromTfLiteSupportError:(TfLiteSupportError*)supportError ++ error:(NSError**)error; + + /** +- * Allocates a block of memory with the specified size and returns a pointer to it. If memory cannot +- * be allocated because of an invalid memSize, it saves an error. In other cases, it terminates +- * program execution. ++ * Allocates a block of memory with the specified size and returns a pointer to ++ * it. If memory cannot be allocated because of an invalid memSize, it saves an ++ * error. In other cases, it terminates program execution. + * + * @param memSize size of memory to be allocated +- * @param error Pointer to the memory location where errors if any should be saved. If `nil`, no +- * error will be saved. ++ * @param error Pointer to the memory location where errors if any should be ++ * saved. If `nil`, no error will be saved. + * +- * @return Pointer to the allocated block of memory on successfull allocation. nil in case as error +- * is encountered because of invalid memSize. If failure is due to any other reason, method +- * terminates program execution. ++ * @return Pointer to the allocated block of memory on successfull allocation. ++ * nil in case as error is encountered because of invalid memSize. If failure is ++ * due to any other reason, method terminates program execution. + */ +-+ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error; +++ (void*)mallocWithSize:(size_t)memSize error:(NSError**)error; + @end + + NS_ASSUME_NONNULL_END +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m +index 6fc2eadeeafe9..2f2d85a23593a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m +@@ -16,39 +16,43 @@ + #import "tensorflow_lite_support/ios/sources/TFLCommon.h" + + /** Error domain of TensorFlow Lite Support related errors. */ +-static NSString *const TFLSupportTaskErrorDomain = @"org.tensorflow.lite.tasks"; ++static NSString* const TFLSupportTaskErrorDomain = @"org.tensorflow.lite.tasks"; + + @implementation TFLCommonUtils + + + (void)customErrorWithCode:(NSInteger)code +- description:(NSString *)description +- error:(NSError **)error { ++ description:(NSString*)description ++ error:(NSError**)error { + if (error) +- *error = [NSError errorWithDomain:TFLSupportTaskErrorDomain +- code:code +- userInfo:@{NSLocalizedDescriptionKey : description}]; ++ *error = ++ [NSError errorWithDomain:TFLSupportTaskErrorDomain ++ code:code ++ userInfo:@{NSLocalizedDescriptionKey : description}]; + } + +-+ (void)errorFromTfLiteSupportError:(TfLiteSupportError *)supportError error:(NSError **)error { +++ (void)errorFromTfLiteSupportError:(TfLiteSupportError*)supportError ++ error:(NSError**)error { + if (supportError && error) +- *error = [NSError +- errorWithDomain:TFLSupportTaskErrorDomain +- code:supportError->code +- userInfo:@{ +- NSLocalizedDescriptionKey : [NSString stringWithCString:supportError->message +- encoding:NSUTF8StringEncoding] +- }]; ++ *error = [NSError errorWithDomain:TFLSupportTaskErrorDomain ++ code:supportError->code ++ userInfo:@{ ++ NSLocalizedDescriptionKey : [NSString ++ stringWithCString:supportError->message ++ encoding:NSUTF8StringEncoding] ++ }]; + } + +-+ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error { +++ (void*)mallocWithSize:(size_t)memSize error:(NSError**)error { + if (!memSize) { +- [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError +- description:@"Invalid memory size passed for allocation of object." +- error:error]; ++ [TFLCommonUtils ++ customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError ++ description: ++ @"Invalid memory size passed for allocation of object." ++ error:error]; + return NULL; + } + +- void *allocedMemory = malloc(memSize); ++ void* allocedMemory = malloc(memSize); + if (!allocedMemory && memSize) { + exit(-1); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h +index 40ba41b8eb0f9..90864c703c411 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h +@@ -18,7 +18,7 @@ + NS_ASSUME_NONNULL_BEGIN + + @interface TFLBaseOptions (Helpers) +-- (void)copyBaseOptionsToCBaseOptions:(TfLiteBaseOptions *)cBaseOptions; ++- (void)copyBaseOptionsToCBaseOptions:(TfLiteBaseOptions*)cBaseOptions; + @end + + NS_ASSUME_NONNULL_END +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.m +index 0fed6d7c9966e..ddab0f7ab4207 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.m +@@ -16,7 +16,7 @@ + + @implementation TFLBaseOptions (Helpers) + +-- (void)copyBaseOptionsToCBaseOptions:(TfLiteBaseOptions *)cBaseOptions { ++- (void)copyBaseOptionsToCBaseOptions:(TfLiteBaseOptions*)cBaseOptions { + if (self.modelFile.filePath) { + cBaseOptions->model_file.file_path = self.modelFile.filePath.UTF8String; + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h +index cdcddabe7323a..0f92dd1005631 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h +@@ -19,10 +19,10 @@ NS_ASSUME_NONNULL_BEGIN + NS_SWIFT_NAME(CpuSettings) + @interface TFLCpuSettings : NSObject <NSCopying> + +-/** Specifies the number of threads to be used for TFLite ops that support multi-threadingwhen +- * running inference with CPU. +- * @discussion This property hould be greater than 0 or equal to -1. Setting it to -1 has the +- * effect to let TFLite runtime set the value. ++/** Specifies the number of threads to be used for TFLite ops that support ++ * multi-threadingwhen running inference with CPU. ++ * @discussion This property hould be greater than 0 or equal to -1. Setting it ++ * to -1 has the effect to let TFLite runtime set the value. + */ + @property(nonatomic, assign) int numThreads; + +@@ -35,7 +35,7 @@ NS_SWIFT_NAME(ComputeSettings) + @interface TFLComputeSettings : NSObject <NSCopying> + + /** Holds cpu settings. */ +-@property(nonatomic, copy) TFLCpuSettings *cpuSettings; ++@property(nonatomic, copy) TFLCpuSettings* cpuSettings; + + @end + +@@ -46,30 +46,32 @@ NS_SWIFT_NAME(ExternalFile) + @interface TFLExternalFile : NSObject <NSCopying> + + /** Path to the file in bundle. */ +-@property(nonatomic, copy) NSString *filePath; ++@property(nonatomic, copy) NSString* filePath; + /// Add provision for other sources in future. + + @end + + /** +- * Holds the base options that is used for creation of any type of task. It has fields with +- * important information acceleration configuration, tflite model source etc. ++ * Holds the base options that is used for creation of any type of task. It has ++ * fields with important information acceleration configuration, tflite model ++ * source etc. + */ + NS_SWIFT_NAME(BaseOptions) + @interface TFLBaseOptions : NSObject <NSCopying> + + /** +- * The external model file, as a single standalone TFLite file. It could be packed with TFLite Model +- * Metadata[1] and associated files if exist. Fail to provide the necessary metadata and associated +- * files might result in errors. ++ * The external model file, as a single standalone TFLite file. It could be ++ * packed with TFLite Model Metadata[1] and associated files if exist. Fail to ++ * provide the necessary metadata and associated files might result in errors. + */ +-@property(nonatomic, copy) TFLExternalFile *modelFile; ++@property(nonatomic, copy) TFLExternalFile* modelFile; + + /** +- * Holds settings for one possible acceleration configuration including.cpu/gpu settings. +- * Please see documentation of TfLiteComputeSettings and its members for more details. ++ * Holds settings for one possible acceleration configuration including.cpu/gpu ++ * settings. Please see documentation of TfLiteComputeSettings and its members ++ * for more details. + */ +-@property(nonatomic, copy) TFLComputeSettings *computeSettings; ++@property(nonatomic, copy) TFLComputeSettings* computeSettings; + + @end + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.m +index 826380f1f62db..1e536cdc08194 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.m +@@ -25,8 +25,8 @@ + return self; + } + +-- (id)copyWithZone:(NSZone *)zone { +- TFLCpuSettings *cpuSettings = [[TFLCpuSettings alloc] init]; ++- (id)copyWithZone:(NSZone*)zone { ++ TFLCpuSettings* cpuSettings = [[TFLCpuSettings alloc] init]; + + [cpuSettings setNumThreads:self.numThreads]; + +@@ -46,8 +46,8 @@ + return self; + } + +-- (id)copyWithZone:(NSZone *)zone { +- TFLComputeSettings *computeSettings = [[TFLComputeSettings alloc] init]; ++- (id)copyWithZone:(NSZone*)zone { ++ TFLComputeSettings* computeSettings = [[TFLComputeSettings alloc] init]; + + [computeSettings setCpuSettings:self.cpuSettings]; + +@@ -59,8 +59,8 @@ + @implementation TFLExternalFile + @synthesize filePath; + +-- (id)copyWithZone:(NSZone *)zone { +- TFLExternalFile *externalFile = [[TFLExternalFile alloc] init]; ++- (id)copyWithZone:(NSZone*)zone { ++ TFLExternalFile* externalFile = [[TFLExternalFile alloc] init]; + + [externalFile setFilePath:self.filePath]; + +@@ -82,8 +82,8 @@ + return self; + } + +-- (id)copyWithZone:(NSZone *)zone { +- TFLBaseOptions *baseOptions = [[TFLBaseOptions alloc] init]; ++- (id)copyWithZone:(NSZone*)zone { ++ TFLBaseOptions* baseOptions = [[TFLBaseOptions alloc] init]; + + [baseOptions setModelFile:self.modelFile]; + [baseOptions setComputeSettings:self.computeSettings]; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h +index 623065d416904..78a1f965769aa 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h +@@ -19,11 +19,11 @@ NS_ASSUME_NONNULL_BEGIN + + @interface TFLClassificationOptions (Helpers) + - (BOOL)copyClassificationOptionsToCClassificationOptions: +- (TfLiteClassificationOptions *)cClassificationOptions +- error:(NSError **)error; ++ (TfLiteClassificationOptions*)cClassificationOptions ++ error:(NSError**)error; + + - (void)deleteCStringArraysOfClassificationOptions: +- (TfLiteClassificationOptions *)cClassificationOptions; ++ (TfLiteClassificationOptions*)cClassificationOptions; + @end + + NS_ASSUME_NONNULL_END +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m +index f7aa5fdf18b36..07254ab675c4b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m +@@ -18,28 +18,35 @@ + + @implementation TFLClassificationOptions (Helpers) + +-+ (char **)cStringArrayFromNSArray:(NSArray<NSString *> *)strings error:(NSError **)error { +++ (char**)cStringArrayFromNSArray:(NSArray<NSString*>*)strings ++ error:(NSError**)error { + if (strings.count <= 0) { +- [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError +- description:@"Invalid length of strings found for list type options." +- error:error]; ++ [TFLCommonUtils ++ customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError ++ description: ++ @"Invalid length of strings found for list type options." ++ error:error]; + return NULL; + } + +- char **cStrings = (char **)calloc(strings.count, sizeof(char *)); ++ char** cStrings = (char**)calloc(strings.count, sizeof(char*)); + + if (!cStrings) { +- [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeInternalError +- description:@"Could not initialize list type options." +- error:error]; ++ [TFLCommonUtils ++ customErrorWithCode:TFLSupportErrorCodeInternalError ++ description:@"Could not initialize list type options." ++ error:error]; + return nil; + } + + for (NSInteger i = 0; i < strings.count; i++) { +- char *cString = [TFLCommonUtils +- mallocWithSize:[strings[i] lengthOfBytesUsingEncoding:NSUTF8StringEncoding] + 1 ++ char* cString = [TFLCommonUtils ++ mallocWithSize:[strings[i] ++ lengthOfBytesUsingEncoding:NSUTF8StringEncoding] + ++ 1 + error:error]; +- if (!cString) return nil; ++ if (!cString) ++ return nil; + + strcpy(cString, strings[i].UTF8String); + } +@@ -47,7 +54,7 @@ + return cStrings; + } + +-+ (void)deleteCStringsArray:(char **)cStrings count:(int)count { +++ (void)deleteCStringsArray:(char**)cStrings count:(int)count { + for (NSInteger i = 0; i < count; i++) { + free(cStrings[i]); + } +@@ -56,49 +63,56 @@ + } + + - (BOOL)copyClassificationOptionsToCClassificationOptions: +- (TfLiteClassificationOptions *)cClassificationOptions +- error:(NSError **)error { ++ (TfLiteClassificationOptions*)cClassificationOptions ++ error:(NSError**)error { + cClassificationOptions->score_threshold = self.scoreThreshold; + cClassificationOptions->max_results = (int)self.maxResults; + + if (self.labelDenyList) { +- char **cClassNameBlackList = +- [TFLClassificationOptions cStringArrayFromNSArray:self.labelDenyList error:error]; ++ char** cClassNameBlackList = ++ [TFLClassificationOptions cStringArrayFromNSArray:self.labelDenyList ++ error:error]; + if (!cClassNameBlackList) { + return NO; + } + cClassificationOptions->label_denylist.list = cClassNameBlackList; +- cClassificationOptions->label_denylist.length = (int)self.labelDenyList.count; ++ cClassificationOptions->label_denylist.length = ++ (int)self.labelDenyList.count; + } + + if (self.labelAllowList) { +- char **cClassNameWhiteList = +- [TFLClassificationOptions cStringArrayFromNSArray:self.labelAllowList error:error]; ++ char** cClassNameWhiteList = ++ [TFLClassificationOptions cStringArrayFromNSArray:self.labelAllowList ++ error:error]; + if (!cClassNameWhiteList) { + return NO; + } + + cClassificationOptions->label_allowlist.list = cClassNameWhiteList; +- cClassificationOptions->label_allowlist.length = (int)self.labelAllowList.count; ++ cClassificationOptions->label_allowlist.length = ++ (int)self.labelAllowList.count; + } + + if (self.displayNamesLocal) { +- cClassificationOptions->display_names_local = (char *)self.displayNamesLocal.UTF8String; ++ cClassificationOptions->display_names_local = ++ (char*)self.displayNamesLocal.UTF8String; + } + + return YES; + } + + - (void)deleteCStringArraysOfClassificationOptions: +- (TfLiteClassificationOptions *)cClassificationOptions { ++ (TfLiteClassificationOptions*)cClassificationOptions { + if (self.labelAllowList) { +- [TFLClassificationOptions deleteCStringsArray:cClassificationOptions->label_allowlist.list +- count:cClassificationOptions->label_allowlist.length]; ++ [TFLClassificationOptions ++ deleteCStringsArray:cClassificationOptions->label_allowlist.list ++ count:cClassificationOptions->label_allowlist.length]; + } + + if (self.labelDenyList) { +- [TFLClassificationOptions deleteCStringsArray:cClassificationOptions->label_denylist.list +- count:cClassificationOptions->label_denylist.length]; ++ [TFLClassificationOptions ++ deleteCStringsArray:cClassificationOptions->label_denylist.list ++ count:cClassificationOptions->label_denylist.length]; + } + } + @end +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h +index dbe05c8f98d2f..cc0c8a87da148 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h +@@ -22,13 +22,14 @@ NS_ASSUME_NONNULL_BEGIN + @interface TFLClassificationOptions : NSObject <NSCopying> + + /** If set, all classes in this list will be filtered out from the results . */ +-@property(nonatomic, copy) NSArray *labelDenyList; ++@property(nonatomic, copy) NSArray* labelDenyList; + +-/** If set, all classes not in this list will be filtered out from the results . */ +-@property(nonatomic, copy) NSArray *labelAllowList; ++/** If set, all classes not in this list will be filtered out from the results . ++ */ ++@property(nonatomic, copy) NSArray* labelAllowList; + + /** Display names local for display names*/ +-@property(nonatomic, copy) NSString *displayNamesLocal; ++@property(nonatomic, copy) NSString* displayNamesLocal; + + /** Results with score threshold greater than this value are returned . */ + @property(nonatomic, assign) float scoreThreshold; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.m +index 784f782ebc271..dca232d673238 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.m +@@ -30,8 +30,9 @@ + return self; + } + +-- (id)copyWithZone:(NSZone *)zone { +- TFLClassificationOptions *classificationOptions = [[TFLClassificationOptions alloc] init]; ++- (id)copyWithZone:(NSZone*)zone { ++ TFLClassificationOptions* classificationOptions = ++ [[TFLClassificationOptions alloc] init]; + + [classificationOptions setScoreThreshold:self.scoreThreshold]; + [classificationOptions setMaxResults:self.maxResults]; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h +index 377e02f32045a..c0d6fb335ebf3 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h +@@ -20,39 +20,40 @@ NS_ASSUME_NONNULL_BEGIN + @interface TFLCategory : NSObject + + /** Display name of the class. */ +-@property(nonatomic, copy) NSString *displayName; ++@property(nonatomic, copy) NSString* displayName; + + /** Class name of the class . */ +-@property(nonatomic, copy) NSString *label; ++@property(nonatomic, copy) NSString* label; + + /** Confidence score for this class . */ + @property(nonatomic, assign) float score; + +-/** The index of the class in the corresponding label map, usually packed in the TFLite Model +- * Metadata. */ ++/** The index of the class in the corresponding label map, usually packed in the ++ * TFLite Model Metadata. */ + @property(nonatomic, assign) NSInteger classIndex; + + @end + +-/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */ ++/** Encapsulates list of predicted classes (aka labels) for a given image ++ * classifier head. */ + @interface TFLClassifications : NSObject + + /** +- * The index of the image classifier head these classes refer to. This is useful for multi-head +- * models. ++ * The index of the image classifier head these classes refer to. This is useful ++ * for multi-head models. + */ + @property(nonatomic, assign) int headIndex; + +-/** The array of predicted classes, usually sorted by descending scores (e.g.from high to low +- * probability). */ +-@property(nonatomic, copy) NSArray<TFLCategory *> *categories; ++/** The array of predicted classes, usually sorted by descending scores ++ * (e.g.from high to low probability). */ ++@property(nonatomic, copy) NSArray<TFLCategory*>* categories; + + @end + + /** Encapsulates results of any classification task. */ + @interface TFLClassificationResult : NSObject + +-@property(nonatomic, copy) NSArray<TFLClassifications *> *classifications; ++@property(nonatomic, copy) NSArray<TFLClassifications*>* classifications; + + @end + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.h +index 406eb1e4ceb5a..c52876e9a5d7a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.h +@@ -19,22 +19,23 @@ + + NS_ASSUME_NONNULL_BEGIN + +-/** Helper utility for conversion between TFLite Task C Library Classification Results and iOS +- * Classification Results . */ ++/** Helper utility for conversion between TFLite Task C Library Classification ++ * Results and iOS Classification Results . */ + @interface TFLClassificationUtils : NSObject + + /** +- * Creates and retrurns a TFLClassificationResult from a TfLiteClassificationResult returned by +- * TFLite Task C Library Classification tasks. ++ * Creates and retrurns a TFLClassificationResult from a ++ * TfLiteClassificationResult returned by TFLite Task C Library Classification ++ * tasks. + * +- * @param cClassificationResult Classification results returned by TFLite Task C Library +- * Classification tasks ++ * @param cClassificationResult Classification results returned by TFLite Task C ++ * Library Classification tasks + * +- * @return Classification Result of type TFLClassificationResult to be returned by inference methods +- * of the iOS TF Lite Task Classification tasks. ++ * @return Classification Result of type TFLClassificationResult to be returned ++ * by inference methods of the iOS TF Lite Task Classification tasks. + */ +-+ (TFLClassificationResult *)classificationResultFromCClassificationResults: +- (TfLiteClassificationResult *)cClassificationResult; +++ (TFLClassificationResult*)classificationResultFromCClassificationResults: ++ (TfLiteClassificationResult*)cClassificationResult; + + - (instancetype)init NS_UNAVAILABLE; + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.m +index a24a91e5c9729..b5d884d39f864 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.m +@@ -16,39 +16,44 @@ + + @implementation TFLClassificationUtils + +-+ (TFLClassificationResult *)classificationResultFromCClassificationResults: +- (TfLiteClassificationResult *)cClassificationResult { +- if (cClassificationResult == nil) return nil; +++ (TFLClassificationResult*)classificationResultFromCClassificationResults: ++ (TfLiteClassificationResult*)cClassificationResult { ++ if (cClassificationResult == nil) ++ return nil; + +- NSMutableArray *classificationHeads = [[NSMutableArray alloc] init]; ++ NSMutableArray* classificationHeads = [[NSMutableArray alloc] init]; + for (int i = 0; i < cClassificationResult->size; i++) { +- TfLiteClassifications cClassifications = cClassificationResult->classifications[i]; +- NSMutableArray *classes = [[NSMutableArray alloc] init]; ++ TfLiteClassifications cClassifications = ++ cClassificationResult->classifications[i]; ++ NSMutableArray* classes = [[NSMutableArray alloc] init]; + for (int j = 0; j < cClassifications.size; j++) { + TfLiteCategory cCategory = cClassifications.categories[j]; +- TFLCategory *resultCategory = [[TFLCategory alloc] init]; ++ TFLCategory* resultCategory = [[TFLCategory alloc] init]; + + if (cCategory.display_name != nil) { +- resultCategory.displayName = [NSString stringWithCString:cCategory.display_name +- encoding:NSUTF8StringEncoding]; ++ resultCategory.displayName = ++ [NSString stringWithCString:cCategory.display_name ++ encoding:NSUTF8StringEncoding]; + } + + if (cCategory.label != nil) { +- resultCategory.label = [NSString stringWithCString:cCategory.label +- encoding:NSUTF8StringEncoding]; ++ resultCategory.label = ++ [NSString stringWithCString:cCategory.label ++ encoding:NSUTF8StringEncoding]; + } + + resultCategory.score = cCategory.score; + resultCategory.classIndex = (NSInteger)cCategory.index; + [classes addObject:resultCategory]; + } +- TFLClassifications *classificationHead = [[TFLClassifications alloc] init]; ++ TFLClassifications* classificationHead = [[TFLClassifications alloc] init]; + classificationHead.categories = classes; + classificationHead.headIndex = i; + [classificationHeads addObject:classificationHead]; + } + +- TFLClassificationResult *classificationResult = [[TFLClassificationResult alloc] init]; ++ TFLClassificationResult* classificationResult = ++ [[TFLClassificationResult alloc] init]; + classificationResult.classifications = classificationHeads; + return classificationResult; + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h +index 99de5ad04febf..ac81a15ac11c6 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h +@@ -27,15 +27,17 @@ NS_ASSUME_NONNULL_BEGIN + @end + + /** +- * Classifier API for NLClassification tasks with Bert models, categorizes string into different +- * classes. The API expects a Bert based TFLite model with metadata populated. ++ * Classifier API for NLClassification tasks with Bert models, categorizes ++ * string into different classes. The API expects a Bert based TFLite model with ++ * metadata populated. + * + * The metadata should contain the following information: + * 1 input_process_unit for Wordpiece/Sentencepiece Tokenizer. + * 3 input tensors with names "ids", "mask" and "segment_ids". +- * 1 output tensor of type float32[1, 2], with a optionally attached label file. If a label +- * file is attached, the file should be a plain text file with one label per line, the number +- * of labels should match the number of categories the model outputs. ++ * 1 output tensor of type float32[1, 2], with a optionally attached label ++ * file. If a label file is attached, the file should be a plain text file with ++ * one label per line, the number of labels should match the number of ++ * categories the model outputs. + */ + @interface TFLBertNLClassifier : NSObject + +@@ -45,7 +47,7 @@ NS_ASSUME_NONNULL_BEGIN + * @param modelPath Path to the classification model. + * @return A TFLBertNLClassifier instance. + */ +-+ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath +++ (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath + NS_SWIFT_NAME(bertNLClassifier(modelPath:)); + + /** +@@ -54,8 +56,9 @@ NS_ASSUME_NONNULL_BEGIN + * @param modelPath Path to the classification model. + * @return A TFLBertNLClassifier instance. + */ +-+ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath +- options:(TFLBertNLClassifierOptions *)options +++ (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath ++ options: ++ (TFLBertNLClassifierOptions*)options + NS_SWIFT_NAME(bertNLClassifier(modelPath:options:)); + + /** +@@ -65,7 +68,7 @@ NS_ASSUME_NONNULL_BEGIN + * @param text input text to the model. + * @return A NSDictionary of categorization results. + */ +-- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text ++- (NSDictionary<NSString*, NSNumber*>*)classifyWithText:(NSString*)text + NS_SWIFT_NAME(classify(text:)); + @end + NS_ASSUME_NONNULL_END +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m +index e9d3b3dbbd1e3..8c45ee62cceea 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m +@@ -25,7 +25,7 @@ NS_ASSUME_NONNULL_BEGIN + + @interface TFLBertNLClassifier () + /** BertNLClassifier backed by C API */ +-@property(nonatomic) TfLiteBertNLClassifier *bertNLClassifier; ++@property(nonatomic) TfLiteBertNLClassifier* bertNLClassifier; + @end + + @implementation TFLBertNLClassifier +@@ -34,24 +34,28 @@ NS_ASSUME_NONNULL_BEGIN + TfLiteBertNLClassifierDelete(_bertNLClassifier); + } + +-+ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath { +- TfLiteBertNLClassifier *classifier = TfLiteBertNLClassifierCreate(modelPath.UTF8String); +++ (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath { ++ TfLiteBertNLClassifier* classifier = ++ TfLiteBertNLClassifierCreate(modelPath.UTF8String); + + _GTMDevAssert(classifier, @"Failed to create BertNLClassifier"); + return [[TFLBertNLClassifier alloc] initWithBertNLClassifier:classifier]; + } + +-+ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath +- options:(TFLBertNLClassifierOptions *)options { +- // Note that maxSeqLen has been deprecated. Passing it to the C API is a no-op. +++ (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath ++ options: ++ (TFLBertNLClassifierOptions*)options { ++ // Note that maxSeqLen has been deprecated. Passing it to the C API is a ++ // no-op. + TfLiteBertNLClassifierOptions cOptions = {.max_seq_len = options.maxSeqLen}; +- TfLiteBertNLClassifier *classifier = ++ TfLiteBertNLClassifier* classifier = + TfLiteBertNLClassifierCreateFromOptions(modelPath.UTF8String, &cOptions); + _GTMDevAssert(classifier, @"Failed to create BertNLClassifier"); + return [[TFLBertNLClassifier alloc] initWithBertNLClassifier:classifier]; + } + +-- (instancetype)initWithBertNLClassifier:(TfLiteBertNLClassifier *)bertNLClassifier { ++- (instancetype)initWithBertNLClassifier: ++ (TfLiteBertNLClassifier*)bertNLClassifier { + self = [super init]; + if (self) { + _bertNLClassifier = bertNLClassifier; +@@ -59,9 +63,11 @@ NS_ASSUME_NONNULL_BEGIN + return self; + } + +-- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text { +- Categories *cCategories = TfLiteBertNLClassifierClassify(_bertNLClassifier, text.UTF8String); +- NSMutableDictionary<NSString *, NSNumber *> *ret = [NSMutableDictionary dictionary]; ++- (NSDictionary<NSString*, NSNumber*>*)classifyWithText:(NSString*)text { ++ Categories* cCategories = ++ TfLiteBertNLClassifierClassify(_bertNLClassifier, text.UTF8String); ++ NSMutableDictionary<NSString*, NSNumber*>* ret = ++ [NSMutableDictionary dictionary]; + for (int i = 0; i < cCategories->size; i++) { + Category cCategory = cCategories->categories[i]; + [ret setValue:[NSNumber numberWithDouble:cCategory.score] +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h +index ceb8d2ef9a307..41eb0fb76c9ea 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h +@@ -23,14 +23,14 @@ NS_ASSUME_NONNULL_BEGIN + @property(nonatomic) int inputTensorIndex; + @property(nonatomic) int outputScoreTensorIndex; + @property(nonatomic) int outputLabelTensorIndex; +-@property(nonatomic) NSString *inputTensorName; +-@property(nonatomic) NSString *outputScoreTensorName; +-@property(nonatomic) NSString *outputLabelTensorName; ++@property(nonatomic) NSString* inputTensorName; ++@property(nonatomic) NSString* outputScoreTensorName; ++@property(nonatomic) NSString* outputLabelTensorName; + @end + + /** +- * Classifier API for natural language classification tasks, categorizes string into different +- * classes. ++ * Classifier API for natural language classification tasks, categorizes string ++ * into different classes. + * + * The API expects a TFLite model with the following input/output tensor: + * +@@ -39,25 +39,28 @@ NS_ASSUME_NONNULL_BEGIN + * + * Output score tensor + * (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64/kTfLiteBool) +- * output scores for each class, if type is one of the Int types, dequantize it, if it +- * is Bool type, convert the values to 0.0 and 1.0 respectively. ++ * output scores for each class, if type is one of the Int types, dequantize ++ * it, if it is Bool type, convert the values to 0.0 and 1.0 respectively. + * +- * can have an optional associated file in metadata for labels, the file should be a +- * plain text file with one label per line, the number of labels should match the number +- * of categories the model outputs. Output label tensor: optional (kTfLiteString) - +- * output classname for each class, should be of the same length with scores. If this +- * tensor is not present, the API uses score indices as classnames. - will be ignored if +- * output score tensor already has an associated label file. ++ * can have an optional associated file in metadata for labels, the file ++ * should be a plain text file with one label per line, the number of labels ++ * should match the number of categories the model outputs. Output label tensor: ++ * optional (kTfLiteString) - output classname for each class, should be of the ++ * same length with scores. If this tensor is not present, the API uses score ++ * indices as classnames. - will be ignored if output score tensor already has ++ * an associated label file. + * + * Optional Output label tensor (kTfLiteString/kTfLiteInt32) +- * output classname for each class, should be of the same length with scores. If this +- * tensor is not present, the API uses score indices as classnames. ++ * output classname for each class, should be of the same length with ++ * scores. If this tensor is not present, the API uses score indices as ++ * classnames. + * +- * will be ignored if output score tensor already has an associated labe file. ++ * will be ignored if output score tensor already has an associated labe ++ * file. + * +- * By default the API tries to find the input/output tensors with default configurations in +- * TFLNLClassifierOptions, with tensor name prioritized over tensor index. The option is +- * configurable for different TFLite models. ++ * By default the API tries to find the input/output tensors with default ++ * configurations in TFLNLClassifierOptions, with tensor name prioritized over ++ * tensor index. The option is configurable for different TFLite models. + */ + @interface TFLNLClassifier : NSObject + +@@ -69,8 +72,8 @@ NS_ASSUME_NONNULL_BEGIN + * + * @return A TFLNLClassifier instance. + */ +-+ (instancetype)nlClassifierWithModelPath:(NSString *)modelPath +- options:(TFLNLClassifierOptions *)options +++ (instancetype)nlClassifierWithModelPath:(NSString*)modelPath ++ options:(TFLNLClassifierOptions*)options + NS_SWIFT_NAME(nlClassifier(modelPath:options:)); + + /** +@@ -80,7 +83,7 @@ NS_ASSUME_NONNULL_BEGIN + * @param text input text to the model. + * @return A NSDictionary of categorization results. + */ +-- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text ++- (NSDictionary<NSString*, NSNumber*>*)classifyWithText:(NSString*)text + NS_SWIFT_NAME(classify(text:)); + @end + NS_ASSUME_NONNULL_END +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m +index 8d21a111345d2..39eb15c71681c 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m +@@ -30,7 +30,7 @@ NS_ASSUME_NONNULL_BEGIN + + @interface TFLNLClassifier () + /** NLClassifier backed by C API */ +-@property(nonatomic) TfLiteNLClassifier *nlClassifier; ++@property(nonatomic) TfLiteNLClassifier* nlClassifier; + @end + + @implementation TFLNLClassifier +@@ -39,8 +39,8 @@ NS_ASSUME_NONNULL_BEGIN + TfLiteNLClassifierDelete(_nlClassifier); + } + +-+ (instancetype)nlClassifierWithModelPath:(NSString *)modelPath +- options:(TFLNLClassifierOptions *)options { +++ (instancetype)nlClassifierWithModelPath:(NSString*)modelPath ++ options:(TFLNLClassifierOptions*)options { + TfLiteNLClassifierOptions cOptions = { + .input_tensor_index = options.inputTensorIndex, + .output_score_tensor_index = options.outputScoreTensorIndex, +@@ -48,13 +48,13 @@ NS_ASSUME_NONNULL_BEGIN + .input_tensor_name = options.inputTensorName.UTF8String, + .output_score_tensor_name = options.outputScoreTensorName.UTF8String, + .output_label_tensor_name = options.outputLabelTensorName.UTF8String}; +- TfLiteNLClassifier *classifier = ++ TfLiteNLClassifier* classifier = + TfLiteNLClassifierCreateFromOptions(modelPath.UTF8String, &cOptions); + _GTMDevAssert(classifier, @"Failed to create NLClassifier"); + return [[TFLNLClassifier alloc] initWithNLClassifier:classifier]; + } + +-- (instancetype)initWithNLClassifier:(TfLiteNLClassifier *)nlClassifier { ++- (instancetype)initWithNLClassifier:(TfLiteNLClassifier*)nlClassifier { + self = [super init]; + if (self) { + _nlClassifier = nlClassifier; +@@ -62,9 +62,11 @@ NS_ASSUME_NONNULL_BEGIN + return self; + } + +-- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text { +- Categories *cCategories = TfLiteNLClassifierClassify(_nlClassifier, text.UTF8String); +- NSMutableDictionary<NSString *, NSNumber *> *ret = [NSMutableDictionary dictionary]; ++- (NSDictionary<NSString*, NSNumber*>*)classifyWithText:(NSString*)text { ++ Categories* cCategories = ++ TfLiteNLClassifierClassify(_nlClassifier, text.UTF8String); ++ NSMutableDictionary<NSString*, NSNumber*>* ret = ++ [NSMutableDictionary dictionary]; + for (int i = 0; i < cCategories->size; i++) { + Category cCategory = cCategories->categories[i]; + [ret setValue:[NSNumber numberWithDouble:cCategory.score] +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m +index 9734fe7987a5e..407be10c1381c 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m +@@ -19,8 +19,8 @@ limitations under the License. + NS_ASSUME_NONNULL_BEGIN + + @interface TFLBertNLClassifierTest : XCTestCase +-@property(nonatomic, nullable) NSString *bertModelPath; +-@property(nonatomic, nullable) TFLBertNLClassifierOptions *modelOptions; ++@property(nonatomic, nullable) NSString* bertModelPath; ++@property(nonatomic, nullable) TFLBertNLClassifierOptions* modelOptions; + @end + + @implementation TFLBertNLClassifierTest +@@ -28,30 +28,31 @@ NS_ASSUME_NONNULL_BEGIN + + - (void)setUp { + [super setUp]; +- NSBundle *bundle = [NSBundle bundleForClass:[self class]]; +- self.bertModelPath = [bundle pathForResource:@"bert_nl_classifier" ofType:@"tflite"]; ++ NSBundle* bundle = [NSBundle bundleForClass:[self class]]; ++ self.bertModelPath = [bundle pathForResource:@"bert_nl_classifier" ++ ofType:@"tflite"]; + } + + - (void)testClassifyPositiveResult { +- TFLBertNLClassifier *bertNLClassifier = ++ TFLBertNLClassifier* bertNLClassifier = + [TFLBertNLClassifier bertNLClassifierWithModelPath:self.bertModelPath]; + + XCTAssertNotNil(bertNLClassifier); + +- NSDictionary<NSString *, NSNumber *> *categories = +- [bertNLClassifier classifyWithText:@"it's a charming and often affecting journey"]; ++ NSDictionary<NSString*, NSNumber*>* categories = [bertNLClassifier ++ classifyWithText:@"it's a charming and often affecting journey"]; + + XCTAssertGreaterThan([categories[@"positive"] doubleValue], + [categories[@"negative"] doubleValue]); + } + + - (void)testClassifyNegativeResult { +- TFLBertNLClassifier *bertNLClassifier = ++ TFLBertNLClassifier* bertNLClassifier = + [TFLBertNLClassifier bertNLClassifierWithModelPath:self.bertModelPath]; + + XCTAssertNotNil(bertNLClassifier); + +- NSDictionary<NSString *, NSNumber *> *categories = ++ NSDictionary<NSString*, NSNumber*>* categories = + [bertNLClassifier classifyWithText:@"unflinchingly bleak and desperate"]; + + XCTAssertGreaterThan([categories[@"negative"] doubleValue], +@@ -62,14 +63,14 @@ NS_ASSUME_NONNULL_BEGIN + self.modelOptions = [[TFLBertNLClassifierOptions alloc] init]; + [self.modelOptions setMaxSeqLen:128]; + +- TFLBertNLClassifier *bertNLClassifier = ++ TFLBertNLClassifier* bertNLClassifier = + [TFLBertNLClassifier bertNLClassifierWithModelPath:self.bertModelPath + options:self.modelOptions]; + + XCTAssertNotNil(bertNLClassifier); + +- NSDictionary<NSString *, NSNumber *> *categories = +- [bertNLClassifier classifyWithText:@"it's a charming and often affecting journey"]; ++ NSDictionary<NSString*, NSNumber*>* categories = [bertNLClassifier ++ classifyWithText:@"it's a charming and often affecting journey"]; + + XCTAssertGreaterThan([categories[@"positive"] doubleValue], + [categories[@"negative"] doubleValue]); +@@ -79,13 +80,13 @@ NS_ASSUME_NONNULL_BEGIN + self.modelOptions = [[TFLBertNLClassifierOptions alloc] init]; + [self.modelOptions setMaxSeqLen:128]; + +- TFLBertNLClassifier *bertNLClassifier = ++ TFLBertNLClassifier* bertNLClassifier = + [TFLBertNLClassifier bertNLClassifierWithModelPath:self.bertModelPath + options:self.modelOptions]; + + XCTAssertNotNil(bertNLClassifier); + +- NSDictionary<NSString *, NSNumber *> *categories = ++ NSDictionary<NSString*, NSNumber*>* categories = + [bertNLClassifier classifyWithText:@"unflinchingly bleak and desperate"]; + + XCTAssertGreaterThan([categories[@"negative"] doubleValue], +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m +index 40814ac6409b0..1dcf08acc8c86 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m +@@ -19,8 +19,8 @@ limitations under the License. + NS_ASSUME_NONNULL_BEGIN + + @interface TFLNLClassifierTest : XCTestCase +-@property(nonatomic, nullable) NSString *modelPath; +-@property(nonatomic, nullable) TFLNLClassifierOptions *modelOptions; ++@property(nonatomic, nullable) NSString* modelPath; ++@property(nonatomic, nullable) TFLNLClassifierOptions* modelOptions; + @end + + @implementation TFLNLClassifierTest +@@ -28,34 +28,38 @@ NS_ASSUME_NONNULL_BEGIN + + - (void)setUp { + [super setUp]; +- NSBundle *bundle = [NSBundle bundleForClass:[self class]]; +- self.modelPath = [bundle pathForResource:@"test_model_nl_classifier_with_regex_tokenizer" +- ofType:@"tflite"]; ++ NSBundle* bundle = [NSBundle bundleForClass:[self class]]; ++ self.modelPath = ++ [bundle pathForResource:@"test_model_nl_classifier_with_regex_tokenizer" ++ ofType:@"tflite"]; + self.modelOptions = [[TFLNLClassifierOptions alloc] init]; + [self.modelOptions setInputTensorName:@"input_text"]; + [self.modelOptions setOutputScoreTensorName:@"probability"]; + } + + - (void)testClassifyPositiveResult { +- TFLNLClassifier *nlClassifier = [TFLNLClassifier nlClassifierWithModelPath:self.modelPath +- options:self.modelOptions]; ++ TFLNLClassifier* nlClassifier = ++ [TFLNLClassifier nlClassifierWithModelPath:self.modelPath ++ options:self.modelOptions]; + + XCTAssertNotNil(nlClassifier); + +- NSDictionary<NSString *, NSNumber *> *categories = [nlClassifier +- classifyWithText:@"This is the best movie I’ve seen in recent years. Strongly recommend it!"]; ++ NSDictionary<NSString*, NSNumber*>* categories = ++ [nlClassifier classifyWithText:@"This is the best movie I’ve seen in " ++ @"recent years. Strongly recommend it!"]; + + XCTAssertGreaterThan([categories[@"Positive"] doubleValue], + [categories[@"Negative"] doubleValue]); + } + + - (void)testClassifyNegativeResult { +- TFLNLClassifier *nlClassifier = [TFLNLClassifier nlClassifierWithModelPath:self.modelPath +- options:self.modelOptions]; ++ TFLNLClassifier* nlClassifier = ++ [TFLNLClassifier nlClassifierWithModelPath:self.modelPath ++ options:self.modelOptions]; + + XCTAssertNotNil(nlClassifier); + +- NSDictionary<NSString *, NSNumber *> *categories = ++ NSDictionary<NSString*, NSNumber*>* categories = + [nlClassifier classifyWithText:@"What a waste of my time."]; + + XCTAssertGreaterThan([categories[@"Negative"] doubleValue], +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h +index 57b7c69c70f62..446e2cb137dd9 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h +@@ -54,13 +54,13 @@ struct TFLPos { + * @param modelPath The file path to the tflite model. + * @return A BertQuestionAnswerer instance. + */ +-+ (instancetype)questionAnswererWithModelPath:(NSString *)modelPath +++ (instancetype)questionAnswererWithModelPath:(NSString*)modelPath + NS_SWIFT_NAME(questionAnswerer(modelPath:)); + + /** + * Answers question based on the context. Could be empty if no answer was found + * from the given context. +- * ++ * + * @param context Context the question bases on. + * @param question Question to ask. + * +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m +index a07f8753fbae3..b470c4643111e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m +@@ -25,7 +25,7 @@ NS_ASSUME_NONNULL_BEGIN + + @interface TFLBertQuestionAnswerer () + /** BertQuestionAnswerer backed by C API */ +-@property(nonatomic) TfLiteBertQuestionAnswerer *bertQuestionAnswerer; ++@property(nonatomic) TfLiteBertQuestionAnswerer* bertQuestionAnswerer; + @end + + @implementation TFLBertQuestionAnswerer +@@ -34,14 +34,16 @@ NS_ASSUME_NONNULL_BEGIN + TfLiteBertQuestionAnswererDelete(_bertQuestionAnswerer); + } + +-+ (instancetype)questionAnswererWithModelPath:(NSString *)modelPath { +- TfLiteBertQuestionAnswerer *bert_qa = TfLiteBertQuestionAnswererCreate(modelPath.UTF8String); +++ (instancetype)questionAnswererWithModelPath:(NSString*)modelPath { ++ TfLiteBertQuestionAnswerer* bert_qa = ++ TfLiteBertQuestionAnswererCreate(modelPath.UTF8String); + + _GTMDevAssert(bert_qa, @"Failed to create BertQuestionAnswerer"); + return [[TFLBertQuestionAnswerer alloc] initWithBertQuestionAnswerer:bert_qa]; + } + +-- (instancetype)initWithBertQuestionAnswerer:(TfLiteBertQuestionAnswerer *)bertQuestionAnswerer { ++- (instancetype)initWithBertQuestionAnswerer: ++ (TfLiteBertQuestionAnswerer*)bertQuestionAnswerer { + self = [super init]; + if (self) { + _bertQuestionAnswerer = bertQuestionAnswerer; +@@ -49,14 +51,17 @@ NS_ASSUME_NONNULL_BEGIN + return self; + } + +-- (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question { +- TfLiteQaAnswers *cAnswers = TfLiteBertQuestionAnswererAnswer( ++- (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context ++ question:(NSString*)question { ++ TfLiteQaAnswers* cAnswers = TfLiteBertQuestionAnswererAnswer( + _bertQuestionAnswerer, context.UTF8String, question.UTF8String); +- NSMutableArray<TFLQAAnswer *> *ret = [NSMutableArray arrayWithCapacity:cAnswers->size]; ++ NSMutableArray<TFLQAAnswer*>* ret = ++ [NSMutableArray arrayWithCapacity:cAnswers->size]; + for (int i = 0; i < cAnswers->size; i++) { + TfLiteQaAnswer cAnswer = cAnswers->answers[i]; +- TFLQAAnswer *answer = [[TFLQAAnswer alloc] init]; +- struct TFLPos pos = {.start = cAnswer.start, .end = cAnswer.end, .logit = cAnswer.logit}; ++ TFLQAAnswer* answer = [[TFLQAAnswer alloc] init]; ++ struct TFLPos pos = { ++ .start = cAnswer.start, .end = cAnswer.end, .logit = cAnswer.logit}; + [answer setPos:pos]; + [answer setText:[NSString stringWithUTF8String:cAnswer.text]]; + [ret addObject:answer]; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m +index 9061063096cb4..ac4a1d3be63ef 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m +@@ -16,7 +16,7 @@ limitations under the License. + + #import <XCTest/XCTest.h> + +-static NSString *const kContext = ++static NSString* const kContext = + @"The role of teacher is often formal and ongoing, carried out at a school " + "or other place of formal education. In many countries, a person who " + "wishes to become a teacher must first obtain specified professional " +@@ -27,12 +27,12 @@ static NSString *const kContext = + "continuing professional development. Teachers may use a lesson plan to " + "facilitate student learning, providing a course of study which is called " + "the curriculum."; +-static NSString *const kQuestion = @"What is a course of study called?"; +-static NSString *const kAnswer = @"the curriculum."; ++static NSString* const kQuestion = @"What is a course of study called?"; ++static NSString* const kAnswer = @"the curriculum."; + + @interface TFLBertQuestionAnswererTest : XCTestCase +-@property(nonatomic, nullable) NSString *mobileBertModelPath; +-@property(nonatomic, nullable) NSString *albertModelPath; ++@property(nonatomic, nullable) NSString* mobileBertModelPath; ++@property(nonatomic, nullable) NSString* albertModelPath; + @end + + @implementation TFLBertQuestionAnswererTest +@@ -40,32 +40,33 @@ static NSString *const kAnswer = @"the curriculum."; + + - (void)setUp { + [super setUp]; +- NSBundle *bundle = [NSBundle bundleForClass:[self class]]; +- self.mobileBertModelPath = [bundle pathForResource:@"mobilebert_with_metadata" ofType:@"tflite"]; +- self.albertModelPath = [bundle pathForResource:@"albert_with_metadata" ofType:@"tflite"]; ++ NSBundle* bundle = [NSBundle bundleForClass:[self class]]; ++ self.mobileBertModelPath = [bundle pathForResource:@"mobilebert_with_metadata" ++ ofType:@"tflite"]; ++ self.albertModelPath = [bundle pathForResource:@"albert_with_metadata" ++ ofType:@"tflite"]; + } + + - (void)testInitMobileBert { +- TFLBertQuestionAnswerer* mobileBertAnswerer = +- [TFLBertQuestionAnswerer questionAnswererWithModelPath:self.mobileBertModelPath]; ++ TFLBertQuestionAnswerer* mobileBertAnswerer = [TFLBertQuestionAnswerer ++ questionAnswererWithModelPath:self.mobileBertModelPath]; + + XCTAssertNotNil(mobileBertAnswerer); + + NSArray<TFLQAAnswer*>* answers = +- [mobileBertAnswerer answerWithContext:kContext question:kQuestion]; ++ [mobileBertAnswerer answerWithContext:kContext question:kQuestion]; + + XCTAssertEqualObjects([answers[0] text], kAnswer); + } + + - (void)testInitAlbert { +- TFLBertQuestionAnswerer* albertAnswerer = +- [TFLBertQuestionAnswerer questionAnswererWithModelPath:self.albertModelPath]; ++ TFLBertQuestionAnswerer* albertAnswerer = [TFLBertQuestionAnswerer ++ questionAnswererWithModelPath:self.albertModelPath]; + + XCTAssertNotNil(albertAnswerer); + +- NSArray<TFLQAAnswer*>* answers = +- [albertAnswerer answerWithContext:kContext question:kQuestion]; +- ++ NSArray<TFLQAAnswer*>* answers = [albertAnswerer answerWithContext:kContext ++ question:kQuestion]; + + XCTAssertEqualObjects([answers[0] text], kAnswer); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h +index c37f22f3fb9aa..1b988f2be9737 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h +@@ -30,27 +30,29 @@ NS_ASSUME_NONNULL_BEGIN + * Base options that is used for creation of any type of task. + * @seealso TFLBaseOptions + */ +-@property(nonatomic, copy) TFLBaseOptions *baseOptions; ++@property(nonatomic, copy) TFLBaseOptions* baseOptions; + + /** + * Options that configure the display and filtering of results. + * @seealso TFLClassificationOptions + */ +-@property(nonatomic, copy) TFLClassificationOptions *classificationOptions; ++@property(nonatomic, copy) TFLClassificationOptions* classificationOptions; + + /** +- * Initializes TFLImageClassifierOptions with the model path set to the specified path to a model +- * file. +- * @description The external model file, must be a single standalone TFLite file. It could be packed +- * with TFLite Model Metadata[1] and associated files if exist. Fail to provide the necessary +- * metadata and associated files might result in errors. Check the [documentation] +- * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. ++ * Initializes TFLImageClassifierOptions with the model path set to the ++ * specified path to a model file. ++ * @description The external model file, must be a single standalone TFLite ++ * file. It could be packed with TFLite Model Metadata[1] and associated files ++ * if exist. Fail to provide the necessary metadata and associated files might ++ * result in errors. Check the [documentation] ++ * (https://www.tensorflow.org/lite/convert/metadata) for each task about the ++ * specific requirement. + * + * @param modelPath Path to a TFLite model file. + * @return An instance of TFLImageClassifierOptions set to the specified + * modelPath. + */ +-- (nullable instancetype)initWithModelPath:(nonnull NSString *)modelPath; ++- (nullable instancetype)initWithModelPath:(nonnull NSString*)modelPath; + + @end + +@@ -67,8 +69,9 @@ NS_ASSUME_NONNULL_BEGIN + * + * @return A TFLImageClassifier instance. + */ +-+ (nullable instancetype)imageClassifierWithOptions:(nonnull TFLImageClassifierOptions *)options +- error:(NSError **)error +++ (nullable instancetype)imageClassifierWithOptions: ++ (nonnull TFLImageClassifierOptions*)options ++ error:(NSError**)error + NS_SWIFT_NAME(imageClassifier(options:)); + + /** +@@ -79,8 +82,9 @@ NS_ASSUME_NONNULL_BEGIN + * @param image input to the model. + * @return An NSArray<NSArray<TFLClass *>*> * of classification results. + */ +-- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image +- error:(NSError *_Nullable *)error ++- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image ++ error:(NSError* _Nullable*) ++ error + NS_SWIFT_NAME(classify(gmlImage:)); + + /** +@@ -94,9 +98,10 @@ NS_ASSUME_NONNULL_BEGIN + * + * @return An NSArray<NSArray<TFLClass *>*> * of classification results. + */ +-- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image +- regionOfInterest:(CGRect)roi +- error:(NSError *_Nullable *)error ++- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image ++ regionOfInterest:(CGRect)roi ++ error:(NSError* _Nullable*) ++ error + NS_SWIFT_NAME(classify(gmlImage:regionOfInterest:)); + + - (instancetype)init NS_UNAVAILABLE; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m +index b0a6b005b2a2d..06d6793340269 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m +@@ -24,7 +24,7 @@ + + @interface TFLImageClassifier () + /** ImageClassifier backed by C API */ +-@property(nonatomic) TfLiteImageClassifier *imageClassifier; ++@property(nonatomic) TfLiteImageClassifier* imageClassifier; + @end + + @implementation TFLImageClassifierOptions +@@ -40,7 +40,7 @@ + return self; + } + +-- (nullable instancetype)initWithModelPath:(nonnull NSString *)modelPath { ++- (nullable instancetype)initWithModelPath:(nonnull NSString*)modelPath { + self = [self init]; + if (self) { + self.baseOptions.modelFile.filePath = modelPath; +@@ -55,7 +55,8 @@ + TfLiteImageClassifierDelete(_imageClassifier); + } + +-- (instancetype)initWithImageClassifier:(TfLiteImageClassifier *)imageClassifier { ++- (instancetype)initWithImageClassifier: ++ (TfLiteImageClassifier*)imageClassifier { + self = [super init]; + if (self) { + _imageClassifier = imageClassifier; +@@ -63,25 +64,28 @@ + return self; + } + +-+ (nullable instancetype)imageClassifierWithOptions:(nonnull TFLImageClassifierOptions *)options +- error:(NSError **)error { +++ (nullable instancetype)imageClassifierWithOptions: ++ (nonnull TFLImageClassifierOptions*)options ++ error:(NSError**)error { + TfLiteImageClassifierOptions cOptions = TfLiteImageClassifierOptionsCreate(); + if (![options.classificationOptions +- copyClassificationOptionsToCClassificationOptions:&(cOptions.classification_options) ++ copyClassificationOptionsToCClassificationOptions: ++ &(cOptions.classification_options) + error:error]) + return nil; + + [options.baseOptions copyBaseOptionsToCBaseOptions:&(cOptions.base_options)]; + +- TfLiteSupportError *createClassifierError = nil; +- TfLiteImageClassifier *imageClassifier = ++ TfLiteSupportError* createClassifierError = nil; ++ TfLiteImageClassifier* imageClassifier = + TfLiteImageClassifierFromOptions(&cOptions, &createClassifierError); + +- [options.classificationOptions +- deleteCStringArraysOfClassificationOptions:&(cOptions.classification_options)]; ++ [options.classificationOptions deleteCStringArraysOfClassificationOptions: ++ &(cOptions.classification_options)]; + + if (!imageClassifier) { +- [TFLCommonUtils errorFromTfLiteSupportError:createClassifierError error:error]; ++ [TFLCommonUtils errorFromTfLiteSupportError:createClassifierError ++ error:error]; + TfLiteSupportErrorDelete(createClassifierError); + return nil; + } +@@ -89,17 +93,20 @@ + return [[TFLImageClassifier alloc] initWithImageClassifier:imageClassifier]; + } + +-- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image +- error:(NSError *_Nullable *)error { ++- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image ++ error:(NSError* _Nullable*) ++ error { + return [self classifyWithGMLImage:image + regionOfInterest:CGRectMake(0, 0, image.width, image.height) + error:error]; + } + +-- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image +- regionOfInterest:(CGRect)roi +- error:(NSError *_Nullable *)error { +- TfLiteFrameBuffer *cFrameBuffer = [GMLImageUtils cFrameBufferFromGMLImage:image error:error]; ++- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image ++ regionOfInterest:(CGRect)roi ++ error:(NSError* _Nullable*) ++ error { ++ TfLiteFrameBuffer* cFrameBuffer = ++ [GMLImageUtils cFrameBufferFromGMLImage:image error:error]; + + if (!cFrameBuffer) { + return nil; +@@ -110,9 +117,10 @@ + .width = roi.size.width, + .height = roi.size.height}; + +- TfLiteSupportError *classifyError = nil; +- TfLiteClassificationResult *cClassificationResult = TfLiteImageClassifierClassifyWithRoi( +- _imageClassifier, cFrameBuffer, &boundingBox, &classifyError); ++ TfLiteSupportError* classifyError = nil; ++ TfLiteClassificationResult* cClassificationResult = ++ TfLiteImageClassifierClassifyWithRoi(_imageClassifier, cFrameBuffer, ++ &boundingBox, &classifyError); + + free(cFrameBuffer->buffer); + cFrameBuffer->buffer = nil; +@@ -126,8 +134,8 @@ + return nil; + } + +- TFLClassificationResult *classificationHeadsResults = +- [TFLClassificationUtils classificationResultFromCClassificationResults:cClassificationResult]; ++ TFLClassificationResult* classificationHeadsResults = [TFLClassificationUtils ++ classificationResultFromCClassificationResults:cClassificationResult]; + TfLiteClassificationResultDelete(cClassificationResult); + + return classificationHeadsResults; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.h +index 4ae67d11665b4..298485b3ceda2 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.h +@@ -37,8 +37,9 @@ NS_ASSUME_NONNULL_BEGIN + * @return The TfLiteFrameBuffer created from the gmlImage which can be used + * with the TF Lite Task Vision C library. + */ +-+ (nullable TfLiteFrameBuffer *)cFrameBufferFromGMLImage:(GMLImage *)gmlImage +- error:(NSError *_Nullable *)error; +++ (nullable TfLiteFrameBuffer*)cFrameBufferFromGMLImage:(GMLImage*)gmlImage ++ error:(NSError* _Nullable*) ++ error; + + - (instancetype)init NS_UNAVAILABLE; + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.m +index 7f2a0611ce1f2..72425b39630d1 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.m +@@ -24,18 +24,20 @@ + #import <CoreVideo/CoreVideo.h> + + @interface TFLCVPixelBufferUtils : NSObject +-+ (uint8_t *_Nullable)convertBGRAtoRGBforPixelBufferBaseAddress:(CVPixelBufferRef)pixelBuffer +- error:(NSError **)error; +++ (uint8_t* _Nullable) ++ convertBGRAtoRGBforPixelBufferBaseAddress:(CVPixelBufferRef)pixelBuffer ++ error:(NSError**)error; + @end + + @interface UIImage (RawPixelDataUtils) +-- (TfLiteFrameBuffer *)frameBufferWithError:(NSError **)error; ++- (TfLiteFrameBuffer*)frameBufferWithError:(NSError**)error; + @end + + @implementation TFLCVPixelBufferUtils + +-+ (uint8_t *)convertBGRAtoRGBforPixelBufferBaseAddress:(CVPixelBufferRef)pixelBuffer +- error:(NSError **)error { +++ (uint8_t*)convertBGRAtoRGBforPixelBufferBaseAddress: ++ (CVPixelBufferRef)pixelBuffer ++ error:(NSError**)error { + size_t width = CVPixelBufferGetWidth(pixelBuffer); + size_t height = CVPixelBufferGetHeight(pixelBuffer); + size_t stride = CVPixelBufferGetBytesPerRow(pixelBuffer); +@@ -43,17 +45,21 @@ + int destinationChannelCount = 3; + size_t destinationBytesPerRow = destinationChannelCount * width; + +- uint8_t *pixelBufferBaseAddress = (uint8_t *)CVPixelBufferGetBaseAddress(pixelBuffer); ++ uint8_t* pixelBufferBaseAddress = ++ (uint8_t*)CVPixelBufferGetBaseAddress(pixelBuffer); + +- uint8_t *destPixelBufferAddress = [TFLCommonUtils mallocWithSize:height * destinationBytesPerRow +- error:error]; ++ uint8_t* destPixelBufferAddress = ++ [TFLCommonUtils mallocWithSize:height * destinationBytesPerRow ++ error:error]; + + if (!destPixelBufferAddress) { + return NULL; + } + +- vImage_Buffer srcBuffer = { +- .data = pixelBufferBaseAddress, .height = height, .width = width, .rowBytes = stride}; ++ vImage_Buffer srcBuffer = {.data = pixelBufferBaseAddress, ++ .height = height, ++ .width = width, ++ .rowBytes = stride}; + + vImage_Buffer destBuffer = {.data = destPixelBufferAddress, + .height = height, +@@ -61,7 +67,8 @@ + .rowBytes = destinationBytesPerRow}; + + vImage_Error convertError = kvImageNoError; +- convertError = vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer, kvImageNoFlags); ++ convertError = ++ vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer, kvImageNoFlags); + + if (convertError != kvImageNoError) { + [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeImageProcessingError +@@ -78,8 +85,8 @@ + + @implementation UIImage (RawPixelDataUtils) + +-- (TfLiteFrameBuffer *)frameBufferWithError:(NSError **)error { +- TfLiteFrameBuffer *frameBuffer = NULL; ++- (TfLiteFrameBuffer*)frameBufferWithError:(NSError**)error { ++ TfLiteFrameBuffer* frameBuffer = NULL; + + if (self.CGImage) { + frameBuffer = [self frameBufferFromCGImage:self.CGImage error:error]; +@@ -95,23 +102,25 @@ + return frameBuffer; + } + +-+ (UInt8 *_Nullable)pixelDataFromCGImage:(CGImageRef)cgImage error:(NSError **)error { +++ (UInt8* _Nullable)pixelDataFromCGImage:(CGImageRef)cgImage ++ error:(NSError**)error { + long width = CGImageGetWidth(cgImage); + long height = CGImageGetHeight(cgImage); + + int bitsPerComponent = 8; +- UInt8 *buffer_to_return = NULL; ++ UInt8* buffer_to_return = NULL; + + CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB(); +- CGContextRef context = CGBitmapContextCreate(nil, width, height, bitsPerComponent, 0, colorSpace, +- kCGImageAlphaNoneSkipLast); ++ CGContextRef context = ++ CGBitmapContextCreate(nil, width, height, bitsPerComponent, 0, colorSpace, ++ kCGImageAlphaNoneSkipLast); + + if (context) { + CGContextDrawImage(context, CGRectMake(0, 0, width, height), cgImage); +- buffer_to_return = +- [UIImage populateRGBBufferFromSourceRGBABuffer:CGBitmapContextGetData(context) +- width:width +- height:height]; ++ buffer_to_return = [UIImage ++ populateRGBBufferFromSourceRGBABuffer:CGBitmapContextGetData(context) ++ width:width ++ height:height]; + CGContextRelease(context); + } + +@@ -126,15 +135,16 @@ + return buffer_to_return; + } + +-+ (nullable UInt8 *)populateRGBBufferFromSourceRGBABuffer:(UInt8 *)buffer +- width:(size_t)width +- height:(size_t)height { +- if (!buffer) return nil; +++ (nullable UInt8*)populateRGBBufferFromSourceRGBABuffer:(UInt8*)buffer ++ width:(size_t)width ++ height:(size_t)height { ++ if (!buffer) ++ return nil; + + int sourceChannelCount = 4; + int destChannelCount = 3; + +- UInt8 *buffer_to_return = malloc(height * destChannelCount * width); ++ UInt8* buffer_to_return = malloc(height * destChannelCount * width); + if (!buffer_to_return) { + return nil; + } +@@ -150,14 +160,15 @@ + return buffer_to_return; + } + +-- (TfLiteFrameBuffer *)frameBufferFromCGImage:(CGImageRef)cgImage error:(NSError **)error { +- UInt8 *buffer = [UIImage pixelDataFromCGImage:cgImage error:error]; ++- (TfLiteFrameBuffer*)frameBufferFromCGImage:(CGImageRef)cgImage ++ error:(NSError**)error { ++ UInt8* buffer = [UIImage pixelDataFromCGImage:cgImage error:error]; + + if (buffer == NULL) { + return NULL; + } + +- TfLiteFrameBuffer *cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer)); ++ TfLiteFrameBuffer* cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer)); + + cFrameBuffer->dimension.width = (int)CGImageGetWidth(cgImage); + cFrameBuffer->dimension.height = (int)CGImageGetHeight(cgImage); +@@ -169,14 +180,16 @@ + return cFrameBuffer; + } + +-- (TfLiteFrameBuffer *)frameBufferFromCIImage:(CIImage *)ciImage error:(NSError **)error { +- uint8_t *buffer = nil; ++- (TfLiteFrameBuffer*)frameBufferFromCIImage:(CIImage*)ciImage ++ error:(NSError**)error { ++ uint8_t* buffer = nil; + + int width = 0; + int height = 0; + if (ciImage.pixelBuffer) { +- buffer = [TFLCVPixelBufferUtils convertBGRAtoRGBforPixelBufferBaseAddress:ciImage.pixelBuffer +- error:error]; ++ buffer = [TFLCVPixelBufferUtils ++ convertBGRAtoRGBforPixelBufferBaseAddress:ciImage.pixelBuffer ++ error:error]; + width = (int)CVPixelBufferGetWidth(ciImage.pixelBuffer); + height = (int)CVPixelBufferGetHeight(ciImage.pixelBuffer); + +@@ -195,7 +208,7 @@ + return NULL; + } + +- TfLiteFrameBuffer *cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer)); ++ TfLiteFrameBuffer* cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer)); + cFrameBuffer->buffer = buffer; + cFrameBuffer->dimension.width = width; + cFrameBuffer->dimension.height = height; +@@ -210,41 +223,49 @@ + + @implementation GMLImageUtils + +-+ (nullable TfLiteFrameBuffer *)cFrameBufferFromGMLImage:(GMLImage *)gmlImage +- error:(NSError *_Nullable *)error { +- TfLiteFrameBuffer *cFrameBuffer = NULL; +++ (nullable TfLiteFrameBuffer*)cFrameBufferFromGMLImage:(GMLImage*)gmlImage ++ error:(NSError* _Nullable*) ++ error { ++ TfLiteFrameBuffer* cFrameBuffer = NULL; + + switch (gmlImage.imageSourceType) { + case GMLImageSourceTypeSampleBuffer: { +- CVPixelBufferRef sampleImagePixelBuffer = CMSampleBufferGetImageBuffer(gmlImage.sampleBuffer); +- cFrameBuffer = [GMLImageUtils bufferFromCVPixelBuffer:sampleImagePixelBuffer error:error]; ++ CVPixelBufferRef sampleImagePixelBuffer = ++ CMSampleBufferGetImageBuffer(gmlImage.sampleBuffer); ++ cFrameBuffer = ++ [GMLImageUtils bufferFromCVPixelBuffer:sampleImagePixelBuffer ++ error:error]; + break; + } + case GMLImageSourceTypePixelBuffer: { +- cFrameBuffer = [GMLImageUtils bufferFromCVPixelBuffer:gmlImage.pixelBuffer error:error]; ++ cFrameBuffer = [GMLImageUtils bufferFromCVPixelBuffer:gmlImage.pixelBuffer ++ error:error]; + break; + } + case GMLImageSourceTypeImage: { +- cFrameBuffer = [GMLImageUtils frameBufferFromUIImage:gmlImage.image error:error]; ++ cFrameBuffer = [GMLImageUtils frameBufferFromUIImage:gmlImage.image ++ error:error]; + } + + default: +- [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError +- description:@"Invalid source type for GMLImage." +- error:error]; ++ [TFLCommonUtils ++ customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError ++ description:@"Invalid source type for GMLImage." ++ error:error]; + break; + } + + return cFrameBuffer; + } + +-+ (TfLiteFrameBuffer *)frameBufferFromUIImage:(UIImage *)image error:(NSError **)error { +++ (TfLiteFrameBuffer*)frameBufferFromUIImage:(UIImage*)image ++ error:(NSError**)error { + return [image frameBufferWithError:error]; + } + +-+ (TfLiteFrameBuffer *)bufferFromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer +- error:(NSError **)error { +- uint8_t *buffer = nil; +++ (TfLiteFrameBuffer*)bufferFromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer ++ error:(NSError**)error { ++ uint8_t* buffer = nil; + enum TfLiteFrameBufferFormat cPixelFormat = kRGB; + + CVPixelBufferLockBaseAddress(pixelBuffer, 0); +@@ -253,25 +274,30 @@ + switch (pixelBufferFormat) { + case kCVPixelFormatType_24RGB: { + cPixelFormat = kRGB; +- buffer = [GMLImageUtils copyPixelufferDataForInference:pixelBuffer error:error]; ++ buffer = [GMLImageUtils copyPixelufferDataForInference:pixelBuffer ++ error:error]; + break; + } + case kCVPixelFormatType_32RGBA: { + cPixelFormat = kRGBA; +- buffer = [GMLImageUtils copyPixelufferDataForInference:pixelBuffer error:error]; ++ buffer = [GMLImageUtils copyPixelufferDataForInference:pixelBuffer ++ error:error]; + break; + } + case kCVPixelFormatType_32BGRA: { + cPixelFormat = kRGB; +- buffer = [TFLCVPixelBufferUtils convertBGRAtoRGBforPixelBufferBaseAddress:pixelBuffer +- error:error]; ++ buffer = [TFLCVPixelBufferUtils ++ convertBGRAtoRGBforPixelBufferBaseAddress:pixelBuffer ++ error:error]; + break; + } + + default: { +- [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError +- description:@"Unsupported pixel format for TfLiteFrameBufferFormat." +- error:error]; ++ [TFLCommonUtils ++ customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError ++ description: ++ @"Unsupported pixel format for TfLiteFrameBufferFormat." ++ error:error]; + break; + } + } +@@ -282,7 +308,7 @@ + return nil; + } + +- TfLiteFrameBuffer *cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer)); ++ TfLiteFrameBuffer* cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer)); + + cFrameBuffer->dimension.width = (int)CVPixelBufferGetWidth(pixelBuffer); + cFrameBuffer->dimension.height = (int)CVPixelBufferGetHeight(pixelBuffer); +@@ -292,12 +318,14 @@ + return cFrameBuffer; + } + +-+ (UInt8 *)copyPixelufferDataForInference:(CVPixelBufferRef)pixelBuffer error:(NSError **)error { +++ (UInt8*)copyPixelufferDataForInference:(CVPixelBufferRef)pixelBuffer ++ error:(NSError**)error { + size_t height = CVPixelBufferGetHeight(pixelBuffer); + size_t stride = CVPixelBufferGetBytesPerRow(pixelBuffer); +- UInt8 *buffer = [TFLCommonUtils mallocWithSize:height * stride error:error]; ++ UInt8* buffer = [TFLCommonUtils mallocWithSize:height * stride error:error]; + +- if (buffer) memcpy(buffer, CVPixelBufferGetBaseAddress(pixelBuffer), height * stride); ++ if (buffer) ++ memcpy(buffer, CVPixelBufferGetBaseAddress(pixelBuffer), height * stride); + + return buffer; + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m +index b5f514397e41d..f26959434bbc9 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m +@@ -18,123 +18,140 @@ + NS_ASSUME_NONNULL_BEGIN + + @interface TFLImageClassifierTests : XCTestCase +-@property(nonatomic, nullable) NSString *modelPath; ++@property(nonatomic, nullable) NSString* modelPath; + @end + + @implementation TFLImageClassifierTests + +-- (GMLImage *)imageFromBundleWithName:(NSString *)name ofType:(NSString *)type { +- NSString *imagePath = [[NSBundle bundleForClass:[self class]] pathForResource:name ofType:type]; ++- (GMLImage*)imageFromBundleWithName:(NSString*)name ofType:(NSString*)type { ++ NSString* imagePath = ++ [[NSBundle bundleForClass:[self class]] pathForResource:name ofType:type]; + XCTAssertNotNil(imagePath); +- UIImage *image = [[UIImage alloc] initWithContentsOfFile:imagePath]; ++ UIImage* image = [[UIImage alloc] initWithContentsOfFile:imagePath]; + XCTAssertNotNil(image); + +- GMLImage *gmlImage = [[GMLImage alloc] initWithImage:image]; ++ GMLImage* gmlImage = [[GMLImage alloc] initWithImage:image]; + XCTAssertNotNil(gmlImage); + + return gmlImage; + } + - (void)setUp { +- // Put setup code here. This method is called before the invocation of each test method in the +- // class. static let bundle = Bundle(for: TFLSentencepieceTokenizerTest.self) +- self.modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"mobilenet_v2_1.0_224" +- ofType:@"tflite"]; ++ // Put setup code here. This method is called before the invocation of each ++ // test method in the class. static let bundle = Bundle(for: ++ // TFLSentencepieceTokenizerTest.self) ++ self.modelPath = [[NSBundle bundleForClass:[self class]] ++ pathForResource:@"mobilenet_v2_1.0_224" ++ ofType:@"tflite"]; + XCTAssertNotNil(self.modelPath); + } + + - (void)tearDown { +- // Put teardown code here. This method is called after the invocation of each test method in the +- // class. ++ // Put teardown code here. This method is called after the invocation of each ++ // test method in the class. + } + + - (void)testSuccessfullImageInferenceOnMLImageWithUIImage { +- TFLImageClassifierOptions *imageClassifierOptions = ++ TFLImageClassifierOptions* imageClassifierOptions = + [[TFLImageClassifierOptions alloc] initWithModelPath:self.modelPath]; + +- TFLImageClassifier *imageClassifier = +- [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil]; ++ TFLImageClassifier* imageClassifier = ++ [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions ++ error:nil]; + XCTAssertNotNil(imageClassifier); +- GMLImage *gmlImage = [self imageFromBundleWithName:@"burger" ofType:@"jpg"]; ++ GMLImage* gmlImage = [self imageFromBundleWithName:@"burger" ofType:@"jpg"]; + +- TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage +- error:nil]; ++ TFLClassificationResult* classificationResults = ++ [imageClassifier classifyWithGMLImage:gmlImage error:nil]; + XCTAssertTrue([classificationResults.classifications count] > 0); +- XCTAssertTrue([classificationResults.classifications[0].categories count] > 0); ++ XCTAssertTrue([classificationResults.classifications[0].categories count] > ++ 0); + +- TFLCategory *category = classificationResults.classifications[0].categories[0]; ++ TFLCategory* category = ++ classificationResults.classifications[0].categories[0]; + XCTAssertTrue([category.label isEqual:@"cheeseburger"]); + // TODO: match the score as image_classifier_test.cc + XCTAssertEqualWithAccuracy(category.score, 0.748976, 0.001); + } + + - (void)testModelOptionsWithMaxResults { +- TFLImageClassifierOptions *imageClassifierOptions = ++ TFLImageClassifierOptions* imageClassifierOptions = + [[TFLImageClassifierOptions alloc] initWithModelPath:self.modelPath]; + int maxResults = 3; + imageClassifierOptions.classificationOptions.maxResults = maxResults; + +- TFLImageClassifier *imageClassifier = +- [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil]; ++ TFLImageClassifier* imageClassifier = ++ [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions ++ error:nil]; + XCTAssertNotNil(imageClassifier); + +- GMLImage *gmlImage = [self imageFromBundleWithName:@"burger" ofType:@"jpg"]; ++ GMLImage* gmlImage = [self imageFromBundleWithName:@"burger" ofType:@"jpg"]; + +- TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage +- error:nil]; ++ TFLClassificationResult* classificationResults = ++ [imageClassifier classifyWithGMLImage:gmlImage error:nil]; + XCTAssertTrue([classificationResults.classifications count] > 0); +- XCTAssertLessThanOrEqual([classificationResults.classifications[0].categories count], maxResults); ++ XCTAssertLessThanOrEqual( ++ [classificationResults.classifications[0].categories count], maxResults); + +- TFLCategory *category = classificationResults.classifications[0].categories[0]; ++ TFLCategory* category = ++ classificationResults.classifications[0].categories[0]; + XCTAssertTrue([category.label isEqual:@"cheeseburger"]); + // TODO: match the score as image_classifier_test.cc + XCTAssertEqualWithAccuracy(category.score, 0.748976, 0.001); + } + + - (void)testInferenceWithBoundingBox { +- TFLImageClassifierOptions *imageClassifierOptions = ++ TFLImageClassifierOptions* imageClassifierOptions = + [[TFLImageClassifierOptions alloc] initWithModelPath:self.modelPath]; + int maxResults = 3; + imageClassifierOptions.classificationOptions.maxResults = maxResults; + +- TFLImageClassifier *imageClassifier = +- [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil]; ++ TFLImageClassifier* imageClassifier = ++ [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions ++ error:nil]; + XCTAssertNotNil(imageClassifier); + +- GMLImage *gmlImage = [self imageFromBundleWithName:@"multi_objects" ofType:@"jpg"]; ++ GMLImage* gmlImage = [self imageFromBundleWithName:@"multi_objects" ++ ofType:@"jpg"]; + + CGRect roi = CGRectMake(406, 110, 148, 153); +- TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage +- regionOfInterest:roi +- error:nil]; ++ TFLClassificationResult* classificationResults = ++ [imageClassifier classifyWithGMLImage:gmlImage ++ regionOfInterest:roi ++ error:nil]; + XCTAssertTrue([classificationResults.classifications count] > 0); +- XCTAssertTrue([classificationResults.classifications[0].categories count] > 0); ++ XCTAssertTrue([classificationResults.classifications[0].categories count] > ++ 0); + +- TFLCategory *category = classificationResults.classifications[0].categories[0]; ++ TFLCategory* category = ++ classificationResults.classifications[0].categories[0]; + // TODO: match the label and score as image_classifier_test.cc + // XCTAssertTrue([category.label isEqual:@"soccer ball"]); + // XCTAssertEqualWithAccuracy(category.score, 0.256512, 0.001); + } + + - (void)testInferenceWithRGBAImage { +- TFLImageClassifierOptions *imageClassifierOptions = ++ TFLImageClassifierOptions* imageClassifierOptions = + [[TFLImageClassifierOptions alloc] initWithModelPath:self.modelPath]; + +- TFLImageClassifier *imageClassifier = +- [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil]; ++ TFLImageClassifier* imageClassifier = ++ [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions ++ error:nil]; + XCTAssertNotNil(imageClassifier); + +- GMLImage *gmlImage = [self imageFromBundleWithName:@"sparrow" ofType:@"png"]; ++ GMLImage* gmlImage = [self imageFromBundleWithName:@"sparrow" ofType:@"png"]; + XCTAssertNotNil(gmlImage); + +- TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage +- error:nil]; ++ TFLClassificationResult* classificationResults = ++ [imageClassifier classifyWithGMLImage:gmlImage error:nil]; + XCTAssertTrue([classificationResults.classifications count] > 0); +- XCTAssertTrue([classificationResults.classifications[0].categories count] > 0); ++ XCTAssertTrue([classificationResults.classifications[0].categories count] > ++ 0); + +- TFLCategory *category = classificationResults.classifications[0].categories[0]; ++ TFLCategory* category = ++ classificationResults.classifications[0].categories[0]; + XCTAssertTrue([category.label isEqual:@"junco"]); +- // TODO: inspect if score is correct. Better to test againest "burger", because we know the +- // expected result for "burger.jpg". ++ // TODO: inspect if score is correct. Better to test againest "burger", ++ // because we know the expected result for "burger.jpg". + XCTAssertEqualWithAccuracy(category.score, 0.253016, 0.001); + } + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h +index aa6924893b301..d08f5177ceee9 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h +@@ -28,11 +28,13 @@ NS_ASSUME_NONNULL_BEGIN + /** + * Initializes the tokenizer with the path to wordpiece vocabulary file. + */ +-- (instancetype)initWithVocabPath:(NSString *)vocabPath NS_DESIGNATED_INITIALIZER; ++- (instancetype)initWithVocabPath:(NSString*)vocabPath ++ NS_DESIGNATED_INITIALIZER; + + /** + * Initializes the tokenizer with a list of tokens. + */ +-- (instancetype)initWithVocab:(NSArray<NSString *> *)vocab NS_DESIGNATED_INITIALIZER; ++- (instancetype)initWithVocab:(NSArray<NSString*>*)vocab ++ NS_DESIGNATED_INITIALIZER; + @end + NS_ASSUME_NONNULL_END +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm +index 949cef2b0b7c2..2a028f6cd7d1a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm +@@ -24,7 +24,7 @@ using BertTokenizerCPP = ::tflite::support::text::tokenizer::BertTokenizer; + std::unique_ptr<BertTokenizerCPP> _bertTokenizer; + } + +-- (instancetype)initWithVocabPath:(NSString *)vocabPath { ++- (instancetype)initWithVocabPath:(NSString*)vocabPath { + self = [super init]; + if (self) { + _bertTokenizer = absl::make_unique<BertTokenizerCPP>(MakeString(vocabPath)); +@@ -32,12 +32,12 @@ using BertTokenizerCPP = ::tflite::support::text::tokenizer::BertTokenizer; + return self; + } + +-- (instancetype)initWithVocab:(NSArray<NSString *> *)vocab { ++- (instancetype)initWithVocab:(NSArray<NSString*>*)vocab { + self = [super init]; + if (self) { + std::vector<std::string> vocabCpp; + vocabCpp.reserve([vocab count]); +- for (NSString *word in vocab) { ++ for (NSString* word in vocab) { + vocabCpp.emplace_back(MakeString(word)); + } + _bertTokenizer = absl::make_unique<BertTokenizerCPP>(vocabCpp); +@@ -45,11 +45,11 @@ using BertTokenizerCPP = ::tflite::support::text::tokenizer::BertTokenizer; + return self; + } + +-- (NSArray<NSString *> *)tokensFromInput:(NSString *)input { ++- (NSArray<NSString*>*)tokensFromInput:(NSString*)input { + return Tokenize(_bertTokenizer.get(), input); + } + +-- (NSArray<NSNumber *> *)idsFromTokens:(NSArray<NSString *> *)tokens { ++- (NSArray<NSNumber*>*)idsFromTokens:(NSArray<NSString*>*)tokens { + return ConvertTokensToIds(_bertTokenizer.get(), tokens); + } + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h +index eef3bf1e223e6..9813e32ecb5d3 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h +@@ -28,6 +28,6 @@ NS_ASSUME_NONNULL_BEGIN + /** + * Initializes the tokenizer with the path to sentencepiece model file. + */ +-- (instancetype)initWithModelPath:(NSString *)modelPath; ++- (instancetype)initWithModelPath:(NSString*)modelPath; + @end + NS_ASSUME_NONNULL_END +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm +index 1e21cee5c08d2..1ba49923040c1 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm +@@ -19,25 +19,27 @@ limitations under the License. + #import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h" + + NS_ASSUME_NONNULL_BEGIN +-using SentencepieceTokenizerCPP = ::tflite::support::text::tokenizer::SentencePieceTokenizer; ++using SentencepieceTokenizerCPP = ++ ::tflite::support::text::tokenizer::SentencePieceTokenizer; + + @implementation TFLSentencepieceTokenizer { + std::unique_ptr<SentencepieceTokenizerCPP> _spTokenizer; + } + +-- (instancetype)initWithModelPath:(NSString *)modelPath { ++- (instancetype)initWithModelPath:(NSString*)modelPath { + self = [super init]; + if (self) { +- _spTokenizer = absl::make_unique<SentencepieceTokenizerCPP>(MakeString(modelPath)); ++ _spTokenizer = ++ absl::make_unique<SentencepieceTokenizerCPP>(MakeString(modelPath)); + } + return self; + } + +-- (NSArray<NSString *> *)tokensFromInput:(NSString *)input { ++- (NSArray<NSString*>*)tokensFromInput:(NSString*)input { + return Tokenize(_spTokenizer.get(), input); + } + +-- (NSArray<NSNumber *> *)idsFromTokens:(NSArray<NSString *> *)tokens { ++- (NSArray<NSNumber*>*)idsFromTokens:(NSArray<NSString*>*)tokens { + return ConvertTokensToIds(_spTokenizer.get(), tokens); + } + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h +index ee0972f8aba30..bd832060b6e80 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h +@@ -26,7 +26,7 @@ NS_ASSUME_NONNULL_BEGIN + * + * @return A list of tokens. + */ +-- (NSArray<NSString *> *)tokensFromInput:(NSString *)input; ++- (NSArray<NSString*>*)tokensFromInput:(NSString*)input; + + /* + * Convert a list of tokens back to their coressponding IDs. +@@ -34,6 +34,6 @@ NS_ASSUME_NONNULL_BEGIN + * + * @return A list of ids. + */ +-- (NSArray<NSNumber *> *)idsFromTokens:(NSArray<NSString *> *)tokens; ++- (NSArray<NSNumber*>*)idsFromTokens:(NSArray<NSString*>*)tokens; + @end + NS_ASSUME_NONNULL_END +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h +index 574b555301616..14e2906675b71 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h +@@ -18,21 +18,24 @@ limitations under the License. + using ::tflite::support::text::tokenizer::Tokenizer; + + /** +- * Invokes the cpp tokenizer's tokenize function and converts input/output to objc. ++ * Invokes the cpp tokenizer's tokenize function and converts input/output to ++ * objc. + * + * @param tokenizer The cpp tokenizer pointer. + * @param input The input string to be tokenized. + * + * @return A list of tokens. + */ +-NSArray<NSString *> *Tokenize(Tokenizer *tokenizer, NSString *input); ++NSArray<NSString*>* Tokenize(Tokenizer* tokenizer, NSString* input); + + /** +- * Invokes the cpp tokenizer's convertTokensToIds function and converts input/output to objc. ++ * Invokes the cpp tokenizer's convertTokensToIds function and converts ++ * input/output to objc. + * + * @param tokenizer The cpp tokenizer pointer. + * @param input The tokens to be converted. + * + * @return A list of ids. + */ +-NSArray<NSNumber *> *ConvertTokensToIds(Tokenizer *tokenizer, NSArray<NSString *> *tokens); ++NSArray<NSNumber*>* ConvertTokensToIds(Tokenizer* tokenizer, ++ NSArray<NSString*>* tokens); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm +index 52180578170d8..8e92e3712e29e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm +@@ -18,21 +18,24 @@ limitations under the License. + + using ::tflite::support::text::tokenizer::TokenizerResult; + +-NSArray<NSString *> *Tokenize(Tokenizer *tokenizer, NSString *input) { ++NSArray<NSString*>* Tokenize(Tokenizer* tokenizer, NSString* input) { + TokenizerResult tokenize_result = tokenizer->Tokenize(MakeString(input)); + std::vector<std::string> subwords = tokenize_result.subwords; +- NSMutableArray<NSString *> *ret = [NSMutableArray arrayWithCapacity:subwords.size()]; ++ NSMutableArray<NSString*>* ret = ++ [NSMutableArray arrayWithCapacity:subwords.size()]; + for (int i = 0; i < subwords.size(); ++i) { + [ret addObject:MakeNSString(subwords[i])]; + } + return ret; + } + +-NSArray<NSNumber *> *ConvertTokensToIds(Tokenizer *tokenizer, NSArray<NSString *> *tokens) { +- NSMutableArray<NSNumber *> *ret = [NSMutableArray arrayWithCapacity:[tokens count]]; +- for (NSString *token in tokens) { ++NSArray<NSNumber*>* ConvertTokensToIds(Tokenizer* tokenizer, ++ NSArray<NSString*>* tokens) { ++ NSMutableArray<NSNumber*>* ret = ++ [NSMutableArray arrayWithCapacity:[tokens count]]; ++ for (NSString* token in tokens) { + std::string cc_token = MakeString(token); +- const char *cToken = cc_token.c_str(); ++ const char* cToken = cc_token.c_str(); + int id; + tokenizer->LookupId(cToken, &id); + [ret addObject:[NSNumber numberWithInt:id]]; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java +index 1d8a9767f41c7..e066146eb0c7d 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java +@@ -15,19 +15,24 @@ limitations under the License. + + package org.tensorflow.lite.support.audio; + +-import static java.lang.System.arraycopy; + import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; + ++import static java.lang.System.arraycopy; ++ + import android.media.AudioFormat; + import android.media.AudioRecord; + import android.os.Build; ++ + import androidx.annotation.RequiresApi; ++ + import com.google.auto.value.AutoValue; ++ ++import org.tensorflow.lite.DataType; ++import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; ++ + import java.nio.ByteBuffer; + import java.nio.ByteOrder; + import java.nio.FloatBuffer; +-import org.tensorflow.lite.DataType; +-import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + + /** + * Defines a ring buffer and some utility functions to prepare the input audio samples. +@@ -60,285 +65,282 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + * </pre> + */ + public class TensorAudio { ++ private static final String TAG = TensorAudio.class.getSimpleName(); ++ private final FloatRingBuffer buffer; ++ private final TensorAudioFormat format; + +- private static final String TAG = TensorAudio.class.getSimpleName(); +- private final FloatRingBuffer buffer; +- private final TensorAudioFormat format; +- +- /** +- * Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code +- * sampleCounts} * {@code format.getChannels()}. +- * +- * @param format the expected {@link TensorAudioFormat} of audio data loaded into this class. +- * @param sampleCounts the number of samples to be fed into the model +- */ +- public static TensorAudio create(TensorAudioFormat format, int sampleCounts) { +- return new TensorAudio(format, sampleCounts); +- } +- +- /** +- * Creates a {@link TensorAudio} instance with a ring buffer whose size is {@code sampleCounts} * +- * {@code format.getChannelCount()}. +- * +- * @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines +- * the number of channels and sample rate. +- * @param sampleCounts the number of samples to be fed into the model +- */ +- public static TensorAudio create(AudioFormat format, int sampleCounts) { +- return new TensorAudio(TensorAudioFormat.create(format), sampleCounts); +- } +- +- /** +- * Wraps a few constants describing the format of the incoming audio samples, namely number of +- * channels and the sample rate. By default, channels is set to 1. +- */ +- @AutoValue +- public abstract static class TensorAudioFormat { +- private static final int DEFAULT_CHANNELS = 1; +- +- /** Creates a {@link TensorAudioFormat} instance from Android AudioFormat class. */ +- @RequiresApi(Build.VERSION_CODES.M) +- public static TensorAudioFormat create(AudioFormat format) { +- return TensorAudioFormat.builder() +- .setChannels(format.getChannelCount()) +- .setSampleRate(format.getSampleRate()) +- .build(); ++ /** ++ * Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code ++ * sampleCounts} * {@code format.getChannels()}. ++ * ++ * @param format the expected {@link TensorAudioFormat} of audio data loaded into this class. ++ * @param sampleCounts the number of samples to be fed into the model ++ */ ++ public static TensorAudio create(TensorAudioFormat format, int sampleCounts) { ++ return new TensorAudio(format, sampleCounts); + } + +- public abstract int getChannels(); +- +- public abstract int getSampleRate(); +- +- public static Builder builder() { +- return new AutoValue_TensorAudio_TensorAudioFormat.Builder().setChannels(DEFAULT_CHANNELS); ++ /** ++ * Creates a {@link TensorAudio} instance with a ring buffer whose size is {@code sampleCounts} ++ * * ++ * {@code format.getChannelCount()}. ++ * ++ * @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines ++ * the number of channels and sample rate. ++ * @param sampleCounts the number of samples to be fed into the model ++ */ ++ public static TensorAudio create(AudioFormat format, int sampleCounts) { ++ return new TensorAudio(TensorAudioFormat.create(format), sampleCounts); + } + +- /** Builder for {@link TensorAudioFormat} */ +- @AutoValue.Builder +- public abstract static class Builder { +- +- /* By default, it's set to have 1 channel. */ +- public abstract Builder setChannels(int value); +- +- public abstract Builder setSampleRate(int value); +- +- abstract TensorAudioFormat autoBuild(); +- +- public TensorAudioFormat build() { +- TensorAudioFormat format = autoBuild(); +- checkArgument(format.getChannels() > 0, "Number of channels should be greater than 0"); +- checkArgument(format.getSampleRate() > 0, "Sample rate should be greater than 0"); +- return format; +- } +- } +- } +- +- /** +- * Stores the input audio samples {@code src} in the ring buffer. +- * +- * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For +- * multi-channel input, the array is interleaved. +- */ +- public void load(float[] src) { +- load(src, 0, src.length); +- } +- +- /** +- * Stores the input audio samples {@code src} in the ring buffer. +- * +- * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For +- * multi-channel input, the array is interleaved. +- * @param offsetInFloat starting position in the {@code src} array +- * @param sizeInFloat the number of float values to be copied +- * @throws IllegalArgumentException for incompatible audio format or incorrect input size +- */ +- public void load(float[] src, int offsetInFloat, int sizeInFloat) { +- checkArgument( +- sizeInFloat % format.getChannels() == 0, +- String.format( +- "Size (%d) needs to be a multiplier of the number of channels (%d)", +- sizeInFloat, format.getChannels())); +- buffer.load(src, offsetInFloat, sizeInFloat); +- } +- +- /** +- * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring +- * buffer. +- * +- * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For +- * multi-channel input, the array is interleaved. +- */ +- public void load(short[] src) { +- load(src, 0, src.length); +- } +- +- /** +- * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring +- * buffer. +- * +- * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For +- * multi-channel input, the array is interleaved. +- * @param offsetInShort starting position in the src array +- * @param sizeInShort the number of short values to be copied +- * @throws IllegalArgumentException if the source array can't be copied +- */ +- public void load(short[] src, int offsetInShort, int sizeInShort) { +- checkArgument( +- offsetInShort + sizeInShort <= src.length, +- String.format( +- "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)", +- offsetInShort, sizeInShort, src.length)); +- float[] floatData = new float[sizeInShort]; +- for (int i = offsetInShort; i < sizeInShort; i++) { +- // Convert the data to PCM Float encoding i.e. values between -1 and 1 +- floatData[i] = src[i] / Short.MAX_VALUE; +- } +- load(floatData); +- } +- +- /** +- * Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only +- * supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT. +- * +- * @param record an instance of {@link android.media.AudioRecord} +- * @return number of captured audio values whose size is {@code channelCount * sampleCount}. If +- * there was no new data in the AudioRecord or an error occurred, this method will return 0. +- * @throws IllegalArgumentException for unsupported audio encoding format +- * @throws IllegalStateException if reading from AudioRecord failed +- */ +- @RequiresApi(Build.VERSION_CODES.M) +- public int load(AudioRecord record) { +- checkArgument( +- this.format.equals(TensorAudioFormat.create(record.getFormat())), +- "Incompatible audio format."); +- int loadedValues = 0; +- if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) { +- float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()]; +- loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING); +- if (loadedValues > 0) { +- load(newData, 0, loadedValues); +- return loadedValues; +- } +- } else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) { +- short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()]; +- loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING); +- if (loadedValues > 0) { +- load(newData, 0, loadedValues); +- return loadedValues; +- } +- } else { +- throw new IllegalArgumentException( +- "Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT."); ++ /** ++ * Wraps a few constants describing the format of the incoming audio samples, namely number of ++ * channels and the sample rate. By default, channels is set to 1. ++ */ ++ @AutoValue ++ public abstract static class TensorAudioFormat { ++ private static final int DEFAULT_CHANNELS = 1; ++ ++ /** Creates a {@link TensorAudioFormat} instance from Android AudioFormat class. */ ++ @RequiresApi(Build.VERSION_CODES.M) ++ public static TensorAudioFormat create(AudioFormat format) { ++ return TensorAudioFormat.builder() ++ .setChannels(format.getChannelCount()) ++ .setSampleRate(format.getSampleRate()) ++ .build(); ++ } ++ ++ public abstract int getChannels(); ++ ++ public abstract int getSampleRate(); ++ ++ public static Builder builder() { ++ return new AutoValue_TensorAudio_TensorAudioFormat.Builder().setChannels( ++ DEFAULT_CHANNELS); ++ } ++ ++ /** Builder for {@link TensorAudioFormat} */ ++ @AutoValue.Builder ++ public abstract static class Builder { ++ /* By default, it's set to have 1 channel. */ ++ public abstract Builder setChannels(int value); ++ ++ public abstract Builder setSampleRate(int value); ++ ++ abstract TensorAudioFormat autoBuild(); ++ ++ public TensorAudioFormat build() { ++ TensorAudioFormat format = autoBuild(); ++ checkArgument( ++ format.getChannels() > 0, "Number of channels should be greater than 0"); ++ checkArgument(format.getSampleRate() > 0, "Sample rate should be greater than 0"); ++ return format; ++ } ++ } + } + +- switch (loadedValues) { +- case AudioRecord.ERROR_INVALID_OPERATION: +- throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION"); +- +- case AudioRecord.ERROR_BAD_VALUE: +- throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE"); +- +- case AudioRecord.ERROR_DEAD_OBJECT: +- throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT"); ++ /** ++ * Stores the input audio samples {@code src} in the ring buffer. ++ * ++ * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For ++ * multi-channel input, the array is interleaved. ++ */ ++ public void load(float[] src) { ++ load(src, 0, src.length); ++ } + +- case AudioRecord.ERROR: +- throw new IllegalStateException("AudioRecord.ERROR"); ++ /** ++ * Stores the input audio samples {@code src} in the ring buffer. ++ * ++ * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For ++ * multi-channel input, the array is interleaved. ++ * @param offsetInFloat starting position in the {@code src} array ++ * @param sizeInFloat the number of float values to be copied ++ * @throws IllegalArgumentException for incompatible audio format or incorrect input size ++ */ ++ public void load(float[] src, int offsetInFloat, int sizeInFloat) { ++ checkArgument(sizeInFloat % format.getChannels() == 0, ++ String.format("Size (%d) needs to be a multiplier of the number of channels (%d)", ++ sizeInFloat, format.getChannels())); ++ buffer.load(src, offsetInFloat, sizeInFloat); ++ } + +- default: +- return 0; ++ /** ++ * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ++ * ring buffer. ++ * ++ * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For ++ * multi-channel input, the array is interleaved. ++ */ ++ public void load(short[] src) { ++ load(src, 0, src.length); + } +- } +- +- /** +- * Returns a float {@link TensorBuffer} holding all the available audio samples in {@link +- * android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1]. +- */ +- public TensorBuffer getTensorBuffer() { +- ByteBuffer byteBuffer = buffer.getBuffer(); +- TensorBuffer tensorBuffer = +- TensorBuffer.createFixedSize( +- new int[] { +- /* batch= */ 1, /* modelInputLengthInFloat= */ byteBuffer.asFloatBuffer().limit() +- }, +- DataType.FLOAT32); +- tensorBuffer.loadBuffer(byteBuffer); +- return tensorBuffer; +- } +- +- /* Returns the {@link TensorAudioFormat} associated with the tensor. */ +- public TensorAudioFormat getFormat() { +- return format; +- } +- +- private TensorAudio(TensorAudioFormat format, int sampleCounts) { +- this.format = format; +- this.buffer = new FloatRingBuffer(sampleCounts * format.getChannels()); +- } +- +- /** Actual implementation of the ring buffer. */ +- private static class FloatRingBuffer { +- +- private final float[] buffer; +- private int nextIndex = 0; +- +- public FloatRingBuffer(int flatSize) { +- buffer = new float[flatSize]; ++ ++ /** ++ * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ++ * ring buffer. ++ * ++ * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For ++ * multi-channel input, the array is interleaved. ++ * @param offsetInShort starting position in the src array ++ * @param sizeInShort the number of short values to be copied ++ * @throws IllegalArgumentException if the source array can't be copied ++ */ ++ public void load(short[] src, int offsetInShort, int sizeInShort) { ++ checkArgument(offsetInShort + sizeInShort <= src.length, ++ String.format( ++ "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)", ++ offsetInShort, sizeInShort, src.length)); ++ float[] floatData = new float[sizeInShort]; ++ for (int i = offsetInShort; i < sizeInShort; i++) { ++ // Convert the data to PCM Float encoding i.e. values between -1 and 1 ++ floatData[i] = src[i] / Short.MAX_VALUE; ++ } ++ load(floatData); + } + + /** +- * Loads the entire float array to the ring buffer. If the float array is longer than ring +- * buffer's capacity, samples with lower indices in the array will be ignored. ++ * Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only ++ * supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT. ++ * ++ * @param record an instance of {@link android.media.AudioRecord} ++ * @return number of captured audio values whose size is {@code channelCount * sampleCount}. If ++ * there was no new data in the AudioRecord or an error occurred, this method will return 0. ++ * @throws IllegalArgumentException for unsupported audio encoding format ++ * @throws IllegalStateException if reading from AudioRecord failed + */ +- public void load(float[] newData) { +- load(newData, 0, newData.length); ++ @RequiresApi(Build.VERSION_CODES.M) ++ public int load(AudioRecord record) { ++ checkArgument(this.format.equals(TensorAudioFormat.create(record.getFormat())), ++ "Incompatible audio format."); ++ int loadedValues = 0; ++ if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) { ++ float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()]; ++ loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING); ++ if (loadedValues > 0) { ++ load(newData, 0, loadedValues); ++ return loadedValues; ++ } ++ } else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) { ++ short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()]; ++ loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING); ++ if (loadedValues > 0) { ++ load(newData, 0, loadedValues); ++ return loadedValues; ++ } ++ } else { ++ throw new IllegalArgumentException( ++ "Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT."); ++ } ++ ++ switch (loadedValues) { ++ case AudioRecord.ERROR_INVALID_OPERATION: ++ throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION"); ++ ++ case AudioRecord.ERROR_BAD_VALUE: ++ throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE"); ++ ++ case AudioRecord.ERROR_DEAD_OBJECT: ++ throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT"); ++ ++ case AudioRecord.ERROR: ++ throw new IllegalStateException("AudioRecord.ERROR"); ++ ++ default: ++ return 0; ++ } + } + + /** +- * Loads a slice of the float array to the ring buffer. If the float array is longer than ring +- * buffer's capacity, samples with lower indices in the array will be ignored. ++ * Returns a float {@link TensorBuffer} holding all the available audio samples in {@link ++ * android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1]. + */ +- public void load(float[] newData, int offset, int size) { +- checkArgument( +- offset + size <= newData.length, +- String.format( +- "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)", +- offset, size, newData.length)); +- // If buffer can't hold all the data, only keep the most recent data of size buffer.length +- if (size > buffer.length) { +- offset = size - buffer.length; +- size = buffer.length; +- } +- if (nextIndex + size < buffer.length) { +- // No need to wrap nextIndex, just copy newData[offset:offset + size] +- // to buffer[nextIndex:nextIndex+size] +- arraycopy(newData, offset, buffer, nextIndex, size); +- } else { +- // Need to wrap nextIndex, perform copy in two chunks. +- int firstChunkSize = buffer.length - nextIndex; +- // First copy newData[offset:offset+firstChunkSize] to buffer[nextIndex:buffer.length] +- arraycopy(newData, offset, buffer, nextIndex, firstChunkSize); +- // Then copy newData[offset+firstChunkSize:offset+size] to buffer[0:size-firstChunkSize] +- arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize); +- } +- +- nextIndex = (nextIndex + size) % buffer.length; ++ public TensorBuffer getTensorBuffer() { ++ ByteBuffer byteBuffer = buffer.getBuffer(); ++ TensorBuffer tensorBuffer = TensorBuffer.createFixedSize( ++ new int[] {/* batch= */ 1, ++ /* modelInputLengthInFloat= */ byteBuffer.asFloatBuffer().limit()}, ++ DataType.FLOAT32); ++ tensorBuffer.loadBuffer(byteBuffer); ++ return tensorBuffer; ++ } ++ ++ /* Returns the {@link TensorAudioFormat} associated with the tensor. */ ++ public TensorAudioFormat getFormat() { ++ return format; + } + +- public ByteBuffer getBuffer() { +- // Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms, which +- // can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around 0.01ms), so +- // generally we don't create direct buffer for every invocation. +- ByteBuffer byteBuffer = ByteBuffer.allocate(DataType.FLOAT32.byteSize() * buffer.length); +- byteBuffer.order(ByteOrder.nativeOrder()); +- FloatBuffer result = byteBuffer.asFloatBuffer(); +- result.put(buffer, nextIndex, buffer.length - nextIndex); +- result.put(buffer, 0, nextIndex); +- byteBuffer.rewind(); +- return byteBuffer; ++ private TensorAudio(TensorAudioFormat format, int sampleCounts) { ++ this.format = format; ++ this.buffer = new FloatRingBuffer(sampleCounts * format.getChannels()); + } + +- public int getCapacity() { +- return buffer.length; ++ /** Actual implementation of the ring buffer. */ ++ private static class FloatRingBuffer { ++ private final float[] buffer; ++ private int nextIndex = 0; ++ ++ public FloatRingBuffer(int flatSize) { ++ buffer = new float[flatSize]; ++ } ++ ++ /** ++ * Loads the entire float array to the ring buffer. If the float array is longer than ring ++ * buffer's capacity, samples with lower indices in the array will be ignored. ++ */ ++ public void load(float[] newData) { ++ load(newData, 0, newData.length); ++ } ++ ++ /** ++ * Loads a slice of the float array to the ring buffer. If the float array is longer than ++ * ring buffer's capacity, samples with lower indices in the array will be ignored. ++ */ ++ public void load(float[] newData, int offset, int size) { ++ checkArgument(offset + size <= newData.length, ++ String.format( ++ "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)", ++ offset, size, newData.length)); ++ // If buffer can't hold all the data, only keep the most recent data of size ++ // buffer.length ++ if (size > buffer.length) { ++ offset = size - buffer.length; ++ size = buffer.length; ++ } ++ if (nextIndex + size < buffer.length) { ++ // No need to wrap nextIndex, just copy newData[offset:offset + size] ++ // to buffer[nextIndex:nextIndex+size] ++ arraycopy(newData, offset, buffer, nextIndex, size); ++ } else { ++ // Need to wrap nextIndex, perform copy in two chunks. ++ int firstChunkSize = buffer.length - nextIndex; ++ // First copy newData[offset:offset+firstChunkSize] to ++ // buffer[nextIndex:buffer.length] ++ arraycopy(newData, offset, buffer, nextIndex, firstChunkSize); ++ // Then copy newData[offset+firstChunkSize:offset+size] to ++ // buffer[0:size-firstChunkSize] ++ arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize); ++ } ++ ++ nextIndex = (nextIndex + size) % buffer.length; ++ } ++ ++ public ByteBuffer getBuffer() { ++ // Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms, ++ // which can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around ++ // 0.01ms), so generally we don't create direct buffer for every invocation. ++ ByteBuffer byteBuffer = ++ ByteBuffer.allocate(DataType.FLOAT32.byteSize() * buffer.length); ++ byteBuffer.order(ByteOrder.nativeOrder()); ++ FloatBuffer result = byteBuffer.asFloatBuffer(); ++ result.put(buffer, nextIndex, buffer.length - nextIndex); ++ result.put(buffer, 0, nextIndex); ++ byteBuffer.rewind(); ++ return byteBuffer; ++ } ++ ++ public int getCapacity() { ++ return buffer.length; ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java +index 776391b526b47..6090f85d99083 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java +@@ -17,6 +17,10 @@ package org.tensorflow.lite.support.common; + + import android.content.Context; + import android.content.res.AssetFileDescriptor; ++ ++import org.checkerframework.checker.nullness.qual.NonNull; ++import org.tensorflow.lite.support.common.internal.SupportPreconditions; ++ + import java.io.BufferedReader; + import java.io.FileInputStream; + import java.io.IOException; +@@ -28,160 +32,159 @@ import java.nio.channels.FileChannel; + import java.nio.charset.Charset; + import java.util.ArrayList; + import java.util.List; +-import org.checkerframework.checker.nullness.qual.NonNull; +-import org.tensorflow.lite.support.common.internal.SupportPreconditions; + + /** File I/O utilities. */ + public class FileUtil { +- private FileUtil() {} +- +- /** +- * Loads labels from the label file into a list of strings. +- * +- * <p>A legal label file is the plain text file whose contents are split into lines, and each line +- * is an individual value. The file should be in assets of the context. +- * +- * @param context The context holds assets. +- * @param filePath The path of the label file, relative with assets directory. +- * @return a list of labels. +- * @throws IOException if error occurs to open or read the file. +- */ +- @NonNull +- public static List<String> loadLabels(@NonNull Context context, @NonNull String filePath) +- throws IOException { +- return loadLabels(context, filePath, Charset.defaultCharset()); +- } +- +- /** +- * Loads labels from the label file into a list of strings. +- * +- * <p>A legal label file is the plain text file whose contents are split into lines, and each line +- * is an individual value. The empty lines will be ignored. The file should be in assets of the +- * context. +- * +- * @param context The context holds assets. +- * @param filePath The path of the label file, relative with assets directory. +- * @param cs {@code Charset} to use when decoding content of label file. +- * @return a list of labels. +- * @throws IOException if error occurs to open or read the file. +- */ +- @NonNull +- public static List<String> loadLabels( +- @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException { +- SupportPreconditions.checkNotNull(context, "Context cannot be null."); +- SupportPreconditions.checkNotNull(filePath, "File path cannot be null."); +- try (InputStream inputStream = context.getAssets().open(filePath)) { +- return loadLabels(inputStream, cs); ++ private FileUtil() {} ++ ++ /** ++ * Loads labels from the label file into a list of strings. ++ * ++ * <p>A legal label file is the plain text file whose contents are split into lines, and each ++ * line is an individual value. The file should be in assets of the context. ++ * ++ * @param context The context holds assets. ++ * @param filePath The path of the label file, relative with assets directory. ++ * @return a list of labels. ++ * @throws IOException if error occurs to open or read the file. ++ */ ++ @NonNull ++ public static List<String> loadLabels(@NonNull Context context, @NonNull String filePath) ++ throws IOException { ++ return loadLabels(context, filePath, Charset.defaultCharset()); + } +- } +- +- /** +- * Loads labels from an input stream of an opened label file. See details for label files in +- * {@link FileUtil#loadLabels(Context, String)}. +- * +- * @param inputStream the input stream of an opened label file. +- * @return a list of labels. +- * @throws IOException if error occurs to open or read the file. +- */ +- @NonNull +- public static List<String> loadLabels(@NonNull InputStream inputStream) throws IOException { +- return loadLabels(inputStream, Charset.defaultCharset()); +- } +- +- /** +- * Loads labels from an input stream of an opened label file. See details for label files in +- * {@link FileUtil#loadLabels(Context, String)}. +- * +- * @param inputStream the input stream of an opened label file. +- * @param cs {@code Charset} to use when decoding content of label file. +- * @return a list of labels. +- * @throws IOException if error occurs to open or read the file. +- */ +- @NonNull +- public static List<String> loadLabels(@NonNull InputStream inputStream, Charset cs) +- throws IOException { +- List<String> labels = new ArrayList<>(); +- try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, cs))) { +- String line; +- while ((line = reader.readLine()) != null) { +- if (line.trim().length() > 0) { +- labels.add(line); ++ ++ /** ++ * Loads labels from the label file into a list of strings. ++ * ++ * <p>A legal label file is the plain text file whose contents are split into lines, and each ++ * line is an individual value. The empty lines will be ignored. The file should be in assets of ++ * the context. ++ * ++ * @param context The context holds assets. ++ * @param filePath The path of the label file, relative with assets directory. ++ * @param cs {@code Charset} to use when decoding content of label file. ++ * @return a list of labels. ++ * @throws IOException if error occurs to open or read the file. ++ */ ++ @NonNull ++ public static List<String> loadLabels( ++ @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException { ++ SupportPreconditions.checkNotNull(context, "Context cannot be null."); ++ SupportPreconditions.checkNotNull(filePath, "File path cannot be null."); ++ try (InputStream inputStream = context.getAssets().open(filePath)) { ++ return loadLabels(inputStream, cs); + } +- } +- return labels; + } +- } +- +- /** +- * Loads a vocabulary file (a single-column text file) into a list of strings. +- * +- * <p>A vocabulary file is a single-column plain text file whose contents are split into lines, +- * and each line is an individual value. The file should be in assets of the context. +- * +- * @param context The context holds assets. +- * @param filePath The path of the vocabulary file, relative with assets directory. +- * @return a list of vocabulary words. +- * @throws IOException if error occurs to open or read the file. +- */ +- @NonNull +- public static List<String> loadSingleColumnTextFile( +- @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException { +- return loadLabels(context, filePath, cs); +- } +- +- /** +- * Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column +- * text file). +- * +- * <p>A vocabulary file is a single-column plain text file whose contents are split into lines, +- * and each line is an individual value. The file should be in assets of the context. +- * +- * @param inputStream the input stream of an opened vocabulary file. +- * @return a list of vocabulary words. +- * @throws IOException if error occurs to open or read the file. +- */ +- @NonNull +- public static List<String> loadSingleColumnTextFile(@NonNull InputStream inputStream, Charset cs) +- throws IOException { +- return loadLabels(inputStream, cs); +- } +- +- /** +- * Loads a file from the asset folder through memory mapping. +- * +- * @param context Application context to access assets. +- * @param filePath Asset path of the file. +- * @return the loaded memory mapped file. +- * @throws IOException if an I/O error occurs when loading the tflite model. +- */ +- @NonNull +- public static MappedByteBuffer loadMappedFile(@NonNull Context context, @NonNull String filePath) +- throws IOException { +- SupportPreconditions.checkNotNull(context, "Context should not be null."); +- SupportPreconditions.checkNotNull(filePath, "File path cannot be null."); +- try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath); +- FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) { +- FileChannel fileChannel = inputStream.getChannel(); +- long startOffset = fileDescriptor.getStartOffset(); +- long declaredLength = fileDescriptor.getDeclaredLength(); +- return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); ++ ++ /** ++ * Loads labels from an input stream of an opened label file. See details for label files in ++ * {@link FileUtil#loadLabels(Context, String)}. ++ * ++ * @param inputStream the input stream of an opened label file. ++ * @return a list of labels. ++ * @throws IOException if error occurs to open or read the file. ++ */ ++ @NonNull ++ public static List<String> loadLabels(@NonNull InputStream inputStream) throws IOException { ++ return loadLabels(inputStream, Charset.defaultCharset()); ++ } ++ ++ /** ++ * Loads labels from an input stream of an opened label file. See details for label files in ++ * {@link FileUtil#loadLabels(Context, String)}. ++ * ++ * @param inputStream the input stream of an opened label file. ++ * @param cs {@code Charset} to use when decoding content of label file. ++ * @return a list of labels. ++ * @throws IOException if error occurs to open or read the file. ++ */ ++ @NonNull ++ public static List<String> loadLabels(@NonNull InputStream inputStream, Charset cs) ++ throws IOException { ++ List<String> labels = new ArrayList<>(); ++ try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, cs))) { ++ String line; ++ while ((line = reader.readLine()) != null) { ++ if (line.trim().length() > 0) { ++ labels.add(line); ++ } ++ } ++ return labels; ++ } ++ } ++ ++ /** ++ * Loads a vocabulary file (a single-column text file) into a list of strings. ++ * ++ * <p>A vocabulary file is a single-column plain text file whose contents are split into lines, ++ * and each line is an individual value. The file should be in assets of the context. ++ * ++ * @param context The context holds assets. ++ * @param filePath The path of the vocabulary file, relative with assets directory. ++ * @return a list of vocabulary words. ++ * @throws IOException if error occurs to open or read the file. ++ */ ++ @NonNull ++ public static List<String> loadSingleColumnTextFile( ++ @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException { ++ return loadLabels(context, filePath, cs); ++ } ++ ++ /** ++ * Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column ++ * text file). ++ * ++ * <p>A vocabulary file is a single-column plain text file whose contents are split into lines, ++ * and each line is an individual value. The file should be in assets of the context. ++ * ++ * @param inputStream the input stream of an opened vocabulary file. ++ * @return a list of vocabulary words. ++ * @throws IOException if error occurs to open or read the file. ++ */ ++ @NonNull ++ public static List<String> loadSingleColumnTextFile( ++ @NonNull InputStream inputStream, Charset cs) throws IOException { ++ return loadLabels(inputStream, cs); ++ } ++ ++ /** ++ * Loads a file from the asset folder through memory mapping. ++ * ++ * @param context Application context to access assets. ++ * @param filePath Asset path of the file. ++ * @return the loaded memory mapped file. ++ * @throws IOException if an I/O error occurs when loading the tflite model. ++ */ ++ @NonNull ++ public static MappedByteBuffer loadMappedFile( ++ @NonNull Context context, @NonNull String filePath) throws IOException { ++ SupportPreconditions.checkNotNull(context, "Context should not be null."); ++ SupportPreconditions.checkNotNull(filePath, "File path cannot be null."); ++ try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath); ++ FileInputStream inputStream = ++ new FileInputStream(fileDescriptor.getFileDescriptor())) { ++ FileChannel fileChannel = inputStream.getChannel(); ++ long startOffset = fileDescriptor.getStartOffset(); ++ long declaredLength = fileDescriptor.getDeclaredLength(); ++ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); ++ } ++ } ++ ++ /** ++ * Loads a binary file from the asset folder. ++ * ++ * @param context Application context to access assets. ++ * @param filePath Asset path of the file. ++ * @return the byte array for the binary file. ++ * @throws IOException if an I/O error occurs when loading file. ++ */ ++ @NonNull ++ public static byte[] loadByteFromFile(@NonNull Context context, @NonNull String filePath) ++ throws IOException { ++ ByteBuffer buffer = loadMappedFile(context, filePath); ++ byte[] byteArray = new byte[buffer.remaining()]; ++ buffer.get(byteArray); ++ return byteArray; + } +- } +- +- /** +- * Loads a binary file from the asset folder. +- * +- * @param context Application context to access assets. +- * @param filePath Asset path of the file. +- * @return the byte array for the binary file. +- * @throws IOException if an I/O error occurs when loading file. +- */ +- @NonNull +- public static byte[] loadByteFromFile(@NonNull Context context, @NonNull String filePath) +- throws IOException { +- ByteBuffer buffer = loadMappedFile(context, filePath); +- byte[] byteArray = new byte[buffer.remaining()]; +- buffer.get(byteArray); +- return byteArray; +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java +index 38dfe8818cbbc..45dfc4d9d868b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java +@@ -20,12 +20,11 @@ package org.tensorflow.lite.support.common; + * @param <T> The class which Operator handles. + */ + public interface Operator<T> { +- +- /** +- * Applies an operation on a T object, returning a T object. +- * +- * <p>Note: The returned object could probably be the same one with given input, and given input +- * could probably be changed. +- */ +- T apply(T x); ++ /** ++ * Applies an operation on a T object, returning a T object. ++ * ++ * <p>Note: The returned object could probably be the same one with given input, and given input ++ * could probably be changed. ++ */ ++ T apply(T x); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java +index 9d0024b2f5887..a94adb89b8666 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java +@@ -17,5 +17,5 @@ package org.tensorflow.lite.support.common; + + /** Processes T object with prepared {@code Operator<T>}. */ + public interface Processor<T> { +- T process(T input); ++ T process(T input); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java +index af688c863c254..aa900b7c93d87 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java +@@ -15,13 +15,14 @@ limitations under the License. + + package org.tensorflow.lite.support.common; + ++import org.checkerframework.checker.nullness.qual.NonNull; ++import org.tensorflow.lite.support.common.internal.SupportPreconditions; ++ + import java.util.ArrayList; + import java.util.Collections; + import java.util.HashMap; + import java.util.List; + import java.util.Map; +-import org.checkerframework.checker.nullness.qual.NonNull; +-import org.tensorflow.lite.support.common.internal.SupportPreconditions; + + /** + * A processor base class that chains a serial of {@code Operator<T>} and executes them. +@@ -32,52 +33,50 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions; + * @param <T> The type that the Operator is handling. + */ + public class SequentialProcessor<T> implements Processor<T> { ++ /** List of operators added to this {@link SequentialProcessor}. */ ++ protected final List<Operator<T>> operatorList; ++ /** ++ * The {@link Map} between the operator name and the corresponding op indexes in {@code ++ * operatorList}. An operator may be added multiple times into this {@link SequentialProcessor}. ++ */ ++ protected final Map<String, List<Integer>> operatorIndex; + +- /** List of operators added to this {@link SequentialProcessor}. */ +- protected final List<Operator<T>> operatorList; +- /** +- * The {@link Map} between the operator name and the corresponding op indexes in {@code +- * operatorList}. An operator may be added multiple times into this {@link SequentialProcessor}. +- */ +- protected final Map<String, List<Integer>> operatorIndex; +- +- protected SequentialProcessor(Builder<T> builder) { +- operatorList = builder.operatorList; +- operatorIndex = Collections.unmodifiableMap(builder.operatorIndex); +- } ++ protected SequentialProcessor(Builder<T> builder) { ++ operatorList = builder.operatorList; ++ operatorIndex = Collections.unmodifiableMap(builder.operatorIndex); ++ } + +- @Override +- public T process(T x) { +- for (Operator<T> op : operatorList) { +- x = op.apply(x); ++ @Override ++ public T process(T x) { ++ for (Operator<T> op : operatorList) { ++ x = op.apply(x); ++ } ++ return x; + } +- return x; +- } + +- /** The inner builder class to build a Sequential Processor. */ +- protected static class Builder<T> { ++ /** The inner builder class to build a Sequential Processor. */ ++ protected static class Builder<T> { ++ private final List<Operator<T>> operatorList; ++ private final Map<String, List<Integer>> operatorIndex; + +- private final List<Operator<T>> operatorList; +- private final Map<String, List<Integer>> operatorIndex; ++ protected Builder() { ++ operatorList = new ArrayList<>(); ++ operatorIndex = new HashMap<>(); ++ } + +- protected Builder() { +- operatorList = new ArrayList<>(); +- operatorIndex = new HashMap<>(); +- } +- +- public Builder<T> add(@NonNull Operator<T> op) { +- SupportPreconditions.checkNotNull(op, "Adding null Op is illegal."); +- operatorList.add(op); +- String operatorName = op.getClass().getName(); +- if (!operatorIndex.containsKey(operatorName)) { +- operatorIndex.put(operatorName, new ArrayList<Integer>()); +- } +- operatorIndex.get(operatorName).add(operatorList.size() - 1); +- return this; +- } ++ public Builder<T> add(@NonNull Operator<T> op) { ++ SupportPreconditions.checkNotNull(op, "Adding null Op is illegal."); ++ operatorList.add(op); ++ String operatorName = op.getClass().getName(); ++ if (!operatorIndex.containsKey(operatorName)) { ++ operatorIndex.put(operatorName, new ArrayList<Integer>()); ++ } ++ operatorIndex.get(operatorName).add(operatorList.size() - 1); ++ return this; ++ } + +- public SequentialProcessor<T> build() { +- return new SequentialProcessor<T>(this); ++ public SequentialProcessor<T> build() { ++ return new SequentialProcessor<T>(this); ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java +index d1b7021df257c..692c2d479dcce 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java +@@ -21,7 +21,7 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + * Applies some operation on TensorBuffers. + */ + public interface TensorOperator extends Operator<TensorBuffer> { +- /** @see Operator#apply(Object) . */ +- @Override +- TensorBuffer apply(TensorBuffer input); ++ /** @see Operator#apply(Object) . */ ++ @Override ++ TensorBuffer apply(TensorBuffer input); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java +index b9d3d620e9c52..4391c4523527f 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java +@@ -32,37 +32,36 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + * @see TensorProcessor#process to apply the processor on a {@link TensorBuffer}. + */ + public class TensorProcessor extends SequentialProcessor<TensorBuffer> { +- private TensorProcessor(Builder builder) { +- super(builder); +- } +- +- /** The Builder to create an {@link TensorProcessor}, which could be executed later. */ +- public static class Builder extends SequentialProcessor.Builder<TensorBuffer> { +- +- /** +- * Creates a Builder to build {@link TensorProcessor}. +- * +- * @see #add(TensorOperator) to add an Op. +- * @see #build() to complete the building process and get a built Processor. +- */ +- public Builder() { +- super(); ++ private TensorProcessor(Builder builder) { ++ super(builder); + } + +- /** +- * Adds an {@link TensorOperator} into the Operator chain. +- * +- * @param op the Operator instance to be executed then. +- */ +- public TensorProcessor.Builder add(TensorOperator op) { +- super.add(op); +- return this; +- } ++ /** The Builder to create an {@link TensorProcessor}, which could be executed later. */ ++ public static class Builder extends SequentialProcessor.Builder<TensorBuffer> { ++ /** ++ * Creates a Builder to build {@link TensorProcessor}. ++ * ++ * @see #add(TensorOperator) to add an Op. ++ * @see #build() to complete the building process and get a built Processor. ++ */ ++ public Builder() { ++ super(); ++ } ++ ++ /** ++ * Adds an {@link TensorOperator} into the Operator chain. ++ * ++ * @param op the Operator instance to be executed then. ++ */ ++ public TensorProcessor.Builder add(TensorOperator op) { ++ super.add(op); ++ return this; ++ } + +- /** Completes the building process and gets the {@link TensorProcessor} instance. */ +- @Override +- public TensorProcessor build() { +- return new TensorProcessor(this); ++ /** Completes the building process and gets the {@link TensorProcessor} instance. */ ++ @Override ++ public TensorProcessor build() { ++ return new TensorProcessor(this); ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java +index e3e962a5f8252..29faa545b71f2 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java +@@ -19,164 +19,168 @@ import org.checkerframework.checker.nullness.qual.Nullable; + + /** Static error checking util methods. */ + public final class SupportPreconditions { +- /** +- * Ensures that an object reference passed as a parameter to the calling method is not null. +- * +- * @param reference an object reference +- * @return the non-null reference that was validated +- * @throws NullPointerException if {@code reference} is null +- */ +- public static <T extends Object> T checkNotNull(T reference) { +- if (reference == null) { +- throw new NullPointerException("The object reference is null."); ++ /** ++ * Ensures that an object reference passed as a parameter to the calling method is not null. ++ * ++ * @param reference an object reference ++ * @return the non-null reference that was validated ++ * @throws NullPointerException if {@code reference} is null ++ */ ++ public static <T extends Object> T checkNotNull(T reference) { ++ if (reference == null) { ++ throw new NullPointerException("The object reference is null."); ++ } ++ return reference; + } +- return reference; +- } +- +- /** +- * Ensures that an object reference passed as a parameter to the calling method is not null. +- * +- * @param reference an object reference +- * @param errorMessage the exception message to use if the check fails; will be converted to a +- * string using {@link String#valueOf(Object)} +- * @return the non-null reference that was validated +- * @throws NullPointerException if {@code reference} is null +- */ +- public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) { +- if (reference == null) { +- throw new NullPointerException(String.valueOf(errorMessage)); ++ ++ /** ++ * Ensures that an object reference passed as a parameter to the calling method is not null. ++ * ++ * @param reference an object reference ++ * @param errorMessage the exception message to use if the check fails; will be converted to a ++ * string using {@link String#valueOf(Object)} ++ * @return the non-null reference that was validated ++ * @throws NullPointerException if {@code reference} is null ++ */ ++ public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) { ++ if (reference == null) { ++ throw new NullPointerException(String.valueOf(errorMessage)); ++ } ++ return reference; ++ } ++ ++ /** ++ * Ensures that the given String is not empty and not null. ++ * ++ * @param string the String to test ++ * @return the non-null non-empty String that was validated ++ * @throws IllegalArgumentException if {@code string} is null or empty ++ */ ++ public static String checkNotEmpty(String string) { ++ if (string == null || string.length() == 0) { ++ throw new IllegalArgumentException("Given String is empty or null."); ++ } ++ return string; + } +- return reference; +- } +- +- /** +- * Ensures that the given String is not empty and not null. +- * +- * @param string the String to test +- * @return the non-null non-empty String that was validated +- * @throws IllegalArgumentException if {@code string} is null or empty +- */ +- public static String checkNotEmpty(String string) { +- if (string == null || string.length() == 0) { +- throw new IllegalArgumentException("Given String is empty or null."); ++ ++ /** ++ * Ensures that the given String is not empty and not null. ++ * ++ * @param string the String to test ++ * @param errorMessage the exception message to use if the check fails; will be converted to a ++ * string using {@link String#valueOf(Object)} ++ * @return the non-null non-empty String that was validated ++ * @throws IllegalArgumentException if {@code string} is null or empty ++ */ ++ public static String checkNotEmpty(String string, Object errorMessage) { ++ if (string == null || string.length() == 0) { ++ throw new IllegalArgumentException(String.valueOf(errorMessage)); ++ } ++ return string; + } +- return string; +- } +- +- /** +- * Ensures that the given String is not empty and not null. +- * +- * @param string the String to test +- * @param errorMessage the exception message to use if the check fails; will be converted to a +- * string using {@link String#valueOf(Object)} +- * @return the non-null non-empty String that was validated +- * @throws IllegalArgumentException if {@code string} is null or empty +- */ +- public static String checkNotEmpty(String string, Object errorMessage) { +- if (string == null || string.length() == 0) { +- throw new IllegalArgumentException(String.valueOf(errorMessage)); ++ ++ /** ++ * Ensures the truth of an expression involving one or more parameters to the calling method. ++ * ++ * @param expression a boolean expression. ++ * @throws IllegalArgumentException if {@code expression} is false. ++ */ ++ public static void checkArgument(boolean expression) { ++ if (!expression) { ++ throw new IllegalArgumentException(); ++ } + } +- return string; +- } +- +- /** +- * Ensures the truth of an expression involving one or more parameters to the calling method. +- * +- * @param expression a boolean expression. +- * @throws IllegalArgumentException if {@code expression} is false. +- */ +- public static void checkArgument(boolean expression) { +- if (!expression) { +- throw new IllegalArgumentException(); ++ ++ /** ++ * Ensures the truth of an expression involving one or more parameters to the calling method. ++ * ++ * @param expression a boolean expression. ++ * @param errorMessage the exception message to use if the check fails; will be converted to a ++ * string using {@link String#valueOf(Object)}. ++ * @throws IllegalArgumentException if {@code expression} is false. ++ */ ++ public static void checkArgument(boolean expression, @Nullable Object errorMessage) { ++ if (!expression) { ++ throw new IllegalArgumentException(String.valueOf(errorMessage)); ++ } + } +- } +- +- /** +- * Ensures the truth of an expression involving one or more parameters to the calling method. +- * +- * @param expression a boolean expression. +- * @param errorMessage the exception message to use if the check fails; will be converted to a +- * string using {@link String#valueOf(Object)}. +- * @throws IllegalArgumentException if {@code expression} is false. +- */ +- public static void checkArgument(boolean expression, @Nullable Object errorMessage) { +- if (!expression) { +- throw new IllegalArgumentException(String.valueOf(errorMessage)); ++ ++ /** ++ * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of ++ * size ++ * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive. ++ * ++ * @param index a user-supplied index identifying an element of an array, list or string ++ * @param size the size of that array, list or string ++ * @return the value of {@code index} ++ * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code ++ * size} ++ * @throws IllegalArgumentException if {@code size} is negative ++ */ ++ public static int checkElementIndex(int index, int size) { ++ return checkElementIndex(index, size, "index"); + } +- } +- +- /** +- * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size +- * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive. +- * +- * @param index a user-supplied index identifying an element of an array, list or string +- * @param size the size of that array, list or string +- * @return the value of {@code index} +- * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size} +- * @throws IllegalArgumentException if {@code size} is negative +- */ +- public static int checkElementIndex(int index, int size) { +- return checkElementIndex(index, size, "index"); +- } +- +- /** +- * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size +- * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive. +- * +- * @param index a user-supplied index identifying an element of an array, list or string +- * @param size the size of that array, list or string +- * @param desc the text to use to describe this index in an error message +- * @return the value of {@code index} +- * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size} +- * @throws IllegalArgumentException if {@code size} is negative +- */ +- public static int checkElementIndex(int index, int size, @Nullable String desc) { +- // Carefully optimized for execution by hotspot (explanatory comment above) +- if (index < 0 || index >= size) { +- throw new IndexOutOfBoundsException(badElementIndex(index, size, desc)); ++ ++ /** ++ * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of ++ * size ++ * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive. ++ * ++ * @param index a user-supplied index identifying an element of an array, list or string ++ * @param size the size of that array, list or string ++ * @param desc the text to use to describe this index in an error message ++ * @return the value of {@code index} ++ * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code ++ * size} ++ * @throws IllegalArgumentException if {@code size} is negative ++ */ ++ public static int checkElementIndex(int index, int size, @Nullable String desc) { ++ // Carefully optimized for execution by hotspot (explanatory comment above) ++ if (index < 0 || index >= size) { ++ throw new IndexOutOfBoundsException(badElementIndex(index, size, desc)); ++ } ++ return index; + } +- return index; +- } +- +- /** +- * Ensures the truth of an expression involving the state of the calling instance, but not +- * involving any parameters to the calling method. +- * +- * @param expression a boolean expression +- * @throws IllegalStateException if {@code expression} is false +- */ +- public static void checkState(boolean expression) { +- if (!expression) { +- throw new IllegalStateException(); ++ ++ /** ++ * Ensures the truth of an expression involving the state of the calling instance, but not ++ * involving any parameters to the calling method. ++ * ++ * @param expression a boolean expression ++ * @throws IllegalStateException if {@code expression} is false ++ */ ++ public static void checkState(boolean expression) { ++ if (!expression) { ++ throw new IllegalStateException(); ++ } + } +- } +- +- /** +- * Ensures the truth of an expression involving the state of the calling instance, but not +- * involving any parameters to the calling method. +- * +- * @param expression a boolean expression +- * @param errorMessage the exception message to use if the check fails; will be converted to a +- * string using {@link String#valueOf(Object)} +- * @throws IllegalStateException if {@code expression} is false +- */ +- public static void checkState(boolean expression, @Nullable Object errorMessage) { +- if (!expression) { +- throw new IllegalStateException(String.valueOf(errorMessage)); ++ ++ /** ++ * Ensures the truth of an expression involving the state of the calling instance, but not ++ * involving any parameters to the calling method. ++ * ++ * @param expression a boolean expression ++ * @param errorMessage the exception message to use if the check fails; will be converted to a ++ * string using {@link String#valueOf(Object)} ++ * @throws IllegalStateException if {@code expression} is false ++ */ ++ public static void checkState(boolean expression, @Nullable Object errorMessage) { ++ if (!expression) { ++ throw new IllegalStateException(String.valueOf(errorMessage)); ++ } + } +- } +- +- private static String badElementIndex(int index, int size, @Nullable String desc) { +- if (index < 0) { +- return String.format("%s (%s) must not be negative", desc, index); +- } else if (size < 0) { +- throw new IllegalArgumentException("negative size: " + size); +- } else { // index >= size +- return String.format("%s (%s) must be less than size (%s)", desc, index, size); ++ ++ private static String badElementIndex(int index, int size, @Nullable String desc) { ++ if (index < 0) { ++ return String.format("%s (%s) must not be negative", desc, index); ++ } else if (size < 0) { ++ throw new IllegalArgumentException("negative size: " + size); ++ } else { // index >= size ++ return String.format("%s (%s) must be less than size (%s)", desc, index, size); ++ } + } +- } + +- private SupportPreconditions() { +- throw new AssertionError("SupportPreconditions is Uninstantiable."); +- } ++ private SupportPreconditions() { ++ throw new AssertionError("SupportPreconditions is Uninstantiable."); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java +index 742a1ef90994c..a14cd1f1e503d 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java +@@ -22,34 +22,33 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + + /** Casts a {@link TensorBuffer} to a specified data type. */ + public class CastOp implements TensorOperator { ++ private final DataType destinationType; ++ ++ /** ++ * Constructs a CastOp. ++ * ++ * <p>Note: For only converting type for a certain {@link TensorBuffer} on-the-fly rather than ++ * in a processor, please directly use {@link TensorBuffer#createFrom(TensorBuffer, DataType)}. ++ * ++ * <p>When this Op is executed, if the original {@link TensorBuffer} is already in {@code ++ * destinationType}, the original buffer will be directly returned. ++ * ++ * @param destinationType The type of the casted {@link TensorBuffer}. ++ * @throws IllegalArgumentException if {@code destinationType} is neither {@link DataType#UINT8} ++ * nor {@link DataType#FLOAT32}. ++ */ ++ public CastOp(DataType destinationType) { ++ SupportPreconditions.checkArgument( ++ destinationType == DataType.UINT8 || destinationType == DataType.FLOAT32, ++ "Destination type " + destinationType + " is not supported."); ++ this.destinationType = destinationType; ++ } + +- private final DataType destinationType; +- +- /** +- * Constructs a CastOp. +- * +- * <p>Note: For only converting type for a certain {@link TensorBuffer} on-the-fly rather than in +- * a processor, please directly use {@link TensorBuffer#createFrom(TensorBuffer, DataType)}. +- * +- * <p>When this Op is executed, if the original {@link TensorBuffer} is already in {@code +- * destinationType}, the original buffer will be directly returned. +- * +- * @param destinationType The type of the casted {@link TensorBuffer}. +- * @throws IllegalArgumentException if {@code destinationType} is neither {@link DataType#UINT8} +- * nor {@link DataType#FLOAT32}. +- */ +- public CastOp(DataType destinationType) { +- SupportPreconditions.checkArgument( +- destinationType == DataType.UINT8 || destinationType == DataType.FLOAT32, +- "Destination type " + destinationType + " is not supported."); +- this.destinationType = destinationType; +- } +- +- @Override +- public TensorBuffer apply(TensorBuffer input) { +- if (input.getDataType() == destinationType) { +- return input; ++ @Override ++ public TensorBuffer apply(TensorBuffer input) { ++ if (input.getDataType() == destinationType) { ++ return input; ++ } ++ return TensorBuffer.createFrom(input, destinationType); + } +- return TensorBuffer.createFrom(input, destinationType); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java +index 1881747870be3..8b6d183189b7f 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java +@@ -32,9 +32,8 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + * as 0. + */ + public class DequantizeOp extends NormalizeOp implements TensorOperator { +- +- public DequantizeOp(float zeroPoint, float scale) { +- // Quantization: f = (q - z) * s +- super(zeroPoint, 1 / scale); +- } ++ public DequantizeOp(float zeroPoint, float scale) { ++ // Quantization: f = (q - z) * s ++ super(zeroPoint, 1 / scale); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java +index cff4d0b55d60a..912df13b59cec 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java +@@ -26,135 +26,134 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat; + * Normalizes a {@link TensorBuffer} with given mean and stddev: output = (input - mean) / stddev. + */ + public class NormalizeOp implements TensorOperator { ++ // mean.length should always be equal to stddev.length and always >= 1. ++ private final float[] mean; ++ private final float[] stddev; ++ private final int numChannels; ++ private final boolean isIdentityOp; + +- // mean.length should always be equal to stddev.length and always >= 1. +- private final float[] mean; +- private final float[] stddev; +- private final int numChannels; +- private final boolean isIdentityOp; ++ /** ++ * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which ++ * satisfies: ++ * ++ * <pre> ++ * output = (input - mean) / stddev ++ * </pre> ++ * ++ * <p>In the following two cases, reset {@code mean} to 0 and {@code stddev} to 1 to bypass the ++ * normalization. <br> ++ * 1. Both {@code mean} and {code stddev} are 0. <br> ++ * 2. {@code mean} is 0 and {stddev} is Infinity. ++ * ++ * <p>Note: If {@code mean} is set to 0 and {@code stddev} is set to 1, no computation will ++ * happen, and original input will be directly returned in execution. ++ * ++ * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at ++ * present, except when the input is a {@link DataType#UINT8} tensor, {@code mean} is set to 0 ++ * and ++ * {@code stddev} is set to 1, so that the original {@link DataType#UINT8} tensor is returned. ++ * ++ * @param mean the mean value to be subtracted first. ++ * @param stddev the standard deviation value to divide then. ++ * @throws IllegalArgumentException if {@code stddev} is zero. ++ */ ++ public NormalizeOp(float mean, float stddev) { ++ // Make exceptions to the cases that ++ // 1. Both mean and stddev are 0.0f. This may happen when reading the normalization ++ // parameters from a tensor which does not have the values populated in the metadata. The ++ // same situation may also happen to the quantization parameters. ++ // 2. mean is 0.0f and stddev is Infinity. This may happen when reading the quantization ++ // parameters from a tensor which does not have the values populated in the metadata, and ++ // then passing the parameters into the DequantizeOp. Bypass both of the two cases, by ++ // reseting stddev to 1.0f. ++ if (mean == 0.0f && (stddev == 0.0f || Float.isInfinite(stddev))) { ++ stddev = 1.0f; ++ } + +- /** +- * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which +- * satisfies: +- * +- * <pre> +- * output = (input - mean) / stddev +- * </pre> +- * +- * <p>In the following two cases, reset {@code mean} to 0 and {@code stddev} to 1 to bypass the +- * normalization. <br> +- * 1. Both {@code mean} and {code stddev} are 0. <br> +- * 2. {@code mean} is 0 and {stddev} is Infinity. +- * +- * <p>Note: If {@code mean} is set to 0 and {@code stddev} is set to 1, no computation will +- * happen, and original input will be directly returned in execution. +- * +- * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at +- * present, except when the input is a {@link DataType#UINT8} tensor, {@code mean} is set to 0 and +- * {@code stddev} is set to 1, so that the original {@link DataType#UINT8} tensor is returned. +- * +- * @param mean the mean value to be subtracted first. +- * @param stddev the standard deviation value to divide then. +- * @throws IllegalArgumentException if {@code stddev} is zero. +- */ +- public NormalizeOp(float mean, float stddev) { +- // Make exceptions to the cases that +- // 1. Both mean and stddev are 0.0f. This may happen when reading the normalization parameters +- // from a tensor which does not have the values populated in the metadata. The same situation +- // may also happen to the quantization parameters. +- // 2. mean is 0.0f and stddev is Infinity. This may happen when reading the quantization +- // parameters from a tensor which does not have the values populated in the metadata, and then +- // passing the parameters into the DequantizeOp. +- // Bypass both of the two cases, by reseting stddev to 1.0f. +- if (mean == 0.0f && (stddev == 0.0f || Float.isInfinite(stddev))) { +- stddev = 1.0f; +- } ++ SupportPreconditions.checkArgument(stddev != 0.0f, "Stddev cannot be zero."); ++ boolean meansIsZeroAndDevsIs1 = false; ++ if (mean == 0.0f && stddev == 1.0f) { ++ meansIsZeroAndDevsIs1 = true; ++ } + +- SupportPreconditions.checkArgument(stddev != 0.0f, "Stddev cannot be zero."); +- boolean meansIsZeroAndDevsIs1 = false; +- if (mean == 0.0f && stddev == 1.0f) { +- meansIsZeroAndDevsIs1 = true; ++ this.isIdentityOp = meansIsZeroAndDevsIs1; ++ this.mean = new float[] {mean}; ++ this.stddev = new float[] {stddev}; ++ this.numChannels = 1; + } + +- this.isIdentityOp = meansIsZeroAndDevsIs1; +- this.mean = new float[] {mean}; +- this.stddev = new float[] {stddev}; +- this.numChannels = 1; +- } +- +- /** +- * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which +- * satisfies: +- * +- * <pre> +- * // Pseudo code. [...][i] means a certain element whose channel id is i. +- * output[...][i] = (input[...][i] - mean[i]) / stddev[i] +- * </pre> +- * +- * <p>Note: If all values in {@code mean} are set to 0 and all {@code stddev} are set to 1, no +- * computation will happen, and original input will be directly returned in execution. +- * +- * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at +- * present, except that the input is a {@link DataType#UINT8} tensor, all {@code mean} are set to +- * 0 and all {@code stddev} are set to 1. +- * +- * @param mean the mean values to be subtracted first for each channel. +- * @param stddev the standard deviation values to divide then for each channel. +- * @throws IllegalArgumentException if any {@code stddev} is zero, or {@code mean} has different +- * number of elements with {@code stddev}, or any of them is empty. +- */ +- public NormalizeOp(@NonNull float[] mean, @NonNull float[] stddev) { +- SupportPreconditions.checkNotNull(mean, "Mean cannot be null"); +- SupportPreconditions.checkNotNull(stddev, "Stddev cannot be null"); +- SupportPreconditions.checkArgument( +- mean.length == stddev.length, +- "Per channel normalization requires same number of means and stddevs"); +- SupportPreconditions.checkArgument(mean.length > 0, "Means and stddevs are empty."); +- this.mean = mean.clone(); +- this.stddev = stddev.clone(); +- boolean allMeansAreZeroAndAllDevsAre1 = true; +- this.numChannels = mean.length; +- for (int i = 0; i < numChannels; i++) { +- SupportPreconditions.checkArgument(this.stddev[i] != 0, "Stddev cannot be zero."); +- if (this.stddev[i] != 1 || this.mean[i] != 0) { +- allMeansAreZeroAndAllDevsAre1 = false; +- } ++ /** ++ * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which ++ * satisfies: ++ * ++ * <pre> ++ * // Pseudo code. [...][i] means a certain element whose channel id is i. ++ * output[...][i] = (input[...][i] - mean[i]) / stddev[i] ++ * </pre> ++ * ++ * <p>Note: If all values in {@code mean} are set to 0 and all {@code stddev} are set to 1, no ++ * computation will happen, and original input will be directly returned in execution. ++ * ++ * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at ++ * present, except that the input is a {@link DataType#UINT8} tensor, all {@code mean} are set ++ * to 0 and all {@code stddev} are set to 1. ++ * ++ * @param mean the mean values to be subtracted first for each channel. ++ * @param stddev the standard deviation values to divide then for each channel. ++ * @throws IllegalArgumentException if any {@code stddev} is zero, or {@code mean} has different ++ * number of elements with {@code stddev}, or any of them is empty. ++ */ ++ public NormalizeOp(@NonNull float[] mean, @NonNull float[] stddev) { ++ SupportPreconditions.checkNotNull(mean, "Mean cannot be null"); ++ SupportPreconditions.checkNotNull(stddev, "Stddev cannot be null"); ++ SupportPreconditions.checkArgument(mean.length == stddev.length, ++ "Per channel normalization requires same number of means and stddevs"); ++ SupportPreconditions.checkArgument(mean.length > 0, "Means and stddevs are empty."); ++ this.mean = mean.clone(); ++ this.stddev = stddev.clone(); ++ boolean allMeansAreZeroAndAllDevsAre1 = true; ++ this.numChannels = mean.length; ++ for (int i = 0; i < numChannels; i++) { ++ SupportPreconditions.checkArgument(this.stddev[i] != 0, "Stddev cannot be zero."); ++ if (this.stddev[i] != 1 || this.mean[i] != 0) { ++ allMeansAreZeroAndAllDevsAre1 = false; ++ } ++ } ++ this.isIdentityOp = allMeansAreZeroAndAllDevsAre1; + } +- this.isIdentityOp = allMeansAreZeroAndAllDevsAre1; +- } + +- /** +- * Applies the defined normalization on given tensor and returns the result. +- * +- * <p>Note: {@code input} is possibly the same instance with the output. +- * +- * @param input input tensor. It may be the same instance with the output. +- * @return output tensor. +- */ +- @Override +- @NonNull +- public TensorBuffer apply(@NonNull TensorBuffer input) { +- if (isIdentityOp) { +- return input; +- } +- int[] shape = input.getShape(); +- SupportPreconditions.checkArgument( +- numChannels == 1 || (shape.length != 0 && shape[shape.length - 1] == numChannels), +- "Number of means (stddevs) is not same with number of channels (size of last axis)."); +- // TODO(136750944): Eliminate the array copy here. +- float[] values = input.getFloatArray(); +- int j = 0; +- for (int i = 0; i < values.length; i++) { +- values[i] = (values[i] - mean[j]) / stddev[j]; +- j = (j + 1) % numChannels; +- } +- TensorBuffer output; +- if (input.isDynamic()) { +- output = TensorBufferFloat.createDynamic(DataType.FLOAT32); +- } else { +- output = TensorBufferFloat.createFixedSize(shape, DataType.FLOAT32); ++ /** ++ * Applies the defined normalization on given tensor and returns the result. ++ * ++ * <p>Note: {@code input} is possibly the same instance with the output. ++ * ++ * @param input input tensor. It may be the same instance with the output. ++ * @return output tensor. ++ */ ++ @Override ++ @NonNull ++ public TensorBuffer apply(@NonNull TensorBuffer input) { ++ if (isIdentityOp) { ++ return input; ++ } ++ int[] shape = input.getShape(); ++ SupportPreconditions.checkArgument( ++ numChannels == 1 || (shape.length != 0 && shape[shape.length - 1] == numChannels), ++ "Number of means (stddevs) is not same with number of channels (size of last axis)."); ++ // TODO(136750944): Eliminate the array copy here. ++ float[] values = input.getFloatArray(); ++ int j = 0; ++ for (int i = 0; i < values.length; i++) { ++ values[i] = (values[i] - mean[j]) / stddev[j]; ++ j = (j + 1) % numChannels; ++ } ++ TensorBuffer output; ++ if (input.isDynamic()) { ++ output = TensorBufferFloat.createDynamic(DataType.FLOAT32); ++ } else { ++ output = TensorBufferFloat.createFixedSize(shape, DataType.FLOAT32); ++ } ++ output.loadArray(values, shape); ++ return output; + } +- output.loadArray(values, shape); +- return output; +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java +index 8b3e82aee13ef..84cb856fd4ed9 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java +@@ -33,9 +33,8 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + * as 0. + */ + public class QuantizeOp extends NormalizeOp implements TensorOperator { +- +- public QuantizeOp(float zeroPoint, float scale) { +- // Quantization: f = (q - z) * s, i.e. q = f / s + z = (f - (-z * s)) / s +- super(-zeroPoint * scale, scale); +- } ++ public QuantizeOp(float zeroPoint, float scale) { ++ // Quantization: f = (q - z) * s, i.e. q = f / s + z = (f - (-z * s)) / s ++ super(-zeroPoint * scale, scale); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java +index 9bee78d139efa..f9b6a1f874bff 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java +@@ -21,67 +21,67 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c + import android.graphics.Bitmap; + import android.graphics.Bitmap.Config; + import android.media.Image; ++ + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + + /** Holds a {@link Bitmap} and converts it to other image formats as needed. */ + final class BitmapContainer implements ImageContainer { +- +- private final Bitmap bitmap; +- +- /** +- * Creates a {@link BitmapContainer} object with ARGB_8888 {@link Bitmap}. +- * +- * @throws IllegalArgumentException if the bitmap configuration is not ARGB_8888 +- */ +- static BitmapContainer create(Bitmap bitmap) { +- return new BitmapContainer(bitmap); +- } +- +- private BitmapContainer(Bitmap bitmap) { +- checkNotNull(bitmap, "Cannot load null bitmap."); +- checkArgument( +- bitmap.getConfig().equals(Config.ARGB_8888), "Only supports loading ARGB_8888 bitmaps."); +- this.bitmap = bitmap; +- } +- +- @Override +- public BitmapContainer clone() { +- return create(bitmap.copy(bitmap.getConfig(), bitmap.isMutable())); +- } +- +- @Override +- public Bitmap getBitmap() { +- // Not making a defensive copy for performance considerations. During image processing, +- // users may need to set and get the bitmap many times. +- return bitmap; +- } +- +- @Override +- public TensorBuffer getTensorBuffer(DataType dataType) { +- TensorBuffer buffer = TensorBuffer.createDynamic(dataType); +- ImageConversions.convertBitmapToTensorBuffer(bitmap, buffer); +- return buffer; +- } +- +- @Override +- public Image getMediaImage() { +- throw new UnsupportedOperationException( +- "Converting from Bitmap to android.media.Image is unsupported."); +- } +- +- @Override +- public int getWidth() { +- return bitmap.getWidth(); +- } +- +- @Override +- public int getHeight() { +- return bitmap.getHeight(); +- } +- +- @Override +- public ColorSpaceType getColorSpaceType() { +- return ColorSpaceType.fromBitmapConfig(bitmap.getConfig()); +- } ++ private final Bitmap bitmap; ++ ++ /** ++ * Creates a {@link BitmapContainer} object with ARGB_8888 {@link Bitmap}. ++ * ++ * @throws IllegalArgumentException if the bitmap configuration is not ARGB_8888 ++ */ ++ static BitmapContainer create(Bitmap bitmap) { ++ return new BitmapContainer(bitmap); ++ } ++ ++ private BitmapContainer(Bitmap bitmap) { ++ checkNotNull(bitmap, "Cannot load null bitmap."); ++ checkArgument(bitmap.getConfig().equals(Config.ARGB_8888), ++ "Only supports loading ARGB_8888 bitmaps."); ++ this.bitmap = bitmap; ++ } ++ ++ @Override ++ public BitmapContainer clone() { ++ return create(bitmap.copy(bitmap.getConfig(), bitmap.isMutable())); ++ } ++ ++ @Override ++ public Bitmap getBitmap() { ++ // Not making a defensive copy for performance considerations. During image processing, ++ // users may need to set and get the bitmap many times. ++ return bitmap; ++ } ++ ++ @Override ++ public TensorBuffer getTensorBuffer(DataType dataType) { ++ TensorBuffer buffer = TensorBuffer.createDynamic(dataType); ++ ImageConversions.convertBitmapToTensorBuffer(bitmap, buffer); ++ return buffer; ++ } ++ ++ @Override ++ public Image getMediaImage() { ++ throw new UnsupportedOperationException( ++ "Converting from Bitmap to android.media.Image is unsupported."); ++ } ++ ++ @Override ++ public int getWidth() { ++ return bitmap.getWidth(); ++ } ++ ++ @Override ++ public int getHeight() { ++ return bitmap.getHeight(); ++ } ++ ++ @Override ++ public ColorSpaceType getColorSpaceType() { ++ return ColorSpaceType.fromBitmapConfig(bitmap.getConfig()); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java +index 8571d6227e136..a2e833b68d6d0 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java +@@ -18,13 +18,15 @@ package org.tensorflow.lite.support.image; + import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; + + import android.graphics.RectF; ++ ++import org.tensorflow.lite.DataType; ++import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; ++ + import java.nio.ByteBuffer; + import java.nio.FloatBuffer; + import java.util.ArrayList; + import java.util.Arrays; + import java.util.List; +-import org.tensorflow.lite.DataType; +-import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + + /** + * Helper class for converting values that represents bounding boxes into rectangles. +@@ -37,207 +39,186 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + * elements in each type is configurable as well. + */ + public final class BoundingBoxUtil { ++ /** Denotes how a bounding box is represented. */ ++ public enum Type { ++ /** ++ * Represents the bounding box by using the combination of boundaries, {left, top, right, ++ * bottom}. The default order is {left, top, right, bottom}. Other orders can be indicated ++ * by an index array. ++ */ ++ BOUNDARIES, ++ /** ++ * Represents the bounding box by using the upper_left corner, width and height. The default ++ * order is {upper_left_x, upper_left_y, width, height}. Other orders can be indicated by an ++ * index array. ++ */ ++ UPPER_LEFT, ++ /** ++ * Represents the bounding box by using the center of the box, width and height. The default ++ * order is {center_x, center_y, width, height}. Other orders can be indicated by an index ++ * array. ++ */ ++ CENTER, ++ } ++ ++ /** Denotes if the coordinates are actual pixels or relative ratios. */ ++ public enum CoordinateType { ++ /** The coordinates are relative ratios in range [0, 1]. */ ++ RATIO, ++ /** The coordinates are actual pixel values. */ ++ PIXEL ++ } + +- /** Denotes how a bounding box is represented. */ +- public enum Type { +- /** +- * Represents the bounding box by using the combination of boundaries, {left, top, right, +- * bottom}. The default order is {left, top, right, bottom}. Other orders can be indicated by an +- * index array. +- */ +- BOUNDARIES, +- /** +- * Represents the bounding box by using the upper_left corner, width and height. The default +- * order is {upper_left_x, upper_left_y, width, height}. Other orders can be indicated by an +- * index array. +- */ +- UPPER_LEFT, + /** +- * Represents the bounding box by using the center of the box, width and height. The default +- * order is {center_x, center_y, width, height}. Other orders can be indicated by an index +- * array. ++ * Creates a list of bounding boxes from a {@link TensorBuffer} which represents bounding boxes. ++ * ++ * @param tensor holds the data representing some boxes. ++ * @param valueIndex denotes the order of the elements defined in each bounding box type. An ++ * empty ++ * index array represent the default order of each bounding box type. For example, to denote ++ * the default order of BOUNDARIES, {left, top, right, bottom}, the index should be {0, 1, ++ * 2, 3}. To denote the order {left, right, top, bottom}, the order should be {0, 2, 1, 3}. ++ * <p>The index array can be applied to all bounding box types to adjust the order of their ++ * corresponding underlying elements. ++ * @param boundingBoxAxis specifies the index of the dimension that represents bounding box. The ++ * size of that dimension is required to be 4. Index here starts from 0. For example, if the ++ * tensor has shape 4x10, the axis for bounding boxes is likely to be 0. Negative axis is ++ * also supported: -1 gives the last axis and -2 gives the second, .etc. theFor shape 10x4, the ++ * axis is likely to be 1 (or -1, equivalently). ++ * @param type defines how values should be converted into boxes. See {@link Type} ++ * @param coordinateType defines how values are interpreted to coordinates. See {@link ++ * CoordinateType} ++ * @param height the height of the image which the boxes belong to. Only has effects when {@code ++ * coordinateType} is {@link CoordinateType#RATIO} ++ * @param width the width of the image which the boxes belong to. Only has effects when {@code ++ * coordinateType} is {@link CoordinateType#RATIO} ++ * @return A list of bounding boxes that the {@code tensor} represents. All dimensions except ++ * {@code boundingBoxAxis} will be collapsed with order kept. For example, given {@code ++ * tensor} with shape {1, 4, 10, 2} and {@code boundingBoxAxis = 1}, The result will be a ++ * list of 20 bounding boxes. ++ * @throws IllegalArgumentException if size of bounding box dimension (set by {@code ++ * boundingBoxAxis}) is not 4. ++ * @throws IllegalArgumentException if {@code boundingBoxAxis} is not in {@code (-(D+1), D)} ++ * where ++ * {@code D} is the number of dimensions of the {@code tensor}. ++ * @throws IllegalArgumentException if {@code tensor} has data type other than {@link ++ * DataType#FLOAT32}. + */ +- CENTER, +- } +- +- /** Denotes if the coordinates are actual pixels or relative ratios. */ +- public enum CoordinateType { +- /** The coordinates are relative ratios in range [0, 1]. */ +- RATIO, +- /** The coordinates are actual pixel values. */ +- PIXEL +- } +- +- /** +- * Creates a list of bounding boxes from a {@link TensorBuffer} which represents bounding boxes. +- * +- * @param tensor holds the data representing some boxes. +- * @param valueIndex denotes the order of the elements defined in each bounding box type. An empty +- * index array represent the default order of each bounding box type. For example, to denote +- * the default order of BOUNDARIES, {left, top, right, bottom}, the index should be {0, 1, 2, +- * 3}. To denote the order {left, right, top, bottom}, the order should be {0, 2, 1, 3}. +- * <p>The index array can be applied to all bounding box types to adjust the order of their +- * corresponding underlying elements. +- * @param boundingBoxAxis specifies the index of the dimension that represents bounding box. The +- * size of that dimension is required to be 4. Index here starts from 0. For example, if the +- * tensor has shape 4x10, the axis for bounding boxes is likely to be 0. Negative axis is also +- * supported: -1 gives the last axis and -2 gives the second, .etc. theFor shape 10x4, the +- * axis is likely to be 1 (or -1, equivalently). +- * @param type defines how values should be converted into boxes. See {@link Type} +- * @param coordinateType defines how values are interpreted to coordinates. See {@link +- * CoordinateType} +- * @param height the height of the image which the boxes belong to. Only has effects when {@code +- * coordinateType} is {@link CoordinateType#RATIO} +- * @param width the width of the image which the boxes belong to. Only has effects when {@code +- * coordinateType} is {@link CoordinateType#RATIO} +- * @return A list of bounding boxes that the {@code tensor} represents. All dimensions except +- * {@code boundingBoxAxis} will be collapsed with order kept. For example, given {@code +- * tensor} with shape {1, 4, 10, 2} and {@code boundingBoxAxis = 1}, The result will be a list +- * of 20 bounding boxes. +- * @throws IllegalArgumentException if size of bounding box dimension (set by {@code +- * boundingBoxAxis}) is not 4. +- * @throws IllegalArgumentException if {@code boundingBoxAxis} is not in {@code (-(D+1), D)} where +- * {@code D} is the number of dimensions of the {@code tensor}. +- * @throws IllegalArgumentException if {@code tensor} has data type other than {@link +- * DataType#FLOAT32}. +- */ +- public static List<RectF> convert( +- TensorBuffer tensor, +- int[] valueIndex, +- int boundingBoxAxis, +- Type type, +- CoordinateType coordinateType, +- int height, +- int width) { +- int[] shape = tensor.getShape(); +- checkArgument( +- boundingBoxAxis >= -shape.length && boundingBoxAxis < shape.length, +- String.format( +- "Axis %d is not in range (-(D+1), D), where D is the number of dimensions of input" +- + " tensor (shape=%s)", +- boundingBoxAxis, Arrays.toString(shape))); +- if (boundingBoxAxis < 0) { +- boundingBoxAxis = shape.length + boundingBoxAxis; +- } +- checkArgument( +- shape[boundingBoxAxis] == 4, +- String.format( +- "Size of bounding box dimension %d is not 4. Got %d in shape %s", +- boundingBoxAxis, shape[boundingBoxAxis], Arrays.toString(shape))); +- checkArgument( +- valueIndex.length == 4, +- String.format( +- "Bounding box index array length %d is not 4. Got index array %s", +- valueIndex.length, Arrays.toString(valueIndex))); +- checkArgument( +- tensor.getDataType() == DataType.FLOAT32, +- "Bounding Boxes only create from FLOAT32 buffers. Got: " + tensor.getDataType().name()); +- List<RectF> boundingBoxList = new ArrayList<>(); +- // Collapse dimensions to {a, 4, b}. So each bounding box could be represent as (i, j), and its +- // four values are (i, k, j), where 0 <= k < 4. We can compute the 4 flattened index by +- // i * 4b + k * b + j. +- int a = 1; +- for (int i = 0; i < boundingBoxAxis; i++) { +- a *= shape[i]; ++ public static List<RectF> convert(TensorBuffer tensor, int[] valueIndex, int boundingBoxAxis, ++ Type type, CoordinateType coordinateType, int height, int width) { ++ int[] shape = tensor.getShape(); ++ checkArgument(boundingBoxAxis >= -shape.length && boundingBoxAxis < shape.length, ++ String.format( ++ "Axis %d is not in range (-(D+1), D), where D is the number of dimensions of input" ++ + " tensor (shape=%s)", ++ boundingBoxAxis, Arrays.toString(shape))); ++ if (boundingBoxAxis < 0) { ++ boundingBoxAxis = shape.length + boundingBoxAxis; ++ } ++ checkArgument(shape[boundingBoxAxis] == 4, ++ String.format("Size of bounding box dimension %d is not 4. Got %d in shape %s", ++ boundingBoxAxis, shape[boundingBoxAxis], Arrays.toString(shape))); ++ checkArgument(valueIndex.length == 4, ++ String.format("Bounding box index array length %d is not 4. Got index array %s", ++ valueIndex.length, Arrays.toString(valueIndex))); ++ checkArgument(tensor.getDataType() == DataType.FLOAT32, ++ "Bounding Boxes only create from FLOAT32 buffers. Got: " ++ + tensor.getDataType().name()); ++ List<RectF> boundingBoxList = new ArrayList<>(); ++ // Collapse dimensions to {a, 4, b}. So each bounding box could be represent as (i, j), and ++ // its four values are (i, k, j), where 0 <= k < 4. We can compute the 4 flattened index by ++ // i * 4b + k * b + j. ++ int a = 1; ++ for (int i = 0; i < boundingBoxAxis; i++) { ++ a *= shape[i]; ++ } ++ int b = 1; ++ for (int i = boundingBoxAxis + 1; i < shape.length; i++) { ++ b *= shape[i]; ++ } ++ float[] values = new float[4]; ++ ByteBuffer byteBuffer = tensor.getBuffer(); ++ byteBuffer.rewind(); ++ FloatBuffer floatBuffer = byteBuffer.asFloatBuffer(); ++ for (int i = 0; i < a; i++) { ++ for (int j = 0; j < b; j++) { ++ for (int k = 0; k < 4; k++) { ++ values[k] = floatBuffer.get((i * 4 + k) * b + j); ++ } ++ boundingBoxList.add(convertOneBoundingBox( ++ values, valueIndex, type, coordinateType, height, width)); ++ } ++ } ++ byteBuffer.rewind(); ++ return boundingBoxList; + } +- int b = 1; +- for (int i = boundingBoxAxis + 1; i < shape.length; i++) { +- b *= shape[i]; ++ ++ private static RectF convertOneBoundingBox(float[] values, int[] valueIndex, Type type, ++ CoordinateType coordinateType, int height, int width) { ++ float[] orderedValues = new float[4]; ++ for (int i = 0; i < 4; i++) { ++ orderedValues[i] = values[valueIndex[i]]; ++ } ++ return convertOneBoundingBox(orderedValues, type, coordinateType, height, width); + } +- float[] values = new float[4]; +- ByteBuffer byteBuffer = tensor.getBuffer(); +- byteBuffer.rewind(); +- FloatBuffer floatBuffer = byteBuffer.asFloatBuffer(); +- for (int i = 0; i < a; i++) { +- for (int j = 0; j < b; j++) { +- for (int k = 0; k < 4; k++) { +- values[k] = floatBuffer.get((i * 4 + k) * b + j); ++ ++ private static RectF convertOneBoundingBox( ++ float[] values, Type type, CoordinateType coordinateType, int height, int width) { ++ switch (type) { ++ case BOUNDARIES: ++ return convertFromBoundaries(values, coordinateType, height, width); ++ case UPPER_LEFT: ++ return convertFromUpperLeft(values, coordinateType, height, width); ++ case CENTER: ++ return convertFromCenter(values, coordinateType, height, width); + } +- boundingBoxList.add( +- convertOneBoundingBox(values, valueIndex, type, coordinateType, height, width)); +- } ++ throw new IllegalArgumentException("Cannot recognize BoundingBox.Type " + type); + } +- byteBuffer.rewind(); +- return boundingBoxList; +- } +- +- private static RectF convertOneBoundingBox( +- float[] values, +- int[] valueIndex, +- Type type, +- CoordinateType coordinateType, +- int height, +- int width) { +- float[] orderedValues = new float[4]; +- for (int i = 0; i < 4; i++) { +- orderedValues[i] = values[valueIndex[i]]; ++ ++ private static RectF convertFromBoundaries( ++ float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) { ++ float left = values[0]; ++ float top = values[1]; ++ float right = values[2]; ++ float bottom = values[3]; ++ return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType); ++ } ++ ++ private static RectF convertFromUpperLeft( ++ float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) { ++ float left = values[0]; ++ float top = values[1]; ++ float right = values[0] + values[2]; ++ float bottom = values[1] + values[3]; ++ return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType); + } +- return convertOneBoundingBox(orderedValues, type, coordinateType, height, width); +- } +- +- private static RectF convertOneBoundingBox( +- float[] values, Type type, CoordinateType coordinateType, int height, int width) { +- switch (type) { +- case BOUNDARIES: +- return convertFromBoundaries(values, coordinateType, height, width); +- case UPPER_LEFT: +- return convertFromUpperLeft(values, coordinateType, height, width); +- case CENTER: +- return convertFromCenter(values, coordinateType, height, width); ++ ++ private static RectF convertFromCenter( ++ float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) { ++ float centerX = values[0]; ++ float centerY = values[1]; ++ float w = values[2]; ++ float h = values[3]; ++ ++ float left = centerX - w / 2; ++ float top = centerY - h / 2; ++ float right = centerX + w / 2; ++ float bottom = centerY + h / 2; ++ return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType); + } +- throw new IllegalArgumentException("Cannot recognize BoundingBox.Type " + type); +- } +- +- private static RectF convertFromBoundaries( +- float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) { +- float left = values[0]; +- float top = values[1]; +- float right = values[2]; +- float bottom = values[3]; +- return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType); +- } +- +- private static RectF convertFromUpperLeft( +- float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) { +- float left = values[0]; +- float top = values[1]; +- float right = values[0] + values[2]; +- float bottom = values[1] + values[3]; +- return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType); +- } +- +- private static RectF convertFromCenter( +- float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) { +- float centerX = values[0]; +- float centerY = values[1]; +- float w = values[2]; +- float h = values[3]; +- +- float left = centerX - w / 2; +- float top = centerY - h / 2; +- float right = centerX + w / 2; +- float bottom = centerY + h / 2; +- return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType); +- } +- +- private static RectF getRectF( +- float left, +- float top, +- float right, +- float bottom, +- int imageHeight, +- int imageWidth, +- CoordinateType coordinateType) { +- if (coordinateType == CoordinateType.PIXEL) { +- return new RectF(left, top, right, bottom); +- } else if (coordinateType == CoordinateType.RATIO) { +- return new RectF( +- left * imageWidth, top * imageHeight, right * imageWidth, bottom * imageHeight); +- } else { +- throw new IllegalArgumentException("Cannot convert coordinate type " + coordinateType); ++ ++ private static RectF getRectF(float left, float top, float right, float bottom, int imageHeight, ++ int imageWidth, CoordinateType coordinateType) { ++ if (coordinateType == CoordinateType.PIXEL) { ++ return new RectF(left, top, right, bottom); ++ } else if (coordinateType == CoordinateType.RATIO) { ++ return new RectF( ++ left * imageWidth, top * imageHeight, right * imageWidth, bottom * imageHeight); ++ } else { ++ throw new IllegalArgumentException("Cannot convert coordinate type " + coordinateType); ++ } + } +- } + +- // Private constructor to prevent initialization. +- private BoundingBoxUtil() {} ++ // Private constructor to prevent initialization. ++ private BoundingBoxUtil() {} + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java +index 457bcf1da1de3..716cacdf7bf51 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java +@@ -20,354 +20,351 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c + import android.graphics.Bitmap; + import android.graphics.Bitmap.Config; + import android.graphics.ImageFormat; +-import java.util.Arrays; ++ + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + ++import java.util.Arrays; ++ + /** Represents the type of color space of an image. */ + public enum ColorSpaceType { +- /** Each pixel has red, green, and blue color components. */ +- RGB(0) { +- +- // The channel axis should always be 3 for RGB images. +- private static final int CHANNEL_VALUE = 3; +- +- @Override +- Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) { +- return ImageConversions.convertRgbTensorBufferToBitmap(buffer); ++ /** Each pixel has red, green, and blue color components. */ ++ RGB(0) { ++ // The channel axis should always be 3 for RGB images. ++ private static final int CHANNEL_VALUE = 3; ++ ++ @Override ++ Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) { ++ return ImageConversions.convertRgbTensorBufferToBitmap(buffer); ++ } ++ ++ @Override ++ int getChannelValue() { ++ return CHANNEL_VALUE; ++ } ++ ++ @Override ++ int[] getNormalizedShape(int[] shape) { ++ switch (shape.length) { ++ // The shape is in (h, w, c) format. ++ case 3: ++ return insertValue(shape, BATCH_DIM, BATCH_VALUE); ++ case 4: ++ return shape; ++ default: ++ throw new IllegalArgumentException(getShapeInfoMessage() ++ + "The provided image shape is " + Arrays.toString(shape)); ++ } ++ } ++ ++ @Override ++ int getNumElements(int height, int width) { ++ return height * width * CHANNEL_VALUE; ++ } ++ ++ @Override ++ String getShapeInfoMessage() { ++ return "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels" ++ + " representing R, G, B in order. "; ++ } ++ ++ @Override ++ Config toBitmapConfig() { ++ return Config.ARGB_8888; ++ } ++ }, ++ ++ /** Each pixel is a single element representing only the amount of light. */ ++ GRAYSCALE(1) { ++ // The channel axis should always be 1 for grayscale images. ++ private static final int CHANNEL_VALUE = 1; ++ ++ @Override ++ Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) { ++ return ImageConversions.convertGrayscaleTensorBufferToBitmap(buffer); ++ } ++ ++ @Override ++ int getChannelValue() { ++ return CHANNEL_VALUE; ++ } ++ ++ @Override ++ int[] getNormalizedShape(int[] shape) { ++ switch (shape.length) { ++ // The shape is in (h, w) format. ++ case 2: ++ int[] shapeWithBatch = insertValue(shape, BATCH_DIM, BATCH_VALUE); ++ return insertValue(shapeWithBatch, CHANNEL_DIM, CHANNEL_VALUE); ++ case 4: ++ return shape; ++ default: ++ // (1, h, w) and (h, w, 1) are potential grayscale image shapes. However, since ++ // they both have three dimensions, it will require extra info to differentiate ++ // between them. Since we haven't encountered real use cases of these two ++ // shapes, they are not supported at this moment to avoid confusion. We may want ++ // to revisit it in the future. ++ throw new IllegalArgumentException(getShapeInfoMessage() ++ + "The provided image shape is " + Arrays.toString(shape)); ++ } ++ } ++ ++ @Override ++ int getNumElements(int height, int width) { ++ return height * width; ++ } ++ ++ @Override ++ String getShapeInfoMessage() { ++ return "The shape of a grayscale image should be (h, w) or (1, h, w, 1). "; ++ } ++ ++ @Override ++ Config toBitmapConfig() { ++ return Config.ALPHA_8; ++ } ++ }, ++ ++ /** YUV420sp format, encoded as "YYYYYYYY UVUV". */ ++ NV12(2) { ++ @Override ++ int getNumElements(int height, int width) { ++ return getYuv420NumElements(height, width); ++ } ++ }, ++ ++ /** ++ * YUV420sp format, encoded as "YYYYYYYY VUVU", the standard picture format on Android Camera1 ++ * preview. ++ */ ++ NV21(3) { ++ @Override ++ int getNumElements(int height, int width) { ++ return getYuv420NumElements(height, width); ++ } ++ }, ++ ++ /** YUV420p format, encoded as "YYYYYYYY VV UU". */ ++ YV12(4) { ++ @Override ++ int getNumElements(int height, int width) { ++ return getYuv420NumElements(height, width); ++ } ++ }, ++ ++ /** YUV420p format, encoded as "YYYYYYYY UU VV". */ ++ YV21(5) { ++ @Override ++ int getNumElements(int height, int width) { ++ return getYuv420NumElements(height, width); ++ } ++ }, ++ ++ /** ++ * YUV420 format corresponding to {@link android.graphics.ImageFormat#YUV_420_888}. The actual ++ * encoding format (i.e. NV12 / Nv21 / YV12 / YV21) depends on the implementation of the image. ++ * ++ * <p>Use this format only when you load an {@link android.media.Image}. ++ */ ++ YUV_420_888(6) { ++ @Override ++ int getNumElements(int height, int width) { ++ return getYuv420NumElements(height, width); ++ } ++ }; ++ ++ private static final int BATCH_DIM = 0; // The first element of the normalizaed shape. ++ private static final int BATCH_VALUE = 1; // The batch axis should always be one. ++ private static final int HEIGHT_DIM = 1; // The second element of the normalizaed shape. ++ private static final int WIDTH_DIM = 2; // The third element of the normalizaed shape. ++ private static final int CHANNEL_DIM = 3; // The fourth element of the normalizaed shape. ++ private final int value; ++ ++ ColorSpaceType(int value) { ++ this.value = value; + } + +- @Override +- int getChannelValue() { +- return CHANNEL_VALUE; ++ /** ++ * Converts a bitmap configuration into the corresponding color space type. ++ * ++ * @throws IllegalArgumentException if the config is unsupported ++ */ ++ static ColorSpaceType fromBitmapConfig(Config config) { ++ switch (config) { ++ case ARGB_8888: ++ return ColorSpaceType.RGB; ++ case ALPHA_8: ++ return ColorSpaceType.GRAYSCALE; ++ default: ++ throw new IllegalArgumentException( ++ "Bitmap configuration: " + config + ", is not supported yet."); ++ } + } + +- @Override +- int[] getNormalizedShape(int[] shape) { +- switch (shape.length) { +- // The shape is in (h, w, c) format. +- case 3: +- return insertValue(shape, BATCH_DIM, BATCH_VALUE); +- case 4: +- return shape; +- default: +- throw new IllegalArgumentException( +- getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape)); +- } ++ /** ++ * Converts an {@link ImageFormat} value into the corresponding color space type. ++ * ++ * @throws IllegalArgumentException if the config is unsupported ++ */ ++ static ColorSpaceType fromImageFormat(int imageFormat) { ++ switch (imageFormat) { ++ case ImageFormat.NV21: ++ return ColorSpaceType.NV21; ++ case ImageFormat.YV12: ++ return ColorSpaceType.YV12; ++ case ImageFormat.YUV_420_888: ++ return ColorSpaceType.YUV_420_888; ++ default: ++ throw new IllegalArgumentException( ++ "ImageFormat: " + imageFormat + ", is not supported yet."); ++ } + } + +- @Override +- int getNumElements(int height, int width) { +- return height * width * CHANNEL_VALUE; ++ public int getValue() { ++ return value; + } + +- @Override +- String getShapeInfoMessage() { +- return "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels" +- + " representing R, G, B in order. "; ++ /** ++ * Verifies if the given shape matches the color space type. ++ * ++ * @throws IllegalArgumentException if {@code shape} does not match the color space type ++ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE ++ */ ++ void assertShape(int[] shape) { ++ assertRgbOrGrayScale("assertShape()"); ++ ++ int[] normalizedShape = getNormalizedShape(shape); ++ checkArgument(isValidNormalizedShape(normalizedShape), ++ getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape)); + } + +- @Override +- Config toBitmapConfig() { +- return Config.ARGB_8888; ++ /** ++ * Verifies if the given {@code numElements} in an image buffer matches {@code height} / {@code ++ * width} under this color space type. For example, the {@code numElements} of an RGB image of ++ * 30 x 20 should be {@code 30 * 20 * 3 = 1800}; the {@code numElements} of a NV21 image of 30 x ++ * 20 should be {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}. ++ * ++ * @throws IllegalArgumentException if {@code shape} does not match the color space type ++ */ ++ void assertNumElements(int numElements, int height, int width) { ++ checkArgument(numElements >= getNumElements(height, width), ++ String.format( ++ "The given number of elements (%d) does not match the image (%s) in %d x %d. The" ++ + " expected number of elements should be at least %d.", ++ numElements, this.name(), height, width, getNumElements(height, width))); + } +- }, +- +- /** Each pixel is a single element representing only the amount of light. */ +- GRAYSCALE(1) { +- +- // The channel axis should always be 1 for grayscale images. +- private static final int CHANNEL_VALUE = 1; + +- @Override ++ /** ++ * Converts a {@link TensorBuffer} that represents an image to a Bitmap with the color space ++ * type. ++ * ++ * @throws IllegalArgumentException if the shape of buffer does not match the color space type, ++ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE ++ */ + Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) { +- return ImageConversions.convertGrayscaleTensorBufferToBitmap(buffer); ++ throw new UnsupportedOperationException( ++ "convertTensorBufferToBitmap() is unsupported for the color space type " ++ + this.name()); + } + +- @Override +- int getChannelValue() { +- return CHANNEL_VALUE; ++ /** ++ * Returns the width of the given shape corresponding to the color space type. ++ * ++ * @throws IllegalArgumentException if {@code shape} does not match the color space type ++ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE ++ */ ++ int getWidth(int[] shape) { ++ assertRgbOrGrayScale("getWidth()"); ++ assertShape(shape); ++ return getNormalizedShape(shape)[WIDTH_DIM]; + } + +- @Override +- int[] getNormalizedShape(int[] shape) { +- switch (shape.length) { +- // The shape is in (h, w) format. +- case 2: +- int[] shapeWithBatch = insertValue(shape, BATCH_DIM, BATCH_VALUE); +- return insertValue(shapeWithBatch, CHANNEL_DIM, CHANNEL_VALUE); +- case 4: +- return shape; +- default: +- // (1, h, w) and (h, w, 1) are potential grayscale image shapes. However, since they +- // both have three dimensions, it will require extra info to differentiate between them. +- // Since we haven't encountered real use cases of these two shapes, they are not supported +- // at this moment to avoid confusion. We may want to revisit it in the future. +- throw new IllegalArgumentException( +- getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape)); +- } ++ /** ++ * Returns the height of the given shape corresponding to the color space type. ++ * ++ * @throws IllegalArgumentException if {@code shape} does not match the color space type ++ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE ++ */ ++ int getHeight(int[] shape) { ++ assertRgbOrGrayScale("getHeight()"); ++ assertShape(shape); ++ return getNormalizedShape(shape)[HEIGHT_DIM]; + } + +- @Override +- int getNumElements(int height, int width) { +- return height * width; ++ /** ++ * Returns the channel value corresponding to the color space type. ++ * ++ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE ++ */ ++ int getChannelValue() { ++ throw new UnsupportedOperationException( ++ "getChannelValue() is unsupported for the color space type " + this.name()); ++ } ++ /** ++ * Gets the normalized shape in the form of (1, h, w, c). Sometimes, a given shape may not have ++ * batch or channel axis. ++ * ++ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE ++ */ ++ int[] getNormalizedShape(int[] shape) { ++ throw new UnsupportedOperationException( ++ "getNormalizedShape() is unsupported for the color space type " + this.name()); + } + +- @Override ++ /** ++ * Returns the shape information corresponding to the color space type. ++ * ++ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE ++ */ + String getShapeInfoMessage() { +- return "The shape of a grayscale image should be (h, w) or (1, h, w, 1). "; ++ throw new UnsupportedOperationException( ++ "getShapeInfoMessage() is unsupported for the color space type " + this.name()); + } + +- @Override ++ /** ++ * Converts the color space type to the corresponding bitmap config. ++ * ++ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE ++ */ + Config toBitmapConfig() { +- return Config.ALPHA_8; ++ throw new UnsupportedOperationException( ++ "toBitmapConfig() is unsupported for the color space type " + this.name()); + } +- }, + +- /** YUV420sp format, encoded as "YYYYYYYY UVUV". */ +- NV12(2) { +- @Override +- int getNumElements(int height, int width) { +- return getYuv420NumElements(height, width); +- } +- }, +- +- /** +- * YUV420sp format, encoded as "YYYYYYYY VUVU", the standard picture format on Android Camera1 +- * preview. +- */ +- NV21(3) { +- @Override +- int getNumElements(int height, int width) { +- return getYuv420NumElements(height, width); +- } +- }, ++ /** ++ * Gets the number of elements given the height and width of an image. For example, the number ++ * of elements of an RGB image of 30 x 20 is {@code 30 * 20 * 3 = 1800}; the number of elements ++ * of a NV21 image of 30 x 20 is {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}. ++ */ ++ abstract int getNumElements(int height, int width); + +- /** YUV420p format, encoded as "YYYYYYYY VV UU". */ +- YV12(4) { +- @Override +- int getNumElements(int height, int width) { +- return getYuv420NumElements(height, width); ++ private static int getYuv420NumElements(int height, int width) { ++ // Height and width of U/V planes are half of the Y plane. ++ return height * width + ((height + 1) / 2) * ((width + 1) / 2) * 2; + } +- }, + +- /** YUV420p format, encoded as "YYYYYYYY UU VV". */ +- YV21(5) { +- @Override +- int getNumElements(int height, int width) { +- return getYuv420NumElements(height, width); ++ /** Inserts a value at the specified position and return the new array. */ ++ private static int[] insertValue(int[] array, int pos, int value) { ++ int[] newArray = new int[array.length + 1]; ++ for (int i = 0; i < pos; i++) { ++ newArray[i] = array[i]; ++ } ++ newArray[pos] = value; ++ for (int i = pos + 1; i < newArray.length; i++) { ++ newArray[i] = array[i - 1]; ++ } ++ return newArray; + } +- }, +- +- /** +- * YUV420 format corresponding to {@link android.graphics.ImageFormat#YUV_420_888}. The actual +- * encoding format (i.e. NV12 / Nv21 / YV12 / YV21) depends on the implementation of the image. +- * +- * <p>Use this format only when you load an {@link android.media.Image}. +- */ +- YUV_420_888(6) { +- @Override +- int getNumElements(int height, int width) { +- return getYuv420NumElements(height, width); +- } +- }; +- +- private static final int BATCH_DIM = 0; // The first element of the normalizaed shape. +- private static final int BATCH_VALUE = 1; // The batch axis should always be one. +- private static final int HEIGHT_DIM = 1; // The second element of the normalizaed shape. +- private static final int WIDTH_DIM = 2; // The third element of the normalizaed shape. +- private static final int CHANNEL_DIM = 3; // The fourth element of the normalizaed shape. +- private final int value; +- +- ColorSpaceType(int value) { +- this.value = value; +- } +- +- /** +- * Converts a bitmap configuration into the corresponding color space type. +- * +- * @throws IllegalArgumentException if the config is unsupported +- */ +- static ColorSpaceType fromBitmapConfig(Config config) { +- switch (config) { +- case ARGB_8888: +- return ColorSpaceType.RGB; +- case ALPHA_8: +- return ColorSpaceType.GRAYSCALE; +- default: +- throw new IllegalArgumentException( +- "Bitmap configuration: " + config + ", is not supported yet."); +- } +- } +- +- /** +- * Converts an {@link ImageFormat} value into the corresponding color space type. +- * +- * @throws IllegalArgumentException if the config is unsupported +- */ +- static ColorSpaceType fromImageFormat(int imageFormat) { +- switch (imageFormat) { +- case ImageFormat.NV21: +- return ColorSpaceType.NV21; +- case ImageFormat.YV12: +- return ColorSpaceType.YV12; +- case ImageFormat.YUV_420_888: +- return ColorSpaceType.YUV_420_888; +- default: +- throw new IllegalArgumentException( +- "ImageFormat: " + imageFormat + ", is not supported yet."); +- } +- } +- +- public int getValue() { +- return value; +- } +- +- /** +- * Verifies if the given shape matches the color space type. +- * +- * @throws IllegalArgumentException if {@code shape} does not match the color space type +- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE +- */ +- void assertShape(int[] shape) { +- assertRgbOrGrayScale("assertShape()"); +- +- int[] normalizedShape = getNormalizedShape(shape); +- checkArgument( +- isValidNormalizedShape(normalizedShape), +- getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape)); +- } +- +- /** +- * Verifies if the given {@code numElements} in an image buffer matches {@code height} / {@code +- * width} under this color space type. For example, the {@code numElements} of an RGB image of 30 +- * x 20 should be {@code 30 * 20 * 3 = 1800}; the {@code numElements} of a NV21 image of 30 x 20 +- * should be {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}. +- * +- * @throws IllegalArgumentException if {@code shape} does not match the color space type +- */ +- void assertNumElements(int numElements, int height, int width) { +- checkArgument( +- numElements >= getNumElements(height, width), +- String.format( +- "The given number of elements (%d) does not match the image (%s) in %d x %d. The" +- + " expected number of elements should be at least %d.", +- numElements, this.name(), height, width, getNumElements(height, width))); +- } +- +- /** +- * Converts a {@link TensorBuffer} that represents an image to a Bitmap with the color space type. +- * +- * @throws IllegalArgumentException if the shape of buffer does not match the color space type, +- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE +- */ +- Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) { +- throw new UnsupportedOperationException( +- "convertTensorBufferToBitmap() is unsupported for the color space type " + this.name()); +- } +- +- /** +- * Returns the width of the given shape corresponding to the color space type. +- * +- * @throws IllegalArgumentException if {@code shape} does not match the color space type +- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE +- */ +- int getWidth(int[] shape) { +- assertRgbOrGrayScale("getWidth()"); +- assertShape(shape); +- return getNormalizedShape(shape)[WIDTH_DIM]; +- } +- +- /** +- * Returns the height of the given shape corresponding to the color space type. +- * +- * @throws IllegalArgumentException if {@code shape} does not match the color space type +- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE +- */ +- int getHeight(int[] shape) { +- assertRgbOrGrayScale("getHeight()"); +- assertShape(shape); +- return getNormalizedShape(shape)[HEIGHT_DIM]; +- } +- +- /** +- * Returns the channel value corresponding to the color space type. +- * +- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE +- */ +- int getChannelValue() { +- throw new UnsupportedOperationException( +- "getChannelValue() is unsupported for the color space type " + this.name()); +- } +- /** +- * Gets the normalized shape in the form of (1, h, w, c). Sometimes, a given shape may not have +- * batch or channel axis. +- * +- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE +- */ +- int[] getNormalizedShape(int[] shape) { +- throw new UnsupportedOperationException( +- "getNormalizedShape() is unsupported for the color space type " + this.name()); +- } +- +- /** +- * Returns the shape information corresponding to the color space type. +- * +- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE +- */ +- String getShapeInfoMessage() { +- throw new UnsupportedOperationException( +- "getShapeInfoMessage() is unsupported for the color space type " + this.name()); +- } +- +- /** +- * Converts the color space type to the corresponding bitmap config. +- * +- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE +- */ +- Config toBitmapConfig() { +- throw new UnsupportedOperationException( +- "toBitmapConfig() is unsupported for the color space type " + this.name()); +- } +- +- /** +- * Gets the number of elements given the height and width of an image. For example, the number of +- * elements of an RGB image of 30 x 20 is {@code 30 * 20 * 3 = 1800}; the number of elements of a +- * NV21 image of 30 x 20 is {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}. +- */ +- abstract int getNumElements(int height, int width); +- +- private static int getYuv420NumElements(int height, int width) { +- // Height and width of U/V planes are half of the Y plane. +- return height * width + ((height + 1) / 2) * ((width + 1) / 2) * 2; +- } +- +- /** Inserts a value at the specified position and return the new array. */ +- private static int[] insertValue(int[] array, int pos, int value) { +- int[] newArray = new int[array.length + 1]; +- for (int i = 0; i < pos; i++) { +- newArray[i] = array[i]; +- } +- newArray[pos] = value; +- for (int i = pos + 1; i < newArray.length; i++) { +- newArray[i] = array[i - 1]; ++ ++ protected boolean isValidNormalizedShape(int[] shape) { ++ return shape[BATCH_DIM] == BATCH_VALUE && shape[HEIGHT_DIM] > 0 && shape[WIDTH_DIM] > 0 ++ && shape[CHANNEL_DIM] == getChannelValue(); + } +- return newArray; +- } +- +- protected boolean isValidNormalizedShape(int[] shape) { +- return shape[BATCH_DIM] == BATCH_VALUE +- && shape[HEIGHT_DIM] > 0 +- && shape[WIDTH_DIM] > 0 +- && shape[CHANNEL_DIM] == getChannelValue(); +- } +- +- /** Some existing methods are only valid for RGB and GRAYSCALE images. */ +- private void assertRgbOrGrayScale(String unsupportedMethodName) { +- if (this != ColorSpaceType.RGB && this != ColorSpaceType.GRAYSCALE) { +- throw new UnsupportedOperationException( +- unsupportedMethodName +- + " only supports RGB and GRAYSCALE formats, but not " +- + this.name()); ++ ++ /** Some existing methods are only valid for RGB and GRAYSCALE images. */ ++ private void assertRgbOrGrayScale(String unsupportedMethodName) { ++ if (this != ColorSpaceType.RGB && this != ColorSpaceType.GRAYSCALE) { ++ throw new UnsupportedOperationException(unsupportedMethodName ++ + " only supports RGB and GRAYSCALE formats, but not " + this.name()); ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java +index 379d14798d62d..5c097da5ecb6d 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java +@@ -17,6 +17,7 @@ package org.tensorflow.lite.support.image; + + import android.graphics.Bitmap; + import android.media.Image; ++ + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +@@ -32,28 +33,27 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + * </ul> + */ + interface ImageContainer { ++ /** Performs deep copy of the {@link ImageContainer}. */ ++ ImageContainer clone(); + +- /** Performs deep copy of the {@link ImageContainer}. */ +- ImageContainer clone(); +- +- /** Returns the width of the image. */ +- int getWidth(); ++ /** Returns the width of the image. */ ++ int getWidth(); + +- /** Returns the height of the image. */ +- int getHeight(); ++ /** Returns the height of the image. */ ++ int getHeight(); + +- /** Gets the {@link Bitmap} representation of the underlying image format. */ +- Bitmap getBitmap(); ++ /** Gets the {@link Bitmap} representation of the underlying image format. */ ++ Bitmap getBitmap(); + +- /** +- * Gets the {@link TensorBuffer} representation with the specific {@code dataType} of the +- * underlying image format. +- */ +- TensorBuffer getTensorBuffer(DataType dataType); ++ /** ++ * Gets the {@link TensorBuffer} representation with the specific {@code dataType} of the ++ * underlying image format. ++ */ ++ TensorBuffer getTensorBuffer(DataType dataType); + +- /** Gets the {@link Image} representation of the underlying image format. */ +- Image getMediaImage(); ++ /** Gets the {@link Image} representation of the underlying image format. */ ++ Image getMediaImage(); + +- /** Returns the color space type of the image. */ +- ColorSpaceType getColorSpaceType(); ++ /** Returns the color space type of the image. */ ++ ColorSpaceType getColorSpaceType(); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java +index 8ed169c49348e..7ed5306fd9f96 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java +@@ -17,128 +17,127 @@ package org.tensorflow.lite.support.image; + + import android.graphics.Bitmap; + import android.graphics.Color; +-import java.nio.ByteBuffer; +-import java.nio.ByteOrder; ++ + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + ++import java.nio.ByteBuffer; ++import java.nio.ByteOrder; ++ + /** + * Implements some stateless image conversion methods. + * + * <p>This class is an internal helper for {@link org.tensorflow.lite.support.image}. + */ + class ImageConversions { ++ /** ++ * Converts a {@link TensorBuffer} that represents a RGB image to an ARGB_8888 Bitmap. ++ * ++ * <p>Data in buffer will be converted into integer to match the Bitmap API. ++ * ++ * @param buffer a RGB image. Its shape should be either (h, w, 3) or (1, h, w, 3) ++ * @throws IllegalArgumentException if the shape of buffer is neither (h, w, 3) nor (1, h, w, 3) ++ */ ++ static Bitmap convertRgbTensorBufferToBitmap(TensorBuffer buffer) { ++ int[] shape = buffer.getShape(); ++ ColorSpaceType rgb = ColorSpaceType.RGB; ++ rgb.assertShape(shape); + +- /** +- * Converts a {@link TensorBuffer} that represents a RGB image to an ARGB_8888 Bitmap. +- * +- * <p>Data in buffer will be converted into integer to match the Bitmap API. +- * +- * @param buffer a RGB image. Its shape should be either (h, w, 3) or (1, h, w, 3) +- * @throws IllegalArgumentException if the shape of buffer is neither (h, w, 3) nor (1, h, w, 3) +- */ +- static Bitmap convertRgbTensorBufferToBitmap(TensorBuffer buffer) { +- int[] shape = buffer.getShape(); +- ColorSpaceType rgb = ColorSpaceType.RGB; +- rgb.assertShape(shape); +- +- int h = rgb.getHeight(shape); +- int w = rgb.getWidth(shape); +- Bitmap bitmap = Bitmap.createBitmap(w, h, rgb.toBitmapConfig()); +- +- // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time. +- int[] intValues = new int[w * h]; +- int[] rgbValues = buffer.getIntArray(); +- for (int i = 0, j = 0; i < intValues.length; i++) { +- int r = rgbValues[j++]; +- int g = rgbValues[j++]; +- int b = rgbValues[j++]; +- intValues[i] = Color.rgb(r, g, b); +- } +- bitmap.setPixels(intValues, 0, w, 0, 0, w, h); +- +- return bitmap; +- } +- +- /** +- * Converts a {@link TensorBuffer} that represents a grayscale image to an ALPHA_8 Bitmap. +- * +- * <p>Data in buffer will be converted into integer to match the Bitmap API. +- * +- * @param buffer a grayscale image. Its shape should be either (h, w) or (1, h, w) +- * @throws IllegalArgumentException if the shape of buffer is neither (h, w) nor (1, h, w, 1) +- */ +- static Bitmap convertGrayscaleTensorBufferToBitmap(TensorBuffer buffer) { +- // Convert buffer into Uint8 as needed. +- TensorBuffer uint8Buffer = +- buffer.getDataType() == DataType.UINT8 +- ? buffer +- : TensorBuffer.createFrom(buffer, DataType.UINT8); +- +- int[] shape = uint8Buffer.getShape(); +- ColorSpaceType grayscale = ColorSpaceType.GRAYSCALE; +- grayscale.assertShape(shape); +- +- // Even though `Bitmap.createBitmap(int[] colors, int width, int height, Bitmap.Config config)` +- // seems to work for internal Android testing framework, but it actually doesn't work for the +- // real Android environment. +- // +- // The only reliable way to create an ALPHA_8 Bitmap is to use `copyPixelsFromBuffer()` to load +- // the pixels from a ByteBuffer, and then use `copyPixelsToBuffer` to read out. +- // Note: for ALPHA_8 Bitmap, methods such as, `setPixels()` and `getPixels()` do not work. +- Bitmap bitmap = +- Bitmap.createBitmap( +- grayscale.getWidth(shape), grayscale.getHeight(shape), grayscale.toBitmapConfig()); +- uint8Buffer.getBuffer().rewind(); +- bitmap.copyPixelsFromBuffer(uint8Buffer.getBuffer()); +- return bitmap; +- } +- +- /** +- * Converts an Image in a Bitmap to a TensorBuffer (3D Tensor: Width-Height-Channel) whose memory +- * is already allocated, or could be dynamically allocated. +- * +- * @param bitmap The Bitmap object representing the image. Currently we only support ARGB_8888 +- * config. +- * @param buffer The destination of the conversion. Needs to be created in advance. If it's +- * fixed-size, its flat size should be w*h*3. +- * @throws IllegalArgumentException if the buffer is fixed-size, but the size doesn't match. +- */ +- static void convertBitmapToTensorBuffer(Bitmap bitmap, TensorBuffer buffer) { +- int w = bitmap.getWidth(); +- int h = bitmap.getHeight(); +- int[] intValues = new int[w * h]; +- bitmap.getPixels(intValues, 0, w, 0, 0, w, h); +- // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time. +- int[] shape = new int[] {h, w, 3}; +- switch (buffer.getDataType()) { +- case UINT8: +- byte[] byteArr = new byte[w * h * 3]; ++ int h = rgb.getHeight(shape); ++ int w = rgb.getWidth(shape); ++ Bitmap bitmap = Bitmap.createBitmap(w, h, rgb.toBitmapConfig()); ++ ++ // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time. ++ int[] intValues = new int[w * h]; ++ int[] rgbValues = buffer.getIntArray(); + for (int i = 0, j = 0; i < intValues.length; i++) { +- byteArr[j++] = (byte) ((intValues[i] >> 16) & 0xff); +- byteArr[j++] = (byte) ((intValues[i] >> 8) & 0xff); +- byteArr[j++] = (byte) (intValues[i] & 0xff); ++ int r = rgbValues[j++]; ++ int g = rgbValues[j++]; ++ int b = rgbValues[j++]; ++ intValues[i] = Color.rgb(r, g, b); + } +- ByteBuffer byteBuffer = ByteBuffer.wrap(byteArr); +- byteBuffer.order(ByteOrder.nativeOrder()); +- buffer.loadBuffer(byteBuffer, shape); +- break; +- case FLOAT32: +- float[] floatArr = new float[w * h * 3]; +- for (int i = 0, j = 0; i < intValues.length; i++) { +- floatArr[j++] = (float) ((intValues[i] >> 16) & 0xff); +- floatArr[j++] = (float) ((intValues[i] >> 8) & 0xff); +- floatArr[j++] = (float) (intValues[i] & 0xff); ++ bitmap.setPixels(intValues, 0, w, 0, 0, w, h); ++ ++ return bitmap; ++ } ++ ++ /** ++ * Converts a {@link TensorBuffer} that represents a grayscale image to an ALPHA_8 Bitmap. ++ * ++ * <p>Data in buffer will be converted into integer to match the Bitmap API. ++ * ++ * @param buffer a grayscale image. Its shape should be either (h, w) or (1, h, w) ++ * @throws IllegalArgumentException if the shape of buffer is neither (h, w) nor (1, h, w, 1) ++ */ ++ static Bitmap convertGrayscaleTensorBufferToBitmap(TensorBuffer buffer) { ++ // Convert buffer into Uint8 as needed. ++ TensorBuffer uint8Buffer = buffer.getDataType() == DataType.UINT8 ++ ? buffer ++ : TensorBuffer.createFrom(buffer, DataType.UINT8); ++ ++ int[] shape = uint8Buffer.getShape(); ++ ColorSpaceType grayscale = ColorSpaceType.GRAYSCALE; ++ grayscale.assertShape(shape); ++ ++ // Even though `Bitmap.createBitmap(int[] colors, int width, int height, Bitmap.Config ++ // config)` seems to work for internal Android testing framework, but it actually doesn't ++ // work for the real Android environment. ++ // ++ // The only reliable way to create an ALPHA_8 Bitmap is to use `copyPixelsFromBuffer()` to ++ // load the pixels from a ByteBuffer, and then use `copyPixelsToBuffer` to read out. Note: ++ // for ALPHA_8 Bitmap, methods such as, `setPixels()` and `getPixels()` do not work. ++ Bitmap bitmap = Bitmap.createBitmap( ++ grayscale.getWidth(shape), grayscale.getHeight(shape), grayscale.toBitmapConfig()); ++ uint8Buffer.getBuffer().rewind(); ++ bitmap.copyPixelsFromBuffer(uint8Buffer.getBuffer()); ++ return bitmap; ++ } ++ ++ /** ++ * Converts an Image in a Bitmap to a TensorBuffer (3D Tensor: Width-Height-Channel) whose ++ * memory is already allocated, or could be dynamically allocated. ++ * ++ * @param bitmap The Bitmap object representing the image. Currently we only support ARGB_8888 ++ * config. ++ * @param buffer The destination of the conversion. Needs to be created in advance. If it's ++ * fixed-size, its flat size should be w*h*3. ++ * @throws IllegalArgumentException if the buffer is fixed-size, but the size doesn't match. ++ */ ++ static void convertBitmapToTensorBuffer(Bitmap bitmap, TensorBuffer buffer) { ++ int w = bitmap.getWidth(); ++ int h = bitmap.getHeight(); ++ int[] intValues = new int[w * h]; ++ bitmap.getPixels(intValues, 0, w, 0, 0, w, h); ++ // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time. ++ int[] shape = new int[] {h, w, 3}; ++ switch (buffer.getDataType()) { ++ case UINT8: ++ byte[] byteArr = new byte[w * h * 3]; ++ for (int i = 0, j = 0; i < intValues.length; i++) { ++ byteArr[j++] = (byte) ((intValues[i] >> 16) & 0xff); ++ byteArr[j++] = (byte) ((intValues[i] >> 8) & 0xff); ++ byteArr[j++] = (byte) (intValues[i] & 0xff); ++ } ++ ByteBuffer byteBuffer = ByteBuffer.wrap(byteArr); ++ byteBuffer.order(ByteOrder.nativeOrder()); ++ buffer.loadBuffer(byteBuffer, shape); ++ break; ++ case FLOAT32: ++ float[] floatArr = new float[w * h * 3]; ++ for (int i = 0, j = 0; i < intValues.length; i++) { ++ floatArr[j++] = (float) ((intValues[i] >> 16) & 0xff); ++ floatArr[j++] = (float) ((intValues[i] >> 8) & 0xff); ++ floatArr[j++] = (float) (intValues[i] & 0xff); ++ } ++ buffer.loadArray(floatArr, shape); ++ break; ++ default: ++ // Should never happen. ++ throw new IllegalStateException( ++ "The type of TensorBuffer, " + buffer.getBuffer() + ", is unsupported."); + } +- buffer.loadArray(floatArr, shape); +- break; +- default: +- // Should never happen. +- throw new IllegalStateException( +- "The type of TensorBuffer, " + buffer.getBuffer() + ", is unsupported."); + } +- } + +- // Hide the constructor as the class is static. +- private ImageConversions() {} ++ // Hide the constructor as the class is static. ++ private ImageConversions() {} + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java +index 1e546634e90e7..e852569490f0b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java +@@ -16,28 +16,29 @@ limitations under the License. + package org.tensorflow.lite.support.image; + + import android.graphics.PointF; ++ + import org.tensorflow.lite.support.common.Operator; + + /** Operates a TensorImage object. Used in ImageProcessor. */ + public interface ImageOperator extends Operator<TensorImage> { +- /** @see org.tensorflow.lite.support.common.Operator#apply(java.lang.Object) */ +- @Override +- TensorImage apply(TensorImage image); +- +- /** Computes the width of the expected output image when input image size is given. */ +- int getOutputImageWidth(int inputImageHeight, int inputImageWidth); +- +- /** Computes the height of the expected output image when input image size is given. */ +- int getOutputImageHeight(int inputImageHeight, int inputImageWidth); +- +- /** +- * Transforms a point from coordinates system of the result image back to the one of the input +- * image. +- * +- * @param point the point from the result coordinates system. +- * @param inputImageHeight the height of input image. +- * @param inputImageWidth the width of input image. +- * @return the point with the coordinates from the coordinates system of the input image. +- */ +- PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth); ++ /** @see org.tensorflow.lite.support.common.Operator#apply(java.lang.Object) */ ++ @Override ++ TensorImage apply(TensorImage image); ++ ++ /** Computes the width of the expected output image when input image size is given. */ ++ int getOutputImageWidth(int inputImageHeight, int inputImageWidth); ++ ++ /** Computes the height of the expected output image when input image size is given. */ ++ int getOutputImageHeight(int inputImageHeight, int inputImageWidth); ++ ++ /** ++ * Transforms a point from coordinates system of the result image back to the one of the input ++ * image. ++ * ++ * @param point the point from the result coordinates system. ++ * @param inputImageHeight the height of input image. ++ * @param inputImageWidth the width of input image. ++ * @return the point with the coordinates from the coordinates system of the input image. ++ */ ++ PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java +index ffee8f2c2a706..d7a853ee86de6 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java +@@ -20,9 +20,7 @@ import static java.lang.Math.min; + + import android.graphics.PointF; + import android.graphics.RectF; +-import java.util.ArrayList; +-import java.util.List; +-import java.util.ListIterator; ++ + import org.tensorflow.lite.support.common.Operator; + import org.tensorflow.lite.support.common.SequentialProcessor; + import org.tensorflow.lite.support.common.TensorOperator; +@@ -30,6 +28,10 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions; + import org.tensorflow.lite.support.image.ops.Rot90Op; + import org.tensorflow.lite.support.image.ops.TensorOperatorWrapper; + ++import java.util.ArrayList; ++import java.util.List; ++import java.util.ListIterator; ++ + /** + * ImageProcessor is a helper class for preprocessing and postprocessing {@link TensorImage}. It + * could transform a {@link TensorImage} to another by executing a chain of {@link ImageOperator}. +@@ -55,156 +57,159 @@ import org.tensorflow.lite.support.image.ops.TensorOperatorWrapper; + * @see ImageProcessor#process(TensorImage) to apply the processor on a {@link TensorImage} + */ + public class ImageProcessor extends SequentialProcessor<TensorImage> { +- private ImageProcessor(Builder builder) { +- super(builder); +- } +- +- /** +- * Transforms a point from coordinates system of the result image back to the one of the input +- * image. +- * +- * @param point the point from the result coordinates system. +- * @param inputImageHeight the height of input image. +- * @param inputImageWidth the width of input image. +- * @return the point with the coordinates from the coordinates system of the input image. +- */ +- public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { +- List<Integer> widths = new ArrayList<>(); +- List<Integer> heights = new ArrayList<>(); +- int currentWidth = inputImageWidth; +- int currentHeight = inputImageHeight; +- for (Operator<TensorImage> op : operatorList) { +- widths.add(currentWidth); +- heights.add(currentHeight); +- ImageOperator imageOperator = (ImageOperator) op; +- int newHeight = imageOperator.getOutputImageHeight(currentHeight, currentWidth); +- int newWidth = imageOperator.getOutputImageWidth(currentHeight, currentWidth); +- currentHeight = newHeight; +- currentWidth = newWidth; ++ private ImageProcessor(Builder builder) { ++ super(builder); + } +- ListIterator<Operator<TensorImage>> opIterator = operatorList.listIterator(operatorList.size()); +- ListIterator<Integer> widthIterator = widths.listIterator(widths.size()); +- ListIterator<Integer> heightIterator = heights.listIterator(heights.size()); +- while (opIterator.hasPrevious()) { +- ImageOperator imageOperator = (ImageOperator) opIterator.previous(); +- int height = heightIterator.previous(); +- int width = widthIterator.previous(); +- point = imageOperator.inverseTransform(point, height, width); ++ ++ /** ++ * Transforms a point from coordinates system of the result image back to the one of the input ++ * image. ++ * ++ * @param point the point from the result coordinates system. ++ * @param inputImageHeight the height of input image. ++ * @param inputImageWidth the width of input image. ++ * @return the point with the coordinates from the coordinates system of the input image. ++ */ ++ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { ++ List<Integer> widths = new ArrayList<>(); ++ List<Integer> heights = new ArrayList<>(); ++ int currentWidth = inputImageWidth; ++ int currentHeight = inputImageHeight; ++ for (Operator<TensorImage> op : operatorList) { ++ widths.add(currentWidth); ++ heights.add(currentHeight); ++ ImageOperator imageOperator = (ImageOperator) op; ++ int newHeight = imageOperator.getOutputImageHeight(currentHeight, currentWidth); ++ int newWidth = imageOperator.getOutputImageWidth(currentHeight, currentWidth); ++ currentHeight = newHeight; ++ currentWidth = newWidth; ++ } ++ ListIterator<Operator<TensorImage>> opIterator = ++ operatorList.listIterator(operatorList.size()); ++ ListIterator<Integer> widthIterator = widths.listIterator(widths.size()); ++ ListIterator<Integer> heightIterator = heights.listIterator(heights.size()); ++ while (opIterator.hasPrevious()) { ++ ImageOperator imageOperator = (ImageOperator) opIterator.previous(); ++ int height = heightIterator.previous(); ++ int width = widthIterator.previous(); ++ point = imageOperator.inverseTransform(point, height, width); ++ } ++ return point; ++ } ++ ++ /** ++ * Transforms a rectangle from coordinates system of the result image back to the one of the ++ * input image. ++ * ++ * @param rect the rectangle from the result coordinates system. ++ * @param inputImageHeight the height of input image. ++ * @param inputImageWidth the width of input image. ++ * @return the rectangle with the coordinates from the coordinates system of the input image. ++ */ ++ public RectF inverseTransform(RectF rect, int inputImageHeight, int inputImageWidth) { ++ // when rotation is involved, corner order may change - top left changes to bottom right, ++ // .etc ++ PointF p1 = inverseTransform( ++ new PointF(rect.left, rect.top), inputImageHeight, inputImageWidth); ++ PointF p2 = inverseTransform( ++ new PointF(rect.right, rect.bottom), inputImageHeight, inputImageWidth); ++ return new RectF(min(p1.x, p2.x), min(p1.y, p2.y), max(p1.x, p2.x), max(p1.y, p2.y)); + } +- return point; +- } +- +- /** +- * Transforms a rectangle from coordinates system of the result image back to the one of the input +- * image. +- * +- * @param rect the rectangle from the result coordinates system. +- * @param inputImageHeight the height of input image. +- * @param inputImageWidth the width of input image. +- * @return the rectangle with the coordinates from the coordinates system of the input image. +- */ +- public RectF inverseTransform(RectF rect, int inputImageHeight, int inputImageWidth) { +- // when rotation is involved, corner order may change - top left changes to bottom right, .etc +- PointF p1 = +- inverseTransform(new PointF(rect.left, rect.top), inputImageHeight, inputImageWidth); +- PointF p2 = +- inverseTransform(new PointF(rect.right, rect.bottom), inputImageHeight, inputImageWidth); +- return new RectF(min(p1.x, p2.x), min(p1.y, p2.y), max(p1.x, p2.x), max(p1.y, p2.y)); +- } +- +- /** +- * Processes a {@link TensorImage} object with prepared {@link TensorOperator}. +- * +- * @throws IllegalArgumentException if the image is not supported by any op. +- */ +- @Override +- public TensorImage process(TensorImage image) { +- return super.process(image); +- } +- +- /** +- * The Builder to create an ImageProcessor, which could be executed later. +- * +- * @see #add(TensorOperator) to add a general TensorOperator +- * @see #add(ImageOperator) to add an ImageOperator +- * @see #build() complete the building process and get a built Processor +- */ +- public static class Builder extends SequentialProcessor.Builder<TensorImage> { +- public Builder() { +- super(); ++ ++ /** ++ * Processes a {@link TensorImage} object with prepared {@link TensorOperator}. ++ * ++ * @throws IllegalArgumentException if the image is not supported by any op. ++ */ ++ @Override ++ public TensorImage process(TensorImage image) { ++ return super.process(image); + } + + /** +- * Adds an {@link ImageOperator} into the Operator chain. ++ * The Builder to create an ImageProcessor, which could be executed later. + * +- * @param op the Operator instance to be executed then ++ * @see #add(TensorOperator) to add a general TensorOperator ++ * @see #add(ImageOperator) to add an ImageOperator ++ * @see #build() complete the building process and get a built Processor + */ +- public Builder add(ImageOperator op) { +- super.add(op); +- return this; ++ public static class Builder extends SequentialProcessor.Builder<TensorImage> { ++ public Builder() { ++ super(); ++ } ++ ++ /** ++ * Adds an {@link ImageOperator} into the Operator chain. ++ * ++ * @param op the Operator instance to be executed then ++ */ ++ public Builder add(ImageOperator op) { ++ super.add(op); ++ return this; ++ } ++ ++ /** ++ * Adds a {@link TensorOperator} into the Operator chain. In execution, the processor calls ++ * {@link TensorImage#getTensorBuffer()} to transform the {@link TensorImage} by ++ * transforming the underlying {@link ++ * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}. ++ * ++ * @param op the Operator instance to be executed then ++ */ ++ public Builder add(TensorOperator op) { ++ return add(new TensorOperatorWrapper(op)); ++ } ++ ++ /** Completes the building process and gets the {@link ImageProcessor} instance. */ ++ @Override ++ public ImageProcessor build() { ++ return new ImageProcessor(this); ++ } + } + + /** +- * Adds a {@link TensorOperator} into the Operator chain. In execution, the processor calls +- * {@link TensorImage#getTensorBuffer()} to transform the {@link TensorImage} by transforming +- * the underlying {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}. ++ * Updates the number of rotations for the first {@link Rot90Op} in this {@link ImageProcessor}. ++ * ++ * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and ++ * then processing images (using {@link #process}) must be protected from concurrent access with ++ * additional synchronization. + * +- * @param op the Operator instance to be executed then ++ * @param k the number of rotations ++ * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link ++ * ImageProcessor} + */ +- public Builder add(TensorOperator op) { +- return add(new TensorOperatorWrapper(op)); ++ public void updateNumberOfRotations(int k) { ++ updateNumberOfRotations(k, /*occurrence=*/0); + } + +- /** Completes the building process and gets the {@link ImageProcessor} instance. */ +- @Override +- public ImageProcessor build() { +- return new ImageProcessor(this); ++ /** ++ * Updates the number of rotations for the {@link Rot90Op} specified by {@code occurrence} in ++ * this ++ * {@link ImageProcessor}. ++ * ++ * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and ++ * then processing images (using {@link #process}) must be protected from concurrent access with ++ * additional synchronization. ++ * ++ * @param k the number of rotations ++ * @param occurrence the index of perticular {@link Rot90Op} in this {@link ImageProcessor}. For ++ * example, if the second {@link Rot90Op} needs to be updated, {@code occurrence} should be ++ * set to 1. ++ * @throws IndexOutOfBoundsException if {@code occurrence} is negative or is not less than the ++ * number of {@link Rot90Op} in this {@link ImageProcessor} ++ * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link ++ * ImageProcessor} ++ */ ++ public synchronized void updateNumberOfRotations(int k, int occurrence) { ++ SupportPreconditions.checkState(operatorIndex.containsKey(Rot90Op.class.getName()), ++ "The Rot90Op has not been added to the ImageProcessor."); ++ ++ List<Integer> indexes = operatorIndex.get(Rot90Op.class.getName()); ++ SupportPreconditions.checkElementIndex(occurrence, indexes.size(), "occurrence"); ++ ++ // The index of the Rot90Op to be replaced in operatorList. ++ int index = indexes.get(occurrence); ++ Rot90Op newRot = new Rot90Op(k); ++ operatorList.set(index, newRot); + } +- } +- +- /** +- * Updates the number of rotations for the first {@link Rot90Op} in this {@link ImageProcessor}. +- * +- * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and +- * then processing images (using {@link #process}) must be protected from concurrent access with +- * additional synchronization. +- * +- * @param k the number of rotations +- * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link +- * ImageProcessor} +- */ +- public void updateNumberOfRotations(int k) { +- updateNumberOfRotations(k, /*occurrence=*/ 0); +- } +- +- /** +- * Updates the number of rotations for the {@link Rot90Op} specified by {@code occurrence} in this +- * {@link ImageProcessor}. +- * +- * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and +- * then processing images (using {@link #process}) must be protected from concurrent access with +- * additional synchronization. +- * +- * @param k the number of rotations +- * @param occurrence the index of perticular {@link Rot90Op} in this {@link ImageProcessor}. For +- * example, if the second {@link Rot90Op} needs to be updated, {@code occurrence} should be +- * set to 1. +- * @throws IndexOutOfBoundsException if {@code occurrence} is negative or is not less than the +- * number of {@link Rot90Op} in this {@link ImageProcessor} +- * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link +- * ImageProcessor} +- */ +- public synchronized void updateNumberOfRotations(int k, int occurrence) { +- SupportPreconditions.checkState( +- operatorIndex.containsKey(Rot90Op.class.getName()), +- "The Rot90Op has not been added to the ImageProcessor."); +- +- List<Integer> indexes = operatorIndex.get(Rot90Op.class.getName()); +- SupportPreconditions.checkElementIndex(occurrence, indexes.size(), "occurrence"); +- +- // The index of the Rot90Op to be replaced in operatorList. +- int index = indexes.get(occurrence); +- Rot90Op newRot = new Rot90Op(k); +- operatorList.set(index, newRot); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java +index 96daf85a02f5a..f61f59fa13ce7 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java +@@ -26,52 +26,51 @@ import com.google.auto.value.AutoValue; + */ + @AutoValue + public abstract class ImageProperties { ++ private static final int DEFAULT_HEIGHT = -1; ++ private static final int DEFAULT_WIDTH = -1; + +- private static final int DEFAULT_HEIGHT = -1; +- private static final int DEFAULT_WIDTH = -1; +- +- public abstract int getHeight(); +- +- public abstract int getWidth(); +- +- public abstract ColorSpaceType getColorSpaceType(); +- +- public static Builder builder() { +- return new AutoValue_ImageProperties.Builder() +- .setHeight(DEFAULT_HEIGHT) +- .setWidth(DEFAULT_WIDTH); +- } +- +- /** +- * Builder for {@link ImageProperties}. Different image objects may require different properties. +- * See the detais below: +- * +- * <ul> +- * {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}: +- * <li>Mandatory proterties: height / width / colorSpaceType. The shape of the TensorBuffer +- * object will not be used to determine image height and width. +- * </ul> +- */ +- @AutoValue.Builder +- public abstract static class Builder { +- public abstract Builder setHeight(int height); +- +- public abstract Builder setWidth(int width); +- +- public abstract Builder setColorSpaceType(ColorSpaceType colorSpaceType); +- +- abstract ImageProperties autoBuild(); +- +- public ImageProperties build() { +- ImageProperties properties = autoBuild(); +- // If width or hight are not configured by the Builder, they will be -1. +- // Enforcing all properties to be populated (AutoValue will error out if objects, like +- // colorSpaceType, are not set up), since they are required for TensorBuffer images. +- // If in the future we have some image object types that only require a portion of these +- // properties, we can delay the check when TensorImage#load() is executed. +- checkState(properties.getHeight() >= 0, "Negative image height is not allowed."); +- checkState(properties.getWidth() >= 0, "Negative image width is not allowed."); +- return properties; ++ public abstract int getHeight(); ++ ++ public abstract int getWidth(); ++ ++ public abstract ColorSpaceType getColorSpaceType(); ++ ++ public static Builder builder() { ++ return new AutoValue_ImageProperties.Builder() ++ .setHeight(DEFAULT_HEIGHT) ++ .setWidth(DEFAULT_WIDTH); ++ } ++ ++ /** ++ * Builder for {@link ImageProperties}. Different image objects may require different ++ * properties. See the detais below: ++ * ++ * <ul> ++ * {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}: ++ * <li>Mandatory proterties: height / width / colorSpaceType. The shape of the TensorBuffer ++ * object will not be used to determine image height and width. ++ * </ul> ++ */ ++ @AutoValue.Builder ++ public abstract static class Builder { ++ public abstract Builder setHeight(int height); ++ ++ public abstract Builder setWidth(int width); ++ ++ public abstract Builder setColorSpaceType(ColorSpaceType colorSpaceType); ++ ++ abstract ImageProperties autoBuild(); ++ ++ public ImageProperties build() { ++ ImageProperties properties = autoBuild(); ++ // If width or hight are not configured by the Builder, they will be -1. ++ // Enforcing all properties to be populated (AutoValue will error out if objects, like ++ // colorSpaceType, are not set up), since they are required for TensorBuffer images. ++ // If in the future we have some image object types that only require a portion of these ++ // properties, we can delay the check when TensorImage#load() is executed. ++ checkState(properties.getHeight() >= 0, "Negative image height is not allowed."); ++ checkState(properties.getWidth() >= 0, "Negative image width is not allowed."); ++ return properties; ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java +index 50d787b5afab1..519aacaf7f20b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java +@@ -21,65 +21,65 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c + import android.graphics.Bitmap; + import android.graphics.ImageFormat; + import android.media.Image; ++ + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + + /** Holds an {@link Image} and converts it to other image formats as needed. */ + final class MediaImageContainer implements ImageContainer { +- +- private final Image image; +- +- /** +- * Creates a {@link MediaImageContainer} object with a YUV_420_888 {@link Image}. +- * +- * @throws IllegalArgumentException if the {@link ImageFormat} of {@code image} is not ARGB_8888 +- */ +- static MediaImageContainer create(Image image) { +- return new MediaImageContainer(image); +- } +- +- private MediaImageContainer(Image image) { +- checkNotNull(image, "Cannot load null Image."); +- checkArgument( +- image.getFormat() == ImageFormat.YUV_420_888, "Only supports loading YUV_420_888 Image."); +- this.image = image; +- } +- +- @Override +- public MediaImageContainer clone() { +- throw new UnsupportedOperationException( +- "android.media.Image is an abstract class and cannot be cloned."); +- } +- +- @Override +- public Bitmap getBitmap() { +- throw new UnsupportedOperationException( +- "Converting an android.media.Image to Bitmap is not supported."); +- } +- +- @Override +- public TensorBuffer getTensorBuffer(DataType dataType) { +- throw new UnsupportedOperationException( +- "Converting an android.media.Image to TesorBuffer is not supported."); +- } +- +- @Override +- public Image getMediaImage() { +- return image; +- } +- +- @Override +- public int getWidth() { +- return image.getWidth(); +- } +- +- @Override +- public int getHeight() { +- return image.getHeight(); +- } +- +- @Override +- public ColorSpaceType getColorSpaceType() { +- return ColorSpaceType.fromImageFormat(image.getFormat()); +- } ++ private final Image image; ++ ++ /** ++ * Creates a {@link MediaImageContainer} object with a YUV_420_888 {@link Image}. ++ * ++ * @throws IllegalArgumentException if the {@link ImageFormat} of {@code image} is not ARGB_8888 ++ */ ++ static MediaImageContainer create(Image image) { ++ return new MediaImageContainer(image); ++ } ++ ++ private MediaImageContainer(Image image) { ++ checkNotNull(image, "Cannot load null Image."); ++ checkArgument(image.getFormat() == ImageFormat.YUV_420_888, ++ "Only supports loading YUV_420_888 Image."); ++ this.image = image; ++ } ++ ++ @Override ++ public MediaImageContainer clone() { ++ throw new UnsupportedOperationException( ++ "android.media.Image is an abstract class and cannot be cloned."); ++ } ++ ++ @Override ++ public Bitmap getBitmap() { ++ throw new UnsupportedOperationException( ++ "Converting an android.media.Image to Bitmap is not supported."); ++ } ++ ++ @Override ++ public TensorBuffer getTensorBuffer(DataType dataType) { ++ throw new UnsupportedOperationException( ++ "Converting an android.media.Image to TesorBuffer is not supported."); ++ } ++ ++ @Override ++ public Image getMediaImage() { ++ return image; ++ } ++ ++ @Override ++ public int getWidth() { ++ return image.getWidth(); ++ } ++ ++ @Override ++ public int getHeight() { ++ return image.getHeight(); ++ } ++ ++ @Override ++ public ColorSpaceType getColorSpaceType() { ++ return ColorSpaceType.fromImageFormat(image.getFormat()); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java +index ed066e5308fb9..03017bf733f02 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java +@@ -21,91 +21,99 @@ import com.google.android.odml.image.MediaImageExtractor; + import com.google.android.odml.image.MlImage; + import com.google.android.odml.image.MlImage.ImageFormat; + import com.google.auto.value.AutoValue; ++ + import java.nio.ByteBuffer; + + /** Converts {@code MlImage} to {@link TensorImage} and vice versa. */ + public class MlImageAdapter { ++ /** Proxies an {@link ImageFormat} and its equivalent {@link ColorSpaceType}. */ ++ @AutoValue ++ abstract static class ImageFormatProxy { ++ abstract ColorSpaceType getColorSpaceType(); + +- /** Proxies an {@link ImageFormat} and its equivalent {@link ColorSpaceType}. */ +- @AutoValue +- abstract static class ImageFormatProxy { +- +- abstract ColorSpaceType getColorSpaceType(); ++ @ImageFormat ++ abstract int getImageFormat(); + +- @ImageFormat +- abstract int getImageFormat(); +- +- static ImageFormatProxy createFromImageFormat(@ImageFormat int format) { +- switch (format) { +- case MlImage.IMAGE_FORMAT_RGB: +- return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.RGB, format); +- case MlImage.IMAGE_FORMAT_NV12: +- return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.NV12, format); +- case MlImage.IMAGE_FORMAT_NV21: +- return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.NV21, format); +- case MlImage.IMAGE_FORMAT_YV12: +- return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.YV12, format); +- case MlImage.IMAGE_FORMAT_YV21: +- return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.YV21, format); +- case MlImage.IMAGE_FORMAT_YUV_420_888: +- return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.YUV_420_888, format); +- case MlImage.IMAGE_FORMAT_ALPHA: +- return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.GRAYSCALE, format); +- case MlImage.IMAGE_FORMAT_RGBA: +- case MlImage.IMAGE_FORMAT_JPEG: +- case MlImage.IMAGE_FORMAT_UNKNOWN: +- throw new IllegalArgumentException( +- "Cannot create ColorSpaceType from MlImage format: " + format); +- default: +- throw new AssertionError("Illegal @ImageFormat: " + format); +- } ++ static ImageFormatProxy createFromImageFormat(@ImageFormat int format) { ++ switch (format) { ++ case MlImage.IMAGE_FORMAT_RGB: ++ return new AutoValue_MlImageAdapter_ImageFormatProxy( ++ ColorSpaceType.RGB, format); ++ case MlImage.IMAGE_FORMAT_NV12: ++ return new AutoValue_MlImageAdapter_ImageFormatProxy( ++ ColorSpaceType.NV12, format); ++ case MlImage.IMAGE_FORMAT_NV21: ++ return new AutoValue_MlImageAdapter_ImageFormatProxy( ++ ColorSpaceType.NV21, format); ++ case MlImage.IMAGE_FORMAT_YV12: ++ return new AutoValue_MlImageAdapter_ImageFormatProxy( ++ ColorSpaceType.YV12, format); ++ case MlImage.IMAGE_FORMAT_YV21: ++ return new AutoValue_MlImageAdapter_ImageFormatProxy( ++ ColorSpaceType.YV21, format); ++ case MlImage.IMAGE_FORMAT_YUV_420_888: ++ return new AutoValue_MlImageAdapter_ImageFormatProxy( ++ ColorSpaceType.YUV_420_888, format); ++ case MlImage.IMAGE_FORMAT_ALPHA: ++ return new AutoValue_MlImageAdapter_ImageFormatProxy( ++ ColorSpaceType.GRAYSCALE, format); ++ case MlImage.IMAGE_FORMAT_RGBA: ++ case MlImage.IMAGE_FORMAT_JPEG: ++ case MlImage.IMAGE_FORMAT_UNKNOWN: ++ throw new IllegalArgumentException( ++ "Cannot create ColorSpaceType from MlImage format: " + format); ++ default: ++ throw new AssertionError("Illegal @ImageFormat: " + format); ++ } ++ } + } +- } + +- /** +- * Creates a {@link TensorImage} from an {@link MlImage}. +- * +- * <p>IMPORTANT: The returned {@link TensorImage} shares storage with {@code mlImage}, so do not +- * modify the contained object in the {@link TensorImage}, as {@code MlImage} expects its +- * contained data are immutable. Also, callers should use {@code MlImage#getInternal()#acquire()} +- * and {@code MlImage#release()} to avoid the {@code mlImage} being released unexpectedly. +- * +- * @throws IllegalArgumentException if the {@code mlImage} is built from an unsupported container. +- */ +- public static TensorImage createTensorImageFrom(MlImage mlImage) { +- // TODO(b/190670174): Choose the best storage from multiple containers. +- com.google.android.odml.image.ImageProperties mlImageProperties = +- mlImage.getContainedImageProperties().get(0); +- switch (mlImageProperties.getStorageType()) { +- case MlImage.STORAGE_TYPE_BITMAP: +- return TensorImage.fromBitmap(BitmapExtractor.extract(mlImage)); +- case MlImage.STORAGE_TYPE_MEDIA_IMAGE: +- TensorImage mediaTensorImage = new TensorImage(); +- mediaTensorImage.load(MediaImageExtractor.extract(mlImage)); +- return mediaTensorImage; +- case MlImage.STORAGE_TYPE_BYTEBUFFER: +- ByteBuffer buffer = ByteBufferExtractor.extract(mlImage); +- ImageFormatProxy formatProxy = +- ImageFormatProxy.createFromImageFormat(mlImageProperties.getImageFormat()); +- TensorImage byteBufferTensorImage = new TensorImage(); +- ImageProperties properties = +- ImageProperties.builder() +- .setColorSpaceType(formatProxy.getColorSpaceType()) +- .setHeight(mlImage.getHeight()) +- .setWidth(mlImage.getWidth()) +- .build(); +- byteBufferTensorImage.load(buffer, properties); +- return byteBufferTensorImage; +- default: +- throw new IllegalArgumentException( +- "Illegal storage type: " + mlImageProperties.getStorageType()); ++ /** ++ * Creates a {@link TensorImage} from an {@link MlImage}. ++ * ++ * <p>IMPORTANT: The returned {@link TensorImage} shares storage with {@code mlImage}, so do not ++ * modify the contained object in the {@link TensorImage}, as {@code MlImage} expects its ++ * contained data are immutable. Also, callers should use {@code ++ * MlImage#getInternal()#acquire()} and {@code MlImage#release()} to avoid the {@code mlImage} ++ * being released unexpectedly. ++ * ++ * @throws IllegalArgumentException if the {@code mlImage} is built from an unsupported ++ * container. ++ */ ++ public static TensorImage createTensorImageFrom(MlImage mlImage) { ++ // TODO(b/190670174): Choose the best storage from multiple containers. ++ com.google.android.odml.image.ImageProperties mlImageProperties = ++ mlImage.getContainedImageProperties().get(0); ++ switch (mlImageProperties.getStorageType()) { ++ case MlImage.STORAGE_TYPE_BITMAP: ++ return TensorImage.fromBitmap(BitmapExtractor.extract(mlImage)); ++ case MlImage.STORAGE_TYPE_MEDIA_IMAGE: ++ TensorImage mediaTensorImage = new TensorImage(); ++ mediaTensorImage.load(MediaImageExtractor.extract(mlImage)); ++ return mediaTensorImage; ++ case MlImage.STORAGE_TYPE_BYTEBUFFER: ++ ByteBuffer buffer = ByteBufferExtractor.extract(mlImage); ++ ImageFormatProxy formatProxy = ++ ImageFormatProxy.createFromImageFormat(mlImageProperties.getImageFormat()); ++ TensorImage byteBufferTensorImage = new TensorImage(); ++ ImageProperties properties = ++ ImageProperties.builder() ++ .setColorSpaceType(formatProxy.getColorSpaceType()) ++ .setHeight(mlImage.getHeight()) ++ .setWidth(mlImage.getWidth()) ++ .build(); ++ byteBufferTensorImage.load(buffer, properties); ++ return byteBufferTensorImage; ++ default: ++ throw new IllegalArgumentException( ++ "Illegal storage type: " + mlImageProperties.getStorageType()); ++ } + } +- } + +- /** Creatas a {@link ColorSpaceType} from {@code MlImage.ImageFormat}. */ +- public static ColorSpaceType createColorSpaceTypeFrom(@ImageFormat int imageFormat) { +- return ImageFormatProxy.createFromImageFormat(imageFormat).getColorSpaceType(); +- } ++ /** Creatas a {@link ColorSpaceType} from {@code MlImage.ImageFormat}. */ ++ public static ColorSpaceType createColorSpaceTypeFrom(@ImageFormat int imageFormat) { ++ return ImageFormatProxy.createFromImageFormat(imageFormat).getColorSpaceType(); ++ } + +- private MlImageAdapter() {} ++ private MlImageAdapter() {} + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java +index 39e2ceb9db521..6dfef70ba67f7 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java +@@ -20,118 +20,108 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c + import android.graphics.Bitmap; + import android.media.Image; + import android.util.Log; ++ + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + + /** Holds a {@link TensorBuffer} and converts it to other image formats as needed. */ + final class TensorBufferContainer implements ImageContainer { ++ private final TensorBuffer buffer; ++ private final ColorSpaceType colorSpaceType; ++ private final int height; ++ private final int width; ++ private static final String TAG = TensorBufferContainer.class.getSimpleName(); ++ ++ /** ++ * Creates a {@link TensorBufferContainer} object with the specified {@link ++ * TensorImage#ColorSpaceType}. ++ * ++ * <p>Only supports {@link ColorSapceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link ++ * #create(TensorBuffer, ImageProperties)} for other color space types. ++ * ++ * @throws IllegalArgumentException if the shape of the {@link TensorBuffer} does not match the ++ * specified color space type, or if the color space type is not supported ++ */ ++ static TensorBufferContainer create(TensorBuffer buffer, ColorSpaceType colorSpaceType) { ++ checkArgument( ++ colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE, ++ "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use" ++ + " `create(TensorBuffer, ImageProperties)` for other color space types."); ++ ++ return new TensorBufferContainer(buffer, colorSpaceType, ++ colorSpaceType.getHeight(buffer.getShape()), ++ colorSpaceType.getWidth(buffer.getShape())); ++ } + +- private final TensorBuffer buffer; +- private final ColorSpaceType colorSpaceType; +- private final int height; +- private final int width; +- private static final String TAG = TensorBufferContainer.class.getSimpleName(); +- +- /** +- * Creates a {@link TensorBufferContainer} object with the specified {@link +- * TensorImage#ColorSpaceType}. +- * +- * <p>Only supports {@link ColorSapceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link +- * #create(TensorBuffer, ImageProperties)} for other color space types. +- * +- * @throws IllegalArgumentException if the shape of the {@link TensorBuffer} does not match the +- * specified color space type, or if the color space type is not supported +- */ +- static TensorBufferContainer create(TensorBuffer buffer, ColorSpaceType colorSpaceType) { +- checkArgument( +- colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE, +- "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use" +- + " `create(TensorBuffer, ImageProperties)` for other color space types."); +- +- return new TensorBufferContainer( +- buffer, +- colorSpaceType, +- colorSpaceType.getHeight(buffer.getShape()), +- colorSpaceType.getWidth(buffer.getShape())); +- } +- +- static TensorBufferContainer create(TensorBuffer buffer, ImageProperties imageProperties) { +- return new TensorBufferContainer( +- buffer, +- imageProperties.getColorSpaceType(), +- imageProperties.getHeight(), +- imageProperties.getWidth()); +- } +- +- private TensorBufferContainer( +- TensorBuffer buffer, ColorSpaceType colorSpaceType, int height, int width) { +- checkArgument( +- colorSpaceType != ColorSpaceType.YUV_420_888, +- "The actual encoding format of YUV420 is required. Choose a ColorSpaceType from: NV12," +- + " NV21, YV12, YV21. Use YUV_420_888 only when loading an android.media.Image."); +- +- colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width); +- this.buffer = buffer; +- this.colorSpaceType = colorSpaceType; +- this.height = height; +- this.width = width; +- } +- +- @Override +- public TensorBufferContainer clone() { +- return new TensorBufferContainer( +- TensorBuffer.createFrom(buffer, buffer.getDataType()), +- colorSpaceType, +- getHeight(), +- getWidth()); +- } +- +- @Override +- public Bitmap getBitmap() { +- if (buffer.getDataType() != DataType.UINT8) { +- // Print warning instead of throwing an exception. When using float models, users may want to +- // convert the resulting float image into Bitmap. That's fine to do so, as long as they are +- // aware of the potential accuracy lost when casting to uint8. +- Log.w( +- TAG, +- "<Warning> TensorBufferContainer is holding a non-uint8 image. The conversion to Bitmap" +- + " will cause numeric casting and clamping on the data value."); ++ static TensorBufferContainer create(TensorBuffer buffer, ImageProperties imageProperties) { ++ return new TensorBufferContainer(buffer, imageProperties.getColorSpaceType(), ++ imageProperties.getHeight(), imageProperties.getWidth()); + } + +- return colorSpaceType.convertTensorBufferToBitmap(buffer); +- } +- +- @Override +- public TensorBuffer getTensorBuffer(DataType dataType) { +- // If the data type of buffer is desired, return it directly. Not making a defensive copy for +- // performance considerations. During image processing, users may need to set and get the +- // TensorBuffer many times. +- // Otherwise, create another one with the expected data type. +- return buffer.getDataType() == dataType ? buffer : TensorBuffer.createFrom(buffer, dataType); +- } +- +- @Override +- public Image getMediaImage() { +- throw new UnsupportedOperationException( +- "Converting from TensorBuffer to android.media.Image is unsupported."); +- } +- +- @Override +- public int getWidth() { +- // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created. +- colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width); +- return width; +- } +- +- @Override +- public int getHeight() { +- // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created. +- colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width); +- return height; +- } +- +- @Override +- public ColorSpaceType getColorSpaceType() { +- return colorSpaceType; +- } ++ private TensorBufferContainer( ++ TensorBuffer buffer, ColorSpaceType colorSpaceType, int height, int width) { ++ checkArgument(colorSpaceType != ColorSpaceType.YUV_420_888, ++ "The actual encoding format of YUV420 is required. Choose a ColorSpaceType from: NV12," ++ + " NV21, YV12, YV21. Use YUV_420_888 only when loading an android.media.Image."); ++ ++ colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width); ++ this.buffer = buffer; ++ this.colorSpaceType = colorSpaceType; ++ this.height = height; ++ this.width = width; ++ } ++ ++ @Override ++ public TensorBufferContainer clone() { ++ return new TensorBufferContainer(TensorBuffer.createFrom(buffer, buffer.getDataType()), ++ colorSpaceType, getHeight(), getWidth()); ++ } ++ ++ @Override ++ public Bitmap getBitmap() { ++ if (buffer.getDataType() != DataType.UINT8) { ++ // Print warning instead of throwing an exception. When using float models, users may ++ // want to convert the resulting float image into Bitmap. That's fine to do so, as long ++ // as they are aware of the potential accuracy lost when casting to uint8. ++ Log.w(TAG, ++ "<Warning> TensorBufferContainer is holding a non-uint8 image. The conversion to Bitmap" ++ + " will cause numeric casting and clamping on the data value."); ++ } ++ ++ return colorSpaceType.convertTensorBufferToBitmap(buffer); ++ } ++ ++ @Override ++ public TensorBuffer getTensorBuffer(DataType dataType) { ++ // If the data type of buffer is desired, return it directly. Not making a defensive copy ++ // for performance considerations. During image processing, users may need to set and get ++ // the TensorBuffer many times. Otherwise, create another one with the expected data type. ++ return buffer.getDataType() == dataType ? buffer ++ : TensorBuffer.createFrom(buffer, dataType); ++ } ++ ++ @Override ++ public Image getMediaImage() { ++ throw new UnsupportedOperationException( ++ "Converting from TensorBuffer to android.media.Image is unsupported."); ++ } ++ ++ @Override ++ public int getWidth() { ++ // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created. ++ colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width); ++ return width; ++ } ++ ++ @Override ++ public int getHeight() { ++ // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created. ++ colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width); ++ return height; ++ } ++ ++ @Override ++ public ColorSpaceType getColorSpaceType() { ++ return colorSpaceType; ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java +index fbb73020e93d9..a5a12520856b5 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java +@@ -19,10 +19,12 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c + + import android.graphics.Bitmap; + import android.media.Image; +-import java.nio.ByteBuffer; ++ + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + ++import java.nio.ByteBuffer; ++ + /** + * TensorImage is the wrapper class for Image object. When using image processing utils in + * TFLite.support library, it's common to convert image objects in variant types to TensorImage at +@@ -49,350 +51,357 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + // TODO(b/138907116): Support loading images from TensorBuffer with properties. + // TODO(b/138905544): Support directly loading RGBBytes, YUVBytes and other types if necessary. + public class TensorImage { ++ private final DataType dataType; ++ private ImageContainer container = null; ++ ++ /** ++ * Initializes a {@link TensorImage} object. ++ * ++ * <p>Note: the data type of this {@link TensorImage} is {@link DataType#UINT8}. Use {@link ++ * #TensorImage(DataType)} if other data types are preferred. ++ */ ++ public TensorImage() { ++ this(DataType.UINT8); ++ } ++ ++ /** ++ * Initializes a {@link TensorImage} object with the specified data type. ++ * ++ * <p>When getting a {@link TensorBuffer} or a {@link ByteBuffer} from this {@link TensorImage}, ++ * such as using {@link #getTensorBuffer} and {@link #getBuffer}, the data values will be ++ * converted to the specified data type. ++ * ++ * <p>Note: the shape of a {@link TensorImage} is not fixed. It can be adjusted to the shape of ++ * the image being loaded to this {@link TensorImage}. ++ * ++ * @param dataType the expected data type of the resulting {@link TensorBuffer}. The type is ++ * always fixed during the lifetime of the {@link TensorImage}. To convert the data type, ++ * use ++ * {@link #createFrom(TensorImage, DataType)} to create a copy and convert data type at the ++ * same time. ++ * @throws IllegalArgumentException if {@code dataType} is neither {@link DataType#UINT8} nor ++ * {@link DataType#FLOAT32} ++ */ ++ public TensorImage(DataType dataType) { ++ checkArgument(dataType == DataType.UINT8 || dataType == DataType.FLOAT32, ++ "Illegal data type for TensorImage: Only FLOAT32 and UINT8 are accepted"); ++ this.dataType = dataType; ++ } ++ ++ /** ++ * Initializes a {@link TensorImage} object of {@link DataType#UINT8} with a {@link ++ * android.graphics.Bitmap} . ++ * ++ * @see #load(Bitmap) for reusing the object when it's expensive to create objects frequently, ++ * because every call of {@code fromBitmap} creates a new {@link TensorImage}. ++ */ ++ public static TensorImage fromBitmap(Bitmap bitmap) { ++ TensorImage image = new TensorImage(); ++ image.load(bitmap); ++ return image; ++ } ++ ++ /** ++ * Creates a deep-copy of a given {@link TensorImage} with the desired data type. ++ * ++ * @param src the {@link TensorImage} to copy from ++ * @param dataType the expected data type of newly created {@link TensorImage} ++ * @return a {@link TensorImage} whose data is copied from {@code src} and data type is {@code ++ * dataType} ++ */ ++ public static TensorImage createFrom(TensorImage src, DataType dataType) { ++ TensorImage dst = new TensorImage(dataType); ++ dst.container = src.container.clone(); ++ return dst; ++ } ++ ++ /** ++ * Loads a {@link android.graphics.Bitmap} image object into this {@link TensorImage}. ++ * ++ * <p>Note: if the {@link TensorImage} has data type other than {@link DataType#UINT8}, numeric ++ * casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link ++ * #getBuffer}, where the {@link android.graphics.Bitmap} will be converted into a {@link ++ * TensorBuffer}. ++ * ++ * <p>Important: when loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore. ++ * The ++ * {@link TensorImage} object will rely on the bitmap. It will probably modify the bitmap as ++ * well. In this method, we perform a zero-copy approach for that bitmap, by simply holding its ++ * reference. Use {@code bitmap.copy(bitmap.getConfig(), true)} to create a copy if necessary. ++ * ++ * <p>Note: to get the best performance, please load images in the same shape to avoid memory ++ * re-allocation. ++ * ++ * @throws IllegalArgumentException if {@code bitmap} is not in ARGB_8888 ++ */ ++ public void load(Bitmap bitmap) { ++ container = BitmapContainer.create(bitmap); ++ } ++ ++ /** ++ * Loads a float array as RGB pixels into this {@link TensorImage}, representing the pixels ++ * inside. ++ * ++ * <p>Note: if the {@link TensorImage} has a data type other than {@link DataType#FLOAT32}, ++ * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link ++ * #getBuffer}. ++ * ++ * @param pixels the RGB pixels representing the image ++ * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3) ++ * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3) ++ */ ++ public void load(float[] pixels, int[] shape) { ++ TensorBuffer buffer = TensorBuffer.createDynamic(getDataType()); ++ buffer.loadArray(pixels, shape); ++ load(buffer); ++ } + +- private final DataType dataType; +- private ImageContainer container = null; +- +- /** +- * Initializes a {@link TensorImage} object. +- * +- * <p>Note: the data type of this {@link TensorImage} is {@link DataType#UINT8}. Use {@link +- * #TensorImage(DataType)} if other data types are preferred. +- */ +- public TensorImage() { +- this(DataType.UINT8); +- } +- +- /** +- * Initializes a {@link TensorImage} object with the specified data type. +- * +- * <p>When getting a {@link TensorBuffer} or a {@link ByteBuffer} from this {@link TensorImage}, +- * such as using {@link #getTensorBuffer} and {@link #getBuffer}, the data values will be +- * converted to the specified data type. +- * +- * <p>Note: the shape of a {@link TensorImage} is not fixed. It can be adjusted to the shape of +- * the image being loaded to this {@link TensorImage}. +- * +- * @param dataType the expected data type of the resulting {@link TensorBuffer}. The type is +- * always fixed during the lifetime of the {@link TensorImage}. To convert the data type, use +- * {@link #createFrom(TensorImage, DataType)} to create a copy and convert data type at the +- * same time. +- * @throws IllegalArgumentException if {@code dataType} is neither {@link DataType#UINT8} nor +- * {@link DataType#FLOAT32} +- */ +- public TensorImage(DataType dataType) { +- checkArgument( +- dataType == DataType.UINT8 || dataType == DataType.FLOAT32, +- "Illegal data type for TensorImage: Only FLOAT32 and UINT8 are accepted"); +- this.dataType = dataType; +- } +- +- /** +- * Initializes a {@link TensorImage} object of {@link DataType#UINT8} with a {@link +- * android.graphics.Bitmap} . +- * +- * @see #load(Bitmap) for reusing the object when it's expensive to create objects frequently, +- * because every call of {@code fromBitmap} creates a new {@link TensorImage}. +- */ +- public static TensorImage fromBitmap(Bitmap bitmap) { +- TensorImage image = new TensorImage(); +- image.load(bitmap); +- return image; +- } +- +- /** +- * Creates a deep-copy of a given {@link TensorImage} with the desired data type. +- * +- * @param src the {@link TensorImage} to copy from +- * @param dataType the expected data type of newly created {@link TensorImage} +- * @return a {@link TensorImage} whose data is copied from {@code src} and data type is {@code +- * dataType} +- */ +- public static TensorImage createFrom(TensorImage src, DataType dataType) { +- TensorImage dst = new TensorImage(dataType); +- dst.container = src.container.clone(); +- return dst; +- } +- +- /** +- * Loads a {@link android.graphics.Bitmap} image object into this {@link TensorImage}. +- * +- * <p>Note: if the {@link TensorImage} has data type other than {@link DataType#UINT8}, numeric +- * casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link +- * #getBuffer}, where the {@link android.graphics.Bitmap} will be converted into a {@link +- * TensorBuffer}. +- * +- * <p>Important: when loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore. The +- * {@link TensorImage} object will rely on the bitmap. It will probably modify the bitmap as well. +- * In this method, we perform a zero-copy approach for that bitmap, by simply holding its +- * reference. Use {@code bitmap.copy(bitmap.getConfig(), true)} to create a copy if necessary. +- * +- * <p>Note: to get the best performance, please load images in the same shape to avoid memory +- * re-allocation. +- * +- * @throws IllegalArgumentException if {@code bitmap} is not in ARGB_8888 +- */ +- public void load(Bitmap bitmap) { +- container = BitmapContainer.create(bitmap); +- } +- +- /** +- * Loads a float array as RGB pixels into this {@link TensorImage}, representing the pixels +- * inside. +- * +- * <p>Note: if the {@link TensorImage} has a data type other than {@link DataType#FLOAT32}, +- * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link +- * #getBuffer}. +- * +- * @param pixels the RGB pixels representing the image +- * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3) +- * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3) +- */ +- public void load(float[] pixels, int[] shape) { +- TensorBuffer buffer = TensorBuffer.createDynamic(getDataType()); +- buffer.loadArray(pixels, shape); +- load(buffer); +- } +- +- /** +- * Loads an int array as RGB pixels into this {@link TensorImage}, representing the pixels inside. +- * +- * <p>Note: numeric casting and clamping will be applied to convert the values into the data type +- * of this {@link TensorImage} when calling {@link #getTensorBuffer} and {@link #getBuffer}. +- * +- * @param pixels the RGB pixels representing the image +- * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3) +- * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3) +- */ +- public void load(int[] pixels, int[] shape) { +- TensorBuffer buffer = TensorBuffer.createDynamic(getDataType()); +- buffer.loadArray(pixels, shape); +- load(buffer); +- } +- +- /** +- * Loads a {@link TensorBuffer} containing pixel values. The color layout should be RGB. +- * +- * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, +- * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link +- * #getBuffer}. +- * +- * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or +- * (1, h, w, 3) +- * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3) +- */ +- public void load(TensorBuffer buffer) { +- load(buffer, ColorSpaceType.RGB); +- } +- +- /** +- * Loads a {@link TensorBuffer} containing pixel values with the specific {@link ColorSpaceType}. +- * +- * <p>Only supports {@link ColorSpaceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link +- * #load(TensorBuffer, ImageProperties)} for other color space types. +- * +- * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, +- * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link +- * #getBuffer}. +- * +- * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or +- * (1, h, w, 3) for RGB images, and either (h, w) or (1, h, w) for GRAYSCALE images +- * @throws IllegalArgumentException if the shape of buffer does not match the color space type, or +- * if the color space type is not supported +- */ +- public void load(TensorBuffer buffer, ColorSpaceType colorSpaceType) { +- checkArgument( +- colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE, +- "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use" +- + " `load(TensorBuffer, ImageProperties)` for other color space types."); +- +- container = TensorBufferContainer.create(buffer, colorSpaceType); +- } +- +- /** +- * Loads a {@link TensorBuffer} containing pixel values with the specific {@link ImageProperties}. +- * +- * <p>The shape of the {@link TensorBuffer} will not be used to determine image height and width. +- * Set image properties through {@link ImageProperties}. +- * +- * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, +- * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link +- * #getBuffer}. +- * +- * @throws IllegalArgumentException if buffer size is less than the image size indicated by image +- * height, width, and color space type in {@link ImageProperties} +- */ +- public void load(TensorBuffer buffer, ImageProperties imageProperties) { +- container = TensorBufferContainer.create(buffer, imageProperties); +- } +- +- /** +- * Loads a {@link ByteBuffer} containing pixel values with the specific {@link ImageProperties}. +- * +- * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, +- * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link +- * #getBuffer}. +- * +- * @throws IllegalArgumentException if buffer size is less than the image size indicated by image +- * height, width, and color space type in {@link ImageProperties} +- */ +- public void load(ByteBuffer buffer, ImageProperties imageProperties) { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); +- tensorBuffer.loadBuffer(buffer, new int[] {buffer.limit()}); +- container = TensorBufferContainer.create(tensorBuffer, imageProperties); +- } +- +- /** +- * Loads an {@link android.media.Image} object into this {@link TensorImage}. +- * +- * <p>The main usage of this method is to load an {@link android.media.Image} object as model +- * input to the <a href="TFLite Task +- * Library">https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview</a>. +- * {@link TensorImage} backed by {@link android.media.Image} is not supported by {@link +- * ImageProcessor}. +- * +- * <p>* @throws IllegalArgumentException if the {@link android.graphics.ImageFormat} of {@code +- * image} is not YUV_420_888 +- */ +- public void load(Image image) { +- container = MediaImageContainer.create(image); +- } +- +- /** +- * Returns a {@link android.graphics.Bitmap} representation of this {@link TensorImage}. +- * +- * <p>Numeric casting and clamping will be applied if the stored data is not uint8. +- * +- * <p>Note that, the reliable way to get pixels from an {@code ALPHA_8} Bitmap is to use {@code +- * copyPixelsToBuffer}. Bitmap methods such as, `setPixels()` and `getPixels` do not work. +- * +- * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance +- * concern, but if modification is necessary, please make a copy. +- * +- * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A" +- * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} of +- * this {@link TensorBuffer}. +- * @throws IllegalStateException if the {@link TensorImage} never loads data +- */ +- public Bitmap getBitmap() { +- if (container == null) { +- throw new IllegalStateException("No image has been loaded yet."); ++ /** ++ * Loads an int array as RGB pixels into this {@link TensorImage}, representing the pixels ++ * inside. ++ * ++ * <p>Note: numeric casting and clamping will be applied to convert the values into the data ++ * type of this {@link TensorImage} when calling {@link #getTensorBuffer} and {@link ++ * #getBuffer}. ++ * ++ * @param pixels the RGB pixels representing the image ++ * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3) ++ * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3) ++ */ ++ public void load(int[] pixels, int[] shape) { ++ TensorBuffer buffer = TensorBuffer.createDynamic(getDataType()); ++ buffer.loadArray(pixels, shape); ++ load(buffer); + } + +- return container.getBitmap(); +- } +- +- /** +- * Returns a {@link ByteBuffer} representation of this {@link TensorImage} with the expected data +- * type. +- * +- * <p>Numeric casting and clamping will be applied if the stored data is different from the data +- * type of the {@link TensorImage}. +- * +- * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance +- * concern, but if modification is necessary, please make a copy. +- * +- * <p>It's essentially a short cut for {@code getTensorBuffer().getBuffer()}. +- * +- * @return a reference to a {@link ByteBuffer} which holds the image data +- * @throws IllegalStateException if the {@link TensorImage} never loads data +- */ +- public ByteBuffer getBuffer() { +- return getTensorBuffer().getBuffer(); +- } +- +- /** +- * Returns a {@link TensorBuffer} representation of this {@link TensorImage} with the expected +- * data type. +- * +- * <p>Numeric casting and clamping will be applied if the stored data is different from the data +- * type of the {@link TensorImage}. +- * +- * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance +- * concern, but if modification is necessary, please make a copy. +- * +- * @return a reference to a {@link TensorBuffer} which holds the image data +- * @throws IllegalStateException if the {@link TensorImage} never loads data +- */ +- public TensorBuffer getTensorBuffer() { +- if (container == null) { +- throw new IllegalStateException("No image has been loaded yet."); ++ /** ++ * Loads a {@link TensorBuffer} containing pixel values. The color layout should be RGB. ++ * ++ * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, ++ * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link ++ * #getBuffer}. ++ * ++ * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or ++ * (1, h, w, 3) ++ * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3) ++ */ ++ public void load(TensorBuffer buffer) { ++ load(buffer, ColorSpaceType.RGB); + } + +- return container.getTensorBuffer(dataType); +- } +- +- /** +- * Returns an {@link android.media.Image} representation of this {@link TensorImage}. +- * +- * <p>This method only works when the {@link TensorImage} is backed by an {@link +- * android.media.Image}, meaning you need to first load an {@link android.media.Image} through +- * {@link #load(Image)}. +- * +- * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance +- * concern, but if modification is necessary, please make a copy. +- * +- * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A" +- * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} of +- * this {@link TensorBuffer}. +- * @throws IllegalStateException if the {@link TensorImage} never loads data +- */ +- public Image getMediaImage() { +- if (container == null) { +- throw new IllegalStateException("No image has been loaded yet."); ++ /** ++ * Loads a {@link TensorBuffer} containing pixel values with the specific {@link ++ * ColorSpaceType}. ++ * ++ * <p>Only supports {@link ColorSpaceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link ++ * #load(TensorBuffer, ImageProperties)} for other color space types. ++ * ++ * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, ++ * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link ++ * #getBuffer}. ++ * ++ * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or ++ * (1, h, w, 3) for RGB images, and either (h, w) or (1, h, w) for GRAYSCALE images ++ * @throws IllegalArgumentException if the shape of buffer does not match the color space type, ++ * or ++ * if the color space type is not supported ++ */ ++ public void load(TensorBuffer buffer, ColorSpaceType colorSpaceType) { ++ checkArgument( ++ colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE, ++ "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use" ++ + " `load(TensorBuffer, ImageProperties)` for other color space types."); ++ ++ container = TensorBufferContainer.create(buffer, colorSpaceType); + } + +- return container.getMediaImage(); +- } +- +- /** +- * Gets the data type of this {@link TensorImage}. +- * +- * @return a data type. Currently only {@link DataType#UINT8} and {@link DataType#FLOAT32} are +- * supported. +- */ +- public DataType getDataType() { +- return dataType; +- } +- +- /** +- * Gets the color space type of this {@link TensorImage}. +- * +- * @throws IllegalStateException if the {@link TensorImage} never loads data +- */ +- public ColorSpaceType getColorSpaceType() { +- if (container == null) { +- throw new IllegalStateException("No image has been loaded yet."); ++ /** ++ * Loads a {@link TensorBuffer} containing pixel values with the specific {@link ++ * ImageProperties}. ++ * ++ * <p>The shape of the {@link TensorBuffer} will not be used to determine image height and ++ * width. Set image properties through {@link ImageProperties}. ++ * ++ * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, ++ * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link ++ * #getBuffer}. ++ * ++ * @throws IllegalArgumentException if buffer size is less than the image size indicated by ++ * image ++ * height, width, and color space type in {@link ImageProperties} ++ */ ++ public void load(TensorBuffer buffer, ImageProperties imageProperties) { ++ container = TensorBufferContainer.create(buffer, imageProperties); + } + +- return container.getColorSpaceType(); +- } +- +- /** +- * Gets the image width. +- * +- * @throws IllegalStateException if the {@link TensorImage} never loads data +- * @throws IllegalArgumentException if the underlying data is corrupted +- */ +- public int getWidth() { +- if (container == null) { +- throw new IllegalStateException("No image has been loaded yet."); ++ /** ++ * Loads a {@link ByteBuffer} containing pixel values with the specific {@link ImageProperties}. ++ * ++ * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, ++ * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link ++ * #getBuffer}. ++ * ++ * @throws IllegalArgumentException if buffer size is less than the image size indicated by ++ * image ++ * height, width, and color space type in {@link ImageProperties} ++ */ ++ public void load(ByteBuffer buffer, ImageProperties imageProperties) { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); ++ tensorBuffer.loadBuffer(buffer, new int[] {buffer.limit()}); ++ container = TensorBufferContainer.create(tensorBuffer, imageProperties); + } + +- return container.getWidth(); +- } +- +- /** +- * Gets the image height. +- * +- * @throws IllegalStateException if the {@link TensorImage} never loads data +- * @throws IllegalArgumentException if the underlying data is corrupted +- */ +- public int getHeight() { +- if (container == null) { +- throw new IllegalStateException("No image has been loaded yet."); ++ /** ++ * Loads an {@link android.media.Image} object into this {@link TensorImage}. ++ * ++ * <p>The main usage of this method is to load an {@link android.media.Image} object as model ++ * input to the <a href="TFLite Task ++ * Library">https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview</a>. ++ * {@link TensorImage} backed by {@link android.media.Image} is not supported by {@link ++ * ImageProcessor}. ++ * ++ * <p>* @throws IllegalArgumentException if the {@link android.graphics.ImageFormat} of {@code ++ * image} is not YUV_420_888 ++ */ ++ public void load(Image image) { ++ container = MediaImageContainer.create(image); + } + +- return container.getHeight(); +- } ++ /** ++ * Returns a {@link android.graphics.Bitmap} representation of this {@link TensorImage}. ++ * ++ * <p>Numeric casting and clamping will be applied if the stored data is not uint8. ++ * ++ * <p>Note that, the reliable way to get pixels from an {@code ALPHA_8} Bitmap is to use {@code ++ * copyPixelsToBuffer}. Bitmap methods such as, `setPixels()` and `getPixels` do not work. ++ * ++ * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for ++ * performance concern, but if modification is necessary, please make a copy. ++ * ++ * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A" ++ * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} ++ * of this {@link TensorBuffer}. ++ * @throws IllegalStateException if the {@link TensorImage} never loads data ++ */ ++ public Bitmap getBitmap() { ++ if (container == null) { ++ throw new IllegalStateException("No image has been loaded yet."); ++ } ++ ++ return container.getBitmap(); ++ } ++ ++ /** ++ * Returns a {@link ByteBuffer} representation of this {@link TensorImage} with the expected ++ * data type. ++ * ++ * <p>Numeric casting and clamping will be applied if the stored data is different from the data ++ * type of the {@link TensorImage}. ++ * ++ * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for ++ * performance concern, but if modification is necessary, please make a copy. ++ * ++ * <p>It's essentially a short cut for {@code getTensorBuffer().getBuffer()}. ++ * ++ * @return a reference to a {@link ByteBuffer} which holds the image data ++ * @throws IllegalStateException if the {@link TensorImage} never loads data ++ */ ++ public ByteBuffer getBuffer() { ++ return getTensorBuffer().getBuffer(); ++ } ++ ++ /** ++ * Returns a {@link TensorBuffer} representation of this {@link TensorImage} with the expected ++ * data type. ++ * ++ * <p>Numeric casting and clamping will be applied if the stored data is different from the data ++ * type of the {@link TensorImage}. ++ * ++ * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for ++ * performance concern, but if modification is necessary, please make a copy. ++ * ++ * @return a reference to a {@link TensorBuffer} which holds the image data ++ * @throws IllegalStateException if the {@link TensorImage} never loads data ++ */ ++ public TensorBuffer getTensorBuffer() { ++ if (container == null) { ++ throw new IllegalStateException("No image has been loaded yet."); ++ } ++ ++ return container.getTensorBuffer(dataType); ++ } ++ ++ /** ++ * Returns an {@link android.media.Image} representation of this {@link TensorImage}. ++ * ++ * <p>This method only works when the {@link TensorImage} is backed by an {@link ++ * android.media.Image}, meaning you need to first load an {@link android.media.Image} through ++ * {@link #load(Image)}. ++ * ++ * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for ++ * performance concern, but if modification is necessary, please make a copy. ++ * ++ * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A" ++ * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} ++ * of this {@link TensorBuffer}. ++ * @throws IllegalStateException if the {@link TensorImage} never loads data ++ */ ++ public Image getMediaImage() { ++ if (container == null) { ++ throw new IllegalStateException("No image has been loaded yet."); ++ } ++ ++ return container.getMediaImage(); ++ } ++ ++ /** ++ * Gets the data type of this {@link TensorImage}. ++ * ++ * @return a data type. Currently only {@link DataType#UINT8} and {@link DataType#FLOAT32} are ++ * supported. ++ */ ++ public DataType getDataType() { ++ return dataType; ++ } ++ ++ /** ++ * Gets the color space type of this {@link TensorImage}. ++ * ++ * @throws IllegalStateException if the {@link TensorImage} never loads data ++ */ ++ public ColorSpaceType getColorSpaceType() { ++ if (container == null) { ++ throw new IllegalStateException("No image has been loaded yet."); ++ } ++ ++ return container.getColorSpaceType(); ++ } ++ ++ /** ++ * Gets the image width. ++ * ++ * @throws IllegalStateException if the {@link TensorImage} never loads data ++ * @throws IllegalArgumentException if the underlying data is corrupted ++ */ ++ public int getWidth() { ++ if (container == null) { ++ throw new IllegalStateException("No image has been loaded yet."); ++ } ++ ++ return container.getWidth(); ++ } ++ ++ /** ++ * Gets the image height. ++ * ++ * @throws IllegalStateException if the {@link TensorImage} never loads data ++ * @throws IllegalArgumentException if the underlying data is corrupted ++ */ ++ public int getHeight() { ++ if (container == null) { ++ throw new IllegalStateException("No image has been loaded yet."); ++ } ++ ++ return container.getHeight(); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java +index 06391de9cc3e0..adccf23dc97f0 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java +@@ -19,6 +19,7 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c + + import android.graphics.Bitmap; + import android.graphics.PointF; ++ + import org.checkerframework.checker.nullness.qual.NonNull; + import org.tensorflow.lite.support.image.ColorSpaceType; + import org.tensorflow.lite.support.image.ImageOperator; +@@ -32,64 +33,60 @@ import org.tensorflow.lite.support.image.TensorImage; + * @see ResizeWithCropOrPadOp for resizing without content distortion. + */ + public class ResizeOp implements ImageOperator { ++ /** Algorithms for resizing. */ ++ public enum ResizeMethod { BILINEAR, NEAREST_NEIGHBOR } + +- /** Algorithms for resizing. */ +- public enum ResizeMethod { +- BILINEAR, +- NEAREST_NEIGHBOR +- } +- +- private final int targetHeight; +- private final int targetWidth; +- private final boolean useBilinear; ++ private final int targetHeight; ++ private final int targetWidth; ++ private final boolean useBilinear; + +- /** +- * Creates a ResizeOp which can resize images to specified size in specified method. +- * +- * @param targetHeight The expected height of resized image. +- * @param targetWidth The expected width of resized image. +- * @param resizeMethod The algorithm to use for resizing. Options: {@link ResizeMethod} +- */ +- public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod) { +- this.targetHeight = targetHeight; +- this.targetWidth = targetWidth; +- useBilinear = (resizeMethod == ResizeMethod.BILINEAR); +- } ++ /** ++ * Creates a ResizeOp which can resize images to specified size in specified method. ++ * ++ * @param targetHeight The expected height of resized image. ++ * @param targetWidth The expected width of resized image. ++ * @param resizeMethod The algorithm to use for resizing. Options: {@link ResizeMethod} ++ */ ++ public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod) { ++ this.targetHeight = targetHeight; ++ this.targetWidth = targetWidth; ++ useBilinear = (resizeMethod == ResizeMethod.BILINEAR); ++ } + +- /** +- * Applies the defined resizing on given image and returns the result. +- * +- * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance +- * with the output. +- * +- * @param image input image. +- * @return output image. +- */ +- @Override +- @NonNull +- public TensorImage apply(@NonNull TensorImage image) { +- checkArgument( +- image.getColorSpaceType() == ColorSpaceType.RGB, +- "Only RGB images are supported in ResizeOp, but not " + image.getColorSpaceType().name()); +- Bitmap scaled = +- Bitmap.createScaledBitmap(image.getBitmap(), targetWidth, targetHeight, useBilinear); +- image.load(scaled); +- return image; +- } ++ /** ++ * Applies the defined resizing on given image and returns the result. ++ * ++ * <p>Note: the content of input {@code image} will change, and {@code image} is the same ++ * instance with the output. ++ * ++ * @param image input image. ++ * @return output image. ++ */ ++ @Override ++ @NonNull ++ public TensorImage apply(@NonNull TensorImage image) { ++ checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB, ++ "Only RGB images are supported in ResizeOp, but not " ++ + image.getColorSpaceType().name()); ++ Bitmap scaled = Bitmap.createScaledBitmap( ++ image.getBitmap(), targetWidth, targetHeight, useBilinear); ++ image.load(scaled); ++ return image; ++ } + +- @Override +- public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { +- return targetHeight; +- } ++ @Override ++ public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { ++ return targetHeight; ++ } + +- @Override +- public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { +- return targetWidth; +- } ++ @Override ++ public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { ++ return targetWidth; ++ } + +- @Override +- public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { +- return new PointF( +- point.x * inputImageWidth / targetWidth, point.y * inputImageHeight / targetHeight); +- } ++ @Override ++ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { ++ return new PointF( ++ point.x * inputImageWidth / targetWidth, point.y * inputImageHeight / targetHeight); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java +index 66491090ac9c0..e5de5bbcf50d9 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java +@@ -22,6 +22,7 @@ import android.graphics.Bitmap.Config; + import android.graphics.Canvas; + import android.graphics.PointF; + import android.graphics.Rect; ++ + import org.checkerframework.checker.nullness.qual.NonNull; + import org.tensorflow.lite.support.image.ColorSpaceType; + import org.tensorflow.lite.support.image.ImageOperator; +@@ -37,96 +38,95 @@ import org.tensorflow.lite.support.image.TensorImage; + * @see ResizeOp for reszing images while stretching / compressing the content. + */ + public class ResizeWithCropOrPadOp implements ImageOperator { +- private final int targetHeight; +- private final int targetWidth; +- private final Bitmap output; +- +- /** +- * Creates a ResizeWithCropOrPadOp which could crop/pad images to specified size. It adopts +- * center-crop and zero-padding. +- * +- * @param targetHeight The expected height of cropped/padded image. +- * @param targetWidth The expected width of cropped/padded image. +- */ +- public ResizeWithCropOrPadOp(int targetHeight, int targetWidth) { +- this.targetHeight = targetHeight; +- this.targetWidth = targetWidth; +- output = Bitmap.createBitmap(this.targetWidth, this.targetHeight, Config.ARGB_8888); +- } ++ private final int targetHeight; ++ private final int targetWidth; ++ private final Bitmap output; + +- /** +- * Applies the defined resizing with cropping or/and padding on given image and returns the +- * result. +- * +- * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance +- * with the output. +- * +- * @param image input image. +- * @return output image. +- */ +- @Override +- @NonNull +- public TensorImage apply(@NonNull TensorImage image) { +- checkArgument( +- image.getColorSpaceType() == ColorSpaceType.RGB, +- "Only RGB images are supported in ResizeWithCropOrPadOp, but not " +- + image.getColorSpaceType().name()); +- Bitmap input = image.getBitmap(); +- int srcL; +- int srcR; +- int srcT; +- int srcB; +- int dstL; +- int dstR; +- int dstT; +- int dstB; +- int w = input.getWidth(); +- int h = input.getHeight(); +- if (targetWidth > w) { // padding +- srcL = 0; +- srcR = w; +- dstL = (targetWidth - w) / 2; +- dstR = dstL + w; +- } else { // cropping +- dstL = 0; +- dstR = targetWidth; +- srcL = (w - targetWidth) / 2; +- srcR = srcL + targetWidth; ++ /** ++ * Creates a ResizeWithCropOrPadOp which could crop/pad images to specified size. It adopts ++ * center-crop and zero-padding. ++ * ++ * @param targetHeight The expected height of cropped/padded image. ++ * @param targetWidth The expected width of cropped/padded image. ++ */ ++ public ResizeWithCropOrPadOp(int targetHeight, int targetWidth) { ++ this.targetHeight = targetHeight; ++ this.targetWidth = targetWidth; ++ output = Bitmap.createBitmap(this.targetWidth, this.targetHeight, Config.ARGB_8888); + } +- if (targetHeight > h) { // padding +- srcT = 0; +- srcB = h; +- dstT = (targetHeight - h) / 2; +- dstB = dstT + h; +- } else { // cropping +- dstT = 0; +- dstB = targetHeight; +- srcT = (h - targetHeight) / 2; +- srcB = srcT + targetHeight; ++ ++ /** ++ * Applies the defined resizing with cropping or/and padding on given image and returns the ++ * result. ++ * ++ * <p>Note: the content of input {@code image} will change, and {@code image} is the same ++ * instance with the output. ++ * ++ * @param image input image. ++ * @return output image. ++ */ ++ @Override ++ @NonNull ++ public TensorImage apply(@NonNull TensorImage image) { ++ checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB, ++ "Only RGB images are supported in ResizeWithCropOrPadOp, but not " ++ + image.getColorSpaceType().name()); ++ Bitmap input = image.getBitmap(); ++ int srcL; ++ int srcR; ++ int srcT; ++ int srcB; ++ int dstL; ++ int dstR; ++ int dstT; ++ int dstB; ++ int w = input.getWidth(); ++ int h = input.getHeight(); ++ if (targetWidth > w) { // padding ++ srcL = 0; ++ srcR = w; ++ dstL = (targetWidth - w) / 2; ++ dstR = dstL + w; ++ } else { // cropping ++ dstL = 0; ++ dstR = targetWidth; ++ srcL = (w - targetWidth) / 2; ++ srcR = srcL + targetWidth; ++ } ++ if (targetHeight > h) { // padding ++ srcT = 0; ++ srcB = h; ++ dstT = (targetHeight - h) / 2; ++ dstB = dstT + h; ++ } else { // cropping ++ dstT = 0; ++ dstB = targetHeight; ++ srcT = (h - targetHeight) / 2; ++ srcB = srcT + targetHeight; ++ } ++ Rect src = new Rect(srcL, srcT, srcR, srcB); ++ Rect dst = new Rect(dstL, dstT, dstR, dstB); ++ new Canvas(output).drawBitmap(input, src, dst, null); ++ image.load(output); ++ return image; + } +- Rect src = new Rect(srcL, srcT, srcR, srcB); +- Rect dst = new Rect(dstL, dstT, dstR, dstB); +- new Canvas(output).drawBitmap(input, src, dst, null); +- image.load(output); +- return image; +- } + +- @Override +- public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { +- return targetHeight; +- } ++ @Override ++ public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { ++ return targetHeight; ++ } + +- @Override +- public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { +- return targetWidth; +- } ++ @Override ++ public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { ++ return targetWidth; ++ } + +- @Override +- public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { +- return transformImpl(point, targetHeight, targetWidth, inputImageHeight, inputImageWidth); +- } ++ @Override ++ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { ++ return transformImpl(point, targetHeight, targetWidth, inputImageHeight, inputImageWidth); ++ } + +- private static PointF transformImpl(PointF point, int srcH, int srcW, int dstH, int dstW) { +- return new PointF(point.x + (dstW - srcW) / 2, point.y + (dstH - srcH) / 2); +- } ++ private static PointF transformImpl(PointF point, int srcH, int srcW, int dstH, int dstW) { ++ return new PointF(point.x + (dstW - srcW) / 2, point.y + (dstH - srcH) / 2); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java +index 849b4bc9ef3db..86413c90c69ca 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java +@@ -20,6 +20,7 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c + import android.graphics.Bitmap; + import android.graphics.Matrix; + import android.graphics.PointF; ++ + import org.checkerframework.checker.nullness.qual.NonNull; + import org.tensorflow.lite.support.image.ColorSpaceType; + import org.tensorflow.lite.support.image.ImageOperator; +@@ -27,83 +28,83 @@ import org.tensorflow.lite.support.image.TensorImage; + + /** Rotates image counter-clockwise. */ + public class Rot90Op implements ImageOperator { ++ private final int numRotation; + +- private final int numRotation; +- +- /** Creates a Rot90 Op which will rotate image by 90 degree counter-clockwise. */ +- public Rot90Op() { +- this(1); +- } ++ /** Creates a Rot90 Op which will rotate image by 90 degree counter-clockwise. */ ++ public Rot90Op() { ++ this(1); ++ } + +- /** +- * Creates a Rot90 Op which will rotate image by 90 degree for {@code k} times counter-clockwise. +- * +- * @param k The number of times the image is rotated by 90 degrees. If it's positive, the image +- * will be rotated counter-clockwise. If it's negative, the op will rotate image clockwise. +- */ +- public Rot90Op(int k) { +- numRotation = k % 4; +- } ++ /** ++ * Creates a Rot90 Op which will rotate image by 90 degree for {@code k} times ++ * counter-clockwise. ++ * ++ * @param k The number of times the image is rotated by 90 degrees. If it's positive, the image ++ * will be rotated counter-clockwise. If it's negative, the op will rotate image clockwise. ++ */ ++ public Rot90Op(int k) { ++ numRotation = k % 4; ++ } + +- /** +- * Applies the defined rotation on given image and returns the result. +- * +- * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance +- * with the output. +- * +- * @param image input image. +- * @return output image. +- */ +- @NonNull +- @Override +- public TensorImage apply(@NonNull TensorImage image) { +- checkArgument( +- image.getColorSpaceType() == ColorSpaceType.RGB, +- "Only RGB images are supported in Rot90Op, but not " + image.getColorSpaceType().name()); +- Bitmap input = image.getBitmap(); +- if (numRotation == 0) { +- return image; ++ /** ++ * Applies the defined rotation on given image and returns the result. ++ * ++ * <p>Note: the content of input {@code image} will change, and {@code image} is the same ++ * instance with the output. ++ * ++ * @param image input image. ++ * @return output image. ++ */ ++ @NonNull ++ @Override ++ public TensorImage apply(@NonNull TensorImage image) { ++ checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB, ++ "Only RGB images are supported in Rot90Op, but not " ++ + image.getColorSpaceType().name()); ++ Bitmap input = image.getBitmap(); ++ if (numRotation == 0) { ++ return image; ++ } ++ int w = input.getWidth(); ++ int h = input.getHeight(); ++ Matrix matrix = new Matrix(); ++ matrix.postTranslate(w * 0.5f, h * 0.5f); ++ matrix.postRotate(-90 * numRotation); ++ int newW = (numRotation % 2 == 0) ? w : h; ++ int newH = (numRotation % 2 == 0) ? h : w; ++ matrix.postTranslate(newW * 0.5f, newH * 0.5f); ++ Bitmap output = Bitmap.createBitmap(input, 0, 0, w, h, matrix, false); ++ image.load(output); ++ return image; + } +- int w = input.getWidth(); +- int h = input.getHeight(); +- Matrix matrix = new Matrix(); +- matrix.postTranslate(w * 0.5f, h * 0.5f); +- matrix.postRotate(-90 * numRotation); +- int newW = (numRotation % 2 == 0) ? w : h; +- int newH = (numRotation % 2 == 0) ? h : w; +- matrix.postTranslate(newW * 0.5f, newH * 0.5f); +- Bitmap output = Bitmap.createBitmap(input, 0, 0, w, h, matrix, false); +- image.load(output); +- return image; +- } + +- @Override +- public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { +- return (numRotation % 2 == 0) ? inputImageHeight : inputImageWidth; +- } ++ @Override ++ public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { ++ return (numRotation % 2 == 0) ? inputImageHeight : inputImageWidth; ++ } + +- @Override +- public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { +- return (numRotation % 2 == 0) ? inputImageWidth : inputImageHeight; +- } ++ @Override ++ public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { ++ return (numRotation % 2 == 0) ? inputImageWidth : inputImageHeight; ++ } + +- @Override +- public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { +- int inverseNumRotation = (4 - numRotation) % 4; +- int height = getOutputImageHeight(inputImageHeight, inputImageWidth); +- int width = getOutputImageWidth(inputImageHeight, inputImageWidth); +- return transformImpl(point, height, width, inverseNumRotation); +- } ++ @Override ++ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { ++ int inverseNumRotation = (4 - numRotation) % 4; ++ int height = getOutputImageHeight(inputImageHeight, inputImageWidth); ++ int width = getOutputImageWidth(inputImageHeight, inputImageWidth); ++ return transformImpl(point, height, width, inverseNumRotation); ++ } + +- private static PointF transformImpl(PointF point, int height, int width, int numRotation) { +- if (numRotation == 0) { +- return point; +- } else if (numRotation == 1) { +- return new PointF(point.y, width - point.x); +- } else if (numRotation == 2) { +- return new PointF(width - point.x, height - point.y); +- } else { // numRotation == 3 +- return new PointF(height - point.y, point.x); ++ private static PointF transformImpl(PointF point, int height, int width, int numRotation) { ++ if (numRotation == 0) { ++ return point; ++ } else if (numRotation == 1) { ++ return new PointF(point.y, width - point.x); ++ } else if (numRotation == 2) { ++ return new PointF(width - point.x, height - point.y); ++ } else { // numRotation == 3 ++ return new PointF(height - point.y, point.x); ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java +index 5d10ac890e57b..feb2b3b7b0762 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java +@@ -16,6 +16,7 @@ limitations under the License. + package org.tensorflow.lite.support.image.ops; + + import android.graphics.PointF; ++ + import org.checkerframework.checker.nullness.qual.NonNull; + import org.tensorflow.lite.support.common.TensorOperator; + import org.tensorflow.lite.support.common.internal.SupportPreconditions; +@@ -31,48 +32,47 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + * @see org.tensorflow.lite.support.image.TensorImage + */ + public class TensorOperatorWrapper implements ImageOperator { ++ private final TensorOperator tensorOp; + +- private final TensorOperator tensorOp; +- +- /** +- * Wraps a {@link TensorOperator} object as an {@link ImageOperator}, so that the {@link +- * TensorOperator} could handle {@link TensorImage} objects by handling its underlying {@link +- * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}. +- * +- * <p>Requirement: The {@code op} should not change coordinate system when applied on an image. +- * +- * @param op The created operator. +- */ +- public TensorOperatorWrapper(TensorOperator op) { +- tensorOp = op; +- } ++ /** ++ * Wraps a {@link TensorOperator} object as an {@link ImageOperator}, so that the {@link ++ * TensorOperator} could handle {@link TensorImage} objects by handling its underlying {@link ++ * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}. ++ * ++ * <p>Requirement: The {@code op} should not change coordinate system when applied on an image. ++ * ++ * @param op The created operator. ++ */ ++ public TensorOperatorWrapper(TensorOperator op) { ++ tensorOp = op; ++ } + +- @Override +- @NonNull +- public TensorImage apply(@NonNull TensorImage image) { +- SupportPreconditions.checkNotNull(image, "Op cannot apply on null image."); +- TensorBuffer resBuffer = tensorOp.apply(image.getTensorBuffer()); +- // Some ops may change the data type of the underlying TensorBuffer, such as CastOp. Therefore, +- // need to create a new TensorImage with the correct data type. +- // However the underlying ops should not touch the color type. +- ColorSpaceType colorSpaceType = image.getColorSpaceType(); +- TensorImage resImage = new TensorImage(resBuffer.getDataType()); +- resImage.load(resBuffer, colorSpaceType); +- return resImage; +- } ++ @Override ++ @NonNull ++ public TensorImage apply(@NonNull TensorImage image) { ++ SupportPreconditions.checkNotNull(image, "Op cannot apply on null image."); ++ TensorBuffer resBuffer = tensorOp.apply(image.getTensorBuffer()); ++ // Some ops may change the data type of the underlying TensorBuffer, such as CastOp. ++ // Therefore, need to create a new TensorImage with the correct data type. However the ++ // underlying ops should not touch the color type. ++ ColorSpaceType colorSpaceType = image.getColorSpaceType(); ++ TensorImage resImage = new TensorImage(resBuffer.getDataType()); ++ resImage.load(resBuffer, colorSpaceType); ++ return resImage; ++ } + +- @Override +- public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { +- return inputImageHeight; +- } ++ @Override ++ public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { ++ return inputImageHeight; ++ } + +- @Override +- public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { +- return inputImageWidth; +- } ++ @Override ++ public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { ++ return inputImageWidth; ++ } + +- @Override +- public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { +- return point; +- } ++ @Override ++ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { ++ return point; ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java +index bd3c10b254ac5..1a6f905b1bffd 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java +@@ -23,6 +23,7 @@ import android.graphics.ColorFilter; + import android.graphics.ColorMatrixColorFilter; + import android.graphics.Paint; + import android.graphics.PointF; ++ + import org.tensorflow.lite.support.image.ColorSpaceType; + import org.tensorflow.lite.support.image.ImageOperator; + import org.tensorflow.lite.support.image.TensorImage; +@@ -41,77 +42,73 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + * https://docs.opencv.org/master/de/d25/imgproc_color_conversions.html#color_convert_rgb_gray + */ + public class TransformToGrayscaleOp implements ImageOperator { ++ // A matrix is created that will be applied later to canvas to generate grayscale image ++ // The luminance of each pixel is calculated as the weighted sum of the 3 RGB values ++ // Y = 0.299R + 0.587G + 0.114B ++ private static final float[] BITMAP_RGBA_GRAYSCALE_TRANSFORMATION = ++ new float[] {0.299F, 0.587F, 0.114F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, ++ 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F}; + +- // A matrix is created that will be applied later to canvas to generate grayscale image +- // The luminance of each pixel is calculated as the weighted sum of the 3 RGB values +- // Y = 0.299R + 0.587G + 0.114B +- private static final float[] BITMAP_RGBA_GRAYSCALE_TRANSFORMATION = +- new float[] { +- 0.299F, 0.587F, 0.114F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, +- 0.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F +- }; +- +- /** Creates a TransformToGrayscaleOp. */ +- public TransformToGrayscaleOp() {} ++ /** Creates a TransformToGrayscaleOp. */ ++ public TransformToGrayscaleOp() {} + +- /** +- * Applies the transformation to grayscale and returns a {@link TensorImage}. +- * +- * <p>If the input image is already {@link +- * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}, this op will be a no-op. +- * +- * @throws IllegalArgumentException if the {@code image} is not {@link +- * org.tensorflow.lite.support.image.ColorSpaceType#RGB} or {@link +- * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}. +- */ +- @Override +- public TensorImage apply(TensorImage image) { +- if (image.getColorSpaceType() == ColorSpaceType.GRAYSCALE) { +- return image; +- } else { +- checkArgument( +- image.getColorSpaceType() == ColorSpaceType.RGB, +- "Only RGB images are supported in TransformToGrayscaleOp, but not " +- + image.getColorSpaceType().name()); +- } +- int h = image.getHeight(); +- int w = image.getWidth(); +- Bitmap bmpGrayscale = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888); +- Canvas canvas = new Canvas(bmpGrayscale); +- Paint paint = new Paint(); +- ColorMatrixColorFilter colorMatrixFilter = +- new ColorMatrixColorFilter(BITMAP_RGBA_GRAYSCALE_TRANSFORMATION); +- paint.setColorFilter((ColorFilter) colorMatrixFilter); +- canvas.drawBitmap(image.getBitmap(), 0.0F, 0.0F, paint); ++ /** ++ * Applies the transformation to grayscale and returns a {@link TensorImage}. ++ * ++ * <p>If the input image is already {@link ++ * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}, this op will be a no-op. ++ * ++ * @throws IllegalArgumentException if the {@code image} is not {@link ++ * org.tensorflow.lite.support.image.ColorSpaceType#RGB} or {@link ++ * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}. ++ */ ++ @Override ++ public TensorImage apply(TensorImage image) { ++ if (image.getColorSpaceType() == ColorSpaceType.GRAYSCALE) { ++ return image; ++ } else { ++ checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB, ++ "Only RGB images are supported in TransformToGrayscaleOp, but not " ++ + image.getColorSpaceType().name()); ++ } ++ int h = image.getHeight(); ++ int w = image.getWidth(); ++ Bitmap bmpGrayscale = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888); ++ Canvas canvas = new Canvas(bmpGrayscale); ++ Paint paint = new Paint(); ++ ColorMatrixColorFilter colorMatrixFilter = ++ new ColorMatrixColorFilter(BITMAP_RGBA_GRAYSCALE_TRANSFORMATION); ++ paint.setColorFilter((ColorFilter) colorMatrixFilter); ++ canvas.drawBitmap(image.getBitmap(), 0.0F, 0.0F, paint); + +- // Get the pixels from the generated grayscale image +- int[] intValues = new int[w * h]; +- bmpGrayscale.getPixels(intValues, 0, w, 0, 0, w, h); +- // Shape with one channel +- int[] shape = new int[] {1, h, w, 1}; ++ // Get the pixels from the generated grayscale image ++ int[] intValues = new int[w * h]; ++ bmpGrayscale.getPixels(intValues, 0, w, 0, 0, w, h); ++ // Shape with one channel ++ int[] shape = new int[] {1, h, w, 1}; + +- // Get R channel from ARGB color +- for (int i = 0; i < intValues.length; i++) { +- intValues[i] = ((intValues[i] >> 16) & 0xff); ++ // Get R channel from ARGB color ++ for (int i = 0; i < intValues.length; i++) { ++ intValues[i] = ((intValues[i] >> 16) & 0xff); ++ } ++ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, image.getDataType()); ++ buffer.loadArray(intValues, shape); ++ image.load(buffer, ColorSpaceType.GRAYSCALE); ++ return image; + } +- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, image.getDataType()); +- buffer.loadArray(intValues, shape); +- image.load(buffer, ColorSpaceType.GRAYSCALE); +- return image; +- } + +- @Override +- public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { +- return inputImageHeight; +- } ++ @Override ++ public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { ++ return inputImageHeight; ++ } + +- @Override +- public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { +- return inputImageWidth; +- } ++ @Override ++ public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { ++ return inputImageWidth; ++ } + +- @Override +- public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { +- return point; +- } ++ @Override ++ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { ++ return point; ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java +index 8135ddcc28619..af56b70a77cf3 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java +@@ -15,9 +15,10 @@ limitations under the License. + + package org.tensorflow.lite.support.label; + +-import java.util.Objects; + import org.tensorflow.lite.annotations.UsedByReflection; + ++import java.util.Objects; ++ + /** + * Category is a util class, contains a label, its display name, a float value as score, and the + * index of the label in the corresponding label file. Typically it's used as result of +@@ -25,102 +26,97 @@ import org.tensorflow.lite.annotations.UsedByReflection; + */ + @UsedByReflection("TFLiteSupport/Task") + public final class Category { +- private static final int DEFAULT_INDEX = -1; +- private static final float TOLERANCE = 1e-6f; +- private final int index; +- private final String label; +- private final String displayName; +- private final float score; +- +- /** +- * Constructs a {@link Category} object. +- * +- * @param label the label of this category object +- * @param displayName the display name of the label, which may be translated for different +- * locales. For exmaple, a label, "apple", may be translated into Spanish for display purpose, +- * so that the displayName is "manzana". +- * @param score the probability score of this label category +- * @param index the index of the label in the corresponding label file +- */ +- @UsedByReflection("TFLiteSupport/Task") +- public static Category create(String label, String displayName, float score, int index) { +- return new Category(label, displayName, score, index); +- } +- +- /** Constructs a {@link Category} object with the default index (-1). */ +- @UsedByReflection("TFLiteSupport/Task") +- public static Category create(String label, String displayName, float score) { +- return new Category(label, displayName, score, DEFAULT_INDEX); +- } +- +- /** Constructs a {@link Category} object with an empty displayName and the default index (-1). */ +- @UsedByReflection("TFLiteSupport/Task") +- public Category(String label, float score) { +- this(label, /*displayName=*/ "", score, DEFAULT_INDEX); +- } +- +- private Category(String label, String displayName, float score, int index) { +- this.label = label; +- this.displayName = displayName; +- this.score = score; +- this.index = index; +- } +- +- /** Gets the reference of category's label. */ +- public String getLabel() { +- return label; +- } +- +- /** +- * Gets the reference of category's displayName, a name in locale of the label. +- * +- * <p>The display name can be an empty string if this {@link Category} object is constructed +- * without displayName, such as when using {@link #Category(String label, float score)}. +- */ +- public String getDisplayName() { +- return displayName; +- } +- +- /** Gets the score of the category. */ +- public float getScore() { +- return score; +- } +- +- /** +- * Gets the index of the category. The index value might be -1, which means it has not been set up +- * properly and is invalid. +- */ +- public int getIndex() { +- return index; +- } +- +- @Override +- public boolean equals(Object o) { +- if (o instanceof Category) { +- Category other = (Category) o; +- return (other.getLabel().equals(this.label) +- && other.getDisplayName().equals(this.displayName) +- && Math.abs(other.getScore() - this.score) < TOLERANCE +- && other.getIndex() == this.index); ++ private static final int DEFAULT_INDEX = -1; ++ private static final float TOLERANCE = 1e-6f; ++ private final int index; ++ private final String label; ++ private final String displayName; ++ private final float score; ++ ++ /** ++ * Constructs a {@link Category} object. ++ * ++ * @param label the label of this category object ++ * @param displayName the display name of the label, which may be translated for different ++ * locales. For exmaple, a label, "apple", may be translated into Spanish for display ++ * purpose, so that the displayName is "manzana". ++ * @param score the probability score of this label category ++ * @param index the index of the label in the corresponding label file ++ */ ++ @UsedByReflection("TFLiteSupport/Task") ++ public static Category create(String label, String displayName, float score, int index) { ++ return new Category(label, displayName, score, index); ++ } ++ ++ /** Constructs a {@link Category} object with the default index (-1). */ ++ @UsedByReflection("TFLiteSupport/Task") ++ public static Category create(String label, String displayName, float score) { ++ return new Category(label, displayName, score, DEFAULT_INDEX); ++ } ++ ++ /** ++ * Constructs a {@link Category} object with an empty displayName and the default index (-1). ++ */ ++ @UsedByReflection("TFLiteSupport/Task") ++ public Category(String label, float score) { ++ this(label, /*displayName=*/"", score, DEFAULT_INDEX); ++ } ++ ++ private Category(String label, String displayName, float score, int index) { ++ this.label = label; ++ this.displayName = displayName; ++ this.score = score; ++ this.index = index; ++ } ++ ++ /** Gets the reference of category's label. */ ++ public String getLabel() { ++ return label; ++ } ++ ++ /** ++ * Gets the reference of category's displayName, a name in locale of the label. ++ * ++ * <p>The display name can be an empty string if this {@link Category} object is constructed ++ * without displayName, such as when using {@link #Category(String label, float score)}. ++ */ ++ public String getDisplayName() { ++ return displayName; ++ } ++ ++ /** Gets the score of the category. */ ++ public float getScore() { ++ return score; ++ } ++ ++ /** ++ * Gets the index of the category. The index value might be -1, which means it has not been set ++ * up properly and is invalid. ++ */ ++ public int getIndex() { ++ return index; ++ } ++ ++ @Override ++ public boolean equals(Object o) { ++ if (o instanceof Category) { ++ Category other = (Category) o; ++ return (other.getLabel().equals(this.label) ++ && other.getDisplayName().equals(this.displayName) ++ && Math.abs(other.getScore() - this.score) < TOLERANCE ++ && other.getIndex() == this.index); ++ } ++ return false; ++ } ++ ++ @Override ++ public int hashCode() { ++ return Objects.hash(label, displayName, score, index); ++ } ++ ++ @Override ++ public String toString() { ++ return "<Category \"" + label + "\" (displayName=" + displayName + " score=" + score ++ + " index=" + index + ")>"; + } +- return false; +- } +- +- @Override +- public int hashCode() { +- return Objects.hash(label, displayName, score, index); +- } +- +- @Override +- public String toString() { +- return "<Category \"" +- + label +- + "\" (displayName=" +- + displayName +- + " score=" +- + score +- + " index=" +- + index +- + ")>"; +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java +index af21d74e25f5d..56ee89f091e03 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java +@@ -16,49 +16,52 @@ limitations under the License. + package org.tensorflow.lite.support.label; + + import android.util.Log; +-import java.util.ArrayList; +-import java.util.Arrays; +-import java.util.List; ++ + import org.checkerframework.checker.nullness.qual.NonNull; + import org.tensorflow.lite.support.common.internal.SupportPreconditions; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + ++import java.util.ArrayList; ++import java.util.Arrays; ++import java.util.List; ++ + /** Label operation utils. */ + public class LabelUtil { +- /** +- * Maps an int value tensor to a list of string labels. It takes an array of strings as the +- * dictionary. Example: if the given tensor is [3, 1, 0], and given labels is ["background", +- * "apple", "banana", "cherry", "date"], the result will be ["date", "banana", "apple"]. +- * +- * @param tensorBuffer A tensor with index values. The values should be non-negative integers, and +- * each value {@code x} will be converted to {@code labels[x + offset]}. If the tensor is +- * given as a float {@link TensorBuffer}, values will be cast to integers. All values that are +- * out of bound will map to empty string. +- * @param labels A list of strings, used as a dictionary to look up. The index of the array +- * element will be used as the key. To get better performance, use an object that implements +- * RandomAccess, such as {@link ArrayList}. +- * @param offset The offset value when look up int values in the {@code labels}. +- * @return the mapped strings. The length of the list is {@link TensorBuffer#getFlatSize}. +- * @throws IllegalArgumentException if {@code tensorBuffer} or {@code labels} is null. +- */ +- public static List<String> mapValueToLabels( +- @NonNull TensorBuffer tensorBuffer, @NonNull List<String> labels, int offset) { +- SupportPreconditions.checkNotNull(tensorBuffer, "Given tensor should not be null"); +- SupportPreconditions.checkNotNull(labels, "Given labels should not be null"); +- int[] values = tensorBuffer.getIntArray(); +- Log.d("values", Arrays.toString(values)); +- List<String> result = new ArrayList<>(); +- for (int v : values) { +- int index = v + offset; +- if (index < 0 || index >= labels.size()) { +- result.add(""); +- } else { +- result.add(labels.get(index)); +- } ++ /** ++ * Maps an int value tensor to a list of string labels. It takes an array of strings as the ++ * dictionary. Example: if the given tensor is [3, 1, 0], and given labels is ["background", ++ * "apple", "banana", "cherry", "date"], the result will be ["date", "banana", "apple"]. ++ * ++ * @param tensorBuffer A tensor with index values. The values should be non-negative integers, ++ * and ++ * each value {@code x} will be converted to {@code labels[x + offset]}. If the tensor is ++ * given as a float {@link TensorBuffer}, values will be cast to integers. All values that ++ * are out of bound will map to empty string. ++ * @param labels A list of strings, used as a dictionary to look up. The index of the array ++ * element will be used as the key. To get better performance, use an object that implements ++ * RandomAccess, such as {@link ArrayList}. ++ * @param offset The offset value when look up int values in the {@code labels}. ++ * @return the mapped strings. The length of the list is {@link TensorBuffer#getFlatSize}. ++ * @throws IllegalArgumentException if {@code tensorBuffer} or {@code labels} is null. ++ */ ++ public static List<String> mapValueToLabels( ++ @NonNull TensorBuffer tensorBuffer, @NonNull List<String> labels, int offset) { ++ SupportPreconditions.checkNotNull(tensorBuffer, "Given tensor should not be null"); ++ SupportPreconditions.checkNotNull(labels, "Given labels should not be null"); ++ int[] values = tensorBuffer.getIntArray(); ++ Log.d("values", Arrays.toString(values)); ++ List<String> result = new ArrayList<>(); ++ for (int v : values) { ++ int index = v + offset; ++ if (index < 0 || index >= labels.size()) { ++ result.add(""); ++ } else { ++ result.add(labels.get(index)); ++ } ++ } ++ return result; + } +- return result; +- } + +- // Private constructor to prevent initialization. +- private LabelUtil() {} ++ // Private constructor to prevent initialization. ++ private LabelUtil() {} + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java +index bdab7cf464c1b..edd683cd08126 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java +@@ -16,16 +16,18 @@ limitations under the License. + package org.tensorflow.lite.support.label; + + import android.content.Context; ++ ++import org.checkerframework.checker.nullness.qual.NonNull; ++import org.tensorflow.lite.DataType; ++import org.tensorflow.lite.support.common.internal.SupportPreconditions; ++import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; ++ + import java.nio.ByteBuffer; + import java.util.ArrayList; + import java.util.Arrays; + import java.util.LinkedHashMap; + import java.util.List; + import java.util.Map; +-import org.checkerframework.checker.nullness.qual.NonNull; +-import org.tensorflow.lite.DataType; +-import org.tensorflow.lite.support.common.internal.SupportPreconditions; +-import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + + /** + * TensorLabel is an util wrapper for TensorBuffers with meaningful labels on an axis. +@@ -56,169 +58,170 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + * a label file (plain text file whose each line is a label) in assets simply. + */ + public class TensorLabel { +- private final Map<Integer, List<String>> axisLabels; +- private final TensorBuffer tensorBuffer; +- private final int[] shape; +- +- /** +- * Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors. +- * +- * @param axisLabels A map, whose key is axis id (starting from 0) and value is corresponding +- * labels. Note: The size of labels should be same with the size of the tensor on that axis. +- * @param tensorBuffer The TensorBuffer to be labeled. +- * @throws NullPointerException if {@code axisLabels} or {@code tensorBuffer} is null, or any +- * value in {@code axisLabels} is null. +- * @throws IllegalArgumentException if any key in {@code axisLabels} is out of range (compared to +- * the shape of {@code tensorBuffer}, or any value (labels) has different size with the {@code +- * tensorBuffer} on the given dimension. +- */ +- public TensorLabel( +- @NonNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer) { +- SupportPreconditions.checkNotNull(axisLabels, "Axis labels cannot be null."); +- SupportPreconditions.checkNotNull(tensorBuffer, "Tensor Buffer cannot be null."); +- this.axisLabels = axisLabels; +- this.tensorBuffer = tensorBuffer; +- this.shape = tensorBuffer.getShape(); +- for (Map.Entry<Integer, List<String>> entry : axisLabels.entrySet()) { +- int axis = entry.getKey(); +- SupportPreconditions.checkArgument( +- axis >= 0 && axis < shape.length, "Invalid axis id: " + axis); +- SupportPreconditions.checkNotNull(entry.getValue(), "Label list is null on axis " + axis); +- SupportPreconditions.checkArgument( +- shape[axis] == entry.getValue().size(), +- "Label number " + entry.getValue().size() + " mismatch the shape on axis " + axis); ++ private final Map<Integer, List<String>> axisLabels; ++ private final TensorBuffer tensorBuffer; ++ private final int[] shape; ++ ++ /** ++ * Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors. ++ * ++ * @param axisLabels A map, whose key is axis id (starting from 0) and value is corresponding ++ * labels. Note: The size of labels should be same with the size of the tensor on that axis. ++ * @param tensorBuffer The TensorBuffer to be labeled. ++ * @throws NullPointerException if {@code axisLabels} or {@code tensorBuffer} is null, or any ++ * value in {@code axisLabels} is null. ++ * @throws IllegalArgumentException if any key in {@code axisLabels} is out of range (compared ++ * to ++ * the shape of {@code tensorBuffer}, or any value (labels) has different size with the ++ * {@code tensorBuffer} on the given dimension. ++ */ ++ public TensorLabel( ++ @NonNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer) { ++ SupportPreconditions.checkNotNull(axisLabels, "Axis labels cannot be null."); ++ SupportPreconditions.checkNotNull(tensorBuffer, "Tensor Buffer cannot be null."); ++ this.axisLabels = axisLabels; ++ this.tensorBuffer = tensorBuffer; ++ this.shape = tensorBuffer.getShape(); ++ for (Map.Entry<Integer, List<String>> entry : axisLabels.entrySet()) { ++ int axis = entry.getKey(); ++ SupportPreconditions.checkArgument( ++ axis >= 0 && axis < shape.length, "Invalid axis id: " + axis); ++ SupportPreconditions.checkNotNull( ++ entry.getValue(), "Label list is null on axis " + axis); ++ SupportPreconditions.checkArgument(shape[axis] == entry.getValue().size(), ++ "Label number " + entry.getValue().size() + " mismatch the shape on axis " ++ + axis); ++ } + } +- } +- +- /** +- * Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors. +- * +- * <p>Note: The labels are applied on the first axis whose size is larger than 1. For example, if +- * the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting from +- * 0), and size of {@code axisLabels} should be 10 as well. +- * +- * @param axisLabels A list of labels, whose size should be same with the size of the tensor on +- * the to-be-labeled axis. +- * @param tensorBuffer The TensorBuffer to be labeled. +- */ +- public TensorLabel(@NonNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer) { +- this(makeMap(getFirstAxisWithSizeGreaterThanOne(tensorBuffer), axisLabels), tensorBuffer); +- } +- +- /** +- * Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the +- * mapping on the first axis with size greater than 1 currently. +- */ +- @NonNull +- public Map<String, TensorBuffer> getMapWithTensorBuffer() { +- int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer); +- +- Map<String, TensorBuffer> labelToTensorMap = new LinkedHashMap<>(); +- SupportPreconditions.checkArgument( +- axisLabels.containsKey(labeledAxis), +- "get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis."); +- List<String> labels = axisLabels.get(labeledAxis); +- +- DataType dataType = tensorBuffer.getDataType(); +- int typeSize = tensorBuffer.getTypeSize(); +- int flatSize = tensorBuffer.getFlatSize(); +- +- // Gets the underlying bytes that could be used to generate the sub-array later. +- ByteBuffer byteBuffer = tensorBuffer.getBuffer(); +- byteBuffer.rewind(); +- +- // Note: computation below is only correct when labeledAxis is the first axis with size greater +- // than 1. +- int subArrayLength = flatSize / shape[labeledAxis] * typeSize; +- int i = 0; +- SupportPreconditions.checkNotNull(labels, "Label list should never be null"); +- for (String label : labels) { +- // Gets the corresponding TensorBuffer. +- byteBuffer.position(i * subArrayLength); +- ByteBuffer subBuffer = byteBuffer.slice(); +- // ByteBuffer.slice doesn't keep order. Modify it to align with the original one. +- subBuffer.order(byteBuffer.order()).limit(subArrayLength); +- TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType); +- labelBuffer.loadBuffer(subBuffer, Arrays.copyOfRange(shape, labeledAxis + 1, shape.length)); +- labelToTensorMap.put(label, labelBuffer); +- i += 1; ++ ++ /** ++ * Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors. ++ * ++ * <p>Note: The labels are applied on the first axis whose size is larger than 1. For example, ++ * if the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting ++ * from 0), and size of {@code axisLabels} should be 10 as well. ++ * ++ * @param axisLabels A list of labels, whose size should be same with the size of the tensor on ++ * the to-be-labeled axis. ++ * @param tensorBuffer The TensorBuffer to be labeled. ++ */ ++ public TensorLabel(@NonNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer) { ++ this(makeMap(getFirstAxisWithSizeGreaterThanOne(tensorBuffer), axisLabels), tensorBuffer); + } +- return labelToTensorMap; +- } +- +- /** +- * Gets a map that maps label to float. Only allow the mapping on the first axis with size greater +- * than 1, and the axis should be effectively the last axis (which means every sub tensor +- * specified by this axis should have a flat size of 1). +- * +- * <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result. +- * +- * @throws IllegalStateException if size of a sub tensor on each label is not 1. +- */ +- @NonNull +- public Map<String, Float> getMapWithFloatValue() { +- int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer); +- SupportPreconditions.checkState( +- labeledAxis == shape.length - 1, +- "get a <String, Scalar> map is only valid when the only labeled axis is the last one."); +- List<String> labels = axisLabels.get(labeledAxis); +- float[] data = tensorBuffer.getFloatArray(); +- SupportPreconditions.checkState(labels.size() == data.length); +- Map<String, Float> result = new LinkedHashMap<>(); +- int i = 0; +- for (String label : labels) { +- result.put(label, data[i]); +- i += 1; ++ ++ /** ++ * Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the ++ * mapping on the first axis with size greater than 1 currently. ++ */ ++ @NonNull ++ public Map<String, TensorBuffer> getMapWithTensorBuffer() { ++ int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer); ++ ++ Map<String, TensorBuffer> labelToTensorMap = new LinkedHashMap<>(); ++ SupportPreconditions.checkArgument(axisLabels.containsKey(labeledAxis), ++ "get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis."); ++ List<String> labels = axisLabels.get(labeledAxis); ++ ++ DataType dataType = tensorBuffer.getDataType(); ++ int typeSize = tensorBuffer.getTypeSize(); ++ int flatSize = tensorBuffer.getFlatSize(); ++ ++ // Gets the underlying bytes that could be used to generate the sub-array later. ++ ByteBuffer byteBuffer = tensorBuffer.getBuffer(); ++ byteBuffer.rewind(); ++ ++ // Note: computation below is only correct when labeledAxis is the first axis with size ++ // greater than 1. ++ int subArrayLength = flatSize / shape[labeledAxis] * typeSize; ++ int i = 0; ++ SupportPreconditions.checkNotNull(labels, "Label list should never be null"); ++ for (String label : labels) { ++ // Gets the corresponding TensorBuffer. ++ byteBuffer.position(i * subArrayLength); ++ ByteBuffer subBuffer = byteBuffer.slice(); ++ // ByteBuffer.slice doesn't keep order. Modify it to align with the original one. ++ subBuffer.order(byteBuffer.order()).limit(subArrayLength); ++ TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType); ++ labelBuffer.loadBuffer( ++ subBuffer, Arrays.copyOfRange(shape, labeledAxis + 1, shape.length)); ++ labelToTensorMap.put(label, labelBuffer); ++ i += 1; ++ } ++ return labelToTensorMap; + } +- return result; +- } +- +- /** +- * Gets a list of {@link Category} from the {@link TensorLabel} object. +- * +- * <p>The axis of label should be effectively the last axis (which means every sub tensor +- * specified by this axis should have a flat size of 1), so that each labelled sub tensor could be +- * converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2, 5, 3}} +- * and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link Category}. +- * +- * <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as +- * the result. +- * +- * @throws IllegalStateException if size of a sub tensor on each label is not 1. +- */ +- @NonNull +- public List<Category> getCategoryList() { +- int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer); +- SupportPreconditions.checkState( +- labeledAxis == shape.length - 1, +- "get a Category list is only valid when the only labeled axis is the last one."); +- List<String> labels = axisLabels.get(labeledAxis); +- float[] data = tensorBuffer.getFloatArray(); +- SupportPreconditions.checkState(labels.size() == data.length); +- List<Category> result = new ArrayList<>(); +- int i = 0; +- for (String label : labels) { +- result.add(new Category(label, data[i])); +- i += 1; ++ ++ /** ++ * Gets a map that maps label to float. Only allow the mapping on the first axis with size ++ * greater than 1, and the axis should be effectively the last axis (which means every sub ++ * tensor specified by this axis should have a flat size of 1). ++ * ++ * <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result. ++ * ++ * @throws IllegalStateException if size of a sub tensor on each label is not 1. ++ */ ++ @NonNull ++ public Map<String, Float> getMapWithFloatValue() { ++ int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer); ++ SupportPreconditions.checkState(labeledAxis == shape.length - 1, ++ "get a <String, Scalar> map is only valid when the only labeled axis is the last one."); ++ List<String> labels = axisLabels.get(labeledAxis); ++ float[] data = tensorBuffer.getFloatArray(); ++ SupportPreconditions.checkState(labels.size() == data.length); ++ Map<String, Float> result = new LinkedHashMap<>(); ++ int i = 0; ++ for (String label : labels) { ++ result.put(label, data[i]); ++ i += 1; ++ } ++ return result; + } +- return result; +- } +- +- private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) { +- int[] shape = tensorBuffer.getShape(); +- for (int i = 0; i < shape.length; i++) { +- if (shape[i] > 1) { +- return i; +- } ++ ++ /** ++ * Gets a list of {@link Category} from the {@link TensorLabel} object. ++ * ++ * <p>The axis of label should be effectively the last axis (which means every sub tensor ++ * specified by this axis should have a flat size of 1), so that each labelled sub tensor could ++ * be converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2, ++ * 5, 3}} and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link ++ * Category}. ++ * ++ * <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as ++ * the result. ++ * ++ * @throws IllegalStateException if size of a sub tensor on each label is not 1. ++ */ ++ @NonNull ++ public List<Category> getCategoryList() { ++ int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer); ++ SupportPreconditions.checkState(labeledAxis == shape.length - 1, ++ "get a Category list is only valid when the only labeled axis is the last one."); ++ List<String> labels = axisLabels.get(labeledAxis); ++ float[] data = tensorBuffer.getFloatArray(); ++ SupportPreconditions.checkState(labels.size() == data.length); ++ List<Category> result = new ArrayList<>(); ++ int i = 0; ++ for (String label : labels) { ++ result.add(new Category(label, data[i])); ++ i += 1; ++ } ++ return result; ++ } ++ ++ private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) { ++ int[] shape = tensorBuffer.getShape(); ++ for (int i = 0; i < shape.length; i++) { ++ if (shape[i] > 1) { ++ return i; ++ } ++ } ++ throw new IllegalArgumentException( ++ "Cannot find an axis to label. A valid axis to label should have size larger than 1."); ++ } ++ ++ // Helper function to wrap the List<String> to a one-entry map. ++ private static Map<Integer, List<String>> makeMap(int axis, List<String> labels) { ++ Map<Integer, List<String>> map = new LinkedHashMap<>(); ++ map.put(axis, labels); ++ return map; + } +- throw new IllegalArgumentException( +- "Cannot find an axis to label. A valid axis to label should have size larger than 1."); +- } +- +- // Helper function to wrap the List<String> to a one-entry map. +- private static Map<Integer, List<String>> makeMap(int axis, List<String> labels) { +- Map<Integer, List<String>> map = new LinkedHashMap<>(); +- map.put(axis, labels); +- return map; +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java +index ed47f65a726a6..e44edc64f4969 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java +@@ -16,16 +16,18 @@ limitations under the License. + package org.tensorflow.lite.support.label.ops; + + import android.content.Context; +-import java.io.IOException; +-import java.util.HashMap; +-import java.util.List; +-import java.util.Map; ++ + import org.checkerframework.checker.nullness.qual.NonNull; + import org.tensorflow.lite.support.common.FileUtil; + import org.tensorflow.lite.support.common.internal.SupportPreconditions; + import org.tensorflow.lite.support.label.TensorLabel; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + ++import java.io.IOException; ++import java.util.HashMap; ++import java.util.List; ++import java.util.Map; ++ + /** + * Labels TensorBuffer with axisLabels for outputs. + * +@@ -33,42 +35,42 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + * a pair of the label name and the corresponding TensorBuffer value. + */ + public class LabelAxisOp { +- // Axis and its corresponding label names. +- private final Map<Integer, List<String>> axisLabels; +- +- protected LabelAxisOp(Builder builder) { +- axisLabels = builder.axisLabels; +- } +- +- public TensorLabel apply(@NonNull TensorBuffer buffer) { +- SupportPreconditions.checkNotNull(buffer, "Tensor buffer cannot be null."); +- return new TensorLabel(axisLabels, buffer); +- } +- +- /** The inner builder class to build a LabelTensor Operator. */ +- public static class Builder { ++ // Axis and its corresponding label names. + private final Map<Integer, List<String>> axisLabels; + +- protected Builder() { +- axisLabels = new HashMap<>(); ++ protected LabelAxisOp(Builder builder) { ++ axisLabels = builder.axisLabels; + } + +- public Builder addAxisLabel(@NonNull Context context, int axis, @NonNull String filePath) +- throws IOException { +- SupportPreconditions.checkNotNull(context, "Context cannot be null."); +- SupportPreconditions.checkNotNull(filePath, "File path cannot be null."); +- List<String> labels = FileUtil.loadLabels(context, filePath); +- axisLabels.put(axis, labels); +- return this; ++ public TensorLabel apply(@NonNull TensorBuffer buffer) { ++ SupportPreconditions.checkNotNull(buffer, "Tensor buffer cannot be null."); ++ return new TensorLabel(axisLabels, buffer); + } + +- public Builder addAxisLabel(int axis, @NonNull List<String> labels) { +- axisLabels.put(axis, labels); +- return this; +- } ++ /** The inner builder class to build a LabelTensor Operator. */ ++ public static class Builder { ++ private final Map<Integer, List<String>> axisLabels; ++ ++ protected Builder() { ++ axisLabels = new HashMap<>(); ++ } ++ ++ public Builder addAxisLabel(@NonNull Context context, int axis, @NonNull String filePath) ++ throws IOException { ++ SupportPreconditions.checkNotNull(context, "Context cannot be null."); ++ SupportPreconditions.checkNotNull(filePath, "File path cannot be null."); ++ List<String> labels = FileUtil.loadLabels(context, filePath); ++ axisLabels.put(axis, labels); ++ return this; ++ } ++ ++ public Builder addAxisLabel(int axis, @NonNull List<String> labels) { ++ axisLabels.put(axis, labels); ++ return this; ++ } + +- public LabelAxisOp build() { +- return new LabelAxisOp(this); ++ public LabelAxisOp build() { ++ return new LabelAxisOp(this); ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java +index 9cfcf923dedee..ada9b33fb0eea 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java +@@ -16,54 +16,55 @@ limitations under the License. + package org.tensorflow.lite.support.model; + + import android.util.Log; +-import java.io.Closeable; +-import java.io.IOException; ++ + import org.checkerframework.checker.nullness.qual.Nullable; + import org.tensorflow.lite.Delegate; + ++import java.io.Closeable; ++import java.io.IOException; ++ + /** + * Helper class to create and call necessary methods of {@code GpuDelegate} which is not a strict + * dependency. + */ + class GpuDelegateProxy implements Delegate, Closeable { ++ private static final String TAG = "GpuDelegateProxy"; + +- private static final String TAG = "GpuDelegateProxy"; +- +- private final Delegate proxiedDelegate; +- private final Closeable proxiedCloseable; ++ private final Delegate proxiedDelegate; ++ private final Closeable proxiedCloseable; + +- @Nullable +- public static GpuDelegateProxy maybeNewInstance() { +- try { +- Class<?> clazz = Class.forName("org.tensorflow.lite.gpu.GpuDelegate"); +- Object instance = clazz.getDeclaredConstructor().newInstance(); +- return new GpuDelegateProxy(instance); +- } catch (ReflectiveOperationException e) { +- Log.e(TAG, "Failed to create the GpuDelegate dynamically.", e); +- return null; ++ @Nullable ++ public static GpuDelegateProxy maybeNewInstance() { ++ try { ++ Class<?> clazz = Class.forName("org.tensorflow.lite.gpu.GpuDelegate"); ++ Object instance = clazz.getDeclaredConstructor().newInstance(); ++ return new GpuDelegateProxy(instance); ++ } catch (ReflectiveOperationException e) { ++ Log.e(TAG, "Failed to create the GpuDelegate dynamically.", e); ++ return null; ++ } + } +- } + +- /** Calls {@code close()} method of the delegate. */ +- @Override +- public void close() { +- try { +- proxiedCloseable.close(); +- } catch (IOException e) { +- // Should not trigger, because GpuDelegate#close never throws. The catch is required because +- // of Closeable#close. +- Log.e(TAG, "Failed to close the GpuDelegate.", e); ++ /** Calls {@code close()} method of the delegate. */ ++ @Override ++ public void close() { ++ try { ++ proxiedCloseable.close(); ++ } catch (IOException e) { ++ // Should not trigger, because GpuDelegate#close never throws. The catch is required ++ // because of Closeable#close. ++ Log.e(TAG, "Failed to close the GpuDelegate.", e); ++ } + } +- } + +- /** Calls {@code getNativeHandle()} method of the delegate. */ +- @Override +- public long getNativeHandle() { +- return proxiedDelegate.getNativeHandle(); +- } ++ /** Calls {@code getNativeHandle()} method of the delegate. */ ++ @Override ++ public long getNativeHandle() { ++ return proxiedDelegate.getNativeHandle(); ++ } + +- private GpuDelegateProxy(Object instance) { +- this.proxiedCloseable = (Closeable) instance; +- this.proxiedDelegate = (Delegate) instance; +- } ++ private GpuDelegateProxy(Object instance) { ++ this.proxiedCloseable = (Closeable) instance; ++ this.proxiedDelegate = (Delegate) instance; ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java +index 1c37c1b3d800d..af2061e948970 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java +@@ -16,9 +16,7 @@ limitations under the License. + package org.tensorflow.lite.support.model; + + import android.content.Context; +-import java.io.IOException; +-import java.nio.MappedByteBuffer; +-import java.util.Map; ++ + import org.checkerframework.checker.nullness.qual.NonNull; + import org.checkerframework.checker.nullness.qual.Nullable; + import org.tensorflow.lite.InterpreterApi; +@@ -27,6 +25,10 @@ import org.tensorflow.lite.Tensor; + import org.tensorflow.lite.support.common.FileUtil; + import org.tensorflow.lite.support.common.internal.SupportPreconditions; + ++import java.io.IOException; ++import java.nio.MappedByteBuffer; ++import java.util.Map; ++ + /** + * The wrapper class for a TFLite model and a TFLite interpreter. + * +@@ -34,253 +36,244 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions; + * interpreter instance to run it. + */ + public class Model { ++ /** The runtime device type used for executing classification. */ ++ public enum Device { CPU, NNAPI, GPU } + +- /** The runtime device type used for executing classification. */ +- public enum Device { +- CPU, +- NNAPI, +- GPU +- } +- +- /** +- * Options for running the model. Configurable parameters includes: +- * +- * <ul> +- * <li>{@code device} {@link Builder#setDevice(Device)} specifies the hardware to run the model. +- * The default value is {@link Device#CPU}. +- * <li>{@code numThreads} {@link Builder#setNumThreads(int)} specifies the number of threads +- * used by TFLite inference. It's only effective when device is set to {@link Device#CPU} +- * and default value is 1. +- * </ul> +- */ +- public static class Options { +- private final Device device; +- private final int numThreads; +- +- /** Builder of {@link Options}. See its doc for details. */ +- public static class Builder { +- private Device device = Device.CPU; +- private int numThreads = 1; +- +- public Builder setDevice(Device device) { +- this.device = device; +- return this; +- } +- +- public Builder setNumThreads(int numThreads) { +- this.numThreads = numThreads; +- return this; +- } +- +- public Options build() { +- return new Options(this); +- } ++ /** ++ * Options for running the model. Configurable parameters includes: ++ * ++ * <ul> ++ * <li>{@code device} {@link Builder#setDevice(Device)} specifies the hardware to run the ++ * model. The default value is {@link Device#CPU}. <li>{@code numThreads} {@link ++ * Builder#setNumThreads(int)} specifies the number of threads used by TFLite inference. It's ++ * only effective when device is set to {@link Device#CPU} and default value is 1. ++ * </ul> ++ */ ++ public static class Options { ++ private final Device device; ++ private final int numThreads; ++ ++ /** Builder of {@link Options}. See its doc for details. */ ++ public static class Builder { ++ private Device device = Device.CPU; ++ private int numThreads = 1; ++ ++ public Builder setDevice(Device device) { ++ this.device = device; ++ return this; ++ } ++ ++ public Builder setNumThreads(int numThreads) { ++ this.numThreads = numThreads; ++ return this; ++ } ++ ++ public Options build() { ++ return new Options(this); ++ } ++ } ++ ++ private Options(Builder builder) { ++ device = builder.device; ++ numThreads = builder.numThreads; ++ } + } + +- private Options(Builder builder) { +- device = builder.device; +- numThreads = builder.numThreads; +- } +- } ++ /** An instance of the driver class to run model inference with Tensorflow Lite. */ ++ private final InterpreterApi interpreter; + +- /** An instance of the driver class to run model inference with Tensorflow Lite. */ +- private final InterpreterApi interpreter; ++ /** Path to tflite model file in asset folder. */ ++ private final String modelPath; + +- /** Path to tflite model file in asset folder. */ +- private final String modelPath; ++ /** The memory-mapped model data. */ ++ private final MappedByteBuffer byteModel; + +- /** The memory-mapped model data. */ +- private final MappedByteBuffer byteModel; ++ private final GpuDelegateProxy gpuDelegateProxy; + +- private final GpuDelegateProxy gpuDelegateProxy; ++ /** ++ * Builder for {@link Model}. ++ * ++ * @deprecated Please use {@link Model#createModel(Context, String, Options)}. ++ */ ++ @Deprecated ++ public static class Builder { ++ private Device device = Device.CPU; ++ private int numThreads = 1; ++ private final String modelPath; ++ private final MappedByteBuffer byteModel; ++ ++ /** ++ * Creates a builder which loads tflite model from asset folder using memory-mapped files. ++ * ++ * @param context Application context to access assets. ++ * @param modelPath Asset path of the model (.tflite file). ++ * @throws IOException if an I/O error occurs when loading the tflite model. ++ */ ++ @NonNull ++ public Builder(@NonNull Context context, @NonNull String modelPath) throws IOException { ++ this.modelPath = modelPath; ++ byteModel = FileUtil.loadMappedFile(context, modelPath); ++ } ++ ++ /** Sets running device. By default, TFLite will run on CPU. */ ++ @NonNull ++ public Builder setDevice(Device device) { ++ this.device = device; ++ return this; ++ } ++ ++ /** Sets number of threads. By default it's 1. */ ++ @NonNull ++ public Builder setNumThreads(int numThreads) { ++ this.numThreads = numThreads; ++ return this; ++ } ++ ++ // Note: The implementation is copied from `Model#createModel`. As the builder is going to ++ // be deprecated, this function is also to be removed. ++ @NonNull ++ public Model build() { ++ Options options = ++ new Options.Builder().setNumThreads(numThreads).setDevice(device).build(); ++ return createModel(byteModel, modelPath, options); ++ } ++ } + +- /** +- * Builder for {@link Model}. +- * +- * @deprecated Please use {@link Model#createModel(Context, String, Options)}. +- */ +- @Deprecated +- public static class Builder { +- private Device device = Device.CPU; +- private int numThreads = 1; +- private final String modelPath; +- private final MappedByteBuffer byteModel; ++ /** ++ * Loads a model from assets and initialize TFLite interpreter. ++ * ++ * <p>The default options are: (1) CPU device; (2) one thread. ++ * ++ * @param context The App Context. ++ * @param modelPath The path of the model file. ++ * @throws IOException if any exception occurs when open the model file. ++ */ ++ public static Model createModel(@NonNull Context context, @NonNull String modelPath) ++ throws IOException { ++ return createModel(context, modelPath, new Options.Builder().build()); ++ } + + /** +- * Creates a builder which loads tflite model from asset folder using memory-mapped files. ++ * Loads a model from assets and initialize TFLite interpreter with given options. + * +- * @param context Application context to access assets. +- * @param modelPath Asset path of the model (.tflite file). +- * @throws IOException if an I/O error occurs when loading the tflite model. ++ * @see Options for details. ++ * @param context The App Context. ++ * @param modelPath The path of the model file. ++ * @param options The options for running the model. ++ * @throws IOException if any exception occurs when open the model file. + */ +- @NonNull +- public Builder(@NonNull Context context, @NonNull String modelPath) throws IOException { +- this.modelPath = modelPath; +- byteModel = FileUtil.loadMappedFile(context, modelPath); ++ public static Model createModel(@NonNull Context context, @NonNull String modelPath, ++ @NonNull Options options) throws IOException { ++ SupportPreconditions.checkNotEmpty( ++ modelPath, "Model path in the asset folder cannot be empty."); ++ MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, modelPath); ++ return createModel(byteModel, modelPath, options); + } + +- /** Sets running device. By default, TFLite will run on CPU. */ +- @NonNull +- public Builder setDevice(Device device) { +- this.device = device; +- return this; ++ /** ++ * Creates a model with loaded {@link MappedByteBuffer}. ++ * ++ * @see Options for details. ++ * @param byteModel The loaded TFLite model. ++ * @param modelPath The original path of the model. It can be fetched later by {@link ++ * Model#getPath()}. ++ * @param options The options for running the model. ++ * @throws IllegalArgumentException if {@code options.device} is {@link Device#GPU} but ++ * "tensorflow-lite-gpu" is not linked to the project. ++ */ ++ public static Model createModel(@NonNull MappedByteBuffer byteModel, @NonNull String modelPath, ++ @NonNull Options options) { ++ InterpreterApi.Options interpreterOptions = new InterpreterApi.Options(); ++ GpuDelegateProxy gpuDelegateProxy = null; ++ switch (options.device) { ++ case NNAPI: ++ interpreterOptions.setUseNNAPI(true); ++ break; ++ case GPU: ++ gpuDelegateProxy = GpuDelegateProxy.maybeNewInstance(); ++ SupportPreconditions.checkArgument(gpuDelegateProxy != null, ++ "Cannot inference with GPU. Did you add \"tensorflow-lite-gpu\" as dependency?"); ++ interpreterOptions.addDelegate(gpuDelegateProxy); ++ break; ++ case CPU: ++ break; ++ } ++ interpreterOptions.setNumThreads(options.numThreads); ++ InterpreterApi interpreter = new InterpreterFactory().create(byteModel, interpreterOptions); ++ return new Model(modelPath, byteModel, interpreter, gpuDelegateProxy); + } + +- /** Sets number of threads. By default it's 1. */ ++ /** Returns the memory-mapped model data. */ + @NonNull +- public Builder setNumThreads(int numThreads) { +- this.numThreads = numThreads; +- return this; ++ public MappedByteBuffer getData() { ++ return byteModel; + } + +- // Note: The implementation is copied from `Model#createModel`. As the builder is going to be +- // deprecated, this function is also to be removed. ++ /** Returns the path of the model file stored in Assets. */ + @NonNull +- public Model build() { +- Options options = new Options.Builder().setNumThreads(numThreads).setDevice(device).build(); +- return createModel(byteModel, modelPath, options); ++ public String getPath() { ++ return modelPath; ++ } ++ ++ /** ++ * Gets the Tensor associated with the provided input index. ++ * ++ * @throws IllegalStateException if the interpreter is closed. ++ */ ++ public Tensor getInputTensor(int inputIndex) { ++ return interpreter.getInputTensor(inputIndex); ++ } ++ ++ /** ++ * Gets the Tensor associated with the provided output index. ++ * ++ * @throws IllegalStateException if the interpreter is closed. ++ */ ++ public Tensor getOutputTensor(int outputIndex) { ++ return interpreter.getOutputTensor(outputIndex); + } +- } +- +- /** +- * Loads a model from assets and initialize TFLite interpreter. +- * +- * <p>The default options are: (1) CPU device; (2) one thread. +- * +- * @param context The App Context. +- * @param modelPath The path of the model file. +- * @throws IOException if any exception occurs when open the model file. +- */ +- public static Model createModel(@NonNull Context context, @NonNull String modelPath) +- throws IOException { +- return createModel(context, modelPath, new Options.Builder().build()); +- } +- +- /** +- * Loads a model from assets and initialize TFLite interpreter with given options. +- * +- * @see Options for details. +- * @param context The App Context. +- * @param modelPath The path of the model file. +- * @param options The options for running the model. +- * @throws IOException if any exception occurs when open the model file. +- */ +- public static Model createModel( +- @NonNull Context context, @NonNull String modelPath, @NonNull Options options) +- throws IOException { +- SupportPreconditions.checkNotEmpty( +- modelPath, "Model path in the asset folder cannot be empty."); +- MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, modelPath); +- return createModel(byteModel, modelPath, options); +- } +- +- /** +- * Creates a model with loaded {@link MappedByteBuffer}. +- * +- * @see Options for details. +- * @param byteModel The loaded TFLite model. +- * @param modelPath The original path of the model. It can be fetched later by {@link +- * Model#getPath()}. +- * @param options The options for running the model. +- * @throws IllegalArgumentException if {@code options.device} is {@link Device#GPU} but +- * "tensorflow-lite-gpu" is not linked to the project. +- */ +- public static Model createModel( +- @NonNull MappedByteBuffer byteModel, @NonNull String modelPath, @NonNull Options options) { +- InterpreterApi.Options interpreterOptions = new InterpreterApi.Options(); +- GpuDelegateProxy gpuDelegateProxy = null; +- switch (options.device) { +- case NNAPI: +- interpreterOptions.setUseNNAPI(true); +- break; +- case GPU: +- gpuDelegateProxy = GpuDelegateProxy.maybeNewInstance(); +- SupportPreconditions.checkArgument( +- gpuDelegateProxy != null, +- "Cannot inference with GPU. Did you add \"tensorflow-lite-gpu\" as dependency?"); +- interpreterOptions.addDelegate(gpuDelegateProxy); +- break; +- case CPU: +- break; ++ ++ /** ++ * Returns the output shape. Useful if output shape is only determined when graph is created. ++ * ++ * @throws IllegalStateException if the interpreter is closed. ++ */ ++ public int[] getOutputTensorShape(int outputIndex) { ++ return interpreter.getOutputTensor(outputIndex).shape(); + } +- interpreterOptions.setNumThreads(options.numThreads); +- InterpreterApi interpreter = new InterpreterFactory().create(byteModel, interpreterOptions); +- return new Model(modelPath, byteModel, interpreter, gpuDelegateProxy); +- } +- +- /** Returns the memory-mapped model data. */ +- @NonNull +- public MappedByteBuffer getData() { +- return byteModel; +- } +- +- /** Returns the path of the model file stored in Assets. */ +- @NonNull +- public String getPath() { +- return modelPath; +- } +- +- /** +- * Gets the Tensor associated with the provided input index. +- * +- * @throws IllegalStateException if the interpreter is closed. +- */ +- public Tensor getInputTensor(int inputIndex) { +- return interpreter.getInputTensor(inputIndex); +- } +- +- /** +- * Gets the Tensor associated with the provided output index. +- * +- * @throws IllegalStateException if the interpreter is closed. +- */ +- public Tensor getOutputTensor(int outputIndex) { +- return interpreter.getOutputTensor(outputIndex); +- } +- +- /** +- * Returns the output shape. Useful if output shape is only determined when graph is created. +- * +- * @throws IllegalStateException if the interpreter is closed. +- */ +- public int[] getOutputTensorShape(int outputIndex) { +- return interpreter.getOutputTensor(outputIndex).shape(); +- } +- +- /** +- * Runs model inference on multiple inputs, and returns multiple outputs. +- * +- * @param inputs an array of input data. The inputs should be in the same order as inputs of the +- * model. Each input can be an array or multidimensional array, or a {@link +- * java.nio.ByteBuffer} of primitive types including int, float, long, and byte. {@link +- * java.nio.ByteBuffer} is the preferred way to pass large input data, whereas string types +- * require using the (multi-dimensional) array input path. When {@link java.nio.ByteBuffer} is +- * used, its content should remain unchanged until model inference is done. +- * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link +- * java.nio.ByteBuffer}s of primitive types including int, float, long, and byte. It only +- * needs to keep entries for the outputs to be used. +- */ +- public void run(@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) { +- interpreter.runForMultipleInputsOutputs(inputs, outputs); +- } +- +- public void close() { +- if (interpreter != null) { +- interpreter.close(); ++ ++ /** ++ * Runs model inference on multiple inputs, and returns multiple outputs. ++ * ++ * @param inputs an array of input data. The inputs should be in the same order as inputs of the ++ * model. Each input can be an array or multidimensional array, or a {@link ++ * java.nio.ByteBuffer} of primitive types including int, float, long, and byte. {@link ++ * java.nio.ByteBuffer} is the preferred way to pass large input data, whereas string types ++ * require using the (multi-dimensional) array input path. When {@link java.nio.ByteBuffer} ++ * is used, its content should remain unchanged until model inference is done. ++ * @param outputs a map mapping output indices to multidimensional arrays of output data or ++ * {@link ++ * java.nio.ByteBuffer}s of primitive types including int, float, long, and byte. It only ++ * needs to keep entries for the outputs to be used. ++ */ ++ public void run(@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) { ++ interpreter.runForMultipleInputsOutputs(inputs, outputs); + } +- if (gpuDelegateProxy != null) { +- gpuDelegateProxy.close(); ++ ++ public void close() { ++ if (interpreter != null) { ++ interpreter.close(); ++ } ++ if (gpuDelegateProxy != null) { ++ gpuDelegateProxy.close(); ++ } ++ } ++ ++ private Model(@NonNull String modelPath, @NonNull MappedByteBuffer byteModel, ++ @NonNull InterpreterApi interpreter, @Nullable GpuDelegateProxy gpuDelegateProxy) { ++ this.modelPath = modelPath; ++ this.byteModel = byteModel; ++ this.interpreter = interpreter; ++ this.gpuDelegateProxy = gpuDelegateProxy; + } +- } +- +- private Model( +- @NonNull String modelPath, +- @NonNull MappedByteBuffer byteModel, +- @NonNull InterpreterApi interpreter, +- @Nullable GpuDelegateProxy gpuDelegateProxy) { +- this.modelPath = modelPath; +- this.byteModel = byteModel; +- this.interpreter = interpreter; +- this.gpuDelegateProxy = gpuDelegateProxy; +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java +index 9e0204bdc2e71..ec6c800ef557a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java +@@ -19,473 +19,476 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c + import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkNotNull; + import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkState; + ++import org.checkerframework.checker.nullness.qual.NonNull; ++import org.tensorflow.lite.DataType; ++ + import java.nio.ByteBuffer; + import java.nio.ByteOrder; + import java.util.Arrays; +-import org.checkerframework.checker.nullness.qual.NonNull; +-import org.tensorflow.lite.DataType; + + /** Represents the data buffer for either a model's input or its output. */ + public abstract class TensorBuffer { +- /** Where the data is stored. */ +- protected ByteBuffer buffer; +- +- /** Shape of the tensor stored in this buffer. */ +- protected int[] shape; +- +- /** Number of elements in the buffer. It will be changed to a proper value in the constructor. */ +- protected int flatSize = -1; +- +- /** +- * Indicator of whether this buffer is dynamic or fixed-size. Fixed-size buffers will have +- * pre-allocated memory and fixed size. While the size of dynamic buffers can be changed. +- */ +- protected final boolean isDynamic; +- +- /** +- * Creates a {@link TensorBuffer} with specified {@code shape} and {@link DataType}. Here are some +- * examples: +- * +- * <pre> +- * // Creating a float TensorBuffer with shape {2, 3}: +- * int[] shape = new int[] {2, 3}; +- * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); +- * </pre> +- * +- * <pre> +- * // Creating an uint8 TensorBuffer of a scalar: +- * int[] shape = new int[] {}; +- * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8); +- * </pre> +- * +- * <pre> +- * // Creating an empty uint8 TensorBuffer: +- * int[] shape = new int[] {0}; +- * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8); +- * </pre> +- * +- * <p>The size of a fixed-size TensorBuffer cannot be changed once it is created. +- * +- * @param shape The shape of the {@link TensorBuffer} to be created. +- * @param dataType The dataType of the {@link TensorBuffer} to be created. +- * @throws NullPointerException if {@code shape} is null. +- * @throws IllegalArgumentException if {@code shape} has non-positive elements. +- */ +- @NonNull +- public static TensorBuffer createFixedSize(@NonNull int[] shape, DataType dataType) { +- switch (dataType) { +- case FLOAT32: +- return new TensorBufferFloat(shape); +- case UINT8: +- return new TensorBufferUint8(shape); +- default: +- throw new AssertionError("TensorBuffer does not support data type: " + dataType); ++ /** Where the data is stored. */ ++ protected ByteBuffer buffer; ++ ++ /** Shape of the tensor stored in this buffer. */ ++ protected int[] shape; ++ ++ /** ++ * Number of elements in the buffer. It will be changed to a proper value in the constructor. ++ */ ++ protected int flatSize = -1; ++ ++ /** ++ * Indicator of whether this buffer is dynamic or fixed-size. Fixed-size buffers will have ++ * pre-allocated memory and fixed size. While the size of dynamic buffers can be changed. ++ */ ++ protected final boolean isDynamic; ++ ++ /** ++ * Creates a {@link TensorBuffer} with specified {@code shape} and {@link DataType}. Here are ++ * some examples: ++ * ++ * <pre> ++ * // Creating a float TensorBuffer with shape {2, 3}: ++ * int[] shape = new int[] {2, 3}; ++ * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); ++ * </pre> ++ * ++ * <pre> ++ * // Creating an uint8 TensorBuffer of a scalar: ++ * int[] shape = new int[] {}; ++ * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8); ++ * </pre> ++ * ++ * <pre> ++ * // Creating an empty uint8 TensorBuffer: ++ * int[] shape = new int[] {0}; ++ * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8); ++ * </pre> ++ * ++ * <p>The size of a fixed-size TensorBuffer cannot be changed once it is created. ++ * ++ * @param shape The shape of the {@link TensorBuffer} to be created. ++ * @param dataType The dataType of the {@link TensorBuffer} to be created. ++ * @throws NullPointerException if {@code shape} is null. ++ * @throws IllegalArgumentException if {@code shape} has non-positive elements. ++ */ ++ @NonNull ++ public static TensorBuffer createFixedSize(@NonNull int[] shape, DataType dataType) { ++ switch (dataType) { ++ case FLOAT32: ++ return new TensorBufferFloat(shape); ++ case UINT8: ++ return new TensorBufferUint8(shape); ++ default: ++ throw new AssertionError("TensorBuffer does not support data type: " + dataType); ++ } ++ } ++ ++ /** ++ * Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of ++ * the created {@link TensorBuffer} is {0}. ++ * ++ * <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of ++ * different buffer sizes. Here are some examples: ++ * ++ * <pre> ++ * // Creating a float dynamic TensorBuffer: ++ * TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); ++ * // Loading a float array: ++ * float[] arr1 = new float[] {1, 2, 3}; ++ * tensorBuffer.loadArray(arr, new int[] {arr1.length}); ++ * // loading another float array: ++ * float[] arr2 = new float[] {1, 2, 3, 4, 5}; ++ * tensorBuffer.loadArray(arr, new int[] {arr2.length}); ++ * // loading a third float array with the same size as arr2, assuming shape doesn't change: ++ * float[] arr3 = new float[] {5, 4, 3, 2, 1}; ++ * tensorBuffer.loadArray(arr); ++ * // loading a forth float array with different size as arr3 and omitting the shape will result ++ * // in error: ++ * float[] arr4 = new float[] {3, 2, 1}; ++ * tensorBuffer.loadArray(arr); // Error: The size of byte buffer and the shape do not match. ++ * </pre> ++ * ++ * @param dataType The dataType of the {@link TensorBuffer} to be created. ++ */ ++ @NonNull ++ public static TensorBuffer createDynamic(DataType dataType) { ++ switch (dataType) { ++ case FLOAT32: ++ return new TensorBufferFloat(); ++ case UINT8: ++ return new TensorBufferUint8(); ++ default: ++ throw new AssertionError("TensorBuffer does not support data type: " + dataType); ++ } + } +- } +- +- /** +- * Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of the +- * created {@link TensorBuffer} is {0}. +- * +- * <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of +- * different buffer sizes. Here are some examples: +- * +- * <pre> +- * // Creating a float dynamic TensorBuffer: +- * TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); +- * // Loading a float array: +- * float[] arr1 = new float[] {1, 2, 3}; +- * tensorBuffer.loadArray(arr, new int[] {arr1.length}); +- * // loading another float array: +- * float[] arr2 = new float[] {1, 2, 3, 4, 5}; +- * tensorBuffer.loadArray(arr, new int[] {arr2.length}); +- * // loading a third float array with the same size as arr2, assuming shape doesn't change: +- * float[] arr3 = new float[] {5, 4, 3, 2, 1}; +- * tensorBuffer.loadArray(arr); +- * // loading a forth float array with different size as arr3 and omitting the shape will result +- * // in error: +- * float[] arr4 = new float[] {3, 2, 1}; +- * tensorBuffer.loadArray(arr); // Error: The size of byte buffer and the shape do not match. +- * </pre> +- * +- * @param dataType The dataType of the {@link TensorBuffer} to be created. +- */ +- @NonNull +- public static TensorBuffer createDynamic(DataType dataType) { +- switch (dataType) { +- case FLOAT32: +- return new TensorBufferFloat(); +- case UINT8: +- return new TensorBufferUint8(); +- default: +- throw new AssertionError("TensorBuffer does not support data type: " + dataType); ++ ++ /** ++ * Creates a {@link TensorBuffer} deep-copying data from another, with specified {@link ++ * DataType}. ++ * ++ * @param buffer the source {@link TensorBuffer} to copy from. ++ * @param dataType the expected {@link DataType} of newly created {@link TensorBuffer}. ++ * @throws NullPointerException if {@code buffer} is null. ++ */ ++ @NonNull ++ public static TensorBuffer createFrom(@NonNull TensorBuffer buffer, DataType dataType) { ++ checkNotNull(buffer, "Cannot create a buffer from null"); ++ TensorBuffer result; ++ if (buffer.isDynamic()) { ++ result = createDynamic(dataType); ++ } else { ++ result = createFixedSize(buffer.shape, dataType); ++ } ++ // The only scenario we need float array is FLOAT32->FLOAT32, or we can always use INT as ++ // intermediate container. ++ // The assumption is not true when we support other data types. ++ if (buffer.getDataType() == DataType.FLOAT32 && dataType == DataType.FLOAT32) { ++ float[] data = buffer.getFloatArray(); ++ result.loadArray(data, buffer.shape); ++ } else { ++ int[] data = buffer.getIntArray(); ++ result.loadArray(data, buffer.shape); ++ } ++ return result; + } +- } +- +- /** +- * Creates a {@link TensorBuffer} deep-copying data from another, with specified {@link DataType}. +- * +- * @param buffer the source {@link TensorBuffer} to copy from. +- * @param dataType the expected {@link DataType} of newly created {@link TensorBuffer}. +- * @throws NullPointerException if {@code buffer} is null. +- */ +- @NonNull +- public static TensorBuffer createFrom(@NonNull TensorBuffer buffer, DataType dataType) { +- checkNotNull(buffer, "Cannot create a buffer from null"); +- TensorBuffer result; +- if (buffer.isDynamic()) { +- result = createDynamic(dataType); +- } else { +- result = createFixedSize(buffer.shape, dataType); ++ ++ /** Returns the data buffer. */ ++ @NonNull ++ public ByteBuffer getBuffer() { ++ return buffer; + } +- // The only scenario we need float array is FLOAT32->FLOAT32, or we can always use INT as +- // intermediate container. +- // The assumption is not true when we support other data types. +- if (buffer.getDataType() == DataType.FLOAT32 && dataType == DataType.FLOAT32) { +- float[] data = buffer.getFloatArray(); +- result.loadArray(data, buffer.shape); +- } else { +- int[] data = buffer.getIntArray(); +- result.loadArray(data, buffer.shape); ++ ++ /** ++ * Gets the flatSize of the buffer. ++ * ++ * @throws IllegalStateException if the underlying data is corrupted ++ */ ++ public int getFlatSize() { ++ assertShapeIsCorrect(); ++ return flatSize; + } +- return result; +- } +- +- /** Returns the data buffer. */ +- @NonNull +- public ByteBuffer getBuffer() { +- return buffer; +- } +- +- /** +- * Gets the flatSize of the buffer. +- * +- * @throws IllegalStateException if the underlying data is corrupted +- */ +- public int getFlatSize() { +- assertShapeIsCorrect(); +- return flatSize; +- } +- +- /** +- * Gets the current shape. (returning a copy here to avoid unexpected modification.) +- * +- * @throws IllegalStateException if the underlying data is corrupted +- */ +- @NonNull +- public int[] getShape() { +- assertShapeIsCorrect(); +- return Arrays.copyOf(shape, shape.length); +- } +- +- /** Returns the data type of this buffer. */ +- public abstract DataType getDataType(); +- +- /** +- * Returns a float array of the values stored in this buffer. If the buffer is of different types +- * than float, the values will be converted into float. For example, values in {@link +- * TensorBufferUint8} will be converted from uint8 to float. +- */ +- @NonNull +- public abstract float[] getFloatArray(); +- +- /** +- * Returns a float value at a given index. If the buffer is of different types than float, the +- * value will be converted into float. For example, when reading a value from {@link +- * TensorBufferUint8}, the value will be first read out as uint8, and then will be converted from +- * uint8 to float. +- * +- * <pre> +- * For example, a TensorBuffer with shape {2, 3} that represents the following array, +- * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]]. +- * +- * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by: +- * float v = tensorBuffer.getFloatValue(3); +- * </pre> +- * +- * @param absIndex The absolute index of the value to be read. +- */ +- public abstract float getFloatValue(int absIndex); +- +- /** +- * Returns an int array of the values stored in this buffer. If the buffer is of different type +- * than int, the values will be converted into int, and loss of precision may apply. For example, +- * getting an int array from a {@link TensorBufferFloat} with values {400.32f, 23.04f}, the output +- * is {400, 23}. +- */ +- @NonNull +- public abstract int[] getIntArray(); +- +- /** +- * Returns an int value at a given index. If the buffer is of different types than int, the value +- * will be converted into int. For example, when reading a value from {@link TensorBufferFloat}, +- * the value will be first read out as float, and then will be converted from float to int. Loss +- * of precision may apply. +- * +- * <pre> +- * For example, a TensorBuffer with shape {2, 3} that represents the following array, +- * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]]. +- * +- * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by: +- * int v = tensorBuffer.getIntValue(3); +- * Note that v is converted from 3.0f to 3 as a result of type conversion. +- * </pre> +- * +- * @param absIndex The absolute index of the value to be read. +- */ +- public abstract int getIntValue(int absIndex); +- +- /** +- * Returns the number of bytes of a single element in the array. For example, a float buffer will +- * return 4, and a byte buffer will return 1. +- */ +- public abstract int getTypeSize(); +- +- /** Returns if the {@link TensorBuffer} is dynamic sized (could resize arbitrarily). */ +- public boolean isDynamic() { +- return isDynamic; +- } +- +- /** +- * Loads an int array into this buffer with specific shape. If the buffer is of different types +- * than int, the values will be converted into the buffer's type before being loaded into the +- * buffer, and loss of precision may apply. For example, loading an int array with values {400, +- * -23} into a {@link TensorBufferUint8} , the values will be clamped to [0, 255] and then be +- * casted to uint8 by {255, 0}. +- * +- * @param src The source array to be loaded. +- * @param shape Shape of the tensor that {@code src} represents. +- * @throws NullPointerException if {@code src} is null. +- * @throws NullPointerException if {@code shape} is null. +- * @throws IllegalArgumentException if the size of the array to be loaded does not match the +- * specified shape. +- */ +- public abstract void loadArray(@NonNull int[] src, @NonNull int[] shape); +- +- /** +- * Loads an int array into this buffer. If the buffer is of different types than int, the values +- * will be converted into the buffer's type before being loaded into the buffer, and loss of +- * precision may apply. For example, loading an int array with values {400, -23} into a {@link +- * TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by +- * {255, 0}. +- * +- * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this +- * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always match +- * the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link +- * TensorBuffer}. Use {@link #loadArray(int[], int[])} if {@code src} has a different shape. +- * +- * @param src The source array to be loaded. +- */ +- public void loadArray(@NonNull int[] src) { +- loadArray(src, shape); +- } +- +- /** +- * Loads a float array into this buffer with specific shape. If the buffer is of different types +- * than float, the values will be converted into the buffer's type before being loaded into the +- * buffer, and loss of precision may apply. For example, loading a float array into a {@link +- * TensorBufferUint8} with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and +- * then be casted to uint8 by {255, 0}. +- * +- * @param src The source array to be loaded. +- * @param shape Shape of the tensor that {@code src} represents. +- * @throws NullPointerException if {@code src} is null. +- * @throws NullPointerException if {@code shape} is null. +- * @throws IllegalArgumentException if the size of the array to be loaded does not match the +- * specified shape. +- */ +- public abstract void loadArray(@NonNull float[] src, @NonNull int[] shape); +- +- /** +- * Loads a float array into this buffer. If the buffer is of different types than float, the +- * values will be converted into the buffer's type before being loaded into the buffer, and loss +- * of precision may apply. For example, loading a float array into a {@link TensorBufferUint8} +- * with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to +- * uint8 by {255, 0}. +- * +- * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this +- * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always match +- * the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link +- * TensorBuffer}. Use {@link #loadArray(float[], int[])} if {@code src} has a different shape. +- * +- * @param src The source array to be loaded. +- */ +- public void loadArray(@NonNull float[] src) { +- loadArray(src, shape); +- } +- +- /** +- * Loads a byte buffer into this {@link TensorBuffer} with specific shape. +- * +- * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for +- * performance concern, but if modification is necessary, please make a copy. +- * +- * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer} +- * backed by an array. +- * +- * @param buffer The byte buffer to load. +- * @throws NullPointerException if {@code buffer} is null. +- * @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not +- * match or the size of {@code buffer} and {@code flatSize} do not match. +- */ +- public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) { +- checkNotNull(buffer, "Byte buffer cannot be null."); +- checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative."); +- +- int flatSize = computeFlatSize(shape); +- checkArgument( +- (buffer.limit() == getTypeSize() * flatSize), +- "The size of byte buffer and the shape do not match. Expected: " +- + getTypeSize() * flatSize +- + " Actual: " +- + buffer.limit()); +- +- if (!isDynamic) { +- // Make sure the new shape fits the buffer size when TensorBuffer has fixed size. +- checkArgument(Arrays.equals(shape, this.shape)); ++ ++ /** ++ * Gets the current shape. (returning a copy here to avoid unexpected modification.) ++ * ++ * @throws IllegalStateException if the underlying data is corrupted ++ */ ++ @NonNull ++ public int[] getShape() { ++ assertShapeIsCorrect(); ++ return Arrays.copyOf(shape, shape.length); + } + +- // Update to the new shape, since shape dim values might change. +- this.shape = shape.clone(); +- this.flatSize = flatSize; +- +- buffer.rewind(); +- this.buffer = buffer; +- } +- +- /** +- * Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of +- * this {@link TensorBuffer}. +- * +- * <p>Using this method assumes that the shape of {@code buffer} is the same as the shape of this +- * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code buffer.limit()}) should always +- * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link +- * TensorBuffer}. Use {@link #loadBuffer(ByteBuffer, int[])} if {@code buffer} has a different +- * shape. +- * +- * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for +- * performance concern, but if modification is necessary, please make a copy. +- * +- * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer} +- * backed by an array. +- * +- * <p>If the {@code buffer} is read-only, we adopt a copy-on-write strategy for performance. +- * +- * @param buffer The byte buffer to load. +- */ +- public void loadBuffer(@NonNull ByteBuffer buffer) { +- loadBuffer(buffer, shape); +- } +- +- /** +- * Constructs a fixed size {@link TensorBuffer} with specified {@code shape}. +- * +- * @throws NullPointerException if {@code shape} is null. +- * @throws IllegalArgumentException if {@code shape} has non-positive elements. +- */ +- protected TensorBuffer(@NonNull int[] shape) { +- isDynamic = false; +- allocateMemory(shape); +- } +- +- /** Constructs a dynamic {@link TensorBuffer} which can be resized. */ +- protected TensorBuffer() { +- isDynamic = true; +- // Initialize the dynamic TensorBuffer with an empty ByteBuffer. +- allocateMemory(new int[] {0}); +- } +- +- /** Calculates number of elements in the buffer. */ +- protected static int computeFlatSize(@NonNull int[] shape) { +- checkNotNull(shape, "Shape cannot be null."); +- int prod = 1; +- for (int s : shape) { +- prod = prod * s; ++ /** Returns the data type of this buffer. */ ++ public abstract DataType getDataType(); ++ ++ /** ++ * Returns a float array of the values stored in this buffer. If the buffer is of different ++ * types than float, the values will be converted into float. For example, values in {@link ++ * TensorBufferUint8} will be converted from uint8 to float. ++ */ ++ @NonNull ++ public abstract float[] getFloatArray(); ++ ++ /** ++ * Returns a float value at a given index. If the buffer is of different types than float, the ++ * value will be converted into float. For example, when reading a value from {@link ++ * TensorBufferUint8}, the value will be first read out as uint8, and then will be converted ++ * from uint8 to float. ++ * ++ * <pre> ++ * For example, a TensorBuffer with shape {2, 3} that represents the following array, ++ * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]]. ++ * ++ * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by: ++ * float v = tensorBuffer.getFloatValue(3); ++ * </pre> ++ * ++ * @param absIndex The absolute index of the value to be read. ++ */ ++ public abstract float getFloatValue(int absIndex); ++ ++ /** ++ * Returns an int array of the values stored in this buffer. If the buffer is of different type ++ * than int, the values will be converted into int, and loss of precision may apply. For ++ * example, getting an int array from a {@link TensorBufferFloat} with values {400.32f, 23.04f}, ++ * the output is {400, 23}. ++ */ ++ @NonNull ++ public abstract int[] getIntArray(); ++ ++ /** ++ * Returns an int value at a given index. If the buffer is of different types than int, the ++ * value will be converted into int. For example, when reading a value from {@link ++ * TensorBufferFloat}, the value will be first read out as float, and then will be converted ++ * from float to int. Loss of precision may apply. ++ * ++ * <pre> ++ * For example, a TensorBuffer with shape {2, 3} that represents the following array, ++ * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]]. ++ * ++ * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by: ++ * int v = tensorBuffer.getIntValue(3); ++ * Note that v is converted from 3.0f to 3 as a result of type conversion. ++ * </pre> ++ * ++ * @param absIndex The absolute index of the value to be read. ++ */ ++ public abstract int getIntValue(int absIndex); ++ ++ /** ++ * Returns the number of bytes of a single element in the array. For example, a float buffer ++ * will return 4, and a byte buffer will return 1. ++ */ ++ public abstract int getTypeSize(); ++ ++ /** Returns if the {@link TensorBuffer} is dynamic sized (could resize arbitrarily). */ ++ public boolean isDynamic() { ++ return isDynamic; + } +- return prod; +- } +- +- /** +- * For dynamic buffer, resize the memory if needed. For fixed-size buffer, check if the {@code +- * shape} of src fits the buffer size. +- */ +- protected void resize(@NonNull int[] shape) { +- if (isDynamic) { +- allocateMemory(shape); +- } else { +- // Make sure the new shape fits the buffer size when TensorBuffer has fixed size. +- checkArgument(Arrays.equals(shape, this.shape)); +- this.shape = shape.clone(); ++ ++ /** ++ * Loads an int array into this buffer with specific shape. If the buffer is of different types ++ * than int, the values will be converted into the buffer's type before being loaded into the ++ * buffer, and loss of precision may apply. For example, loading an int array with values {400, ++ * -23} into a {@link TensorBufferUint8} , the values will be clamped to [0, 255] and then be ++ * casted to uint8 by {255, 0}. ++ * ++ * @param src The source array to be loaded. ++ * @param shape Shape of the tensor that {@code src} represents. ++ * @throws NullPointerException if {@code src} is null. ++ * @throws NullPointerException if {@code shape} is null. ++ * @throws IllegalArgumentException if the size of the array to be loaded does not match the ++ * specified shape. ++ */ ++ public abstract void loadArray(@NonNull int[] src, @NonNull int[] shape); ++ ++ /** ++ * Loads an int array into this buffer. If the buffer is of different types than int, the values ++ * will be converted into the buffer's type before being loaded into the buffer, and loss of ++ * precision may apply. For example, loading an int array with values {400, -23} into a {@link ++ * TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by ++ * {255, 0}. ++ * ++ * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this ++ * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always ++ * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link ++ * TensorBuffer}. Use {@link #loadArray(int[], int[])} if {@code src} has a different shape. ++ * ++ * @param src The source array to be loaded. ++ */ ++ public void loadArray(@NonNull int[] src) { ++ loadArray(src, shape); ++ } ++ ++ /** ++ * Loads a float array into this buffer with specific shape. If the buffer is of different types ++ * than float, the values will be converted into the buffer's type before being loaded into the ++ * buffer, and loss of precision may apply. For example, loading a float array into a {@link ++ * TensorBufferUint8} with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and ++ * then be casted to uint8 by {255, 0}. ++ * ++ * @param src The source array to be loaded. ++ * @param shape Shape of the tensor that {@code src} represents. ++ * @throws NullPointerException if {@code src} is null. ++ * @throws NullPointerException if {@code shape} is null. ++ * @throws IllegalArgumentException if the size of the array to be loaded does not match the ++ * specified shape. ++ */ ++ public abstract void loadArray(@NonNull float[] src, @NonNull int[] shape); ++ ++ /** ++ * Loads a float array into this buffer. If the buffer is of different types than float, the ++ * values will be converted into the buffer's type before being loaded into the buffer, and loss ++ * of precision may apply. For example, loading a float array into a {@link TensorBufferUint8} ++ * with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to ++ * uint8 by {255, 0}. ++ * ++ * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this ++ * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always ++ * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link ++ * TensorBuffer}. Use {@link #loadArray(float[], int[])} if {@code src} has a different shape. ++ * ++ * @param src The source array to be loaded. ++ */ ++ public void loadArray(@NonNull float[] src) { ++ loadArray(src, shape); + } +- } + +- /** Copies the underlying {@link ByteBuffer} if it's readonly. */ +- protected synchronized void copyByteBufferIfReadOnly() { +- if (!buffer.isReadOnly()) { +- return; ++ /** ++ * Loads a byte buffer into this {@link TensorBuffer} with specific shape. ++ * ++ * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here ++ * for performance concern, but if modification is necessary, please make a copy. ++ * ++ * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer} ++ * backed by an array. ++ * ++ * @param buffer The byte buffer to load. ++ * @throws NullPointerException if {@code buffer} is null. ++ * @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not ++ * match or the size of {@code buffer} and {@code flatSize} do not match. ++ */ ++ public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) { ++ checkNotNull(buffer, "Byte buffer cannot be null."); ++ checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative."); ++ ++ int flatSize = computeFlatSize(shape); ++ checkArgument((buffer.limit() == getTypeSize() * flatSize), ++ "The size of byte buffer and the shape do not match. Expected: " ++ + getTypeSize() * flatSize + " Actual: " + buffer.limit()); ++ ++ if (!isDynamic) { ++ // Make sure the new shape fits the buffer size when TensorBuffer has fixed size. ++ checkArgument(Arrays.equals(shape, this.shape)); ++ } ++ ++ // Update to the new shape, since shape dim values might change. ++ this.shape = shape.clone(); ++ this.flatSize = flatSize; ++ ++ buffer.rewind(); ++ this.buffer = buffer; + } +- ByteBuffer newByteBuffer = ByteBuffer.allocateDirect(buffer.capacity()); +- newByteBuffer.order(buffer.order()); +- newByteBuffer.put(buffer); +- newByteBuffer.rewind(); +- buffer = newByteBuffer; +- } +- +- /** +- * Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array, this +- * {@link TensorBuffer} will be created as a scalar and its flatSize will be 1. +- * +- * @throws NullPointerException if {@code shape} is null. +- * @throws IllegalArgumentException if {@code shape} has negative elements. +- */ +- private void allocateMemory(@NonNull int[] shape) { +- checkNotNull(shape, "TensorBuffer shape cannot be null."); +- checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative."); +- +- // Check if the new shape is the same as current shape. +- int newFlatSize = computeFlatSize(shape); +- this.shape = shape.clone(); +- if (flatSize == newFlatSize) { +- return; ++ ++ /** ++ * Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of ++ * this {@link TensorBuffer}. ++ * ++ * <p>Using this method assumes that the shape of {@code buffer} is the same as the shape of ++ * this ++ * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code buffer.limit()}) should always ++ * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link ++ * TensorBuffer}. Use {@link #loadBuffer(ByteBuffer, int[])} if {@code buffer} has a different ++ * shape. ++ * ++ * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here ++ * for performance concern, but if modification is necessary, please make a copy. ++ * ++ * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer} ++ * backed by an array. ++ * ++ * <p>If the {@code buffer} is read-only, we adopt a copy-on-write strategy for performance. ++ * ++ * @param buffer The byte buffer to load. ++ */ ++ public void loadBuffer(@NonNull ByteBuffer buffer) { ++ loadBuffer(buffer, shape); ++ } ++ ++ /** ++ * Constructs a fixed size {@link TensorBuffer} with specified {@code shape}. ++ * ++ * @throws NullPointerException if {@code shape} is null. ++ * @throws IllegalArgumentException if {@code shape} has non-positive elements. ++ */ ++ protected TensorBuffer(@NonNull int[] shape) { ++ isDynamic = false; ++ allocateMemory(shape); ++ } ++ ++ /** Constructs a dynamic {@link TensorBuffer} which can be resized. */ ++ protected TensorBuffer() { ++ isDynamic = true; ++ // Initialize the dynamic TensorBuffer with an empty ByteBuffer. ++ allocateMemory(new int[] {0}); ++ } ++ ++ /** Calculates number of elements in the buffer. */ ++ protected static int computeFlatSize(@NonNull int[] shape) { ++ checkNotNull(shape, "Shape cannot be null."); ++ int prod = 1; ++ for (int s : shape) { ++ prod = prod * s; ++ } ++ return prod; ++ } ++ ++ /** ++ * For dynamic buffer, resize the memory if needed. For fixed-size buffer, check if the {@code ++ * shape} of src fits the buffer size. ++ */ ++ protected void resize(@NonNull int[] shape) { ++ if (isDynamic) { ++ allocateMemory(shape); ++ } else { ++ // Make sure the new shape fits the buffer size when TensorBuffer has fixed size. ++ checkArgument(Arrays.equals(shape, this.shape)); ++ this.shape = shape.clone(); ++ } ++ } ++ ++ /** Copies the underlying {@link ByteBuffer} if it's readonly. */ ++ protected synchronized void copyByteBufferIfReadOnly() { ++ if (!buffer.isReadOnly()) { ++ return; ++ } ++ ByteBuffer newByteBuffer = ByteBuffer.allocateDirect(buffer.capacity()); ++ newByteBuffer.order(buffer.order()); ++ newByteBuffer.put(buffer); ++ newByteBuffer.rewind(); ++ buffer = newByteBuffer; ++ } ++ ++ /** ++ * Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array, ++ * this ++ * {@link TensorBuffer} will be created as a scalar and its flatSize will be 1. ++ * ++ * @throws NullPointerException if {@code shape} is null. ++ * @throws IllegalArgumentException if {@code shape} has negative elements. ++ */ ++ private void allocateMemory(@NonNull int[] shape) { ++ checkNotNull(shape, "TensorBuffer shape cannot be null."); ++ checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative."); ++ ++ // Check if the new shape is the same as current shape. ++ int newFlatSize = computeFlatSize(shape); ++ this.shape = shape.clone(); ++ if (flatSize == newFlatSize) { ++ return; ++ } ++ ++ // Update to the new shape. ++ flatSize = newFlatSize; ++ buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize()); ++ buffer.order(ByteOrder.nativeOrder()); + } + +- // Update to the new shape. +- flatSize = newFlatSize; +- buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize()); +- buffer.order(ByteOrder.nativeOrder()); +- } +- +- /** +- * Verifies if the shape of the {@link TensorBuffer} matched the size of the underlying {@link +- * ByteBuffer}. +- */ +- private void assertShapeIsCorrect() { +- int flatSize = computeFlatSize(shape); +- checkState( +- (buffer.limit() == getTypeSize() * flatSize), +- String.format( +- "The size of underlying ByteBuffer (%d) and the shape (%s) do not match. The" +- + " ByteBuffer may have been changed.", +- buffer.limit(), Arrays.toString(shape))); +- } +- +- /** +- * Checks if {@code shape} meets one of following two requirements: 1. Elements in {@code shape} +- * are all non-negative numbers. 2. {@code shape} is an empty array, which corresponds to scalar. +- */ +- private static boolean isShapeValid(@NonNull int[] shape) { +- if (shape.length == 0) { +- // This shape refers to a scalar. +- return true; ++ /** ++ * Verifies if the shape of the {@link TensorBuffer} matched the size of the underlying {@link ++ * ByteBuffer}. ++ */ ++ private void assertShapeIsCorrect() { ++ int flatSize = computeFlatSize(shape); ++ checkState((buffer.limit() == getTypeSize() * flatSize), ++ String.format( ++ "The size of underlying ByteBuffer (%d) and the shape (%s) do not match. The" ++ + " ByteBuffer may have been changed.", ++ buffer.limit(), Arrays.toString(shape))); + } + +- // This shape refers to a multidimensional array. +- for (int s : shape) { +- // All elements in shape should be non-negative. +- if (s < 0) { +- return false; +- } ++ /** ++ * Checks if {@code shape} meets one of following two requirements: 1. Elements in {@code shape} ++ * are all non-negative numbers. 2. {@code shape} is an empty array, which corresponds to ++ * scalar. ++ */ ++ private static boolean isShapeValid(@NonNull int[] shape) { ++ if (shape.length == 0) { ++ // This shape refers to a scalar. ++ return true; ++ } ++ ++ // This shape refers to a multidimensional array. ++ for (int s : shape) { ++ // All elements in shape should be non-negative. ++ if (s < 0) { ++ return false; ++ } ++ } ++ return true; + } +- return true; +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java +index 8d2bc5ad0c84d..632db6c886b17 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java +@@ -15,103 +15,102 @@ limitations under the License. + + package org.tensorflow.lite.support.tensorbuffer; + +-import java.nio.FloatBuffer; + import org.checkerframework.checker.nullness.qual.NonNull; + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.common.internal.SupportPreconditions; + ++import java.nio.FloatBuffer; ++ + /** Represents data buffer with float values. */ + public final class TensorBufferFloat extends TensorBuffer { +- private static final DataType DATA_TYPE = DataType.FLOAT32; +- +- /** +- * Creates a {@link TensorBufferFloat} with specified {@code shape}. +- * +- * @throws NullPointerException if {@code shape} is null. +- * @throws IllegalArgumentException if {@code shape} has non-positive elements. +- */ +- TensorBufferFloat(@NonNull int[] shape) { +- super(shape); +- } +- +- TensorBufferFloat() { +- super(); +- } +- +- @Override +- public DataType getDataType() { +- return DATA_TYPE; +- } +- +- @Override +- @NonNull +- public float[] getFloatArray() { +- buffer.rewind(); +- float[] arr = new float[flatSize]; +- +- FloatBuffer floatBuffer = buffer.asFloatBuffer(); +- floatBuffer.get(arr); +- return arr; +- } +- +- @Override +- public float getFloatValue(int absIndex) { +- return buffer.getFloat(absIndex << 2); +- } +- +- @Override +- @NonNull +- public int[] getIntArray() { +- buffer.rewind(); +- float[] floatArr = new float[flatSize]; +- buffer.asFloatBuffer().get(floatArr); +- +- int[] intArr = new int[flatSize]; +- for (int i = 0; i < flatSize; i++) { +- intArr[i] = (int) floatArr[i]; ++ private static final DataType DATA_TYPE = DataType.FLOAT32; ++ ++ /** ++ * Creates a {@link TensorBufferFloat} with specified {@code shape}. ++ * ++ * @throws NullPointerException if {@code shape} is null. ++ * @throws IllegalArgumentException if {@code shape} has non-positive elements. ++ */ ++ TensorBufferFloat(@NonNull int[] shape) { ++ super(shape); + } +- return intArr; +- } +- +- @Override +- public int getIntValue(int absIndex) { +- return (int) buffer.getFloat(absIndex << 2); +- } +- +- @Override +- public int getTypeSize() { +- return DATA_TYPE.byteSize(); +- } +- +- @Override +- public void loadArray(@NonNull float[] src, @NonNull int[] shape) { +- SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); +- SupportPreconditions.checkArgument( +- src.length == computeFlatSize(shape), +- "The size of the array to be loaded does not match the specified shape."); +- copyByteBufferIfReadOnly(); +- resize(shape); +- buffer.rewind(); +- +- FloatBuffer floatBuffer = buffer.asFloatBuffer(); +- floatBuffer.put(src); +- } +- +- @Override +- public void loadArray(@NonNull int[] src, @NonNull int[] shape) { +- SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); +- SupportPreconditions.checkArgument( +- src.length == computeFlatSize(shape), +- "The size of the array to be loaded does not match the specified shape."); +- copyByteBufferIfReadOnly(); +- resize(shape); +- buffer.rewind(); +- +- float[] floatArray = new float[src.length]; +- int cnt = 0; +- for (int a : src) { +- floatArray[cnt++] = (float) a; ++ ++ TensorBufferFloat() { ++ super(); ++ } ++ ++ @Override ++ public DataType getDataType() { ++ return DATA_TYPE; ++ } ++ ++ @Override ++ @NonNull ++ public float[] getFloatArray() { ++ buffer.rewind(); ++ float[] arr = new float[flatSize]; ++ ++ FloatBuffer floatBuffer = buffer.asFloatBuffer(); ++ floatBuffer.get(arr); ++ return arr; ++ } ++ ++ @Override ++ public float getFloatValue(int absIndex) { ++ return buffer.getFloat(absIndex << 2); ++ } ++ ++ @Override ++ @NonNull ++ public int[] getIntArray() { ++ buffer.rewind(); ++ float[] floatArr = new float[flatSize]; ++ buffer.asFloatBuffer().get(floatArr); ++ ++ int[] intArr = new int[flatSize]; ++ for (int i = 0; i < flatSize; i++) { ++ intArr[i] = (int) floatArr[i]; ++ } ++ return intArr; ++ } ++ ++ @Override ++ public int getIntValue(int absIndex) { ++ return (int) buffer.getFloat(absIndex << 2); ++ } ++ ++ @Override ++ public int getTypeSize() { ++ return DATA_TYPE.byteSize(); ++ } ++ ++ @Override ++ public void loadArray(@NonNull float[] src, @NonNull int[] shape) { ++ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); ++ SupportPreconditions.checkArgument(src.length == computeFlatSize(shape), ++ "The size of the array to be loaded does not match the specified shape."); ++ copyByteBufferIfReadOnly(); ++ resize(shape); ++ buffer.rewind(); ++ ++ FloatBuffer floatBuffer = buffer.asFloatBuffer(); ++ floatBuffer.put(src); ++ } ++ ++ @Override ++ public void loadArray(@NonNull int[] src, @NonNull int[] shape) { ++ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); ++ SupportPreconditions.checkArgument(src.length == computeFlatSize(shape), ++ "The size of the array to be loaded does not match the specified shape."); ++ copyByteBufferIfReadOnly(); ++ resize(shape); ++ buffer.rewind(); ++ ++ float[] floatArray = new float[src.length]; ++ int cnt = 0; ++ for (int a : src) { ++ floatArray[cnt++] = (float) a; ++ } ++ buffer.asFloatBuffer().put(floatArray); + } +- buffer.asFloatBuffer().put(floatArray); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java +index b2fa466e5be92..2924ef0af6c11 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java +@@ -21,103 +21,101 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions; + + /** Represents data buffer with 8-bit unsigned integer values. */ + public final class TensorBufferUint8 extends TensorBuffer { +- private static final DataType DATA_TYPE = DataType.UINT8; +- +- /** +- * Creates a {@link TensorBufferUint8} with specified {@code shape}. +- * +- * @throws NullPointerException if {@code shape} is null. +- * @throws IllegalArgumentException if {@code shape} has non-positive elements. +- */ +- TensorBufferUint8(@NonNull int[] shape) { +- super(shape); +- } +- +- TensorBufferUint8() { +- super(); +- } +- +- @Override +- public DataType getDataType() { +- return DATA_TYPE; +- } +- +- @Override +- @NonNull +- public float[] getFloatArray() { +- buffer.rewind(); +- byte[] byteArr = new byte[flatSize]; +- buffer.get(byteArr); +- +- float[] floatArr = new float[flatSize]; +- for (int i = 0; i < flatSize; i++) { +- floatArr[i] = (float) (byteArr[i] & 0xff); ++ private static final DataType DATA_TYPE = DataType.UINT8; ++ ++ /** ++ * Creates a {@link TensorBufferUint8} with specified {@code shape}. ++ * ++ * @throws NullPointerException if {@code shape} is null. ++ * @throws IllegalArgumentException if {@code shape} has non-positive elements. ++ */ ++ TensorBufferUint8(@NonNull int[] shape) { ++ super(shape); + } +- return floatArr; +- } +- +- @Override +- public float getFloatValue(int index) { +- return (float) (buffer.get(index) & 0xff); +- } +- +- @Override +- @NonNull +- public int[] getIntArray() { +- buffer.rewind(); +- byte[] byteArr = new byte[flatSize]; +- buffer.get(byteArr); +- +- int[] intArr = new int[flatSize]; +- for (int i = 0; i < flatSize; i++) { +- intArr[i] = byteArr[i] & 0xff; ++ ++ TensorBufferUint8() { ++ super(); ++ } ++ ++ @Override ++ public DataType getDataType() { ++ return DATA_TYPE; ++ } ++ ++ @Override ++ @NonNull ++ public float[] getFloatArray() { ++ buffer.rewind(); ++ byte[] byteArr = new byte[flatSize]; ++ buffer.get(byteArr); ++ ++ float[] floatArr = new float[flatSize]; ++ for (int i = 0; i < flatSize; i++) { ++ floatArr[i] = (float) (byteArr[i] & 0xff); ++ } ++ return floatArr; ++ } ++ ++ @Override ++ public float getFloatValue(int index) { ++ return (float) (buffer.get(index) & 0xff); + } +- return intArr; +- } +- +- @Override +- public int getIntValue(int index) { +- return buffer.get(index) & 0xff; +- } +- +- @Override +- public int getTypeSize() { +- return DATA_TYPE.byteSize(); +- } +- +- @Override +- public void loadArray(@NonNull float[] src, @NonNull int[] shape) { +- SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); +- SupportPreconditions.checkArgument( +- src.length == computeFlatSize(shape), +- "The size of the array to be loaded does not match the specified shape."); +- copyByteBufferIfReadOnly(); +- resize(shape); +- buffer.rewind(); +- +- byte[] byteArr = new byte[src.length]; +- int cnt = 0; +- for (float a : src) { +- byteArr[cnt++] = (byte) Math.max(Math.min(a, 255.0), 0.0); ++ ++ @Override ++ @NonNull ++ public int[] getIntArray() { ++ buffer.rewind(); ++ byte[] byteArr = new byte[flatSize]; ++ buffer.get(byteArr); ++ ++ int[] intArr = new int[flatSize]; ++ for (int i = 0; i < flatSize; i++) { ++ intArr[i] = byteArr[i] & 0xff; ++ } ++ return intArr; + } +- buffer.put(byteArr); +- } +- +- @Override +- public void loadArray(@NonNull int[] src, @NonNull int[] shape) { +- SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); +- SupportPreconditions.checkArgument( +- src.length == computeFlatSize(shape), +- "The size of the array to be loaded does not match the specified shape."); +- copyByteBufferIfReadOnly(); +- resize(shape); +- buffer.rewind(); +- +- byte[] byteArr = new byte[src.length]; +- int cnt = 0; +- for (float a : src) { +- byteArr[cnt++] = (byte) Math.max(Math.min(a, 255), 0); ++ ++ @Override ++ public int getIntValue(int index) { ++ return buffer.get(index) & 0xff; ++ } ++ ++ @Override ++ public int getTypeSize() { ++ return DATA_TYPE.byteSize(); ++ } ++ ++ @Override ++ public void loadArray(@NonNull float[] src, @NonNull int[] shape) { ++ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); ++ SupportPreconditions.checkArgument(src.length == computeFlatSize(shape), ++ "The size of the array to be loaded does not match the specified shape."); ++ copyByteBufferIfReadOnly(); ++ resize(shape); ++ buffer.rewind(); ++ ++ byte[] byteArr = new byte[src.length]; ++ int cnt = 0; ++ for (float a : src) { ++ byteArr[cnt++] = (byte) Math.max(Math.min(a, 255.0), 0.0); ++ } ++ buffer.put(byteArr); ++ } ++ ++ @Override ++ public void loadArray(@NonNull int[] src, @NonNull int[] shape) { ++ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); ++ SupportPreconditions.checkArgument(src.length == computeFlatSize(shape), ++ "The size of the array to be loaded does not match the specified shape."); ++ copyByteBufferIfReadOnly(); ++ resize(shape); ++ buffer.rewind(); ++ ++ byte[] byteArr = new byte[src.length]; ++ int cnt = 0; ++ for (float a : src) { ++ byteArr[cnt++] = (byte) Math.max(Math.min(a, 255), 0); ++ } ++ buffer.put(byteArr); + } +- buffer.put(byteArr); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java +index 4f15f3d6b7d64..b3eb11fb32f5f 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java +@@ -22,13 +22,7 @@ import android.media.AudioFormat; + import android.media.AudioRecord; + import android.media.MediaRecorder; + import android.os.ParcelFileDescriptor; +-import java.io.File; +-import java.io.IOException; +-import java.nio.ByteBuffer; +-import java.nio.MappedByteBuffer; +-import java.util.ArrayList; +-import java.util.Collections; +-import java.util.List; ++ + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.annotations.UsedByReflection; + import org.tensorflow.lite.support.audio.TensorAudio; +@@ -40,6 +34,14 @@ import org.tensorflow.lite.task.core.TaskJniUtils; + import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; + import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; + ++import java.io.File; ++import java.io.IOException; ++import java.nio.ByteBuffer; ++import java.nio.MappedByteBuffer; ++import java.util.ArrayList; ++import java.util.Collections; ++import java.util.List; ++ + /** + * Performs classification on audio waveforms. + * +@@ -72,468 +74,437 @@ import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; + * CLI demo tool</a> for easily trying out this API. + */ + public final class AudioClassifier extends BaseTaskApi { ++ private static final String AUDIO_CLASSIFIER_NATIVE_LIB = "task_audio_jni"; ++ private static final int OPTIONAL_FD_LENGTH = -1; ++ private static final int OPTIONAL_FD_OFFSET = -1; ++ ++ /** ++ * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}. ++ * ++ * @param modelPath path of the classification model with metadata in the assets ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static AudioClassifier createFromFile(Context context, String modelPath) ++ throws IOException { ++ return createFromFileAndOptions( ++ context, modelPath, AudioClassifierOptions.builder().build()); ++ } + +- private static final String AUDIO_CLASSIFIER_NATIVE_LIB = "task_audio_jni"; +- private static final int OPTIONAL_FD_LENGTH = -1; +- private static final int OPTIONAL_FD_OFFSET = -1; +- +- /** +- * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}. +- * +- * @param modelPath path of the classification model with metadata in the assets +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static AudioClassifier createFromFile(Context context, String modelPath) +- throws IOException { +- return createFromFileAndOptions(context, modelPath, AudioClassifierOptions.builder().build()); +- } +- +- /** +- * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}. +- * +- * @param modelFile the classification model {@link File} instance +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static AudioClassifier createFromFile(File modelFile) throws IOException { +- return createFromFileAndOptions(modelFile, AudioClassifierOptions.builder().build()); +- } +- +- /** +- * Creates an {@link AudioClassifier} instance with a model buffer and the default {@link +- * AudioClassifierOptions}. +- * +- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the +- * classification model +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a +- * {@link MappedByteBuffer} +- */ +- public static AudioClassifier createFromBuffer(final ByteBuffer modelBuffer) { +- return createFromBufferAndOptions(modelBuffer, AudioClassifierOptions.builder().build()); +- } +- +- /** +- * Creates an {@link AudioClassifier} instance from {@link AudioClassifierOptions}. +- * +- * @param modelPath path of the classification model with metadata in the assets +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static AudioClassifier createFromFileAndOptions( +- Context context, String modelPath, AudioClassifierOptions options) throws IOException { +- return new AudioClassifier( +- TaskJniUtils.createHandleFromFdAndOptions( +- context, +- new FdAndOptionsHandleProvider<AudioClassifierOptions>() { +- @Override +- public long createHandle( +- int fileDescriptor, +- long fileDescriptorLength, +- long fileDescriptorOffset, +- AudioClassifierOptions options) { +- return initJniWithModelFdAndOptions( +- fileDescriptor, +- fileDescriptorLength, +- fileDescriptorOffset, +- options, +- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions())); +- } +- }, +- AUDIO_CLASSIFIER_NATIVE_LIB, +- modelPath, +- options)); +- } +- +- /** +- * Creates an {@link AudioClassifier} instance. +- * +- * @param modelFile the classification model {@link File} instance +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static AudioClassifier createFromFileAndOptions( +- File modelFile, final AudioClassifierOptions options) throws IOException { +- try (ParcelFileDescriptor descriptor = +- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { +- return new AudioClassifier( +- TaskJniUtils.createHandleFromLibrary( +- new TaskJniUtils.EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithModelFdAndOptions( +- descriptor.getFd(), +- /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH, +- /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET, +- options, +- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions())); +- } +- }, +- AUDIO_CLASSIFIER_NATIVE_LIB)); ++ /** ++ * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}. ++ * ++ * @param modelFile the classification model {@link File} instance ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static AudioClassifier createFromFile(File modelFile) throws IOException { ++ return createFromFileAndOptions(modelFile, AudioClassifierOptions.builder().build()); + } +- } +- +- /** +- * Creates an {@link AudioClassifier} instance with a model buffer and {@link +- * AudioClassifierOptions}. +- * +- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the +- * classification model +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a +- * {@link MappedByteBuffer} +- */ +- public static AudioClassifier createFromBufferAndOptions( +- final ByteBuffer modelBuffer, final AudioClassifierOptions options) { +- if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { +- throw new IllegalArgumentException( +- "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ ++ /** ++ * Creates an {@link AudioClassifier} instance with a model buffer and the default {@link ++ * AudioClassifierOptions}. ++ * ++ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the ++ * classification model ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a ++ * {@link MappedByteBuffer} ++ */ ++ public static AudioClassifier createFromBuffer(final ByteBuffer modelBuffer) { ++ return createFromBufferAndOptions(modelBuffer, AudioClassifierOptions.builder().build()); + } +- return new AudioClassifier( +- TaskJniUtils.createHandleFromLibrary( +- new EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithByteBuffer( +- modelBuffer, +- options, +- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions())); +- } +- }, +- AUDIO_CLASSIFIER_NATIVE_LIB)); +- } +- +- /** +- * Constructor to initialize the JNI with a pointer from C++. +- * +- * @param nativeHandle a pointer referencing memory allocated in C++ +- */ +- private AudioClassifier(long nativeHandle) { +- super(nativeHandle); +- } +- +- /** Options for setting up an {@link AudioClassifier}. */ +- @UsedByReflection("audio_classifier_jni.cc") +- public static class AudioClassifierOptions { +- // Not using AutoValue for this class because scoreThreshold cannot have default value +- // (otherwise, the default value would override the one in the model metadata) and `Optional` is +- // not an option here, because +- // 1. java.util.Optional require Java 8 while we need to support Java 7. +- // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the +- // comments for labelAllowList. +- private final BaseOptions baseOptions; +- private final String displayNamesLocale; +- private final int maxResults; +- private final float scoreThreshold; +- private final boolean isScoreThresholdSet; +- // As an open source project, we've been trying avoiding depending on common java libraries, +- // such as Guava, because it may introduce conflicts with clients who also happen to use those +- // libraries. Therefore, instead of using ImmutableList here, we convert the List into +- // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less +- // vulnerable. +- private final List<String> labelAllowList; +- private final List<String> labelDenyList; +- +- public static Builder builder() { +- return new Builder(); ++ ++ /** ++ * Creates an {@link AudioClassifier} instance from {@link AudioClassifierOptions}. ++ * ++ * @param modelPath path of the classification model with metadata in the assets ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static AudioClassifier createFromFileAndOptions( ++ Context context, String modelPath, AudioClassifierOptions options) throws IOException { ++ return new AudioClassifier(TaskJniUtils.createHandleFromFdAndOptions( ++ context, new FdAndOptionsHandleProvider<AudioClassifierOptions>() { ++ @Override ++ public long createHandle(int fileDescriptor, long fileDescriptorLength, ++ long fileDescriptorOffset, AudioClassifierOptions options) { ++ return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength, ++ fileDescriptorOffset, options, ++ TaskJniUtils.createProtoBaseOptionsHandle( ++ options.getBaseOptions())); ++ } ++ }, AUDIO_CLASSIFIER_NATIVE_LIB, modelPath, options)); + } + +- /** A builder that helps to configure an instance of AudioClassifierOptions. */ +- public static class Builder { +- private BaseOptions baseOptions = BaseOptions.builder().build(); +- private String displayNamesLocale = "en"; +- private int maxResults = -1; +- private float scoreThreshold; +- private boolean isScoreThresholdSet; +- private List<String> labelAllowList = new ArrayList<>(); +- private List<String> labelDenyList = new ArrayList<>(); +- +- private Builder() {} +- +- /** Sets the general options to configure Task APIs, such as accelerators. */ +- public Builder setBaseOptions(BaseOptions baseOptions) { +- this.baseOptions = baseOptions; +- return this; +- } +- +- /** +- * Sets the locale to use for display names specified through the TFLite Model Metadata, if +- * any. +- * +- * <p>Defaults to English({@code "en"}). See the <a +- * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite +- * Metadata schema file.</a> for the accepted pattern of locale. +- */ +- public Builder setDisplayNamesLocale(String displayNamesLocale) { +- this.displayNamesLocale = displayNamesLocale; +- return this; +- } +- +- /** +- * Sets the maximum number of top scored results to return. +- * +- * @param maxResults if < 0, all results will be returned. If 0, an invalid argument error is +- * returned. Defaults to -1. +- * @throws IllegalArgumentException if maxResults is 0 +- */ +- public Builder setMaxResults(int maxResults) { +- if (maxResults == 0) { +- throw new IllegalArgumentException("maxResults cannot be 0."); ++ /** ++ * Creates an {@link AudioClassifier} instance. ++ * ++ * @param modelFile the classification model {@link File} instance ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static AudioClassifier createFromFileAndOptions( ++ File modelFile, final AudioClassifierOptions options) throws IOException { ++ try (ParcelFileDescriptor descriptor = ++ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { ++ return new AudioClassifier( ++ TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithModelFdAndOptions(descriptor.getFd(), ++ /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH, ++ /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options, ++ TaskJniUtils.createProtoBaseOptionsHandle( ++ options.getBaseOptions())); ++ } ++ }, AUDIO_CLASSIFIER_NATIVE_LIB)); + } +- this.maxResults = maxResults; +- return this; +- } +- +- /** +- * Sets the score threshold. +- * +- * <p>It overrides the one provided in the model metadata (if any). Results below this value +- * are rejected. +- */ +- public Builder setScoreThreshold(float scoreThreshold) { +- this.scoreThreshold = scoreThreshold; +- isScoreThresholdSet = true; +- return this; +- } +- +- /** +- * Sets the optional allowlist of labels. +- * +- * <p>If non-empty, classifications whose label is not in this set will be filtered out. +- * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList. +- */ +- public Builder setLabelAllowList(List<String> labelAllowList) { +- this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList)); +- return this; +- } +- +- /** +- * Sets the optional denylist of labels. +- * +- * <p>If non-empty, classifications whose label is in this set will be filtered out. Duplicate +- * or unknown labels are ignored. Mutually exclusive with labelAllowList. +- */ +- public Builder setLabelDenyList(List<String> labelDenyList) { +- this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList)); +- return this; +- } +- +- public AudioClassifierOptions build() { +- return new AudioClassifierOptions(this); +- } + } + +- @UsedByReflection("audio_classifier_jni.cc") +- public String getDisplayNamesLocale() { +- return displayNamesLocale; ++ /** ++ * Creates an {@link AudioClassifier} instance with a model buffer and {@link ++ * AudioClassifierOptions}. ++ * ++ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the ++ * classification model ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a ++ * {@link MappedByteBuffer} ++ */ ++ public static AudioClassifier createFromBufferAndOptions( ++ final ByteBuffer modelBuffer, final AudioClassifierOptions options) { ++ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { ++ throw new IllegalArgumentException( ++ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ } ++ return new AudioClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithByteBuffer(modelBuffer, options, ++ TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions())); ++ } ++ }, AUDIO_CLASSIFIER_NATIVE_LIB)); + } + +- @UsedByReflection("audio_classifier_jni.cc") +- public int getMaxResults() { +- return maxResults; ++ /** ++ * Constructor to initialize the JNI with a pointer from C++. ++ * ++ * @param nativeHandle a pointer referencing memory allocated in C++ ++ */ ++ private AudioClassifier(long nativeHandle) { ++ super(nativeHandle); + } + ++ /** Options for setting up an {@link AudioClassifier}. */ + @UsedByReflection("audio_classifier_jni.cc") +- public float getScoreThreshold() { +- return scoreThreshold; ++ public static class AudioClassifierOptions { ++ // Not using AutoValue for this class because scoreThreshold cannot have default value ++ // (otherwise, the default value would override the one in the model metadata) and ++ // `Optional` is not an option here, because ++ // 1. java.util.Optional require Java 8 while we need to support Java 7. ++ // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See ++ // the comments for labelAllowList. ++ private final BaseOptions baseOptions; ++ private final String displayNamesLocale; ++ private final int maxResults; ++ private final float scoreThreshold; ++ private final boolean isScoreThresholdSet; ++ // As an open source project, we've been trying avoiding depending on common java libraries, ++ // such as Guava, because it may introduce conflicts with clients who also happen to use ++ // those libraries. Therefore, instead of using ImmutableList here, we convert the List into ++ // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less ++ // vulnerable. ++ private final List<String> labelAllowList; ++ private final List<String> labelDenyList; ++ ++ public static Builder builder() { ++ return new Builder(); ++ } ++ ++ /** A builder that helps to configure an instance of AudioClassifierOptions. */ ++ public static class Builder { ++ private BaseOptions baseOptions = BaseOptions.builder().build(); ++ private String displayNamesLocale = "en"; ++ private int maxResults = -1; ++ private float scoreThreshold; ++ private boolean isScoreThresholdSet; ++ private List<String> labelAllowList = new ArrayList<>(); ++ private List<String> labelDenyList = new ArrayList<>(); ++ ++ private Builder() {} ++ ++ /** Sets the general options to configure Task APIs, such as accelerators. */ ++ public Builder setBaseOptions(BaseOptions baseOptions) { ++ this.baseOptions = baseOptions; ++ return this; ++ } ++ ++ /** ++ * Sets the locale to use for display names specified through the TFLite Model Metadata, ++ * if any. ++ * ++ * <p>Defaults to English({@code "en"}). See the <a ++ * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite ++ * Metadata schema file.</a> for the accepted pattern of locale. ++ */ ++ public Builder setDisplayNamesLocale(String displayNamesLocale) { ++ this.displayNamesLocale = displayNamesLocale; ++ return this; ++ } ++ ++ /** ++ * Sets the maximum number of top scored results to return. ++ * ++ * @param maxResults if < 0, all results will be returned. If 0, an invalid argument ++ * error is ++ * returned. Defaults to -1. ++ * @throws IllegalArgumentException if maxResults is 0 ++ */ ++ public Builder setMaxResults(int maxResults) { ++ if (maxResults == 0) { ++ throw new IllegalArgumentException("maxResults cannot be 0."); ++ } ++ this.maxResults = maxResults; ++ return this; ++ } ++ ++ /** ++ * Sets the score threshold. ++ * ++ * <p>It overrides the one provided in the model metadata (if any). Results below this ++ * value are rejected. ++ */ ++ public Builder setScoreThreshold(float scoreThreshold) { ++ this.scoreThreshold = scoreThreshold; ++ isScoreThresholdSet = true; ++ return this; ++ } ++ ++ /** ++ * Sets the optional allowlist of labels. ++ * ++ * <p>If non-empty, classifications whose label is not in this set will be filtered out. ++ * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList. ++ */ ++ public Builder setLabelAllowList(List<String> labelAllowList) { ++ this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList)); ++ return this; ++ } ++ ++ /** ++ * Sets the optional denylist of labels. ++ * ++ * <p>If non-empty, classifications whose label is in this set will be filtered out. ++ * Duplicate or unknown labels are ignored. Mutually exclusive with labelAllowList. ++ */ ++ public Builder setLabelDenyList(List<String> labelDenyList) { ++ this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList)); ++ return this; ++ } ++ ++ public AudioClassifierOptions build() { ++ return new AudioClassifierOptions(this); ++ } ++ } ++ ++ @UsedByReflection("audio_classifier_jni.cc") ++ public String getDisplayNamesLocale() { ++ return displayNamesLocale; ++ } ++ ++ @UsedByReflection("audio_classifier_jni.cc") ++ public int getMaxResults() { ++ return maxResults; ++ } ++ ++ @UsedByReflection("audio_classifier_jni.cc") ++ public float getScoreThreshold() { ++ return scoreThreshold; ++ } ++ ++ @UsedByReflection("audio_classifier_jni.cc") ++ public boolean getIsScoreThresholdSet() { ++ return isScoreThresholdSet; ++ } ++ ++ @UsedByReflection("audio_classifier_jni.cc") ++ public List<String> getLabelAllowList() { ++ return new ArrayList<>(labelAllowList); ++ } ++ ++ @UsedByReflection("audio_classifier_jni.cc") ++ public List<String> getLabelDenyList() { ++ return new ArrayList<>(labelDenyList); ++ } ++ ++ public BaseOptions getBaseOptions() { ++ return baseOptions; ++ } ++ ++ private AudioClassifierOptions(Builder builder) { ++ displayNamesLocale = builder.displayNamesLocale; ++ maxResults = builder.maxResults; ++ scoreThreshold = builder.scoreThreshold; ++ isScoreThresholdSet = builder.isScoreThresholdSet; ++ labelAllowList = builder.labelAllowList; ++ labelDenyList = builder.labelDenyList; ++ baseOptions = builder.baseOptions; ++ } + } + +- @UsedByReflection("audio_classifier_jni.cc") +- public boolean getIsScoreThresholdSet() { +- return isScoreThresholdSet; ++ /** ++ * Performs actual classification on the provided audio tensor. ++ * ++ * @param tensor a {@link TensorAudio} containing the input audio clip in float with values ++ * between [-1, 1). The {@code tensor} argument should have the same flat size as the TFLite ++ * model's input tensor. It's recommended to create {@code tensor} using {@code ++ * createInputTensorAudio} method. ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if error occurs when classifying the audio clip from the native ++ * code ++ */ ++ public List<Classifications> classify(TensorAudio tensor) { ++ TensorBuffer buffer = tensor.getTensorBuffer(); ++ TensorAudioFormat format = tensor.getFormat(); ++ checkState(buffer.getBuffer().hasArray(), ++ "Input tensor buffer should be a non-direct buffer with a backed array (i.e. not readonly" ++ + " buffer)."); ++ return classifyNative(getNativeHandle(), buffer.getBuffer().array(), format.getChannels(), ++ format.getSampleRate()); + } + +- @UsedByReflection("audio_classifier_jni.cc") +- public List<String> getLabelAllowList() { +- return new ArrayList<>(labelAllowList); ++ /** ++ * Creates a {@link TensorAudio} instance to store input audio samples. ++ * ++ * @return a {@link TensorAudio} with the same size as model input tensor ++ * @throws IllegalArgumentException if the model is not compatible ++ */ ++ public TensorAudio createInputTensorAudio() { ++ TensorAudioFormat format = getRequiredTensorAudioFormat(); ++ ++ long bufferSize = getRequiredInputBufferSize(); ++ long samples = bufferSize / format.getChannels(); ++ return TensorAudio.create(format, (int) samples); + } + +- @UsedByReflection("audio_classifier_jni.cc") +- public List<String> getLabelDenyList() { +- return new ArrayList<>(labelDenyList); ++ /** Returns the required input buffer size in number of float elements. */ ++ public long getRequiredInputBufferSize() { ++ return getRequiredInputBufferSizeNative(getNativeHandle()); + } + +- public BaseOptions getBaseOptions() { +- return baseOptions; ++ /** ++ * Creates an {@link android.media.AudioRecord} instance to record audio stream. The returned ++ * AudioRecord instance is initialized and client needs to call {@link ++ * android.media.AudioRecord#startRecording} method to start recording. ++ * ++ * @return an {@link android.media.AudioRecord} instance in {@link ++ * android.media.AudioRecord#STATE_INITIALIZED} ++ * @throws IllegalArgumentException if the model required channel count is unsupported ++ * @throws IllegalStateException if AudioRecord instance failed to initialize ++ */ ++ public AudioRecord createAudioRecord() { ++ TensorAudioFormat format = getRequiredTensorAudioFormat(); ++ int channelConfig = 0; ++ ++ switch (format.getChannels()) { ++ case 1: ++ channelConfig = AudioFormat.CHANNEL_IN_MONO; ++ break; ++ case 2: ++ channelConfig = AudioFormat.CHANNEL_IN_STEREO; ++ break; ++ default: ++ throw new IllegalArgumentException(String.format( ++ "Number of channels required by the model is %d. getAudioRecord method only" ++ + " supports 1 or 2 audio channels.", ++ format.getChannels())); ++ } ++ ++ int bufferSizeInBytes = AudioRecord.getMinBufferSize( ++ format.getSampleRate(), channelConfig, AudioFormat.ENCODING_PCM_FLOAT); ++ if (bufferSizeInBytes == AudioRecord.ERROR ++ || bufferSizeInBytes == AudioRecord.ERROR_BAD_VALUE) { ++ throw new IllegalStateException(String.format( ++ "AudioRecord.getMinBufferSize failed. Returned: %d", bufferSizeInBytes)); ++ } ++ // The buffer of AudioRecord should be strictly longer than what model requires so that ++ // clients could run `TensorAudio::load(record)` together with `AudioClassifier::classify`. ++ int bufferSizeMultiplier = 2; ++ int modelRequiredBufferSize = (int) getRequiredInputBufferSize() ++ * DataType.FLOAT32.byteSize() * bufferSizeMultiplier; ++ if (bufferSizeInBytes < modelRequiredBufferSize) { ++ bufferSizeInBytes = modelRequiredBufferSize; ++ } ++ AudioRecord audioRecord = new AudioRecord( ++ // including MIC, UNPROCESSED, and CAMCORDER. ++ MediaRecorder.AudioSource.VOICE_RECOGNITION, format.getSampleRate(), channelConfig, ++ AudioFormat.ENCODING_PCM_FLOAT, bufferSizeInBytes); ++ checkState(audioRecord.getState() == AudioRecord.STATE_INITIALIZED, ++ "AudioRecord failed to initialize"); ++ return audioRecord; + } + +- private AudioClassifierOptions(Builder builder) { +- displayNamesLocale = builder.displayNamesLocale; +- maxResults = builder.maxResults; +- scoreThreshold = builder.scoreThreshold; +- isScoreThresholdSet = builder.isScoreThresholdSet; +- labelAllowList = builder.labelAllowList; +- labelDenyList = builder.labelDenyList; +- baseOptions = builder.baseOptions; ++ /** Returns the {@link TensorAudioFormat} required by the model. */ ++ public TensorAudioFormat getRequiredTensorAudioFormat() { ++ return TensorAudioFormat.builder() ++ .setChannels(getRequiredChannels()) ++ .setSampleRate(getRequiredSampleRate()) ++ .build(); + } +- } +- +- /** +- * Performs actual classification on the provided audio tensor. +- * +- * @param tensor a {@link TensorAudio} containing the input audio clip in float with values +- * between [-1, 1). The {@code tensor} argument should have the same flat size as the TFLite +- * model's input tensor. It's recommended to create {@code tensor} using {@code +- * createInputTensorAudio} method. +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if error occurs when classifying the audio clip from the native +- * code +- */ +- public List<Classifications> classify(TensorAudio tensor) { +- TensorBuffer buffer = tensor.getTensorBuffer(); +- TensorAudioFormat format = tensor.getFormat(); +- checkState( +- buffer.getBuffer().hasArray(), +- "Input tensor buffer should be a non-direct buffer with a backed array (i.e. not readonly" +- + " buffer)."); +- return classifyNative( +- getNativeHandle(), +- buffer.getBuffer().array(), +- format.getChannels(), +- format.getSampleRate()); +- } +- +- /** +- * Creates a {@link TensorAudio} instance to store input audio samples. +- * +- * @return a {@link TensorAudio} with the same size as model input tensor +- * @throws IllegalArgumentException if the model is not compatible +- */ +- public TensorAudio createInputTensorAudio() { +- TensorAudioFormat format = getRequiredTensorAudioFormat(); +- +- long bufferSize = getRequiredInputBufferSize(); +- long samples = bufferSize / format.getChannels(); +- return TensorAudio.create(format, (int) samples); +- } +- +- /** Returns the required input buffer size in number of float elements. */ +- public long getRequiredInputBufferSize() { +- return getRequiredInputBufferSizeNative(getNativeHandle()); +- } +- +- /** +- * Creates an {@link android.media.AudioRecord} instance to record audio stream. The returned +- * AudioRecord instance is initialized and client needs to call {@link +- * android.media.AudioRecord#startRecording} method to start recording. +- * +- * @return an {@link android.media.AudioRecord} instance in {@link +- * android.media.AudioRecord#STATE_INITIALIZED} +- * @throws IllegalArgumentException if the model required channel count is unsupported +- * @throws IllegalStateException if AudioRecord instance failed to initialize +- */ +- public AudioRecord createAudioRecord() { +- TensorAudioFormat format = getRequiredTensorAudioFormat(); +- int channelConfig = 0; +- +- switch (format.getChannels()) { +- case 1: +- channelConfig = AudioFormat.CHANNEL_IN_MONO; +- break; +- case 2: +- channelConfig = AudioFormat.CHANNEL_IN_STEREO; +- break; +- default: +- throw new IllegalArgumentException( +- String.format( +- "Number of channels required by the model is %d. getAudioRecord method only" +- + " supports 1 or 2 audio channels.", +- format.getChannels())); ++ ++ private int getRequiredChannels() { ++ return getRequiredChannelsNative(getNativeHandle()); + } + +- int bufferSizeInBytes = +- AudioRecord.getMinBufferSize( +- format.getSampleRate(), channelConfig, AudioFormat.ENCODING_PCM_FLOAT); +- if (bufferSizeInBytes == AudioRecord.ERROR +- || bufferSizeInBytes == AudioRecord.ERROR_BAD_VALUE) { +- throw new IllegalStateException( +- String.format("AudioRecord.getMinBufferSize failed. Returned: %d", bufferSizeInBytes)); ++ private int getRequiredSampleRate() { ++ return getRequiredSampleRateNative(getNativeHandle()); + } +- // The buffer of AudioRecord should be strictly longer than what model requires so that clients +- // could run `TensorAudio::load(record)` together with `AudioClassifier::classify`. +- int bufferSizeMultiplier = 2; +- int modelRequiredBufferSize = +- (int) getRequiredInputBufferSize() * DataType.FLOAT32.byteSize() * bufferSizeMultiplier; +- if (bufferSizeInBytes < modelRequiredBufferSize) { +- bufferSizeInBytes = modelRequiredBufferSize; ++ ++ // TODO(b/183343074): JNI method invocation is very expensive, taking about .2ms ++ // each time. Consider combining the native getter methods into 1 and cache it in Java layer. ++ private static native long getRequiredInputBufferSizeNative(long nativeHandle); ++ ++ private static native int getRequiredChannelsNative(long nativeHandle); ++ ++ private static native int getRequiredSampleRateNative(long nativeHandle); ++ ++ private static native List<Classifications> classifyNative( ++ long nativeHandle, byte[] audioBuffer, int channels, int sampleRate); ++ ++ private static native long initJniWithModelFdAndOptions(int fileDescriptor, ++ long fileDescriptorLength, long fileDescriptorOffset, AudioClassifierOptions options, ++ long baseOptionsHandle); ++ ++ private static native long initJniWithByteBuffer( ++ ByteBuffer modelBuffer, AudioClassifierOptions options, long baseOptionsHandle); ++ ++ /** ++ * Releases memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier` instance. ++ * ++ * @param nativeHandle pointer to memory allocated ++ */ ++ @Override ++ protected void deinit(long nativeHandle) { ++ deinitJni(nativeHandle); + } +- AudioRecord audioRecord = +- new AudioRecord( +- // including MIC, UNPROCESSED, and CAMCORDER. +- MediaRecorder.AudioSource.VOICE_RECOGNITION, +- format.getSampleRate(), +- channelConfig, +- AudioFormat.ENCODING_PCM_FLOAT, +- bufferSizeInBytes); +- checkState( +- audioRecord.getState() == AudioRecord.STATE_INITIALIZED, +- "AudioRecord failed to initialize"); +- return audioRecord; +- } +- +- /** Returns the {@link TensorAudioFormat} required by the model. */ +- public TensorAudioFormat getRequiredTensorAudioFormat() { +- return TensorAudioFormat.builder() +- .setChannels(getRequiredChannels()) +- .setSampleRate(getRequiredSampleRate()) +- .build(); +- } +- +- private int getRequiredChannels() { +- return getRequiredChannelsNative(getNativeHandle()); +- } +- +- private int getRequiredSampleRate() { +- return getRequiredSampleRateNative(getNativeHandle()); +- } +- +- // TODO(b/183343074): JNI method invocation is very expensive, taking about .2ms +- // each time. Consider combining the native getter methods into 1 and cache it in Java layer. +- private static native long getRequiredInputBufferSizeNative(long nativeHandle); +- +- private static native int getRequiredChannelsNative(long nativeHandle); +- +- private static native int getRequiredSampleRateNative(long nativeHandle); +- +- private static native List<Classifications> classifyNative( +- long nativeHandle, byte[] audioBuffer, int channels, int sampleRate); +- +- private static native long initJniWithModelFdAndOptions( +- int fileDescriptor, +- long fileDescriptorLength, +- long fileDescriptorOffset, +- AudioClassifierOptions options, +- long baseOptionsHandle); +- +- private static native long initJniWithByteBuffer( +- ByteBuffer modelBuffer, AudioClassifierOptions options, long baseOptionsHandle); +- +- /** +- * Releases memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier` instance. +- * +- * @param nativeHandle pointer to memory allocated +- */ +- @Override +- protected void deinit(long nativeHandle) { +- deinitJni(nativeHandle); +- } +- +- /** +- * Native method to release memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier` +- * instance. +- * +- * @param nativeHandle pointer to memory allocated +- */ +- private static native void deinitJni(long nativeHandle); ++ ++ /** ++ * Native method to release memory pointed by {@code nativeHandle}, namely a C++ ++ * `AudioClassifier` instance. ++ * ++ * @param nativeHandle pointer to memory allocated ++ */ ++ private static native void deinitJni(long nativeHandle); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java +index 446d328441a97..7d5b07fa735cd 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java +@@ -16,11 +16,13 @@ limitations under the License. + package org.tensorflow.lite.task.audio.classifier; + + import com.google.auto.value.AutoValue; ++ ++import org.tensorflow.lite.annotations.UsedByReflection; ++import org.tensorflow.lite.support.label.Category; ++ + import java.util.ArrayList; + import java.util.Collections; + import java.util.List; +-import org.tensorflow.lite.annotations.UsedByReflection; +-import org.tensorflow.lite.support.label.Category; + + /** + * The classification results of one head in a multihead (a.k.a. multi-output) {@link +@@ -31,18 +33,18 @@ import org.tensorflow.lite.support.label.Category; + @AutoValue + @UsedByReflection("audio_classifier_jni.cc") + public abstract class Classifications { ++ @UsedByReflection("audio_classifier_jni.cc") ++ static Classifications create(List<Category> categories, int headIndex, String headName) { ++ return new AutoValue_Classifications( ++ Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex, ++ headName); ++ } + +- @UsedByReflection("audio_classifier_jni.cc") +- static Classifications create(List<Category> categories, int headIndex, String headName) { +- return new AutoValue_Classifications( +- Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex, headName); +- } +- +- // Same reason for not using ImmutableList as stated in +- // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}. +- public abstract List<Category> getCategories(); ++ // Same reason for not using ImmutableList as stated in ++ // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}. ++ public abstract List<Category> getCategories(); + +- public abstract int getHeadIndex(); ++ public abstract int getHeadIndex(); + +- public abstract String getHeadName(); ++ public abstract String getHeadName(); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java +index 242414bd21bdb..b2d722332c954 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java +@@ -20,65 +20,66 @@ import com.google.auto.value.AutoValue; + /** Options to configure Task APIs in general. */ + @AutoValue + public abstract class BaseOptions { +- private static final int DEFAULT_NUM_THREADS = -1; ++ private static final int DEFAULT_NUM_THREADS = -1; + +- /** Builder for {@link BaseOptions}. */ +- @AutoValue.Builder +- public abstract static class Builder { ++ /** Builder for {@link BaseOptions}. */ ++ @AutoValue.Builder ++ public abstract static class Builder { ++ /** ++ * Sets the advanced accelerator options. ++ * ++ * <p>Note: this method will override those highlevel API to choose an delegate, such as ++ * {@link #useGpu} and {@link #useNnapi}. ++ */ ++ public abstract Builder setComputeSettings(ComputeSettings computeSettings); + +- /** +- * Sets the advanced accelerator options. +- * +- * <p>Note: this method will override those highlevel API to choose an delegate, such as {@link +- * #useGpu} and {@link #useNnapi}. +- */ +- public abstract Builder setComputeSettings(ComputeSettings computeSettings); ++ /** ++ * Sets the number of threads to be used for TFLite ops that support multi-threading when ++ * running inference with CPU. Defaults to -1. ++ * ++ * <p>{@code numThreads} should be greater than 0 or equal to -1. Setting numThreads to -1 ++ * has the effect to let TFLite runtime set the value. ++ */ ++ public abstract Builder setNumThreads(int numThreads); + +- /** +- * Sets the number of threads to be used for TFLite ops that support multi-threading when +- * running inference with CPU. Defaults to -1. +- * +- * <p>{@code numThreads} should be greater than 0 or equal to -1. Setting numThreads to -1 has +- * the effect to let TFLite runtime set the value. +- */ +- public abstract Builder setNumThreads(int numThreads); ++ /** ++ * Uses GPU for inference. The advanced GPU configuration settings will be set to default ++ * values. ++ * ++ * <p>Note: this method will override the settings from {@link #setComputeSettings}. ++ * ++ * <p>To manipulate the advanced GPU configuration settings, use {@link ++ * #setComputeSettings}. ++ */ ++ public Builder useGpu() { ++ return setComputeSettings( ++ ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.GPU).build()); ++ } + +- /** +- * Uses GPU for inference. The advanced GPU configuration settings will be set to default +- * values. +- * +- * <p>Note: this method will override the settings from {@link #setComputeSettings}. +- * +- * <p>To manipulate the advanced GPU configuration settings, use {@link #setComputeSettings}. +- */ +- public Builder useGpu() { +- return setComputeSettings( +- ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.GPU).build()); +- } ++ /** ++ * Uses NNAPI for inference. The advanced NNAPI configuration settings will be set to ++ * default values. ++ * ++ * <p>Note: this method will override the settings from {@link #setComputeSettings}. ++ * ++ * <p>To manipulate the advanced NNAPI configuration settings, use {@link ++ * #setComputeSettings}. ++ */ ++ public Builder useNnapi() { ++ return setComputeSettings( ++ ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.NNAPI).build()); ++ } + +- /** +- * Uses NNAPI for inference. The advanced NNAPI configuration settings will be set to default +- * values. +- * +- * <p>Note: this method will override the settings from {@link #setComputeSettings}. +- * +- * <p>To manipulate the advanced NNAPI configuration settings, use {@link #setComputeSettings}. +- */ +- public Builder useNnapi() { +- return setComputeSettings( +- ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.NNAPI).build()); ++ public abstract BaseOptions build(); + } + +- public abstract BaseOptions build(); +- } +- +- public static Builder builder() { +- return new AutoValue_BaseOptions.Builder() +- .setComputeSettings(ComputeSettings.builder().build()) +- .setNumThreads(DEFAULT_NUM_THREADS); +- } ++ public static Builder builder() { ++ return new AutoValue_BaseOptions.Builder() ++ .setComputeSettings(ComputeSettings.builder().build()) ++ .setNumThreads(DEFAULT_NUM_THREADS); ++ } + +- abstract ComputeSettings getComputeSettings(); ++ abstract ComputeSettings getComputeSettings(); + +- abstract int getNumThreads(); ++ abstract int getNumThreads(); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java +index b3fe9def83c69..a8ae65cd1cf3b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java +@@ -16,76 +16,78 @@ limitations under the License. + package org.tensorflow.lite.task.core; + + import android.util.Log; ++ + import java.io.Closeable; + + /** + * Base class for Task API, provides shared logic to load/unload native libs to its C++ counterpart. + */ + public abstract class BaseTaskApi implements Closeable { +- private static final String TAG = BaseTaskApi.class.getSimpleName(); +- +- /** +- * Represents a pointer to the corresponding C++ task_api object. The nativeHandle pointer is +- * initialized from subclasses and must be released by calling {@link #deinit} after it is no +- * longer needed. +- */ +- private final long nativeHandle; +- +- /** Indicates whether the {@link #nativeHandle} pointer has been released yet. */ +- private boolean closed; +- +- /** +- * Constructor to initialize the JNI with a pointer from C++. +- * +- * @param nativeHandle a pointer referencing memory allocated in C++. +- */ +- protected BaseTaskApi(long nativeHandle) { +- if (nativeHandle == TaskJniUtils.INVALID_POINTER) { +- throw new IllegalArgumentException("Failed to load C++ pointer from JNI"); ++ private static final String TAG = BaseTaskApi.class.getSimpleName(); ++ ++ /** ++ * Represents a pointer to the corresponding C++ task_api object. The nativeHandle pointer is ++ * initialized from subclasses and must be released by calling {@link #deinit} after it is no ++ * longer needed. ++ */ ++ private final long nativeHandle; ++ ++ /** Indicates whether the {@link #nativeHandle} pointer has been released yet. */ ++ private boolean closed; ++ ++ /** ++ * Constructor to initialize the JNI with a pointer from C++. ++ * ++ * @param nativeHandle a pointer referencing memory allocated in C++. ++ */ ++ protected BaseTaskApi(long nativeHandle) { ++ if (nativeHandle == TaskJniUtils.INVALID_POINTER) { ++ throw new IllegalArgumentException("Failed to load C++ pointer from JNI"); ++ } ++ this.nativeHandle = nativeHandle; ++ } ++ ++ public boolean isClosed() { ++ return closed; + } +- this.nativeHandle = nativeHandle; +- } +- +- public boolean isClosed() { +- return closed; +- } +- +- /** Release the memory allocated from C++ and deregister the library from the static holder. */ +- @Override +- public synchronized void close() { +- if (closed) { +- return; ++ ++ /** Release the memory allocated from C++ and deregister the library from the static holder. */ ++ @Override ++ public synchronized void close() { ++ if (closed) { ++ return; ++ } ++ deinit(nativeHandle); ++ closed = true; + } +- deinit(nativeHandle); +- closed = true; +- } + +- public long getNativeHandle() { +- return nativeHandle; +- } ++ public long getNativeHandle() { ++ return nativeHandle; ++ } + +- protected void checkNotClosed() { +- if (isClosed()) { +- throw new IllegalStateException("Internal error: The task lib has already been closed."); ++ protected void checkNotClosed() { ++ if (isClosed()) { ++ throw new IllegalStateException( ++ "Internal error: The task lib has already been closed."); ++ } + } +- } +- +- @Override +- protected void finalize() throws Throwable { +- try { +- if (!closed) { +- Log.w(TAG, "Closing an already closed native lib"); +- close(); +- } +- } finally { +- super.finalize(); ++ ++ @Override ++ protected void finalize() throws Throwable { ++ try { ++ if (!closed) { ++ Log.w(TAG, "Closing an already closed native lib"); ++ close(); ++ } ++ } finally { ++ super.finalize(); ++ } + } +- } +- +- /** +- * Releases memory pointed by the pointer in the native layer. +- * +- * @param nativeHandle pointer to memory allocated +- */ +- protected abstract void deinit(long nativeHandle); ++ ++ /** ++ * Releases memory pointed by the pointer in the native layer. ++ * ++ * @param nativeHandle pointer to memory allocated ++ */ ++ protected abstract void deinit(long nativeHandle); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java +index 80a9e82ff3802..0c2d04283594d 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java +@@ -20,38 +20,36 @@ import com.google.auto.value.AutoValue; + /** Options to configure how to accelerate the model inference using dedicated delegates. */ + @AutoValue + public abstract class ComputeSettings { ++ /** TFLite accelerator delegate options. */ ++ public enum Delegate { ++ NONE(0), ++ NNAPI(1), ++ GPU(2); + +- /** TFLite accelerator delegate options. */ +- public enum Delegate { +- NONE(0), +- NNAPI(1), +- GPU(2); ++ private final int value; + +- private final int value; ++ Delegate(int value) { ++ this.value = value; ++ } + +- Delegate(int value) { +- this.value = value; ++ public int getValue() { ++ return value; ++ } + } + +- public int getValue() { +- return value; +- } +- } +- +- /** Builder for {@link ComputeSettings}. */ +- @AutoValue.Builder +- public abstract static class Builder { +- +- public abstract Builder setDelegate(Delegate delegate); ++ /** Builder for {@link ComputeSettings}. */ ++ @AutoValue.Builder ++ public abstract static class Builder { ++ public abstract Builder setDelegate(Delegate delegate); + +- public abstract ComputeSettings build(); +- } ++ public abstract ComputeSettings build(); ++ } + +- public static Builder builder() { +- return new AutoValue_ComputeSettings.Builder().setDelegate(DEFAULT_DELEGATE); +- } ++ public static Builder builder() { ++ return new AutoValue_ComputeSettings.Builder().setDelegate(DEFAULT_DELEGATE); ++ } + +- public abstract Delegate getDelegate(); ++ public abstract Delegate getDelegate(); + +- private static final Delegate DEFAULT_DELEGATE = Delegate.NONE; ++ private static final Delegate DEFAULT_DELEGATE = Delegate.NONE; + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java +index 76109f453b01f..9d5b775456c43 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java +@@ -18,6 +18,7 @@ package org.tensorflow.lite.task.core; + import android.content.Context; + import android.content.res.AssetFileDescriptor; + import android.util.Log; ++ + import java.io.FileInputStream; + import java.io.IOException; + import java.nio.ByteBuffer; +@@ -26,156 +27,146 @@ import java.nio.channels.FileChannel; + + /** JNI utils for Task API. */ + public class TaskJniUtils { +- public static final long INVALID_POINTER = 0; +- private static final String TAG = TaskJniUtils.class.getSimpleName(); +- /** Syntax sugar to get nativeHandle from empty param list. */ +- public interface EmptyHandleProvider { +- long createHandle(); +- } +- +- /** Syntax sugar to get nativeHandle from an array of {@link ByteBuffer}s. */ +- public interface MultipleBuffersHandleProvider { +- long createHandle(ByteBuffer... buffers); +- } +- +- /** Syntax sugar to get nativeHandle from file descriptor and options. */ +- public interface FdAndOptionsHandleProvider<T> { +- long createHandle( +- int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset, T options); +- } +- +- /** +- * Initializes the JNI and returns C++ handle with file descriptor and options for task API. +- * +- * @param context the Android app context +- * @param provider provider to get C++ handle, usually returned from native call +- * @param libName name of C++ lib to be loaded +- * @param filePath path of the file to be loaded +- * @param options options to set up the task API, used by the provider +- * @return C++ handle as long +- * @throws IOException If model file fails to load. +- */ +- public static <T> long createHandleFromFdAndOptions( +- Context context, +- final FdAndOptionsHandleProvider<T> provider, +- String libName, +- String filePath, +- final T options) +- throws IOException { +- try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(filePath)) { +- return createHandleFromLibrary( +- new EmptyHandleProvider() { ++ public static final long INVALID_POINTER = 0; ++ private static final String TAG = TaskJniUtils.class.getSimpleName(); ++ /** Syntax sugar to get nativeHandle from empty param list. */ ++ public interface EmptyHandleProvider { ++ long createHandle(); ++ } ++ ++ /** Syntax sugar to get nativeHandle from an array of {@link ByteBuffer}s. */ ++ public interface MultipleBuffersHandleProvider { ++ long createHandle(ByteBuffer... buffers); ++ } ++ ++ /** Syntax sugar to get nativeHandle from file descriptor and options. */ ++ public interface FdAndOptionsHandleProvider<T> { ++ long createHandle(int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset, ++ T options); ++ } ++ ++ /** ++ * Initializes the JNI and returns C++ handle with file descriptor and options for task API. ++ * ++ * @param context the Android app context ++ * @param provider provider to get C++ handle, usually returned from native call ++ * @param libName name of C++ lib to be loaded ++ * @param filePath path of the file to be loaded ++ * @param options options to set up the task API, used by the provider ++ * @return C++ handle as long ++ * @throws IOException If model file fails to load. ++ */ ++ public static <T> long createHandleFromFdAndOptions(Context context, ++ final FdAndOptionsHandleProvider<T> provider, String libName, String filePath, ++ final T options) throws IOException { ++ try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(filePath)) { ++ return createHandleFromLibrary(new EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return provider.createHandle( ++ /*fileDescriptor=*/assetFileDescriptor.getParcelFileDescriptor() ++ .getFd(), ++ /*fileDescriptorLength=*/assetFileDescriptor.getLength(), ++ /*fileDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options); ++ } ++ }, libName); ++ } ++ } ++ ++ /** ++ * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes ++ * {@link EmptyHandleProvider#createHandle()}. ++ * ++ * @param provider provider to get C++ handle, usually returned from native call ++ * @return C++ handle as long ++ */ ++ public static long createHandleFromLibrary(EmptyHandleProvider provider, String libName) { ++ tryLoadLibrary(libName); ++ try { ++ return provider.createHandle(); ++ } catch (RuntimeException e) { ++ String errorMessage = "Error getting native address of native library: " + libName; ++ Log.e(TAG, errorMessage, e); ++ throw new IllegalStateException(errorMessage, e); ++ } ++ } ++ ++ /** ++ * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes ++ * {@link MultipleBuffersHandleProvider#createHandle(ByteBuffer...)}. ++ * ++ * @param context app context ++ * @param provider provider to get C++ pointer, usually returned from native call ++ * @param libName name of C++ lib to load ++ * @param filePaths file paths to load ++ * @return C++ pointer as long ++ * @throws IOException If model file fails to load. ++ */ ++ public static long createHandleWithMultipleAssetFilesFromLibrary(Context context, ++ final MultipleBuffersHandleProvider provider, String libName, String... filePaths) ++ throws IOException { ++ final MappedByteBuffer[] buffers = new MappedByteBuffer[filePaths.length]; ++ for (int i = 0; i < filePaths.length; i++) { ++ buffers[i] = loadMappedFile(context, filePaths[i]); ++ } ++ return createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { +- return provider.createHandle( +- /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(), +- /*fileDescriptorLength=*/ assetFileDescriptor.getLength(), +- /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(), +- options); ++ return provider.createHandle(buffers); + } +- }, +- libName); +- } +- } +- +- /** +- * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes +- * {@link EmptyHandleProvider#createHandle()}. +- * +- * @param provider provider to get C++ handle, usually returned from native call +- * @return C++ handle as long +- */ +- public static long createHandleFromLibrary(EmptyHandleProvider provider, String libName) { +- tryLoadLibrary(libName); +- try { +- return provider.createHandle(); +- } catch (RuntimeException e) { +- String errorMessage = "Error getting native address of native library: " + libName; +- Log.e(TAG, errorMessage, e); +- throw new IllegalStateException(errorMessage, e); +- } +- } +- +- /** +- * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes +- * {@link MultipleBuffersHandleProvider#createHandle(ByteBuffer...)}. +- * +- * @param context app context +- * @param provider provider to get C++ pointer, usually returned from native call +- * @param libName name of C++ lib to load +- * @param filePaths file paths to load +- * @return C++ pointer as long +- * @throws IOException If model file fails to load. +- */ +- public static long createHandleWithMultipleAssetFilesFromLibrary( +- Context context, +- final MultipleBuffersHandleProvider provider, +- String libName, +- String... filePaths) +- throws IOException { +- final MappedByteBuffer[] buffers = new MappedByteBuffer[filePaths.length]; +- for (int i = 0; i < filePaths.length; i++) { +- buffers[i] = loadMappedFile(context, filePaths[i]); ++ }, libName); + } +- return createHandleFromLibrary( +- new EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return provider.createHandle(buffers); +- } +- }, +- libName); +- } +- +- /** +- * Loads a file from the asset folder through memory mapping. +- * +- * @param context Application context to access assets. +- * @param filePath Asset path of the file. +- * @return the loaded memory mapped file. +- * @throws IOException If model file fails to load. +- */ +- public static MappedByteBuffer loadMappedFile(Context context, String filePath) +- throws IOException { +- try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath); +- FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) { +- FileChannel fileChannel = inputStream.getChannel(); +- long startOffset = fileDescriptor.getStartOffset(); +- long declaredLength = fileDescriptor.getDeclaredLength(); +- return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); ++ ++ /** ++ * Loads a file from the asset folder through memory mapping. ++ * ++ * @param context Application context to access assets. ++ * @param filePath Asset path of the file. ++ * @return the loaded memory mapped file. ++ * @throws IOException If model file fails to load. ++ */ ++ public static MappedByteBuffer loadMappedFile(Context context, String filePath) ++ throws IOException { ++ try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath); ++ FileInputStream inputStream = ++ new FileInputStream(fileDescriptor.getFileDescriptor())) { ++ FileChannel fileChannel = inputStream.getChannel(); ++ long startOffset = fileDescriptor.getStartOffset(); ++ long declaredLength = fileDescriptor.getDeclaredLength(); ++ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); ++ } + } +- } +- +- /** +- * Try loading a native library, if it's already loaded return directly. +- * +- * @param libName name of the lib +- */ +- public static void tryLoadLibrary(String libName) { +- try { +- System.loadLibrary(libName); +- } catch (UnsatisfiedLinkError e) { +- String errorMessage = "Error loading native library: " + libName; +- Log.e(TAG, errorMessage, e); +- throw new UnsatisfiedLinkError(errorMessage); ++ ++ /** ++ * Try loading a native library, if it's already loaded return directly. ++ * ++ * @param libName name of the lib ++ */ ++ public static void tryLoadLibrary(String libName) { ++ try { ++ System.loadLibrary(libName); ++ } catch (UnsatisfiedLinkError e) { ++ String errorMessage = "Error loading native library: " + libName; ++ Log.e(TAG, errorMessage, e); ++ throw new UnsatisfiedLinkError(errorMessage); ++ } + } +- } + +- public static long createProtoBaseOptionsHandle(BaseOptions baseOptions) { +- return createProtoBaseOptionsHandleWithLegacyNumThreads(baseOptions, /*legacyNumThreads =*/ -1); +- } ++ public static long createProtoBaseOptionsHandle(BaseOptions baseOptions) { ++ return createProtoBaseOptionsHandleWithLegacyNumThreads( ++ baseOptions, /*legacyNumThreads =*/-1); ++ } + +- public static long createProtoBaseOptionsHandleWithLegacyNumThreads( +- BaseOptions baseOptions, int legacyNumThreads) { +- // NumThreads should be configured through BaseOptions. However, if NumThreads is configured +- // through the legacy API of the Task Java API (then it will not equal to -1, the default +- // value), use it to overide the one in baseOptions. +- return createProtoBaseOptions( +- baseOptions.getComputeSettings().getDelegate().getValue(), +- legacyNumThreads == -1 ? baseOptions.getNumThreads() : legacyNumThreads); +- } ++ public static long createProtoBaseOptionsHandleWithLegacyNumThreads( ++ BaseOptions baseOptions, int legacyNumThreads) { ++ // NumThreads should be configured through BaseOptions. However, if NumThreads is configured ++ // through the legacy API of the Task Java API (then it will not equal to -1, the default ++ // value), use it to overide the one in baseOptions. ++ return createProtoBaseOptions(baseOptions.getComputeSettings().getDelegate().getValue(), ++ legacyNumThreads == -1 ? baseOptions.getNumThreads() : legacyNumThreads); ++ } + +- private TaskJniUtils() {} ++ private TaskJniUtils() {} + +- private static native long createProtoBaseOptions(int delegate, int numThreads); ++ private static native long createProtoBaseOptions(int delegate, int numThreads); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java +index 287ba444c386b..b1784d02f2362 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java +@@ -16,6 +16,7 @@ limitations under the License. + package org.tensorflow.lite.task.core.vision; + + import android.graphics.Rect; ++ + import com.google.auto.value.AutoValue; + + /** +@@ -45,74 +46,74 @@ import com.google.auto.value.AutoValue; + */ + @AutoValue + public abstract class ImageProcessingOptions { +- +- /** +- * Orientation type that follows EXIF specification. +- * +- * <p>The name of each enum value defines the position of the 0th row and the 0th column of the +- * image content. See the <a href="http://jpegclub.org/exif_orientation.html">EXIF orientation +- * documentation</a> for details. +- */ +- public enum Orientation { +- TOP_LEFT(0), +- TOP_RIGHT(1), +- BOTTOM_RIGHT(2), +- BOTTOM_LEFT(3), +- LEFT_TOP(4), +- RIGHT_TOP(5), +- RIGHT_BOTTOM(6), +- LEFT_BOTTOM(7); +- +- private final int value; +- +- Orientation(int value) { +- this.value = value; +- } +- +- public int getValue() { +- return value; +- } +- }; +- +- private static final Rect defaultRoi = new Rect(); +- private static final Orientation DEFAULT_ORIENTATION = Orientation.TOP_LEFT; +- +- public abstract Rect getRoi(); +- +- public abstract Orientation getOrientation(); +- +- public static Builder builder() { +- return new AutoValue_ImageProcessingOptions.Builder() +- .setRoi(defaultRoi) +- .setOrientation(DEFAULT_ORIENTATION); +- } +- +- /** Builder for {@link ImageProcessingOptions}. */ +- @AutoValue.Builder +- public abstract static class Builder { +- + /** +- * Sets the region of interest (ROI) of the image. Defaults to the entire image. ++ * Orientation type that follows EXIF specification. + * +- * <p>Cropping according to this region of interest is prepended to the pre-processing +- * operations. ++ * <p>The name of each enum value defines the position of the 0th row and the 0th column of the ++ * image content. See the <a href="http://jpegclub.org/exif_orientation.html">EXIF orientation ++ * documentation</a> for details. + */ +- public abstract Builder setRoi(Rect roi); ++ public enum Orientation { ++ TOP_LEFT(0), ++ TOP_RIGHT(1), ++ BOTTOM_RIGHT(2), ++ BOTTOM_LEFT(3), ++ LEFT_TOP(4), ++ RIGHT_TOP(5), ++ RIGHT_BOTTOM(6), ++ LEFT_BOTTOM(7); ++ ++ private final int value; ++ ++ Orientation(int value) { ++ this.value = value; ++ } ++ ++ public int getValue() { ++ return value; ++ } ++ } ++ ; + +- /** +- * Sets the orientation of the image. Defaults to {@link Orientation#TOP_LEFT}. +- * +- * <p>Rotation will be applied accordingly so that inference is performed on an "upright" image. +- */ +- public abstract Builder setOrientation(Orientation orientation); ++ private static final Rect defaultRoi = new Rect(); ++ private static final Orientation DEFAULT_ORIENTATION = Orientation.TOP_LEFT; + +- abstract Rect getRoi(); ++ public abstract Rect getRoi(); + +- abstract ImageProcessingOptions autoBuild(); ++ public abstract Orientation getOrientation(); ++ ++ public static Builder builder() { ++ return new AutoValue_ImageProcessingOptions.Builder() ++ .setRoi(defaultRoi) ++ .setOrientation(DEFAULT_ORIENTATION); ++ } + +- public ImageProcessingOptions build() { +- setRoi(new Rect(getRoi())); // Make a defensive copy, since Rect is mutable. +- return autoBuild(); ++ /** Builder for {@link ImageProcessingOptions}. */ ++ @AutoValue.Builder ++ public abstract static class Builder { ++ /** ++ * Sets the region of interest (ROI) of the image. Defaults to the entire image. ++ * ++ * <p>Cropping according to this region of interest is prepended to the pre-processing ++ * operations. ++ */ ++ public abstract Builder setRoi(Rect roi); ++ ++ /** ++ * Sets the orientation of the image. Defaults to {@link Orientation#TOP_LEFT}. ++ * ++ * <p>Rotation will be applied accordingly so that inference is performed on an "upright" ++ * image. ++ */ ++ public abstract Builder setOrientation(Orientation orientation); ++ ++ abstract Rect getRoi(); ++ ++ abstract ImageProcessingOptions autoBuild(); ++ ++ public ImageProcessingOptions build() { ++ setRoi(new Rect(getRoi())); // Make a defensive copy, since Rect is mutable. ++ return autoBuild(); ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java +index d0ac3f83b4ed5..ce912c96e29de 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java +@@ -17,12 +17,9 @@ package org.tensorflow.lite.task.text.nlclassifier; + + import android.content.Context; + import android.os.ParcelFileDescriptor; ++ + import com.google.auto.value.AutoValue; +-import java.io.File; +-import java.io.IOException; +-import java.nio.ByteBuffer; +-import java.nio.MappedByteBuffer; +-import java.util.List; ++ + import org.tensorflow.lite.annotations.UsedByReflection; + import org.tensorflow.lite.support.label.Category; + import org.tensorflow.lite.task.core.BaseOptions; +@@ -30,6 +27,12 @@ import org.tensorflow.lite.task.core.BaseTaskApi; + import org.tensorflow.lite.task.core.TaskJniUtils; + import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; + ++import java.io.File; ++import java.io.IOException; ++import java.nio.ByteBuffer; ++import java.nio.MappedByteBuffer; ++import java.util.List; ++ + /** + * Classifier API for NLClassification tasks with Bert models, categorizes string into different + * classes. The API expects a Bert based TFLite model with metadata populated. +@@ -45,209 +48,199 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; + * </ul> + */ + public class BertNLClassifier extends BaseTaskApi { ++ private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni"; ++ ++ /** Options to configure BertNLClassifier. */ ++ @AutoValue ++ @UsedByReflection("bert_nl_classifier_jni.cc") ++ public abstract static class BertNLClassifierOptions { ++ static final int DEFAULT_MAX_SEQ_LEN = 128; ++ ++ abstract int getMaxSeqLen(); ++ ++ abstract BaseOptions getBaseOptions(); ++ ++ public static Builder builder() { ++ return new AutoValue_BertNLClassifier_BertNLClassifierOptions.Builder() ++ .setMaxSeqLen(DEFAULT_MAX_SEQ_LEN) ++ .setBaseOptions(BaseOptions.builder().build()); ++ } ++ ++ /** Builder for {@link BertNLClassifierOptions}. */ ++ @AutoValue.Builder ++ public abstract static class Builder { ++ /** Sets the general options to configure Task APIs, such as accelerators. */ ++ public abstract Builder setBaseOptions(BaseOptions baseOptions); ++ ++ /** ++ * Set the maximum sequence length. ++ * ++ * @deprecated maximum sequence length is now read from the model (i.e. input tensor ++ * size) ++ * automatically ++ */ ++ @Deprecated ++ public abstract Builder setMaxSeqLen(int value); ++ ++ public abstract BertNLClassifierOptions build(); ++ } ++ } + +- private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni"; +- +- /** Options to configure BertNLClassifier. */ +- @AutoValue +- @UsedByReflection("bert_nl_classifier_jni.cc") +- public abstract static class BertNLClassifierOptions { +- static final int DEFAULT_MAX_SEQ_LEN = 128; +- +- abstract int getMaxSeqLen(); ++ /** ++ * Creates {@link BertNLClassifier} from a model file with metadata and default {@link ++ * BertNLClassifierOptions}. ++ * ++ * @param context Android context ++ * @param modelPath Path to the classification model ++ * @return a {@link BertNLClassifier} instance ++ * @throws IOException If model file fails to load ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static BertNLClassifier createFromFile(final Context context, final String modelPath) ++ throws IOException { ++ return createFromBuffer(TaskJniUtils.loadMappedFile(context, modelPath)); ++ } + +- abstract BaseOptions getBaseOptions(); ++ /** ++ * Creates {@link BertNLClassifier} from a {@link File} object with metadata and default {@link ++ * BertNLClassifierOptions}. ++ * ++ * @param modelFile The classification model {@link File} instance ++ * @return a {@link BertNLClassifier} instance ++ * @throws IOException If model file fails to load ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static BertNLClassifier createFromFile(File modelFile) throws IOException { ++ return createFromFileAndOptions(modelFile, BertNLClassifierOptions.builder().build()); ++ } + +- public static Builder builder() { +- return new AutoValue_BertNLClassifier_BertNLClassifierOptions.Builder() +- .setMaxSeqLen(DEFAULT_MAX_SEQ_LEN) +- .setBaseOptions(BaseOptions.builder().build()); ++ /** ++ * Creates {@link BertNLClassifier} from a model file with metadata and {@link ++ * BertNLClassifierOptions}. ++ * ++ * @param context Android context. ++ * @param modelPath Path to the classification model ++ * @param options to configure the classifier ++ * @return a {@link BertNLClassifier} instance ++ * @throws IOException If model file fails to load ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static BertNLClassifier createFromFileAndOptions(final Context context, ++ final String modelPath, BertNLClassifierOptions options) throws IOException { ++ return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options); + } + +- /** Builder for {@link BertNLClassifierOptions}. */ +- @AutoValue.Builder +- public abstract static class Builder { ++ /** ++ * Creates {@link BertNLClassifier} from a {@link File} object with metadata and {@link ++ * BertNLClassifierOptions}. ++ * ++ * @param modelFile The classification model {@link File} instance ++ * @param options to configure the classifier ++ * @return a {@link BertNLClassifier} instance ++ * @throws IOException If model file fails to load ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static BertNLClassifier createFromFileAndOptions( ++ File modelFile, final BertNLClassifierOptions options) throws IOException { ++ try (ParcelFileDescriptor descriptor = ++ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { ++ return new BertNLClassifier( ++ TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithFileDescriptor(descriptor.getFd(), options, ++ TaskJniUtils.createProtoBaseOptionsHandle( ++ options.getBaseOptions())); ++ } ++ }, BERT_NL_CLASSIFIER_NATIVE_LIBNAME)); ++ } ++ } + +- /** Sets the general options to configure Task APIs, such as accelerators. */ +- public abstract Builder setBaseOptions(BaseOptions baseOptions); ++ /** ++ * Creates {@link BertNLClassifier} with a model buffer and default {@link ++ * BertNLClassifierOptions}. ++ * ++ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model ++ * @return a {@link BertNLClassifier} instance ++ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a ++ * {@link MappedByteBuffer} ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static BertNLClassifier createFromBuffer(final ByteBuffer modelBuffer) { ++ return createFromBufferAndOptions(modelBuffer, BertNLClassifierOptions.builder().build()); ++ } + +- /** +- * Set the maximum sequence length. +- * +- * @deprecated maximum sequence length is now read from the model (i.e. input tensor size) +- * automatically +- */ +- @Deprecated +- public abstract Builder setMaxSeqLen(int value); ++ /** ++ * Creates {@link BertNLClassifier} with a model buffer and {@link BertNLClassifierOptions}. ++ * ++ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model ++ * @param options to configure the classifier ++ * @return a {@link BertNLClassifier} instance ++ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a ++ * {@link MappedByteBuffer} ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static BertNLClassifier createFromBufferAndOptions( ++ final ByteBuffer modelBuffer, final BertNLClassifierOptions options) { ++ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { ++ throw new IllegalArgumentException( ++ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ } ++ return new BertNLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithByteBuffer(modelBuffer, options, ++ TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions())); ++ } ++ }, BERT_NL_CLASSIFIER_NATIVE_LIBNAME)); ++ } + +- public abstract BertNLClassifierOptions build(); ++ /** ++ * Performs classification on a string input, returns classified {@link Category}s. ++ * ++ * @param text input text to the model. ++ * @return A list of Category results. ++ */ ++ public List<Category> classify(String text) { ++ return classifyNative(getNativeHandle(), text); + } +- } +- +- /** +- * Creates {@link BertNLClassifier} from a model file with metadata and default {@link +- * BertNLClassifierOptions}. +- * +- * @param context Android context +- * @param modelPath Path to the classification model +- * @return a {@link BertNLClassifier} instance +- * @throws IOException If model file fails to load +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static BertNLClassifier createFromFile(final Context context, final String modelPath) +- throws IOException { +- return createFromBuffer(TaskJniUtils.loadMappedFile(context, modelPath)); +- } +- +- /** +- * Creates {@link BertNLClassifier} from a {@link File} object with metadata and default {@link +- * BertNLClassifierOptions}. +- * +- * @param modelFile The classification model {@link File} instance +- * @return a {@link BertNLClassifier} instance +- * @throws IOException If model file fails to load +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static BertNLClassifier createFromFile(File modelFile) throws IOException { +- return createFromFileAndOptions(modelFile, BertNLClassifierOptions.builder().build()); +- } +- +- /** +- * Creates {@link BertNLClassifier} from a model file with metadata and {@link +- * BertNLClassifierOptions}. +- * +- * @param context Android context. +- * @param modelPath Path to the classification model +- * @param options to configure the classifier +- * @return a {@link BertNLClassifier} instance +- * @throws IOException If model file fails to load +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static BertNLClassifier createFromFileAndOptions( +- final Context context, final String modelPath, BertNLClassifierOptions options) +- throws IOException { +- return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options); +- } +- +- /** +- * Creates {@link BertNLClassifier} from a {@link File} object with metadata and {@link +- * BertNLClassifierOptions}. +- * +- * @param modelFile The classification model {@link File} instance +- * @param options to configure the classifier +- * @return a {@link BertNLClassifier} instance +- * @throws IOException If model file fails to load +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static BertNLClassifier createFromFileAndOptions( +- File modelFile, final BertNLClassifierOptions options) throws IOException { +- try (ParcelFileDescriptor descriptor = +- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { +- return new BertNLClassifier( +- TaskJniUtils.createHandleFromLibrary( +- new EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithFileDescriptor( +- descriptor.getFd(), +- options, +- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions())); +- } +- }, +- BERT_NL_CLASSIFIER_NATIVE_LIBNAME)); ++ ++ /** ++ * Constructor to initialize the JNI with a pointer from C++. ++ * ++ * @param nativeHandle a pointer referencing memory allocated in C++. ++ */ ++ private BertNLClassifier(long nativeHandle) { ++ super(nativeHandle); + } +- } +- +- /** +- * Creates {@link BertNLClassifier} with a model buffer and default {@link +- * BertNLClassifierOptions}. +- * +- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model +- * @return a {@link BertNLClassifier} instance +- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a +- * {@link MappedByteBuffer} +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static BertNLClassifier createFromBuffer(final ByteBuffer modelBuffer) { +- return createFromBufferAndOptions(modelBuffer, BertNLClassifierOptions.builder().build()); +- } +- +- /** +- * Creates {@link BertNLClassifier} with a model buffer and {@link BertNLClassifierOptions}. +- * +- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model +- * @param options to configure the classifier +- * @return a {@link BertNLClassifier} instance +- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a +- * {@link MappedByteBuffer} +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static BertNLClassifier createFromBufferAndOptions( +- final ByteBuffer modelBuffer, final BertNLClassifierOptions options) { +- if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { +- throw new IllegalArgumentException( +- "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ ++ private static native long initJniWithByteBuffer( ++ ByteBuffer modelBuffer, BertNLClassifierOptions options, long baseOptionsHandle); ++ ++ private static native long initJniWithFileDescriptor( ++ int fd, BertNLClassifierOptions options, long baseOptionsHandle); ++ ++ private static native List<Category> classifyNative(long nativeHandle, String text); ++ ++ @Override ++ protected void deinit(long nativeHandle) { ++ deinitJni(nativeHandle); + } +- return new BertNLClassifier( +- TaskJniUtils.createHandleFromLibrary( +- new EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithByteBuffer( +- modelBuffer, +- options, +- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions())); +- } +- }, +- BERT_NL_CLASSIFIER_NATIVE_LIBNAME)); +- } +- +- /** +- * Performs classification on a string input, returns classified {@link Category}s. +- * +- * @param text input text to the model. +- * @return A list of Category results. +- */ +- public List<Category> classify(String text) { +- return classifyNative(getNativeHandle(), text); +- } +- +- /** +- * Constructor to initialize the JNI with a pointer from C++. +- * +- * @param nativeHandle a pointer referencing memory allocated in C++. +- */ +- private BertNLClassifier(long nativeHandle) { +- super(nativeHandle); +- } +- +- private static native long initJniWithByteBuffer( +- ByteBuffer modelBuffer, BertNLClassifierOptions options, long baseOptionsHandle); +- +- private static native long initJniWithFileDescriptor( +- int fd, BertNLClassifierOptions options, long baseOptionsHandle); +- +- private static native List<Category> classifyNative(long nativeHandle, String text); +- +- @Override +- protected void deinit(long nativeHandle) { +- deinitJni(nativeHandle); +- } +- +- /** +- * Native implementation to release memory pointed by the pointer. +- * +- * @param nativeHandle pointer to memory allocated +- */ +- private native void deinitJni(long nativeHandle); ++ ++ /** ++ * Native implementation to release memory pointed by the pointer. ++ * ++ * @param nativeHandle pointer to memory allocated ++ */ ++ private native void deinitJni(long nativeHandle); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java +index ff573bf415759..b8aa32be94dc5 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java +@@ -17,13 +17,11 @@ package org.tensorflow.lite.task.text.nlclassifier; + + import android.content.Context; + import android.os.ParcelFileDescriptor; ++ + import androidx.annotation.Nullable; ++ + import com.google.auto.value.AutoValue; +-import java.io.File; +-import java.io.IOException; +-import java.nio.ByteBuffer; +-import java.nio.MappedByteBuffer; +-import java.util.List; ++ + import org.tensorflow.lite.annotations.UsedByReflection; + import org.tensorflow.lite.support.label.Category; + import org.tensorflow.lite.task.core.BaseOptions; +@@ -31,6 +29,12 @@ import org.tensorflow.lite.task.core.BaseTaskApi; + import org.tensorflow.lite.task.core.TaskJniUtils; + import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; + ++import java.io.File; ++import java.io.IOException; ++import java.nio.ByteBuffer; ++import java.nio.MappedByteBuffer; ++import java.util.List; ++ + /** + * Classifier API for natural language classification tasks, categorizes string into different + * classes. +@@ -67,294 +71,296 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; + * configurable for different TFLite models. + */ + public class NLClassifier extends BaseTaskApi { +- +- /** Options to identify input and output tensors of the model. */ +- @AutoValue +- @UsedByReflection("nl_classifier_jni.cc") +- public abstract static class NLClassifierOptions { +- private static final int DEFAULT_INPUT_TENSOR_INDEX = 0; +- private static final int DEFAULT_OUTPUT_SCORE_TENSOR_INDEX = 0; +- // By default there is no output label tensor. The label file can be attached +- // to the output score tensor metadata. +- private static final int DEFAULT_OUTPUT_LABEL_TENSOR_INDEX = -1; +- private static final String DEFAULT_INPUT_TENSOR_NAME = "INPUT"; +- private static final String DEFAULT_OUTPUT_SCORE_TENSOR_NAME = "OUTPUT_SCORE"; +- private static final String DEFAULT_OUTPUT_LABEL_TENSOR_NAME = "OUTPUT_LABEL"; +- +- @UsedByReflection("nl_classifier_jni.cc") +- abstract int getInputTensorIndex(); +- +- @UsedByReflection("nl_classifier_jni.cc") +- abstract int getOutputScoreTensorIndex(); +- ++ /** Options to identify input and output tensors of the model. */ ++ @AutoValue + @UsedByReflection("nl_classifier_jni.cc") +- abstract int getOutputLabelTensorIndex(); +- +- @UsedByReflection("nl_classifier_jni.cc") +- abstract String getInputTensorName(); ++ public abstract static class NLClassifierOptions { ++ private static final int DEFAULT_INPUT_TENSOR_INDEX = 0; ++ private static final int DEFAULT_OUTPUT_SCORE_TENSOR_INDEX = 0; ++ // By default there is no output label tensor. The label file can be attached ++ // to the output score tensor metadata. ++ private static final int DEFAULT_OUTPUT_LABEL_TENSOR_INDEX = -1; ++ private static final String DEFAULT_INPUT_TENSOR_NAME = "INPUT"; ++ private static final String DEFAULT_OUTPUT_SCORE_TENSOR_NAME = "OUTPUT_SCORE"; ++ private static final String DEFAULT_OUTPUT_LABEL_TENSOR_NAME = "OUTPUT_LABEL"; ++ ++ @UsedByReflection("nl_classifier_jni.cc") ++ abstract int getInputTensorIndex(); ++ ++ @UsedByReflection("nl_classifier_jni.cc") ++ abstract int getOutputScoreTensorIndex(); ++ ++ @UsedByReflection("nl_classifier_jni.cc") ++ abstract int getOutputLabelTensorIndex(); ++ ++ @UsedByReflection("nl_classifier_jni.cc") ++ abstract String getInputTensorName(); ++ ++ @UsedByReflection("nl_classifier_jni.cc") ++ abstract String getOutputScoreTensorName(); ++ ++ @UsedByReflection("nl_classifier_jni.cc") ++ abstract String getOutputLabelTensorName(); ++ ++ @Nullable ++ abstract BaseOptions getBaseOptions(); ++ ++ public static Builder builder() { ++ return new AutoValue_NLClassifier_NLClassifierOptions.Builder() ++ .setInputTensorIndex(DEFAULT_INPUT_TENSOR_INDEX) ++ .setOutputScoreTensorIndex(DEFAULT_OUTPUT_SCORE_TENSOR_INDEX) ++ .setOutputLabelTensorIndex(DEFAULT_OUTPUT_LABEL_TENSOR_INDEX) ++ .setInputTensorName(DEFAULT_INPUT_TENSOR_NAME) ++ .setOutputScoreTensorName(DEFAULT_OUTPUT_SCORE_TENSOR_NAME) ++ .setOutputLabelTensorName(DEFAULT_OUTPUT_LABEL_TENSOR_NAME); ++ } ++ ++ /** Builder for {@link NLClassifierOptions}. */ ++ @AutoValue.Builder ++ public abstract static class Builder { ++ /** Sets the general options to configure Task APIs, such as accelerators. */ ++ public abstract Builder setBaseOptions(@Nullable BaseOptions baseOptions); ++ ++ /** ++ * Configure the input/output tensors for NLClassifier: ++ * ++ * <p>- No special configuration is needed if the model has only one input tensor and ++ * one output tensor. ++ * ++ * <p>- When the model has multiple input or output tensors, use the following ++ * configurations to specifiy the desired tensors: <br> ++ * -- tensor names: {@code inputTensorName}, {@code outputScoreTensorName}, {@code ++ * outputLabelTensorName}<br> ++ * -- tensor indices: {@code inputTensorIndex}, {@code outputScoreTensorIndex}, {@code ++ * outputLabelTensorIndex} <br> ++ * Tensor names has higher priorities than tensor indices in locating the tensors. It ++ * means the tensors will be first located according to tensor names. If not found, then ++ * the tensors will be located according to tensor indices. ++ * ++ * <p>- Failing to match the input text tensor or output score tensor with neither ++ * tensor names nor tensor indices will trigger a runtime error. However, failing to ++ * locate the output label tensor will not trigger an error because the label tensor is ++ * optional. ++ */ ++ ++ /** ++ * Set the name of the input text tensor, if the model has multiple inputs. Only the ++ * input tensor specified will be used for inference; other input tensors will be ++ * ignored. Dafualt to {@code "INPUT"}. ++ * ++ * <p>See the section, Configure the input/output tensors for NLClassifier, for more ++ * details. ++ */ ++ public abstract Builder setInputTensorName(String inputTensorName); ++ ++ /** ++ * Set the name of the output score tensor, if the model has multiple outputs. Dafualt ++ * to ++ * {@code "OUTPUT_SCORE"}. ++ * ++ * <p>See the section, Configure the input/output tensors for NLClassifier, for more ++ * details. ++ */ ++ public abstract Builder setOutputScoreTensorName(String outputScoreTensorName); ++ ++ /** ++ * Set the name of the output label tensor, if the model has multiple outputs. Dafualt ++ * to ++ * {@code "OUTPUT_LABEL"}. ++ * ++ * <p>See the section, Configure the input/output tensors for NLClassifier, for more ++ * details. ++ * ++ * <p>By default, label file should be packed with the output score tensor through Model ++ * Metadata. See the <a ++ * href="https://www.tensorflow.org/lite/convert/metadata_writer_tutorial#natural_language_classifiers">MetadataWriter ++ * for NLClassifier</a>. NLClassifier reads and parses labels from the label file ++ * automatically. However, some models may output a specific label tensor instead. In ++ * this case, NLClassifier reads labels from the output label tensor. ++ */ ++ public abstract Builder setOutputLabelTensorName(String outputLabelTensorName); ++ ++ /** ++ * Set the index of the input text tensor among all input tensors, if the model has ++ * multiple inputs. Only the input tensor specified will be used for inference; other ++ * input tensors will be ignored. Dafualt to 0. ++ * ++ * <p>See the section, Configure the input/output tensors for NLClassifier, for more ++ * details. ++ */ ++ public abstract Builder setInputTensorIndex(int inputTensorIndex); ++ ++ /** ++ * Set the index of the output score tensor among all output tensors, if the model has ++ * multiple outputs. Dafualt to 0. ++ * ++ * <p>See the section, Configure the input/output tensors for NLClassifier, for more ++ * details. ++ */ ++ public abstract Builder setOutputScoreTensorIndex(int outputScoreTensorIndex); ++ ++ /** ++ * Set the index of the optional output label tensor among all output tensors, if the ++ * model has multiple outputs. ++ * ++ * <p>See the document above {@code outputLabelTensorName} for more information about ++ * what the output label tensor is. ++ * ++ * <p>See the section, Configure the input/output tensors for NLClassifier, for more ++ * details. ++ * ++ * <p>{@code outputLabelTensorIndex} dafualts to -1, meaning to disable the output label ++ * tensor. ++ */ ++ public abstract Builder setOutputLabelTensorIndex(int outputLabelTensorIndex); ++ ++ public abstract NLClassifierOptions build(); ++ } ++ } + +- @UsedByReflection("nl_classifier_jni.cc") +- abstract String getOutputScoreTensorName(); ++ private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni"; ++ ++ /** ++ * Creates {@link NLClassifier} from default {@link NLClassifierOptions}. ++ * ++ * @param context Android context ++ * @param modelPath path to the classification model relative to asset dir ++ * @return an {@link NLClassifier} instance ++ * @throws IOException if model file fails to load ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static NLClassifier createFromFile(Context context, String modelPath) ++ throws IOException { ++ return createFromFileAndOptions(context, modelPath, NLClassifierOptions.builder().build()); ++ } + +- @UsedByReflection("nl_classifier_jni.cc") +- abstract String getOutputLabelTensorName(); +- +- @Nullable +- abstract BaseOptions getBaseOptions(); +- +- public static Builder builder() { +- return new AutoValue_NLClassifier_NLClassifierOptions.Builder() +- .setInputTensorIndex(DEFAULT_INPUT_TENSOR_INDEX) +- .setOutputScoreTensorIndex(DEFAULT_OUTPUT_SCORE_TENSOR_INDEX) +- .setOutputLabelTensorIndex(DEFAULT_OUTPUT_LABEL_TENSOR_INDEX) +- .setInputTensorName(DEFAULT_INPUT_TENSOR_NAME) +- .setOutputScoreTensorName(DEFAULT_OUTPUT_SCORE_TENSOR_NAME) +- .setOutputLabelTensorName(DEFAULT_OUTPUT_LABEL_TENSOR_NAME); ++ /** ++ * Creates {@link NLClassifier} from default {@link NLClassifierOptions}. ++ * ++ * @param modelFile the classification model {@link File} instance ++ * @return an {@link NLClassifier} instance ++ * @throws IOException if model file fails to load ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static NLClassifier createFromFile(File modelFile) throws IOException { ++ return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build()); + } + +- /** Builder for {@link NLClassifierOptions}. */ +- @AutoValue.Builder +- public abstract static class Builder { +- /** Sets the general options to configure Task APIs, such as accelerators. */ +- public abstract Builder setBaseOptions(@Nullable BaseOptions baseOptions); +- +- /** +- * Configure the input/output tensors for NLClassifier: +- * +- * <p>- No special configuration is needed if the model has only one input tensor and one +- * output tensor. +- * +- * <p>- When the model has multiple input or output tensors, use the following configurations +- * to specifiy the desired tensors: <br> +- * -- tensor names: {@code inputTensorName}, {@code outputScoreTensorName}, {@code +- * outputLabelTensorName}<br> +- * -- tensor indices: {@code inputTensorIndex}, {@code outputScoreTensorIndex}, {@code +- * outputLabelTensorIndex} <br> +- * Tensor names has higher priorities than tensor indices in locating the tensors. It means +- * the tensors will be first located according to tensor names. If not found, then the tensors +- * will be located according to tensor indices. +- * +- * <p>- Failing to match the input text tensor or output score tensor with neither tensor +- * names nor tensor indices will trigger a runtime error. However, failing to locate the +- * output label tensor will not trigger an error because the label tensor is optional. +- */ +- +- /** +- * Set the name of the input text tensor, if the model has multiple inputs. Only the input +- * tensor specified will be used for inference; other input tensors will be ignored. Dafualt +- * to {@code "INPUT"}. +- * +- * <p>See the section, Configure the input/output tensors for NLClassifier, for more details. +- */ +- public abstract Builder setInputTensorName(String inputTensorName); +- +- /** +- * Set the name of the output score tensor, if the model has multiple outputs. Dafualt to +- * {@code "OUTPUT_SCORE"}. +- * +- * <p>See the section, Configure the input/output tensors for NLClassifier, for more details. +- */ +- public abstract Builder setOutputScoreTensorName(String outputScoreTensorName); +- +- /** +- * Set the name of the output label tensor, if the model has multiple outputs. Dafualt to +- * {@code "OUTPUT_LABEL"}. +- * +- * <p>See the section, Configure the input/output tensors for NLClassifier, for more details. +- * +- * <p>By default, label file should be packed with the output score tensor through Model +- * Metadata. See the <a +- * href="https://www.tensorflow.org/lite/convert/metadata_writer_tutorial#natural_language_classifiers">MetadataWriter +- * for NLClassifier</a>. NLClassifier reads and parses labels from the label file +- * automatically. However, some models may output a specific label tensor instead. In this +- * case, NLClassifier reads labels from the output label tensor. +- */ +- public abstract Builder setOutputLabelTensorName(String outputLabelTensorName); +- +- /** +- * Set the index of the input text tensor among all input tensors, if the model has multiple +- * inputs. Only the input tensor specified will be used for inference; other input tensors +- * will be ignored. Dafualt to 0. +- * +- * <p>See the section, Configure the input/output tensors for NLClassifier, for more details. +- */ +- public abstract Builder setInputTensorIndex(int inputTensorIndex); +- +- /** +- * Set the index of the output score tensor among all output tensors, if the model has +- * multiple outputs. Dafualt to 0. +- * +- * <p>See the section, Configure the input/output tensors for NLClassifier, for more details. +- */ +- public abstract Builder setOutputScoreTensorIndex(int outputScoreTensorIndex); +- +- /** +- * Set the index of the optional output label tensor among all output tensors, if the model +- * has multiple outputs. +- * +- * <p>See the document above {@code outputLabelTensorName} for more information about what the +- * output label tensor is. +- * +- * <p>See the section, Configure the input/output tensors for NLClassifier, for more details. +- * +- * <p>{@code outputLabelTensorIndex} dafualts to -1, meaning to disable the output label +- * tensor. +- */ +- public abstract Builder setOutputLabelTensorIndex(int outputLabelTensorIndex); +- +- public abstract NLClassifierOptions build(); ++ /** ++ * Creates {@link NLClassifier} from {@link NLClassifierOptions}. ++ * ++ * @param context Android context ++ * @param modelPath path to the classification model relative to asset dir ++ * @param options configurations for the model. ++ * @return an {@link NLClassifier} instance ++ * @throws IOException if model file fails to load ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static NLClassifier createFromFileAndOptions( ++ Context context, String modelPath, NLClassifierOptions options) throws IOException { ++ return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options); + } +- } +- +- private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni"; +- +- /** +- * Creates {@link NLClassifier} from default {@link NLClassifierOptions}. +- * +- * @param context Android context +- * @param modelPath path to the classification model relative to asset dir +- * @return an {@link NLClassifier} instance +- * @throws IOException if model file fails to load +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static NLClassifier createFromFile(Context context, String modelPath) throws IOException { +- return createFromFileAndOptions(context, modelPath, NLClassifierOptions.builder().build()); +- } +- +- /** +- * Creates {@link NLClassifier} from default {@link NLClassifierOptions}. +- * +- * @param modelFile the classification model {@link File} instance +- * @return an {@link NLClassifier} instance +- * @throws IOException if model file fails to load +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static NLClassifier createFromFile(File modelFile) throws IOException { +- return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build()); +- } +- +- /** +- * Creates {@link NLClassifier} from {@link NLClassifierOptions}. +- * +- * @param context Android context +- * @param modelPath path to the classification model relative to asset dir +- * @param options configurations for the model. +- * @return an {@link NLClassifier} instance +- * @throws IOException if model file fails to load +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static NLClassifier createFromFileAndOptions( +- Context context, String modelPath, NLClassifierOptions options) throws IOException { +- return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options); +- } +- +- /** +- * Creates {@link NLClassifier} from {@link NLClassifierOptions}. +- * +- * @param modelFile the classification model {@link File} instance +- * @param options configurations for the model +- * @return an {@link NLClassifier} instance +- * @throws IOException if model file fails to load +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static NLClassifier createFromFileAndOptions( +- File modelFile, final NLClassifierOptions options) throws IOException { +- try (ParcelFileDescriptor descriptor = +- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { +- return new NLClassifier( +- TaskJniUtils.createHandleFromLibrary( +- new EmptyHandleProvider() { ++ ++ /** ++ * Creates {@link NLClassifier} from {@link NLClassifierOptions}. ++ * ++ * @param modelFile the classification model {@link File} instance ++ * @param options configurations for the model ++ * @return an {@link NLClassifier} instance ++ * @throws IOException if model file fails to load ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static NLClassifier createFromFileAndOptions( ++ File modelFile, final NLClassifierOptions options) throws IOException { ++ try (ParcelFileDescriptor descriptor = ++ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { ++ return new NLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { +- long baseOptionsHandle = +- options.getBaseOptions() == null +- ? 0 // pass an invalid native handle +- : TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()); +- return initJniWithFileDescriptor(options, descriptor.getFd(), baseOptionsHandle); ++ long baseOptionsHandle = options.getBaseOptions() == null ++ ? 0 // pass an invalid native handle ++ : TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()); ++ return initJniWithFileDescriptor( ++ options, descriptor.getFd(), baseOptionsHandle); + } +- }, +- NL_CLASSIFIER_NATIVE_LIBNAME)); +- } +- } +- +- /** +- * Creates {@link NLClassifier} with a model {@link ByteBuffer} and {@link NLClassifierOptions}. +- * +- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the +- * classification model +- * @param options configurations for the model +- * @return {@link NLClassifier} instance +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a +- * {@link MappedByteBuffer} +- */ +- public static NLClassifier createFromBufferAndOptions( +- final ByteBuffer modelBuffer, final NLClassifierOptions options) { +- if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { +- throw new IllegalArgumentException( +- "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ }, NL_CLASSIFIER_NATIVE_LIBNAME)); ++ } + } + +- return new NLClassifier( +- TaskJniUtils.createHandleFromLibrary( +- new EmptyHandleProvider() { +- @Override +- public long createHandle() { +- long baseOptionsHandle = +- options.getBaseOptions() == null ++ /** ++ * Creates {@link NLClassifier} with a model {@link ByteBuffer} and {@link NLClassifierOptions}. ++ * ++ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the ++ * classification model ++ * @param options configurations for the model ++ * @return {@link NLClassifier} instance ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a ++ * {@link MappedByteBuffer} ++ */ ++ public static NLClassifier createFromBufferAndOptions( ++ final ByteBuffer modelBuffer, final NLClassifierOptions options) { ++ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { ++ throw new IllegalArgumentException( ++ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ } ++ ++ return new NLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ long baseOptionsHandle = options.getBaseOptions() == null + ? 0 // pass an invalid native handle + : TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()); + return initJniWithByteBuffer(options, modelBuffer, baseOptionsHandle); +- } +- }, +- NL_CLASSIFIER_NATIVE_LIBNAME)); +- } +- +- /** +- * Performs classification on a string input, returns classified {@link Category}s. +- * +- * @param text input text to the model +- * @return a list of Category results +- */ +- public List<Category> classify(String text) { +- return classifyNative(getNativeHandle(), text); +- } +- +- /** +- * Constructor to initialize the JNI with a pointer from C++. +- * +- * @param nativeHandle a pointer referencing memory allocated in C++. +- */ +- protected NLClassifier(long nativeHandle) { +- super(nativeHandle); +- } +- +- @Override +- protected void deinit(long nativeHandle) { +- deinitJni(nativeHandle); +- } +- +- private static native long initJniWithByteBuffer( +- NLClassifierOptions options, ByteBuffer modelBuffer, long baseOptionsHandle); +- +- private static native long initJniWithFileDescriptor( +- NLClassifierOptions options, int fd, long baseOptionsHandle); +- +- private static native List<Category> classifyNative(long nativeHandle, String text); +- +- /** +- * Native implementation to release memory pointed by the pointer. +- * +- * @param nativeHandle pointer to memory allocated +- */ +- private native void deinitJni(long nativeHandle); ++ } ++ }, NL_CLASSIFIER_NATIVE_LIBNAME)); ++ } ++ ++ /** ++ * Performs classification on a string input, returns classified {@link Category}s. ++ * ++ * @param text input text to the model ++ * @return a list of Category results ++ */ ++ public List<Category> classify(String text) { ++ return classifyNative(getNativeHandle(), text); ++ } ++ ++ /** ++ * Constructor to initialize the JNI with a pointer from C++. ++ * ++ * @param nativeHandle a pointer referencing memory allocated in C++. ++ */ ++ protected NLClassifier(long nativeHandle) { ++ super(nativeHandle); ++ } ++ ++ @Override ++ protected void deinit(long nativeHandle) { ++ deinitJni(nativeHandle); ++ } ++ ++ private static native long initJniWithByteBuffer( ++ NLClassifierOptions options, ByteBuffer modelBuffer, long baseOptionsHandle); ++ ++ private static native long initJniWithFileDescriptor( ++ NLClassifierOptions options, int fd, long baseOptionsHandle); ++ ++ private static native List<Category> classifyNative(long nativeHandle, String text); ++ ++ /** ++ * Native implementation to release memory pointed by the pointer. ++ * ++ * @param nativeHandle pointer to memory allocated ++ */ ++ private native void deinitJni(long nativeHandle); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java +index aafa2c88c55e8..39648d9bb4042 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java +@@ -17,11 +17,9 @@ package org.tensorflow.lite.task.text.qa; + + import android.content.Context; + import android.os.ParcelFileDescriptor; ++ + import com.google.auto.value.AutoValue; +-import java.io.File; +-import java.io.IOException; +-import java.nio.ByteBuffer; +-import java.util.List; ++ + import org.tensorflow.lite.task.core.BaseOptions; + import org.tensorflow.lite.task.core.BaseTaskApi; + import org.tensorflow.lite.task.core.TaskJniUtils; +@@ -29,6 +27,11 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; + import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; + import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider; + ++import java.io.File; ++import java.io.IOException; ++import java.nio.ByteBuffer; ++import java.util.List; ++ + /** + * Returns the most possible answers on a given question for QA models (BERT, Albert, etc.). + * +@@ -45,225 +48,204 @@ import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider; + * </ul> + */ + public class BertQuestionAnswerer extends BaseTaskApi implements QuestionAnswerer { +- private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni"; +- private static final int OPTIONAL_FD_LENGTH = -1; +- private static final int OPTIONAL_FD_OFFSET = -1; +- +- /** +- * Creates a {@link BertQuestionAnswerer} instance from the default {@link +- * BertQuestionAnswererOptions}. +- * +- * @param context android context +- * @param modelPath file path to the model with metadata. Note: The model should not be compressed +- * @return a {@link BertQuestionAnswerer} instance +- * @throws IOException if model file fails to load +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static BertQuestionAnswerer createFromFile(Context context, String modelPath) +- throws IOException { +- return createFromFileAndOptions( +- context, modelPath, BertQuestionAnswererOptions.builder().build()); +- } ++ private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni"; ++ private static final int OPTIONAL_FD_LENGTH = -1; ++ private static final int OPTIONAL_FD_OFFSET = -1; ++ ++ /** ++ * Creates a {@link BertQuestionAnswerer} instance from the default {@link ++ * BertQuestionAnswererOptions}. ++ * ++ * @param context android context ++ * @param modelPath file path to the model with metadata. Note: The model should not be ++ * compressed ++ * @return a {@link BertQuestionAnswerer} instance ++ * @throws IOException if model file fails to load ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static BertQuestionAnswerer createFromFile(Context context, String modelPath) ++ throws IOException { ++ return createFromFileAndOptions( ++ context, modelPath, BertQuestionAnswererOptions.builder().build()); ++ } + +- /** +- * Creates a {@link BertQuestionAnswerer} instance from the default {@link +- * BertQuestionAnswererOptions}. +- * +- * @param modelFile a {@link File} object of the model +- * @return a {@link BertQuestionAnswerer} instance +- * @throws IOException if model file fails to load +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static BertQuestionAnswerer createFromFile(File modelFile) throws IOException { +- return createFromFileAndOptions(modelFile, BertQuestionAnswererOptions.builder().build()); +- } ++ /** ++ * Creates a {@link BertQuestionAnswerer} instance from the default {@link ++ * BertQuestionAnswererOptions}. ++ * ++ * @param modelFile a {@link File} object of the model ++ * @return a {@link BertQuestionAnswerer} instance ++ * @throws IOException if model file fails to load ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static BertQuestionAnswerer createFromFile(File modelFile) throws IOException { ++ return createFromFileAndOptions(modelFile, BertQuestionAnswererOptions.builder().build()); ++ } + +- /** +- * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}. +- * +- * @param context android context +- * @param modelPath file path to the model with metadata. Note: The model should not be compressed +- * @return a {@link BertQuestionAnswerer} instance +- * @throws IOException if model file fails to load +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static BertQuestionAnswerer createFromFileAndOptions( +- Context context, String modelPath, BertQuestionAnswererOptions options) throws IOException { +- return new BertQuestionAnswerer( +- TaskJniUtils.createHandleFromFdAndOptions( +- context, +- new FdAndOptionsHandleProvider<BertQuestionAnswererOptions>() { +- @Override +- public long createHandle( +- int fileDescriptor, +- long fileDescriptorLength, +- long fileDescriptorOffset, +- BertQuestionAnswererOptions options) { +- return initJniWithFileDescriptor( +- fileDescriptor, +- fileDescriptorLength, +- fileDescriptorOffset, +- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions())); +- } +- }, +- BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, +- modelPath, +- options)); +- } ++ /** ++ * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}. ++ * ++ * @param context android context ++ * @param modelPath file path to the model with metadata. Note: The model should not be ++ * compressed ++ * @return a {@link BertQuestionAnswerer} instance ++ * @throws IOException if model file fails to load ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static BertQuestionAnswerer createFromFileAndOptions(Context context, String modelPath, ++ BertQuestionAnswererOptions options) throws IOException { ++ return new BertQuestionAnswerer(TaskJniUtils.createHandleFromFdAndOptions( ++ context, new FdAndOptionsHandleProvider<BertQuestionAnswererOptions>() { ++ @Override ++ public long createHandle(int fileDescriptor, long fileDescriptorLength, ++ long fileDescriptorOffset, BertQuestionAnswererOptions options) { ++ return initJniWithFileDescriptor(fileDescriptor, fileDescriptorLength, ++ fileDescriptorOffset, ++ TaskJniUtils.createProtoBaseOptionsHandle( ++ options.getBaseOptions())); ++ } ++ }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, options)); ++ } + +- /** +- * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}. +- * +- * @param modelFile a {@link File} object of the model +- * @return a {@link BertQuestionAnswerer} instance +- * @throws IOException if model file fails to load +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static BertQuestionAnswerer createFromFileAndOptions( +- File modelFile, final BertQuestionAnswererOptions options) throws IOException { +- try (ParcelFileDescriptor descriptor = +- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { +- return new BertQuestionAnswerer( +- TaskJniUtils.createHandleFromLibrary( +- new EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithFileDescriptor( +- /*fileDescriptor=*/ descriptor.getFd(), +- /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH, +- /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET, +- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions())); +- } +- }, +- BERT_QUESTION_ANSWERER_NATIVE_LIBNAME)); ++ /** ++ * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}. ++ * ++ * @param modelFile a {@link File} object of the model ++ * @return a {@link BertQuestionAnswerer} instance ++ * @throws IOException if model file fails to load ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static BertQuestionAnswerer createFromFileAndOptions( ++ File modelFile, final BertQuestionAnswererOptions options) throws IOException { ++ try (ParcelFileDescriptor descriptor = ++ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { ++ return new BertQuestionAnswerer( ++ TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithFileDescriptor( ++ /*fileDescriptor=*/descriptor.getFd(), ++ /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH, ++ /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, ++ TaskJniUtils.createProtoBaseOptionsHandle( ++ options.getBaseOptions())); ++ } ++ }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME)); ++ } + } +- } + +- /** +- * Creates a {@link BertQuestionAnswerer} instance with a Bert model and a vocabulary file. +- * +- * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 +- * +- * @param context android context +- * @param modelPath file path to the Bert model. Note: The model should not be compressed +- * @param vocabPath file path to the vocabulary file. Note: The file should not be compressed +- * @return a {@link BertQuestionAnswerer} instance +- * @throws IOException If model file fails to load +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static BertQuestionAnswerer createBertQuestionAnswererFromFile( +- Context context, String modelPath, String vocabPath) throws IOException { +- return new BertQuestionAnswerer( +- TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( +- context, +- new MultipleBuffersHandleProvider() { +- @Override +- public long createHandle(ByteBuffer... buffers) { +- return initJniWithBertByteBuffers(buffers); +- } +- }, +- BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, +- modelPath, +- vocabPath)); +- } ++ /** ++ * Creates a {@link BertQuestionAnswerer} instance with a Bert model and a vocabulary file. ++ * ++ * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 ++ * ++ * @param context android context ++ * @param modelPath file path to the Bert model. Note: The model should not be compressed ++ * @param vocabPath file path to the vocabulary file. Note: The file should not be compressed ++ * @return a {@link BertQuestionAnswerer} instance ++ * @throws IOException If model file fails to load ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static BertQuestionAnswerer createBertQuestionAnswererFromFile( ++ Context context, String modelPath, String vocabPath) throws IOException { ++ return new BertQuestionAnswerer(TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( ++ context, new MultipleBuffersHandleProvider() { ++ @Override ++ public long createHandle(ByteBuffer... buffers) { ++ return initJniWithBertByteBuffers(buffers); ++ } ++ }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, vocabPath)); ++ } + +- /** +- * Creates a {@link BertQuestionAnswerer} instance with an Albert model and a sentence piece model +- * file. +- * +- * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 +- * +- * @param context android context +- * @param modelPath file path to the Albert model. Note: The model should not be compressed +- * @param sentencePieceModelPath file path to the sentence piece model file. Note: The model +- * should not be compressed +- * @return a {@link BertQuestionAnswerer} instance +- * @throws IOException If model file fails to load +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static BertQuestionAnswerer createAlbertQuestionAnswererFromFile( +- Context context, String modelPath, String sentencePieceModelPath) throws IOException { +- return new BertQuestionAnswerer( +- TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( +- context, +- new MultipleBuffersHandleProvider() { +- @Override +- public long createHandle(ByteBuffer... buffers) { +- return initJniWithAlbertByteBuffers(buffers); +- } +- }, +- BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, +- modelPath, +- sentencePieceModelPath)); +- } ++ /** ++ * Creates a {@link BertQuestionAnswerer} instance with an Albert model and a sentence piece ++ * model file. ++ * ++ * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 ++ * ++ * @param context android context ++ * @param modelPath file path to the Albert model. Note: The model should not be compressed ++ * @param sentencePieceModelPath file path to the sentence piece model file. Note: The model ++ * should not be compressed ++ * @return a {@link BertQuestionAnswerer} instance ++ * @throws IOException If model file fails to load ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static BertQuestionAnswerer createAlbertQuestionAnswererFromFile( ++ Context context, String modelPath, String sentencePieceModelPath) throws IOException { ++ return new BertQuestionAnswerer(TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( ++ context, new MultipleBuffersHandleProvider() { ++ @Override ++ public long createHandle(ByteBuffer... buffers) { ++ return initJniWithAlbertByteBuffers(buffers); ++ } ++ }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, sentencePieceModelPath)); ++ } + +- /** Options for setting up a {@link BertQuestionAnswerer}. */ +- @AutoValue +- public abstract static class BertQuestionAnswererOptions { +- abstract BaseOptions getBaseOptions(); ++ /** Options for setting up a {@link BertQuestionAnswerer}. */ ++ @AutoValue ++ public abstract static class BertQuestionAnswererOptions { ++ abstract BaseOptions getBaseOptions(); + +- public static Builder builder() { +- return new AutoValue_BertQuestionAnswerer_BertQuestionAnswererOptions.Builder() +- .setBaseOptions(BaseOptions.builder().build()); +- } ++ public static Builder builder() { ++ return new AutoValue_BertQuestionAnswerer_BertQuestionAnswererOptions.Builder() ++ .setBaseOptions(BaseOptions.builder().build()); ++ } + +- /** Builder for {@link BertQuestionAnswererOptions}. */ +- @AutoValue.Builder +- public abstract static class Builder { +- /** Sets the general options to configure Task APIs, such as accelerators. */ +- public abstract Builder setBaseOptions(BaseOptions baseOptions); ++ /** Builder for {@link BertQuestionAnswererOptions}. */ ++ @AutoValue.Builder ++ public abstract static class Builder { ++ /** Sets the general options to configure Task APIs, such as accelerators. */ ++ public abstract Builder setBaseOptions(BaseOptions baseOptions); + +- public abstract BertQuestionAnswererOptions build(); ++ public abstract BertQuestionAnswererOptions build(); ++ } + } +- } + +- @Override +- public List<QaAnswer> answer(String context, String question) { +- checkNotClosed(); +- return answerNative(getNativeHandle(), context, question); +- } ++ @Override ++ public List<QaAnswer> answer(String context, String question) { ++ checkNotClosed(); ++ return answerNative(getNativeHandle(), context, question); ++ } + +- private BertQuestionAnswerer(long nativeHandle) { +- super(nativeHandle); +- } ++ private BertQuestionAnswerer(long nativeHandle) { ++ super(nativeHandle); ++ } + +- // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer. +- private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers); ++ // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer. ++ private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers); + +- // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is sentencepiece model file +- // buffer. +- private static native long initJniWithAlbertByteBuffers(ByteBuffer... modelBuffers); ++ // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is sentencepiece model file ++ // buffer. ++ private static native long initJniWithAlbertByteBuffers(ByteBuffer... modelBuffers); + +- private static native long initJniWithFileDescriptor( +- int fileDescriptor, +- long fileDescriptorLength, +- long fileDescriptorOffset, +- long baseOptionsHandle); ++ private static native long initJniWithFileDescriptor(int fileDescriptor, ++ long fileDescriptorLength, long fileDescriptorOffset, long baseOptionsHandle); + +- private static native List<QaAnswer> answerNative( +- long nativeHandle, String context, String question); ++ private static native List<QaAnswer> answerNative( ++ long nativeHandle, String context, String question); + +- @Override +- protected void deinit(long nativeHandle) { +- deinitJni(nativeHandle); +- } ++ @Override ++ protected void deinit(long nativeHandle) { ++ deinitJni(nativeHandle); ++ } + +- /** +- * Native implementation to release memory pointed by the pointer. +- * +- * @param nativeHandle pointer to memory allocated +- */ +- private native void deinitJni(long nativeHandle); ++ /** ++ * Native implementation to release memory pointed by the pointer. ++ * ++ * @param nativeHandle pointer to memory allocated ++ */ ++ private native void deinitJni(long nativeHandle); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java +index 4259a69794059..955da9988ca0a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java +@@ -22,37 +22,37 @@ import org.tensorflow.lite.annotations.UsedByReflection; + * position information to the context. + */ + public class QaAnswer { +- public Pos pos; +- public String text; +- +- @UsedByReflection("bert_question_answerer_jni.cc") +- public QaAnswer(String text, Pos pos) { +- this.text = text; +- this.pos = pos; +- } +- +- public QaAnswer(String text, int start, int end, float logit) { +- this(text, new Pos(start, end, logit)); +- } +- +- /** +- * Position information of the answer relative to context. It is sortable in descending order +- * based on logit. +- */ +- public static class Pos implements Comparable<Pos> { +- public int start; +- public int end; +- public float logit; +- +- public Pos(int start, int end, float logit) { +- this.start = start; +- this.end = end; +- this.logit = logit; ++ public Pos pos; ++ public String text; ++ ++ @UsedByReflection("bert_question_answerer_jni.cc") ++ public QaAnswer(String text, Pos pos) { ++ this.text = text; ++ this.pos = pos; ++ } ++ ++ public QaAnswer(String text, int start, int end, float logit) { ++ this(text, new Pos(start, end, logit)); + } + +- @Override +- public int compareTo(Pos other) { +- return Float.compare(other.logit, this.logit); ++ /** ++ * Position information of the answer relative to context. It is sortable in descending order ++ * based on logit. ++ */ ++ public static class Pos implements Comparable<Pos> { ++ public int start; ++ public int end; ++ public float logit; ++ ++ public Pos(int start, int end, float logit) { ++ this.start = start; ++ this.end = end; ++ this.logit = logit; ++ } ++ ++ @Override ++ public int compareTo(Pos other) { ++ return Float.compare(other.logit, this.logit); ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java +index 8df6d3794e1b5..7a59a99d7fddf 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java +@@ -19,14 +19,13 @@ import java.util.List; + + /** API to answer questions based on context. */ + public interface QuestionAnswerer { +- +- /** +- * Answers question based on context, and returns a list of possible {@link QaAnswer}s. Could be +- * empty if no answer was found from the given context. +- * +- * @param context context the question bases on +- * @param question question to ask +- * @return a list of possible answers in {@link QaAnswer} +- */ +- List<QaAnswer> answer(String context, String question); ++ /** ++ * Answers question based on context, and returns a list of possible {@link QaAnswer}s. Could be ++ * empty if no answer was found from the given context. ++ * ++ * @param context context the question bases on ++ * @param question question to ask ++ * @return a list of possible answers in {@link QaAnswer} ++ */ ++ List<QaAnswer> answer(String context, String question); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java +index d33f0fbbdd497..0d35443a7de5d 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java +@@ -16,11 +16,13 @@ limitations under the License. + package org.tensorflow.lite.task.vision.classifier; + + import com.google.auto.value.AutoValue; ++ ++import org.tensorflow.lite.annotations.UsedByReflection; ++import org.tensorflow.lite.support.label.Category; ++ + import java.util.ArrayList; + import java.util.Collections; + import java.util.List; +-import org.tensorflow.lite.annotations.UsedByReflection; +-import org.tensorflow.lite.support.label.Category; + + /** + * The classification results of one head in a multihead (a.k.a. multi-output) {@link +@@ -31,16 +33,15 @@ import org.tensorflow.lite.support.label.Category; + @AutoValue + @UsedByReflection("image_classifier_jni.cc") + public abstract class Classifications { ++ @UsedByReflection("image_classifier_jni.cc") ++ static Classifications create(List<Category> categories, int headIndex) { ++ return new AutoValue_Classifications( ++ Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex); ++ } + +- @UsedByReflection("image_classifier_jni.cc") +- static Classifications create(List<Category> categories, int headIndex) { +- return new AutoValue_Classifications( +- Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex); +- } +- +- // Same reason for not using ImmutableList as stated in +- // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}. +- public abstract List<Category> getCategories(); ++ // Same reason for not using ImmutableList as stated in ++ // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}. ++ public abstract List<Category> getCategories(); + +- public abstract int getHeadIndex(); ++ public abstract int getHeadIndex(); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java +index 2bf3fa8a465b4..48038f6a1c04e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java +@@ -18,14 +18,9 @@ package org.tensorflow.lite.task.vision.classifier; + import android.content.Context; + import android.graphics.Rect; + import android.os.ParcelFileDescriptor; ++ + import com.google.android.odml.image.MlImage; +-import java.io.File; +-import java.io.IOException; +-import java.nio.ByteBuffer; +-import java.nio.MappedByteBuffer; +-import java.util.ArrayList; +-import java.util.Collections; +-import java.util.List; ++ + import org.tensorflow.lite.annotations.UsedByReflection; + import org.tensorflow.lite.support.image.MlImageAdapter; + import org.tensorflow.lite.support.image.TensorImage; +@@ -37,6 +32,14 @@ import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; + import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi; + import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider; + ++import java.io.File; ++import java.io.IOException; ++import java.nio.ByteBuffer; ++import java.nio.MappedByteBuffer; ++import java.util.ArrayList; ++import java.util.Collections; ++import java.util.List; ++ + /** + * Performs classification on images. + * +@@ -71,476 +74,449 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider; + * Hub.</a>. + */ + public final class ImageClassifier extends BaseVisionTaskApi { ++ private static final String IMAGE_CLASSIFIER_NATIVE_LIB = "task_vision_jni"; ++ private static final int OPTIONAL_FD_LENGTH = -1; ++ private static final int OPTIONAL_FD_OFFSET = -1; ++ ++ /** ++ * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}. ++ * ++ * @param modelPath path of the classification model with metadata in the assets ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ImageClassifier createFromFile(Context context, String modelPath) ++ throws IOException { ++ return createFromFileAndOptions( ++ context, modelPath, ImageClassifierOptions.builder().build()); ++ } + +- private static final String IMAGE_CLASSIFIER_NATIVE_LIB = "task_vision_jni"; +- private static final int OPTIONAL_FD_LENGTH = -1; +- private static final int OPTIONAL_FD_OFFSET = -1; +- +- /** +- * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}. +- * +- * @param modelPath path of the classification model with metadata in the assets +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ImageClassifier createFromFile(Context context, String modelPath) +- throws IOException { +- return createFromFileAndOptions(context, modelPath, ImageClassifierOptions.builder().build()); +- } +- +- /** +- * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}. +- * +- * @param modelFile the classification model {@link File} instance +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ImageClassifier createFromFile(File modelFile) throws IOException { +- return createFromFileAndOptions(modelFile, ImageClassifierOptions.builder().build()); +- } +- +- /** +- * Creates an {@link ImageClassifier} instance with a model buffer and the default {@link +- * ImageClassifierOptions}. +- * +- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the +- * classification model +- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a +- * {@link MappedByteBuffer} +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ImageClassifier createFromBuffer(final ByteBuffer modelBuffer) { +- return createFromBufferAndOptions(modelBuffer, ImageClassifierOptions.builder().build()); +- } +- +- /** +- * Creates an {@link ImageClassifier} instance from {@link ImageClassifierOptions}. +- * +- * @param modelPath path of the classification model with metadata in the assets +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ImageClassifier createFromFileAndOptions( +- Context context, String modelPath, ImageClassifierOptions options) throws IOException { +- return new ImageClassifier( +- TaskJniUtils.createHandleFromFdAndOptions( +- context, +- new FdAndOptionsHandleProvider<ImageClassifierOptions>() { +- @Override +- public long createHandle( +- int fileDescriptor, +- long fileDescriptorLength, +- long fileDescriptorOffset, +- ImageClassifierOptions options) { +- return initJniWithModelFdAndOptions( +- fileDescriptor, +- fileDescriptorLength, +- fileDescriptorOffset, +- options, +- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( +- options.getBaseOptions(), options.getNumThreads())); +- } +- }, +- IMAGE_CLASSIFIER_NATIVE_LIB, +- modelPath, +- options)); +- } +- +- /** +- * Creates an {@link ImageClassifier} instance. +- * +- * @param modelFile the classification model {@link File} instance +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ImageClassifier createFromFileAndOptions( +- File modelFile, final ImageClassifierOptions options) throws IOException { +- try (ParcelFileDescriptor descriptor = +- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { +- return new ImageClassifier( +- TaskJniUtils.createHandleFromLibrary( +- new TaskJniUtils.EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithModelFdAndOptions( +- descriptor.getFd(), +- /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH, +- /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET, +- options, +- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( +- options.getBaseOptions(), options.getNumThreads())); +- } +- }, +- IMAGE_CLASSIFIER_NATIVE_LIB)); ++ /** ++ * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}. ++ * ++ * @param modelFile the classification model {@link File} instance ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ImageClassifier createFromFile(File modelFile) throws IOException { ++ return createFromFileAndOptions(modelFile, ImageClassifierOptions.builder().build()); + } +- } +- +- /** +- * Creates an {@link ImageClassifier} instance with a model buffer and {@link +- * ImageClassifierOptions}. +- * +- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the +- * classification model +- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a +- * {@link MappedByteBuffer} +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ImageClassifier createFromBufferAndOptions( +- final ByteBuffer modelBuffer, final ImageClassifierOptions options) { +- if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { +- throw new IllegalArgumentException( +- "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ ++ /** ++ * Creates an {@link ImageClassifier} instance with a model buffer and the default {@link ++ * ImageClassifierOptions}. ++ * ++ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the ++ * classification model ++ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a ++ * {@link MappedByteBuffer} ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ImageClassifier createFromBuffer(final ByteBuffer modelBuffer) { ++ return createFromBufferAndOptions(modelBuffer, ImageClassifierOptions.builder().build()); + } +- return new ImageClassifier( +- TaskJniUtils.createHandleFromLibrary( +- new EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithByteBuffer( +- modelBuffer, +- options, +- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( +- options.getBaseOptions(), options.getNumThreads())); +- } +- }, +- IMAGE_CLASSIFIER_NATIVE_LIB)); +- } +- +- /** +- * Constructor to initialize the JNI with a pointer from C++. +- * +- * @param nativeHandle a pointer referencing memory allocated in C++ +- */ +- ImageClassifier(long nativeHandle) { +- super(nativeHandle); +- } +- +- /** Options for setting up an ImageClassifier. */ +- @UsedByReflection("image_classifier_jni.cc") +- public static class ImageClassifierOptions { +- // Not using AutoValue for this class because scoreThreshold cannot have default value +- // (otherwise, the default value would override the one in the model metadata) and `Optional` is +- // not an option here, because +- // 1. java.util.Optional require Java 8 while we need to support Java 7. +- // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the +- // comments for labelAllowList. +- private final BaseOptions baseOptions; +- private final String displayNamesLocale; +- private final int maxResults; +- private final float scoreThreshold; +- private final boolean isScoreThresholdSet; +- // As an open source project, we've been trying avoiding depending on common java libraries, +- // such as Guava, because it may introduce conflicts with clients who also happen to use those +- // libraries. Therefore, instead of using ImmutableList here, we convert the List into +- // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less +- // vulnerable. +- private final List<String> labelAllowList; +- private final List<String> labelDenyList; +- private final int numThreads; +- +- public static Builder builder() { +- return new Builder(); ++ ++ /** ++ * Creates an {@link ImageClassifier} instance from {@link ImageClassifierOptions}. ++ * ++ * @param modelPath path of the classification model with metadata in the assets ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ImageClassifier createFromFileAndOptions( ++ Context context, String modelPath, ImageClassifierOptions options) throws IOException { ++ return new ImageClassifier(TaskJniUtils.createHandleFromFdAndOptions( ++ context, new FdAndOptionsHandleProvider<ImageClassifierOptions>() { ++ @Override ++ public long createHandle(int fileDescriptor, long fileDescriptorLength, ++ long fileDescriptorOffset, ImageClassifierOptions options) { ++ return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength, ++ fileDescriptorOffset, options, ++ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( ++ options.getBaseOptions(), options.getNumThreads())); ++ } ++ }, IMAGE_CLASSIFIER_NATIVE_LIB, modelPath, options)); + } + +- /** A builder that helps to configure an instance of ImageClassifierOptions. */ +- public static class Builder { +- private BaseOptions baseOptions = BaseOptions.builder().build(); +- private String displayNamesLocale = "en"; +- private int maxResults = -1; +- private float scoreThreshold; +- private boolean isScoreThresholdSet = false; +- private List<String> labelAllowList = new ArrayList<>(); +- private List<String> labelDenyList = new ArrayList<>(); +- private int numThreads = -1; +- +- Builder() {} +- +- /** Sets the general options to configure Task APIs, such as accelerators. */ +- public Builder setBaseOptions(BaseOptions baseOptions) { +- this.baseOptions = baseOptions; +- return this; +- } +- +- /** +- * Sets the locale to use for display names specified through the TFLite Model Metadata, if +- * any. +- * +- * <p>Defaults to English({@code "en"}). See the <a +- * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite +- * Metadata schema file.</a> for the accepted pattern of locale. +- */ +- public Builder setDisplayNamesLocale(String displayNamesLocale) { +- this.displayNamesLocale = displayNamesLocale; +- return this; +- } +- +- /** +- * Sets the maximum number of top scored results to return. +- * +- * <p>If < 0, all results will be returned. If 0, an invalid argument error is returned. +- * Defaults to -1. +- * +- * @throws IllegalArgumentException if maxResults is 0. +- */ +- public Builder setMaxResults(int maxResults) { +- if (maxResults == 0) { +- throw new IllegalArgumentException("maxResults cannot be 0."); ++ /** ++ * Creates an {@link ImageClassifier} instance. ++ * ++ * @param modelFile the classification model {@link File} instance ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ImageClassifier createFromFileAndOptions( ++ File modelFile, final ImageClassifierOptions options) throws IOException { ++ try (ParcelFileDescriptor descriptor = ++ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { ++ return new ImageClassifier( ++ TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithModelFdAndOptions(descriptor.getFd(), ++ /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH, ++ /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options, ++ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( ++ options.getBaseOptions(), options.getNumThreads())); ++ } ++ }, IMAGE_CLASSIFIER_NATIVE_LIB)); + } +- this.maxResults = maxResults; +- return this; +- } +- +- /** +- * Sets the score threshold. +- * +- * <p>It overrides the one provided in the model metadata (if any). Results below this value +- * are rejected. +- */ +- public Builder setScoreThreshold(float scoreThreshold) { +- this.scoreThreshold = scoreThreshold; +- isScoreThresholdSet = true; +- return this; +- } +- +- /** +- * Sets the optional allowlist of labels. +- * +- * <p>If non-empty, classifications whose label is not in this set will be filtered out. +- * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList. +- */ +- public Builder setLabelAllowList(List<String> labelAllowList) { +- this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList)); +- return this; +- } +- +- /** +- * Sets the optional denylist of labels. +- * +- * <p>If non-empty, classifications whose label is in this set will be filtered out. Duplicate +- * or unknown labels are ignored. Mutually exclusive with labelAllowList. +- */ +- public Builder setLabelDenyList(List<String> labelDenyList) { +- this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList)); +- return this; +- } +- +- /** +- * Sets the number of threads to be used for TFLite ops that support multi-threading when +- * running inference with CPU. Defaults to -1. +- * +- * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the +- * effect to let TFLite runtime set the value. +- * +- * @deprecated use {@link BaseOptions} to configure number of threads instead. This method +- * will override the number of threads configured from {@link BaseOptions}. +- */ +- @Deprecated +- public Builder setNumThreads(int numThreads) { +- this.numThreads = numThreads; +- return this; +- } +- +- public ImageClassifierOptions build() { +- return new ImageClassifierOptions(this); +- } + } + +- @UsedByReflection("image_classifier_jni.cc") +- public String getDisplayNamesLocale() { +- return displayNamesLocale; ++ /** ++ * Creates an {@link ImageClassifier} instance with a model buffer and {@link ++ * ImageClassifierOptions}. ++ * ++ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the ++ * classification model ++ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a ++ * {@link MappedByteBuffer} ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ImageClassifier createFromBufferAndOptions( ++ final ByteBuffer modelBuffer, final ImageClassifierOptions options) { ++ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { ++ throw new IllegalArgumentException( ++ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ } ++ return new ImageClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithByteBuffer(modelBuffer, options, ++ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( ++ options.getBaseOptions(), options.getNumThreads())); ++ } ++ }, IMAGE_CLASSIFIER_NATIVE_LIB)); + } + +- @UsedByReflection("image_classifier_jni.cc") +- public int getMaxResults() { +- return maxResults; ++ /** ++ * Constructor to initialize the JNI with a pointer from C++. ++ * ++ * @param nativeHandle a pointer referencing memory allocated in C++ ++ */ ++ ImageClassifier(long nativeHandle) { ++ super(nativeHandle); + } + ++ /** Options for setting up an ImageClassifier. */ + @UsedByReflection("image_classifier_jni.cc") +- public float getScoreThreshold() { +- return scoreThreshold; ++ public static class ImageClassifierOptions { ++ // Not using AutoValue for this class because scoreThreshold cannot have default value ++ // (otherwise, the default value would override the one in the model metadata) and ++ // `Optional` is not an option here, because ++ // 1. java.util.Optional require Java 8 while we need to support Java 7. ++ // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See ++ // the comments for labelAllowList. ++ private final BaseOptions baseOptions; ++ private final String displayNamesLocale; ++ private final int maxResults; ++ private final float scoreThreshold; ++ private final boolean isScoreThresholdSet; ++ // As an open source project, we've been trying avoiding depending on common java libraries, ++ // such as Guava, because it may introduce conflicts with clients who also happen to use ++ // those libraries. Therefore, instead of using ImmutableList here, we convert the List into ++ // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less ++ // vulnerable. ++ private final List<String> labelAllowList; ++ private final List<String> labelDenyList; ++ private final int numThreads; ++ ++ public static Builder builder() { ++ return new Builder(); ++ } ++ ++ /** A builder that helps to configure an instance of ImageClassifierOptions. */ ++ public static class Builder { ++ private BaseOptions baseOptions = BaseOptions.builder().build(); ++ private String displayNamesLocale = "en"; ++ private int maxResults = -1; ++ private float scoreThreshold; ++ private boolean isScoreThresholdSet = false; ++ private List<String> labelAllowList = new ArrayList<>(); ++ private List<String> labelDenyList = new ArrayList<>(); ++ private int numThreads = -1; ++ ++ Builder() {} ++ ++ /** Sets the general options to configure Task APIs, such as accelerators. */ ++ public Builder setBaseOptions(BaseOptions baseOptions) { ++ this.baseOptions = baseOptions; ++ return this; ++ } ++ ++ /** ++ * Sets the locale to use for display names specified through the TFLite Model Metadata, ++ * if any. ++ * ++ * <p>Defaults to English({@code "en"}). See the <a ++ * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite ++ * Metadata schema file.</a> for the accepted pattern of locale. ++ */ ++ public Builder setDisplayNamesLocale(String displayNamesLocale) { ++ this.displayNamesLocale = displayNamesLocale; ++ return this; ++ } ++ ++ /** ++ * Sets the maximum number of top scored results to return. ++ * ++ * <p>If < 0, all results will be returned. If 0, an invalid argument error is returned. ++ * Defaults to -1. ++ * ++ * @throws IllegalArgumentException if maxResults is 0. ++ */ ++ public Builder setMaxResults(int maxResults) { ++ if (maxResults == 0) { ++ throw new IllegalArgumentException("maxResults cannot be 0."); ++ } ++ this.maxResults = maxResults; ++ return this; ++ } ++ ++ /** ++ * Sets the score threshold. ++ * ++ * <p>It overrides the one provided in the model metadata (if any). Results below this ++ * value are rejected. ++ */ ++ public Builder setScoreThreshold(float scoreThreshold) { ++ this.scoreThreshold = scoreThreshold; ++ isScoreThresholdSet = true; ++ return this; ++ } ++ ++ /** ++ * Sets the optional allowlist of labels. ++ * ++ * <p>If non-empty, classifications whose label is not in this set will be filtered out. ++ * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList. ++ */ ++ public Builder setLabelAllowList(List<String> labelAllowList) { ++ this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList)); ++ return this; ++ } ++ ++ /** ++ * Sets the optional denylist of labels. ++ * ++ * <p>If non-empty, classifications whose label is in this set will be filtered out. ++ * Duplicate or unknown labels are ignored. Mutually exclusive with labelAllowList. ++ */ ++ public Builder setLabelDenyList(List<String> labelDenyList) { ++ this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList)); ++ return this; ++ } ++ ++ /** ++ * Sets the number of threads to be used for TFLite ops that support multi-threading ++ * when running inference with CPU. Defaults to -1. ++ * ++ * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has ++ * the effect to let TFLite runtime set the value. ++ * ++ * @deprecated use {@link BaseOptions} to configure number of threads instead. This ++ * method ++ * will override the number of threads configured from {@link BaseOptions}. ++ */ ++ @Deprecated ++ public Builder setNumThreads(int numThreads) { ++ this.numThreads = numThreads; ++ return this; ++ } ++ ++ public ImageClassifierOptions build() { ++ return new ImageClassifierOptions(this); ++ } ++ } ++ ++ @UsedByReflection("image_classifier_jni.cc") ++ public String getDisplayNamesLocale() { ++ return displayNamesLocale; ++ } ++ ++ @UsedByReflection("image_classifier_jni.cc") ++ public int getMaxResults() { ++ return maxResults; ++ } ++ ++ @UsedByReflection("image_classifier_jni.cc") ++ public float getScoreThreshold() { ++ return scoreThreshold; ++ } ++ ++ @UsedByReflection("image_classifier_jni.cc") ++ public boolean getIsScoreThresholdSet() { ++ return isScoreThresholdSet; ++ } ++ ++ @UsedByReflection("image_classifier_jni.cc") ++ public List<String> getLabelAllowList() { ++ return new ArrayList<>(labelAllowList); ++ } ++ ++ @UsedByReflection("image_classifier_jni.cc") ++ public List<String> getLabelDenyList() { ++ return new ArrayList<>(labelDenyList); ++ } ++ ++ @UsedByReflection("image_classifier_jni.cc") ++ public int getNumThreads() { ++ return numThreads; ++ } ++ ++ public BaseOptions getBaseOptions() { ++ return baseOptions; ++ } ++ ++ ImageClassifierOptions(Builder builder) { ++ displayNamesLocale = builder.displayNamesLocale; ++ maxResults = builder.maxResults; ++ scoreThreshold = builder.scoreThreshold; ++ isScoreThresholdSet = builder.isScoreThresholdSet; ++ labelAllowList = builder.labelAllowList; ++ labelDenyList = builder.labelDenyList; ++ numThreads = builder.numThreads; ++ baseOptions = builder.baseOptions; ++ } + } + +- @UsedByReflection("image_classifier_jni.cc") +- public boolean getIsScoreThresholdSet() { +- return isScoreThresholdSet; ++ /** ++ * Performs actual classification on the provided {@link TensorImage}. ++ * ++ * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types: ++ * ++ * <ul> ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} ++ * </ul> ++ * ++ * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image ++ * @throws IllegalArgumentException if the color space type of image is unsupported ++ */ ++ public List<Classifications> classify(TensorImage image) { ++ return classify(image, ImageProcessingOptions.builder().build()); + } + +- @UsedByReflection("image_classifier_jni.cc") +- public List<String> getLabelAllowList() { +- return new ArrayList<>(labelAllowList); ++ /** ++ * Performs actual classification on the provided {@link TensorImage} with {@link ++ * ImageProcessingOptions}. ++ * ++ * <p>{@link ImageClassifier} supports the following options: ++ * ++ * <ul> ++ * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It ++ * defaults to the entire image. ++ * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It ++ * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. ++ * </ul> ++ * ++ * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types: ++ * ++ * <ul> ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} ++ * </ul> ++ * ++ * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image ++ * @throws IllegalArgumentException if the color space type of image is unsupported ++ */ ++ public List<Classifications> classify(TensorImage image, ImageProcessingOptions options) { ++ return run(new InferenceProvider<List<Classifications>>() { ++ @Override ++ public List<Classifications> run( ++ long frameBufferHandle, int width, int height, ImageProcessingOptions options) { ++ return classify(frameBufferHandle, width, height, options); ++ } ++ }, image, options); + } + +- @UsedByReflection("image_classifier_jni.cc") +- public List<String> getLabelDenyList() { +- return new ArrayList<>(labelDenyList); ++ /** ++ * Performs actual classification on the provided {@code MlImage}. ++ * ++ * @param image an {@code MlImage} object that represents an image ++ * @throws IllegalArgumentException if the storage type or format of the image is unsupported ++ */ ++ public List<Classifications> classify(MlImage image) { ++ return classify(image, ImageProcessingOptions.builder().build()); + } + +- @UsedByReflection("image_classifier_jni.cc") +- public int getNumThreads() { +- return numThreads; ++ /** ++ * Performs actual classification on the provided {@code MlImage} with {@link ++ * ImageProcessingOptions}. ++ * ++ * <p>{@link ImageClassifier} supports the following options: ++ * ++ * <ul> ++ * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It ++ * defaults to the entire image. ++ * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It ++ * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link ++ * MlImage#getRotation()} is not effective. ++ * </ul> ++ * ++ * @param image a {@code MlImage} object that represents an image ++ * @param options configures options including ROI and rotation ++ * @throws IllegalArgumentException if the storage type or format of the image is unsupported ++ */ ++ public List<Classifications> classify(MlImage image, ImageProcessingOptions options) { ++ image.getInternal().acquire(); ++ TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); ++ List<Classifications> result = classify(tensorImage, options); ++ image.close(); ++ return result; + } + +- public BaseOptions getBaseOptions() { +- return baseOptions; ++ private List<Classifications> classify( ++ long frameBufferHandle, int width, int height, ImageProcessingOptions options) { ++ checkNotClosed(); ++ ++ Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi(); ++ ++ return classifyNative(getNativeHandle(), frameBufferHandle, ++ new int[] {roi.left, roi.top, roi.width(), roi.height()}); + } + +- ImageClassifierOptions(Builder builder) { +- displayNamesLocale = builder.displayNamesLocale; +- maxResults = builder.maxResults; +- scoreThreshold = builder.scoreThreshold; +- isScoreThresholdSet = builder.isScoreThresholdSet; +- labelAllowList = builder.labelAllowList; +- labelDenyList = builder.labelDenyList; +- numThreads = builder.numThreads; +- baseOptions = builder.baseOptions; ++ private static native long initJniWithModelFdAndOptions(int fileDescriptor, ++ long fileDescriptorLength, long fileDescriptorOffset, ImageClassifierOptions options, ++ long baseOptionsHandle); ++ ++ private static native long initJniWithByteBuffer( ++ ByteBuffer modelBuffer, ImageClassifierOptions options, long baseOptionsHandle); ++ ++ /** ++ * The native method to classify an image with the ROI and orientation. ++ * ++ * @param roi the ROI of the input image, an array representing the bounding box as {left, top, ++ * width, height} ++ */ ++ private static native List<Classifications> classifyNative( ++ long nativeHandle, long frameBufferHandle, int[] roi); ++ ++ @Override ++ protected void deinit(long nativeHandle) { ++ deinitJni(nativeHandle); + } +- } +- +- /** +- * Performs actual classification on the provided {@link TensorImage}. +- * +- * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types: +- * +- * <ul> +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} +- * </ul> +- * +- * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image +- * @throws IllegalArgumentException if the color space type of image is unsupported +- */ +- public List<Classifications> classify(TensorImage image) { +- return classify(image, ImageProcessingOptions.builder().build()); +- } +- +- /** +- * Performs actual classification on the provided {@link TensorImage} with {@link +- * ImageProcessingOptions}. +- * +- * <p>{@link ImageClassifier} supports the following options: +- * +- * <ul> +- * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It +- * defaults to the entire image. +- * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It +- * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. +- * </ul> +- * +- * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types: +- * +- * <ul> +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} +- * </ul> +- * +- * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image +- * @throws IllegalArgumentException if the color space type of image is unsupported +- */ +- public List<Classifications> classify(TensorImage image, ImageProcessingOptions options) { +- return run( +- new InferenceProvider<List<Classifications>>() { +- @Override +- public List<Classifications> run( +- long frameBufferHandle, int width, int height, ImageProcessingOptions options) { +- return classify(frameBufferHandle, width, height, options); +- } +- }, +- image, +- options); +- } +- +- /** +- * Performs actual classification on the provided {@code MlImage}. +- * +- * @param image an {@code MlImage} object that represents an image +- * @throws IllegalArgumentException if the storage type or format of the image is unsupported +- */ +- public List<Classifications> classify(MlImage image) { +- return classify(image, ImageProcessingOptions.builder().build()); +- } +- +- /** +- * Performs actual classification on the provided {@code MlImage} with {@link +- * ImageProcessingOptions}. +- * +- * <p>{@link ImageClassifier} supports the following options: +- * +- * <ul> +- * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It +- * defaults to the entire image. +- * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It +- * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link +- * MlImage#getRotation()} is not effective. +- * </ul> +- * +- * @param image a {@code MlImage} object that represents an image +- * @param options configures options including ROI and rotation +- * @throws IllegalArgumentException if the storage type or format of the image is unsupported +- */ +- public List<Classifications> classify(MlImage image, ImageProcessingOptions options) { +- image.getInternal().acquire(); +- TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); +- List<Classifications> result = classify(tensorImage, options); +- image.close(); +- return result; +- } +- +- private List<Classifications> classify( +- long frameBufferHandle, int width, int height, ImageProcessingOptions options) { +- checkNotClosed(); +- +- Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi(); +- +- return classifyNative( +- getNativeHandle(), +- frameBufferHandle, +- new int[] {roi.left, roi.top, roi.width(), roi.height()}); +- } +- +- private static native long initJniWithModelFdAndOptions( +- int fileDescriptor, +- long fileDescriptorLength, +- long fileDescriptorOffset, +- ImageClassifierOptions options, +- long baseOptionsHandle); +- +- private static native long initJniWithByteBuffer( +- ByteBuffer modelBuffer, ImageClassifierOptions options, long baseOptionsHandle); +- +- /** +- * The native method to classify an image with the ROI and orientation. +- * +- * @param roi the ROI of the input image, an array representing the bounding box as {left, top, +- * width, height} +- */ +- private static native List<Classifications> classifyNative( +- long nativeHandle, long frameBufferHandle, int[] roi); +- +- @Override +- protected void deinit(long nativeHandle) { +- deinitJni(nativeHandle); +- } +- +- /** +- * Native implementation to release memory pointed by the pointer. +- * +- * @param nativeHandle pointer to memory allocated +- */ +- private native void deinitJni(long nativeHandle); ++ ++ /** ++ * Native implementation to release memory pointed by the pointer. ++ * ++ * @param nativeHandle pointer to memory allocated ++ */ ++ private native void deinitJni(long nativeHandle); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java +index fdc898f451337..59ab62a949a25 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java +@@ -21,213 +21,184 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c + import android.graphics.ImageFormat; + import android.media.Image; + import android.media.Image.Plane; ++ + import com.google.auto.value.AutoValue; +-import java.nio.ByteBuffer; ++ + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.image.ColorSpaceType; + import org.tensorflow.lite.support.image.TensorImage; + import org.tensorflow.lite.task.core.BaseTaskApi; + import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; + ++import java.nio.ByteBuffer; ++ + /** Base class for Task Vision APIs. */ + public abstract class BaseVisionTaskApi extends BaseTaskApi { +- +- /** Syntax sugar to run vision tasks with FrameBuffer and image processing options. */ +- public interface InferenceProvider<T> { +- T run(long frameBufferHandle, int width, int height, ImageProcessingOptions options); +- } +- +- protected BaseVisionTaskApi(long nativeHandle) { +- super(nativeHandle); +- } +- +- /** Runs inference with {@link TensorImage} and {@link ImageProcessingOptions}. */ +- protected <T> T run( +- InferenceProvider<T> provider, TensorImage image, ImageProcessingOptions options) { +- FrameBufferData frameBufferData = createFrameBuffer(image, options.getOrientation().getValue()); +- T results = +- provider.run( +- frameBufferData.getFrameBufferHandle(), image.getWidth(), image.getHeight(), options); +- deleteFrameBuffer( +- frameBufferData.getFrameBufferHandle(), +- frameBufferData.getByteArrayHandle(), +- frameBufferData.getByteArray()); +- return results; +- } +- +- private static FrameBufferData createFrameBuffer(TensorImage image, int orientation) { +- ColorSpaceType colorSpaceType = image.getColorSpaceType(); +- switch (colorSpaceType) { +- case RGB: +- case NV12: +- case NV21: +- case YV12: +- case YV21: +- // All these types can be converted to ByteBuffer inside TensorImage. Creating FrameBuffer +- // base on the image ByteBuffer. +- return createFrameBufferFromByteBuffer(image, orientation); +- case YUV_420_888: +- // YUV_420_888 is a specific type for android.media.Image. +- return createFrameBufferFromMediaImage(image, orientation); +- default: +- throw new IllegalArgumentException( +- "Color space type, " + colorSpaceType.name() + ", is unsupported."); ++ /** Syntax sugar to run vision tasks with FrameBuffer and image processing options. */ ++ public interface InferenceProvider<T> { ++ T run(long frameBufferHandle, int width, int height, ImageProcessingOptions options); + } +- } +- +- /** +- * Creates FrameBuffer from the {@link android.media.Image} stored in the given {@link +- * TensorImage}. +- */ +- private static FrameBufferData createFrameBufferFromMediaImage( +- TensorImage image, int orientation) { +- Image mediaImage = image.getMediaImage(); +- +- checkArgument( +- mediaImage.getFormat() == ImageFormat.YUV_420_888, +- "Only supports loading YUV_420_888 Image."); +- +- Plane[] planes = mediaImage.getPlanes(); +- checkArgument( +- planes.length == 3, +- String.format("The input image should have 3 planes, but got %d plane(s).", planes.length)); +- +- // Verify and rewind planes. +- for (Plane plane : planes) { +- ByteBuffer buffer = plane.getBuffer(); +- checkNotNull(buffer, "The image buffer is corrupted and the plane is null."); +- // From the public documentation, plane.getBuffer() should always return a direct ByteBuffer. +- // See https://developer.android.com/reference/android/media/Image.Plane#getBuffer() +- checkArgument( +- buffer.isDirect(), +- "The image plane buffer is not a direct ByteBuffer, and is not supported."); +- buffer.rewind(); ++ ++ protected BaseVisionTaskApi(long nativeHandle) { ++ super(nativeHandle); + } + +- return FrameBufferData.create( +- createFrameBufferFromPlanes( +- planes[0].getBuffer(), +- planes[1].getBuffer(), +- planes[2].getBuffer(), +- mediaImage.getWidth(), +- mediaImage.getHeight(), +- planes[0].getRowStride(), +- // row_stride and pixel_stride should be identical for U/V planes. +- planes[1].getRowStride(), +- planes[1].getPixelStride(), +- orientation), +- // FrameBuffer created with direct ByteBuffer does not require memory freeing. +- /*byteArrayHandle=*/ 0, +- /*byteArray=*/ new byte[0]); +- } +- +- /** Creates FrameBuffer from the {@link ByteBuffer} stored in the given {@link TensorImage}. */ +- private static FrameBufferData createFrameBufferFromByteBuffer( +- TensorImage image, int orientation) { +- // base_vision_api_jni.cc expects an uint8 image. Convert image of other types into uint8. +- TensorImage imageUint8 = +- image.getDataType() == DataType.UINT8 +- ? image +- : TensorImage.createFrom(image, DataType.UINT8); +- +- ByteBuffer byteBuffer = imageUint8.getBuffer(); +- byteBuffer.rewind(); +- ColorSpaceType colorSpaceType = image.getColorSpaceType(); +- if (byteBuffer.isDirect()) { +- return FrameBufferData.create( +- createFrameBufferFromByteBuffer( +- byteBuffer, +- imageUint8.getWidth(), +- imageUint8.getHeight(), +- orientation, +- colorSpaceType.getValue()), +- // FrameBuffer created with direct ByteBuffer does not require memory freeing. +- /*byteArrayHandle=*/ 0, +- /*byteArray=*/ new byte[0]); +- } else { +- // If the byte array is copied in jni (during GetByteArrayElements), need to free +- // the copied array once inference is done. +- long[] byteArrayHandle = new long[1]; +- byte[] byteArray = getBytesFromByteBuffer(byteBuffer); +- return FrameBufferData.create( +- createFrameBufferFromBytes( +- byteArray, +- imageUint8.getWidth(), +- imageUint8.getHeight(), +- orientation, +- colorSpaceType.getValue(), +- byteArrayHandle), +- byteArrayHandle[0], +- byteArray); ++ /** Runs inference with {@link TensorImage} and {@link ImageProcessingOptions}. */ ++ protected <T> T run( ++ InferenceProvider<T> provider, TensorImage image, ImageProcessingOptions options) { ++ FrameBufferData frameBufferData = ++ createFrameBuffer(image, options.getOrientation().getValue()); ++ T results = provider.run(frameBufferData.getFrameBufferHandle(), image.getWidth(), ++ image.getHeight(), options); ++ deleteFrameBuffer(frameBufferData.getFrameBufferHandle(), ++ frameBufferData.getByteArrayHandle(), frameBufferData.getByteArray()); ++ return results; + } +- } + +- /** Holds the FrameBuffer and the underlying data pointers in C++. */ +- @AutoValue +- abstract static class FrameBufferData { ++ private static FrameBufferData createFrameBuffer(TensorImage image, int orientation) { ++ ColorSpaceType colorSpaceType = image.getColorSpaceType(); ++ switch (colorSpaceType) { ++ case RGB: ++ case NV12: ++ case NV21: ++ case YV12: ++ case YV21: ++ // All these types can be converted to ByteBuffer inside TensorImage. Creating ++ // FrameBuffer base on the image ByteBuffer. ++ return createFrameBufferFromByteBuffer(image, orientation); ++ case YUV_420_888: ++ // YUV_420_888 is a specific type for android.media.Image. ++ return createFrameBufferFromMediaImage(image, orientation); ++ default: ++ throw new IllegalArgumentException( ++ "Color space type, " + colorSpaceType.name() + ", is unsupported."); ++ } ++ } + + /** +- * Initializes a {@link FrameBufferData} object. +- * +- * @param frameBufferHandle the native handle to the FrameBuffer object. +- * @param byteArrayHandle the native handle to the data array that backs up the FrameBuffer +- * object. If the FrameBuffer is created on a byte array, this byte array need to be freed +- * after inference is done. If the FrameBuffer is created on a direct ByteBuffer, no byte +- * array needs to be freed, and byteArrayHandle will be 0. +- * @param byteArray the byte array that is used to create the c++ byte array object, which is +- * needed when releasing byteArrayHandle. If the FrameBuffer is created on a direct +- * ByteBuffer (no byte array needs to be freed), pass in an empty array for {@code +- * byteArray}. ++ * Creates FrameBuffer from the {@link android.media.Image} stored in the given {@link ++ * TensorImage}. + */ +- public static FrameBufferData create( +- long frameBufferHandle, long byteArrayHandle, byte[] byteArray) { +- return new AutoValue_BaseVisionTaskApi_FrameBufferData( +- frameBufferHandle, byteArrayHandle, byteArray); ++ private static FrameBufferData createFrameBufferFromMediaImage( ++ TensorImage image, int orientation) { ++ Image mediaImage = image.getMediaImage(); ++ ++ checkArgument(mediaImage.getFormat() == ImageFormat.YUV_420_888, ++ "Only supports loading YUV_420_888 Image."); ++ ++ Plane[] planes = mediaImage.getPlanes(); ++ checkArgument(planes.length == 3, ++ String.format("The input image should have 3 planes, but got %d plane(s).", ++ planes.length)); ++ ++ // Verify and rewind planes. ++ for (Plane plane : planes) { ++ ByteBuffer buffer = plane.getBuffer(); ++ checkNotNull(buffer, "The image buffer is corrupted and the plane is null."); ++ // From the public documentation, plane.getBuffer() should always return a direct ++ // ByteBuffer. See ++ // https://developer.android.com/reference/android/media/Image.Plane#getBuffer() ++ checkArgument(buffer.isDirect(), ++ "The image plane buffer is not a direct ByteBuffer, and is not supported."); ++ buffer.rewind(); ++ } ++ ++ return FrameBufferData.create( ++ createFrameBufferFromPlanes(planes[0].getBuffer(), planes[1].getBuffer(), ++ planes[2].getBuffer(), mediaImage.getWidth(), mediaImage.getHeight(), ++ planes[0].getRowStride(), ++ // row_stride and pixel_stride should be identical for U/V planes. ++ planes[1].getRowStride(), planes[1].getPixelStride(), orientation), ++ // FrameBuffer created with direct ByteBuffer does not require memory freeing. ++ /*byteArrayHandle=*/0, ++ /*byteArray=*/new byte[0]); ++ } ++ ++ /** Creates FrameBuffer from the {@link ByteBuffer} stored in the given {@link TensorImage}. */ ++ private static FrameBufferData createFrameBufferFromByteBuffer( ++ TensorImage image, int orientation) { ++ // base_vision_api_jni.cc expects an uint8 image. Convert image of other types into uint8. ++ TensorImage imageUint8 = image.getDataType() == DataType.UINT8 ++ ? image ++ : TensorImage.createFrom(image, DataType.UINT8); ++ ++ ByteBuffer byteBuffer = imageUint8.getBuffer(); ++ byteBuffer.rewind(); ++ ColorSpaceType colorSpaceType = image.getColorSpaceType(); ++ if (byteBuffer.isDirect()) { ++ return FrameBufferData.create( ++ createFrameBufferFromByteBuffer(byteBuffer, imageUint8.getWidth(), ++ imageUint8.getHeight(), orientation, colorSpaceType.getValue()), ++ // FrameBuffer created with direct ByteBuffer does not require memory freeing. ++ /*byteArrayHandle=*/0, ++ /*byteArray=*/new byte[0]); ++ } else { ++ // If the byte array is copied in jni (during GetByteArrayElements), need to free ++ // the copied array once inference is done. ++ long[] byteArrayHandle = new long[1]; ++ byte[] byteArray = getBytesFromByteBuffer(byteBuffer); ++ return FrameBufferData.create( ++ createFrameBufferFromBytes(byteArray, imageUint8.getWidth(), ++ imageUint8.getHeight(), orientation, colorSpaceType.getValue(), ++ byteArrayHandle), ++ byteArrayHandle[0], byteArray); ++ } ++ } ++ ++ /** Holds the FrameBuffer and the underlying data pointers in C++. */ ++ @AutoValue ++ abstract static class FrameBufferData { ++ /** ++ * Initializes a {@link FrameBufferData} object. ++ * ++ * @param frameBufferHandle the native handle to the FrameBuffer object. ++ * @param byteArrayHandle the native handle to the data array that backs up the FrameBuffer ++ * object. If the FrameBuffer is created on a byte array, this byte array need to be ++ * freed after inference is done. If the FrameBuffer is created on a direct ByteBuffer, no ++ * byte array needs to be freed, and byteArrayHandle will be 0. ++ * @param byteArray the byte array that is used to create the c++ byte array object, which ++ * is ++ * needed when releasing byteArrayHandle. If the FrameBuffer is created on a direct ++ * ByteBuffer (no byte array needs to be freed), pass in an empty array for {@code ++ * byteArray}. ++ */ ++ public static FrameBufferData create( ++ long frameBufferHandle, long byteArrayHandle, byte[] byteArray) { ++ return new AutoValue_BaseVisionTaskApi_FrameBufferData( ++ frameBufferHandle, byteArrayHandle, byteArray); ++ } ++ ++ abstract long getFrameBufferHandle(); ++ ++ abstract long getByteArrayHandle(); ++ ++ // Package private method for transferring data. ++ @SuppressWarnings("mutable") ++ abstract byte[] getByteArray(); + } + +- abstract long getFrameBufferHandle(); +- +- abstract long getByteArrayHandle(); +- +- // Package private method for transferring data. +- @SuppressWarnings("mutable") +- abstract byte[] getByteArray(); +- } +- +- private static native long createFrameBufferFromByteBuffer( +- ByteBuffer image, int width, int height, int orientation, int colorSpaceType); +- +- private static native long createFrameBufferFromBytes( +- byte[] image, +- int width, +- int height, +- int orientation, +- int colorSpaceType, +- long[] byteArrayHandle); +- +- private static native long createFrameBufferFromPlanes( +- ByteBuffer yBuffer, +- ByteBuffer uBuffer, +- ByteBuffer vBuffer, +- int width, +- int height, +- int yRowStride, +- int uvRowStride, +- int uvPixelStride, +- int orientation); +- +- private static native void deleteFrameBuffer( +- long frameBufferHandle, long byteArrayHandle, byte[] byteArray); +- +- private static byte[] getBytesFromByteBuffer(ByteBuffer byteBuffer) { +- // If the ByteBuffer has a back up array, use it directly without copy. +- if (byteBuffer.hasArray() && byteBuffer.arrayOffset() == 0) { +- return byteBuffer.array(); ++ private static native long createFrameBufferFromByteBuffer( ++ ByteBuffer image, int width, int height, int orientation, int colorSpaceType); ++ ++ private static native long createFrameBufferFromBytes(byte[] image, int width, int height, ++ int orientation, int colorSpaceType, long[] byteArrayHandle); ++ ++ private static native long createFrameBufferFromPlanes(ByteBuffer yBuffer, ByteBuffer uBuffer, ++ ByteBuffer vBuffer, int width, int height, int yRowStride, int uvRowStride, ++ int uvPixelStride, int orientation); ++ ++ private static native void deleteFrameBuffer( ++ long frameBufferHandle, long byteArrayHandle, byte[] byteArray); ++ ++ private static byte[] getBytesFromByteBuffer(ByteBuffer byteBuffer) { ++ // If the ByteBuffer has a back up array, use it directly without copy. ++ if (byteBuffer.hasArray() && byteBuffer.arrayOffset() == 0) { ++ return byteBuffer.array(); ++ } ++ // Copy out the data otherwise. ++ byteBuffer.rewind(); ++ byte[] bytes = new byte[byteBuffer.limit()]; ++ byteBuffer.get(bytes, 0, bytes.length); ++ return bytes; + } +- // Copy out the data otherwise. +- byteBuffer.rewind(); +- byte[] bytes = new byte[byteBuffer.limit()]; +- byteBuffer.get(bytes, 0, bytes.length); +- return bytes; +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java +index 007e032d8b331..7106fe8a08b35 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java +@@ -16,27 +16,29 @@ limitations under the License. + package org.tensorflow.lite.task.vision.detector; + + import android.graphics.RectF; ++ + import com.google.auto.value.AutoValue; ++ ++import org.tensorflow.lite.annotations.UsedByReflection; ++import org.tensorflow.lite.support.label.Category; ++ + import java.util.ArrayList; + import java.util.Collections; + import java.util.List; +-import org.tensorflow.lite.annotations.UsedByReflection; +-import org.tensorflow.lite.support.label.Category; + + /** Represents one detected object in the results of a {@link ObjectDetector}. */ + @AutoValue + @UsedByReflection("object_detection_jni.cc") + public abstract class Detection { ++ @UsedByReflection("object_detection_jni.cc") ++ public static Detection create(RectF boundingBox, List<Category> categories) { ++ return new AutoValue_Detection(new RectF(boundingBox), ++ Collections.unmodifiableList(new ArrayList<Category>(categories))); ++ } + +- @UsedByReflection("object_detection_jni.cc") +- public static Detection create(RectF boundingBox, List<Category> categories) { +- return new AutoValue_Detection( +- new RectF(boundingBox), Collections.unmodifiableList(new ArrayList<Category>(categories))); +- } +- +- public abstract RectF getBoundingBox(); ++ public abstract RectF getBoundingBox(); + +- // Same reason for not using ImmutableList as stated in +- // {@link ObjectDetector#ObjectDetectorOptions#labelAllowList}. +- public abstract List<Category> getCategories(); ++ // Same reason for not using ImmutableList as stated in ++ // {@link ObjectDetector#ObjectDetectorOptions#labelAllowList}. ++ public abstract List<Category> getCategories(); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java +index e2046d15a7351..c0585b8eda6aa 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java +@@ -17,14 +17,9 @@ package org.tensorflow.lite.task.vision.detector; + + import android.content.Context; + import android.os.ParcelFileDescriptor; ++ + import com.google.android.odml.image.MlImage; +-import java.io.File; +-import java.io.IOException; +-import java.nio.ByteBuffer; +-import java.nio.MappedByteBuffer; +-import java.util.ArrayList; +-import java.util.Collections; +-import java.util.List; ++ + import org.tensorflow.lite.annotations.UsedByReflection; + import org.tensorflow.lite.support.image.MlImageAdapter; + import org.tensorflow.lite.support.image.TensorImage; +@@ -35,6 +30,14 @@ import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; + import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; + import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi; + ++import java.io.File; ++import java.io.IOException; ++import java.nio.ByteBuffer; ++import java.nio.MappedByteBuffer; ++import java.util.ArrayList; ++import java.util.Collections; ++import java.util.List; ++ + /** + * Performs object detection on images. + * +@@ -86,469 +89,447 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi; + * Hub.</a>. + */ + public final class ObjectDetector extends BaseVisionTaskApi { ++ private static final String OBJECT_DETECTOR_NATIVE_LIB = "task_vision_jni"; ++ private static final int OPTIONAL_FD_LENGTH = -1; ++ private static final int OPTIONAL_FD_OFFSET = -1; ++ ++ /** ++ * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}. ++ * ++ * @param modelPath path to the detection model with metadata in the assets ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ObjectDetector createFromFile(Context context, String modelPath) ++ throws IOException { ++ return createFromFileAndOptions( ++ context, modelPath, ObjectDetectorOptions.builder().build()); ++ } + +- private static final String OBJECT_DETECTOR_NATIVE_LIB = "task_vision_jni"; +- private static final int OPTIONAL_FD_LENGTH = -1; +- private static final int OPTIONAL_FD_OFFSET = -1; +- +- /** +- * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}. +- * +- * @param modelPath path to the detection model with metadata in the assets +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ObjectDetector createFromFile(Context context, String modelPath) +- throws IOException { +- return createFromFileAndOptions(context, modelPath, ObjectDetectorOptions.builder().build()); +- } +- +- /** +- * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}. +- * +- * @param modelFile the detection model {@link File} instance +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ObjectDetector createFromFile(File modelFile) throws IOException { +- return createFromFileAndOptions(modelFile, ObjectDetectorOptions.builder().build()); +- } +- +- /** +- * Creates an {@link ObjectDetector} instance with a model buffer and the default {@link +- * ObjectDetectorOptions}. +- * +- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection +- * model +- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a +- * {@link MappedByteBuffer} * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ObjectDetector createFromBuffer(final ByteBuffer modelBuffer) { +- return createFromBufferAndOptions(modelBuffer, ObjectDetectorOptions.builder().build()); +- } +- +- /** +- * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}. +- * +- * @param modelPath path to the detection model with metadata in the assets +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ObjectDetector createFromFileAndOptions( +- Context context, String modelPath, ObjectDetectorOptions options) throws IOException { +- return new ObjectDetector( +- TaskJniUtils.createHandleFromFdAndOptions( +- context, +- new FdAndOptionsHandleProvider<ObjectDetectorOptions>() { +- @Override +- public long createHandle( +- int fileDescriptor, +- long fileDescriptorLength, +- long fileDescriptorOffset, +- ObjectDetectorOptions options) { +- return initJniWithModelFdAndOptions( +- fileDescriptor, +- fileDescriptorLength, +- fileDescriptorOffset, +- options, +- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( +- options.getBaseOptions(), options.getNumThreads())); +- } +- }, +- OBJECT_DETECTOR_NATIVE_LIB, +- modelPath, +- options)); +- } +- +- /** +- * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}. +- * +- * @param modelFile the detection model {@link File} instance +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ObjectDetector createFromFileAndOptions( +- File modelFile, final ObjectDetectorOptions options) throws IOException { +- try (ParcelFileDescriptor descriptor = +- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { +- return new ObjectDetector( +- TaskJniUtils.createHandleFromLibrary( +- new TaskJniUtils.EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithModelFdAndOptions( +- descriptor.getFd(), +- /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH, +- /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET, +- options, +- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( +- options.getBaseOptions(), options.getNumThreads())); +- } +- }, +- OBJECT_DETECTOR_NATIVE_LIB)); ++ /** ++ * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}. ++ * ++ * @param modelFile the detection model {@link File} instance ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ObjectDetector createFromFile(File modelFile) throws IOException { ++ return createFromFileAndOptions(modelFile, ObjectDetectorOptions.builder().build()); + } +- } +- +- /** +- * Creates an {@link ObjectDetector} instance with a model buffer and {@link +- * ObjectDetectorOptions}. +- * +- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection +- * model +- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a +- * {@link MappedByteBuffer} +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ObjectDetector createFromBufferAndOptions( +- final ByteBuffer modelBuffer, final ObjectDetectorOptions options) { +- if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { +- throw new IllegalArgumentException( +- "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ ++ /** ++ * Creates an {@link ObjectDetector} instance with a model buffer and the default {@link ++ * ObjectDetectorOptions}. ++ * ++ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection ++ * model ++ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a ++ * {@link MappedByteBuffer} * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ObjectDetector createFromBuffer(final ByteBuffer modelBuffer) { ++ return createFromBufferAndOptions(modelBuffer, ObjectDetectorOptions.builder().build()); + } +- return new ObjectDetector( +- TaskJniUtils.createHandleFromLibrary( +- new EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithByteBuffer( +- modelBuffer, +- options, +- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( +- options.getBaseOptions(), options.getNumThreads())); +- } +- }, +- OBJECT_DETECTOR_NATIVE_LIB)); +- } +- +- /** +- * Constructor to initialize the JNI with a pointer from C++. +- * +- * @param nativeHandle a pointer referencing memory allocated in C++ +- */ +- private ObjectDetector(long nativeHandle) { +- super(nativeHandle); +- } +- +- /** Options for setting up an ObjectDetector. */ +- @UsedByReflection("object_detector_jni.cc") +- public static class ObjectDetectorOptions { +- // Not using AutoValue for this class because scoreThreshold cannot have default value +- // (otherwise, the default value would override the one in the model metadata) and `Optional` is +- // not an option here, because +- // 1. java.util.Optional require Java 8 while we need to support Java 7. +- // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the +- // comments for labelAllowList. +- private final BaseOptions baseOptions; +- private final String displayNamesLocale; +- private final int maxResults; +- private final float scoreThreshold; +- private final boolean isScoreThresholdSet; +- // As an open source project, we've been trying avoiding depending on common java libraries, +- // such as Guava, because it may introduce conflicts with clients who also happen to use those +- // libraries. Therefore, instead of using ImmutableList here, we convert the List into +- // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less +- // vulnerable. +- private final List<String> labelAllowList; +- private final List<String> labelDenyList; +- private final int numThreads; +- +- public static Builder builder() { +- return new Builder(); ++ ++ /** ++ * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}. ++ * ++ * @param modelPath path to the detection model with metadata in the assets ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ObjectDetector createFromFileAndOptions( ++ Context context, String modelPath, ObjectDetectorOptions options) throws IOException { ++ return new ObjectDetector(TaskJniUtils.createHandleFromFdAndOptions( ++ context, new FdAndOptionsHandleProvider<ObjectDetectorOptions>() { ++ @Override ++ public long createHandle(int fileDescriptor, long fileDescriptorLength, ++ long fileDescriptorOffset, ObjectDetectorOptions options) { ++ return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength, ++ fileDescriptorOffset, options, ++ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( ++ options.getBaseOptions(), options.getNumThreads())); ++ } ++ }, OBJECT_DETECTOR_NATIVE_LIB, modelPath, options)); + } + +- /** A builder that helps to configure an instance of ObjectDetectorOptions. */ +- public static class Builder { +- private BaseOptions baseOptions = BaseOptions.builder().build(); +- private String displayNamesLocale = "en"; +- private int maxResults = -1; +- private float scoreThreshold; +- private boolean isScoreThresholdSet = false; +- private List<String> labelAllowList = new ArrayList<>(); +- private List<String> labelDenyList = new ArrayList<>(); +- private int numThreads = -1; +- +- private Builder() {} +- +- /** Sets the general options to configure Task APIs, such as accelerators. */ +- public Builder setBaseOptions(BaseOptions baseOptions) { +- this.baseOptions = baseOptions; +- return this; +- } +- +- /** +- * Sets the locale to use for display names specified through the TFLite Model Metadata, if +- * any. +- * +- * <p>Defaults to English({@code "en"}). See the <a +- * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite +- * Metadata schema file.</a> for the accepted pattern of locale. +- */ +- public Builder setDisplayNamesLocale(String displayNamesLocale) { +- this.displayNamesLocale = displayNamesLocale; +- return this; +- } +- +- /** +- * Sets the maximum number of top-scored detection results to return. +- * +- * <p>If < 0, all available results will be returned. If 0, an invalid argument error is +- * returned. Note that models may intrinsically be limited to returning a maximum number of +- * results N: if the provided value here is above N, only N results will be returned. Defaults +- * to -1. +- * +- * @throws IllegalArgumentException if maxResults is 0. +- */ +- public Builder setMaxResults(int maxResults) { +- if (maxResults == 0) { +- throw new IllegalArgumentException("maxResults cannot be 0."); ++ /** ++ * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}. ++ * ++ * @param modelFile the detection model {@link File} instance ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ObjectDetector createFromFileAndOptions( ++ File modelFile, final ObjectDetectorOptions options) throws IOException { ++ try (ParcelFileDescriptor descriptor = ++ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { ++ return new ObjectDetector( ++ TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithModelFdAndOptions(descriptor.getFd(), ++ /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH, ++ /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options, ++ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( ++ options.getBaseOptions(), options.getNumThreads())); ++ } ++ }, OBJECT_DETECTOR_NATIVE_LIB)); + } +- this.maxResults = maxResults; +- return this; +- } +- +- /** +- * Sets the score threshold that overrides the one provided in the model metadata (if any). +- * Results below this value are rejected. +- */ +- public Builder setScoreThreshold(float scoreThreshold) { +- this.scoreThreshold = scoreThreshold; +- this.isScoreThresholdSet = true; +- return this; +- } +- +- /** +- * Sets the optional allow list of labels. +- * +- * <p>If non-empty, detection results whose label is not in this set will be filtered out. +- * Duplicate or unknown labels are ignored. Mutually exclusive with {@code labelDenyList}. It +- * will cause {@link IllegalStateException} when calling {@link #createFromFileAndOptions}, if +- * both {@code labelDenyList} and {@code labelAllowList} are set. +- */ +- public Builder setLabelAllowList(List<String> labelAllowList) { +- this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList)); +- return this; +- } +- +- /** +- * Sets the optional deny list of labels. +- * +- * <p>If non-empty, detection results whose label is in this set will be filtered out. +- * Duplicate or unknown labels are ignored. Mutually exclusive with {@code labelAllowList}. It +- * will cause {@link IllegalStateException} when calling {@link #createFromFileAndOptions}, if +- * both {@code labelDenyList} and {@code labelAllowList} are set. +- */ +- public Builder setLabelDenyList(List<String> labelDenyList) { +- this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList)); +- return this; +- } +- +- /** +- * Sets the number of threads to be used for TFLite ops that support multi-threading when +- * running inference with CPU. Defaults to -1. +- * +- * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the +- * effect to let TFLite runtime set the value. +- * +- * @deprecated use {@link BaseOptions} to configure number of threads instead. This method +- * will override the number of threads configured from {@link BaseOptions}. +- */ +- @Deprecated +- public Builder setNumThreads(int numThreads) { +- this.numThreads = numThreads; +- return this; +- } +- +- public ObjectDetectorOptions build() { +- return new ObjectDetectorOptions(this); +- } + } + +- @UsedByReflection("object_detector_jni.cc") +- public String getDisplayNamesLocale() { +- return displayNamesLocale; ++ /** ++ * Creates an {@link ObjectDetector} instance with a model buffer and {@link ++ * ObjectDetectorOptions}. ++ * ++ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection ++ * model ++ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a ++ * {@link MappedByteBuffer} ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ObjectDetector createFromBufferAndOptions( ++ final ByteBuffer modelBuffer, final ObjectDetectorOptions options) { ++ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { ++ throw new IllegalArgumentException( ++ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ } ++ return new ObjectDetector(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithByteBuffer(modelBuffer, options, ++ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( ++ options.getBaseOptions(), options.getNumThreads())); ++ } ++ }, OBJECT_DETECTOR_NATIVE_LIB)); + } + +- @UsedByReflection("object_detector_jni.cc") +- public int getMaxResults() { +- return maxResults; ++ /** ++ * Constructor to initialize the JNI with a pointer from C++. ++ * ++ * @param nativeHandle a pointer referencing memory allocated in C++ ++ */ ++ private ObjectDetector(long nativeHandle) { ++ super(nativeHandle); + } + ++ /** Options for setting up an ObjectDetector. */ + @UsedByReflection("object_detector_jni.cc") +- public float getScoreThreshold() { +- return scoreThreshold; ++ public static class ObjectDetectorOptions { ++ // Not using AutoValue for this class because scoreThreshold cannot have default value ++ // (otherwise, the default value would override the one in the model metadata) and ++ // `Optional` is not an option here, because ++ // 1. java.util.Optional require Java 8 while we need to support Java 7. ++ // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See ++ // the comments for labelAllowList. ++ private final BaseOptions baseOptions; ++ private final String displayNamesLocale; ++ private final int maxResults; ++ private final float scoreThreshold; ++ private final boolean isScoreThresholdSet; ++ // As an open source project, we've been trying avoiding depending on common java libraries, ++ // such as Guava, because it may introduce conflicts with clients who also happen to use ++ // those libraries. Therefore, instead of using ImmutableList here, we convert the List into ++ // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less ++ // vulnerable. ++ private final List<String> labelAllowList; ++ private final List<String> labelDenyList; ++ private final int numThreads; ++ ++ public static Builder builder() { ++ return new Builder(); ++ } ++ ++ /** A builder that helps to configure an instance of ObjectDetectorOptions. */ ++ public static class Builder { ++ private BaseOptions baseOptions = BaseOptions.builder().build(); ++ private String displayNamesLocale = "en"; ++ private int maxResults = -1; ++ private float scoreThreshold; ++ private boolean isScoreThresholdSet = false; ++ private List<String> labelAllowList = new ArrayList<>(); ++ private List<String> labelDenyList = new ArrayList<>(); ++ private int numThreads = -1; ++ ++ private Builder() {} ++ ++ /** Sets the general options to configure Task APIs, such as accelerators. */ ++ public Builder setBaseOptions(BaseOptions baseOptions) { ++ this.baseOptions = baseOptions; ++ return this; ++ } ++ ++ /** ++ * Sets the locale to use for display names specified through the TFLite Model Metadata, ++ * if any. ++ * ++ * <p>Defaults to English({@code "en"}). See the <a ++ * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite ++ * Metadata schema file.</a> for the accepted pattern of locale. ++ */ ++ public Builder setDisplayNamesLocale(String displayNamesLocale) { ++ this.displayNamesLocale = displayNamesLocale; ++ return this; ++ } ++ ++ /** ++ * Sets the maximum number of top-scored detection results to return. ++ * ++ * <p>If < 0, all available results will be returned. If 0, an invalid argument error is ++ * returned. Note that models may intrinsically be limited to returning a maximum number ++ * of results N: if the provided value here is above N, only N results will be returned. ++ * Defaults to -1. ++ * ++ * @throws IllegalArgumentException if maxResults is 0. ++ */ ++ public Builder setMaxResults(int maxResults) { ++ if (maxResults == 0) { ++ throw new IllegalArgumentException("maxResults cannot be 0."); ++ } ++ this.maxResults = maxResults; ++ return this; ++ } ++ ++ /** ++ * Sets the score threshold that overrides the one provided in the model metadata (if ++ * any). Results below this value are rejected. ++ */ ++ public Builder setScoreThreshold(float scoreThreshold) { ++ this.scoreThreshold = scoreThreshold; ++ this.isScoreThresholdSet = true; ++ return this; ++ } ++ ++ /** ++ * Sets the optional allow list of labels. ++ * ++ * <p>If non-empty, detection results whose label is not in this set will be filtered ++ * out. Duplicate or unknown labels are ignored. Mutually exclusive with {@code ++ * labelDenyList}. It will cause {@link IllegalStateException} when calling {@link ++ * #createFromFileAndOptions}, if both {@code labelDenyList} and {@code labelAllowList} ++ * are set. ++ */ ++ public Builder setLabelAllowList(List<String> labelAllowList) { ++ this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList)); ++ return this; ++ } ++ ++ /** ++ * Sets the optional deny list of labels. ++ * ++ * <p>If non-empty, detection results whose label is in this set will be filtered out. ++ * Duplicate or unknown labels are ignored. Mutually exclusive with {@code ++ * labelAllowList}. It will cause {@link IllegalStateException} when calling {@link ++ * #createFromFileAndOptions}, if both {@code labelDenyList} and {@code labelAllowList} ++ * are set. ++ */ ++ public Builder setLabelDenyList(List<String> labelDenyList) { ++ this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList)); ++ return this; ++ } ++ ++ /** ++ * Sets the number of threads to be used for TFLite ops that support multi-threading ++ * when running inference with CPU. Defaults to -1. ++ * ++ * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has ++ * the effect to let TFLite runtime set the value. ++ * ++ * @deprecated use {@link BaseOptions} to configure number of threads instead. This ++ * method ++ * will override the number of threads configured from {@link BaseOptions}. ++ */ ++ @Deprecated ++ public Builder setNumThreads(int numThreads) { ++ this.numThreads = numThreads; ++ return this; ++ } ++ ++ public ObjectDetectorOptions build() { ++ return new ObjectDetectorOptions(this); ++ } ++ } ++ ++ @UsedByReflection("object_detector_jni.cc") ++ public String getDisplayNamesLocale() { ++ return displayNamesLocale; ++ } ++ ++ @UsedByReflection("object_detector_jni.cc") ++ public int getMaxResults() { ++ return maxResults; ++ } ++ ++ @UsedByReflection("object_detector_jni.cc") ++ public float getScoreThreshold() { ++ return scoreThreshold; ++ } ++ ++ @UsedByReflection("object_detector_jni.cc") ++ public boolean getIsScoreThresholdSet() { ++ return isScoreThresholdSet; ++ } ++ ++ @UsedByReflection("object_detector_jni.cc") ++ public List<String> getLabelAllowList() { ++ return new ArrayList<>(labelAllowList); ++ } ++ ++ @UsedByReflection("object_detector_jni.cc") ++ public List<String> getLabelDenyList() { ++ return new ArrayList<>(labelDenyList); ++ } ++ ++ @UsedByReflection("object_detector_jni.cc") ++ public int getNumThreads() { ++ return numThreads; ++ } ++ ++ public BaseOptions getBaseOptions() { ++ return baseOptions; ++ } ++ ++ private ObjectDetectorOptions(Builder builder) { ++ displayNamesLocale = builder.displayNamesLocale; ++ maxResults = builder.maxResults; ++ scoreThreshold = builder.scoreThreshold; ++ isScoreThresholdSet = builder.isScoreThresholdSet; ++ labelAllowList = builder.labelAllowList; ++ labelDenyList = builder.labelDenyList; ++ numThreads = builder.numThreads; ++ baseOptions = builder.baseOptions; ++ } + } + +- @UsedByReflection("object_detector_jni.cc") +- public boolean getIsScoreThresholdSet() { +- return isScoreThresholdSet; ++ /** ++ * Performs actual detection on the provided image. ++ * ++ * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types: ++ * ++ * <ul> ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} ++ * </ul> ++ * ++ * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ * @throws IllegalArgumentException if the color space type of image is unsupported ++ */ ++ public List<Detection> detect(TensorImage image) { ++ return detect(image, ImageProcessingOptions.builder().build()); + } + +- @UsedByReflection("object_detector_jni.cc") +- public List<String> getLabelAllowList() { +- return new ArrayList<>(labelAllowList); ++ /** ++ * Performs actual detection on the provided image. ++ * ++ * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types: ++ * ++ * <ul> ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} ++ * </ul> ++ * ++ * <p>{@link ObjectDetector} supports the following options: ++ * ++ * <ul> ++ * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It ++ * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. ++ * </ul> ++ * ++ * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image ++ * @param options the options to configure how to preprocess the image ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ * @throws IllegalArgumentException if the color space type of image is unsupported ++ */ ++ public List<Detection> detect(TensorImage image, ImageProcessingOptions options) { ++ return run(new InferenceProvider<List<Detection>>() { ++ @Override ++ public List<Detection> run( ++ long frameBufferHandle, int width, int height, ImageProcessingOptions options) { ++ return detect(frameBufferHandle, options); ++ } ++ }, image, options); + } + +- @UsedByReflection("object_detector_jni.cc") +- public List<String> getLabelDenyList() { +- return new ArrayList<>(labelDenyList); ++ /** ++ * Performs actual detection on the provided {@code MlImage}. ++ * ++ * @param image an {@code MlImage} object that represents an image ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ * @throws IllegalArgumentException if the storage type or format of the image is unsupported ++ */ ++ public List<Detection> detect(MlImage image) { ++ return detect(image, ImageProcessingOptions.builder().build()); + } + +- @UsedByReflection("object_detector_jni.cc") +- public int getNumThreads() { +- return numThreads; ++ /** ++ * Performs actual detection on the provided {@code MlImage} with {@link ++ * ImageProcessingOptions}. ++ * ++ * <p>{@link ObjectDetector} supports the following options: ++ * ++ * <ul> ++ * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It ++ * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link ++ * MlImage#getRotation()} is not effective. ++ * </ul> ++ * ++ * @param image an {@code MlImage} object that represents an image ++ * @param options the options to configure how to preprocess the image ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ * @throws IllegalArgumentException if the storage type or format of the image is unsupported ++ */ ++ public List<Detection> detect(MlImage image, ImageProcessingOptions options) { ++ image.getInternal().acquire(); ++ TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); ++ List<Detection> result = detect(tensorImage, options); ++ image.close(); ++ return result; + } + +- public BaseOptions getBaseOptions() { +- return baseOptions; ++ private List<Detection> detect(long frameBufferHandle, ImageProcessingOptions options) { ++ checkNotClosed(); ++ ++ return detectNative(getNativeHandle(), frameBufferHandle); + } + +- private ObjectDetectorOptions(Builder builder) { +- displayNamesLocale = builder.displayNamesLocale; +- maxResults = builder.maxResults; +- scoreThreshold = builder.scoreThreshold; +- isScoreThresholdSet = builder.isScoreThresholdSet; +- labelAllowList = builder.labelAllowList; +- labelDenyList = builder.labelDenyList; +- numThreads = builder.numThreads; +- baseOptions = builder.baseOptions; ++ private static native long initJniWithModelFdAndOptions(int fileDescriptor, ++ long fileDescriptorLength, long fileDescriptorOffset, ObjectDetectorOptions options, ++ long baseOptionsHandle); ++ ++ private static native long initJniWithByteBuffer( ++ ByteBuffer modelBuffer, ObjectDetectorOptions options, long baseOptionsHandle); ++ ++ private static native List<Detection> detectNative(long nativeHandle, long frameBufferHandle); ++ ++ @Override ++ protected void deinit(long nativeHandle) { ++ deinitJni(nativeHandle); + } +- } +- +- /** +- * Performs actual detection on the provided image. +- * +- * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types: +- * +- * <ul> +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} +- * </ul> +- * +- * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- * @throws IllegalArgumentException if the color space type of image is unsupported +- */ +- public List<Detection> detect(TensorImage image) { +- return detect(image, ImageProcessingOptions.builder().build()); +- } +- +- /** +- * Performs actual detection on the provided image. +- * +- * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types: +- * +- * <ul> +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} +- * </ul> +- * +- * <p>{@link ObjectDetector} supports the following options: +- * +- * <ul> +- * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It +- * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. +- * </ul> +- * +- * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image +- * @param options the options to configure how to preprocess the image +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- * @throws IllegalArgumentException if the color space type of image is unsupported +- */ +- public List<Detection> detect(TensorImage image, ImageProcessingOptions options) { +- return run( +- new InferenceProvider<List<Detection>>() { +- @Override +- public List<Detection> run( +- long frameBufferHandle, int width, int height, ImageProcessingOptions options) { +- return detect(frameBufferHandle, options); +- } +- }, +- image, +- options); +- } +- +- /** +- * Performs actual detection on the provided {@code MlImage}. +- * +- * @param image an {@code MlImage} object that represents an image +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- * @throws IllegalArgumentException if the storage type or format of the image is unsupported +- */ +- public List<Detection> detect(MlImage image) { +- return detect(image, ImageProcessingOptions.builder().build()); +- } +- +- /** +- * Performs actual detection on the provided {@code MlImage} with {@link ImageProcessingOptions}. +- * +- * <p>{@link ObjectDetector} supports the following options: +- * +- * <ul> +- * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It +- * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link +- * MlImage#getRotation()} is not effective. +- * </ul> +- * +- * @param image an {@code MlImage} object that represents an image +- * @param options the options to configure how to preprocess the image +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- * @throws IllegalArgumentException if the storage type or format of the image is unsupported +- */ +- public List<Detection> detect(MlImage image, ImageProcessingOptions options) { +- image.getInternal().acquire(); +- TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); +- List<Detection> result = detect(tensorImage, options); +- image.close(); +- return result; +- } +- +- private List<Detection> detect(long frameBufferHandle, ImageProcessingOptions options) { +- checkNotClosed(); +- +- return detectNative(getNativeHandle(), frameBufferHandle); +- } +- +- private static native long initJniWithModelFdAndOptions( +- int fileDescriptor, +- long fileDescriptorLength, +- long fileDescriptorOffset, +- ObjectDetectorOptions options, +- long baseOptionsHandle); +- +- private static native long initJniWithByteBuffer( +- ByteBuffer modelBuffer, ObjectDetectorOptions options, long baseOptionsHandle); +- +- private static native List<Detection> detectNative(long nativeHandle, long frameBufferHandle); +- +- @Override +- protected void deinit(long nativeHandle) { +- deinitJni(nativeHandle); +- } +- +- /** +- * Native implementation to release memory pointed by the pointer. +- * +- * @param nativeHandle pointer to memory allocated +- */ +- private native void deinitJni(long nativeHandle); ++ ++ /** ++ * Native implementation to release memory pointed by the pointer. ++ * ++ * @param nativeHandle pointer to memory allocated ++ */ ++ private native void deinitJni(long nativeHandle); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java +index 0defaa9f16b96..991fedeeae9c2 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java +@@ -17,72 +17,74 @@ package org.tensorflow.lite.task.vision.segmenter; + + import android.graphics.Color; + import android.os.Build; ++ + import androidx.annotation.RequiresApi; ++ + import com.google.auto.value.AutoValue; ++ + import org.tensorflow.lite.annotations.UsedByReflection; + + /** Represents a label associated with a color for display purposes. */ + @AutoValue + @UsedByReflection("image_segmentation_jni.cc") + public abstract class ColoredLabel { ++ /** ++ * Creates a {@link ColoredLabel} object with an ARGB color int. ++ * ++ * @param label the label string, as provided in the label map packed in the TFLite Model ++ * Metadata. ++ * @param displayName the display name of label, as configured through {@link ++ * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale} ++ * @param argb the color components for the label in ARGB. See <a ++ * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android ++ * Color ints.</a> for more details. ++ */ ++ @UsedByReflection("image_segmentation_jni.cc") ++ public static ColoredLabel create(String label, String displayName, int argb) { ++ return new AutoValue_ColoredLabel(label, displayName, argb); ++ } + +- /** +- * Creates a {@link ColoredLabel} object with an ARGB color int. +- * +- * @param label the label string, as provided in the label map packed in the TFLite Model +- * Metadata. +- * @param displayName the display name of label, as configured through {@link +- * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale} +- * @param argb the color components for the label in ARGB. See <a +- * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android +- * Color ints.</a> for more details. +- */ +- @UsedByReflection("image_segmentation_jni.cc") +- public static ColoredLabel create(String label, String displayName, int argb) { +- return new AutoValue_ColoredLabel(label, displayName, argb); +- } +- +- /** +- * Creates a {@link ColoredLabel} object with a {@link android.graphics.Color} instance. +- * +- * @param label the label string, as provided in the label map packed in the TFLite Model +- * Metadata. +- * @param displayName the display name of label, as configured through {@link +- * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale} +- * @param color the color components for the label. The Color instatnce is supported on Android +- * API level 26 and above. For API level lower than 26, use {@link #create(String, String, +- * int)}. See <a +- * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android +- * Color instances.</a> for more details. +- */ +- @RequiresApi(Build.VERSION_CODES.O) +- public static ColoredLabel create(String label, String displayName, Color color) { +- return new AutoValue_ColoredLabel(label, displayName, color.toArgb()); +- } ++ /** ++ * Creates a {@link ColoredLabel} object with a {@link android.graphics.Color} instance. ++ * ++ * @param label the label string, as provided in the label map packed in the TFLite Model ++ * Metadata. ++ * @param displayName the display name of label, as configured through {@link ++ * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale} ++ * @param color the color components for the label. The Color instatnce is supported on Android ++ * API level 26 and above. For API level lower than 26, use {@link #create(String, String, ++ * int)}. See <a ++ * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android ++ * Color instances.</a> for more details. ++ */ ++ @RequiresApi(Build.VERSION_CODES.O) ++ public static ColoredLabel create(String label, String displayName, Color color) { ++ return new AutoValue_ColoredLabel(label, displayName, color.toArgb()); ++ } + +- public abstract String getlabel(); ++ public abstract String getlabel(); + +- public abstract String getDisplayName(); ++ public abstract String getDisplayName(); + +- /** +- * Gets the ARGB int that represents the color. +- * +- * <p>See <a +- * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android Color +- * ints.</a> for more details. +- */ +- public abstract int getArgb(); ++ /** ++ * Gets the ARGB int that represents the color. ++ * ++ * <p>See <a ++ * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android ++ * Color ints.</a> for more details. ++ */ ++ public abstract int getArgb(); + +- /** +- * Gets the {@link android.graphics.Color} instance of the underlying color. +- * +- * <p>The Color instatnce is supported on Android API level 26 and above. For API level lower than +- * 26, use {@link #getArgb()}. See <a +- * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android +- * Color instances.</a> for more details. +- */ +- @RequiresApi(Build.VERSION_CODES.O) +- public Color getColor() { +- return Color.valueOf(getArgb()); +- } ++ /** ++ * Gets the {@link android.graphics.Color} instance of the underlying color. ++ * ++ * <p>The Color instatnce is supported on Android API level 26 and above. For API level lower ++ * than 26, use {@link #getArgb()}. See <a ++ * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android ++ * Color instances.</a> for more details. ++ */ ++ @RequiresApi(Build.VERSION_CODES.O) ++ public Color getColor() { ++ return Color.valueOf(getArgb()); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java +index 0caa7a33e1729..4c3b36304a0e3 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java +@@ -18,16 +18,10 @@ package org.tensorflow.lite.task.vision.segmenter; + import android.content.Context; + import android.content.res.AssetFileDescriptor; + import android.os.ParcelFileDescriptor; ++ + import com.google.android.odml.image.MlImage; + import com.google.auto.value.AutoValue; +-import java.io.File; +-import java.io.IOException; +-import java.nio.ByteBuffer; +-import java.nio.ByteOrder; +-import java.nio.MappedByteBuffer; +-import java.util.ArrayList; +-import java.util.Arrays; +-import java.util.List; ++ + import org.tensorflow.lite.support.image.MlImageAdapter; + import org.tensorflow.lite.support.image.TensorImage; + import org.tensorflow.lite.task.core.BaseOptions; +@@ -37,6 +31,15 @@ import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; + import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi; + import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider; + ++import java.io.File; ++import java.io.IOException; ++import java.nio.ByteBuffer; ++import java.nio.ByteOrder; ++import java.nio.MappedByteBuffer; ++import java.util.ArrayList; ++import java.util.Arrays; ++import java.util.List; ++ + /** + * Performs segmentation on images. + * +@@ -75,394 +78,365 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider; + * href="https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1">TensorFlow Hub.</a>. + */ + public final class ImageSegmenter extends BaseVisionTaskApi { ++ private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni"; ++ private static final int OPTIONAL_FD_LENGTH = -1; ++ private static final int OPTIONAL_FD_OFFSET = -1; ++ ++ private final OutputType outputType; ++ ++ /** ++ * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}. ++ * ++ * @param modelPath path of the segmentation model with metadata in the assets ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ImageSegmenter createFromFile(Context context, String modelPath) ++ throws IOException { ++ return createFromFileAndOptions( ++ context, modelPath, ImageSegmenterOptions.builder().build()); ++ } + +- private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni"; +- private static final int OPTIONAL_FD_LENGTH = -1; +- private static final int OPTIONAL_FD_OFFSET = -1; +- +- private final OutputType outputType; +- +- /** +- * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}. +- * +- * @param modelPath path of the segmentation model with metadata in the assets +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ImageSegmenter createFromFile(Context context, String modelPath) +- throws IOException { +- return createFromFileAndOptions(context, modelPath, ImageSegmenterOptions.builder().build()); +- } +- +- /** +- * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}. +- * +- * @param modelFile the segmentation model {@link File} instance +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ImageSegmenter createFromFile(File modelFile) throws IOException { +- return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build()); +- } +- +- /** +- * Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link +- * ImageSegmenterOptions}. +- * +- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the +- * segmentation model +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a +- * {@link MappedByteBuffer} +- */ +- public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) { +- return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build()); +- } +- +- /** +- * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}. +- * +- * @param modelPath path of the segmentation model with metadata in the assets +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ImageSegmenter createFromFileAndOptions( +- Context context, String modelPath, final ImageSegmenterOptions options) throws IOException { +- try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) { +- return createFromModelFdAndOptions( +- /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(), +- /*fileDescriptorLength=*/ assetFileDescriptor.getLength(), +- /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(), +- options); ++ /** ++ * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}. ++ * ++ * @param modelFile the segmentation model {@link File} instance ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ImageSegmenter createFromFile(File modelFile) throws IOException { ++ return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build()); + } +- } +- +- /** +- * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}. +- * +- * @param modelFile the segmentation model {@link File} instance +- * @throws IOException if an I/O error occurs when loading the tflite model +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ImageSegmenter createFromFileAndOptions( +- File modelFile, final ImageSegmenterOptions options) throws IOException { +- try (ParcelFileDescriptor descriptor = +- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { +- return createFromModelFdAndOptions( +- /*fileDescriptor=*/ descriptor.getFd(), +- /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH, +- /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET, +- options); ++ ++ /** ++ * Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link ++ * ImageSegmenterOptions}. ++ * ++ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the ++ * segmentation model ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a ++ * {@link MappedByteBuffer} ++ */ ++ public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) { ++ return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build()); ++ } ++ ++ /** ++ * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}. ++ * ++ * @param modelPath path of the segmentation model with metadata in the assets ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ImageSegmenter createFromFileAndOptions(Context context, String modelPath, ++ final ImageSegmenterOptions options) throws IOException { ++ try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) { ++ return createFromModelFdAndOptions( ++ /*fileDescriptor=*/assetFileDescriptor.getParcelFileDescriptor().getFd(), ++ /*fileDescriptorLength=*/assetFileDescriptor.getLength(), ++ /*fileDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options); ++ } ++ } ++ ++ /** ++ * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}. ++ * ++ * @param modelFile the segmentation model {@link File} instance ++ * @throws IOException if an I/O error occurs when loading the tflite model ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ImageSegmenter createFromFileAndOptions( ++ File modelFile, final ImageSegmenterOptions options) throws IOException { ++ try (ParcelFileDescriptor descriptor = ++ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { ++ return createFromModelFdAndOptions( ++ /*fileDescriptor=*/descriptor.getFd(), ++ /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH, ++ /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options); ++ } ++ } ++ ++ /** ++ * Creates an {@link ImageSegmenter} instance with a model buffer and {@link ++ * ImageSegmenterOptions}. ++ * ++ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the ++ * segmentation model ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a ++ * {@link MappedByteBuffer} ++ */ ++ public static ImageSegmenter createFromBufferAndOptions( ++ final ByteBuffer modelBuffer, final ImageSegmenterOptions options) { ++ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { ++ throw new IllegalArgumentException( ++ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ } ++ return new ImageSegmenter(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithByteBuffer(modelBuffer, options.getDisplayNamesLocale(), ++ options.getOutputType().getValue(), ++ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( ++ options.getBaseOptions(), options.getNumThreads())); ++ } ++ }, IMAGE_SEGMENTER_NATIVE_LIB), options.getOutputType()); ++ } ++ ++ /** ++ * Constructor to initialize the JNI with a pointer from C++. ++ * ++ * @param nativeHandle a pointer referencing memory allocated in C++ ++ */ ++ private ImageSegmenter(long nativeHandle, OutputType outputType) { ++ super(nativeHandle); ++ this.outputType = outputType; ++ } ++ ++ /** Options for setting up an {@link ImageSegmenter}. */ ++ @AutoValue ++ public abstract static class ImageSegmenterOptions { ++ private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en"; ++ private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK; ++ private static final int NUM_THREADS = -1; ++ ++ public abstract BaseOptions getBaseOptions(); ++ ++ public abstract String getDisplayNamesLocale(); ++ ++ public abstract OutputType getOutputType(); ++ ++ public abstract int getNumThreads(); ++ ++ public static Builder builder() { ++ return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder() ++ .setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE) ++ .setOutputType(DEFAULT_OUTPUT_TYPE) ++ .setNumThreads(NUM_THREADS) ++ .setBaseOptions(BaseOptions.builder().build()); ++ } ++ ++ /** Builder for {@link ImageSegmenterOptions}. */ ++ @AutoValue.Builder ++ public abstract static class Builder { ++ /** Sets the general options to configure Task APIs, such as accelerators. */ ++ public abstract Builder setBaseOptions(BaseOptions baseOptions); ++ ++ /** ++ * Sets the locale to use for display names specified through the TFLite Model Metadata, ++ * if any. ++ * ++ * <p>Defaults to English({@code "en"}). See the <a ++ * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite ++ * Metadata schema file.</a> for the accepted pattern of locale. ++ */ ++ public abstract Builder setDisplayNamesLocale(String displayNamesLocale); ++ ++ public abstract Builder setOutputType(OutputType outputType); ++ ++ /** ++ * Sets the number of threads to be used for TFLite ops that support multi-threading ++ * when running inference with CPU. Defaults to -1. ++ * ++ * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has ++ * the effect to let TFLite runtime set the value. ++ * ++ * @deprecated use {@link BaseOptions} to configure number of threads instead. This ++ * method ++ * will override the number of threads configured from {@link BaseOptions}. ++ */ ++ @Deprecated ++ public abstract Builder setNumThreads(int numThreads); ++ ++ public abstract ImageSegmenterOptions build(); ++ } ++ } ++ ++ /** ++ * Performs actual segmentation on the provided image. ++ * ++ * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types: ++ * ++ * <ul> ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} ++ * </ul> ++ * ++ * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image ++ * @return results of performing image segmentation. Note that at the time, a single {@link ++ * Segmentation} element is expected to be returned. The result is stored in a {@link List} ++ * for later extension to e.g. instance segmentation models, which may return one ++ * segmentation per object. ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ * @throws IllegalArgumentException if the color space type of image is unsupported ++ */ ++ public List<Segmentation> segment(TensorImage image) { ++ return segment(image, ImageProcessingOptions.builder().build()); ++ } ++ ++ /** ++ * Performs actual segmentation on the provided image with {@link ImageProcessingOptions}. ++ * ++ * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types: ++ * ++ * <ul> ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} ++ * </ul> ++ * ++ * <p>{@link ImageSegmenter} supports the following options: ++ * ++ * <ul> ++ * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It ++ * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT} ++ * </ul> ++ * ++ * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image ++ * @param options the options configure how to preprocess the image ++ * @return results of performing image segmentation. Note that at the time, a single {@link ++ * Segmentation} element is expected to be returned. The result is stored in a {@link List} ++ * for later extension to e.g. instance segmentation models, which may return one ++ * segmentation per object. ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ * @throws IllegalArgumentException if the color space type of image is unsupported ++ */ ++ public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) { ++ return run(new InferenceProvider<List<Segmentation>>() { ++ @Override ++ public List<Segmentation> run( ++ long frameBufferHandle, int width, int height, ImageProcessingOptions options) { ++ return segment(frameBufferHandle, options); ++ } ++ }, image, options); + } +- } +- +- /** +- * Creates an {@link ImageSegmenter} instance with a model buffer and {@link +- * ImageSegmenterOptions}. +- * +- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the +- * segmentation model +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a +- * {@link MappedByteBuffer} +- */ +- public static ImageSegmenter createFromBufferAndOptions( +- final ByteBuffer modelBuffer, final ImageSegmenterOptions options) { +- if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { +- throw new IllegalArgumentException( +- "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ ++ /** ++ * Performs actual segmentation on the provided {@code MlImage}. ++ * ++ * @param image an {@code MlImage} to segment. ++ * @return results of performing image segmentation. Note that at the time, a single {@link ++ * Segmentation} element is expected to be returned. The result is stored in a {@link List} ++ * for later extension to e.g. instance segmentation models, which may return one ++ * segmentation per object. ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ * @throws IllegalArgumentException if the storage type or format of the image is unsupported ++ */ ++ public List<Segmentation> segment(MlImage image) { ++ return segment(image, ImageProcessingOptions.builder().build()); + } +- return new ImageSegmenter( +- TaskJniUtils.createHandleFromLibrary( +- new EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithByteBuffer( +- modelBuffer, +- options.getDisplayNamesLocale(), +- options.getOutputType().getValue(), +- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( +- options.getBaseOptions(), options.getNumThreads())); +- } +- }, +- IMAGE_SEGMENTER_NATIVE_LIB), +- options.getOutputType()); +- } +- +- /** +- * Constructor to initialize the JNI with a pointer from C++. +- * +- * @param nativeHandle a pointer referencing memory allocated in C++ +- */ +- private ImageSegmenter(long nativeHandle, OutputType outputType) { +- super(nativeHandle); +- this.outputType = outputType; +- } +- +- /** Options for setting up an {@link ImageSegmenter}. */ +- @AutoValue +- public abstract static class ImageSegmenterOptions { +- private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en"; +- private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK; +- private static final int NUM_THREADS = -1; +- +- public abstract BaseOptions getBaseOptions(); +- +- public abstract String getDisplayNamesLocale(); +- +- public abstract OutputType getOutputType(); +- +- public abstract int getNumThreads(); +- +- public static Builder builder() { +- return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder() +- .setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE) +- .setOutputType(DEFAULT_OUTPUT_TYPE) +- .setNumThreads(NUM_THREADS) +- .setBaseOptions(BaseOptions.builder().build()); ++ ++ /** ++ * Performs actual segmentation on the provided {@code MlImage} with {@link ++ * ImageProcessingOptions}. ++ * ++ * <p>{@link ImageSegmenter} supports the following options: ++ * ++ * <ul> ++ * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It ++ * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link ++ * MlImage#getRotation()} is not effective. ++ * </ul> ++ * ++ * @param image an {@code MlImage} to segment. ++ * @param options the options configure how to preprocess the image. ++ * @return results of performing image segmentation. Note that at the time, a single {@link ++ * Segmentation} element is expected to be returned. The result is stored in a {@link List} ++ * for later extension to e.g. instance segmentation models, which may return one ++ * segmentation per object. ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ * @throws IllegalArgumentException if the color space type of image is unsupported ++ */ ++ public List<Segmentation> segment(MlImage image, ImageProcessingOptions options) { ++ image.getInternal().acquire(); ++ TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); ++ List<Segmentation> result = segment(tensorImage, options); ++ image.close(); ++ return result; + } + +- /** Builder for {@link ImageSegmenterOptions}. */ +- @AutoValue.Builder +- public abstract static class Builder { +- +- /** Sets the general options to configure Task APIs, such as accelerators. */ +- public abstract Builder setBaseOptions(BaseOptions baseOptions); +- +- /** +- * Sets the locale to use for display names specified through the TFLite Model Metadata, if +- * any. +- * +- * <p>Defaults to English({@code "en"}). See the <a +- * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite +- * Metadata schema file.</a> for the accepted pattern of locale. +- */ +- public abstract Builder setDisplayNamesLocale(String displayNamesLocale); +- +- public abstract Builder setOutputType(OutputType outputType); +- +- /** +- * Sets the number of threads to be used for TFLite ops that support multi-threading when +- * running inference with CPU. Defaults to -1. +- * +- * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the +- * effect to let TFLite runtime set the value. +- * +- * @deprecated use {@link BaseOptions} to configure number of threads instead. This method +- * will override the number of threads configured from {@link BaseOptions}. +- */ +- @Deprecated +- public abstract Builder setNumThreads(int numThreads); +- +- public abstract ImageSegmenterOptions build(); ++ public List<Segmentation> segment(long frameBufferHandle, ImageProcessingOptions options) { ++ checkNotClosed(); ++ ++ List<byte[]> maskByteArrays = new ArrayList<>(); ++ List<ColoredLabel> coloredLabels = new ArrayList<>(); ++ int[] maskShape = new int[2]; ++ segmentNative( ++ getNativeHandle(), frameBufferHandle, maskByteArrays, maskShape, coloredLabels); ++ ++ List<ByteBuffer> maskByteBuffers = new ArrayList<>(); ++ for (byte[] bytes : maskByteArrays) { ++ ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); ++ // Change the byte order to little_endian, since the buffers were generated in jni. ++ byteBuffer.order(ByteOrder.LITTLE_ENDIAN); ++ maskByteBuffers.add(byteBuffer); ++ } ++ ++ return Arrays.asList(Segmentation.create(outputType, ++ outputType.createMasksFromBuffer(maskByteBuffers, maskShape), coloredLabels)); + } +- } +- +- /** +- * Performs actual segmentation on the provided image. +- * +- * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types: +- * +- * <ul> +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} +- * </ul> +- * +- * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image +- * @return results of performing image segmentation. Note that at the time, a single {@link +- * Segmentation} element is expected to be returned. The result is stored in a {@link List} +- * for later extension to e.g. instance segmentation models, which may return one segmentation +- * per object. +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- * @throws IllegalArgumentException if the color space type of image is unsupported +- */ +- public List<Segmentation> segment(TensorImage image) { +- return segment(image, ImageProcessingOptions.builder().build()); +- } +- +- /** +- * Performs actual segmentation on the provided image with {@link ImageProcessingOptions}. +- * +- * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types: +- * +- * <ul> +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} +- * </ul> +- * +- * <p>{@link ImageSegmenter} supports the following options: +- * +- * <ul> +- * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It +- * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT} +- * </ul> +- * +- * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image +- * @param options the options configure how to preprocess the image +- * @return results of performing image segmentation. Note that at the time, a single {@link +- * Segmentation} element is expected to be returned. The result is stored in a {@link List} +- * for later extension to e.g. instance segmentation models, which may return one segmentation +- * per object. +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- * @throws IllegalArgumentException if the color space type of image is unsupported +- */ +- public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) { +- return run( +- new InferenceProvider<List<Segmentation>>() { +- @Override +- public List<Segmentation> run( +- long frameBufferHandle, int width, int height, ImageProcessingOptions options) { +- return segment(frameBufferHandle, options); +- } +- }, +- image, +- options); +- } +- +- /** +- * Performs actual segmentation on the provided {@code MlImage}. +- * +- * @param image an {@code MlImage} to segment. +- * @return results of performing image segmentation. Note that at the time, a single {@link +- * Segmentation} element is expected to be returned. The result is stored in a {@link List} +- * for later extension to e.g. instance segmentation models, which may return one segmentation +- * per object. +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- * @throws IllegalArgumentException if the storage type or format of the image is unsupported +- */ +- public List<Segmentation> segment(MlImage image) { +- return segment(image, ImageProcessingOptions.builder().build()); +- } +- +- /** +- * Performs actual segmentation on the provided {@code MlImage} with {@link +- * ImageProcessingOptions}. +- * +- * <p>{@link ImageSegmenter} supports the following options: +- * +- * <ul> +- * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It +- * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link +- * MlImage#getRotation()} is not effective. +- * </ul> +- * +- * @param image an {@code MlImage} to segment. +- * @param options the options configure how to preprocess the image. +- * @return results of performing image segmentation. Note that at the time, a single {@link +- * Segmentation} element is expected to be returned. The result is stored in a {@link List} +- * for later extension to e.g. instance segmentation models, which may return one segmentation +- * per object. +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- * @throws IllegalArgumentException if the color space type of image is unsupported +- */ +- public List<Segmentation> segment(MlImage image, ImageProcessingOptions options) { +- image.getInternal().acquire(); +- TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); +- List<Segmentation> result = segment(tensorImage, options); +- image.close(); +- return result; +- } +- +- public List<Segmentation> segment(long frameBufferHandle, ImageProcessingOptions options) { +- checkNotClosed(); +- +- List<byte[]> maskByteArrays = new ArrayList<>(); +- List<ColoredLabel> coloredLabels = new ArrayList<>(); +- int[] maskShape = new int[2]; +- segmentNative(getNativeHandle(), frameBufferHandle, maskByteArrays, maskShape, coloredLabels); +- +- List<ByteBuffer> maskByteBuffers = new ArrayList<>(); +- for (byte[] bytes : maskByteArrays) { +- ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); +- // Change the byte order to little_endian, since the buffers were generated in jni. +- byteBuffer.order(ByteOrder.LITTLE_ENDIAN); +- maskByteBuffers.add(byteBuffer); ++ ++ private static ImageSegmenter createFromModelFdAndOptions(final int fileDescriptor, ++ final long fileDescriptorLength, final long fileDescriptorOffset, ++ final ImageSegmenterOptions options) { ++ long nativeHandle = TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength, ++ fileDescriptorOffset, options.getDisplayNamesLocale(), ++ options.getOutputType().getValue(), ++ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( ++ options.getBaseOptions(), options.getNumThreads())); ++ } ++ }, IMAGE_SEGMENTER_NATIVE_LIB); ++ return new ImageSegmenter(nativeHandle, options.getOutputType()); ++ } ++ ++ private static native long initJniWithModelFdAndOptions(int fileDescriptor, ++ long fileDescriptorLength, long fileDescriptorOffset, String displayNamesLocale, ++ int outputType, long baseOptionsHandle); ++ ++ private static native long initJniWithByteBuffer(ByteBuffer modelBuffer, ++ String displayNamesLocale, int outputType, long baseOptionsHandle); ++ ++ /** ++ * The native method to segment the image. ++ * ++ * <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the ++ * native layer. ++ */ ++ private static native void segmentNative(long nativeHandle, long frameBufferHandle, ++ List<byte[]> maskByteArrays, int[] maskShape, List<ColoredLabel> coloredLabels); ++ ++ @Override ++ protected void deinit(long nativeHandle) { ++ deinitJni(nativeHandle); + } + +- return Arrays.asList( +- Segmentation.create( +- outputType, +- outputType.createMasksFromBuffer(maskByteBuffers, maskShape), +- coloredLabels)); +- } +- +- private static ImageSegmenter createFromModelFdAndOptions( +- final int fileDescriptor, +- final long fileDescriptorLength, +- final long fileDescriptorOffset, +- final ImageSegmenterOptions options) { +- long nativeHandle = +- TaskJniUtils.createHandleFromLibrary( +- new EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithModelFdAndOptions( +- fileDescriptor, +- fileDescriptorLength, +- fileDescriptorOffset, +- options.getDisplayNamesLocale(), +- options.getOutputType().getValue(), +- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( +- options.getBaseOptions(), options.getNumThreads())); +- } +- }, +- IMAGE_SEGMENTER_NATIVE_LIB); +- return new ImageSegmenter(nativeHandle, options.getOutputType()); +- } +- +- private static native long initJniWithModelFdAndOptions( +- int fileDescriptor, +- long fileDescriptorLength, +- long fileDescriptorOffset, +- String displayNamesLocale, +- int outputType, +- long baseOptionsHandle); +- +- private static native long initJniWithByteBuffer( +- ByteBuffer modelBuffer, String displayNamesLocale, int outputType, long baseOptionsHandle); +- +- /** +- * The native method to segment the image. +- * +- * <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the native +- * layer. +- */ +- private static native void segmentNative( +- long nativeHandle, +- long frameBufferHandle, +- List<byte[]> maskByteArrays, +- int[] maskShape, +- List<ColoredLabel> coloredLabels); +- +- @Override +- protected void deinit(long nativeHandle) { +- deinitJni(nativeHandle); +- } +- +- /** +- * Native implementation to release memory pointed by the pointer. +- * +- * @param nativeHandle pointer to memory allocated +- */ +- private native void deinitJni(long nativeHandle); ++ /** ++ * Native implementation to release memory pointed by the pointer. ++ * ++ * @param nativeHandle pointer to memory allocated ++ */ ++ private native void deinitJni(long nativeHandle); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java +index 26ace1eaa1783..8c69cf5d152a0 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java +@@ -20,126 +20,128 @@ import static org.tensorflow.lite.DataType.UINT8; + import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; + import static org.tensorflow.lite.support.image.ColorSpaceType.GRAYSCALE; + ++import org.tensorflow.lite.support.image.TensorImage; ++import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; ++ + import java.nio.ByteBuffer; + import java.util.ArrayList; + import java.util.List; +-import org.tensorflow.lite.support.image.TensorImage; +-import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + + /** + * Output mask type. This allows specifying the type of post-processing to perform on the raw model + * results. + */ + public enum OutputType { +- +- /** +- * Gives a single output mask where each pixel represents the class which the pixel in the +- * original image was predicted to belong to. +- */ +- CATEGORY_MASK(0) { + /** +- * {@inheritDoc} +- * +- * @throws IllegalArgumentException if more than one {@link TensorImage} are provided, or if the +- * color space of the {@link TensorImage} is not {@link ColorSpaceType#GRAYSCALE} ++ * Gives a single output mask where each pixel represents the class which the pixel in the ++ * original image was predicted to belong to. + */ +- @Override +- void assertMasksMatchColoredLabels(List<TensorImage> masks, List<ColoredLabel> coloredLabels) { +- checkArgument( +- masks.size() == 1, +- "CATRGORY_MASK only allows one TensorImage in the list, providing " + masks.size()); +- +- TensorImage mask = masks.get(0); +- checkArgument( +- mask.getColorSpaceType() == GRAYSCALE, +- "CATRGORY_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing " +- + mask.getColorSpaceType()); +- } ++ CATEGORY_MASK(0) { ++ /** ++ * {@inheritDoc} ++ * ++ * @throws IllegalArgumentException if more than one {@link TensorImage} are provided, or if ++ * the ++ * color space of the {@link TensorImage} is not {@link ColorSpaceType#GRAYSCALE} ++ */ ++ @Override ++ void assertMasksMatchColoredLabels( ++ List<TensorImage> masks, List<ColoredLabel> coloredLabels) { ++ checkArgument(masks.size() == 1, ++ "CATRGORY_MASK only allows one TensorImage in the list, providing " ++ + masks.size()); ++ ++ TensorImage mask = masks.get(0); ++ checkArgument(mask.getColorSpaceType() == GRAYSCALE, ++ "CATRGORY_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing " ++ + mask.getColorSpaceType()); ++ } ++ ++ /** ++ * {@inheritDoc} ++ * ++ * @throws IllegalArgumentException if more than one {@link ByteBuffer} are provided in the ++ * list ++ */ ++ @Override ++ List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) { ++ checkArgument(buffers.size() == 1, ++ "CATRGORY_MASK only allows one mask in the buffer list, providing " ++ + buffers.size()); ++ ++ List<TensorImage> masks = new ArrayList<>(); ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(UINT8); ++ tensorBuffer.loadBuffer(buffers.get(0), maskShape); ++ TensorImage tensorImage = new TensorImage(UINT8); ++ tensorImage.load(tensorBuffer, GRAYSCALE); ++ masks.add(tensorImage); ++ ++ return masks; ++ } ++ }, + + /** +- * {@inheritDoc} +- * +- * @throws IllegalArgumentException if more than one {@link ByteBuffer} are provided in the list ++ * Gives a list of output masks where, for each mask, each pixel represents the prediction ++ * confidence, usually in the [0, 1] range. + */ +- @Override +- List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) { +- checkArgument( +- buffers.size() == 1, +- "CATRGORY_MASK only allows one mask in the buffer list, providing " + buffers.size()); +- +- List<TensorImage> masks = new ArrayList<>(); +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(UINT8); +- tensorBuffer.loadBuffer(buffers.get(0), maskShape); +- TensorImage tensorImage = new TensorImage(UINT8); +- tensorImage.load(tensorBuffer, GRAYSCALE); +- masks.add(tensorImage); +- +- return masks; ++ CONFIDENCE_MASK(1) { ++ /** ++ * {@inheritDoc} ++ * ++ * @throws IllegalArgumentException if more the size of the masks list does not match the ++ * size ++ * of the coloredlabels list, or if the color space type of the any mask is not {@link ++ * ColorSpaceType#GRAYSCALE} ++ */ ++ @Override ++ void assertMasksMatchColoredLabels( ++ List<TensorImage> masks, List<ColoredLabel> coloredLabels) { ++ checkArgument(masks.size() == coloredLabels.size(), ++ String.format( ++ "When using CONFIDENCE_MASK, the number of masks (%d) should match the number of" ++ + " coloredLabels (%d).", ++ masks.size(), coloredLabels.size())); ++ ++ for (TensorImage mask : masks) { ++ checkArgument(mask.getColorSpaceType() == GRAYSCALE, ++ "CONFIDENCE_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing " ++ + mask.getColorSpaceType()); ++ } ++ } ++ ++ @Override ++ List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) { ++ List<TensorImage> masks = new ArrayList<>(); ++ for (ByteBuffer buffer : buffers) { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(FLOAT32); ++ tensorBuffer.loadBuffer(buffer, maskShape); ++ TensorImage tensorImage = new TensorImage(FLOAT32); ++ tensorImage.load(tensorBuffer, GRAYSCALE); ++ masks.add(tensorImage); ++ } ++ return masks; ++ } ++ }; ++ ++ public int getValue() { ++ return value; + } +- }, + +- /** +- * Gives a list of output masks where, for each mask, each pixel represents the prediction +- * confidence, usually in the [0, 1] range. +- */ +- CONFIDENCE_MASK(1) { + /** +- * {@inheritDoc} ++ * Verifies that the given list of masks matches the list of colored labels. + * +- * @throws IllegalArgumentException if more the size of the masks list does not match the size +- * of the coloredlabels list, or if the color space type of the any mask is not {@link +- * ColorSpaceType#GRAYSCALE} ++ * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the ++ * output type + */ +- @Override +- void assertMasksMatchColoredLabels(List<TensorImage> masks, List<ColoredLabel> coloredLabels) { +- checkArgument( +- masks.size() == coloredLabels.size(), +- String.format( +- "When using CONFIDENCE_MASK, the number of masks (%d) should match the number of" +- + " coloredLabels (%d).", +- masks.size(), coloredLabels.size())); +- +- for (TensorImage mask : masks) { +- checkArgument( +- mask.getColorSpaceType() == GRAYSCALE, +- "CONFIDENCE_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing " +- + mask.getColorSpaceType()); +- } +- } +- +- @Override +- List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) { +- List<TensorImage> masks = new ArrayList<>(); +- for (ByteBuffer buffer : buffers) { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(FLOAT32); +- tensorBuffer.loadBuffer(buffer, maskShape); +- TensorImage tensorImage = new TensorImage(FLOAT32); +- tensorImage.load(tensorBuffer, GRAYSCALE); +- masks.add(tensorImage); +- } +- return masks; +- } +- }; ++ abstract void assertMasksMatchColoredLabels( ++ List<TensorImage> masks, List<ColoredLabel> coloredLabels); + +- public int getValue() { +- return value; +- } ++ /** Creates the masks in {@link TensorImage} based on the data in {@link ByteBuffer}. */ ++ abstract List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape); + +- /** +- * Verifies that the given list of masks matches the list of colored labels. +- * +- * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the +- * output type +- */ +- abstract void assertMasksMatchColoredLabels( +- List<TensorImage> masks, List<ColoredLabel> coloredLabels); ++ private final int value; + +- /** Creates the masks in {@link TensorImage} based on the data in {@link ByteBuffer}. */ +- abstract List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape); +- +- private final int value; +- +- private OutputType(int value) { +- this.value = value; +- } ++ private OutputType(int value) { ++ this.value = value; ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java +index 018482c7e82db..f5062bc8745f0 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java +@@ -16,67 +16,69 @@ limitations under the License. + package org.tensorflow.lite.task.vision.segmenter; + + import com.google.auto.value.AutoValue; ++ ++import org.tensorflow.lite.support.image.TensorImage; ++ + import java.util.ArrayList; + import java.util.Collections; + import java.util.List; +-import org.tensorflow.lite.support.image.TensorImage; + + /** Represents the segmentation result of an {@link ImageSegmenter}. */ + @AutoValue + public abstract class Segmentation { ++ /** ++ * Creates a {@link Segmentation} object. ++ * ++ * <p>{@link Segmentation} provides two types of outputs as indicated through {@link ++ * OutputType}: ++ * ++ * <p>{@link OutputType#CATEGORY_MASK}: the result contains a single category mask, which is a ++ * grayscale {@link TensorImage} with shape (height, width), in row major order. The value of ++ * each pixel in this mask represents the class to which the pixel in the mask belongs. The ++ * pixel values are in 1:1 corresponding with the colored labels, i.e. a pixel with value {@code ++ * i} is associated with {@code coloredLabels.get(i)}. ++ * ++ * <p>{@link OutputType#CONFIDENCE_MASK}: the result contains a list of confidence masks, which ++ * are in 1:1 correspondance with the colored labels, i.e. {@link masks.get(i)} is associated ++ * with ++ * {@code coloredLabels.get(i)}. Each confidence mask is a grayscale {@link TensorImage} with ++ * shape (height, width), in row major order. The value of each pixel in these masks represents ++ * the confidence score for this particular class. ++ * ++ * <p>IMPORTANT: segmentation masks are not direcly suited for display, in particular:<br> ++ * \* they are relative to the unrotated input frame, i.e. *not* taking into account the {@code ++ * Orientation} flag of the input FrameBuffer, <br> ++ * \* their dimensions are intrinsic to the model, i.e. *not* dependent on the input FrameBuffer ++ * dimensions. ++ * ++ * <p>Example of such post-processing, assuming: <br> ++ * \* an input FrameBuffer with width=640, height=480, orientation=kLeftBottom (i.e. the image ++ * will be rotated 90° clockwise during preprocessing to make it "upright"), <br> ++ * \* a model outputting masks of size 224x224. <br> ++ * In order to be directly displayable on top of the input image assumed to be displayed *with* ++ * the {@code Orientation} flag taken into account (according to the <a ++ * href="http://jpegclub.org/exif_orientation.html">EXIF specification</a>), the masks need to ++ * be: re-scaled to 640 x 480, then rotated 90° clockwise. ++ * ++ * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the ++ * {@code outputType} ++ */ ++ static Segmentation create( ++ OutputType outputType, List<TensorImage> masks, List<ColoredLabel> coloredLabels) { ++ outputType.assertMasksMatchColoredLabels(masks, coloredLabels); + +- /** +- * Creates a {@link Segmentation} object. +- * +- * <p>{@link Segmentation} provides two types of outputs as indicated through {@link OutputType}: +- * +- * <p>{@link OutputType#CATEGORY_MASK}: the result contains a single category mask, which is a +- * grayscale {@link TensorImage} with shape (height, width), in row major order. The value of each +- * pixel in this mask represents the class to which the pixel in the mask belongs. The pixel +- * values are in 1:1 corresponding with the colored labels, i.e. a pixel with value {@code i} is +- * associated with {@code coloredLabels.get(i)}. +- * +- * <p>{@link OutputType#CONFIDENCE_MASK}: the result contains a list of confidence masks, which +- * are in 1:1 correspondance with the colored labels, i.e. {@link masks.get(i)} is associated with +- * {@code coloredLabels.get(i)}. Each confidence mask is a grayscale {@link TensorImage} with +- * shape (height, width), in row major order. The value of each pixel in these masks represents +- * the confidence score for this particular class. +- * +- * <p>IMPORTANT: segmentation masks are not direcly suited for display, in particular:<br> +- * \* they are relative to the unrotated input frame, i.e. *not* taking into account the {@code +- * Orientation} flag of the input FrameBuffer, <br> +- * \* their dimensions are intrinsic to the model, i.e. *not* dependent on the input FrameBuffer +- * dimensions. +- * +- * <p>Example of such post-processing, assuming: <br> +- * \* an input FrameBuffer with width=640, height=480, orientation=kLeftBottom (i.e. the image +- * will be rotated 90° clockwise during preprocessing to make it "upright"), <br> +- * \* a model outputting masks of size 224x224. <br> +- * In order to be directly displayable on top of the input image assumed to be displayed *with* +- * the {@code Orientation} flag taken into account (according to the <a +- * href="http://jpegclub.org/exif_orientation.html">EXIF specification</a>), the masks need to be: +- * re-scaled to 640 x 480, then rotated 90° clockwise. +- * +- * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the +- * {@code outputType} +- */ +- static Segmentation create( +- OutputType outputType, List<TensorImage> masks, List<ColoredLabel> coloredLabels) { +- outputType.assertMasksMatchColoredLabels(masks, coloredLabels); +- +- return new AutoValue_Segmentation( +- outputType, +- Collections.unmodifiableList(new ArrayList<TensorImage>(masks)), +- Collections.unmodifiableList(new ArrayList<ColoredLabel>(coloredLabels))); +- } ++ return new AutoValue_Segmentation(outputType, ++ Collections.unmodifiableList(new ArrayList<TensorImage>(masks)), ++ Collections.unmodifiableList(new ArrayList<ColoredLabel>(coloredLabels))); ++ } + +- public abstract OutputType getOutputType(); ++ public abstract OutputType getOutputType(); + +- // As an open source project, we've been trying avoiding depending on common java libraries, +- // such as Guava, because it may introduce conflicts with clients who also happen to use those +- // libraries. Therefore, instead of using ImmutableList here, we convert the List into +- // unmodifiableList in create() to make it less vulnerable. +- public abstract List<TensorImage> getMasks(); ++ // As an open source project, we've been trying avoiding depending on common java libraries, ++ // such as Guava, because it may introduce conflicts with clients who also happen to use those ++ // libraries. Therefore, instead of using ImmutableList here, we convert the List into ++ // unmodifiableList in create() to make it less vulnerable. ++ public abstract List<TensorImage> getMasks(); + +- public abstract List<ColoredLabel> getColoredLabels(); ++ public abstract List<ColoredLabel> getColoredLabels(); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java +index edbb5d82db2c1..903f7913219bf 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java +@@ -16,6 +16,7 @@ limitations under the License. + package org.tensorflow.lite.support.audio; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + import static org.mockito.ArgumentMatchers.any; + import static org.mockito.ArgumentMatchers.anyInt; +@@ -25,6 +26,7 @@ import static org.mockito.Mockito.when; + + import android.media.AudioFormat; + import android.media.AudioRecord; ++ + import org.junit.Test; + import org.junit.runner.RunWith; + import org.junit.runners.Suite; +@@ -35,249 +37,249 @@ import org.tensorflow.lite.support.audio.TensorAudio.TensorAudioFormat; + /** Test for {@link TensorAudio}. */ + @RunWith(Suite.class) + @SuiteClasses({ +- TensorAudioTest.General.class, ++ TensorAudioTest.General.class, + }) + public class TensorAudioTest { +- +- /** General tests of TensorAudio. */ +- @RunWith(RobolectricTestRunner.class) +- public static final class General extends TensorAudioTest { +- @Test +- public void createSucceedsWithTensorAudioFormat() throws Exception { +- TensorAudio tensor = +- TensorAudio.create( +- TensorAudioFormat.builder().setChannels(1).setSampleRate(2).build(), 100); +- assertThat(tensor.getFormat().getChannels()).isEqualTo(1); +- assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2); +- assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(100); +- } +- +- @Test +- public void createSucceedsWithTensorAudioFormatWithMultipleChannels() throws Exception { +- TensorAudio tensor = +- TensorAudio.create( +- TensorAudioFormat.builder().setChannels(5).setSampleRate(2).build(), 100); +- assertThat(tensor.getFormat().getChannels()).isEqualTo(5); +- assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2); +- assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(500); +- } +- +- @Test +- public void createSucceededsWithDefaultArguments() throws Exception { +- TensorAudio tensor = +- TensorAudio.create(TensorAudioFormat.builder().setSampleRate(20).build(), 1000); +- // Number of channels defaults to 1. +- assertThat(tensor.getFormat().getChannels()).isEqualTo(1); +- assertThat(tensor.getFormat().getSampleRate()).isEqualTo(20); +- assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(1000); +- } +- +- @Test +- public void createSucceedsWithAudioFormat() throws Exception { +- AudioFormat format = +- new AudioFormat.Builder() +- .setChannelMask(AudioFormat.CHANNEL_IN_STEREO) +- .setEncoding(AudioFormat.ENCODING_PCM_16BIT) +- .setSampleRate(16000) +- .build(); +- TensorAudio tensor = TensorAudio.create(format, 100); +- // STEREO has 2 channels +- assertThat(tensor.getFormat().getChannels()).isEqualTo(2); +- assertThat(tensor.getFormat().getSampleRate()).isEqualTo(16000); +- // flatSize = channelCount * sampleCount +- assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(200); +- } +- +- @Test +- public void createFailedWithInvalidSampleRate() throws Exception { +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, +- () -> TensorAudio.create(TensorAudioFormat.builder().setSampleRate(0).build(), 100)); +- // Sample rate 0 is not allowed +- assertThat(exception).hasMessageThat().ignoringCase().contains("sample rate"); +- } +- +- @Test +- public void createFailedWithInvalidChannels() throws Exception { +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, +- () -> +- TensorAudio.create( +- TensorAudioFormat.builder().setSampleRate(1).setChannels(-1).build(), 100)); +- // Negative channels is not allowed +- assertThat(exception).hasMessageThat().ignoringCase().contains("channels"); +- } +- +- @Test +- public void loadSucceedsFromArray() throws Exception { +- TensorAudioFormat format = +- TensorAudioFormat.builder().setChannels(2).setSampleRate(2).build(); +- TensorAudio tensor = TensorAudio.create(format, 2); +- assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[4]); +- +- tensor.load(new float[] {2.f, 0}); +- assertThat(tensor.getTensorBuffer().getFloatArray()) +- .usingTolerance(0.001f) +- .containsExactly(new float[] {0, 0, 2.f, 0}); +- +- tensor.load(new float[] {2.f, 3.f}, 0, 2); +- assertThat(tensor.getTensorBuffer().getFloatArray()) +- .usingTolerance(0.001f) +- .containsExactly(new float[] {2.f, 0, 2.f, 3.f}); +- +- tensor.load(new short[] {Short.MAX_VALUE, Short.MIN_VALUE}); +- assertThat(tensor.getTensorBuffer().getFloatArray()) +- .usingTolerance(0.001f) +- .containsExactly(new float[] {2.f, 3.f, 1.f, -1.f}); +- +- tensor.load(new short[] {1, 2, 3, 0, 1, Short.MIN_VALUE, 3, 4, 5}, 3, 6); +- // The entire sequence becomes {2.f, 0, 2.f, 3.f, 1.f, -1.f, 0, 0, -1.f, 0, 0, 0} but the ring +- // buffer is only keep the last 4 results. +- assertThat(tensor.getTensorBuffer().getFloatArray()) +- .usingTolerance(0.001f) +- .containsExactly(new float[] {-1.f, 0, 0, 0}); +- } +- +- @Test +- public void loadFailsWithIndexOutOfRange() throws Exception { +- TensorAudioFormat format = TensorAudioFormat.builder().setSampleRate(2).build(); +- TensorAudio tensor = TensorAudio.create(format, 5); +- +- assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[100], 99, 2)); +- +- assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[100], 99, 2)); +- } +- +- @Test +- public void loadFailsWithIncompatibleInputSize() throws Exception { +- TensorAudioFormat format = +- TensorAudioFormat.builder().setChannels(3).setSampleRate(2).build(); +- TensorAudio tensor = TensorAudio.create(format, 5); +- +- assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[1])); +- +- assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[2])); +- +- assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[2], 1, 1)); +- +- assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[5], 2, 4)); +- } +- +- @Test +- public void loadAudioRecordSucceeds() throws Exception { +- TensorAudio tensor = +- TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4); +- tensor.load(new float[] {1, 2, 3, 4, 5}); +- assertThat(tensor.getTensorBuffer().getFloatArray()) +- .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f}); +- +- AudioRecord record = mock(AudioRecord.class); +- when(record.getBufferSizeInFrames()).thenReturn(5); +- when(record.getChannelCount()).thenReturn(1); +- when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT); +- when(record.getFormat()) +- .thenReturn( +- new AudioFormat.Builder() +- .setChannelMask(AudioFormat.CHANNEL_IN_MONO) +- .setEncoding(AudioFormat.ENCODING_PCM_FLOAT) +- .setSampleRate(16000) +- .build()); +- // Unused +- when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) +- .thenReturn(AudioRecord.ERROR_INVALID_OPERATION); +- // Used +- when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) +- .thenReturn(1); +- assertThat(tensor.load(record)).isEqualTo(1); +- assertThat(tensor.getTensorBuffer().getFloatArray()) +- .isEqualTo(new float[] {3.f, 4.f, 5.f, 0}); +- +- record = mock(AudioRecord.class); +- when(record.getBufferSizeInFrames()).thenReturn(5); +- when(record.getChannelCount()).thenReturn(1); +- when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_16BIT); +- when(record.getFormat()) +- .thenReturn( +- new AudioFormat.Builder() +- .setChannelMask(AudioFormat.CHANNEL_IN_MONO) +- .setEncoding(AudioFormat.ENCODING_PCM_16BIT) +- .setSampleRate(16000) +- .build()); +- // Used +- when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) +- .thenReturn(2); +- // Unused +- when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) +- .thenReturn(AudioRecord.ERROR_INVALID_OPERATION); +- assertThat(tensor.load(record)).isEqualTo(2); +- assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[] {5.f, 0, 0, 0}); +- } +- +- @Test +- public void loadAudioRecordFailsWithErrorState() throws Exception { +- TensorAudio tensor = +- TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4); +- tensor.load(new float[] {1, 2, 3, 4, 5}); +- assertThat(tensor.getTensorBuffer().getFloatArray()) +- .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f}); +- +- AudioRecord record = mock(AudioRecord.class); +- when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT); +- when(record.getFormat()) +- .thenReturn( +- new AudioFormat.Builder() +- .setChannelMask(AudioFormat.CHANNEL_IN_MONO) +- .setEncoding(AudioFormat.ENCODING_PCM_FLOAT) +- .setSampleRate(16000) +- .build()); +- // Unused +- when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) +- .thenReturn(AudioRecord.ERROR_INVALID_OPERATION); +- // Used +- when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) +- .thenReturn(AudioRecord.ERROR_DEAD_OBJECT); +- IllegalStateException exception = +- assertThrows(IllegalStateException.class, () -> tensor.load(record)); +- assertThat(exception).hasMessageThat().contains("ERROR_DEAD_OBJECT"); +- } +- +- @Test +- public void loadAudioRecordFailsWithUnsupportedAudioEncoding() throws Exception { +- TensorAudio tensor = +- TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4); +- AudioRecord record = mock(AudioRecord.class); +- when(record.getFormat()) +- .thenReturn( +- new AudioFormat.Builder() +- .setChannelMask(AudioFormat.CHANNEL_IN_MONO) +- .setEncoding(AudioFormat.ENCODING_PCM_8BIT) // Not supported +- .setSampleRate(16000) +- .build()); +- when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_8BIT); +- +- IllegalArgumentException exception = +- assertThrows(IllegalArgumentException.class, () -> tensor.load(record)); +- assertThat(exception).hasMessageThat().ignoringCase().contains("unsupported encoding"); +- } +- +- @Test +- public void loadAudioRecordFailsWithIncompatibleAudioFormat() throws Exception { +- TensorAudio tensor = +- TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4); +- AudioRecord record = mock(AudioRecord.class); +- when(record.getFormat()) +- .thenReturn( +- new AudioFormat.Builder() +- .setChannelMask(AudioFormat.CHANNEL_IN_MONO) +- .setEncoding(AudioFormat.ENCODING_PCM_FLOAT) +- .setSampleRate(44100) // Mismatch +- .build()); +- +- IllegalArgumentException exception = +- assertThrows(IllegalArgumentException.class, () -> tensor.load(record)); +- assertThat(exception).hasMessageThat().ignoringCase().contains("Incompatible audio format"); ++ /** General tests of TensorAudio. */ ++ @RunWith(RobolectricTestRunner.class) ++ public static final class General extends TensorAudioTest { ++ @Test ++ public void createSucceedsWithTensorAudioFormat() throws Exception { ++ TensorAudio tensor = TensorAudio.create( ++ TensorAudioFormat.builder().setChannels(1).setSampleRate(2).build(), 100); ++ assertThat(tensor.getFormat().getChannels()).isEqualTo(1); ++ assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2); ++ assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(100); ++ } ++ ++ @Test ++ public void createSucceedsWithTensorAudioFormatWithMultipleChannels() throws Exception { ++ TensorAudio tensor = TensorAudio.create( ++ TensorAudioFormat.builder().setChannels(5).setSampleRate(2).build(), 100); ++ assertThat(tensor.getFormat().getChannels()).isEqualTo(5); ++ assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2); ++ assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(500); ++ } ++ ++ @Test ++ public void createSucceededsWithDefaultArguments() throws Exception { ++ TensorAudio tensor = ++ TensorAudio.create(TensorAudioFormat.builder().setSampleRate(20).build(), 1000); ++ // Number of channels defaults to 1. ++ assertThat(tensor.getFormat().getChannels()).isEqualTo(1); ++ assertThat(tensor.getFormat().getSampleRate()).isEqualTo(20); ++ assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(1000); ++ } ++ ++ @Test ++ public void createSucceedsWithAudioFormat() throws Exception { ++ AudioFormat format = new AudioFormat.Builder() ++ .setChannelMask(AudioFormat.CHANNEL_IN_STEREO) ++ .setEncoding(AudioFormat.ENCODING_PCM_16BIT) ++ .setSampleRate(16000) ++ .build(); ++ TensorAudio tensor = TensorAudio.create(format, 100); ++ // STEREO has 2 channels ++ assertThat(tensor.getFormat().getChannels()).isEqualTo(2); ++ assertThat(tensor.getFormat().getSampleRate()).isEqualTo(16000); ++ // flatSize = channelCount * sampleCount ++ assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(200); ++ } ++ ++ @Test ++ public void createFailedWithInvalidSampleRate() throws Exception { ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () ++ -> TensorAudio.create( ++ TensorAudioFormat.builder().setSampleRate(0).build(), 100)); ++ // Sample rate 0 is not allowed ++ assertThat(exception).hasMessageThat().ignoringCase().contains("sample rate"); ++ } ++ ++ @Test ++ public void createFailedWithInvalidChannels() throws Exception { ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () ++ -> TensorAudio.create(TensorAudioFormat.builder() ++ .setSampleRate(1) ++ .setChannels(-1) ++ .build(), ++ 100)); ++ // Negative channels is not allowed ++ assertThat(exception).hasMessageThat().ignoringCase().contains("channels"); ++ } ++ ++ @Test ++ public void loadSucceedsFromArray() throws Exception { ++ TensorAudioFormat format = ++ TensorAudioFormat.builder().setChannels(2).setSampleRate(2).build(); ++ TensorAudio tensor = TensorAudio.create(format, 2); ++ assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[4]); ++ ++ tensor.load(new float[] {2.f, 0}); ++ assertThat(tensor.getTensorBuffer().getFloatArray()) ++ .usingTolerance(0.001f) ++ .containsExactly(new float[] {0, 0, 2.f, 0}); ++ ++ tensor.load(new float[] {2.f, 3.f}, 0, 2); ++ assertThat(tensor.getTensorBuffer().getFloatArray()) ++ .usingTolerance(0.001f) ++ .containsExactly(new float[] {2.f, 0, 2.f, 3.f}); ++ ++ tensor.load(new short[] {Short.MAX_VALUE, Short.MIN_VALUE}); ++ assertThat(tensor.getTensorBuffer().getFloatArray()) ++ .usingTolerance(0.001f) ++ .containsExactly(new float[] {2.f, 3.f, 1.f, -1.f}); ++ ++ tensor.load(new short[] {1, 2, 3, 0, 1, Short.MIN_VALUE, 3, 4, 5}, 3, 6); ++ // The entire sequence becomes {2.f, 0, 2.f, 3.f, 1.f, -1.f, 0, 0, -1.f, 0, 0, 0} but ++ // the ring buffer is only keep the last 4 results. ++ assertThat(tensor.getTensorBuffer().getFloatArray()) ++ .usingTolerance(0.001f) ++ .containsExactly(new float[] {-1.f, 0, 0, 0}); ++ } ++ ++ @Test ++ public void loadFailsWithIndexOutOfRange() throws Exception { ++ TensorAudioFormat format = TensorAudioFormat.builder().setSampleRate(2).build(); ++ TensorAudio tensor = TensorAudio.create(format, 5); ++ ++ assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[100], 99, 2)); ++ ++ assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[100], 99, 2)); ++ } ++ ++ @Test ++ public void loadFailsWithIncompatibleInputSize() throws Exception { ++ TensorAudioFormat format = ++ TensorAudioFormat.builder().setChannels(3).setSampleRate(2).build(); ++ TensorAudio tensor = TensorAudio.create(format, 5); ++ ++ assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[1])); ++ ++ assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[2])); ++ ++ assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[2], 1, 1)); ++ ++ assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[5], 2, 4)); ++ } ++ ++ @Test ++ public void loadAudioRecordSucceeds() throws Exception { ++ TensorAudio tensor = ++ TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4); ++ tensor.load(new float[] {1, 2, 3, 4, 5}); ++ assertThat(tensor.getTensorBuffer().getFloatArray()) ++ .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f}); ++ ++ AudioRecord record = mock(AudioRecord.class); ++ when(record.getBufferSizeInFrames()).thenReturn(5); ++ when(record.getChannelCount()).thenReturn(1); ++ when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT); ++ when(record.getFormat()) ++ .thenReturn(new AudioFormat.Builder() ++ .setChannelMask(AudioFormat.CHANNEL_IN_MONO) ++ .setEncoding(AudioFormat.ENCODING_PCM_FLOAT) ++ .setSampleRate(16000) ++ .build()); ++ // Unused ++ when(record.read( ++ any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) ++ .thenReturn(AudioRecord.ERROR_INVALID_OPERATION); ++ // Used ++ when(record.read( ++ any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) ++ .thenReturn(1); ++ assertThat(tensor.load(record)).isEqualTo(1); ++ assertThat(tensor.getTensorBuffer().getFloatArray()) ++ .isEqualTo(new float[] {3.f, 4.f, 5.f, 0}); ++ ++ record = mock(AudioRecord.class); ++ when(record.getBufferSizeInFrames()).thenReturn(5); ++ when(record.getChannelCount()).thenReturn(1); ++ when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_16BIT); ++ when(record.getFormat()) ++ .thenReturn(new AudioFormat.Builder() ++ .setChannelMask(AudioFormat.CHANNEL_IN_MONO) ++ .setEncoding(AudioFormat.ENCODING_PCM_16BIT) ++ .setSampleRate(16000) ++ .build()); ++ // Used ++ when(record.read( ++ any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) ++ .thenReturn(2); ++ // Unused ++ when(record.read( ++ any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) ++ .thenReturn(AudioRecord.ERROR_INVALID_OPERATION); ++ assertThat(tensor.load(record)).isEqualTo(2); ++ assertThat(tensor.getTensorBuffer().getFloatArray()) ++ .isEqualTo(new float[] {5.f, 0, 0, 0}); ++ } ++ ++ @Test ++ public void loadAudioRecordFailsWithErrorState() throws Exception { ++ TensorAudio tensor = ++ TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4); ++ tensor.load(new float[] {1, 2, 3, 4, 5}); ++ assertThat(tensor.getTensorBuffer().getFloatArray()) ++ .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f}); ++ ++ AudioRecord record = mock(AudioRecord.class); ++ when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT); ++ when(record.getFormat()) ++ .thenReturn(new AudioFormat.Builder() ++ .setChannelMask(AudioFormat.CHANNEL_IN_MONO) ++ .setEncoding(AudioFormat.ENCODING_PCM_FLOAT) ++ .setSampleRate(16000) ++ .build()); ++ // Unused ++ when(record.read( ++ any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) ++ .thenReturn(AudioRecord.ERROR_INVALID_OPERATION); ++ // Used ++ when(record.read( ++ any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) ++ .thenReturn(AudioRecord.ERROR_DEAD_OBJECT); ++ IllegalStateException exception = ++ assertThrows(IllegalStateException.class, () -> tensor.load(record)); ++ assertThat(exception).hasMessageThat().contains("ERROR_DEAD_OBJECT"); ++ } ++ ++ @Test ++ public void loadAudioRecordFailsWithUnsupportedAudioEncoding() throws Exception { ++ TensorAudio tensor = ++ TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4); ++ AudioRecord record = mock(AudioRecord.class); ++ when(record.getFormat()) ++ .thenReturn(new AudioFormat.Builder() ++ .setChannelMask(AudioFormat.CHANNEL_IN_MONO) ++ .setEncoding(AudioFormat.ENCODING_PCM_8BIT) // Not supported ++ .setSampleRate(16000) ++ .build()); ++ when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_8BIT); ++ ++ IllegalArgumentException exception = ++ assertThrows(IllegalArgumentException.class, () -> tensor.load(record)); ++ assertThat(exception).hasMessageThat().ignoringCase().contains("unsupported encoding"); ++ } ++ ++ @Test ++ public void loadAudioRecordFailsWithIncompatibleAudioFormat() throws Exception { ++ TensorAudio tensor = ++ TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4); ++ AudioRecord record = mock(AudioRecord.class); ++ when(record.getFormat()) ++ .thenReturn(new AudioFormat.Builder() ++ .setChannelMask(AudioFormat.CHANNEL_IN_MONO) ++ .setEncoding(AudioFormat.ENCODING_PCM_FLOAT) ++ .setSampleRate(44100) // Mismatch ++ .build()); ++ ++ IllegalArgumentException exception = ++ assertThrows(IllegalArgumentException.class, () -> tensor.load(record)); ++ assertThat(exception).hasMessageThat().ignoringCase().contains( ++ "Incompatible audio format"); ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java +index d97665d1ed771..1d26476733c98 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java +@@ -18,78 +18,81 @@ package org.tensorflow.lite.support.common; + import static com.google.common.truth.Truth.assertThat; + + import android.content.Context; ++ + import androidx.test.core.app.ApplicationProvider; ++ ++import org.junit.Assert; ++import org.junit.Test; ++import org.junit.runner.RunWith; ++import org.robolectric.RobolectricTestRunner; ++ + import java.io.ByteArrayInputStream; + import java.io.IOException; + import java.io.InputStream; + import java.nio.MappedByteBuffer; + import java.nio.charset.Charset; + import java.util.List; +-import org.junit.Assert; +-import org.junit.Test; +-import org.junit.runner.RunWith; +-import org.robolectric.RobolectricTestRunner; + + /** Tests of {@link org.tensorflow.lite.support.common.FileUtil}. */ + @RunWith(RobolectricTestRunner.class) + public final class FileUtilTest { +- private final Context context = ApplicationProvider.getApplicationContext(); +- private static final String LABEL_PATH = "flower_labels.txt"; +- +- @Test +- public void testLoadLabels() throws IOException { +- List<String> labels = FileUtil.loadLabels(context, LABEL_PATH); +- assertThat(labels) +- .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips") +- .inOrder(); +- } +- +- @Test +- public void testLoadLabelsFromInputStream() throws IOException { +- InputStream inputStream = context.getAssets().open(LABEL_PATH); +- assertThat(FileUtil.loadLabels(inputStream)) +- .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips") +- .inOrder(); +- } +- +- @Test +- public void whitespaceLabelsShouldNotCount() throws IOException { +- String s = "a\nb\n \n\n\nc"; +- InputStream stream = new ByteArrayInputStream(s.getBytes(Charset.defaultCharset())); +- assertThat(FileUtil.loadLabels(stream)).hasSize(3); +- } +- +- @Test +- public void testLoadLabelsNullContext() throws IOException { +- Context nullContext = null; +- Assert.assertThrows( +- NullPointerException.class, () -> FileUtil.loadLabels(nullContext, LABEL_PATH)); +- } +- +- @Test +- public void testLoadLabelsNullFilePath() throws IOException { +- String nullFilePath = null; +- Assert.assertThrows( +- NullPointerException.class, () -> FileUtil.loadLabels(context, nullFilePath)); +- } +- +- @Test +- public void testLoadMappedFile() throws IOException { +- MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, LABEL_PATH); +- assertThat(byteModel).isNotNull(); +- } +- +- @Test +- public void testLoadMappedFileWithNullContext() throws IOException { +- Context nullContext = null; +- Assert.assertThrows( +- NullPointerException.class, () -> FileUtil.loadMappedFile(nullContext, LABEL_PATH)); +- } +- +- @Test +- public void loadMappedFileWithNullFilePath() throws IOException { +- String nullFilePath = null; +- Assert.assertThrows( +- NullPointerException.class, () -> FileUtil.loadMappedFile(context, nullFilePath)); +- } ++ private final Context context = ApplicationProvider.getApplicationContext(); ++ private static final String LABEL_PATH = "flower_labels.txt"; ++ ++ @Test ++ public void testLoadLabels() throws IOException { ++ List<String> labels = FileUtil.loadLabels(context, LABEL_PATH); ++ assertThat(labels) ++ .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips") ++ .inOrder(); ++ } ++ ++ @Test ++ public void testLoadLabelsFromInputStream() throws IOException { ++ InputStream inputStream = context.getAssets().open(LABEL_PATH); ++ assertThat(FileUtil.loadLabels(inputStream)) ++ .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips") ++ .inOrder(); ++ } ++ ++ @Test ++ public void whitespaceLabelsShouldNotCount() throws IOException { ++ String s = "a\nb\n \n\n\nc"; ++ InputStream stream = new ByteArrayInputStream(s.getBytes(Charset.defaultCharset())); ++ assertThat(FileUtil.loadLabels(stream)).hasSize(3); ++ } ++ ++ @Test ++ public void testLoadLabelsNullContext() throws IOException { ++ Context nullContext = null; ++ Assert.assertThrows( ++ NullPointerException.class, () -> FileUtil.loadLabels(nullContext, LABEL_PATH)); ++ } ++ ++ @Test ++ public void testLoadLabelsNullFilePath() throws IOException { ++ String nullFilePath = null; ++ Assert.assertThrows( ++ NullPointerException.class, () -> FileUtil.loadLabels(context, nullFilePath)); ++ } ++ ++ @Test ++ public void testLoadMappedFile() throws IOException { ++ MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, LABEL_PATH); ++ assertThat(byteModel).isNotNull(); ++ } ++ ++ @Test ++ public void testLoadMappedFileWithNullContext() throws IOException { ++ Context nullContext = null; ++ Assert.assertThrows( ++ NullPointerException.class, () -> FileUtil.loadMappedFile(nullContext, LABEL_PATH)); ++ } ++ ++ @Test ++ public void loadMappedFileWithNullFilePath() throws IOException { ++ String nullFilePath = null; ++ Assert.assertThrows( ++ NullPointerException.class, () -> FileUtil.loadMappedFile(context, nullFilePath)); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java +index 43a7f7cd1ce29..82f97f2534cf7 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java +@@ -27,59 +27,58 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + /** Tests for {@link TensorProcessor}. */ + @RunWith(RobolectricTestRunner.class) + public final class TensorProcessorTest { ++ private static final int EXAMPLE_NUM_FEATURES = 1000; ++ private static final float MEAN = 127.5f; ++ private static final float STDDEV = 127.5f; + +- private static final int EXAMPLE_NUM_FEATURES = 1000; +- private static final float MEAN = 127.5f; +- private static final float STDDEV = 127.5f; +- +- @Test +- public void testBuild() { +- TensorProcessor processor = +- new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build(); +- assertThat(processor).isNotNull(); +- } ++ @Test ++ public void testBuild() { ++ TensorProcessor processor = ++ new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build(); ++ assertThat(processor).isNotNull(); ++ } + +- @Test +- public void testNormalize() { +- TensorBuffer input = createExampleTensorBuffer(); +- TensorProcessor processor = +- new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build(); +- TensorBuffer output = processor.process(input); ++ @Test ++ public void testNormalize() { ++ TensorBuffer input = createExampleTensorBuffer(); ++ TensorProcessor processor = ++ new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build(); ++ TensorBuffer output = processor.process(input); + +- float[] pixels = output.getFloatArray(); +- assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES); +- for (float p : pixels) { +- assertThat(p).isAtLeast(-1); +- assertThat(p).isAtMost(1); ++ float[] pixels = output.getFloatArray(); ++ assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES); ++ for (float p : pixels) { ++ assertThat(p).isAtLeast(-1); ++ assertThat(p).isAtMost(1); ++ } + } +- } + +- @Test +- public void testMultipleNormalize() { +- TensorBuffer input = createExampleTensorBuffer(); +- TensorProcessor processor = +- new TensorProcessor.Builder() +- .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1] +- .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1] +- .build(); +- TensorBuffer output = processor.process(input); ++ @Test ++ public void testMultipleNormalize() { ++ TensorBuffer input = createExampleTensorBuffer(); ++ TensorProcessor processor = ++ new TensorProcessor.Builder() ++ .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1] ++ .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1] ++ .build(); ++ TensorBuffer output = processor.process(input); + +- float[] pixels = output.getFloatArray(); +- assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES); +- for (float p : pixels) { +- assertThat(p).isAtLeast(0); +- assertThat(p).isAtMost(1); ++ float[] pixels = output.getFloatArray(); ++ assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES); ++ for (float p : pixels) { ++ assertThat(p).isAtLeast(0); ++ assertThat(p).isAtMost(1); ++ } + } +- } + +- // Creates a TensorBuffer of size {1, 1000}, containing values in range [0, 255]. +- private static TensorBuffer createExampleTensorBuffer() { +- TensorBuffer buffer = TensorBuffer.createDynamic(DataType.FLOAT32); +- int[] features = new int[EXAMPLE_NUM_FEATURES]; +- for (int i = 0; i < EXAMPLE_NUM_FEATURES; i++) { +- features[i] = i % 256; ++ // Creates a TensorBuffer of size {1, 1000}, containing values in range [0, 255]. ++ private static TensorBuffer createExampleTensorBuffer() { ++ TensorBuffer buffer = TensorBuffer.createDynamic(DataType.FLOAT32); ++ int[] features = new int[EXAMPLE_NUM_FEATURES]; ++ for (int i = 0; i < EXAMPLE_NUM_FEATURES; i++) { ++ features[i] = i % 256; ++ } ++ buffer.loadArray(features, new int[] {1, EXAMPLE_NUM_FEATURES}); ++ return buffer; + } +- buffer.loadArray(features, new int[] {1, EXAMPLE_NUM_FEATURES}); +- return buffer; +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java +index a159c71863322..e8ba24d27550b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java +@@ -27,56 +27,55 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + /** Tests of {@link CastOp}. */ + @RunWith(RobolectricTestRunner.class) + public final class CastOpTest { ++ private static final float[] FLOAT_ARRAY = new float[] {1.1f, 3.3f, 5.5f, 7.7f, 9.9f}; ++ private static final float[] CASTED_FLOAT_ARRAY = new float[] {1.0f, 3.0f, 5.0f, 7.0f, 9.0f}; ++ private static final int[] INT_ARRAY = new int[] {1, 3, 5, 7, 9}; ++ private static final int[] SHAPE = new int[] {5}; + +- private static final float[] FLOAT_ARRAY = new float[] {1.1f, 3.3f, 5.5f, 7.7f, 9.9f}; +- private static final float[] CASTED_FLOAT_ARRAY = new float[] {1.0f, 3.0f, 5.0f, 7.0f, 9.0f}; +- private static final int[] INT_ARRAY = new int[] {1, 3, 5, 7, 9}; +- private static final int[] SHAPE = new int[] {5}; +- +- @Test +- public void castFloat32ToUint8ShouldSuccess() { +- TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); +- floatBuffer.loadArray(FLOAT_ARRAY, SHAPE); +- CastOp op = new CastOp(DataType.UINT8); +- TensorBuffer uint8Buffer = op.apply(floatBuffer); +- assertThat(uint8Buffer.getDataType()).isEqualTo(DataType.UINT8); +- assertThat(uint8Buffer.getIntArray()).isEqualTo(INT_ARRAY); +- } ++ @Test ++ public void castFloat32ToUint8ShouldSuccess() { ++ TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); ++ floatBuffer.loadArray(FLOAT_ARRAY, SHAPE); ++ CastOp op = new CastOp(DataType.UINT8); ++ TensorBuffer uint8Buffer = op.apply(floatBuffer); ++ assertThat(uint8Buffer.getDataType()).isEqualTo(DataType.UINT8); ++ assertThat(uint8Buffer.getIntArray()).isEqualTo(INT_ARRAY); ++ } + +- @Test +- public void castUint8ToFloat32ShouldSuccess() { +- TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8); +- uint8Buffer.loadArray(INT_ARRAY, SHAPE); +- CastOp op = new CastOp(DataType.FLOAT32); +- TensorBuffer floatBuffer = op.apply(uint8Buffer); +- assertThat(floatBuffer.getDataType()).isEqualTo(DataType.FLOAT32); +- assertThat(floatBuffer.getFloatArray()).isEqualTo(CASTED_FLOAT_ARRAY); +- } ++ @Test ++ public void castUint8ToFloat32ShouldSuccess() { ++ TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8); ++ uint8Buffer.loadArray(INT_ARRAY, SHAPE); ++ CastOp op = new CastOp(DataType.FLOAT32); ++ TensorBuffer floatBuffer = op.apply(uint8Buffer); ++ assertThat(floatBuffer.getDataType()).isEqualTo(DataType.FLOAT32); ++ assertThat(floatBuffer.getFloatArray()).isEqualTo(CASTED_FLOAT_ARRAY); ++ } + +- @Test +- public void castFloat32ToFloat32ShouldNotRecreate() { +- TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); +- floatBuffer.loadArray(FLOAT_ARRAY, SHAPE); +- CastOp op = new CastOp(DataType.FLOAT32); +- TensorBuffer newBuffer = op.apply(floatBuffer); +- assertThat(newBuffer.getDataType()).isEqualTo(DataType.FLOAT32); +- assertThat(newBuffer).isSameInstanceAs(floatBuffer); +- } ++ @Test ++ public void castFloat32ToFloat32ShouldNotRecreate() { ++ TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); ++ floatBuffer.loadArray(FLOAT_ARRAY, SHAPE); ++ CastOp op = new CastOp(DataType.FLOAT32); ++ TensorBuffer newBuffer = op.apply(floatBuffer); ++ assertThat(newBuffer.getDataType()).isEqualTo(DataType.FLOAT32); ++ assertThat(newBuffer).isSameInstanceAs(floatBuffer); ++ } + +- @Test +- public void castUint8ToUint8ShouldNotRecreate() { +- TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8); +- uint8Buffer.loadArray(INT_ARRAY, SHAPE); +- CastOp op = new CastOp(DataType.UINT8); +- TensorBuffer newBuffer = op.apply(uint8Buffer); +- assertThat(newBuffer.getDataType()).isEqualTo(DataType.UINT8); +- assertThat(newBuffer).isSameInstanceAs(uint8Buffer); +- } ++ @Test ++ public void castUint8ToUint8ShouldNotRecreate() { ++ TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8); ++ uint8Buffer.loadArray(INT_ARRAY, SHAPE); ++ CastOp op = new CastOp(DataType.UINT8); ++ TensorBuffer newBuffer = op.apply(uint8Buffer); ++ assertThat(newBuffer.getDataType()).isEqualTo(DataType.UINT8); ++ assertThat(newBuffer).isSameInstanceAs(uint8Buffer); ++ } + +- @Test +- public void castToUnsupportedDataTypeShouldThrow() { +- for (DataType type : new DataType[] {DataType.INT32, DataType.INT64, DataType.STRING}) { +- Assert.assertThrows(IllegalArgumentException.class, () -> new CastOp(type)); ++ @Test ++ public void castToUnsupportedDataTypeShouldThrow() { ++ for (DataType type : new DataType[] {DataType.INT32, DataType.INT64, DataType.STRING}) { ++ Assert.assertThrows(IllegalArgumentException.class, () -> new CastOp(type)); ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java +index 99ded56ce069a..a69bcd7ec0296 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java +@@ -26,16 +26,15 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + /** Tests of {@link DequantizeOp}. */ + @RunWith(RobolectricTestRunner.class) + public final class DequantizeOpTest { +- +- @Test +- public void dequantizeShouldSucess() { +- int[] originalData = new int[] {191, 159, 63, 127, 255, 0}; +- DequantizeOp op = new DequantizeOp(127.0f, 1.0f / 128); +- TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.UINT8); +- input.loadArray(originalData); +- TensorBuffer dequantized = op.apply(input); +- assertThat(dequantized.getDataType()).isEqualTo(DataType.FLOAT32); +- assertThat(dequantized.getFloatArray()) +- .isEqualTo(new float[] {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f}); +- } ++ @Test ++ public void dequantizeShouldSucess() { ++ int[] originalData = new int[] {191, 159, 63, 127, 255, 0}; ++ DequantizeOp op = new DequantizeOp(127.0f, 1.0f / 128); ++ TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.UINT8); ++ input.loadArray(originalData); ++ TensorBuffer dequantized = op.apply(input); ++ assertThat(dequantized.getDataType()).isEqualTo(DataType.FLOAT32); ++ assertThat(dequantized.getFloatArray()) ++ .isEqualTo(new float[] {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f}); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java +index 09ef275a826bc..aabc6be926106 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java +@@ -16,6 +16,7 @@ limitations under the License. + package org.tensorflow.lite.support.common.ops; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.tensorflow.lite.DataType.FLOAT32; + import static org.tensorflow.lite.DataType.UINT8; + +@@ -31,122 +32,120 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + */ + @RunWith(RobolectricTestRunner.class) + public final class NormalizeOpTest { ++ private static final float MEAN = 50; ++ private static final float STDDEV = 50; ++ private static final int NUM_ELEMENTS = 100; ++ ++ @Test ++ public void testNormalizeIntBuffer() { ++ int[] inputArr = new int[NUM_ELEMENTS]; ++ for (int i = 0; i < NUM_ELEMENTS; i++) { ++ inputArr[i] = i; ++ } ++ TensorBuffer input = TensorBuffer.createDynamic(DataType.UINT8); ++ input.loadArray(inputArr, new int[] {inputArr.length}); ++ NormalizeOp op = new NormalizeOp(MEAN, STDDEV); ++ TensorBuffer output = op.apply(input); ++ assertThat(output.getDataType()).isEqualTo(FLOAT32); ++ float[] outputArr = output.getFloatArray(); ++ for (int i = 0; i < NUM_ELEMENTS; i++) { ++ assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV); ++ } ++ } + +- private static final float MEAN = 50; +- private static final float STDDEV = 50; +- private static final int NUM_ELEMENTS = 100; ++ @Test ++ public void testNormalizeFloatBuffer() { ++ float[] inputArr = new float[NUM_ELEMENTS]; ++ for (int i = 0; i < NUM_ELEMENTS; i++) { ++ inputArr[i] = i; ++ } ++ TensorBuffer input = TensorBuffer.createDynamic(FLOAT32); ++ input.loadArray(inputArr, new int[] {inputArr.length}); ++ NormalizeOp op = new NormalizeOp(MEAN, STDDEV); ++ TensorBuffer output = op.apply(input); ++ assertThat(output.getDataType()).isEqualTo(FLOAT32); ++ float[] outputArr = output.getFloatArray(); ++ for (int i = 0; i < NUM_ELEMENTS; i++) { ++ assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV); ++ } ++ } + +- @Test +- public void testNormalizeIntBuffer() { +- int[] inputArr = new int[NUM_ELEMENTS]; +- for (int i = 0; i < NUM_ELEMENTS; i++) { +- inputArr[i] = i; ++ @Test ++ public void testZeroStddev() { ++ Assert.assertThrows(IllegalArgumentException.class, () -> new NormalizeOp(1, 0)); + } +- TensorBuffer input = TensorBuffer.createDynamic(DataType.UINT8); +- input.loadArray(inputArr, new int[] {inputArr.length}); +- NormalizeOp op = new NormalizeOp(MEAN, STDDEV); +- TensorBuffer output = op.apply(input); +- assertThat(output.getDataType()).isEqualTo(FLOAT32); +- float[] outputArr = output.getFloatArray(); +- for (int i = 0; i < NUM_ELEMENTS; i++) { +- assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV); ++ ++ @Test ++ public void testIdentityShortcut() { ++ TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); ++ NormalizeOp op = new NormalizeOp(0, 1); ++ TensorBuffer output = op.apply(input); ++ assertThat(output.getDataType()).isEqualTo(UINT8); ++ assertThat(output).isSameInstanceAs(input); + } +- } + +- @Test +- public void testNormalizeFloatBuffer() { +- float[] inputArr = new float[NUM_ELEMENTS]; +- for (int i = 0; i < NUM_ELEMENTS; i++) { +- inputArr[i] = i; ++ @Test ++ public void testNormalizeOp_zeroMeanAndZeroStddev() { ++ TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); ++ NormalizeOp op = new NormalizeOp(0, 0); ++ TensorBuffer output = op.apply(input); ++ assertThat(output.getDataType()).isEqualTo(UINT8); ++ assertThat(output).isSameInstanceAs(input); + } +- TensorBuffer input = TensorBuffer.createDynamic(FLOAT32); +- input.loadArray(inputArr, new int[] {inputArr.length}); +- NormalizeOp op = new NormalizeOp(MEAN, STDDEV); +- TensorBuffer output = op.apply(input); +- assertThat(output.getDataType()).isEqualTo(FLOAT32); +- float[] outputArr = output.getFloatArray(); +- for (int i = 0; i < NUM_ELEMENTS; i++) { +- assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV); ++ ++ @Test ++ public void testNormalizeOp_zeroMeanAndInifityStddev() { ++ TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); ++ NormalizeOp op = new NormalizeOp(0, Float.POSITIVE_INFINITY); ++ TensorBuffer output = op.apply(input); ++ assertThat(output.getDataType()).isEqualTo(UINT8); ++ assertThat(output).isSameInstanceAs(input); + } +- } +- +- @Test +- public void testZeroStddev() { +- Assert.assertThrows(IllegalArgumentException.class, () -> new NormalizeOp(1, 0)); +- } +- +- @Test +- public void testIdentityShortcut() { +- TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); +- NormalizeOp op = new NormalizeOp(0, 1); +- TensorBuffer output = op.apply(input); +- assertThat(output.getDataType()).isEqualTo(UINT8); +- assertThat(output).isSameInstanceAs(input); +- } +- +- @Test +- public void testNormalizeOp_zeroMeanAndZeroStddev() { +- TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); +- NormalizeOp op = new NormalizeOp(0, 0); +- TensorBuffer output = op.apply(input); +- assertThat(output.getDataType()).isEqualTo(UINT8); +- assertThat(output).isSameInstanceAs(input); +- } +- +- @Test +- public void testNormalizeOp_zeroMeanAndInifityStddev() { +- TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); +- NormalizeOp op = new NormalizeOp(0, Float.POSITIVE_INFINITY); +- TensorBuffer output = op.apply(input); +- assertThat(output.getDataType()).isEqualTo(UINT8); +- assertThat(output).isSameInstanceAs(input); +- } +- +- @Test +- public void testMultiChannelNormalize() { +- float[] inputArr = new float[NUM_ELEMENTS]; +- for (int i = 0; i < NUM_ELEMENTS; i++) { +- inputArr[i] = i; ++ ++ @Test ++ public void testMultiChannelNormalize() { ++ float[] inputArr = new float[NUM_ELEMENTS]; ++ for (int i = 0; i < NUM_ELEMENTS; i++) { ++ inputArr[i] = i; ++ } ++ TensorBuffer input = TensorBuffer.createDynamic(FLOAT32); ++ input.loadArray(inputArr, new int[] {20, 5}); ++ float[] means = new float[] {1, 2, 3, 4, 5}; ++ float[] stddevs = new float[] {6, 7, 8, 9, 10}; ++ NormalizeOp op = new NormalizeOp(means, stddevs); ++ TensorBuffer output = op.apply(input); ++ assertThat(output.getDataType()).isEqualTo(FLOAT32); ++ float[] outputArr = output.getFloatArray(); ++ for (int i = 0; i < NUM_ELEMENTS; i++) { ++ assertThat(outputArr[i]).isEqualTo((i - means[i % 5]) / stddevs[i % 5]); ++ } + } +- TensorBuffer input = TensorBuffer.createDynamic(FLOAT32); +- input.loadArray(inputArr, new int[] {20, 5}); +- float[] means = new float[] {1, 2, 3, 4, 5}; +- float[] stddevs = new float[] {6, 7, 8, 9, 10}; +- NormalizeOp op = new NormalizeOp(means, stddevs); +- TensorBuffer output = op.apply(input); +- assertThat(output.getDataType()).isEqualTo(FLOAT32); +- float[] outputArr = output.getFloatArray(); +- for (int i = 0; i < NUM_ELEMENTS; i++) { +- assertThat(outputArr[i]).isEqualTo((i - means[i % 5]) / stddevs[i % 5]); ++ ++ @Test ++ public void testMultiChannelShortcut() { ++ TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); ++ NormalizeOp op = new NormalizeOp(new float[] {0, 0, 0}, new float[] {1, 1, 1}); ++ TensorBuffer output = op.apply(input); ++ assertThat(output.getDataType()).isEqualTo(UINT8); ++ assertThat(output).isSameInstanceAs(input); ++ } ++ ++ @Test ++ public void testMismatchedNumbersOfMeansAndStddevs() { ++ Assert.assertThrows(IllegalArgumentException.class, ++ () -> new NormalizeOp(new float[] {2, 3}, new float[] {1})); ++ } ++ ++ @Test ++ public void testMismatchedInputTensorChannelNum() { ++ TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); ++ NormalizeOp op = new NormalizeOp(new float[] {0, 0}, new float[] {1, 2}); ++ Assert.assertThrows(IllegalArgumentException.class, () -> op.apply(input)); ++ } ++ ++ @Test ++ public void testAnyChannelInvalidStddev() { ++ Assert.assertThrows(IllegalArgumentException.class, ++ () -> new NormalizeOp(new float[] {2, 3}, new float[] {1, 0})); + } +- } +- +- @Test +- public void testMultiChannelShortcut() { +- TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); +- NormalizeOp op = new NormalizeOp(new float[] {0, 0, 0}, new float[] {1, 1, 1}); +- TensorBuffer output = op.apply(input); +- assertThat(output.getDataType()).isEqualTo(UINT8); +- assertThat(output).isSameInstanceAs(input); +- } +- +- @Test +- public void testMismatchedNumbersOfMeansAndStddevs() { +- Assert.assertThrows( +- IllegalArgumentException.class, () -> new NormalizeOp(new float[] {2, 3}, new float[] {1})); +- } +- +- @Test +- public void testMismatchedInputTensorChannelNum() { +- TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); +- NormalizeOp op = new NormalizeOp(new float[] {0, 0}, new float[] {1, 2}); +- Assert.assertThrows(IllegalArgumentException.class, () -> op.apply(input)); +- } +- +- @Test +- public void testAnyChannelInvalidStddev() { +- Assert.assertThrows( +- IllegalArgumentException.class, +- () -> new NormalizeOp(new float[] {2, 3}, new float[] {1, 0})); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java +index 8ef72f92e0696..519cd287e1575 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java +@@ -26,15 +26,14 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + /** Tests of {@link QuantizeOp}. */ + @RunWith(RobolectricTestRunner.class) + public final class QuantizeOpTest { +- +- @Test +- public void quantizeShouldSuccess() { +- float[] originalData = {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f}; // -0.9921875 == -127 / 128 +- QuantizeOp op = new QuantizeOp(127.0f, 1.0f / 128); +- TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.FLOAT32); +- input.loadArray(originalData); +- TensorBuffer quantized = op.apply(input); +- assertThat(quantized.getDataType()).isEqualTo(DataType.FLOAT32); +- assertThat(quantized.getIntArray()).isEqualTo(new int[] {191, 159, 63, 127, 255, 0}); +- } ++ @Test ++ public void quantizeShouldSuccess() { ++ float[] originalData = {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f}; // -0.9921875 == -127 / 128 ++ QuantizeOp op = new QuantizeOp(127.0f, 1.0f / 128); ++ TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.FLOAT32); ++ input.loadArray(originalData); ++ TensorBuffer quantized = op.apply(input); ++ assertThat(quantized.getDataType()).isEqualTo(DataType.FLOAT32); ++ assertThat(quantized.getIntArray()).isEqualTo(new int[] {191, 159, 63, 127, 255, 0}); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java +index 7f16c8e95628d..e8edb588c61c6 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java +@@ -18,7 +18,7 @@ package org.tensorflow.lite.support.image; + import static com.google.common.truth.Truth.assertThat; + + import android.graphics.RectF; +-import java.util.List; ++ + import org.junit.Assert; + import org.junit.Before; + import org.junit.Test; +@@ -28,213 +28,142 @@ import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.image.BoundingBoxUtil.CoordinateType; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + ++import java.util.List; ++ + /** Tests of {@link BoundingBoxUtil}. */ + @RunWith(RobolectricTestRunner.class) + public class BoundingBoxUtilTest { +- +- private TensorBuffer tensorBuffer; +- +- @Before +- public void setUp() { +- // 2 bounding boxes with additional batch dimension. +- tensorBuffer = TensorBuffer.createFixedSize(new int[] {1, 2, 4}, DataType.FLOAT32); +- } +- +- @Test +- public void convertDefaultRatioBoundaries() { +- tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f}); +- +- List<RectF> boxList = +- BoundingBoxUtil.convert( +- tensorBuffer, +- new int[] {0, 1, 2, 3}, +- -1, +- BoundingBoxUtil.Type.BOUNDARIES, +- CoordinateType.RATIO, +- 500, +- 400); +- +- assertThat(boxList).hasSize(2); +- assertThat(boxList.get(0)).isEqualTo(new RectF(100, 100, 300, 400)); +- assertThat(boxList.get(1)).isEqualTo(new RectF(200, 0, 400, 500)); +- } +- +- @Test +- public void convertComplexTensor() { +- tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 4, 2}, DataType.FLOAT32); +- tensorBuffer.loadArray( +- new float[] { +- // sub tensor 0 +- 0, 1, 10, 11, 20, 21, 30, 31, +- // sub tensor 1 +- 100, 101, 110, 111, 120, 121, 130, 131, +- // sub tensor 2 +- 200, 201, 210, 211, 220, 221, 230, 231 +- }); +- +- List<RectF> boxList = +- BoundingBoxUtil.convert( +- tensorBuffer, +- new int[] {0, 1, 2, 3}, +- 1, +- BoundingBoxUtil.Type.BOUNDARIES, +- CoordinateType.PIXEL, +- 0, +- 0); +- +- assertThat(boxList).hasSize(6); +- assertThat(boxList.get(0)).isEqualTo(new RectF(0, 10, 20, 30)); +- assertThat(boxList.get(1)).isEqualTo(new RectF(1, 11, 21, 31)); +- assertThat(boxList.get(2)).isEqualTo(new RectF(100, 110, 120, 130)); +- assertThat(boxList.get(3)).isEqualTo(new RectF(101, 111, 121, 131)); +- } +- +- @Test +- public void convertIndexedRatioBoundaries() { +- tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f}); +- +- List<RectF> boxList = +- BoundingBoxUtil.convert( +- tensorBuffer, +- new int[] {1, 0, 3, 2}, +- -1, +- BoundingBoxUtil.Type.BOUNDARIES, +- CoordinateType.RATIO, +- 500, +- 400); +- +- assertThat(boxList).hasSize(2); +- assertThat(boxList.get(0)).isEqualTo(new RectF(80, 125, 320, 375)); +- assertThat(boxList.get(1)).isEqualTo(new RectF(0, 250, 400, 500)); +- } +- +- @Test +- public void convertPixelBoundaries() { +- tensorBuffer.loadArray(new float[] {100, 100, 300, 400, 200, 0, 400, 500}); +- +- List<RectF> boxList = +- BoundingBoxUtil.convert( +- tensorBuffer, +- new int[] {0, 1, 2, 3}, +- -1, +- BoundingBoxUtil.Type.BOUNDARIES, +- CoordinateType.PIXEL, +- 500, +- 400); +- +- assertThat(boxList) +- .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500)) +- .inOrder(); +- } +- +- @Test +- public void convertRatioUpperLeft() { +- tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.5f, 0.6f, 0.5f, 0.0f, 0.5f, 1.0f}); +- +- List<RectF> boxList = +- BoundingBoxUtil.convert( +- tensorBuffer, +- new int[] {0, 1, 2, 3}, +- -1, +- BoundingBoxUtil.Type.UPPER_LEFT, +- CoordinateType.RATIO, +- 500, +- 400); +- +- assertThat(boxList).hasSize(2); +- assertThat(boxList) +- .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500)) +- .inOrder(); +- } +- +- @Test +- public void convertPixelUpperLeft() { +- tensorBuffer.loadArray(new float[] {100, 100, 200, 300, 200, 0, 200, 500}); +- +- List<RectF> boxList = +- BoundingBoxUtil.convert( +- tensorBuffer, +- new int[] {0, 1, 2, 3}, +- -1, +- BoundingBoxUtil.Type.UPPER_LEFT, +- CoordinateType.PIXEL, +- 500, +- 400); +- +- assertThat(boxList) +- .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500)) +- .inOrder(); +- } +- +- @Test +- public void convertRatioCenter() { +- tensorBuffer.loadArray(new float[] {0.5f, 0.5f, 0.5f, 0.6f, 0.75f, 0.5f, 0.5f, 1.0f}); +- +- List<RectF> boxList = +- BoundingBoxUtil.convert( +- tensorBuffer, +- new int[] {0, 1, 2, 3}, +- -1, +- BoundingBoxUtil.Type.CENTER, +- CoordinateType.RATIO, +- 500, +- 400); +- +- assertThat(boxList) +- .containsExactly(new RectF(100, 99.99999f, 300, 400), new RectF(200, 0, 400, 500)) +- .inOrder(); +- } +- +- @Test +- public void convertPixelCenter() { +- tensorBuffer.loadArray(new float[] {200, 250, 200, 300, 300, 250, 200, 500}); +- +- List<RectF> boxList = +- BoundingBoxUtil.convert( +- tensorBuffer, +- new int[] {0, 1, 2, 3}, +- -1, +- BoundingBoxUtil.Type.CENTER, +- CoordinateType.PIXEL, +- 500, +- 400); +- +- assertThat(boxList) +- .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500)) +- .inOrder(); +- } +- +- @Test +- public void convertTensorWithUnexpectedShapeShouldThrow() { +- TensorBuffer badShapeTensor = TensorBuffer.createFixedSize(new int[] {1, 5}, DataType.FLOAT32); +- +- Assert.assertThrows( +- IllegalArgumentException.class, +- () -> +- BoundingBoxUtil.convert( +- badShapeTensor, +- new int[] {0, 1, 2, 3}, +- -1, +- BoundingBoxUtil.Type.BOUNDARIES, +- CoordinateType.RATIO, +- 300, +- 400)); +- } +- +- @Test +- public void convertIntTensorShouldThrow() { +- TensorBuffer badTypeTensor = TensorBuffer.createFixedSize(new int[] {1, 4}, DataType.UINT8); +- +- Assert.assertThrows( +- IllegalArgumentException.class, +- () -> +- BoundingBoxUtil.convert( +- badTypeTensor, +- new int[] {0, 1, 2, 3}, +- -1, +- BoundingBoxUtil.Type.BOUNDARIES, +- CoordinateType.RATIO, +- 300, +- 400)); +- } ++ private TensorBuffer tensorBuffer; ++ ++ @Before ++ public void setUp() { ++ // 2 bounding boxes with additional batch dimension. ++ tensorBuffer = TensorBuffer.createFixedSize(new int[] {1, 2, 4}, DataType.FLOAT32); ++ } ++ ++ @Test ++ public void convertDefaultRatioBoundaries() { ++ tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f}); ++ ++ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1, ++ BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 500, 400); ++ ++ assertThat(boxList).hasSize(2); ++ assertThat(boxList.get(0)).isEqualTo(new RectF(100, 100, 300, 400)); ++ assertThat(boxList.get(1)).isEqualTo(new RectF(200, 0, 400, 500)); ++ } ++ ++ @Test ++ public void convertComplexTensor() { ++ tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 4, 2}, DataType.FLOAT32); ++ tensorBuffer.loadArray(new float[] {// sub tensor 0 ++ 0, 1, 10, 11, 20, 21, 30, 31, ++ // sub tensor 1 ++ 100, 101, 110, 111, 120, 121, 130, 131, ++ // sub tensor 2 ++ 200, 201, 210, 211, 220, 221, 230, 231}); ++ ++ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, 1, ++ BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.PIXEL, 0, 0); ++ ++ assertThat(boxList).hasSize(6); ++ assertThat(boxList.get(0)).isEqualTo(new RectF(0, 10, 20, 30)); ++ assertThat(boxList.get(1)).isEqualTo(new RectF(1, 11, 21, 31)); ++ assertThat(boxList.get(2)).isEqualTo(new RectF(100, 110, 120, 130)); ++ assertThat(boxList.get(3)).isEqualTo(new RectF(101, 111, 121, 131)); ++ } ++ ++ @Test ++ public void convertIndexedRatioBoundaries() { ++ tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f}); ++ ++ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {1, 0, 3, 2}, -1, ++ BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 500, 400); ++ ++ assertThat(boxList).hasSize(2); ++ assertThat(boxList.get(0)).isEqualTo(new RectF(80, 125, 320, 375)); ++ assertThat(boxList.get(1)).isEqualTo(new RectF(0, 250, 400, 500)); ++ } ++ ++ @Test ++ public void convertPixelBoundaries() { ++ tensorBuffer.loadArray(new float[] {100, 100, 300, 400, 200, 0, 400, 500}); ++ ++ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1, ++ BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.PIXEL, 500, 400); ++ ++ assertThat(boxList) ++ .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500)) ++ .inOrder(); ++ } ++ ++ @Test ++ public void convertRatioUpperLeft() { ++ tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.5f, 0.6f, 0.5f, 0.0f, 0.5f, 1.0f}); ++ ++ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1, ++ BoundingBoxUtil.Type.UPPER_LEFT, CoordinateType.RATIO, 500, 400); ++ ++ assertThat(boxList).hasSize(2); ++ assertThat(boxList) ++ .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500)) ++ .inOrder(); ++ } ++ ++ @Test ++ public void convertPixelUpperLeft() { ++ tensorBuffer.loadArray(new float[] {100, 100, 200, 300, 200, 0, 200, 500}); ++ ++ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1, ++ BoundingBoxUtil.Type.UPPER_LEFT, CoordinateType.PIXEL, 500, 400); ++ ++ assertThat(boxList) ++ .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500)) ++ .inOrder(); ++ } ++ ++ @Test ++ public void convertRatioCenter() { ++ tensorBuffer.loadArray(new float[] {0.5f, 0.5f, 0.5f, 0.6f, 0.75f, 0.5f, 0.5f, 1.0f}); ++ ++ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1, ++ BoundingBoxUtil.Type.CENTER, CoordinateType.RATIO, 500, 400); ++ ++ assertThat(boxList) ++ .containsExactly(new RectF(100, 99.99999f, 300, 400), new RectF(200, 0, 400, 500)) ++ .inOrder(); ++ } ++ ++ @Test ++ public void convertPixelCenter() { ++ tensorBuffer.loadArray(new float[] {200, 250, 200, 300, 300, 250, 200, 500}); ++ ++ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1, ++ BoundingBoxUtil.Type.CENTER, CoordinateType.PIXEL, 500, 400); ++ ++ assertThat(boxList) ++ .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500)) ++ .inOrder(); ++ } ++ ++ @Test ++ public void convertTensorWithUnexpectedShapeShouldThrow() { ++ TensorBuffer badShapeTensor = ++ TensorBuffer.createFixedSize(new int[] {1, 5}, DataType.FLOAT32); ++ ++ Assert.assertThrows(IllegalArgumentException.class, ++ () ++ -> BoundingBoxUtil.convert(badShapeTensor, new int[] {0, 1, 2, 3}, -1, ++ BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 300, 400)); ++ } ++ ++ @Test ++ public void convertIntTensorShouldThrow() { ++ TensorBuffer badTypeTensor = TensorBuffer.createFixedSize(new int[] {1, 4}, DataType.UINT8); ++ ++ Assert.assertThrows(IllegalArgumentException.class, ++ () ++ -> BoundingBoxUtil.convert(badTypeTensor, new int[] {0, 1, 2, 3}, -1, ++ BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 300, 400)); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java +index c41508308291a..329b5aa370744 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java +@@ -15,10 +15,12 @@ limitations under the License. + package org.tensorflow.lite.support.image; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleBitmap; + import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleTensorBuffer; + + import android.graphics.Bitmap; ++ + import org.junit.Test; + import org.junit.runner.RunWith; + import org.junit.runners.JUnit4; +@@ -27,22 +29,21 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + + @RunWith(JUnit4.class) + public final class ColorSpaceTypeInstrumentedTest { +- +- @Test +- public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithUint8() { +- TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.UINT8, false); +- Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer); +- +- Bitmap expectedBitmap = createGrayscaleBitmap(); +- assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); +- } +- +- @Test +- public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithFloat() { +- TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.FLOAT32, false); +- Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer); +- +- Bitmap expectedBitmap = createGrayscaleBitmap(); +- assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); +- } ++ @Test ++ public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithUint8() { ++ TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.UINT8, false); ++ Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer); ++ ++ Bitmap expectedBitmap = createGrayscaleBitmap(); ++ assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); ++ } ++ ++ @Test ++ public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithFloat() { ++ TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.FLOAT32, false); ++ Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer); ++ ++ Bitmap expectedBitmap = createGrayscaleBitmap(); ++ assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java +index 46977fdb2bdfa..92612255269f6 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java +@@ -16,6 +16,7 @@ limitations under the License. + package org.tensorflow.lite.support.image; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + import static org.tensorflow.lite.support.image.TestImageCreator.createRgbBitmap; + import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensorBuffer; +@@ -23,8 +24,7 @@ import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensor + import android.graphics.Bitmap; + import android.graphics.Bitmap.Config; + import android.graphics.ImageFormat; +-import java.util.Arrays; +-import java.util.Collection; ++ + import org.junit.Rule; + import org.junit.Test; + import org.junit.rules.ErrorCollector; +@@ -38,386 +38,353 @@ import org.robolectric.RobolectricTestRunner; + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + ++import java.util.Arrays; ++import java.util.Collection; ++ + /** Tests of {@link ImageConversions}. */ + @RunWith(Suite.class) +-@SuiteClasses({ +- ColorSpaceTypeTest.ValidShapeTest.class, +- ColorSpaceTypeTest.InvalidShapeTest.class, +- ColorSpaceTypeTest.BitmapConfigTest.class, +- ColorSpaceTypeTest.ImageFormatTest.class, +- ColorSpaceTypeTest.YuvImageTest.class, +- ColorSpaceTypeTest.AssertNumElementsTest.class, +- ColorSpaceTypeTest.General.class +-}) ++@SuiteClasses({ColorSpaceTypeTest.ValidShapeTest.class, ColorSpaceTypeTest.InvalidShapeTest.class, ++ ColorSpaceTypeTest.BitmapConfigTest.class, ColorSpaceTypeTest.ImageFormatTest.class, ++ ColorSpaceTypeTest.YuvImageTest.class, ColorSpaceTypeTest.AssertNumElementsTest.class, ++ ColorSpaceTypeTest.General.class}) + public class ColorSpaceTypeTest { +- +- /** Parameterized tests for valid shapes. */ +- @RunWith(ParameterizedRobolectricTestRunner.class) +- public static final class ValidShapeTest extends ColorSpaceTypeTest { +- +- @Parameter(0) +- public ColorSpaceType colorSpaceType; +- +- /** The shape that matches the colorSpaceType. */ +- @Parameter(1) +- public int[] validShape; +- +- /** The height of validShape. */ +- @Parameter(2) +- public int expectedHeight; +- +- /** The width of validShape. */ +- @Parameter(3) +- public int expectedWidth; +- +- @Parameters(name = "colorSpaceType={0}; validShape={1}; height={2}; width={3}") +- public static Collection<Object[]> data() { +- return Arrays.asList( +- new Object[][] { +- {ColorSpaceType.RGB, new int[] {1, 10, 20, 3}, 10, 20}, +- {ColorSpaceType.RGB, new int[] {10, 20, 3}, 10, 20}, +- {ColorSpaceType.GRAYSCALE, new int[] {10, 20}, 10, 20}, +- {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 1}, 10, 20}, +- }); +- } +- +- @Test +- public void getHeightSucceedsWithValidShape() { +- assertThat(colorSpaceType.getHeight(validShape)).isEqualTo(expectedHeight); ++ /** Parameterized tests for valid shapes. */ ++ @RunWith(ParameterizedRobolectricTestRunner.class) ++ public static final class ValidShapeTest extends ColorSpaceTypeTest { ++ @Parameter(0) ++ public ColorSpaceType colorSpaceType; ++ ++ /** The shape that matches the colorSpaceType. */ ++ @Parameter(1) ++ public int[] validShape; ++ ++ /** The height of validShape. */ ++ @Parameter(2) ++ public int expectedHeight; ++ ++ /** The width of validShape. */ ++ @Parameter(3) ++ public int expectedWidth; ++ ++ @Parameters(name = "colorSpaceType={0}; validShape={1}; height={2}; width={3}") ++ public static Collection<Object[]> data() { ++ return Arrays.asList(new Object[][] { ++ {ColorSpaceType.RGB, new int[] {1, 10, 20, 3}, 10, 20}, ++ {ColorSpaceType.RGB, new int[] {10, 20, 3}, 10, 20}, ++ {ColorSpaceType.GRAYSCALE, new int[] {10, 20}, 10, 20}, ++ {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 1}, 10, 20}, ++ }); ++ } ++ ++ @Test ++ public void getHeightSucceedsWithValidShape() { ++ assertThat(colorSpaceType.getHeight(validShape)).isEqualTo(expectedHeight); ++ } ++ ++ @Test ++ public void getWidthSucceedsWithValidShape() { ++ assertThat(colorSpaceType.getWidth(validShape)).isEqualTo(expectedWidth); ++ } + } + +- @Test +- public void getWidthSucceedsWithValidShape() { +- assertThat(colorSpaceType.getWidth(validShape)).isEqualTo(expectedWidth); +- } +- } +- +- /** Parameterized tests for invalid shapes. */ +- @RunWith(ParameterizedRobolectricTestRunner.class) +- public static final class InvalidShapeTest extends ColorSpaceTypeTest { +- +- private static final String RGB_ASSERT_SHAPE_MESSAGE = +- "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels" +- + " representing R, G, B in order. The provided image shape is "; +- private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE = +- "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image" +- + " shape is "; +- +- @Parameter(0) +- public ColorSpaceType colorSpaceType; +- +- /** The shape that does not match the colorSpaceType. */ +- @Parameter(1) +- public int[] invalidShape; +- +- @Parameter(2) +- public String errorMessage; +- +- @Parameters(name = "colorSpaceType={0}; invalidShape={1}") +- public static Collection<Object[]> data() { +- return Arrays.asList( +- new Object[][] { +- {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {1, -10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {1, 10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {-10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.GRAYSCALE, new int[] {1, -10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.GRAYSCALE, new int[] {1, 10, -20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.GRAYSCALE, new int[] {-10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.GRAYSCALE, new int[] {10, -20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, +- }); ++ /** Parameterized tests for invalid shapes. */ ++ @RunWith(ParameterizedRobolectricTestRunner.class) ++ public static final class InvalidShapeTest extends ColorSpaceTypeTest { ++ private static final String RGB_ASSERT_SHAPE_MESSAGE = ++ "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels" ++ + " representing R, G, B in order. The provided image shape is "; ++ private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE = ++ "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image" ++ + " shape is "; ++ ++ @Parameter(0) ++ public ColorSpaceType colorSpaceType; ++ ++ /** The shape that does not match the colorSpaceType. */ ++ @Parameter(1) ++ public int[] invalidShape; ++ ++ @Parameter(2) ++ public String errorMessage; ++ ++ @Parameters(name = "colorSpaceType={0}; invalidShape={1}") ++ public static Collection<Object[]> data() { ++ return Arrays.asList(new Object[][] { ++ {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {1, -10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {1, 10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {-10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20}, ++ GRAYSCALE_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3}, ++ GRAYSCALE_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.GRAYSCALE, new int[] {1, -10, 20}, ++ GRAYSCALE_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.GRAYSCALE, new int[] {1, 10, -20}, ++ GRAYSCALE_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4}, ++ GRAYSCALE_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.GRAYSCALE, new int[] {-10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.GRAYSCALE, new int[] {10, -20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, ++ }); ++ } ++ ++ @Test ++ public void assertShapeFaislsWithInvalidShape() { ++ IllegalArgumentException exception = assertThrows( ++ IllegalArgumentException.class, () -> colorSpaceType.assertShape(invalidShape)); ++ assertThat(exception).hasMessageThat().contains( ++ errorMessage + Arrays.toString(invalidShape)); ++ } ++ ++ @Test ++ public void getHeightFaislsWithInvalidShape() { ++ IllegalArgumentException exception = assertThrows( ++ IllegalArgumentException.class, () -> colorSpaceType.getHeight(invalidShape)); ++ assertThat(exception).hasMessageThat().contains( ++ errorMessage + Arrays.toString(invalidShape)); ++ } ++ ++ @Test ++ public void getWidthFaislsWithInvalidShape() { ++ IllegalArgumentException exception = assertThrows( ++ IllegalArgumentException.class, () -> colorSpaceType.getWidth(invalidShape)); ++ assertThat(exception).hasMessageThat().contains( ++ errorMessage + Arrays.toString(invalidShape)); ++ } + } + +- @Test +- public void assertShapeFaislsWithInvalidShape() { +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> colorSpaceType.assertShape(invalidShape)); +- assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape)); ++ /** Parameterized tests for Bitmap Config. */ ++ @RunWith(ParameterizedRobolectricTestRunner.class) ++ public static final class BitmapConfigTest extends ColorSpaceTypeTest { ++ @Parameter(0) ++ public ColorSpaceType colorSpaceType; ++ ++ /** The Bitmap configuration match the colorSpaceType. */ ++ @Parameter(1) ++ public Config config; ++ ++ @Parameters(name = "colorSpaceType={0}; config={1}") ++ public static Collection<Object[]> data() { ++ return Arrays.asList(new Object[][] { ++ {ColorSpaceType.RGB, Config.ARGB_8888}, ++ {ColorSpaceType.GRAYSCALE, Config.ALPHA_8}, ++ }); ++ } ++ ++ @Test ++ public void fromBitmapConfigSucceedsWithSupportedConfig() { ++ assertThat(ColorSpaceType.fromBitmapConfig(config)).isEqualTo(colorSpaceType); ++ } ++ ++ @Test ++ public void toBitmapConfigSucceedsWithSupportedConfig() { ++ assertThat(colorSpaceType.toBitmapConfig()).isEqualTo(config); ++ } + } + +- @Test +- public void getHeightFaislsWithInvalidShape() { +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> colorSpaceType.getHeight(invalidShape)); +- assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape)); ++ /** Parameterized tests for ImageFormat. */ ++ @RunWith(ParameterizedRobolectricTestRunner.class) ++ public static final class ImageFormatTest extends ColorSpaceTypeTest { ++ @Parameter(0) ++ public ColorSpaceType colorSpaceType; ++ ++ /** The ImageFormat that matches the colorSpaceType. */ ++ @Parameter(1) ++ public int imageFormat; ++ ++ @Parameters(name = "colorSpaceType={0}; imageFormat={1}") ++ public static Collection<Object[]> data() { ++ return Arrays.asList(new Object[][] { ++ {ColorSpaceType.NV21, ImageFormat.NV21}, ++ {ColorSpaceType.YV12, ImageFormat.YV12}, ++ {ColorSpaceType.YUV_420_888, ImageFormat.YUV_420_888}, ++ }); ++ } ++ ++ @Test ++ public void fromImageFormatSucceedsWithSupportedImageFormat() { ++ assertThat(ColorSpaceType.fromImageFormat(imageFormat)).isEqualTo(colorSpaceType); ++ } + } + +- @Test +- public void getWidthFaislsWithInvalidShape() { +- IllegalArgumentException exception = +- assertThrows(IllegalArgumentException.class, () -> colorSpaceType.getWidth(invalidShape)); +- assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape)); +- } +- } +- +- /** Parameterized tests for Bitmap Config. */ +- @RunWith(ParameterizedRobolectricTestRunner.class) +- public static final class BitmapConfigTest extends ColorSpaceTypeTest { +- +- @Parameter(0) +- public ColorSpaceType colorSpaceType; +- +- /** The Bitmap configuration match the colorSpaceType. */ +- @Parameter(1) +- public Config config; +- +- @Parameters(name = "colorSpaceType={0}; config={1}") +- public static Collection<Object[]> data() { +- return Arrays.asList( +- new Object[][] { +- {ColorSpaceType.RGB, Config.ARGB_8888}, +- {ColorSpaceType.GRAYSCALE, Config.ALPHA_8}, +- }); ++ /** Parameterized tests for YUV image formats: NV12, NV21, YV12, YV21, YUV_420_888. */ ++ @RunWith(ParameterizedRobolectricTestRunner.class) ++ public static final class YuvImageTest extends ColorSpaceTypeTest { ++ @Parameter(0) ++ public ColorSpaceType colorSpaceType; ++ ++ @Parameters(name = "colorSpaceType={0}") ++ public static Collection<Object[]> data() { ++ return Arrays.asList(new Object[][] { ++ {ColorSpaceType.NV12}, ++ {ColorSpaceType.NV21}, ++ {ColorSpaceType.YV12}, ++ {ColorSpaceType.YV21}, ++ {ColorSpaceType.YUV_420_888}, ++ }); ++ } ++ ++ @Test ++ public void convertTensorBufferToBitmapShouldFail() { ++ UnsupportedOperationException exception = ++ assertThrows(UnsupportedOperationException.class, ++ () ++ -> colorSpaceType.convertTensorBufferToBitmap( ++ TensorBuffer.createDynamic(DataType.FLOAT32))); ++ assertThat(exception).hasMessageThat().contains( ++ "convertTensorBufferToBitmap() is unsupported for the color space type " ++ + colorSpaceType.name()); ++ } ++ ++ @Test ++ public void getWidthShouldFail() { ++ UnsupportedOperationException exception = ++ assertThrows(UnsupportedOperationException.class, ++ () -> colorSpaceType.getWidth(new int[] {})); ++ assertThat(exception).hasMessageThat().contains( ++ "getWidth() only supports RGB and GRAYSCALE formats, but not " ++ + colorSpaceType.name()); ++ } ++ ++ @Test ++ public void getHeightShouldFail() { ++ UnsupportedOperationException exception = ++ assertThrows(UnsupportedOperationException.class, ++ () -> colorSpaceType.getHeight(new int[] {})); ++ assertThat(exception).hasMessageThat().contains( ++ "getHeight() only supports RGB and GRAYSCALE formats, but not " ++ + colorSpaceType.name()); ++ } ++ ++ @Test ++ public void assertShapeShouldFail() { ++ UnsupportedOperationException exception = ++ assertThrows(UnsupportedOperationException.class, ++ () -> colorSpaceType.assertShape(new int[] {})); ++ assertThat(exception).hasMessageThat().contains( ++ "assertShape() only supports RGB and GRAYSCALE formats, but not " ++ + colorSpaceType.name()); ++ } ++ ++ @Test ++ public void getChannelValueShouldFail() { ++ UnsupportedOperationException exception = assertThrows( ++ UnsupportedOperationException.class, () -> colorSpaceType.getChannelValue()); ++ assertThat(exception).hasMessageThat().contains( ++ "getChannelValue() is unsupported for the color space type " ++ + colorSpaceType.name()); ++ } ++ ++ @Test ++ public void getNormalizedShapeShouldFail() { ++ UnsupportedOperationException exception = ++ assertThrows(UnsupportedOperationException.class, ++ () -> colorSpaceType.getNormalizedShape(new int[] {})); ++ assertThat(exception).hasMessageThat().contains( ++ "getNormalizedShape() is unsupported for the color space type " ++ + colorSpaceType.name()); ++ } ++ ++ @Test ++ public void getShapeInfoMessageShouldFail() { ++ UnsupportedOperationException exception = ++ assertThrows(UnsupportedOperationException.class, ++ () -> colorSpaceType.getShapeInfoMessage()); ++ assertThat(exception).hasMessageThat().contains( ++ "getShapeInfoMessage() is unsupported for the color space type " ++ + colorSpaceType.name()); ++ } ++ ++ @Test ++ public void toBitmapConfigShouldFail() { ++ UnsupportedOperationException exception = assertThrows( ++ UnsupportedOperationException.class, () -> colorSpaceType.toBitmapConfig()); ++ assertThat(exception).hasMessageThat().contains( ++ "toBitmapConfig() is unsupported for the color space type " ++ + colorSpaceType.name()); ++ } + } + +- @Test +- public void fromBitmapConfigSucceedsWithSupportedConfig() { +- assertThat(ColorSpaceType.fromBitmapConfig(config)).isEqualTo(colorSpaceType); +- } +- +- @Test +- public void toBitmapConfigSucceedsWithSupportedConfig() { +- assertThat(colorSpaceType.toBitmapConfig()).isEqualTo(config); +- } +- } +- +- /** Parameterized tests for ImageFormat. */ +- @RunWith(ParameterizedRobolectricTestRunner.class) +- public static final class ImageFormatTest extends ColorSpaceTypeTest { +- +- @Parameter(0) +- public ColorSpaceType colorSpaceType; +- +- /** The ImageFormat that matches the colorSpaceType. */ +- @Parameter(1) +- public int imageFormat; +- +- @Parameters(name = "colorSpaceType={0}; imageFormat={1}") +- public static Collection<Object[]> data() { +- return Arrays.asList( +- new Object[][] { +- {ColorSpaceType.NV21, ImageFormat.NV21}, +- {ColorSpaceType.YV12, ImageFormat.YV12}, +- {ColorSpaceType.YUV_420_888, ImageFormat.YUV_420_888}, +- }); +- } +- +- @Test +- public void fromImageFormatSucceedsWithSupportedImageFormat() { +- assertThat(ColorSpaceType.fromImageFormat(imageFormat)).isEqualTo(colorSpaceType); +- } +- } +- +- /** Parameterized tests for YUV image formats: NV12, NV21, YV12, YV21, YUV_420_888. */ +- @RunWith(ParameterizedRobolectricTestRunner.class) +- public static final class YuvImageTest extends ColorSpaceTypeTest { +- +- @Parameter(0) +- public ColorSpaceType colorSpaceType; +- +- @Parameters(name = "colorSpaceType={0}") +- public static Collection<Object[]> data() { +- return Arrays.asList( +- new Object[][] { +- {ColorSpaceType.NV12}, +- {ColorSpaceType.NV21}, +- {ColorSpaceType.YV12}, +- {ColorSpaceType.YV21}, +- {ColorSpaceType.YUV_420_888}, +- }); +- } +- +- @Test +- public void convertTensorBufferToBitmapShouldFail() { +- UnsupportedOperationException exception = +- assertThrows( +- UnsupportedOperationException.class, +- () -> +- colorSpaceType.convertTensorBufferToBitmap( +- TensorBuffer.createDynamic(DataType.FLOAT32))); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "convertTensorBufferToBitmap() is unsupported for the color space type " +- + colorSpaceType.name()); +- } +- +- @Test +- public void getWidthShouldFail() { +- UnsupportedOperationException exception = +- assertThrows( +- UnsupportedOperationException.class, () -> colorSpaceType.getWidth(new int[] {})); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "getWidth() only supports RGB and GRAYSCALE formats, but not " +- + colorSpaceType.name()); +- } +- +- @Test +- public void getHeightShouldFail() { +- UnsupportedOperationException exception = +- assertThrows( +- UnsupportedOperationException.class, () -> colorSpaceType.getHeight(new int[] {})); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "getHeight() only supports RGB and GRAYSCALE formats, but not " +- + colorSpaceType.name()); +- } +- +- @Test +- public void assertShapeShouldFail() { +- UnsupportedOperationException exception = +- assertThrows( +- UnsupportedOperationException.class, () -> colorSpaceType.assertShape(new int[] {})); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "assertShape() only supports RGB and GRAYSCALE formats, but not " +- + colorSpaceType.name()); +- } +- +- @Test +- public void getChannelValueShouldFail() { +- UnsupportedOperationException exception = +- assertThrows(UnsupportedOperationException.class, () -> colorSpaceType.getChannelValue()); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "getChannelValue() is unsupported for the color space type " + colorSpaceType.name()); +- } +- +- @Test +- public void getNormalizedShapeShouldFail() { +- UnsupportedOperationException exception = +- assertThrows( +- UnsupportedOperationException.class, +- () -> colorSpaceType.getNormalizedShape(new int[] {})); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "getNormalizedShape() is unsupported for the color space type " +- + colorSpaceType.name()); +- } +- +- @Test +- public void getShapeInfoMessageShouldFail() { +- UnsupportedOperationException exception = +- assertThrows( +- UnsupportedOperationException.class, () -> colorSpaceType.getShapeInfoMessage()); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "getShapeInfoMessage() is unsupported for the color space type " +- + colorSpaceType.name()); +- } +- +- @Test +- public void toBitmapConfigShouldFail() { +- UnsupportedOperationException exception = +- assertThrows(UnsupportedOperationException.class, () -> colorSpaceType.toBitmapConfig()); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "toBitmapConfig() is unsupported for the color space type " + colorSpaceType.name()); +- } +- } +- +- /** Parameterized tests for assertNumElements/getNumElements with all image formats. */ +- @RunWith(ParameterizedRobolectricTestRunner.class) +- public static final class AssertNumElementsTest extends ColorSpaceTypeTest { +- private static final int HEIGHT = 2; +- private static final int WIDTH = 3; +- private static final int LESS_NUM_ELEMENTS = 5; // less than expected +- private static final int MORE_NUM_ELEMENTS = 20; // more than expected. OK. +- @Rule public ErrorCollector errorCollector = new ErrorCollector(); +- +- @Parameter(0) +- public ColorSpaceType colorSpaceType; +- +- @Parameter(1) +- public int expectedNumElements; +- +- @Parameters(name = "colorSpaceType={0};expectedNumElements={1}") +- public static Collection<Object[]> data() { +- return Arrays.asList( +- new Object[][] { +- {ColorSpaceType.RGB, 18}, +- {ColorSpaceType.GRAYSCALE, 6}, +- {ColorSpaceType.NV12, 10}, +- {ColorSpaceType.NV21, 10}, +- {ColorSpaceType.YV12, 10}, +- {ColorSpaceType.YV21, 10}, +- }); +- } +- +- @Test +- public void getNumElementsShouldSucceedWithExpectedNumElements() { +- assertThat(colorSpaceType.getNumElements(HEIGHT, WIDTH)).isEqualTo(expectedNumElements); +- } +- +- @Test +- public void assertNumElementsShouldSucceedWithMoreNumElements() { +- errorCollector.checkSucceeds( +- () -> { +- colorSpaceType.assertNumElements(MORE_NUM_ELEMENTS, HEIGHT, WIDTH); +- return null; +- }); +- } +- +- @Test +- public void assertNumElementsShouldFailWithLessNumElements() { +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, +- () -> colorSpaceType.assertNumElements(LESS_NUM_ELEMENTS, HEIGHT, WIDTH)); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- String.format( +- "The given number of elements (%d) does not match the image (%s) in %d x %d. The" +- + " expected number of elements should be at least %d.", +- LESS_NUM_ELEMENTS, colorSpaceType.name(), HEIGHT, WIDTH, expectedNumElements)); +- } +- } +- +- /** General tests of ColorSpaceTypeTest. */ +- @RunWith(RobolectricTestRunner.class) +- public static final class General extends ColorSpaceTypeTest { +- +- @Test +- public void convertTensorBufferToBitmapShouldSuccessWithRGB() { +- TensorBuffer buffer = createRgbTensorBuffer(DataType.UINT8, false); +- Bitmap bitmap = ColorSpaceType.RGB.convertTensorBufferToBitmap(buffer); +- +- Bitmap expectedBitmap = createRgbBitmap(); +- assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); ++ /** Parameterized tests for assertNumElements/getNumElements with all image formats. */ ++ @RunWith(ParameterizedRobolectricTestRunner.class) ++ public static final class AssertNumElementsTest extends ColorSpaceTypeTest { ++ private static final int HEIGHT = 2; ++ private static final int WIDTH = 3; ++ private static final int LESS_NUM_ELEMENTS = 5; // less than expected ++ private static final int MORE_NUM_ELEMENTS = 20; // more than expected. OK. ++ @Rule ++ public ErrorCollector errorCollector = new ErrorCollector(); ++ ++ @Parameter(0) ++ public ColorSpaceType colorSpaceType; ++ ++ @Parameter(1) ++ public int expectedNumElements; ++ ++ @Parameters(name = "colorSpaceType={0};expectedNumElements={1}") ++ public static Collection<Object[]> data() { ++ return Arrays.asList(new Object[][] { ++ {ColorSpaceType.RGB, 18}, ++ {ColorSpaceType.GRAYSCALE, 6}, ++ {ColorSpaceType.NV12, 10}, ++ {ColorSpaceType.NV21, 10}, ++ {ColorSpaceType.YV12, 10}, ++ {ColorSpaceType.YV21, 10}, ++ }); ++ } ++ ++ @Test ++ public void getNumElementsShouldSucceedWithExpectedNumElements() { ++ assertThat(colorSpaceType.getNumElements(HEIGHT, WIDTH)).isEqualTo(expectedNumElements); ++ } ++ ++ @Test ++ public void assertNumElementsShouldSucceedWithMoreNumElements() { ++ errorCollector.checkSucceeds(() -> { ++ colorSpaceType.assertNumElements(MORE_NUM_ELEMENTS, HEIGHT, WIDTH); ++ return null; ++ }); ++ } ++ ++ @Test ++ public void assertNumElementsShouldFailWithLessNumElements() { ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> colorSpaceType.assertNumElements(LESS_NUM_ELEMENTS, HEIGHT, WIDTH)); ++ assertThat(exception).hasMessageThat().contains(String.format( ++ "The given number of elements (%d) does not match the image (%s) in %d x %d. The" ++ + " expected number of elements should be at least %d.", ++ LESS_NUM_ELEMENTS, colorSpaceType.name(), HEIGHT, WIDTH, expectedNumElements)); ++ } + } + +- @Test +- public void fromBitmapConfigFailsWithUnsupportedConfig() { +- Config unsupportedConfig = Config.ARGB_4444; +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, +- () -> ColorSpaceType.fromBitmapConfig(unsupportedConfig)); +- assertThat(exception) +- .hasMessageThat() +- .contains("Bitmap configuration: " + unsupportedConfig + ", is not supported yet."); ++ /** General tests of ColorSpaceTypeTest. */ ++ @RunWith(RobolectricTestRunner.class) ++ public static final class General extends ColorSpaceTypeTest { ++ @Test ++ public void convertTensorBufferToBitmapShouldSuccessWithRGB() { ++ TensorBuffer buffer = createRgbTensorBuffer(DataType.UINT8, false); ++ Bitmap bitmap = ColorSpaceType.RGB.convertTensorBufferToBitmap(buffer); ++ ++ Bitmap expectedBitmap = createRgbBitmap(); ++ assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); ++ } ++ ++ @Test ++ public void fromBitmapConfigFailsWithUnsupportedConfig() { ++ Config unsupportedConfig = Config.ARGB_4444; ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> ColorSpaceType.fromBitmapConfig(unsupportedConfig)); ++ assertThat(exception).hasMessageThat().contains( ++ "Bitmap configuration: " + unsupportedConfig + ", is not supported yet."); ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java +index 1a4d367bf0fe1..49efc4273911c 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java +@@ -21,7 +21,9 @@ import static android.graphics.Color.BLUE; + import static android.graphics.Color.GREEN; + import static android.graphics.Color.RED; + import static android.graphics.Color.WHITE; ++ + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + import static org.tensorflow.lite.support.image.ImageConversions.convertGrayscaleTensorBufferToBitmap; + +@@ -30,10 +32,10 @@ import android.content.res.AssetManager; + import android.graphics.Bitmap; + import android.graphics.BitmapFactory; + import android.util.Log; ++ + import androidx.test.core.app.ApplicationProvider; + import androidx.test.ext.junit.runners.AndroidJUnit4; +-import java.io.IOException; +-import java.util.Arrays; ++ + import org.junit.Assert; + import org.junit.Before; + import org.junit.Test; +@@ -43,192 +45,190 @@ import org.junit.runners.Suite.SuiteClasses; + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + ++import java.io.IOException; ++import java.util.Arrays; ++ + /** Instrumented unit test for {@link ImageConversions}. */ + @RunWith(Suite.class) +-@SuiteClasses({ +- ImageConversionsInstrumentedTest.TensorBufferToBitmap.class, +- ImageConversionsInstrumentedTest.BitmapToTensorBuffer.class +-}) ++@SuiteClasses({ImageConversionsInstrumentedTest.TensorBufferToBitmap.class, ++ ImageConversionsInstrumentedTest.BitmapToTensorBuffer.class}) + public class ImageConversionsInstrumentedTest { ++ /** Tests for the TensorBuffer data type and normalized form. */ ++ // Note that parameterized test with android_library_instrumentation_tests is currently not ++ // supported internally. ++ @RunWith(AndroidJUnit4.class) ++ public static final class TensorBufferToBitmap extends ImageConversionsInstrumentedTest { ++ @Test ++ public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatNormalized() { ++ DataType dataType = DataType.FLOAT32; ++ boolean isNormalized = true; ++ ++ TensorBuffer buffer = ++ TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized); ++ Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer); ++ ++ Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap(); ++ assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); ++ } + +- /** Tests for the TensorBuffer data type and normalized form. */ +- // Note that parameterized test with android_library_instrumentation_tests is currently not +- // supported internally. +- @RunWith(AndroidJUnit4.class) +- public static final class TensorBufferToBitmap extends ImageConversionsInstrumentedTest { +- +- @Test +- public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatNormalized() { +- DataType dataType = DataType.FLOAT32; +- boolean isNormalized = true; ++ @Test ++ public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatUnnormalized() { ++ DataType dataType = DataType.FLOAT32; ++ boolean isNormalized = false; + +- TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized); +- Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer); ++ TensorBuffer buffer = ++ TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized); ++ Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer); + +- Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap(); +- assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); +- } +- +- @Test +- public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatUnnormalized() { +- DataType dataType = DataType.FLOAT32; +- boolean isNormalized = false; ++ Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap(); ++ assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); ++ } + +- TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized); +- Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer); ++ @Test ++ public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Normalized() { ++ DataType dataType = DataType.UINT8; ++ boolean isNormalized = true; + +- Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap(); +- assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); +- } ++ TensorBuffer buffer = ++ TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized); ++ Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer); + +- @Test +- public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Normalized() { +- DataType dataType = DataType.UINT8; +- boolean isNormalized = true; +- +- TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized); +- Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer); ++ Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap(); ++ assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); ++ } + +- Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap(); +- assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); +- } ++ @Test ++ public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Unnormalized() { ++ DataType dataType = DataType.UINT8; ++ boolean isNormalized = false; + +- @Test +- public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Unnormalized() { +- DataType dataType = DataType.UINT8; +- boolean isNormalized = false; ++ TensorBuffer buffer = ++ TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized); ++ Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer); + +- TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized); +- Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer); ++ Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap(); ++ assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); ++ } + +- Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap(); +- assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); +- } ++ @Test ++ public void ++ convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithFloat() { ++ DataType dataType = DataType.FLOAT32; ++ TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType); ++ ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> convertGrayscaleTensorBufferToBitmap(buffer)); ++ assertThat(exception).hasMessageThat().contains( ++ "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image" ++ + " shape is " + Arrays.toString(buffer.getShape())); ++ } + +- @Test +- public void convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithFloat() { +- DataType dataType = DataType.FLOAT32; +- TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType); +- +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> convertGrayscaleTensorBufferToBitmap(buffer)); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image" +- + " shape is " +- + Arrays.toString(buffer.getShape())); ++ @Test ++ public void ++ convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithUint8() { ++ DataType dataType = DataType.UINT8; ++ TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType); ++ ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> convertGrayscaleTensorBufferToBitmap(buffer)); ++ assertThat(exception).hasMessageThat().contains( ++ "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image" ++ + " shape is " + Arrays.toString(buffer.getShape())); ++ } + } + +- @Test +- public void convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithUint8() { +- DataType dataType = DataType.UINT8; +- TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType); +- +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> convertGrayscaleTensorBufferToBitmap(buffer)); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image" +- + " shape is " +- + Arrays.toString(buffer.getShape())); +- } +- } +- +- /** BitmapToTensorBuffer tests of ImageConversionsInstrumentedTest. */ +- @RunWith(AndroidJUnit4.class) +- public static final class BitmapToTensorBuffer extends ImageConversionsInstrumentedTest { +- +- private Bitmap greyGrid; +- private Bitmap colorGrid; +- private TensorBuffer buffer; +- +- static final String GREY_GRID_PATH = "grey_grid.png"; +- static final String COLOR_GRID_PATH = "color_grid.png"; +- +- @Before +- public void loadAssets() { +- Context context = ApplicationProvider.getApplicationContext(); +- AssetManager assetManager = context.getAssets(); +- try { +- greyGrid = BitmapFactory.decodeStream(assetManager.open(GREY_GRID_PATH)); +- colorGrid = BitmapFactory.decodeStream(assetManager.open(COLOR_GRID_PATH)); +- } catch (IOException e) { +- Log.e("Test", "Cannot load asset files"); +- } +- Assert.assertEquals(ARGB_8888, greyGrid.getConfig()); +- Assert.assertEquals(ARGB_8888, colorGrid.getConfig()); +- buffer = TensorBuffer.createDynamic(DataType.UINT8); +- } ++ /** BitmapToTensorBuffer tests of ImageConversionsInstrumentedTest. */ ++ @RunWith(AndroidJUnit4.class) ++ public static final class BitmapToTensorBuffer extends ImageConversionsInstrumentedTest { ++ private Bitmap greyGrid; ++ private Bitmap colorGrid; ++ private TensorBuffer buffer; ++ ++ static final String GREY_GRID_PATH = "grey_grid.png"; ++ static final String COLOR_GRID_PATH = "color_grid.png"; ++ ++ @Before ++ public void loadAssets() { ++ Context context = ApplicationProvider.getApplicationContext(); ++ AssetManager assetManager = context.getAssets(); ++ try { ++ greyGrid = BitmapFactory.decodeStream(assetManager.open(GREY_GRID_PATH)); ++ colorGrid = BitmapFactory.decodeStream(assetManager.open(COLOR_GRID_PATH)); ++ } catch (IOException e) { ++ Log.e("Test", "Cannot load asset files"); ++ } ++ Assert.assertEquals(ARGB_8888, greyGrid.getConfig()); ++ Assert.assertEquals(ARGB_8888, colorGrid.getConfig()); ++ buffer = TensorBuffer.createDynamic(DataType.UINT8); ++ } + +- @Test +- public void testBitmapDimensionLayout() { +- // This test is not only for proving the correctness of bitmap -> TensorBuffer conversion, but +- // also for us to better understand how Android Bitmap is storing pixels - height first or +- // width first. +- // We use a black image which has a white corner to understand what happens. By setting up the +- // correct loop to pass the test, we can reveal the order of pixels returned from `getPixels`. +- // The result shows that Android stores bitmap in an h-first manner. The returned array of +- // `getPixels` is like [ 1st row, 2nd row, ... ] which is the same with TFLite. +- Assert.assertEquals(100, greyGrid.getWidth()); +- Assert.assertEquals(100, greyGrid.getHeight()); +- Assert.assertEquals(BLACK, greyGrid.getPixel(25, 25)); // left top +- Assert.assertEquals(BLACK, greyGrid.getPixel(75, 25)); // right top +- Assert.assertEquals(WHITE, greyGrid.getPixel(25, 75)); // left bottom +- Assert.assertEquals(BLACK, greyGrid.getPixel(75, 75)); // right bottom +- +- ImageConversions.convertBitmapToTensorBuffer(greyGrid, buffer); +- Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape()); +- Assert.assertEquals(DataType.UINT8, buffer.getDataType()); +- +- int[] pixels = buffer.getIntArray(); +- int index = 0; +- for (int h = 0; h < 100; h++) { +- for (int w = 0; w < 100; w++) { +- int expected = (w < 50 && h >= 50) ? 255 : 0; +- Assert.assertEquals(expected, pixels[index++]); +- Assert.assertEquals(expected, pixels[index++]); +- Assert.assertEquals(expected, pixels[index++]); ++ @Test ++ public void testBitmapDimensionLayout() { ++ // This test is not only for proving the correctness of bitmap -> TensorBuffer ++ // conversion, but also for us to better understand how Android Bitmap is storing pixels ++ // - height first or width first. We use a black image which has a white corner to ++ // understand what happens. By setting up the correct loop to pass the test, we can ++ // reveal the order of pixels returned from `getPixels`. The result shows that Android ++ // stores bitmap in an h-first manner. The returned array of `getPixels` is like [ 1st ++ // row, 2nd row, ... ] which is the same with TFLite. ++ Assert.assertEquals(100, greyGrid.getWidth()); ++ Assert.assertEquals(100, greyGrid.getHeight()); ++ Assert.assertEquals(BLACK, greyGrid.getPixel(25, 25)); // left top ++ Assert.assertEquals(BLACK, greyGrid.getPixel(75, 25)); // right top ++ Assert.assertEquals(WHITE, greyGrid.getPixel(25, 75)); // left bottom ++ Assert.assertEquals(BLACK, greyGrid.getPixel(75, 75)); // right bottom ++ ++ ImageConversions.convertBitmapToTensorBuffer(greyGrid, buffer); ++ Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape()); ++ Assert.assertEquals(DataType.UINT8, buffer.getDataType()); ++ ++ int[] pixels = buffer.getIntArray(); ++ int index = 0; ++ for (int h = 0; h < 100; h++) { ++ for (int w = 0; w < 100; w++) { ++ int expected = (w < 50 && h >= 50) ? 255 : 0; ++ Assert.assertEquals(expected, pixels[index++]); ++ Assert.assertEquals(expected, pixels[index++]); ++ Assert.assertEquals(expected, pixels[index++]); ++ } ++ } + } +- } +- } + +- @Test +- public void testBitmapARGB8888ChannelLayout() { +- // This test is not only for proving the correctness of bitmap -> TensorBuffer conversion, but +- // also for us to better understand how Android Bitmap is storing pixels - RGB channel or +- // other possible ordering. +- // We use an colored grid image to understand what happens. It's a simple grid image with 4 +- // grid in different colors. Passed through our Bitmap -> TensorBuffer conversion which simply +- // unpack channels from an integer returned from `getPixel`, its channel sequence could be +- // revealed directly. +- // The result shows that Android Bitmap has no magic when loading channels. If loading from +- // PNG images, channel order still remains R-G-B. +- Assert.assertEquals(100, colorGrid.getWidth()); +- Assert.assertEquals(100, colorGrid.getHeight()); +- Assert.assertEquals(BLUE, colorGrid.getPixel(25, 25)); // left top +- Assert.assertEquals(BLACK, colorGrid.getPixel(75, 25)); // right top +- Assert.assertEquals(GREEN, colorGrid.getPixel(25, 75)); // left bottom +- Assert.assertEquals(RED, colorGrid.getPixel(75, 75)); // right bottom +- +- ImageConversions.convertBitmapToTensorBuffer(colorGrid, buffer); +- Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape()); +- Assert.assertEquals(DataType.UINT8, buffer.getDataType()); +- +- int[] pixels = buffer.getIntArray(); +- Assert.assertArrayEquals(new int[] {0, 0, 255}, getChannels(pixels, 25, 25)); // left top +- Assert.assertArrayEquals(new int[] {0, 0, 0}, getChannels(pixels, 25, 75)); // right top +- Assert.assertArrayEquals(new int[] {0, 255, 0}, getChannels(pixels, 75, 25)); // left bottom +- Assert.assertArrayEquals(new int[] {255, 0, 0}, getChannels(pixels, 75, 75)); // right bottom +- } ++ @Test ++ public void testBitmapARGB8888ChannelLayout() { ++ // This test is not only for proving the correctness of bitmap -> TensorBuffer ++ // conversion, but also for us to better understand how Android Bitmap is storing pixels ++ // - RGB channel or other possible ordering. We use an colored grid image to understand ++ // what happens. It's a simple grid image with 4 grid in different colors. Passed ++ // through our Bitmap -> TensorBuffer conversion which simply unpack channels from an ++ // integer returned from `getPixel`, its channel sequence could be revealed directly. ++ // The result shows that Android Bitmap has no magic when loading channels. If loading ++ // from PNG images, channel order still remains R-G-B. ++ Assert.assertEquals(100, colorGrid.getWidth()); ++ Assert.assertEquals(100, colorGrid.getHeight()); ++ Assert.assertEquals(BLUE, colorGrid.getPixel(25, 25)); // left top ++ Assert.assertEquals(BLACK, colorGrid.getPixel(75, 25)); // right top ++ Assert.assertEquals(GREEN, colorGrid.getPixel(25, 75)); // left bottom ++ Assert.assertEquals(RED, colorGrid.getPixel(75, 75)); // right bottom ++ ++ ImageConversions.convertBitmapToTensorBuffer(colorGrid, buffer); ++ Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape()); ++ Assert.assertEquals(DataType.UINT8, buffer.getDataType()); ++ ++ int[] pixels = buffer.getIntArray(); ++ Assert.assertArrayEquals( ++ new int[] {0, 0, 255}, getChannels(pixels, 25, 25)); // left top ++ Assert.assertArrayEquals(new int[] {0, 0, 0}, getChannels(pixels, 25, 75)); // right top ++ Assert.assertArrayEquals( ++ new int[] {0, 255, 0}, getChannels(pixels, 75, 25)); // left bottom ++ Assert.assertArrayEquals( ++ new int[] {255, 0, 0}, getChannels(pixels, 75, 75)); // right bottom ++ } + +- /** Helper function only for {@link #testBitmapARGB8888ChannelLayout()}. */ +- private static int[] getChannels(int[] pixels, int h, int w) { +- int id = (h * 100 + w) * 3; +- return new int[] {pixels[id++], pixels[id++], pixels[id]}; ++ /** Helper function only for {@link #testBitmapARGB8888ChannelLayout()}. */ ++ private static int[] getChannels(int[] pixels, int h, int w) { ++ int id = (h * 100 + w) * 3; ++ return new int[] {pixels[id++], pixels[id++], pixels[id]}; ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java +index b3300872c2357..c91db9d184f63 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java +@@ -16,13 +16,13 @@ limitations under the License. + package org.tensorflow.lite.support.image; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + import static org.tensorflow.lite.support.image.ImageConversions.convertBitmapToTensorBuffer; + import static org.tensorflow.lite.support.image.ImageConversions.convertRgbTensorBufferToBitmap; + + import android.graphics.Bitmap; +-import java.util.Arrays; +-import java.util.Collection; ++ + import org.junit.Assert; + import org.junit.Test; + import org.junit.runner.RunWith; +@@ -35,93 +35,93 @@ import org.robolectric.RobolectricTestRunner; + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + ++import java.util.Arrays; ++import java.util.Collection; ++ + /** Tests of {@link ImageConversions}. */ + @RunWith(Suite.class) + @SuiteClasses({ImageConversionsTest.TensorBufferToBitmap.class, ImageConversionsTest.General.class}) + public class ImageConversionsTest { +- +- /** Parameterized tests for the TensorBuffer data type and normalized form. */ +- @RunWith(ParameterizedRobolectricTestRunner.class) +- public static final class TensorBufferToBitmap extends ImageConversionsTest { +- +- /** The data type that used to create the TensorBuffer. */ +- @Parameter(0) +- public DataType dataType; +- +- /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */ +- @Parameter(1) +- public boolean isNormalized; +- +- @Parameters(name = "dataType={0}; isNormalized={1}") +- public static Collection<Object[]> data() { +- return Arrays.asList( +- new Object[][] { +- {DataType.FLOAT32, true}, {DataType.UINT8, true}, +- {DataType.FLOAT32, false}, {DataType.UINT8, false}, +- }); +- } +- +- @Test +- public void convertRgbTensorBufferToBitmapShouldSuccess() { +- TensorBuffer buffer = TestImageCreator.createRgbTensorBuffer(dataType, isNormalized); +- Bitmap bitmap = convertRgbTensorBufferToBitmap(buffer); +- +- Bitmap expectedBitmap = TestImageCreator.createRgbBitmap(); +- assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); +- } +- +- @Test +- public void convertRgbTensorBufferToBitmapShouldRejectBufferWithInvalidShape() { +- TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10, 3}, dataType); +- +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> convertRgbTensorBufferToBitmap(buffer)); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels" +- + " representing R, G, B in order. The provided image shape is " +- + Arrays.toString(buffer.getShape())); +- } +- } +- +- /** General tests of ImageConversionsTest. */ +- @RunWith(RobolectricTestRunner.class) +- public static final class General extends ImageConversionsTest { +- +- private static final Bitmap rgbBitmap = TestImageCreator.createRgbBitmap(); +- private static final TensorBuffer rgbTensorBuffer = +- TestImageCreator.createRgbTensorBuffer(DataType.UINT8, false); +- +- @Test +- public void convertBitmapToTensorBufferShouldSuccess() { +- TensorBuffer intBuffer = TensorBuffer.createFixedSize(new int[] {10, 10, 3}, DataType.UINT8); +- convertBitmapToTensorBuffer(rgbBitmap, intBuffer); +- assertThat(areEqualIntTensorBuffer(intBuffer, rgbTensorBuffer)).isTrue(); +- } +- +- @Test +- public void convertBitmapToTensorBufferShouldThrowShapeNotExactlySame() { +- TensorBuffer intBuffer = TensorBuffer.createFixedSize(new int[] {5, 20, 3}, DataType.UINT8); +- Assert.assertThrows( +- IllegalArgumentException.class, () -> convertBitmapToTensorBuffer(rgbBitmap, intBuffer)); ++ /** Parameterized tests for the TensorBuffer data type and normalized form. */ ++ @RunWith(ParameterizedRobolectricTestRunner.class) ++ public static final class TensorBufferToBitmap extends ImageConversionsTest { ++ /** The data type that used to create the TensorBuffer. */ ++ @Parameter(0) ++ public DataType dataType; ++ ++ /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */ ++ @Parameter(1) ++ public boolean isNormalized; ++ ++ @Parameters(name = "dataType={0}; isNormalized={1}") ++ public static Collection<Object[]> data() { ++ return Arrays.asList(new Object[][] { ++ {DataType.FLOAT32, true}, ++ {DataType.UINT8, true}, ++ {DataType.FLOAT32, false}, ++ {DataType.UINT8, false}, ++ }); ++ } ++ ++ @Test ++ public void convertRgbTensorBufferToBitmapShouldSuccess() { ++ TensorBuffer buffer = TestImageCreator.createRgbTensorBuffer(dataType, isNormalized); ++ Bitmap bitmap = convertRgbTensorBufferToBitmap(buffer); ++ ++ Bitmap expectedBitmap = TestImageCreator.createRgbBitmap(); ++ assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); ++ } ++ ++ @Test ++ public void convertRgbTensorBufferToBitmapShouldRejectBufferWithInvalidShape() { ++ TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10, 3}, dataType); ++ ++ IllegalArgumentException exception = assertThrows( ++ IllegalArgumentException.class, () -> convertRgbTensorBufferToBitmap(buffer)); ++ assertThat(exception).hasMessageThat().contains( ++ "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels" ++ + " representing R, G, B in order. The provided image shape is " ++ + Arrays.toString(buffer.getShape())); ++ } + } + +- @Test +- public void convertBitmapToTensorBufferShouldCastIntToFloatIfNeeded() { +- TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); +- convertBitmapToTensorBuffer(rgbBitmap, floatBuffer); +- assertThat(areEqualIntTensorBuffer(floatBuffer, rgbTensorBuffer)).isTrue(); ++ /** General tests of ImageConversionsTest. */ ++ @RunWith(RobolectricTestRunner.class) ++ public static final class General extends ImageConversionsTest { ++ private static final Bitmap rgbBitmap = TestImageCreator.createRgbBitmap(); ++ private static final TensorBuffer rgbTensorBuffer = ++ TestImageCreator.createRgbTensorBuffer(DataType.UINT8, false); ++ ++ @Test ++ public void convertBitmapToTensorBufferShouldSuccess() { ++ TensorBuffer intBuffer = ++ TensorBuffer.createFixedSize(new int[] {10, 10, 3}, DataType.UINT8); ++ convertBitmapToTensorBuffer(rgbBitmap, intBuffer); ++ assertThat(areEqualIntTensorBuffer(intBuffer, rgbTensorBuffer)).isTrue(); ++ } ++ ++ @Test ++ public void convertBitmapToTensorBufferShouldThrowShapeNotExactlySame() { ++ TensorBuffer intBuffer = ++ TensorBuffer.createFixedSize(new int[] {5, 20, 3}, DataType.UINT8); ++ Assert.assertThrows(IllegalArgumentException.class, ++ () -> convertBitmapToTensorBuffer(rgbBitmap, intBuffer)); ++ } ++ ++ @Test ++ public void convertBitmapToTensorBufferShouldCastIntToFloatIfNeeded() { ++ TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); ++ convertBitmapToTensorBuffer(rgbBitmap, floatBuffer); ++ assertThat(areEqualIntTensorBuffer(floatBuffer, rgbTensorBuffer)).isTrue(); ++ } + } +- } + +- private static boolean areEqualIntTensorBuffer(TensorBuffer tb1, TensorBuffer tb2) { +- if (!Arrays.equals(tb1.getShape(), tb2.getShape())) { +- return false; ++ private static boolean areEqualIntTensorBuffer(TensorBuffer tb1, TensorBuffer tb2) { ++ if (!Arrays.equals(tb1.getShape(), tb2.getShape())) { ++ return false; ++ } ++ int[] arr1 = tb1.getIntArray(); ++ int[] arr2 = tb2.getIntArray(); ++ return Arrays.equals(arr1, arr2); + } +- int[] arr1 = tb1.getIntArray(); +- int[] arr2 = tb2.getIntArray(); +- return Arrays.equals(arr1, arr2); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java +index 8ac27fdb07ad1..e9cbfc1dc50bd 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java +@@ -16,10 +16,13 @@ limitations under the License. + package org.tensorflow.lite.support.image; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + + import android.graphics.Bitmap; ++ + import androidx.test.ext.junit.runners.AndroidJUnit4; ++ + import org.junit.Before; + import org.junit.Test; + import org.junit.runner.RunWith; +@@ -30,120 +33,114 @@ import org.tensorflow.lite.support.image.ops.Rot90Op; + /** Instrumented unit test for {@link ImageProcessor}. */ + @RunWith(AndroidJUnit4.class) + public final class ImageProcessorInstrumentedTest { ++ private Bitmap exampleBitmap; ++ private TensorImage input; ++ private ImageProcessor processor; ++ ++ private static final int EXAMPLE_WIDTH = 10; ++ private static final int EXAMPLE_HEIGHT = 15; ++ ++ @Before ++ public void setUp() { ++ // The default number of rotation is once. ++ processor = new ImageProcessor.Builder().add(new Rot90Op()).build(); ++ exampleBitmap = createExampleBitmap(); ++ input = new TensorImage(DataType.UINT8); ++ input.load(exampleBitmap); ++ } ++ ++ @Test ++ public void updateNumberOfRotations_rotateTwice() { ++ int numberOfRotations = 2; ++ ++ processor.updateNumberOfRotations(numberOfRotations); ++ TensorImage output = processor.process(input); ++ ++ Bitmap outputBitmap = output.getBitmap(); ++ assertExampleBitmapWithTwoRotations(outputBitmap); ++ } ++ ++ @Test ++ public void updateNumberOfRotationsWithOpIndex_rotateTwiceAndOpIndex0() { ++ int numberOfRotations = 2; ++ int occurrence = 0; ++ ++ processor.updateNumberOfRotations(numberOfRotations, occurrence); ++ TensorImage output = processor.process(input); ++ ++ Bitmap outputBitmap = output.getBitmap(); ++ assertExampleBitmapWithTwoRotations(outputBitmap); ++ } ++ ++ @Test ++ public void updateNumberOfRotationsWithOpIndex_negativeOpIndex() { ++ int numberOfRotations = 2; ++ int negativeOpIndex = -1; ++ ++ IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class, ++ () -> processor.updateNumberOfRotations(numberOfRotations, negativeOpIndex)); ++ assertThat(exception).hasMessageThat().isEqualTo("occurrence (-1) must not be negative"); ++ } ++ ++ @Test ++ public void updateNumberOfRotationsWithOpIndex_occurrenceEqualToTheNumberOfRot90Op() { ++ int numberOfRotations = 2; ++ int occurrence = 1; ++ ++ IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class, ++ () -> processor.updateNumberOfRotations(numberOfRotations, occurrence)); ++ assertThat(exception).hasMessageThat().isEqualTo( ++ "occurrence (1) must be less than size (1)"); ++ } ++ ++ @Test ++ public void updateNumberOfRotationsWithOpIndex_noRot90OpIsAddedToImageProcessor() { ++ int numberOfRotations = 2; ++ int occurrence = 1; ++ // Add an op other than Rot90Op into ImageProcessor. ++ ImageProcessor processor = ++ new ImageProcessor.Builder().add(new ResizeWithCropOrPadOp(5, 5)).build(); ++ ++ IllegalStateException exception = assertThrows(IllegalStateException.class, ++ () -> processor.updateNumberOfRotations(numberOfRotations, occurrence)); ++ assertThat(exception).hasMessageThat().isEqualTo( ++ "The Rot90Op has not been added to the ImageProcessor."); ++ } ++ ++ @Test ++ public void updateNumberOfRotationsWithOpIndex_twoRot90Ops() { ++ // The overall effect of the two rotations is equivalent to rotating for twice. ++ int numberOfRotations0 = 5; ++ int numberOfRotations1 = 1; ++ ++ // Add two Rot90Ops into ImageProcessor. ++ ImageProcessor processor = ++ new ImageProcessor.Builder().add(new Rot90Op()).add(new Rot90Op()).build(); ++ processor.updateNumberOfRotations(numberOfRotations0, /*occurrence=*/0); ++ processor.updateNumberOfRotations(numberOfRotations1, /*occurrence=*/1); ++ ++ TensorImage output = processor.process(input); ++ Bitmap outputBitmap = output.getBitmap(); ++ assertExampleBitmapWithTwoRotations(outputBitmap); ++ } + +- private Bitmap exampleBitmap; +- private TensorImage input; +- private ImageProcessor processor; +- +- private static final int EXAMPLE_WIDTH = 10; +- private static final int EXAMPLE_HEIGHT = 15; +- +- @Before +- public void setUp() { +- // The default number of rotation is once. +- processor = new ImageProcessor.Builder().add(new Rot90Op()).build(); +- exampleBitmap = createExampleBitmap(); +- input = new TensorImage(DataType.UINT8); +- input.load(exampleBitmap); +- } +- +- @Test +- public void updateNumberOfRotations_rotateTwice() { +- int numberOfRotations = 2; +- +- processor.updateNumberOfRotations(numberOfRotations); +- TensorImage output = processor.process(input); +- +- Bitmap outputBitmap = output.getBitmap(); +- assertExampleBitmapWithTwoRotations(outputBitmap); +- } +- +- @Test +- public void updateNumberOfRotationsWithOpIndex_rotateTwiceAndOpIndex0() { +- int numberOfRotations = 2; +- int occurrence = 0; +- +- processor.updateNumberOfRotations(numberOfRotations, occurrence); +- TensorImage output = processor.process(input); +- +- Bitmap outputBitmap = output.getBitmap(); +- assertExampleBitmapWithTwoRotations(outputBitmap); +- } +- +- @Test +- public void updateNumberOfRotationsWithOpIndex_negativeOpIndex() { +- int numberOfRotations = 2; +- int negativeOpIndex = -1; +- +- IndexOutOfBoundsException exception = +- assertThrows( +- IndexOutOfBoundsException.class, +- () -> processor.updateNumberOfRotations(numberOfRotations, negativeOpIndex)); +- assertThat(exception).hasMessageThat().isEqualTo("occurrence (-1) must not be negative"); +- } +- +- @Test +- public void updateNumberOfRotationsWithOpIndex_occurrenceEqualToTheNumberOfRot90Op() { +- int numberOfRotations = 2; +- int occurrence = 1; +- +- IndexOutOfBoundsException exception = +- assertThrows( +- IndexOutOfBoundsException.class, +- () -> processor.updateNumberOfRotations(numberOfRotations, occurrence)); +- assertThat(exception).hasMessageThat().isEqualTo("occurrence (1) must be less than size (1)"); +- } +- +- @Test +- public void updateNumberOfRotationsWithOpIndex_noRot90OpIsAddedToImageProcessor() { +- int numberOfRotations = 2; +- int occurrence = 1; +- // Add an op other than Rot90Op into ImageProcessor. +- ImageProcessor processor = +- new ImageProcessor.Builder().add(new ResizeWithCropOrPadOp(5, 5)).build(); +- +- IllegalStateException exception = +- assertThrows( +- IllegalStateException.class, +- () -> processor.updateNumberOfRotations(numberOfRotations, occurrence)); +- assertThat(exception) +- .hasMessageThat() +- .isEqualTo("The Rot90Op has not been added to the ImageProcessor."); +- } +- +- @Test +- public void updateNumberOfRotationsWithOpIndex_twoRot90Ops() { +- // The overall effect of the two rotations is equivalent to rotating for twice. +- int numberOfRotations0 = 5; +- int numberOfRotations1 = 1; +- +- // Add two Rot90Ops into ImageProcessor. +- ImageProcessor processor = +- new ImageProcessor.Builder().add(new Rot90Op()).add(new Rot90Op()).build(); +- processor.updateNumberOfRotations(numberOfRotations0, /*occurrence=*/ 0); +- processor.updateNumberOfRotations(numberOfRotations1, /*occurrence=*/ 1); +- +- TensorImage output = processor.process(input); +- Bitmap outputBitmap = output.getBitmap(); +- assertExampleBitmapWithTwoRotations(outputBitmap); +- } +- +- private void assertExampleBitmapWithTwoRotations(Bitmap bitmapRotated) { +- assertThat(bitmapRotated.getWidth()).isEqualTo(EXAMPLE_WIDTH); +- assertThat(bitmapRotated.getHeight()).isEqualTo(EXAMPLE_HEIGHT); +- for (int i = 0; i < exampleBitmap.getWidth(); i++) { +- for (int j = 0; j < exampleBitmap.getHeight(); j++) { +- assertThat(exampleBitmap.getPixel(i, j)) +- .isEqualTo(bitmapRotated.getPixel(EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j)); +- } ++ private void assertExampleBitmapWithTwoRotations(Bitmap bitmapRotated) { ++ assertThat(bitmapRotated.getWidth()).isEqualTo(EXAMPLE_WIDTH); ++ assertThat(bitmapRotated.getHeight()).isEqualTo(EXAMPLE_HEIGHT); ++ for (int i = 0; i < exampleBitmap.getWidth(); i++) { ++ for (int j = 0; j < exampleBitmap.getHeight(); j++) { ++ assertThat(exampleBitmap.getPixel(i, j)) ++ .isEqualTo(bitmapRotated.getPixel( ++ EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j)); ++ } ++ } + } +- } + +- private static Bitmap createExampleBitmap() { +- int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT]; +- for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) { +- colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); ++ private static Bitmap createExampleBitmap() { ++ int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT]; ++ for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) { ++ colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); ++ } ++ return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); + } +- return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java +index a655f4a506900..a93ba5465125c 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java +@@ -16,10 +16,12 @@ limitations under the License. + package org.tensorflow.lite.support.image; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + + import android.graphics.Bitmap; + import android.graphics.RectF; ++ + import org.junit.Test; + import org.junit.runner.RunWith; + import org.robolectric.RobolectricTestRunner; +@@ -34,115 +36,112 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + /** Tests for {@link ImageProcessor}. */ + @RunWith(RobolectricTestRunner.class) + public final class ImageProcessorTest { ++ private static final int EXAMPLE_WIDTH = 10; ++ private static final int EXAMPLE_HEIGHT = 15; ++ private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH; ++ private static final int EXAMPLE_NUM_CHANNELS = 3; ++ private static final float MEAN = 127.5f; ++ private static final float STDDEV = 127.5f; ++ ++ @Test ++ public void testBuild() { ++ ImageProcessor processor = ++ new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build(); ++ assertThat(processor).isNotNull(); ++ } + +- private static final int EXAMPLE_WIDTH = 10; +- private static final int EXAMPLE_HEIGHT = 15; +- private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH; +- private static final int EXAMPLE_NUM_CHANNELS = 3; +- private static final float MEAN = 127.5f; +- private static final float STDDEV = 127.5f; +- +- @Test +- public void testBuild() { +- ImageProcessor processor = +- new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build(); +- assertThat(processor).isNotNull(); +- } +- +- @Test +- public void testNormalize() { +- TensorImage input = new TensorImage(DataType.FLOAT32); +- input.load(createExampleBitmap()); +- ImageProcessor processor = +- new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build(); +- TensorImage output = processor.process(input); +- +- float[] pixels = output.getTensorBuffer().getFloatArray(); +- assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS); +- for (float p : pixels) { +- assertThat(p).isAtLeast(-1); +- assertThat(p).isAtMost(1); ++ @Test ++ public void testNormalize() { ++ TensorImage input = new TensorImage(DataType.FLOAT32); ++ input.load(createExampleBitmap()); ++ ImageProcessor processor = ++ new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build(); ++ TensorImage output = processor.process(input); ++ ++ float[] pixels = output.getTensorBuffer().getFloatArray(); ++ assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS); ++ for (float p : pixels) { ++ assertThat(p).isAtLeast(-1); ++ assertThat(p).isAtMost(1); ++ } + } +- } +- +- @Test +- public void testMultipleNormalize() { +- TensorImage input = new TensorImage(DataType.FLOAT32); +- input.load(createExampleBitmap()); +- ImageProcessor processor = +- new ImageProcessor.Builder() +- .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1] +- .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1] +- .build(); +- TensorImage output = processor.process(input); +- +- float[] pixels = output.getTensorBuffer().getFloatArray(); +- assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS); +- for (float p : pixels) { +- assertThat(p).isAtLeast(0); +- assertThat(p).isAtMost(1); ++ ++ @Test ++ public void testMultipleNormalize() { ++ TensorImage input = new TensorImage(DataType.FLOAT32); ++ input.load(createExampleBitmap()); ++ ImageProcessor processor = ++ new ImageProcessor.Builder() ++ .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1] ++ .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1] ++ .build(); ++ TensorImage output = processor.process(input); ++ ++ float[] pixels = output.getTensorBuffer().getFloatArray(); ++ assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS); ++ for (float p : pixels) { ++ assertThat(p).isAtLeast(0); ++ assertThat(p).isAtMost(1); ++ } + } +- } +- +- @Test +- public void inverseTransformRectCorrectly() { +- ImageProcessor processor = +- new ImageProcessor.Builder() +- .add(new ResizeOp(200, 300, ResizeMethod.BILINEAR)) +- .add(new ResizeWithCropOrPadOp(100, 200)) +- .add(new Rot90Op(1)) +- .add(new NormalizeOp(127, 128)) +- .build(); +- RectF transformed = new RectF(0, 50, 100, 150); +- RectF original = processor.inverseTransform(transformed, 400, 600); +- assertThat(original.top).isEqualTo(100); +- assertThat(original.left).isEqualTo(200); +- assertThat(original.right).isEqualTo(400); +- assertThat(original.bottom).isEqualTo(300); +- } +- +- @Test +- public void resizeShouldFailWithNonRgbImages() { +- int[] data = new int[] {1, 2, 3}; +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); +- tensorBuffer.loadArray(data, new int[] {1, 3}); +- TensorImage image = new TensorImage(); +- image.load(tensorBuffer, ColorSpaceType.GRAYSCALE); +- +- ImageProcessor processor = +- new ImageProcessor.Builder().add(new ResizeOp(200, 300, ResizeMethod.BILINEAR)).build(); +- +- IllegalArgumentException exception = +- assertThrows(IllegalArgumentException.class, () -> processor.process(image)); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "Only RGB images are supported in ResizeOp, but not " ++ ++ @Test ++ public void inverseTransformRectCorrectly() { ++ ImageProcessor processor = new ImageProcessor.Builder() ++ .add(new ResizeOp(200, 300, ResizeMethod.BILINEAR)) ++ .add(new ResizeWithCropOrPadOp(100, 200)) ++ .add(new Rot90Op(1)) ++ .add(new NormalizeOp(127, 128)) ++ .build(); ++ RectF transformed = new RectF(0, 50, 100, 150); ++ RectF original = processor.inverseTransform(transformed, 400, 600); ++ assertThat(original.top).isEqualTo(100); ++ assertThat(original.left).isEqualTo(200); ++ assertThat(original.right).isEqualTo(400); ++ assertThat(original.bottom).isEqualTo(300); ++ } ++ ++ @Test ++ public void resizeShouldFailWithNonRgbImages() { ++ int[] data = new int[] {1, 2, 3}; ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); ++ tensorBuffer.loadArray(data, new int[] {1, 3}); ++ TensorImage image = new TensorImage(); ++ image.load(tensorBuffer, ColorSpaceType.GRAYSCALE); ++ ++ ImageProcessor processor = new ImageProcessor.Builder() ++ .add(new ResizeOp(200, 300, ResizeMethod.BILINEAR)) ++ .build(); ++ ++ IllegalArgumentException exception = ++ assertThrows(IllegalArgumentException.class, () -> processor.process(image)); ++ assertThat(exception).hasMessageThat().contains( ++ "Only RGB images are supported in ResizeOp, but not " + + image.getColorSpaceType().name()); +- } +- +- @Test +- public void normalizeShouldSuccessWithNonRgbImages() { +- int[] data = new int[] {1, 2, 3}; +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); +- tensorBuffer.loadArray(data, new int[] {1, 3}); +- TensorImage image = new TensorImage(); +- image.load(tensorBuffer, ColorSpaceType.GRAYSCALE); +- +- ImageProcessor processor = +- new ImageProcessor.Builder().add(new NormalizeOp(0.5f, 1f)).build(); +- TensorImage output = processor.process(image); +- +- float[] pixels = output.getTensorBuffer().getFloatArray(); +- assertThat(pixels).isEqualTo(new float[]{0.5f, 1.5f, 2.5f}); +- } +- +- private static Bitmap createExampleBitmap() { +- int[] colors = new int[EXAMPLE_NUM_PIXELS]; +- for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) { +- colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); + } + +- return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); +- } ++ @Test ++ public void normalizeShouldSuccessWithNonRgbImages() { ++ int[] data = new int[] {1, 2, 3}; ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); ++ tensorBuffer.loadArray(data, new int[] {1, 3}); ++ TensorImage image = new TensorImage(); ++ image.load(tensorBuffer, ColorSpaceType.GRAYSCALE); ++ ++ ImageProcessor processor = ++ new ImageProcessor.Builder().add(new NormalizeOp(0.5f, 1f)).build(); ++ TensorImage output = processor.process(image); ++ ++ float[] pixels = output.getTensorBuffer().getFloatArray(); ++ assertThat(pixels).isEqualTo(new float[] {0.5f, 1.5f, 2.5f}); ++ } ++ ++ private static Bitmap createExampleBitmap() { ++ int[] colors = new int[EXAMPLE_NUM_PIXELS]; ++ for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) { ++ colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); ++ } ++ ++ return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java +index 7e61aa8d3ce58..e8caefcab8a04 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java +@@ -16,20 +16,19 @@ limitations under the License. + package org.tensorflow.lite.support.image; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + import static org.mockito.Mockito.when; + + import android.graphics.Bitmap; + import android.media.Image; ++ + import com.google.android.odml.image.BitmapMlImageBuilder; + import com.google.android.odml.image.ByteBufferMlImageBuilder; + import com.google.android.odml.image.MediaMlImageBuilder; + import com.google.android.odml.image.MlImage; + import com.google.android.odml.image.MlImage.ImageFormat; +-import java.io.IOException; +-import java.nio.ByteBuffer; +-import java.util.Arrays; +-import java.util.Collection; ++ + import org.junit.Before; + import org.junit.Test; + import org.junit.runner.RunWith; +@@ -42,139 +41,141 @@ import org.robolectric.ParameterizedRobolectricTestRunner.Parameter; + import org.robolectric.ParameterizedRobolectricTestRunner.Parameters; + import org.robolectric.RobolectricTestRunner; + ++import java.io.IOException; ++import java.nio.ByteBuffer; ++import java.util.Arrays; ++import java.util.Collection; ++ + /** Test for {@link MlImageAdapter}. */ + @RunWith(Suite.class) + @SuiteClasses({ +- MlImageAdapterTest.CreateTensorImageFromSupportedByteBufferMlImage.class, +- MlImageAdapterTest.CreateTensorImageFromUnsupportedByteBufferMlImage.class, +- MlImageAdapterTest.General.class, ++ MlImageAdapterTest.CreateTensorImageFromSupportedByteBufferMlImage.class, ++ MlImageAdapterTest.CreateTensorImageFromUnsupportedByteBufferMlImage.class, ++ MlImageAdapterTest.General.class, + }) + public class MlImageAdapterTest { +- +- @RunWith(ParameterizedRobolectricTestRunner.class) +- public static final class CreateTensorImageFromSupportedByteBufferMlImage +- extends MlImageAdapterTest { +- +- @Parameter(0) +- @ImageFormat +- public int imageFormat; +- +- @Parameter(1) +- public ColorSpaceType colorSpaceType; +- +- @Parameters(name = "imageFormat={0}") +- public static Collection<Object[]> data() { +- return Arrays.asList( +- new Object[][] { +- {MlImage.IMAGE_FORMAT_RGB, ColorSpaceType.RGB}, +- {MlImage.IMAGE_FORMAT_ALPHA, ColorSpaceType.GRAYSCALE}, +- {MlImage.IMAGE_FORMAT_NV21, ColorSpaceType.NV21}, +- {MlImage.IMAGE_FORMAT_NV12, ColorSpaceType.NV12}, +- {MlImage.IMAGE_FORMAT_YV12, ColorSpaceType.YV12}, +- {MlImage.IMAGE_FORMAT_YV21, ColorSpaceType.YV21}, +- }); +- } +- +- @Test +- public void createTensorImageFrom_supportedByteBufferMlImage_succeeds() throws IOException { +- ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer(); +- buffer.rewind(); +- MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build(); +- +- TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); +- +- assertThat(tensorImage.getWidth()).isEqualTo(1); +- assertThat(tensorImage.getHeight()).isEqualTo(2); +- assertThat(tensorImage.getColorSpaceType()).isEqualTo(colorSpaceType); +- assertThat(tensorImage.getBuffer().position()).isEqualTo(0); +- assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(buffer); +- } +- } +- +- @RunWith(ParameterizedRobolectricTestRunner.class) +- public static final class CreateTensorImageFromUnsupportedByteBufferMlImage +- extends MlImageAdapterTest { +- @Parameter(0) +- @ImageFormat +- public int imageFormat; +- +- @Parameters(name = "imageFormat={0}") +- public static Collection<Object[]> data() { +- return Arrays.asList( +- new Object[][] { +- {MlImage.IMAGE_FORMAT_RGBA}, +- {MlImage.IMAGE_FORMAT_JPEG}, +- {MlImage.IMAGE_FORMAT_YUV_420_888}, +- {MlImage.IMAGE_FORMAT_UNKNOWN}, +- }); +- } +- +- @Test +- public void createTensorImageFrom_unsupportedByteBufferMlImage_throws() throws IOException { +- ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer(); +- buffer.rewind(); +- MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build(); +- +- assertThrows( +- IllegalArgumentException.class, () -> MlImageAdapter.createTensorImageFrom(image)); +- } +- } +- +- @RunWith(RobolectricTestRunner.class) +- public static final class General extends MlImageAdapterTest { +- +- @Mock Image mediaImageMock; +- +- @Before +- public void setUp() { +- MockitoAnnotations.openMocks(this); +- } +- +- @Test +- public void createTensorImageFrom_bitmapMlImage_succeeds() throws IOException { +- Bitmap bitmap = +- Bitmap.createBitmap(new int[] {0xff000100, 0xff000001}, 1, 2, Bitmap.Config.ARGB_8888); +- MlImage image = new BitmapMlImageBuilder(bitmap).build(); +- ByteBuffer expectedBuffer = ByteBuffer.allocateDirect(6); +- for (byte b : new byte[] {0, 1, 0, 0, 0, 1}) { +- expectedBuffer.put(b); +- } +- expectedBuffer.rewind(); +- +- TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); +- +- assertThat(tensorImage.getWidth()).isEqualTo(1); +- assertThat(tensorImage.getHeight()).isEqualTo(2); +- assertThat(tensorImage.getBuffer().position()).isEqualTo(0); +- assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(expectedBuffer); +- } +- +- @Test +- public void createTensorImageFrom_yuv420888MediaImageMlImage_succeeds() throws IOException { +- setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_420_888, 1, 2); +- MlImage image = new MediaMlImageBuilder(mediaImageMock).build(); +- +- TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); +- +- assertThat(tensorImage.getWidth()).isEqualTo(1); +- assertThat(tensorImage.getHeight()).isEqualTo(2); +- assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.YUV_420_888); ++ @RunWith(ParameterizedRobolectricTestRunner.class) ++ public static final class CreateTensorImageFromSupportedByteBufferMlImage ++ extends MlImageAdapterTest { ++ @Parameter(0) ++ @ImageFormat ++ public int imageFormat; ++ ++ @Parameter(1) ++ public ColorSpaceType colorSpaceType; ++ ++ @Parameters(name = "imageFormat={0}") ++ public static Collection<Object[]> data() { ++ return Arrays.asList(new Object[][] { ++ {MlImage.IMAGE_FORMAT_RGB, ColorSpaceType.RGB}, ++ {MlImage.IMAGE_FORMAT_ALPHA, ColorSpaceType.GRAYSCALE}, ++ {MlImage.IMAGE_FORMAT_NV21, ColorSpaceType.NV21}, ++ {MlImage.IMAGE_FORMAT_NV12, ColorSpaceType.NV12}, ++ {MlImage.IMAGE_FORMAT_YV12, ColorSpaceType.YV12}, ++ {MlImage.IMAGE_FORMAT_YV21, ColorSpaceType.YV21}, ++ }); ++ } ++ ++ @Test ++ public void createTensorImageFrom_supportedByteBufferMlImage_succeeds() throws IOException { ++ ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer(); ++ buffer.rewind(); ++ MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build(); ++ ++ TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); ++ ++ assertThat(tensorImage.getWidth()).isEqualTo(1); ++ assertThat(tensorImage.getHeight()).isEqualTo(2); ++ assertThat(tensorImage.getColorSpaceType()).isEqualTo(colorSpaceType); ++ assertThat(tensorImage.getBuffer().position()).isEqualTo(0); ++ assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(buffer); ++ } + } + +- @Test +- public void createTensorImageFrom_nonYuv420888MediaImageMlImage_throws() throws IOException { +- setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_422_888, 1, 2); +- MlImage image = new MediaMlImageBuilder(mediaImageMock).build(); +- +- assertThrows( +- IllegalArgumentException.class, () -> MlImageAdapter.createTensorImageFrom(image)); ++ @RunWith(ParameterizedRobolectricTestRunner.class) ++ public static final class CreateTensorImageFromUnsupportedByteBufferMlImage ++ extends MlImageAdapterTest { ++ @Parameter(0) ++ @ImageFormat ++ public int imageFormat; ++ ++ @Parameters(name = "imageFormat={0}") ++ public static Collection<Object[]> data() { ++ return Arrays.asList(new Object[][] { ++ {MlImage.IMAGE_FORMAT_RGBA}, ++ {MlImage.IMAGE_FORMAT_JPEG}, ++ {MlImage.IMAGE_FORMAT_YUV_420_888}, ++ {MlImage.IMAGE_FORMAT_UNKNOWN}, ++ }); ++ } ++ ++ @Test ++ public void createTensorImageFrom_unsupportedByteBufferMlImage_throws() throws IOException { ++ ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer(); ++ buffer.rewind(); ++ MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build(); ++ ++ assertThrows(IllegalArgumentException.class, ++ () -> MlImageAdapter.createTensorImageFrom(image)); ++ } + } + +- private static void setUpMediaImageMock( +- Image mediaImageMock, int imageFormat, int width, int height) { +- when(mediaImageMock.getFormat()).thenReturn(imageFormat); +- when(mediaImageMock.getWidth()).thenReturn(width); +- when(mediaImageMock.getHeight()).thenReturn(height); ++ @RunWith(RobolectricTestRunner.class) ++ public static final class General extends MlImageAdapterTest { ++ @Mock ++ Image mediaImageMock; ++ ++ @Before ++ public void setUp() { ++ MockitoAnnotations.openMocks(this); ++ } ++ ++ @Test ++ public void createTensorImageFrom_bitmapMlImage_succeeds() throws IOException { ++ Bitmap bitmap = Bitmap.createBitmap( ++ new int[] {0xff000100, 0xff000001}, 1, 2, Bitmap.Config.ARGB_8888); ++ MlImage image = new BitmapMlImageBuilder(bitmap).build(); ++ ByteBuffer expectedBuffer = ByteBuffer.allocateDirect(6); ++ for (byte b : new byte[] {0, 1, 0, 0, 0, 1}) { ++ expectedBuffer.put(b); ++ } ++ expectedBuffer.rewind(); ++ ++ TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); ++ ++ assertThat(tensorImage.getWidth()).isEqualTo(1); ++ assertThat(tensorImage.getHeight()).isEqualTo(2); ++ assertThat(tensorImage.getBuffer().position()).isEqualTo(0); ++ assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(expectedBuffer); ++ } ++ ++ @Test ++ public void createTensorImageFrom_yuv420888MediaImageMlImage_succeeds() throws IOException { ++ setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_420_888, 1, 2); ++ MlImage image = new MediaMlImageBuilder(mediaImageMock).build(); ++ ++ TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); ++ ++ assertThat(tensorImage.getWidth()).isEqualTo(1); ++ assertThat(tensorImage.getHeight()).isEqualTo(2); ++ assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.YUV_420_888); ++ } ++ ++ @Test ++ public void createTensorImageFrom_nonYuv420888MediaImageMlImage_throws() ++ throws IOException { ++ setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_422_888, 1, 2); ++ MlImage image = new MediaMlImageBuilder(mediaImageMock).build(); ++ ++ assertThrows(IllegalArgumentException.class, ++ () -> MlImageAdapter.createTensorImageFrom(image)); ++ } ++ ++ private static void setUpMediaImageMock( ++ Image mediaImageMock, int imageFormat, int width, int height) { ++ when(mediaImageMock.getFormat()).thenReturn(imageFormat); ++ when(mediaImageMock.getWidth()).thenReturn(width); ++ when(mediaImageMock.getHeight()).thenReturn(height); ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java +index ca5f7dc7551be..83b54d0a8db78 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java +@@ -15,6 +15,7 @@ limitations under the License. + package org.tensorflow.lite.support.image; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.tensorflow.lite.DataType.FLOAT32; + import static org.tensorflow.lite.DataType.UINT8; + import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleBitmap; +@@ -23,6 +24,7 @@ import static org.tensorflow.lite.support.image.TestImageCreator.createRgbBitmap + import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensorBuffer; + + import android.graphics.Bitmap; ++ + import org.junit.Test; + import org.junit.runner.RunWith; + import org.junit.runners.JUnit4; +@@ -31,110 +33,110 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + + @RunWith(JUnit4.class) + public final class TensorImageInstrumentedTest { ++ /** ++ * Difference between the pair of float and uint8 values. It is used to test the data ++ * conversion. ++ */ ++ private static final float DELTA = 0.1f; ++ ++ // Note that parameterized test with android_library_instrumentation_tests is currently not ++ // supported in internally. ++ @Test ++ public void loadAndGetBitmapSucceedsWithFloatBufferFloatImage() { ++ DataType tensorBufferDataType = FLOAT32; ++ DataType tensorImageDataType = FLOAT32; ++ boolean isNormalized = true; ++ ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE; ++ ++ TensorBuffer tensorBuffer = ++ createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); ++ TensorImage tensorImage = new TensorImage(tensorImageDataType); ++ ++ tensorImage.load(tensorBuffer, colorSpaceType); ++ Bitmap bitmap = tensorImage.getBitmap(); ++ ++ Bitmap expectedBitmap = createBitmap(colorSpaceType); ++ assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); ++ } ++ ++ @Test ++ public void loadAndGetBitmapSucceedsWithFloatBufferUINT8Image() { ++ DataType tensorBufferDataType = FLOAT32; ++ DataType tensorImageDataType = UINT8; ++ boolean isNormalized = false; ++ ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE; ++ ++ TensorBuffer tensorBuffer = ++ createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); ++ TensorImage tensorImage = new TensorImage(tensorImageDataType); + +- /** +- * Difference between the pair of float and uint8 values. It is used to test the data conversion. +- */ +- private static final float DELTA = 0.1f; +- +- // Note that parameterized test with android_library_instrumentation_tests is currently not +- // supported in internally. +- @Test +- public void loadAndGetBitmapSucceedsWithFloatBufferFloatImage() { +- DataType tensorBufferDataType = FLOAT32; +- DataType tensorImageDataType = FLOAT32; +- boolean isNormalized = true; +- ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE; +- +- TensorBuffer tensorBuffer = +- createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); +- TensorImage tensorImage = new TensorImage(tensorImageDataType); +- +- tensorImage.load(tensorBuffer, colorSpaceType); +- Bitmap bitmap = tensorImage.getBitmap(); +- +- Bitmap expectedBitmap = createBitmap(colorSpaceType); +- assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); +- } +- +- @Test +- public void loadAndGetBitmapSucceedsWithFloatBufferUINT8Image() { +- DataType tensorBufferDataType = FLOAT32; +- DataType tensorImageDataType = UINT8; +- boolean isNormalized = false; +- ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE; +- +- TensorBuffer tensorBuffer = +- createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); +- TensorImage tensorImage = new TensorImage(tensorImageDataType); +- +- tensorImage.load(tensorBuffer, colorSpaceType); +- Bitmap bitmap = tensorImage.getBitmap(); +- +- Bitmap expectedBitmap = createBitmap(colorSpaceType); +- assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); +- } +- +- @Test +- public void loadAndGetBitmapSucceedsWithUINT8BufferFloatImage() { +- DataType tensorBufferDataType = UINT8; +- DataType tensorImageDataType = FLOAT32; +- boolean isNormalized = true; +- ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE; +- +- TensorBuffer tensorBuffer = +- createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); +- TensorImage tensorImage = new TensorImage(tensorImageDataType); +- +- tensorImage.load(tensorBuffer, colorSpaceType); +- Bitmap bitmap = tensorImage.getBitmap(); +- +- Bitmap expectedBitmap = createBitmap(colorSpaceType); +- assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); +- } +- +- @Test +- public void loadAndGetBitmapSucceedsWithUINT8BufferUINT8Image() { +- DataType tensorBufferDataType = UINT8; +- DataType tensorImageDataType = UINT8; +- boolean isNormalized = false; +- ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE; +- +- TensorBuffer tensorBuffer = +- createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); +- TensorImage tensorImage = new TensorImage(tensorImageDataType); +- +- tensorImage.load(tensorBuffer, colorSpaceType); +- Bitmap bitmap = tensorImage.getBitmap(); +- +- Bitmap expectedBitmap = createBitmap(colorSpaceType); +- assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); +- } +- +- private static TensorBuffer createTensorBuffer( +- DataType dataType, boolean isNormalized, ColorSpaceType colorSpaceType, float delta) { +- switch (colorSpaceType) { +- case RGB: +- return createRgbTensorBuffer(dataType, isNormalized, delta); +- case GRAYSCALE: +- return createGrayscaleTensorBuffer(dataType, isNormalized, delta); +- default: +- break; ++ tensorImage.load(tensorBuffer, colorSpaceType); ++ Bitmap bitmap = tensorImage.getBitmap(); ++ ++ Bitmap expectedBitmap = createBitmap(colorSpaceType); ++ assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); + } +- throw new IllegalArgumentException( +- "The ColorSpaceType, " + colorSpaceType + ", is unsupported."); +- } +- +- private static Bitmap createBitmap(ColorSpaceType colorSpaceType) { +- switch (colorSpaceType) { +- case RGB: +- return createRgbBitmap(); +- case GRAYSCALE: +- return createGrayscaleBitmap(); +- default: +- break; ++ ++ @Test ++ public void loadAndGetBitmapSucceedsWithUINT8BufferFloatImage() { ++ DataType tensorBufferDataType = UINT8; ++ DataType tensorImageDataType = FLOAT32; ++ boolean isNormalized = true; ++ ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE; ++ ++ TensorBuffer tensorBuffer = ++ createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); ++ TensorImage tensorImage = new TensorImage(tensorImageDataType); ++ ++ tensorImage.load(tensorBuffer, colorSpaceType); ++ Bitmap bitmap = tensorImage.getBitmap(); ++ ++ Bitmap expectedBitmap = createBitmap(colorSpaceType); ++ assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); ++ } ++ ++ @Test ++ public void loadAndGetBitmapSucceedsWithUINT8BufferUINT8Image() { ++ DataType tensorBufferDataType = UINT8; ++ DataType tensorImageDataType = UINT8; ++ boolean isNormalized = false; ++ ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE; ++ ++ TensorBuffer tensorBuffer = ++ createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); ++ TensorImage tensorImage = new TensorImage(tensorImageDataType); ++ ++ tensorImage.load(tensorBuffer, colorSpaceType); ++ Bitmap bitmap = tensorImage.getBitmap(); ++ ++ Bitmap expectedBitmap = createBitmap(colorSpaceType); ++ assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); ++ } ++ ++ private static TensorBuffer createTensorBuffer( ++ DataType dataType, boolean isNormalized, ColorSpaceType colorSpaceType, float delta) { ++ switch (colorSpaceType) { ++ case RGB: ++ return createRgbTensorBuffer(dataType, isNormalized, delta); ++ case GRAYSCALE: ++ return createGrayscaleTensorBuffer(dataType, isNormalized, delta); ++ default: ++ break; ++ } ++ throw new IllegalArgumentException( ++ "The ColorSpaceType, " + colorSpaceType + ", is unsupported."); ++ } ++ ++ private static Bitmap createBitmap(ColorSpaceType colorSpaceType) { ++ switch (colorSpaceType) { ++ case RGB: ++ return createRgbBitmap(); ++ case GRAYSCALE: ++ return createGrayscaleBitmap(); ++ default: ++ break; ++ } ++ throw new IllegalArgumentException( ++ "The ColorSpaceType, " + colorSpaceType + ", is unsupported."); + } +- throw new IllegalArgumentException( +- "The ColorSpaceType, " + colorSpaceType + ", is unsupported."); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java +index f27edef4de779..b3130f4f2073c 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java +@@ -16,6 +16,7 @@ limitations under the License. + package org.tensorflow.lite.support.image; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertArrayEquals; + import static org.junit.Assert.assertThrows; + import static org.mockito.Mockito.when; +@@ -31,9 +32,7 @@ import android.graphics.Bitmap.Config; + import android.graphics.Color; + import android.graphics.ImageFormat; + import android.media.Image; +-import java.nio.ByteBuffer; +-import java.util.Arrays; +-import java.util.Collection; ++ + import org.junit.Before; + import org.junit.Test; + import org.junit.runner.RunWith; +@@ -48,713 +47,689 @@ import org.robolectric.RobolectricTestRunner; + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + ++import java.nio.ByteBuffer; ++import java.util.Arrays; ++import java.util.Collection; ++ + /** Tests of {@link org.tensorflow.lite.support.image.TensorImage}. */ + @RunWith(Suite.class) +-@SuiteClasses({ +- TensorImageTest.General.class, +- TensorImageTest.LoadTensorBufferWithRgbAndGrayscale.class, +- TensorImageTest.LoadTensorBufferWithInvalidShapeTest.class, +- TensorImageTest.LoadTensorBufferWithYUV.class, +- TensorImageTest.LoadTensorBufferWithImageProperties.class +-}) ++@SuiteClasses( ++ {TensorImageTest.General.class, TensorImageTest.LoadTensorBufferWithRgbAndGrayscale.class, ++ TensorImageTest.LoadTensorBufferWithInvalidShapeTest.class, ++ TensorImageTest.LoadTensorBufferWithYUV.class, ++ TensorImageTest.LoadTensorBufferWithImageProperties.class}) + public class TensorImageTest { +- +- @RunWith(RobolectricTestRunner.class) +- public static final class General extends TensorImageTest { +- +- private static final Bitmap exampleBitmap = createExampleBitmap(); +- private static final float[] exampleFloatPixels = createExampleFloatPixels(); +- private static final int[] exampleUint8Pixels = createExampleUint8Pixels(); +- +- private static final int EXAMPLE_WIDTH = 5; +- private static final int EXAMPLE_HEIGHT = 10; +- private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH; +- private static final int EXAMPLE_NUM_CHANNELS = 3; +- private static final int[] EXAMPLE_SHAPE = { +- EXAMPLE_HEIGHT, EXAMPLE_WIDTH, EXAMPLE_NUM_CHANNELS +- }; +- private static final float MEAN = 127.5f; +- private static final float STDDEV = 127.5f; +- +- @Mock Image imageMock; +- +- @Before +- public void setUp() { +- MockitoAnnotations.initMocks(this); +- } +- +- @Test +- public void defaultConstructorCreatesUint8TensorImage() { +- TensorImage image = new TensorImage(); +- assertThat(image.getDataType()).isEqualTo(UINT8); +- } +- +- @Test +- public void createFromSucceedsWithUint8TensorImage() { +- TensorImage uint8Image = new TensorImage(UINT8); +- uint8Image.load(new int[] {1, 2, 3, 4, -5, 600}, new int[] {2, 1, 3}); +- +- TensorImage floatImage = TensorImage.createFrom(uint8Image, FLOAT32); +- float[] pixels = floatImage.getTensorBuffer().getFloatArray(); +- assertThat(pixels).isEqualTo(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 0.0f, 255.0f}); +- } +- +- @Test +- public void createFromSucceedsWithFloatTensorImage() { +- TensorImage floatImage = new TensorImage(FLOAT32); +- floatImage.load(new float[] {1, 2.495f, 3.5f, 4.5f, -5, 600}, new int[] {2, 1, 3}); +- +- TensorImage uint8Image = TensorImage.createFrom(floatImage, UINT8); +- int[] pixels = uint8Image.getTensorBuffer().getIntArray(); +- assertThat(pixels).isEqualTo(new int[] {1, 2, 3, 4, 0, 255}); +- } +- +- @Test +- public void loadBitmapSucceedsWithUint8TensorImage() { +- Bitmap rgbBitmap = createRgbBitmap(); +- TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(UINT8, false, 0.0f); +- TensorImage uint8Image = new TensorImage(UINT8); +- +- uint8Image.load(rgbBitmap); +- assertThat(uint8Image.getBitmap().sameAs(rgbBitmap)).isTrue(); +- assertEqualTensorBuffers(uint8Image.getTensorBuffer(), rgbTensorBuffer); +- assertThat(uint8Image.getDataType()).isEqualTo(UINT8); +- } +- +- @Test +- public void loadBitmapSucceedsWithFloatTensorImage() { +- Bitmap rgbBitmap = createRgbBitmap(); +- TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(FLOAT32, false, 0.0f); +- TensorImage floatImage = new TensorImage(FLOAT32); +- +- floatImage.load(rgbBitmap); +- assertThat(floatImage.getBitmap().sameAs(rgbBitmap)).isTrue(); +- assertEqualTensorBuffers(floatImage.getTensorBuffer(), rgbTensorBuffer); +- assertThat(floatImage.getDataType()).isEqualTo(FLOAT32); +- } +- +- @Test +- public void loadFloatArrayWithUint8TensorImage() { +- TensorImage uint8Image = new TensorImage(UINT8); +- +- uint8Image.load(exampleFloatPixels, EXAMPLE_SHAPE); +- assertThat(uint8Image.getBitmap()).isNotNull(); +- assertThat(uint8Image.getTensorBuffer().getFloatArray()) +- .isEqualTo( +- new float +- [exampleFloatPixels +- .length]); // All zero because of normalization and casting when loading. +- } +- +- @Test +- public void loadFloatArrayWithFloatTensorImage() { +- TensorImage floatImage = new TensorImage(FLOAT32); +- +- floatImage.load(exampleFloatPixels, EXAMPLE_SHAPE); +- assertThat(floatImage.getTensorBuffer().getFloatArray()).isEqualTo(exampleFloatPixels); +- } +- +- @Test +- public void loadUint8ArrayWithUint8TensorImage() { +- TensorImage uint8Image = new TensorImage(UINT8); +- +- uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE); +- assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue(); +- assertThat(uint8Image.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels); +- } +- +- @Test +- public void loadUint8ArrayWithFloatTensorImage() { +- TensorImage floatImage = new TensorImage(FLOAT32); +- +- floatImage.load(exampleUint8Pixels, EXAMPLE_SHAPE); +- assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels); +- } +- +- @Test +- public void loadTensorBufferWithUint8TensorImage() { +- TensorImage uint8Image = new TensorImage(UINT8); +- +- uint8Image.load(exampleBitmap); +- TensorBuffer buffer = uint8Image.getTensorBuffer(); +- uint8Image.load(buffer); +- assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue(); +- } +- +- @Test +- public void loadTensorBufferWithFloatTensorImage() { +- TensorImage floatImage = new TensorImage(FLOAT32); +- +- floatImage.load(exampleBitmap); +- TensorBuffer buffer = floatImage.getTensorBuffer(); +- floatImage.load(buffer); +- assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels); +- } +- +- @Test +- public void loadAndGetMediaImageSucceedsWithYuv420888Format() { +- setUpImageMock(imageMock, ImageFormat.YUV_420_888); +- TensorImage tensorImage = new TensorImage(UINT8); +- +- tensorImage.load(imageMock); +- Image imageReturned = tensorImage.getMediaImage(); +- +- assertThat(imageReturned).isEqualTo(imageMock); +- } +- +- @Test +- public void loadMediaImageFailsWithNonYuv420888Format() { +- setUpImageMock(imageMock, ImageFormat.YUV_422_888); +- TensorImage tensorImage = new TensorImage(UINT8); +- +- IllegalArgumentException exception = +- assertThrows(IllegalArgumentException.class, () -> tensorImage.load(imageMock)); +- assertThat(exception).hasMessageThat().contains("Only supports loading YUV_420_888 Image."); +- } +- +- @Test +- public void getBitmapWithUint8TensorImage() { +- TensorImage uint8Image = new TensorImage(UINT8); +- +- uint8Image.load(exampleBitmap); +- assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue(); +- // Also check zero copy is effective here (input and output are references of the same +- // object). +- assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap); +- // Also check we don't create new Bitmap only with reading operations. +- assertThat(uint8Image.getBuffer().limit()) +- .isEqualTo(EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS); +- assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap); +- +- uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE); +- assertThat(uint8Image.getBitmap()).isNotSameInstanceAs(exampleBitmap); +- } +- +- @Test +- public void getBitmapWithFloatTensorImage() { +- TensorImage floatImage = new TensorImage(FLOAT32); +- +- floatImage.load(exampleBitmap); +- assertThat(floatImage.getBitmap()).isSameInstanceAs(exampleBitmap); +- } +- +- @Test +- public void getBitmapWithEmptyTensorImage() { +- TensorImage uint8Image = new TensorImage(UINT8); +- +- assertThrows(IllegalStateException.class, uint8Image::getBitmap); +- } +- +- @Test +- public void getMediaImageFailsWithBackedBitmap() { +- TensorImage tensorImage = TensorImage.fromBitmap(exampleBitmap); +- +- UnsupportedOperationException exception = +- assertThrows(UnsupportedOperationException.class, () -> tensorImage.getMediaImage()); +- assertThat(exception) +- .hasMessageThat() +- .contains("Converting from Bitmap to android.media.Image is unsupported."); +- } +- +- @Test +- public void getMediaImageFailsWithBackedTensorBuffer() { +- TensorImage tensorImage = new TensorImage(UINT8); +- tensorImage.load(exampleFloatPixels, EXAMPLE_SHAPE); +- +- UnsupportedOperationException exception = +- assertThrows(UnsupportedOperationException.class, () -> tensorImage.getMediaImage()); +- assertThat(exception) +- .hasMessageThat() +- .contains("Converting from TensorBuffer to android.media.Image is unsupported."); +- } +- +- @Test +- public void getShapeOfInternalBitmapShouldSuccess() { +- Bitmap bitmap = Bitmap.createBitmap(300, 400, Config.ARGB_8888); +- TensorImage image = TensorImage.fromBitmap(bitmap); +- +- int width = image.getWidth(); +- int height = image.getHeight(); +- +- assertThat(width).isEqualTo(300); +- assertThat(height).isEqualTo(400); +- } +- +- @Test +- public void getShapeOfInternalTensorBufferShouldSuccess() { +- TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 400, 300, 3}, UINT8); +- TensorImage image = new TensorImage(); +- image.load(buffer); +- +- int width = image.getWidth(); +- int height = image.getHeight(); +- +- assertThat(width).isEqualTo(300); +- assertThat(height).isEqualTo(400); +- } +- +- @Test +- public void getShapeOfNullImageShouldThrow() { +- TensorImage image = new TensorImage(); +- +- assertThrows(IllegalStateException.class, image::getHeight); +- } +- +- @Test +- public void getShapeOfACorruptedBufferShouldThrowRatherThanCrash() { +- int[] data = new int[] {1, 2, 3, 4, 5, 6}; +- TensorBuffer buffer = TensorBuffer.createDynamic(UINT8); +- buffer.loadArray(data, new int[] {1, 1, 2, 3}); +- TensorImage image = new TensorImage(); +- image.load(buffer); +- // Reload data but with an invalid shape, which leads to `buffer` corrupted. +- int[] newData = new int[] {1, 2, 3}; +- buffer.loadArray(newData, new int[] {1, 1, 1, 3}); +- +- assertThrows(IllegalArgumentException.class, image::getHeight); +- } +- +- @Test +- public void getColorSpaceTypeSucceedsWithBitmapARGB_8888() { +- Bitmap rgbBitmap = createRgbBitmap(); +- TensorImage tensorImage = TensorImage.fromBitmap(rgbBitmap); +- +- assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB); +- } +- +- @Test +- public void getColorSpaceTypeSucceedsWithRgbTensorBuffer() { +- TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false); +- TensorImage tensorImage = new TensorImage(); +- tensorImage.load(rgbBuffer); +- +- assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB); +- } +- +- @Test +- public void getColorSpaceTypeSucceedsWithGrayscaleTensorBuffer() { +- TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false); +- TensorImage tensorImage = new TensorImage(); +- tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE); +- +- assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE); +- } +- +- @Test +- public void getColorSpaceTypeSucceedsWithRepeatedLoading() { +- TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false); +- TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false); +- Bitmap rgbBitmap = createRgbBitmap(); +- TensorImage tensorImage = new TensorImage(); +- +- tensorImage.load(rgbBuffer); +- assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB); +- tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE); +- assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE); +- tensorImage.load(rgbBitmap); +- assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB); +- } +- +- @Test +- public void getColorSpaceTypeFailsWhenNoImageHasBeenLoaded() { +- TensorImage tensorImage = new TensorImage(); +- +- IllegalStateException exception = +- assertThrows(IllegalStateException.class, tensorImage::getColorSpaceType); +- assertThat(exception).hasMessageThat().contains("No image has been loaded yet."); +- } +- +- /** +- * Creates an example bit map, which is a 10x10 ARGB bitmap and pixels are set by: pixel[i] = +- * {A: 0, B: i + 2, G: i + 1, G: i}, where i is the flatten index +- */ +- private static Bitmap createExampleBitmap() { +- int[] colors = new int[EXAMPLE_NUM_PIXELS]; +- for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) { +- colors[i] = Color.rgb(i, i + 1, i + 2); +- } +- +- return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); +- } +- +- private static float[] createExampleFloatPixels() { +- float[] pixels = new float[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS]; +- for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) { +- pixels[j++] = (i - MEAN) / STDDEV; +- pixels[j++] = (i + 1 - MEAN) / STDDEV; +- pixels[j++] = (i + 2 - MEAN) / STDDEV; +- } +- return pixels; +- } +- +- private static int[] createExampleUint8Pixels() { +- int[] pixels = new int[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS]; +- for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) { +- pixels[j++] = i; +- pixels[j++] = i + 1; +- pixels[j++] = i + 2; +- } +- return pixels; +- } +- } +- +- /** Parameterized tests for loading TensorBuffers with RGB and Grayscale images. */ +- @RunWith(ParameterizedRobolectricTestRunner.class) +- public static final class LoadTensorBufferWithRgbAndGrayscale extends TensorImageTest { +- +- /** +- * Difference between the pair of float and uint8 values. It is used to test the data +- * conversion. +- */ +- private static final float DELTA = 0.1f; +- +- /** The data type that used to create the TensorBuffer. */ +- @Parameter(0) +- public DataType tensorBufferDataType; +- +- /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */ +- @Parameter(1) +- public boolean isNormalized; +- +- /** The color space type of the TensorBuffer. */ +- @Parameter(2) +- public ColorSpaceType colorSpaceType; +- +- /** The data type that used to create the TensorImage. */ +- @Parameter(3) +- public DataType tensorImageDataType; +- +- @Parameters( +- name = +- "tensorBufferDataType={0}; isNormalized={1}; colorSpaceType={2};" +- + " tensorImageDataType={3}") +- public static Collection<Object[]> data() { +- return Arrays.asList( +- new Object[][] { +- {FLOAT32, true, ColorSpaceType.RGB, FLOAT32}, +- {FLOAT32, false, ColorSpaceType.RGB, UINT8}, +- {UINT8, true, ColorSpaceType.RGB, FLOAT32}, +- {UINT8, false, ColorSpaceType.RGB, UINT8}, +- }); +- } +- +- @Test +- public void loadAndGetBitmapSucceedsWithTensorBufferAndColorSpaceType() { +- TensorBuffer tensorBuffer = +- createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); +- TensorImage tensorImage = new TensorImage(tensorImageDataType); +- +- tensorImage.load(tensorBuffer, colorSpaceType); +- Bitmap bitmap = tensorImage.getBitmap(); +- +- Bitmap expectedBitmap = createBitmap(colorSpaceType); +- assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); +- } +- +- @Test +- public void loadAndGetTensorBufferSucceedsWithTensorBufferAndColorSpaceType() { +- TensorBuffer tensorBuffer = +- createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); +- TensorImage tensorImage = new TensorImage(tensorImageDataType); +- +- tensorImage.load(tensorBuffer, colorSpaceType); +- TensorBuffer buffer = tensorImage.getTensorBuffer(); +- +- // If tensorBufferDataType is UINT8, expectedTensorBuffer should not contain delta. +- float expectedResidual = tensorBufferDataType == UINT8 ? 0.f : DELTA; +- TensorBuffer expectedTensorBuffer = +- createTensorBuffer(tensorImageDataType, isNormalized, colorSpaceType, expectedResidual); +- assertEqualTensorBuffers(buffer, expectedTensorBuffer); +- } +- +- private static TensorBuffer createTensorBuffer( +- DataType dataType, boolean isNormalized, ColorSpaceType colorSpaceType, float delta) { +- switch (colorSpaceType) { +- case RGB: +- return createRgbTensorBuffer(dataType, isNormalized, delta); +- case GRAYSCALE: +- return createGrayscaleTensorBuffer(dataType, isNormalized, delta); +- default: +- break; +- } +- throw new IllegalArgumentException( +- "The ColorSpaceType, " + colorSpaceType + ", is unsupported."); +- } +- +- private static Bitmap createBitmap(ColorSpaceType colorSpaceType) { +- switch (colorSpaceType) { +- case RGB: +- return createRgbBitmap(); +- case GRAYSCALE: +- return createGrayscaleBitmap(); +- default: +- break; +- } +- throw new IllegalArgumentException( +- "The ColorSpaceType, " + colorSpaceType + ", is unsupported."); +- } +- } +- +- /** Parameterized tests for loading TensorBuffers with YUV images. */ +- @RunWith(ParameterizedRobolectricTestRunner.class) +- public static final class LoadTensorBufferWithYUV extends TensorImageTest { +- +- private static final int HEIGHT = 2; +- private static final int WIDTH = 3; +- +- @Parameter(0) +- public ColorSpaceType colorSpaceType; +- +- @Parameters(name = "colorSpaceType={0}") +- public static Collection<Object[]> data() { +- return Arrays.asList( +- new Object[][] { +- {ColorSpaceType.NV12}, +- {ColorSpaceType.NV21}, +- {ColorSpaceType.YV12}, +- {ColorSpaceType.YV21}, +- }); +- } +- +- @Test +- public void loadTensorBufferWithColorSpaceShouldFail() { +- TensorImage tensorImage = new TensorImage(); +- +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, +- () -> tensorImage.load(TensorBuffer.createDynamic(DataType.FLOAT32), colorSpaceType)); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use" +- + " `load(TensorBuffer, ImageProperties)` for other color space types."); +- } +- +- @Test +- public void loadTensorBufferAndGetBitmapShouldFail() { +- int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)]; +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); +- tensorBuffer.loadArray(data, new int[] {data.length}); +- +- ImageProperties imageProperties = +- ImageProperties.builder() +- .setHeight(HEIGHT) +- .setWidth(WIDTH) +- .setColorSpaceType(colorSpaceType) +- .build(); +- +- TensorImage tensorImage = new TensorImage(DataType.FLOAT32); +- tensorImage.load(tensorBuffer, imageProperties); +- +- UnsupportedOperationException exception = +- assertThrows(UnsupportedOperationException.class, () -> tensorImage.getBitmap()); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "convertTensorBufferToBitmap() is unsupported for the color space type " +- + colorSpaceType.name()); +- } +- } +- +- /** Parameterized tests for loading TensorBuffers with ImageProperties. */ +- @RunWith(ParameterizedRobolectricTestRunner.class) +- public static final class LoadTensorBufferWithImageProperties extends TensorImageTest { +- +- private static final int HEIGHT = 2; +- private static final int WIDTH = 3; +- private static final int WRONG_WIDTH = 10; +- +- @Parameter(0) +- public ColorSpaceType colorSpaceType; +- +- @Parameters(name = "colorSpaceType={0}") +- public static Collection<Object[]> data() { +- return Arrays.asList( +- new Object[][] { +- {ColorSpaceType.RGB}, +- {ColorSpaceType.GRAYSCALE}, +- {ColorSpaceType.NV12}, +- {ColorSpaceType.NV21}, +- {ColorSpaceType.YV12}, +- {ColorSpaceType.YV21}, +- }); +- } +- +- @Test +- public void loadAndGetTensorBufferShouldSucceedWithCorrectProperties() { +- int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)]; +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); +- tensorBuffer.loadArray(data, new int[] {data.length}); +- +- ImageProperties imageProperties = +- ImageProperties.builder() +- .setHeight(HEIGHT) +- .setWidth(WIDTH) +- .setColorSpaceType(colorSpaceType) +- .build(); +- +- TensorImage tensorImage = new TensorImage(DataType.FLOAT32); +- tensorImage.load(tensorBuffer, imageProperties); +- +- assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer); +- } +- +- @Test +- public void loadAndGetTensorBufferShouldSucceedWithLargerBuffer() { +- // Should allow buffer to be greater than the size specified by height and width. +- int moreElements = 1; +- int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH) + moreElements]; +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); +- tensorBuffer.loadArray(data, new int[] {data.length}); +- +- ImageProperties imageProperties = +- ImageProperties.builder() +- .setHeight(HEIGHT) +- .setWidth(WIDTH) +- .setColorSpaceType(colorSpaceType) +- .build(); +- +- TensorImage tensorImage = new TensorImage(DataType.FLOAT32); +- tensorImage.load(tensorBuffer, imageProperties); +- +- assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer); +- } +- +- @Test +- public void loadAndGetByteBufferShouldSucceedWithCorrectProperties() { +- ByteBuffer byteBuffer = ByteBuffer.allocate(colorSpaceType.getNumElements(HEIGHT, WIDTH)); +- +- ImageProperties imageProperties = +- ImageProperties.builder() +- .setHeight(HEIGHT) +- .setWidth(WIDTH) +- .setColorSpaceType(colorSpaceType) +- .build(); +- +- TensorImage tensorImage = new TensorImage(DataType.UINT8); +- tensorImage.load(byteBuffer, imageProperties); +- +- assertEqualByteBuffers(tensorImage.getBuffer(), byteBuffer); +- } +- +- @Test +- public void loadTensorBufferWithShouldFailWithWrongImageShape() { +- int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)]; +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); +- tensorBuffer.loadArray(data, new int[] {data.length}); +- +- ImageProperties imageProperties = +- ImageProperties.builder() +- .setHeight(HEIGHT) +- .setWidth(WRONG_WIDTH) +- .setColorSpaceType(colorSpaceType) +- .build(); +- +- TensorImage tensorImage = new TensorImage(DataType.FLOAT32); +- +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, +- () -> tensorImage.load(tensorBuffer, imageProperties)); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- String.format( +- "The given number of elements (%d) does not match the image (%s) in %d x %d. The" +- + " expected number of elements should be at least %d.", +- data.length, +- colorSpaceType.name(), +- HEIGHT, +- WRONG_WIDTH, +- colorSpaceType.getNumElements(HEIGHT, WRONG_WIDTH))); +- } +- +- @Test +- public void getShapeOfInternalTensorBufferShouldSuccess() { +- int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)]; +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); +- tensorBuffer.loadArray(data, new int[] {data.length}); +- +- ImageProperties imageProperties = +- ImageProperties.builder() +- .setHeight(HEIGHT) +- .setWidth(WIDTH) +- .setColorSpaceType(colorSpaceType) +- .build(); +- +- TensorImage tensorImage = new TensorImage(DataType.FLOAT32); +- tensorImage.load(tensorBuffer, imageProperties); +- +- assertThat(tensorImage.getWidth()).isEqualTo(WIDTH); +- assertThat(tensorImage.getHeight()).isEqualTo(HEIGHT); +- } +- } +- +- /** Parameterized tests for loading TensorBuffer with invalid shapes. */ +- @RunWith(ParameterizedRobolectricTestRunner.class) +- public static final class LoadTensorBufferWithInvalidShapeTest extends TensorImageTest { +- +- private static final String RGB_ASSERT_SHAPE_MESSAGE = +- "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels" +- + " representing R, G, B in order. The provided image shape is "; +- private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE = +- "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image" +- + " shape is "; +- +- @Parameter(0) +- public ColorSpaceType colorSpaceType; +- +- /** The shape that does not match the colorSpaceType. */ +- @Parameter(1) +- public int[] invalidShape; +- +- @Parameter(2) +- public String errorMessage; +- +- @Parameters(name = "colorSpaceType={0}; invalidShape={1}") +- public static Collection<Object[]> data() { +- return Arrays.asList( +- new Object[][] { +- {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, +- {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, +- }); +- } +- +- @Test +- public void loadTensorBufferWithInvalidShape() { +- TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(invalidShape, UINT8); +- TensorImage tensorImage = new TensorImage(); +- +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> tensorImage.load(tensorBuffer, colorSpaceType)); +- assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape)); ++ @RunWith(RobolectricTestRunner.class) ++ public static final class General extends TensorImageTest { ++ private static final Bitmap exampleBitmap = createExampleBitmap(); ++ private static final float[] exampleFloatPixels = createExampleFloatPixels(); ++ private static final int[] exampleUint8Pixels = createExampleUint8Pixels(); ++ ++ private static final int EXAMPLE_WIDTH = 5; ++ private static final int EXAMPLE_HEIGHT = 10; ++ private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH; ++ private static final int EXAMPLE_NUM_CHANNELS = 3; ++ private static final int[] EXAMPLE_SHAPE = { ++ EXAMPLE_HEIGHT, EXAMPLE_WIDTH, EXAMPLE_NUM_CHANNELS}; ++ private static final float MEAN = 127.5f; ++ private static final float STDDEV = 127.5f; ++ ++ @Mock ++ Image imageMock; ++ ++ @Before ++ public void setUp() { ++ MockitoAnnotations.initMocks(this); ++ } ++ ++ @Test ++ public void defaultConstructorCreatesUint8TensorImage() { ++ TensorImage image = new TensorImage(); ++ assertThat(image.getDataType()).isEqualTo(UINT8); ++ } ++ ++ @Test ++ public void createFromSucceedsWithUint8TensorImage() { ++ TensorImage uint8Image = new TensorImage(UINT8); ++ uint8Image.load(new int[] {1, 2, 3, 4, -5, 600}, new int[] {2, 1, 3}); ++ ++ TensorImage floatImage = TensorImage.createFrom(uint8Image, FLOAT32); ++ float[] pixels = floatImage.getTensorBuffer().getFloatArray(); ++ assertThat(pixels).isEqualTo(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 0.0f, 255.0f}); ++ } ++ ++ @Test ++ public void createFromSucceedsWithFloatTensorImage() { ++ TensorImage floatImage = new TensorImage(FLOAT32); ++ floatImage.load(new float[] {1, 2.495f, 3.5f, 4.5f, -5, 600}, new int[] {2, 1, 3}); ++ ++ TensorImage uint8Image = TensorImage.createFrom(floatImage, UINT8); ++ int[] pixels = uint8Image.getTensorBuffer().getIntArray(); ++ assertThat(pixels).isEqualTo(new int[] {1, 2, 3, 4, 0, 255}); ++ } ++ ++ @Test ++ public void loadBitmapSucceedsWithUint8TensorImage() { ++ Bitmap rgbBitmap = createRgbBitmap(); ++ TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(UINT8, false, 0.0f); ++ TensorImage uint8Image = new TensorImage(UINT8); ++ ++ uint8Image.load(rgbBitmap); ++ assertThat(uint8Image.getBitmap().sameAs(rgbBitmap)).isTrue(); ++ assertEqualTensorBuffers(uint8Image.getTensorBuffer(), rgbTensorBuffer); ++ assertThat(uint8Image.getDataType()).isEqualTo(UINT8); ++ } ++ ++ @Test ++ public void loadBitmapSucceedsWithFloatTensorImage() { ++ Bitmap rgbBitmap = createRgbBitmap(); ++ TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(FLOAT32, false, 0.0f); ++ TensorImage floatImage = new TensorImage(FLOAT32); ++ ++ floatImage.load(rgbBitmap); ++ assertThat(floatImage.getBitmap().sameAs(rgbBitmap)).isTrue(); ++ assertEqualTensorBuffers(floatImage.getTensorBuffer(), rgbTensorBuffer); ++ assertThat(floatImage.getDataType()).isEqualTo(FLOAT32); ++ } ++ ++ @Test ++ public void loadFloatArrayWithUint8TensorImage() { ++ TensorImage uint8Image = new TensorImage(UINT8); ++ ++ uint8Image.load(exampleFloatPixels, EXAMPLE_SHAPE); ++ assertThat(uint8Image.getBitmap()).isNotNull(); ++ assertThat(uint8Image.getTensorBuffer().getFloatArray()) ++ .isEqualTo(new float[exampleFloatPixels.length]); // All zero because of ++ // normalization and casting ++ // when loading. ++ } ++ ++ @Test ++ public void loadFloatArrayWithFloatTensorImage() { ++ TensorImage floatImage = new TensorImage(FLOAT32); ++ ++ floatImage.load(exampleFloatPixels, EXAMPLE_SHAPE); ++ assertThat(floatImage.getTensorBuffer().getFloatArray()).isEqualTo(exampleFloatPixels); ++ } ++ ++ @Test ++ public void loadUint8ArrayWithUint8TensorImage() { ++ TensorImage uint8Image = new TensorImage(UINT8); ++ ++ uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE); ++ assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue(); ++ assertThat(uint8Image.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels); ++ } ++ ++ @Test ++ public void loadUint8ArrayWithFloatTensorImage() { ++ TensorImage floatImage = new TensorImage(FLOAT32); ++ ++ floatImage.load(exampleUint8Pixels, EXAMPLE_SHAPE); ++ assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels); ++ } ++ ++ @Test ++ public void loadTensorBufferWithUint8TensorImage() { ++ TensorImage uint8Image = new TensorImage(UINT8); ++ ++ uint8Image.load(exampleBitmap); ++ TensorBuffer buffer = uint8Image.getTensorBuffer(); ++ uint8Image.load(buffer); ++ assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue(); ++ } ++ ++ @Test ++ public void loadTensorBufferWithFloatTensorImage() { ++ TensorImage floatImage = new TensorImage(FLOAT32); ++ ++ floatImage.load(exampleBitmap); ++ TensorBuffer buffer = floatImage.getTensorBuffer(); ++ floatImage.load(buffer); ++ assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels); ++ } ++ ++ @Test ++ public void loadAndGetMediaImageSucceedsWithYuv420888Format() { ++ setUpImageMock(imageMock, ImageFormat.YUV_420_888); ++ TensorImage tensorImage = new TensorImage(UINT8); ++ ++ tensorImage.load(imageMock); ++ Image imageReturned = tensorImage.getMediaImage(); ++ ++ assertThat(imageReturned).isEqualTo(imageMock); ++ } ++ ++ @Test ++ public void loadMediaImageFailsWithNonYuv420888Format() { ++ setUpImageMock(imageMock, ImageFormat.YUV_422_888); ++ TensorImage tensorImage = new TensorImage(UINT8); ++ ++ IllegalArgumentException exception = ++ assertThrows(IllegalArgumentException.class, () -> tensorImage.load(imageMock)); ++ assertThat(exception).hasMessageThat().contains( ++ "Only supports loading YUV_420_888 Image."); ++ } ++ ++ @Test ++ public void getBitmapWithUint8TensorImage() { ++ TensorImage uint8Image = new TensorImage(UINT8); ++ ++ uint8Image.load(exampleBitmap); ++ assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue(); ++ // Also check zero copy is effective here (input and output are references of the same ++ // object). ++ assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap); ++ // Also check we don't create new Bitmap only with reading operations. ++ assertThat(uint8Image.getBuffer().limit()) ++ .isEqualTo(EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS); ++ assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap); ++ ++ uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE); ++ assertThat(uint8Image.getBitmap()).isNotSameInstanceAs(exampleBitmap); ++ } ++ ++ @Test ++ public void getBitmapWithFloatTensorImage() { ++ TensorImage floatImage = new TensorImage(FLOAT32); ++ ++ floatImage.load(exampleBitmap); ++ assertThat(floatImage.getBitmap()).isSameInstanceAs(exampleBitmap); ++ } ++ ++ @Test ++ public void getBitmapWithEmptyTensorImage() { ++ TensorImage uint8Image = new TensorImage(UINT8); ++ ++ assertThrows(IllegalStateException.class, uint8Image::getBitmap); ++ } ++ ++ @Test ++ public void getMediaImageFailsWithBackedBitmap() { ++ TensorImage tensorImage = TensorImage.fromBitmap(exampleBitmap); ++ ++ UnsupportedOperationException exception = assertThrows( ++ UnsupportedOperationException.class, () -> tensorImage.getMediaImage()); ++ assertThat(exception).hasMessageThat().contains( ++ "Converting from Bitmap to android.media.Image is unsupported."); ++ } ++ ++ @Test ++ public void getMediaImageFailsWithBackedTensorBuffer() { ++ TensorImage tensorImage = new TensorImage(UINT8); ++ tensorImage.load(exampleFloatPixels, EXAMPLE_SHAPE); ++ ++ UnsupportedOperationException exception = assertThrows( ++ UnsupportedOperationException.class, () -> tensorImage.getMediaImage()); ++ assertThat(exception).hasMessageThat().contains( ++ "Converting from TensorBuffer to android.media.Image is unsupported."); ++ } ++ ++ @Test ++ public void getShapeOfInternalBitmapShouldSuccess() { ++ Bitmap bitmap = Bitmap.createBitmap(300, 400, Config.ARGB_8888); ++ TensorImage image = TensorImage.fromBitmap(bitmap); ++ ++ int width = image.getWidth(); ++ int height = image.getHeight(); ++ ++ assertThat(width).isEqualTo(300); ++ assertThat(height).isEqualTo(400); ++ } ++ ++ @Test ++ public void getShapeOfInternalTensorBufferShouldSuccess() { ++ TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 400, 300, 3}, UINT8); ++ TensorImage image = new TensorImage(); ++ image.load(buffer); ++ ++ int width = image.getWidth(); ++ int height = image.getHeight(); ++ ++ assertThat(width).isEqualTo(300); ++ assertThat(height).isEqualTo(400); ++ } ++ ++ @Test ++ public void getShapeOfNullImageShouldThrow() { ++ TensorImage image = new TensorImage(); ++ ++ assertThrows(IllegalStateException.class, image::getHeight); ++ } ++ ++ @Test ++ public void getShapeOfACorruptedBufferShouldThrowRatherThanCrash() { ++ int[] data = new int[] {1, 2, 3, 4, 5, 6}; ++ TensorBuffer buffer = TensorBuffer.createDynamic(UINT8); ++ buffer.loadArray(data, new int[] {1, 1, 2, 3}); ++ TensorImage image = new TensorImage(); ++ image.load(buffer); ++ // Reload data but with an invalid shape, which leads to `buffer` corrupted. ++ int[] newData = new int[] {1, 2, 3}; ++ buffer.loadArray(newData, new int[] {1, 1, 1, 3}); ++ ++ assertThrows(IllegalArgumentException.class, image::getHeight); ++ } ++ ++ @Test ++ public void getColorSpaceTypeSucceedsWithBitmapARGB_8888() { ++ Bitmap rgbBitmap = createRgbBitmap(); ++ TensorImage tensorImage = TensorImage.fromBitmap(rgbBitmap); ++ ++ assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB); ++ } ++ ++ @Test ++ public void getColorSpaceTypeSucceedsWithRgbTensorBuffer() { ++ TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false); ++ TensorImage tensorImage = new TensorImage(); ++ tensorImage.load(rgbBuffer); ++ ++ assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB); ++ } ++ ++ @Test ++ public void getColorSpaceTypeSucceedsWithGrayscaleTensorBuffer() { ++ TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false); ++ TensorImage tensorImage = new TensorImage(); ++ tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE); ++ ++ assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE); ++ } ++ ++ @Test ++ public void getColorSpaceTypeSucceedsWithRepeatedLoading() { ++ TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false); ++ TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false); ++ Bitmap rgbBitmap = createRgbBitmap(); ++ TensorImage tensorImage = new TensorImage(); ++ ++ tensorImage.load(rgbBuffer); ++ assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB); ++ tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE); ++ assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE); ++ tensorImage.load(rgbBitmap); ++ assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB); ++ } ++ ++ @Test ++ public void getColorSpaceTypeFailsWhenNoImageHasBeenLoaded() { ++ TensorImage tensorImage = new TensorImage(); ++ ++ IllegalStateException exception = ++ assertThrows(IllegalStateException.class, tensorImage::getColorSpaceType); ++ assertThat(exception).hasMessageThat().contains("No image has been loaded yet."); ++ } ++ ++ /** ++ * Creates an example bit map, which is a 10x10 ARGB bitmap and pixels are set by: pixel[i] ++ * = {A: 0, B: i + 2, G: i + 1, G: i}, where i is the flatten index ++ */ ++ private static Bitmap createExampleBitmap() { ++ int[] colors = new int[EXAMPLE_NUM_PIXELS]; ++ for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) { ++ colors[i] = Color.rgb(i, i + 1, i + 2); ++ } ++ ++ return Bitmap.createBitmap( ++ colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); ++ } ++ ++ private static float[] createExampleFloatPixels() { ++ float[] pixels = new float[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS]; ++ for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) { ++ pixels[j++] = (i - MEAN) / STDDEV; ++ pixels[j++] = (i + 1 - MEAN) / STDDEV; ++ pixels[j++] = (i + 2 - MEAN) / STDDEV; ++ } ++ return pixels; ++ } ++ ++ private static int[] createExampleUint8Pixels() { ++ int[] pixels = new int[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS]; ++ for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) { ++ pixels[j++] = i; ++ pixels[j++] = i + 1; ++ pixels[j++] = i + 2; ++ } ++ return pixels; ++ } ++ } ++ ++ /** Parameterized tests for loading TensorBuffers with RGB and Grayscale images. */ ++ @RunWith(ParameterizedRobolectricTestRunner.class) ++ public static final class LoadTensorBufferWithRgbAndGrayscale extends TensorImageTest { ++ /** ++ * Difference between the pair of float and uint8 values. It is used to test the data ++ * conversion. ++ */ ++ private static final float DELTA = 0.1f; ++ ++ /** The data type that used to create the TensorBuffer. */ ++ @Parameter(0) ++ public DataType tensorBufferDataType; ++ ++ /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */ ++ @Parameter(1) ++ public boolean isNormalized; ++ ++ /** The color space type of the TensorBuffer. */ ++ @Parameter(2) ++ public ColorSpaceType colorSpaceType; ++ ++ /** The data type that used to create the TensorImage. */ ++ @Parameter(3) ++ public DataType tensorImageDataType; ++ ++ @Parameters(name = "tensorBufferDataType={0}; isNormalized={1}; colorSpaceType={2};" ++ + " tensorImageDataType={3}") ++ public static Collection<Object[]> ++ data() { ++ return Arrays.asList(new Object[][] { ++ {FLOAT32, true, ColorSpaceType.RGB, FLOAT32}, ++ {FLOAT32, false, ColorSpaceType.RGB, UINT8}, ++ {UINT8, true, ColorSpaceType.RGB, FLOAT32}, ++ {UINT8, false, ColorSpaceType.RGB, UINT8}, ++ }); ++ } ++ ++ @Test ++ public void loadAndGetBitmapSucceedsWithTensorBufferAndColorSpaceType() { ++ TensorBuffer tensorBuffer = ++ createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); ++ TensorImage tensorImage = new TensorImage(tensorImageDataType); ++ ++ tensorImage.load(tensorBuffer, colorSpaceType); ++ Bitmap bitmap = tensorImage.getBitmap(); ++ ++ Bitmap expectedBitmap = createBitmap(colorSpaceType); ++ assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); ++ } ++ ++ @Test ++ public void loadAndGetTensorBufferSucceedsWithTensorBufferAndColorSpaceType() { ++ TensorBuffer tensorBuffer = ++ createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); ++ TensorImage tensorImage = new TensorImage(tensorImageDataType); ++ ++ tensorImage.load(tensorBuffer, colorSpaceType); ++ TensorBuffer buffer = tensorImage.getTensorBuffer(); ++ ++ // If tensorBufferDataType is UINT8, expectedTensorBuffer should not contain delta. ++ float expectedResidual = tensorBufferDataType == UINT8 ? 0.f : DELTA; ++ TensorBuffer expectedTensorBuffer = createTensorBuffer( ++ tensorImageDataType, isNormalized, colorSpaceType, expectedResidual); ++ assertEqualTensorBuffers(buffer, expectedTensorBuffer); ++ } ++ ++ private static TensorBuffer createTensorBuffer(DataType dataType, boolean isNormalized, ++ ColorSpaceType colorSpaceType, float delta) { ++ switch (colorSpaceType) { ++ case RGB: ++ return createRgbTensorBuffer(dataType, isNormalized, delta); ++ case GRAYSCALE: ++ return createGrayscaleTensorBuffer(dataType, isNormalized, delta); ++ default: ++ break; ++ } ++ throw new IllegalArgumentException( ++ "The ColorSpaceType, " + colorSpaceType + ", is unsupported."); ++ } ++ ++ private static Bitmap createBitmap(ColorSpaceType colorSpaceType) { ++ switch (colorSpaceType) { ++ case RGB: ++ return createRgbBitmap(); ++ case GRAYSCALE: ++ return createGrayscaleBitmap(); ++ default: ++ break; ++ } ++ throw new IllegalArgumentException( ++ "The ColorSpaceType, " + colorSpaceType + ", is unsupported."); ++ } ++ } ++ ++ /** Parameterized tests for loading TensorBuffers with YUV images. */ ++ @RunWith(ParameterizedRobolectricTestRunner.class) ++ public static final class LoadTensorBufferWithYUV extends TensorImageTest { ++ private static final int HEIGHT = 2; ++ private static final int WIDTH = 3; ++ ++ @Parameter(0) ++ public ColorSpaceType colorSpaceType; ++ ++ @Parameters(name = "colorSpaceType={0}") ++ public static Collection<Object[]> data() { ++ return Arrays.asList(new Object[][] { ++ {ColorSpaceType.NV12}, ++ {ColorSpaceType.NV21}, ++ {ColorSpaceType.YV12}, ++ {ColorSpaceType.YV21}, ++ }); ++ } ++ ++ @Test ++ public void loadTensorBufferWithColorSpaceShouldFail() { ++ TensorImage tensorImage = new TensorImage(); ++ ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () ++ -> tensorImage.load( ++ TensorBuffer.createDynamic(DataType.FLOAT32), colorSpaceType)); ++ assertThat(exception).hasMessageThat().contains( ++ "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use" ++ + " `load(TensorBuffer, ImageProperties)` for other color space types."); ++ } ++ ++ @Test ++ public void loadTensorBufferAndGetBitmapShouldFail() { ++ int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)]; ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); ++ tensorBuffer.loadArray(data, new int[] {data.length}); ++ ++ ImageProperties imageProperties = ImageProperties.builder() ++ .setHeight(HEIGHT) ++ .setWidth(WIDTH) ++ .setColorSpaceType(colorSpaceType) ++ .build(); ++ ++ TensorImage tensorImage = new TensorImage(DataType.FLOAT32); ++ tensorImage.load(tensorBuffer, imageProperties); ++ ++ UnsupportedOperationException exception = assertThrows( ++ UnsupportedOperationException.class, () -> tensorImage.getBitmap()); ++ assertThat(exception).hasMessageThat().contains( ++ "convertTensorBufferToBitmap() is unsupported for the color space type " ++ + colorSpaceType.name()); ++ } ++ } ++ ++ /** Parameterized tests for loading TensorBuffers with ImageProperties. */ ++ @RunWith(ParameterizedRobolectricTestRunner.class) ++ public static final class LoadTensorBufferWithImageProperties extends TensorImageTest { ++ private static final int HEIGHT = 2; ++ private static final int WIDTH = 3; ++ private static final int WRONG_WIDTH = 10; ++ ++ @Parameter(0) ++ public ColorSpaceType colorSpaceType; ++ ++ @Parameters(name = "colorSpaceType={0}") ++ public static Collection<Object[]> data() { ++ return Arrays.asList(new Object[][] { ++ {ColorSpaceType.RGB}, ++ {ColorSpaceType.GRAYSCALE}, ++ {ColorSpaceType.NV12}, ++ {ColorSpaceType.NV21}, ++ {ColorSpaceType.YV12}, ++ {ColorSpaceType.YV21}, ++ }); ++ } ++ ++ @Test ++ public void loadAndGetTensorBufferShouldSucceedWithCorrectProperties() { ++ int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)]; ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); ++ tensorBuffer.loadArray(data, new int[] {data.length}); ++ ++ ImageProperties imageProperties = ImageProperties.builder() ++ .setHeight(HEIGHT) ++ .setWidth(WIDTH) ++ .setColorSpaceType(colorSpaceType) ++ .build(); ++ ++ TensorImage tensorImage = new TensorImage(DataType.FLOAT32); ++ tensorImage.load(tensorBuffer, imageProperties); ++ ++ assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer); ++ } ++ ++ @Test ++ public void loadAndGetTensorBufferShouldSucceedWithLargerBuffer() { ++ // Should allow buffer to be greater than the size specified by height and width. ++ int moreElements = 1; ++ int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH) + moreElements]; ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); ++ tensorBuffer.loadArray(data, new int[] {data.length}); ++ ++ ImageProperties imageProperties = ImageProperties.builder() ++ .setHeight(HEIGHT) ++ .setWidth(WIDTH) ++ .setColorSpaceType(colorSpaceType) ++ .build(); ++ ++ TensorImage tensorImage = new TensorImage(DataType.FLOAT32); ++ tensorImage.load(tensorBuffer, imageProperties); ++ ++ assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer); ++ } ++ ++ @Test ++ public void loadAndGetByteBufferShouldSucceedWithCorrectProperties() { ++ ByteBuffer byteBuffer = ++ ByteBuffer.allocate(colorSpaceType.getNumElements(HEIGHT, WIDTH)); ++ ++ ImageProperties imageProperties = ImageProperties.builder() ++ .setHeight(HEIGHT) ++ .setWidth(WIDTH) ++ .setColorSpaceType(colorSpaceType) ++ .build(); ++ ++ TensorImage tensorImage = new TensorImage(DataType.UINT8); ++ tensorImage.load(byteBuffer, imageProperties); ++ ++ assertEqualByteBuffers(tensorImage.getBuffer(), byteBuffer); ++ } ++ ++ @Test ++ public void loadTensorBufferWithShouldFailWithWrongImageShape() { ++ int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)]; ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); ++ tensorBuffer.loadArray(data, new int[] {data.length}); ++ ++ ImageProperties imageProperties = ImageProperties.builder() ++ .setHeight(HEIGHT) ++ .setWidth(WRONG_WIDTH) ++ .setColorSpaceType(colorSpaceType) ++ .build(); ++ ++ TensorImage tensorImage = new TensorImage(DataType.FLOAT32); ++ ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> tensorImage.load(tensorBuffer, imageProperties)); ++ assertThat(exception).hasMessageThat().contains(String.format( ++ "The given number of elements (%d) does not match the image (%s) in %d x %d. The" ++ + " expected number of elements should be at least %d.", ++ data.length, colorSpaceType.name(), HEIGHT, WRONG_WIDTH, ++ colorSpaceType.getNumElements(HEIGHT, WRONG_WIDTH))); ++ } ++ ++ @Test ++ public void getShapeOfInternalTensorBufferShouldSuccess() { ++ int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)]; ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); ++ tensorBuffer.loadArray(data, new int[] {data.length}); ++ ++ ImageProperties imageProperties = ImageProperties.builder() ++ .setHeight(HEIGHT) ++ .setWidth(WIDTH) ++ .setColorSpaceType(colorSpaceType) ++ .build(); ++ ++ TensorImage tensorImage = new TensorImage(DataType.FLOAT32); ++ tensorImage.load(tensorBuffer, imageProperties); ++ ++ assertThat(tensorImage.getWidth()).isEqualTo(WIDTH); ++ assertThat(tensorImage.getHeight()).isEqualTo(HEIGHT); ++ } ++ } ++ ++ /** Parameterized tests for loading TensorBuffer with invalid shapes. */ ++ @RunWith(ParameterizedRobolectricTestRunner.class) ++ public static final class LoadTensorBufferWithInvalidShapeTest extends TensorImageTest { ++ private static final String RGB_ASSERT_SHAPE_MESSAGE = ++ "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels" ++ + " representing R, G, B in order. The provided image shape is "; ++ private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE = ++ "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image" ++ + " shape is "; ++ ++ @Parameter(0) ++ public ColorSpaceType colorSpaceType; ++ ++ /** The shape that does not match the colorSpaceType. */ ++ @Parameter(1) ++ public int[] invalidShape; ++ ++ @Parameter(2) ++ public String errorMessage; ++ ++ @Parameters(name = "colorSpaceType={0}; invalidShape={1}") ++ public static Collection<Object[]> data() { ++ return Arrays.asList(new Object[][] { ++ {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20}, ++ GRAYSCALE_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3}, ++ GRAYSCALE_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4}, ++ GRAYSCALE_ASSERT_SHAPE_MESSAGE}, ++ {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, ++ }); ++ } ++ ++ @Test ++ public void loadTensorBufferWithInvalidShape() { ++ TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(invalidShape, UINT8); ++ TensorImage tensorImage = new TensorImage(); ++ ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> tensorImage.load(tensorBuffer, colorSpaceType)); ++ assertThat(exception).hasMessageThat().contains( ++ errorMessage + Arrays.toString(invalidShape)); ++ } ++ } ++ ++ private static void assertEqualTensorBuffers( ++ TensorBuffer tensorBuffer1, TensorBuffer tensorBuffer2) { ++ assertEqualByteBuffers(tensorBuffer1.getBuffer(), tensorBuffer2.getBuffer()); ++ assertArrayEquals(tensorBuffer1.getShape(), tensorBuffer2.getShape()); ++ } ++ ++ private static void assertEqualByteBuffers(ByteBuffer buffer1, ByteBuffer buffer2) { ++ buffer1.rewind(); ++ buffer2.rewind(); ++ assertThat(buffer1.equals(buffer2)).isTrue(); ++ } ++ ++ private static void setUpImageMock(Image imageMock, int imageFormat) { ++ when(imageMock.getFormat()).thenReturn(imageFormat); + } +- } +- +- private static void assertEqualTensorBuffers( +- TensorBuffer tensorBuffer1, TensorBuffer tensorBuffer2) { +- assertEqualByteBuffers(tensorBuffer1.getBuffer(), tensorBuffer2.getBuffer()); +- assertArrayEquals(tensorBuffer1.getShape(), tensorBuffer2.getShape()); +- } +- +- private static void assertEqualByteBuffers(ByteBuffer buffer1, ByteBuffer buffer2) { +- buffer1.rewind(); +- buffer2.rewind(); +- assertThat(buffer1.equals(buffer2)).isTrue(); +- } +- +- private static void setUpImageMock(Image imageMock, int imageFormat) { +- when(imageMock.getFormat()).thenReturn(imageFormat); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java +index 7a5d0e9a9ea33..4ac2eca0b8cc6 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java +@@ -17,109 +17,112 @@ package org.tensorflow.lite.support.image; + + import android.graphics.Bitmap; + import android.graphics.Color; +-import java.nio.ByteBuffer; ++ + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + ++import java.nio.ByteBuffer; ++ + /** Creates test images for other test files. */ + final class TestImageCreator { +- /** +- * Creates an example bitmap, which is a 10x10 ARGB bitmap and pixels are set by: <br> +- * pixel[i] = {A: 255, B: i + 2, G: i + 1, R: i}, where i is the flatten index. +- */ +- static Bitmap createRgbBitmap() { +- int[] colors = new int[100]; +- for (int i = 0; i < 100; i++) { +- colors[i] = Color.rgb(i, i + 1, i + 2); ++ /** ++ * Creates an example bitmap, which is a 10x10 ARGB bitmap and pixels are set by: <br> ++ * pixel[i] = {A: 255, B: i + 2, G: i + 1, R: i}, where i is the flatten index. ++ */ ++ static Bitmap createRgbBitmap() { ++ int[] colors = new int[100]; ++ for (int i = 0; i < 100; i++) { ++ colors[i] = Color.rgb(i, i + 1, i + 2); ++ } ++ return Bitmap.createBitmap(colors, 10, 10, Bitmap.Config.ARGB_8888); + } +- return Bitmap.createBitmap(colors, 10, 10, Bitmap.Config.ARGB_8888); +- } + +- /** +- * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap. +- * +- * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is +- * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...]. +- * +- * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w, 3) +- */ +- static TensorBuffer createRgbTensorBuffer(DataType dataType, boolean isNormalized) { +- return createRgbTensorBuffer(dataType, isNormalized, /*delta=*/ 0.1f); +- } +- +- /** +- * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap. +- * +- * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w) +- * @param delta the delta that applied to the float values, such that the float array is [0 + + +- * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...] +- */ +- static TensorBuffer createRgbTensorBuffer(DataType dataType, boolean isNormalized, float delta) { +- float[] rgbValues = new float[300]; +- for (int i = 0, j = 0; i < 100; i++) { +- rgbValues[j++] = i + delta; +- rgbValues[j++] = i + 1 + delta; +- rgbValues[j++] = i + 2 + delta; ++ /** ++ * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap. ++ * ++ * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is ++ * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...]. ++ * ++ * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w, 3) ++ */ ++ static TensorBuffer createRgbTensorBuffer(DataType dataType, boolean isNormalized) { ++ return createRgbTensorBuffer(dataType, isNormalized, /*delta=*/0.1f); + } + +- int[] shape = isNormalized ? new int[] {1, 10, 10, 3} : new int[] {10, 10, 3}; +- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType); +- // If dataType is UINT8, rgbValues will be converted into uint8, such as from +- // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...]. +- buffer.loadArray(rgbValues, shape); +- return buffer; +- } ++ /** ++ * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap. ++ * ++ * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w) ++ * @param delta the delta that applied to the float values, such that the float array is [0 + + ++ * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...] ++ */ ++ static TensorBuffer createRgbTensorBuffer( ++ DataType dataType, boolean isNormalized, float delta) { ++ float[] rgbValues = new float[300]; ++ for (int i = 0, j = 0; i < 100; i++) { ++ rgbValues[j++] = i + delta; ++ rgbValues[j++] = i + 1 + delta; ++ rgbValues[j++] = i + 2 + delta; ++ } ++ ++ int[] shape = isNormalized ? new int[] {1, 10, 10, 3} : new int[] {10, 10, 3}; ++ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType); ++ // If dataType is UINT8, rgbValues will be converted into uint8, such as from ++ // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...]. ++ buffer.loadArray(rgbValues, shape); ++ return buffer; ++ } + +- /** +- * Creates an example bitmap, which is a 10x10 ALPHA_8 bitmap and pixels are set by: <br> +- * pixel[i] = i, where i is the flatten index. +- */ +- static Bitmap createGrayscaleBitmap() { +- byte[] grayValues = new byte[100]; +- for (int i = 0; i < 100; i++) { +- grayValues[i] = (byte) i; ++ /** ++ * Creates an example bitmap, which is a 10x10 ALPHA_8 bitmap and pixels are set by: <br> ++ * pixel[i] = i, where i is the flatten index. ++ */ ++ static Bitmap createGrayscaleBitmap() { ++ byte[] grayValues = new byte[100]; ++ for (int i = 0; i < 100; i++) { ++ grayValues[i] = (byte) i; ++ } ++ ByteBuffer buffer = ByteBuffer.wrap(grayValues); ++ Bitmap bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ALPHA_8); ++ buffer.rewind(); ++ bitmap.copyPixelsFromBuffer(buffer); ++ return bitmap; + } +- ByteBuffer buffer = ByteBuffer.wrap(grayValues); +- Bitmap bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ALPHA_8); +- buffer.rewind(); +- bitmap.copyPixelsFromBuffer(buffer); +- return bitmap; +- } + +- /** +- * Creates a 10*10 float or uint8 TensorBuffer representing the same image in +- * createGrayscaleBitmap. +- * +- * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is +- * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...]. +- * +- * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w) +- */ +- static TensorBuffer createGrayscaleTensorBuffer(DataType dataType, boolean isNormalized) { +- return createGrayscaleTensorBuffer(dataType, isNormalized, /*delta=*/ 0.1f); +- } ++ /** ++ * Creates a 10*10 float or uint8 TensorBuffer representing the same image in ++ * createGrayscaleBitmap. ++ * ++ * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is ++ * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...]. ++ * ++ * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w) ++ */ ++ static TensorBuffer createGrayscaleTensorBuffer(DataType dataType, boolean isNormalized) { ++ return createGrayscaleTensorBuffer(dataType, isNormalized, /*delta=*/0.1f); ++ } + +- /** +- * Creates a 10*10 float or uint8 TensorBuffer representing the same image in +- * createGrayscaleBitmap. +- * +- * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w) +- * @param delta the delta that applied to the float values, such that the float array is [0 + +- * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...] +- */ +- static TensorBuffer createGrayscaleTensorBuffer( +- DataType dataType, boolean isNormalized, float delta) { +- float[] grayValues = new float[100]; +- for (int i = 0; i < 100; i++) { +- grayValues[i] = i + delta; ++ /** ++ * Creates a 10*10 float or uint8 TensorBuffer representing the same image in ++ * createGrayscaleBitmap. ++ * ++ * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w) ++ * @param delta the delta that applied to the float values, such that the float array is [0 + ++ * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...] ++ */ ++ static TensorBuffer createGrayscaleTensorBuffer( ++ DataType dataType, boolean isNormalized, float delta) { ++ float[] grayValues = new float[100]; ++ for (int i = 0; i < 100; i++) { ++ grayValues[i] = i + delta; ++ } ++ int[] shape = isNormalized ? new int[] {1, 10, 10, 1} : new int[] {10, 10}; ++ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType); ++ // If dataType is UINT8, grayValues will be converted into uint8, such as from ++ // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...]. ++ buffer.loadArray(grayValues, shape); ++ return buffer; + } +- int[] shape = isNormalized ? new int[] {1, 10, 10, 1} : new int[] {10, 10}; +- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType); +- // If dataType is UINT8, grayValues will be converted into uint8, such as from +- // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...]. +- buffer.loadArray(grayValues, shape); +- return buffer; +- } + +- private TestImageCreator() {} ++ private TestImageCreator() {} + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java +index a34f47d44c0ac..070e17893ad76 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java +@@ -19,7 +19,9 @@ import static com.google.common.truth.Truth.assertThat; + + import android.graphics.Bitmap; + import android.graphics.PointF; ++ + import androidx.test.ext.junit.runners.AndroidJUnit4; ++ + import org.junit.Before; + import org.junit.Test; + import org.junit.runner.RunWith; +@@ -31,63 +33,62 @@ import org.tensorflow.lite.support.image.ops.ResizeOp.ResizeMethod; + /** Instrumented unit test for {@link ResizeOp}. */ + @RunWith(AndroidJUnit4.class) + public class ResizeOpInstrumentedTest { ++ private static final int EXAMPLE_WIDTH = 10; ++ private static final int EXAMPLE_HEIGHT = 15; + +- private static final int EXAMPLE_WIDTH = 10; +- private static final int EXAMPLE_HEIGHT = 15; +- +- private Bitmap exampleBitmap; +- private TensorImage input; ++ private Bitmap exampleBitmap; ++ private TensorImage input; + +- @Before +- public void setUp() { +- exampleBitmap = createExampleBitmap(); +- input = new TensorImage(DataType.UINT8); +- input.load(exampleBitmap); +- } ++ @Before ++ public void setUp() { ++ exampleBitmap = createExampleBitmap(); ++ input = new TensorImage(DataType.UINT8); ++ input.load(exampleBitmap); ++ } + +- @Test +- public void resizeShouldSuccess() { +- int targetWidth = EXAMPLE_WIDTH * 2; +- int targetHeight = EXAMPLE_HEIGHT * 2; +- ImageProcessor processor = +- new ImageProcessor.Builder() +- .add(new ResizeOp(targetHeight, targetWidth, ResizeMethod.NEAREST_NEIGHBOR)) +- .build(); +- TensorImage output = processor.process(input); ++ @Test ++ public void resizeShouldSuccess() { ++ int targetWidth = EXAMPLE_WIDTH * 2; ++ int targetHeight = EXAMPLE_HEIGHT * 2; ++ ImageProcessor processor = ++ new ImageProcessor.Builder() ++ .add(new ResizeOp(targetHeight, targetWidth, ResizeMethod.NEAREST_NEIGHBOR)) ++ .build(); ++ TensorImage output = processor.process(input); + +- Bitmap outputBitmap = output.getBitmap(); +- assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth); +- assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight); +- for (int i = 0; i < outputBitmap.getWidth(); i++) { +- for (int j = 0; j < outputBitmap.getHeight(); j++) { +- int expected = exampleBitmap.getPixel(i / 2, j / 2); +- assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected); +- } ++ Bitmap outputBitmap = output.getBitmap(); ++ assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth); ++ assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight); ++ for (int i = 0; i < outputBitmap.getWidth(); i++) { ++ for (int j = 0; j < outputBitmap.getHeight(); j++) { ++ int expected = exampleBitmap.getPixel(i / 2, j / 2); ++ assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected); ++ } ++ } + } +- } + +- @Test +- public void inverseTransformPointShouldSuccess() { +- ResizeOp op = new ResizeOp(200, 300, ResizeMethod.NEAREST_NEIGHBOR); +- PointF transformed = new PointF(32.0f, 42.0f); +- // The original image size is 900x400 assumed +- PointF original = op.inverseTransform(transformed, 400, 900); +- assertThat(original.x).isEqualTo(96); +- assertThat(original.y).isEqualTo(84); +- PointF outside = op.inverseTransform(new PointF(500, 1000), 400, 900); +- assertThat(outside.x).isEqualTo(1500); +- assertThat(outside.y).isEqualTo(2000); +- } ++ @Test ++ public void inverseTransformPointShouldSuccess() { ++ ResizeOp op = new ResizeOp(200, 300, ResizeMethod.NEAREST_NEIGHBOR); ++ PointF transformed = new PointF(32.0f, 42.0f); ++ // The original image size is 900x400 assumed ++ PointF original = op.inverseTransform(transformed, 400, 900); ++ assertThat(original.x).isEqualTo(96); ++ assertThat(original.y).isEqualTo(84); ++ PointF outside = op.inverseTransform(new PointF(500, 1000), 400, 900); ++ assertThat(outside.x).isEqualTo(1500); ++ assertThat(outside.y).isEqualTo(2000); ++ } + +- /** +- * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] = {A: +- * 255, B: i + 2, G: i + 1, G: i}, where i is the flatten index +- */ +- private static Bitmap createExampleBitmap() { +- int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT]; +- for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) { +- colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); ++ /** ++ * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] = ++ * {A: 255, B: i + 2, G: i + 1, G: i}, where i is the flatten index ++ */ ++ private static Bitmap createExampleBitmap() { ++ int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT]; ++ for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) { ++ colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); ++ } ++ return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); + } +- return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java +index 5c483780b30f4..85c777904f2ec 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java +@@ -19,7 +19,9 @@ import static com.google.common.truth.Truth.assertThat; + + import android.graphics.Bitmap; + import android.graphics.PointF; ++ + import androidx.test.ext.junit.runners.AndroidJUnit4; ++ + import org.junit.Before; + import org.junit.Test; + import org.junit.runner.RunWith; +@@ -30,131 +32,128 @@ import org.tensorflow.lite.support.image.TensorImage; + /** Instrumented unit test for {@link ResizeWithCropOrPadOp}. */ + @RunWith(AndroidJUnit4.class) + public class ResizeWithCropOrPadOpInstrumentedTest { ++ private Bitmap exampleBitmap; ++ private TensorImage input; + +- private Bitmap exampleBitmap; +- private TensorImage input; +- +- private static final int EXAMPLE_WIDTH = 10; +- private static final int EXAMPLE_HEIGHT = 15; +- +- @Before +- public void setUp() { +- exampleBitmap = createExampleBitmap(); +- input = new TensorImage(DataType.UINT8); +- input.load(exampleBitmap); +- } +- +- @Test +- public void testResizeWithCrop() { +- int targetWidth = 6; +- int targetHeight = 5; +- ImageProcessor processor = +- new ImageProcessor.Builder() +- .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth)) +- .build(); +- TensorImage output = processor.process(input); +- +- Bitmap outputBitmap = output.getBitmap(); +- assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth); +- assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight); +- for (int i = 0; i < outputBitmap.getWidth(); i++) { +- for (int j = 0; j < outputBitmap.getHeight(); j++) { +- int expected = +- exampleBitmap.getPixel( +- i + (EXAMPLE_WIDTH - targetWidth) / 2, j + (EXAMPLE_HEIGHT - targetHeight) / 2); +- assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected); +- } ++ private static final int EXAMPLE_WIDTH = 10; ++ private static final int EXAMPLE_HEIGHT = 15; ++ ++ @Before ++ public void setUp() { ++ exampleBitmap = createExampleBitmap(); ++ input = new TensorImage(DataType.UINT8); ++ input.load(exampleBitmap); + } +- } +- +- @Test +- public void testResizeWithPad() { +- int targetWidth = 15; +- int targetHeight = 20; +- ImageProcessor processor = +- new ImageProcessor.Builder() +- .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth)) +- .build(); +- TensorImage output = processor.process(input); +- // Pad 2 rows / columns on top / left, and 3 rows / columns on bottom / right +- +- Bitmap outputBitmap = output.getBitmap(); +- assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth); +- assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight); +- int leftPad = (targetWidth - EXAMPLE_WIDTH) / 2; +- int topPad = (targetHeight - EXAMPLE_HEIGHT) / 2; +- for (int i = 0; i < outputBitmap.getWidth(); i++) { +- for (int j = 0; j < outputBitmap.getHeight(); j++) { +- int expected = 0; // ZERO padding +- if (i >= leftPad +- && i < leftPad + EXAMPLE_WIDTH +- && j >= topPad +- && j < topPad + EXAMPLE_HEIGHT) { +- expected = exampleBitmap.getPixel(i - leftPad, j - topPad); ++ ++ @Test ++ public void testResizeWithCrop() { ++ int targetWidth = 6; ++ int targetHeight = 5; ++ ImageProcessor processor = ++ new ImageProcessor.Builder() ++ .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth)) ++ .build(); ++ TensorImage output = processor.process(input); ++ ++ Bitmap outputBitmap = output.getBitmap(); ++ assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth); ++ assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight); ++ for (int i = 0; i < outputBitmap.getWidth(); i++) { ++ for (int j = 0; j < outputBitmap.getHeight(); j++) { ++ int expected = exampleBitmap.getPixel(i + (EXAMPLE_WIDTH - targetWidth) / 2, ++ j + (EXAMPLE_HEIGHT - targetHeight) / 2); ++ assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected); ++ } ++ } ++ } ++ ++ @Test ++ public void testResizeWithPad() { ++ int targetWidth = 15; ++ int targetHeight = 20; ++ ImageProcessor processor = ++ new ImageProcessor.Builder() ++ .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth)) ++ .build(); ++ TensorImage output = processor.process(input); ++ // Pad 2 rows / columns on top / left, and 3 rows / columns on bottom / right ++ ++ Bitmap outputBitmap = output.getBitmap(); ++ assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth); ++ assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight); ++ int leftPad = (targetWidth - EXAMPLE_WIDTH) / 2; ++ int topPad = (targetHeight - EXAMPLE_HEIGHT) / 2; ++ for (int i = 0; i < outputBitmap.getWidth(); i++) { ++ for (int j = 0; j < outputBitmap.getHeight(); j++) { ++ int expected = 0; // ZERO padding ++ if (i >= leftPad && i < leftPad + EXAMPLE_WIDTH && j >= topPad ++ && j < topPad + EXAMPLE_HEIGHT) { ++ expected = exampleBitmap.getPixel(i - leftPad, j - topPad); ++ } ++ assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected); ++ } + } +- assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected); +- } + } +- } +- +- @Test +- public void testResizeWithCropAndPad() { +- int targetSize = 12; +- // Pad 1 column on left & right, crop 1 row on top and 2 rows on bottom +- ImageProcessor processor = +- new ImageProcessor.Builder().add(new ResizeWithCropOrPadOp(targetSize, targetSize)).build(); +- TensorImage output = processor.process(input); +- +- Bitmap outputBitmap = output.getBitmap(); +- assertThat(outputBitmap.getWidth()).isEqualTo(targetSize); +- assertThat(outputBitmap.getHeight()).isEqualTo(targetSize); +- +- int leftPad = (targetSize - EXAMPLE_WIDTH) / 2; +- int topCrop = (EXAMPLE_HEIGHT - targetSize) / 2; +- for (int i = 0; i < outputBitmap.getWidth(); i++) { +- for (int j = 0; j < outputBitmap.getHeight(); j++) { +- int expected = 0; +- if (i >= leftPad && i < leftPad + EXAMPLE_WIDTH) { +- expected = exampleBitmap.getPixel(i - leftPad, j + topCrop); ++ ++ @Test ++ public void testResizeWithCropAndPad() { ++ int targetSize = 12; ++ // Pad 1 column on left & right, crop 1 row on top and 2 rows on bottom ++ ImageProcessor processor = new ImageProcessor.Builder() ++ .add(new ResizeWithCropOrPadOp(targetSize, targetSize)) ++ .build(); ++ TensorImage output = processor.process(input); ++ ++ Bitmap outputBitmap = output.getBitmap(); ++ assertThat(outputBitmap.getWidth()).isEqualTo(targetSize); ++ assertThat(outputBitmap.getHeight()).isEqualTo(targetSize); ++ ++ int leftPad = (targetSize - EXAMPLE_WIDTH) / 2; ++ int topCrop = (EXAMPLE_HEIGHT - targetSize) / 2; ++ for (int i = 0; i < outputBitmap.getWidth(); i++) { ++ for (int j = 0; j < outputBitmap.getHeight(); j++) { ++ int expected = 0; ++ if (i >= leftPad && i < leftPad + EXAMPLE_WIDTH) { ++ expected = exampleBitmap.getPixel(i - leftPad, j + topCrop); ++ } ++ assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected); ++ } + } +- assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected); +- } + } +- } +- +- @Test +- public void inverseTransformCorrectlyWhenCropped() { +- ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300); +- // The point (100, 50) is transformed from 600x500 image +- PointF original = op.inverseTransform(new PointF(100, 50), 500, 600); +- assertThat(original.x).isEqualTo(250); +- assertThat(original.y).isEqualTo(150); +- PointF cropped = op.inverseTransform(new PointF(-10, -10), 500, 600); +- assertThat(cropped.x).isEqualTo(140); +- assertThat(cropped.y).isEqualTo(90); +- } +- +- @Test +- public void inverseTransformCorrectlyWhenPadded() { +- ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300); +- // The point (100, 50) is transformed from 100x200 image +- PointF original = op.inverseTransform(new PointF(100, 50), 200, 100); +- assertThat(original.x).isEqualTo(0); +- assertThat(original.y).isEqualTo(0); +- PointF outside = op.inverseTransform(new PointF(50, 10), 200, 100); +- assertThat(outside.x).isEqualTo(-50); +- assertThat(outside.y).isEqualTo(-40); +- } +- +- /** +- * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] = {A: +- * 255, R: i + 2, G: i + 1, B: i}, where i is the flatten index +- */ +- private static Bitmap createExampleBitmap() { +- int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT]; +- for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) { +- colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); ++ ++ @Test ++ public void inverseTransformCorrectlyWhenCropped() { ++ ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300); ++ // The point (100, 50) is transformed from 600x500 image ++ PointF original = op.inverseTransform(new PointF(100, 50), 500, 600); ++ assertThat(original.x).isEqualTo(250); ++ assertThat(original.y).isEqualTo(150); ++ PointF cropped = op.inverseTransform(new PointF(-10, -10), 500, 600); ++ assertThat(cropped.x).isEqualTo(140); ++ assertThat(cropped.y).isEqualTo(90); ++ } ++ ++ @Test ++ public void inverseTransformCorrectlyWhenPadded() { ++ ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300); ++ // The point (100, 50) is transformed from 100x200 image ++ PointF original = op.inverseTransform(new PointF(100, 50), 200, 100); ++ assertThat(original.x).isEqualTo(0); ++ assertThat(original.y).isEqualTo(0); ++ PointF outside = op.inverseTransform(new PointF(50, 10), 200, 100); ++ assertThat(outside.x).isEqualTo(-50); ++ assertThat(outside.y).isEqualTo(-40); ++ } ++ ++ /** ++ * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] = ++ * {A: 255, R: i + 2, G: i + 1, B: i}, where i is the flatten index ++ */ ++ private static Bitmap createExampleBitmap() { ++ int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT]; ++ for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) { ++ colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); ++ } ++ return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); + } +- return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java +index eb54788764f1e..d00fe0e44422e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java +@@ -19,7 +19,9 @@ import static com.google.common.truth.Truth.assertThat; + + import android.graphics.Bitmap; + import android.graphics.PointF; ++ + import androidx.test.ext.junit.runners.AndroidJUnit4; ++ + import org.junit.Before; + import org.junit.Test; + import org.junit.runner.RunWith; +@@ -30,68 +32,68 @@ import org.tensorflow.lite.support.image.TensorImage; + /** Instrumented unit test for {@link Rot90Op}. */ + @RunWith(AndroidJUnit4.class) + public class Rot90OpInstrumentedTest { ++ private Bitmap exampleBitmap; ++ private TensorImage input; ++ ++ private static final int EXAMPLE_WIDTH = 10; ++ private static final int EXAMPLE_HEIGHT = 15; + +- private Bitmap exampleBitmap; +- private TensorImage input; +- +- private static final int EXAMPLE_WIDTH = 10; +- private static final int EXAMPLE_HEIGHT = 15; +- +- @Before +- public void setUp() { +- exampleBitmap = createExampleBitmap(); +- input = new TensorImage(DataType.UINT8); +- input.load(exampleBitmap); +- } +- +- @Test +- public void testRot90() { +- ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op()).build(); +- TensorImage output = processor.process(input); +- +- Bitmap outputBitmap = output.getBitmap(); +- assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_HEIGHT); +- assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_WIDTH); +- for (int i = 0; i < exampleBitmap.getWidth(); i++) { +- for (int j = 0; j < exampleBitmap.getHeight(); j++) { +- assertThat(exampleBitmap.getPixel(i, j)) +- .isEqualTo(outputBitmap.getPixel(j, EXAMPLE_WIDTH - 1 - i)); +- } ++ @Before ++ public void setUp() { ++ exampleBitmap = createExampleBitmap(); ++ input = new TensorImage(DataType.UINT8); ++ input.load(exampleBitmap); + } +- } +- +- @Test +- public void testRot90Twice() { +- ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op(2)).build(); +- TensorImage output = processor.process(input); +- +- Bitmap outputBitmap = output.getBitmap(); +- assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_WIDTH); +- assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_HEIGHT); +- for (int i = 0; i < exampleBitmap.getWidth(); i++) { +- for (int j = 0; j < exampleBitmap.getHeight(); j++) { +- assertThat(exampleBitmap.getPixel(i, j)) +- .isEqualTo(outputBitmap.getPixel(EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j)); +- } ++ ++ @Test ++ public void testRot90() { ++ ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op()).build(); ++ TensorImage output = processor.process(input); ++ ++ Bitmap outputBitmap = output.getBitmap(); ++ assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_HEIGHT); ++ assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_WIDTH); ++ for (int i = 0; i < exampleBitmap.getWidth(); i++) { ++ for (int j = 0; j < exampleBitmap.getHeight(); j++) { ++ assertThat(exampleBitmap.getPixel(i, j)) ++ .isEqualTo(outputBitmap.getPixel(j, EXAMPLE_WIDTH - 1 - i)); ++ } ++ } + } +- } +- +- @Test +- public void inverseTransformCorrectlyWhenRotated() { +- Rot90Op op = new Rot90Op(3); +- PointF original = op.inverseTransform(new PointF(20, 10), 200, 100); +- assertThat(original.x).isEqualTo(10); +- assertThat(original.y).isEqualTo(180); +- PointF outside = op.inverseTransform(new PointF(-10, 110), 200, 100); +- assertThat(outside.x).isEqualTo(110); +- assertThat(outside.y).isEqualTo(210); +- } +- +- private static Bitmap createExampleBitmap() { +- int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT]; +- for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) { +- colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); ++ ++ @Test ++ public void testRot90Twice() { ++ ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op(2)).build(); ++ TensorImage output = processor.process(input); ++ ++ Bitmap outputBitmap = output.getBitmap(); ++ assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_WIDTH); ++ assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_HEIGHT); ++ for (int i = 0; i < exampleBitmap.getWidth(); i++) { ++ for (int j = 0; j < exampleBitmap.getHeight(); j++) { ++ assertThat(exampleBitmap.getPixel(i, j)) ++ .isEqualTo(outputBitmap.getPixel( ++ EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j)); ++ } ++ } ++ } ++ ++ @Test ++ public void inverseTransformCorrectlyWhenRotated() { ++ Rot90Op op = new Rot90Op(3); ++ PointF original = op.inverseTransform(new PointF(20, 10), 200, 100); ++ assertThat(original.x).isEqualTo(10); ++ assertThat(original.y).isEqualTo(180); ++ PointF outside = op.inverseTransform(new PointF(-10, 110), 200, 100); ++ assertThat(outside.x).isEqualTo(110); ++ assertThat(outside.y).isEqualTo(210); ++ } ++ ++ private static Bitmap createExampleBitmap() { ++ int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT]; ++ for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) { ++ colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); ++ } ++ return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); + } +- return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java +index 46713fd486fa7..f024f68911d27 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java +@@ -16,6 +16,7 @@ limitations under the License. + package org.tensorflow.lite.support.image.ops; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + import static org.mockito.Mockito.doReturn; + import static org.tensorflow.lite.DataType.UINT8; +@@ -24,7 +25,9 @@ import android.graphics.Bitmap; + import android.graphics.Color; + import android.graphics.ImageFormat; + import android.media.Image; ++ + import androidx.test.ext.junit.runners.AndroidJUnit4; ++ + import org.junit.Before; + import org.junit.Rule; + import org.junit.Test; +@@ -40,54 +43,55 @@ import org.tensorflow.lite.support.image.TensorImage; + /** Instrumented unit test for {@link TransformToGrayscaleOp}. */ + @RunWith(AndroidJUnit4.class) + public class TransformToGrayScaleOpInstrumentedTest { +- +- @Rule public final MockitoRule mockito = MockitoJUnit.rule(); +- +- private TensorImage input; +- +- private static final int EXAMPLE_WIDTH = 2; +- private static final int EXAMPLE_HEIGHT = 3; +- @Mock Image imageMock; +- +- @Before +- public void setUp() { +- Bitmap exampleBitmap = createExampleBitmap(); +- input = new TensorImage(DataType.UINT8); +- input.load(exampleBitmap); +- } +- +- @Test +- public void apply_onRgb_succeeds() { +- ImageProcessor processor = +- new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build(); +- +- TensorImage output = processor.process(input); +- int[] pixels = output.getTensorBuffer().getIntArray(); +- +- assertThat(output.getWidth()).isEqualTo(EXAMPLE_WIDTH); +- assertThat(output.getHeight()).isEqualTo(EXAMPLE_HEIGHT); +- assertThat(output.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE); +- assertThat(pixels).isEqualTo(new int[] {0, 255, 76, 29, 150, 179}); +- } +- +- @Test +- public void apply_onYuv_throws() { +- setUpImageMock(imageMock, ImageFormat.YUV_420_888); +- TensorImage tensorImage = new TensorImage(UINT8); +- tensorImage.load(imageMock); +- ImageProcessor processor = +- new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build(); +- +- assertThrows(IllegalArgumentException.class, () -> processor.process(tensorImage)); +- } +- +- private static Bitmap createExampleBitmap() { +- int[] colors = +- new int[] {Color.BLACK, Color.WHITE, Color.RED, Color.BLUE, Color.GREEN, Color.CYAN}; +- return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); +- } +- +- private static void setUpImageMock(Image imageMock, int imageFormat) { +- doReturn(imageFormat).when(imageMock).getFormat(); +- } ++ @Rule ++ public final MockitoRule mockito = MockitoJUnit.rule(); ++ ++ private TensorImage input; ++ ++ private static final int EXAMPLE_WIDTH = 2; ++ private static final int EXAMPLE_HEIGHT = 3; ++ @Mock ++ Image imageMock; ++ ++ @Before ++ public void setUp() { ++ Bitmap exampleBitmap = createExampleBitmap(); ++ input = new TensorImage(DataType.UINT8); ++ input.load(exampleBitmap); ++ } ++ ++ @Test ++ public void apply_onRgb_succeeds() { ++ ImageProcessor processor = ++ new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build(); ++ ++ TensorImage output = processor.process(input); ++ int[] pixels = output.getTensorBuffer().getIntArray(); ++ ++ assertThat(output.getWidth()).isEqualTo(EXAMPLE_WIDTH); ++ assertThat(output.getHeight()).isEqualTo(EXAMPLE_HEIGHT); ++ assertThat(output.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE); ++ assertThat(pixels).isEqualTo(new int[] {0, 255, 76, 29, 150, 179}); ++ } ++ ++ @Test ++ public void apply_onYuv_throws() { ++ setUpImageMock(imageMock, ImageFormat.YUV_420_888); ++ TensorImage tensorImage = new TensorImage(UINT8); ++ tensorImage.load(imageMock); ++ ImageProcessor processor = ++ new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build(); ++ ++ assertThrows(IllegalArgumentException.class, () -> processor.process(tensorImage)); ++ } ++ ++ private static Bitmap createExampleBitmap() { ++ int[] colors = new int[] { ++ Color.BLACK, Color.WHITE, Color.RED, Color.BLUE, Color.GREEN, Color.CYAN}; ++ return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); ++ } ++ ++ private static void setUpImageMock(Image imageMock, int imageFormat) { ++ doReturn(imageFormat).when(imageMock).getFormat(); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java +index 28620dd941e9c..98d1f92f56c6d 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java +@@ -24,114 +24,98 @@ import org.robolectric.RobolectricTestRunner; + /** Tests of {@link org.tensorflow.lite.support.label.Category}. */ + @RunWith(RobolectricTestRunner.class) + public final class CategoryTest { +- private static final String APPLE_LABEL = "apple"; +- private static final String DEFAULT_DISPLAY_NAME = ""; +- private static final String APPLE_DISPLAY_NAME = "manzana"; // "apple" in Spanish. +- private static final float APPLE_SCORE = 0.5f; +- private static final int APPLE_INDEX = 10; +- +- @Test +- public void createShouldSucceed() { +- Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE); +- +- assertThat(category.getLabel()).isEqualTo(APPLE_LABEL); +- assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME); +- assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE); +- } +- +- @Test +- public void createWithIndexShouldSucceed() { +- Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX); +- +- assertThat(category.getLabel()).isEqualTo(APPLE_LABEL); +- assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME); +- assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE); +- assertThat(category.getIndex()).isEqualTo(APPLE_INDEX); +- } +- +- @Test +- public void constructorShouldSucceed() { +- Category category = new Category(APPLE_LABEL, APPLE_SCORE); +- +- assertThat(category.getLabel()).isEqualTo(APPLE_LABEL); +- // Using the constructor, displayName will be default to an empty string. +- assertThat(category.getDisplayName()).isEqualTo(DEFAULT_DISPLAY_NAME); +- assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE); +- } +- +- @Test +- public void toStringWithCreateShouldProvideReadableResult() { +- Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE); +- String categoryString = category.toString(); +- +- assertThat(categoryString) +- .isEqualTo( +- "<Category \"" +- + APPLE_LABEL +- + "\" (displayName=" +- + APPLE_DISPLAY_NAME +- + " score=" +- + APPLE_SCORE +- + " index=-1" +- + ")>"); +- } +- +- @Test +- public void toStringWithCreateIndexShouldProvideReadableResult() { +- Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX); +- String categoryString = category.toString(); +- +- assertThat(categoryString) +- .isEqualTo( +- "<Category \"" +- + APPLE_LABEL +- + "\" (displayName=" +- + APPLE_DISPLAY_NAME +- + " score=" +- + APPLE_SCORE +- + " index=" +- + APPLE_INDEX +- + ")>"); +- } +- +- @Test +- public void toStringWithConstuctorShouldProvideReadableResult() { +- Category category = new Category(APPLE_LABEL, APPLE_SCORE); +- String categoryString = category.toString(); +- +- assertThat(categoryString) +- .isEqualTo( +- "<Category \"" +- + APPLE_LABEL +- + "\" (displayName=" +- + DEFAULT_DISPLAY_NAME +- + " score=" +- + APPLE_SCORE +- + " index=-1" +- + ")>"); +- } +- +- @Test +- public void equalsShouldSucceedWithCreate() { +- Category categoryA = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE); +- Category categoryB = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE); +- +- assertThat(categoryA).isEqualTo(categoryB); +- } +- +- @Test +- public void equalsShouldSucceedWithCreateIndex() { +- Category categoryA = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX); +- Category categoryB = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX); +- +- assertThat(categoryA).isEqualTo(categoryB); +- } +- +- @Test +- public void equalsShouldSucceedWithConstructor() { +- Category categoryA = new Category(APPLE_LABEL, APPLE_SCORE); +- Category categoryB = new Category(APPLE_LABEL, APPLE_SCORE); +- +- assertThat(categoryA).isEqualTo(categoryB); +- } ++ private static final String APPLE_LABEL = "apple"; ++ private static final String DEFAULT_DISPLAY_NAME = ""; ++ private static final String APPLE_DISPLAY_NAME = "manzana"; // "apple" in Spanish. ++ private static final float APPLE_SCORE = 0.5f; ++ private static final int APPLE_INDEX = 10; ++ ++ @Test ++ public void createShouldSucceed() { ++ Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE); ++ ++ assertThat(category.getLabel()).isEqualTo(APPLE_LABEL); ++ assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME); ++ assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE); ++ } ++ ++ @Test ++ public void createWithIndexShouldSucceed() { ++ Category category = ++ Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX); ++ ++ assertThat(category.getLabel()).isEqualTo(APPLE_LABEL); ++ assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME); ++ assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE); ++ assertThat(category.getIndex()).isEqualTo(APPLE_INDEX); ++ } ++ ++ @Test ++ public void constructorShouldSucceed() { ++ Category category = new Category(APPLE_LABEL, APPLE_SCORE); ++ ++ assertThat(category.getLabel()).isEqualTo(APPLE_LABEL); ++ // Using the constructor, displayName will be default to an empty string. ++ assertThat(category.getDisplayName()).isEqualTo(DEFAULT_DISPLAY_NAME); ++ assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE); ++ } ++ ++ @Test ++ public void toStringWithCreateShouldProvideReadableResult() { ++ Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE); ++ String categoryString = category.toString(); ++ ++ assertThat(categoryString) ++ .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + APPLE_DISPLAY_NAME ++ + " score=" + APPLE_SCORE + " index=-1" ++ + ")>"); ++ } ++ ++ @Test ++ public void toStringWithCreateIndexShouldProvideReadableResult() { ++ Category category = ++ Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX); ++ String categoryString = category.toString(); ++ ++ assertThat(categoryString) ++ .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + APPLE_DISPLAY_NAME ++ + " score=" + APPLE_SCORE + " index=" + APPLE_INDEX + ")>"); ++ } ++ ++ @Test ++ public void toStringWithConstuctorShouldProvideReadableResult() { ++ Category category = new Category(APPLE_LABEL, APPLE_SCORE); ++ String categoryString = category.toString(); ++ ++ assertThat(categoryString) ++ .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + DEFAULT_DISPLAY_NAME ++ + " score=" + APPLE_SCORE + " index=-1" ++ + ")>"); ++ } ++ ++ @Test ++ public void equalsShouldSucceedWithCreate() { ++ Category categoryA = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE); ++ Category categoryB = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE); ++ ++ assertThat(categoryA).isEqualTo(categoryB); ++ } ++ ++ @Test ++ public void equalsShouldSucceedWithCreateIndex() { ++ Category categoryA = ++ Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX); ++ Category categoryB = ++ Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX); ++ ++ assertThat(categoryA).isEqualTo(categoryB); ++ } ++ ++ @Test ++ public void equalsShouldSucceedWithConstructor() { ++ Category categoryA = new Category(APPLE_LABEL, APPLE_SCORE); ++ Category categoryB = new Category(APPLE_LABEL, APPLE_SCORE); ++ ++ assertThat(categoryA).isEqualTo(categoryB); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java +index caa468bb0a9ec..91c81c4932b81 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java +@@ -17,35 +17,38 @@ package org.tensorflow.lite.support.label; + + import static com.google.common.truth.Truth.assertThat; + +-import java.util.Arrays; +-import java.util.List; + import org.junit.Test; + import org.junit.runner.RunWith; + import org.robolectric.RobolectricTestRunner; + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + ++import java.util.Arrays; ++import java.util.List; ++ + /** Tests of {@link org.tensorflow.lite.support.label.LabelUtil}. */ + @RunWith(RobolectricTestRunner.class) + public class LabelUtilTest { +- +- @Test +- public void mapIndexToStringsWithInvalidValues() { +- String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"}; +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); +- tensorBuffer.loadArray(new int[] {0, 1, 2, 3, 2, 5}, new int[] {1, 6}); +- List<String> categories = LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1); +- assertThat(categories.toArray()) +- .isEqualTo(new String[] {"apple", "banana", "cherry", "date", "cherry", ""}); +- } +- +- @Test +- public void mapFloatIndexShouldCast() { +- String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"}; +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); +- tensorBuffer.loadArray(new float[] {-1.1f, -0.3f, 0.3f, 1.2f, 1.8f, 1}, new int[] {1, 6}); +- List<String> categories = LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1); +- assertThat(categories.toArray()) +- .isEqualTo(new String[] {"background", "apple", "apple", "banana", "banana", "banana"}); +- } ++ @Test ++ public void mapIndexToStringsWithInvalidValues() { ++ String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"}; ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); ++ tensorBuffer.loadArray(new int[] {0, 1, 2, 3, 2, 5}, new int[] {1, 6}); ++ List<String> categories = ++ LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1); ++ assertThat(categories.toArray()) ++ .isEqualTo(new String[] {"apple", "banana", "cherry", "date", "cherry", ""}); ++ } ++ ++ @Test ++ public void mapFloatIndexShouldCast() { ++ String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"}; ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); ++ tensorBuffer.loadArray(new float[] {-1.1f, -0.3f, 0.3f, 1.2f, 1.8f, 1}, new int[] {1, 6}); ++ List<String> categories = ++ LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1); ++ assertThat(categories.toArray()) ++ .isEqualTo(new String[] { ++ "background", "apple", "apple", "banana", "banana", "banana"}); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java +index 4f296b7476c2d..857a77a2a4bd4 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java +@@ -17,10 +17,6 @@ package org.tensorflow.lite.support.label; + + import static com.google.common.truth.Truth.assertThat; + +-import java.util.Arrays; +-import java.util.HashMap; +-import java.util.List; +-import java.util.Map; + import org.junit.Assert; + import org.junit.Test; + import org.junit.runner.RunWith; +@@ -28,169 +24,180 @@ import org.robolectric.RobolectricTestRunner; + import org.tensorflow.lite.DataType; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + ++import java.util.Arrays; ++import java.util.HashMap; ++import java.util.List; ++import java.util.Map; ++ + /** Tests of {@link org.tensorflow.lite.support.label.TensorLabel}. */ + @RunWith(RobolectricTestRunner.class) + public final class TensorLabelTest { +- @Test +- public void createTensorLabelWithNullAxisLabelsShouldFail() { +- int[] shape = {2}; +- int[] arr = {1, 2}; +- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.UINT8); +- buffer.loadArray(arr, shape); +- Map<Integer, List<String>> nullAxisLabels = null; +- +- Assert.assertThrows(NullPointerException.class, () -> new TensorLabel(nullAxisLabels, buffer)); +- } +- +- @Test +- public void createTensorLabelWithNullTensorBufferShouldFail() { +- Map<Integer, List<String>> axisLabels = new HashMap<>(); +- axisLabels.put(1, Arrays.asList("a", "b", "c", "d")); +- TensorBuffer nullTensorBuffer = null; +- +- Assert.assertThrows( +- NullPointerException.class, () -> new TensorLabel(axisLabels, nullTensorBuffer)); +- } +- +- @Test +- public void createTensorLabelWithStringListShouldSuccess() { +- TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 4, 3}, DataType.FLOAT32); +- +- TensorLabel tensorLabel = new TensorLabel(Arrays.asList("a", "b", "c", "d"), buffer); +- +- assertThat(tensorLabel.getMapWithTensorBuffer()).isNotNull(); +- assertThat(tensorLabel.getMapWithTensorBuffer().keySet()).contains("c"); // randomly pick one +- } +- +- @Test +- public void createTensorLabelWithEmptyShapeShouldFail() { +- int[] shape = new int[] {}; +- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); +- Map<Integer, List<String>> axisLabels = new HashMap<>(); +- axisLabels.put(1, Arrays.asList("a", "b", "c", "d")); +- +- Assert.assertThrows(IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer)); +- } +- +- @Test +- public void createTensorLabelWithMismatchedAxisShouldFail() { +- int[] shape = {1, 4}; +- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); +- Map<Integer, List<String>> axisLabels = new HashMap<>(); +- axisLabels.put(0, Arrays.asList("a", "b", "c", "d")); +- +- Assert.assertThrows(IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer)); +- } +- +- @Test +- public void createTensorLabelWithMismatchedShapeShouldFail() { +- int[] shape = {1, 3}; +- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); +- Map<Integer, List<String>> axisLabels = new HashMap<>(); +- axisLabels.put(1, Arrays.asList("a", "b", "c", "d")); +- +- Assert.assertThrows(IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer)); +- } +- +- @Test +- public void getMapWithFloatBufferValuesShouldSuccess() { +- int numberLabel = 4; +- float[] inputArr = {0.5f, 0.2f, 0.2f, 0.1f}; +- int[] shape = {1, numberLabel}; +- TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); +- input.loadArray(inputArr, shape); +- Map<Integer, List<String>> axisLabels = new HashMap<>(); +- int labelAxis = 1; +- axisLabels.put(labelAxis, Arrays.asList("a", "b", "c", "d")); +- +- TensorLabel tensorLabeled = new TensorLabel(axisLabels, input); +- Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer(); +- +- for (int i = 0; i < numberLabel; i++) { +- String label = axisLabels.get(labelAxis).get(i); +- assertThat(map).containsKey(label); +- float[] array = map.get(label).getFloatArray(); +- assertThat(array).hasLength(1); +- assertThat(array[0]).isEqualTo(inputArr[i]); ++ @Test ++ public void createTensorLabelWithNullAxisLabelsShouldFail() { ++ int[] shape = {2}; ++ int[] arr = {1, 2}; ++ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.UINT8); ++ buffer.loadArray(arr, shape); ++ Map<Integer, List<String>> nullAxisLabels = null; ++ ++ Assert.assertThrows( ++ NullPointerException.class, () -> new TensorLabel(nullAxisLabels, buffer)); ++ } ++ ++ @Test ++ public void createTensorLabelWithNullTensorBufferShouldFail() { ++ Map<Integer, List<String>> axisLabels = new HashMap<>(); ++ axisLabels.put(1, Arrays.asList("a", "b", "c", "d")); ++ TensorBuffer nullTensorBuffer = null; ++ ++ Assert.assertThrows( ++ NullPointerException.class, () -> new TensorLabel(axisLabels, nullTensorBuffer)); ++ } ++ ++ @Test ++ public void createTensorLabelWithStringListShouldSuccess() { ++ TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 4, 3}, DataType.FLOAT32); ++ ++ TensorLabel tensorLabel = new TensorLabel(Arrays.asList("a", "b", "c", "d"), buffer); ++ ++ assertThat(tensorLabel.getMapWithTensorBuffer()).isNotNull(); ++ assertThat(tensorLabel.getMapWithTensorBuffer().keySet()) ++ .contains("c"); // randomly pick one ++ } ++ ++ @Test ++ public void createTensorLabelWithEmptyShapeShouldFail() { ++ int[] shape = new int[] {}; ++ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); ++ Map<Integer, List<String>> axisLabels = new HashMap<>(); ++ axisLabels.put(1, Arrays.asList("a", "b", "c", "d")); ++ ++ Assert.assertThrows( ++ IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer)); + } +- } +- +- @Test +- public void getMapWithIntBufferValuesShouldSuccess() { +- int numberLabel = 3; +- int[] inputArr = {1, 2, 0}; +- int[] shape = {1, 1, numberLabel}; +- TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8); +- input.loadArray(inputArr, shape); +- Map<Integer, List<String>> axisLabels = new HashMap<>(); +- int labelAxis = 2; +- axisLabels.put(labelAxis, Arrays.asList("x", "y", "z")); +- +- TensorLabel tensorLabeled = new TensorLabel(axisLabels, input); +- Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer(); +- +- for (int i = 0; i < numberLabel; i++) { +- String label = axisLabels.get(labelAxis).get(i); +- assertThat(map).containsKey(label); +- int[] array = map.get(label).getIntArray(); +- assertThat(array).hasLength(1); +- assertThat(array[0]).isEqualTo(inputArr[i]); ++ ++ @Test ++ public void createTensorLabelWithMismatchedAxisShouldFail() { ++ int[] shape = {1, 4}; ++ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); ++ Map<Integer, List<String>> axisLabels = new HashMap<>(); ++ axisLabels.put(0, Arrays.asList("a", "b", "c", "d")); ++ ++ Assert.assertThrows( ++ IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer)); + } +- } +- +- @Test +- public void getFloatMapShouldSuccess() { +- int[] shape = {1, 3}; +- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); +- buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f}); +- +- TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer); +- Map<String, Float> map = tensorLabeled.getMapWithFloatValue(); +- +- assertThat(map).hasSize(3); +- assertThat(map).containsEntry("a", 1.0f); +- assertThat(map).containsEntry("b", 2.0f); +- assertThat(map).containsEntry("c", 3.0f); +- } +- +- @Test +- public void getMapFromMultiDimensionalTensorBufferShouldSuccess() { +- int numberLabel = 2; +- int numDim = 3; +- float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f}; +- int[] shape = {numberLabel, numDim}; +- TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); +- input.loadArray(inputArr, shape); +- Map<Integer, List<String>> axisLabels = new HashMap<>(); +- int labelAxis = 0; +- axisLabels.put(labelAxis, Arrays.asList("pos", "neg")); +- +- TensorLabel tensorLabeled = new TensorLabel(axisLabels, input); +- Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer(); +- +- for (int i = 0; i < numberLabel; i++) { +- String label = axisLabels.get(labelAxis).get(i); +- assertThat(map).containsKey(label); +- +- float[] array = map.get(label).getFloatArray(); +- assertThat(array).hasLength(numDim); +- for (int j = 0; j < numDim; j++) { +- assertThat(array[j]).isEqualTo(inputArr[i * numDim + j]); +- } ++ ++ @Test ++ public void createTensorLabelWithMismatchedShapeShouldFail() { ++ int[] shape = {1, 3}; ++ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); ++ Map<Integer, List<String>> axisLabels = new HashMap<>(); ++ axisLabels.put(1, Arrays.asList("a", "b", "c", "d")); ++ ++ Assert.assertThrows( ++ IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer)); ++ } ++ ++ @Test ++ public void getMapWithFloatBufferValuesShouldSuccess() { ++ int numberLabel = 4; ++ float[] inputArr = {0.5f, 0.2f, 0.2f, 0.1f}; ++ int[] shape = {1, numberLabel}; ++ TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); ++ input.loadArray(inputArr, shape); ++ Map<Integer, List<String>> axisLabels = new HashMap<>(); ++ int labelAxis = 1; ++ axisLabels.put(labelAxis, Arrays.asList("a", "b", "c", "d")); ++ ++ TensorLabel tensorLabeled = new TensorLabel(axisLabels, input); ++ Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer(); ++ ++ for (int i = 0; i < numberLabel; i++) { ++ String label = axisLabels.get(labelAxis).get(i); ++ assertThat(map).containsKey(label); ++ float[] array = map.get(label).getFloatArray(); ++ assertThat(array).hasLength(1); ++ assertThat(array[0]).isEqualTo(inputArr[i]); ++ } + } +- } + +- @Test +- public void getCategoryListShouldSuccess() { +- int[] shape = {1, 3}; +- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); +- buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f}); ++ @Test ++ public void getMapWithIntBufferValuesShouldSuccess() { ++ int numberLabel = 3; ++ int[] inputArr = {1, 2, 0}; ++ int[] shape = {1, 1, numberLabel}; ++ TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8); ++ input.loadArray(inputArr, shape); ++ Map<Integer, List<String>> axisLabels = new HashMap<>(); ++ int labelAxis = 2; ++ axisLabels.put(labelAxis, Arrays.asList("x", "y", "z")); ++ ++ TensorLabel tensorLabeled = new TensorLabel(axisLabels, input); ++ Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer(); ++ ++ for (int i = 0; i < numberLabel; i++) { ++ String label = axisLabels.get(labelAxis).get(i); ++ assertThat(map).containsKey(label); ++ int[] array = map.get(label).getIntArray(); ++ assertThat(array).hasLength(1); ++ assertThat(array[0]).isEqualTo(inputArr[i]); ++ } ++ } + +- TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer); +- List<Category> categories = tensorLabeled.getCategoryList(); ++ @Test ++ public void getFloatMapShouldSuccess() { ++ int[] shape = {1, 3}; ++ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); ++ buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f}); + +- assertThat(categories).hasSize(3); +- assertThat(categories) +- .containsExactly(new Category("a", 1.0f), new Category("b", 2.0f), new Category("c", 3.0f)); +- } ++ TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer); ++ Map<String, Float> map = tensorLabeled.getMapWithFloatValue(); ++ ++ assertThat(map).hasSize(3); ++ assertThat(map).containsEntry("a", 1.0f); ++ assertThat(map).containsEntry("b", 2.0f); ++ assertThat(map).containsEntry("c", 3.0f); ++ } ++ ++ @Test ++ public void getMapFromMultiDimensionalTensorBufferShouldSuccess() { ++ int numberLabel = 2; ++ int numDim = 3; ++ float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f}; ++ int[] shape = {numberLabel, numDim}; ++ TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); ++ input.loadArray(inputArr, shape); ++ Map<Integer, List<String>> axisLabels = new HashMap<>(); ++ int labelAxis = 0; ++ axisLabels.put(labelAxis, Arrays.asList("pos", "neg")); ++ ++ TensorLabel tensorLabeled = new TensorLabel(axisLabels, input); ++ Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer(); ++ ++ for (int i = 0; i < numberLabel; i++) { ++ String label = axisLabels.get(labelAxis).get(i); ++ assertThat(map).containsKey(label); ++ ++ float[] array = map.get(label).getFloatArray(); ++ assertThat(array).hasLength(numDim); ++ for (int j = 0; j < numDim; j++) { ++ assertThat(array[j]).isEqualTo(inputArr[i * numDim + j]); ++ } ++ } ++ } ++ ++ @Test ++ public void getCategoryListShouldSuccess() { ++ int[] shape = {1, 3}; ++ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); ++ buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f}); ++ ++ TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer); ++ List<Category> categories = tensorLabeled.getCategoryList(); ++ ++ assertThat(categories).hasSize(3); ++ assertThat(categories) ++ .containsExactly( ++ new Category("a", 1.0f), new Category("b", 2.0f), new Category("c", 3.0f)); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java +index 8fa8860a09ef5..c1afe99f34f34 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java +@@ -18,11 +18,9 @@ package org.tensorflow.lite.support.label.ops; + import static com.google.common.truth.Truth.assertThat; + + import android.content.Context; ++ + import androidx.test.core.app.ApplicationProvider; +-import java.io.IOException; +-import java.util.Arrays; +-import java.util.List; +-import java.util.Map; ++ + import org.junit.Test; + import org.junit.runner.RunWith; + import org.robolectric.RobolectricTestRunner; +@@ -31,90 +29,94 @@ import org.tensorflow.lite.support.common.FileUtil; + import org.tensorflow.lite.support.label.TensorLabel; + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + ++import java.io.IOException; ++import java.util.Arrays; ++import java.util.List; ++import java.util.Map; ++ + /** Tests of {@link org.tensorflow.lite.support.label.ops.LabelAxisOp}. */ + @RunWith(RobolectricTestRunner.class) + public final class LabelAxisOpTest { ++ private final Context context = ApplicationProvider.getApplicationContext(); ++ private static final String LABEL_PATH = "flower_labels.txt"; ++ ++ @Test ++ public void testAddAxisLabelByStringList() { ++ int numberLabel = 2; ++ float[] inputArr = {0.7f, 0.3f}; ++ ++ int[] shape = {numberLabel}; ++ TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); ++ input.loadArray(inputArr, shape); ++ ++ List<String> labels = Arrays.asList("pos", "neg"); ++ LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(0, labels).build(); ++ TensorLabel output = op.apply(input); ++ Map<String, TensorBuffer> map = output.getMapWithTensorBuffer(); ++ ++ assertThat(map).containsKey("pos"); ++ float[] array = map.get("pos").getFloatArray(); ++ assertThat(array).hasLength(1); ++ assertThat(array[0]).isEqualTo(0.7f); ++ ++ assertThat(map).containsKey("neg"); ++ array = map.get("neg").getFloatArray(); ++ assertThat(array).hasLength(1); ++ assertThat(array[0]).isEqualTo(0.3f); ++ } ++ ++ @Test ++ public void testAddAxisLabelWithMultiDimensionTensor() throws IOException { ++ int numberLabel = 2; ++ int numDim = 3; ++ float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f}; ++ ++ int[] shape = {1, numberLabel, numDim}; ++ TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); ++ input.loadArray(inputArr, shape); + +- private final Context context = ApplicationProvider.getApplicationContext(); +- private static final String LABEL_PATH = "flower_labels.txt"; +- +- @Test +- public void testAddAxisLabelByStringList() { +- int numberLabel = 2; +- float[] inputArr = {0.7f, 0.3f}; +- +- int[] shape = {numberLabel}; +- TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); +- input.loadArray(inputArr, shape); +- +- List<String> labels = Arrays.asList("pos", "neg"); +- LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(0, labels).build(); +- TensorLabel output = op.apply(input); +- Map<String, TensorBuffer> map = output.getMapWithTensorBuffer(); +- +- assertThat(map).containsKey("pos"); +- float[] array = map.get("pos").getFloatArray(); +- assertThat(array).hasLength(1); +- assertThat(array[0]).isEqualTo(0.7f); +- +- assertThat(map).containsKey("neg"); +- array = map.get("neg").getFloatArray(); +- assertThat(array).hasLength(1); +- assertThat(array[0]).isEqualTo(0.3f); +- } +- +- @Test +- public void testAddAxisLabelWithMultiDimensionTensor() throws IOException { +- int numberLabel = 2; +- int numDim = 3; +- float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f}; +- +- int[] shape = {1, numberLabel, numDim}; +- TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); +- input.loadArray(inputArr, shape); +- +- List<String> labels = Arrays.asList("pos", "neg"); +- LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(1, labels).build(); +- +- TensorLabel output = op.apply(input); +- Map<String, TensorBuffer> map = output.getMapWithTensorBuffer(); +- +- assertThat(map).containsKey("pos"); +- float[] array = map.get("pos").getFloatArray(); +- assertThat(array).hasLength(numDim); +- assertThat(array).isEqualTo(new float[] {0.5f, 0.1f, 0.3f}); +- +- assertThat(map).containsKey("neg"); +- array = map.get("neg").getFloatArray(); +- assertThat(array).hasLength(numDim); +- assertThat(array).isEqualTo(new float[] {0.2f, 0.2f, 0.1f}); +- } +- +- @Test +- public void testAddAxisLabelByFilePath() throws IOException { +- int numberLabel = 5; +- int[] inputArr = new int[numberLabel]; +- for (int i = 0; i < numberLabel; i++) { +- inputArr[i] = i; ++ List<String> labels = Arrays.asList("pos", "neg"); ++ LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(1, labels).build(); ++ ++ TensorLabel output = op.apply(input); ++ Map<String, TensorBuffer> map = output.getMapWithTensorBuffer(); ++ ++ assertThat(map).containsKey("pos"); ++ float[] array = map.get("pos").getFloatArray(); ++ assertThat(array).hasLength(numDim); ++ assertThat(array).isEqualTo(new float[] {0.5f, 0.1f, 0.3f}); ++ ++ assertThat(map).containsKey("neg"); ++ array = map.get("neg").getFloatArray(); ++ assertThat(array).hasLength(numDim); ++ assertThat(array).isEqualTo(new float[] {0.2f, 0.2f, 0.1f}); + } + +- int[] shape = {numberLabel}; +- TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8); +- input.loadArray(inputArr, shape); ++ @Test ++ public void testAddAxisLabelByFilePath() throws IOException { ++ int numberLabel = 5; ++ int[] inputArr = new int[numberLabel]; ++ for (int i = 0; i < numberLabel; i++) { ++ inputArr[i] = i; ++ } ++ ++ int[] shape = {numberLabel}; ++ TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8); ++ input.loadArray(inputArr, shape); + +- LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(context, 0, LABEL_PATH).build(); +- TensorLabel output = op.apply(input); +- Map<String, TensorBuffer> map = output.getMapWithTensorBuffer(); ++ LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(context, 0, LABEL_PATH).build(); ++ TensorLabel output = op.apply(input); ++ Map<String, TensorBuffer> map = output.getMapWithTensorBuffer(); + +- List<String> labels = FileUtil.loadLabels(context, LABEL_PATH); +- for (int i = 0; i < numberLabel; i++) { +- String label = labels.get(i); ++ List<String> labels = FileUtil.loadLabels(context, LABEL_PATH); ++ for (int i = 0; i < numberLabel; i++) { ++ String label = labels.get(i); + +- assertThat(map).containsKey(label); ++ assertThat(map).containsKey(label); + +- int[] array = map.get(label).getIntArray(); +- assertThat(array).hasLength(1); +- assertThat(array[0]).isEqualTo(inputArr[i]); ++ int[] array = map.get(label).getIntArray(); ++ assertThat(array).hasLength(1); ++ assertThat(array[0]).isEqualTo(inputArr[i]); ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java +index bd59051ce4ccb..d7449187cb54c 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java +@@ -17,6 +17,7 @@ package org.tensorflow.lite.support.model; + import static com.google.common.truth.Truth.assertThat; + + import androidx.test.ext.junit.runners.AndroidJUnit4; ++ + import org.junit.Test; + import org.junit.runner.RunWith; + +@@ -27,13 +28,12 @@ import org.junit.runner.RunWith; + */ + @RunWith(AndroidJUnit4.class) + public final class GpuDelegateProxyInstrumentedTest { +- +- @Test +- public void createGpuDelegateProxyShouldSuccess() { +- GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance(); +- +- assertThat(proxy).isNotNull(); +- proxy.getNativeHandle(); +- proxy.close(); +- } ++ @Test ++ public void createGpuDelegateProxyShouldSuccess() { ++ GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance(); ++ ++ assertThat(proxy).isNotNull(); ++ proxy.getNativeHandle(); ++ proxy.close(); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java +index c1bbcc223a895..4eb2e2920c3bc 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java +@@ -23,11 +23,10 @@ import org.robolectric.RobolectricTestRunner; + /** Tests of {@link org.tensorflow.lite.support.model.GpuDelegateProxy}. */ + @RunWith(RobolectricTestRunner.class) + public final class GpuDelegateProxyTest { ++ @Test ++ public void createGpuDelegateProxyWithoutDependencyShouldReturnNull() { ++ GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance(); + +- @Test +- public void createGpuDelegateProxyWithoutDependencyShouldReturnNull() { +- GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance(); +- +- assertThat(proxy).isNull(); +- } ++ assertThat(proxy).isNull(); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java +index 86e4f72769216..342e82b2de3bb 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java +@@ -16,143 +16,145 @@ limitations under the License. + package org.tensorflow.lite.support.model; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.fail; + + import android.content.Context; ++ + import androidx.test.core.app.ApplicationProvider; +-import java.io.IOException; +-import java.nio.MappedByteBuffer; +-import java.util.HashMap; +-import java.util.Map; ++ ++import org.junit.Ignore; + import org.junit.Test; + import org.junit.runner.RunWith; + import org.robolectric.RobolectricTestRunner; + import org.tensorflow.lite.support.model.Model.Device; + import org.tensorflow.lite.support.model.Model.Options; + +-import org.junit.Ignore; ++import java.io.IOException; ++import java.nio.MappedByteBuffer; ++import java.util.HashMap; ++import java.util.Map; + + /** Tests of {@link org.tensorflow.lite.support.model.Model}. */ + @RunWith(RobolectricTestRunner.class) + public final class ModelTest { ++ private final Context context = ApplicationProvider.getApplicationContext(); ++ private static final String MODEL_PATH = "add.tflite"; ++ ++ @Ignore ++ @Test ++ public void testLoadLocalModel() throws IOException { ++ MappedByteBuffer byteModel = new Model.Builder(context, MODEL_PATH).build().getData(); ++ assertThat(byteModel).isNotNull(); ++ } ++ ++ @Ignore ++ @Test ++ public void testBuildMultiThreadModel() throws IOException { ++ MappedByteBuffer byteModel = ++ new Model.Builder(context, MODEL_PATH).setNumThreads(4).build().getData(); ++ assertThat(byteModel).isNotNull(); ++ } ++ ++ @Ignore ++ @Test ++ public void buildModelWithOptionsShouldSuccess() throws IOException { ++ Options options = new Options.Builder().setNumThreads(2).setDevice(Device.NNAPI).build(); ++ Model model = Model.createModel(context, MODEL_PATH, options); ++ assertThat(model.getData()).isNotNull(); ++ } + +- private final Context context = ApplicationProvider.getApplicationContext(); +- private static final String MODEL_PATH = "add.tflite"; +- +- @Ignore +- @Test +- public void testLoadLocalModel() throws IOException { +- MappedByteBuffer byteModel = new Model.Builder(context, MODEL_PATH).build().getData(); +- assertThat(byteModel).isNotNull(); +- } +- +- @Ignore +- @Test +- public void testBuildMultiThreadModel() throws IOException { +- MappedByteBuffer byteModel = +- new Model.Builder(context, MODEL_PATH).setNumThreads(4).build().getData(); +- assertThat(byteModel).isNotNull(); +- } +- +- @Ignore +- @Test +- public void buildModelWithOptionsShouldSuccess() throws IOException { +- Options options = new Options.Builder().setNumThreads(2).setDevice(Device.NNAPI).build(); +- Model model = Model.createModel(context, MODEL_PATH, options); +- assertThat(model.getData()).isNotNull(); +- } +- +- @Ignore +- @Test +- public void testGetModelPath() throws IOException { +- String modelPath = new Model.Builder(context, MODEL_PATH).build().getPath(); +- assertThat(modelPath).isEqualTo(MODEL_PATH); +- } +- +- @Test +- public void testNonExistingLocalModel() { +- try { +- new Model.Builder(context, "non_exist_model_file").build(); +- fail(); +- } catch (IOException e) { +- assertThat(e).hasMessageThat().contains("non_exist_model_file"); ++ @Ignore ++ @Test ++ public void testGetModelPath() throws IOException { ++ String modelPath = new Model.Builder(context, MODEL_PATH).build().getPath(); ++ assertThat(modelPath).isEqualTo(MODEL_PATH); + } +- } +- +- @Test +- public void testNullLocalModelPath() throws IOException { +- try { +- new Model.Builder(context, null).build(); +- fail(); +- } catch (NullPointerException e) { +- assertThat(e).hasMessageThat().contains("File path cannot be null."); ++ ++ @Test ++ public void testNonExistingLocalModel() { ++ try { ++ new Model.Builder(context, "non_exist_model_file").build(); ++ fail(); ++ } catch (IOException e) { ++ assertThat(e).hasMessageThat().contains("non_exist_model_file"); ++ } + } +- } +- +- @Test +- public void testNullContext() throws IOException { +- try { +- new Model.Builder(null, MODEL_PATH).build(); +- fail(); +- } catch (NullPointerException e) { +- assertThat(e).hasMessageThat().contains("Context should not be null."); ++ ++ @Test ++ public void testNullLocalModelPath() throws IOException { ++ try { ++ new Model.Builder(context, null).build(); ++ fail(); ++ } catch (NullPointerException e) { ++ assertThat(e).hasMessageThat().contains("File path cannot be null."); ++ } ++ } ++ ++ @Test ++ public void testNullContext() throws IOException { ++ try { ++ new Model.Builder(null, MODEL_PATH).build(); ++ fail(); ++ } catch (NullPointerException e) { ++ assertThat(e).hasMessageThat().contains("Context should not be null."); ++ } ++ } ++ ++ @Ignore ++ @Test ++ public void testGetInputTensor() throws IOException { ++ Options options = new Options.Builder().build(); ++ Model model = Model.createModel(context, MODEL_PATH, options); ++ assertThat(model.getInputTensor(0)).isNotNull(); ++ } ++ ++ @Ignore ++ @Test ++ public void testGetOutputTensor() throws IOException { ++ Options options = new Options.Builder().build(); ++ Model model = Model.createModel(context, MODEL_PATH, options); ++ assertThat(model.getOutputTensor(0)).isNotNull(); ++ } ++ ++ @Ignore ++ @Test ++ public void testRun() throws IOException { ++ Context context = ApplicationProvider.getApplicationContext(); ++ Model model = new Model.Builder(context, MODEL_PATH).build(); ++ runModel(model); ++ } ++ ++ @Ignore ++ @Test ++ public void testMultiThreadingRun() throws IOException { ++ Context context = ApplicationProvider.getApplicationContext(); ++ Model model = new Model.Builder(context, MODEL_PATH).setNumThreads(4).build(); ++ runModel(model); ++ } ++ ++ @Ignore ++ @Test ++ public void testNnApiRun() throws IOException { ++ Context context = ApplicationProvider.getApplicationContext(); ++ Model model = new Model.Builder(context, MODEL_PATH).setDevice(Device.NNAPI).build(); ++ runModel(model); ++ } ++ ++ private static void runModel(Model model) throws IOException { ++ // Creates the inputs. ++ float[] x = {1.5f}; ++ float[] y = {0.5f}; ++ float[] expectedSum = {2.0f}; ++ Object[] inputs = {x, y}; ++ ++ // Creates the outputs buffer. ++ float[] sum = new float[1]; ++ Map<Integer, Object> outputs = new HashMap<>(); ++ outputs.put(0, sum); ++ ++ // Runs inference. ++ model.run(inputs, outputs); ++ assertThat(sum).isEqualTo(expectedSum); + } +- } +- +- @Ignore +- @Test +- public void testGetInputTensor() throws IOException { +- Options options = new Options.Builder().build(); +- Model model = Model.createModel(context, MODEL_PATH, options); +- assertThat(model.getInputTensor(0)).isNotNull(); +- } +- +- @Ignore +- @Test +- public void testGetOutputTensor() throws IOException { +- Options options = new Options.Builder().build(); +- Model model = Model.createModel(context, MODEL_PATH, options); +- assertThat(model.getOutputTensor(0)).isNotNull(); +- } +- +- @Ignore +- @Test +- public void testRun() throws IOException { +- Context context = ApplicationProvider.getApplicationContext(); +- Model model = new Model.Builder(context, MODEL_PATH).build(); +- runModel(model); +- } +- +- @Ignore +- @Test +- public void testMultiThreadingRun() throws IOException { +- Context context = ApplicationProvider.getApplicationContext(); +- Model model = new Model.Builder(context, MODEL_PATH).setNumThreads(4).build(); +- runModel(model); +- } +- +- @Ignore +- @Test +- public void testNnApiRun() throws IOException { +- Context context = ApplicationProvider.getApplicationContext(); +- Model model = new Model.Builder(context, MODEL_PATH).setDevice(Device.NNAPI).build(); +- runModel(model); +- } +- +- private static void runModel(Model model) throws IOException { +- // Creates the inputs. +- float[] x = {1.5f}; +- float[] y = {0.5f}; +- float[] expectedSum = {2.0f}; +- Object[] inputs = {x, y}; +- +- // Creates the outputs buffer. +- float[] sum = new float[1]; +- Map<Integer, Object> outputs = new HashMap<>(); +- outputs.put(0, sum); +- +- // Runs inference. +- model.run(inputs, outputs); +- assertThat(sum).isEqualTo(expectedSum); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java +index 3a4d09d8e5701..82b59b36155f3 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java +@@ -26,51 +26,51 @@ import org.tensorflow.lite.DataType; + /** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat}. */ + @RunWith(RobolectricTestRunner.class) + public final class TensorBufferFloatTest { +- @Test +- public void testCreateDynamic() { +- TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(); +- assertThat(tensorBufferFloat).isNotNull(); +- } ++ @Test ++ public void testCreateDynamic() { ++ TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(); ++ assertThat(tensorBufferFloat).isNotNull(); ++ } + +- @Test +- public void testCreateFixedSize() { +- int[] shape = new int[] {1, 2, 3}; +- TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape); +- assertThat(tensorBufferFloat).isNotNull(); +- assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6); +- } ++ @Test ++ public void testCreateFixedSize() { ++ int[] shape = new int[] {1, 2, 3}; ++ TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape); ++ assertThat(tensorBufferFloat).isNotNull(); ++ assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6); ++ } + +- @Test +- public void testCreateFixedSizeWithScalarShape() { +- int[] shape = new int[] {}; +- TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape); +- assertThat(tensorBufferFloat).isNotNull(); +- assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(1); +- } ++ @Test ++ public void testCreateFixedSizeWithScalarShape() { ++ int[] shape = new int[] {}; ++ TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape); ++ assertThat(tensorBufferFloat).isNotNull(); ++ assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(1); ++ } + +- @Test +- public void testCreateWithNullShape() { +- int[] shape = null; +- Assert.assertThrows(NullPointerException.class, () -> new TensorBufferFloat(shape)); +- } ++ @Test ++ public void testCreateWithNullShape() { ++ int[] shape = null; ++ Assert.assertThrows(NullPointerException.class, () -> new TensorBufferFloat(shape)); ++ } + +- @Test +- public void testCreateWithInvalidShape() { +- int[] shape = new int[] {1, -1, 2}; +- Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferFloat(shape)); +- } ++ @Test ++ public void testCreateWithInvalidShape() { ++ int[] shape = new int[] {1, -1, 2}; ++ Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferFloat(shape)); ++ } + +- @Test +- public void testCreateUsingShapeWithZero() { +- int[] shape = new int[] {1, 0, 2}; +- TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape); +- assertThat(tensorBufferFloat).isNotNull(); +- assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(0); +- } ++ @Test ++ public void testCreateUsingShapeWithZero() { ++ int[] shape = new int[] {1, 0, 2}; ++ TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape); ++ assertThat(tensorBufferFloat).isNotNull(); ++ assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(0); ++ } + +- @Test +- public void testGetDataType() { +- TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(); +- assertThat(tensorBufferFloat.getDataType()).isEqualTo(DataType.FLOAT32); +- } ++ @Test ++ public void testGetDataType() { ++ TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(); ++ assertThat(tensorBufferFloat.getDataType()).isEqualTo(DataType.FLOAT32); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java +index c55affe733eac..763356f493390 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java +@@ -16,877 +16,878 @@ limitations under the License. + package org.tensorflow.lite.support.tensorbuffer; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + +-import java.io.IOException; +-import java.nio.ByteBuffer; +-import java.nio.FloatBuffer; +-import java.util.ArrayList; + import org.junit.Assert; + import org.junit.Test; + import org.junit.runner.RunWith; + import org.robolectric.RobolectricTestRunner; + import org.tensorflow.lite.DataType; + ++import java.io.IOException; ++import java.nio.ByteBuffer; ++import java.nio.FloatBuffer; ++import java.util.ArrayList; ++ + /** Test helper class for inserting and retrieving arrays. */ + class ArrayTestRunner { +- // List of TensorBuffer types to be tested. +- private static final DataType[] BUFFER_TYPE_LIST = {DataType.FLOAT32, DataType.UINT8}; +- // List of source arrays to be loaded into TensorBuffer during the tests. +- private final ArrayList<Object> srcArrays; +- // List of array data type with respect to srcArrays. +- private final ArrayList<DataType> arrDataTypes; +- // List of array shape with respect to srcArrays. +- private final ArrayList<int[]> arrShapes; +- private final int[] tensorBufferShape; +- private final ExpectedResults expectedResForFloatBuf; +- private final ExpectedResults expectedResForByteBuf; +- +- public ArrayTestRunner(Builder builder) { +- if (builder.srcArrays.size() != builder.arrDataTypes.size()) { +- throw new IllegalArgumentException( +- "Number of source arrays and number of data types do not match."); +- } +- +- this.srcArrays = builder.srcArrays; +- this.arrDataTypes = builder.arrDataTypes; +- this.arrShapes = builder.arrShapes; +- this.tensorBufferShape = builder.tensorBufferShape; +- this.expectedResForFloatBuf = builder.expectedResForFloatBuf; +- this.expectedResForByteBuf = builder.expectedResForByteBuf; +- } +- +- static class ExpectedResults { +- public float[] floatArr; +- public int[] intArr; +- public int[] shape; +- } +- +- public static class Builder { +- private final ArrayList<Object> srcArrays = new ArrayList<>(); +- private final ArrayList<DataType> arrDataTypes = new ArrayList<>(); +- private final ArrayList<int[]> arrShapes = new ArrayList<>(); +- private int[] tensorBufferShape; +- private final ExpectedResults expectedResForFloatBuf = new ExpectedResults(); +- private final ExpectedResults expectedResForByteBuf = new ExpectedResults(); +- +- public static Builder newInstance() { +- return new Builder(); +- } +- +- private Builder() {} +- +- /** Loads a test array into the test runner. */ +- public Builder addSrcArray(Object src, int[] shape) { +- // src should be a primitive 1D array. +- DataType dataType = dataTypeOfArray(src); +- switch (dataType) { +- case INT32: +- case FLOAT32: +- srcArrays.add(src); +- arrDataTypes.add(dataType); +- arrShapes.add(shape); +- return this; +- default: +- throw new AssertionError("Cannot resolve srouce arrays in the DataType of " + dataType); +- } +- } +- +- public Builder setTensorBufferShape(int[] tensorBufferShape) { +- this.tensorBufferShape = tensorBufferShape; +- return this; +- } +- +- public Builder setExpectedResults( +- DataType bufferType, float[] expectedFloatArr, int[] expectedIntArr) { +- ExpectedResults er; +- switch (bufferType) { +- case UINT8: +- er = expectedResForByteBuf; +- break; +- case FLOAT32: +- er = expectedResForFloatBuf; +- break; +- default: +- throw new AssertionError("Cannot test TensorBuffer in the DataType of " + bufferType); +- } +- +- er.floatArr = expectedFloatArr; +- er.intArr = expectedIntArr; +- return this; +- } +- +- public ArrayTestRunner build() { +- int[] expectedShape; +- if (arrShapes.isEmpty()) { +- // If no array will be loaded, the array is an empty array. +- expectedShape = new int[] {0}; +- } else { +- expectedShape = arrShapes.get(arrShapes.size() - 1); +- } +- expectedResForByteBuf.shape = expectedShape; +- expectedResForFloatBuf.shape = expectedShape; +- return new ArrayTestRunner(this); +- } +- } +- +- public static DataType[] getBufferTypeList() { +- return BUFFER_TYPE_LIST; +- } +- +- /** +- * Runs tests in the following steps: 1. Create a TensorBuffer. If tensorBufferShape is null, +- * create a dynamic buffer. Otherwise, create a fixed-size buffer accordingly. 2. Load arrays in +- * srcArrays one by one into the TensotBuffer. 3. Get arrays for each supported primitive types in +- * TensorBuffer, such as int array and float array for now. Check if the results are correct. 4. +- * Repeat Step 1 to 3 for all buffer types in BUFFER_TYPE_LIST. +- */ +- public void run() { +- for (DataType bufferDataType : BUFFER_TYPE_LIST) { +- TensorBuffer tensorBuffer; +- if (tensorBufferShape == null) { +- tensorBuffer = TensorBuffer.createDynamic(bufferDataType); +- } else { +- tensorBuffer = TensorBuffer.createFixedSize(tensorBufferShape, bufferDataType); +- } +- for (int i = 0; i < srcArrays.size(); i++) { +- switch (arrDataTypes.get(i)) { +- case INT32: +- int[] arrInt = (int[]) srcArrays.get(i); +- tensorBuffer.loadArray(arrInt, arrShapes.get(i)); +- break; +- case FLOAT32: +- float[] arrFloat = (float[]) srcArrays.get(i); +- tensorBuffer.loadArray(arrFloat, arrShapes.get(i)); +- break; +- default: +- break; ++ // List of TensorBuffer types to be tested. ++ private static final DataType[] BUFFER_TYPE_LIST = {DataType.FLOAT32, DataType.UINT8}; ++ // List of source arrays to be loaded into TensorBuffer during the tests. ++ private final ArrayList<Object> srcArrays; ++ // List of array data type with respect to srcArrays. ++ private final ArrayList<DataType> arrDataTypes; ++ // List of array shape with respect to srcArrays. ++ private final ArrayList<int[]> arrShapes; ++ private final int[] tensorBufferShape; ++ private final ExpectedResults expectedResForFloatBuf; ++ private final ExpectedResults expectedResForByteBuf; ++ ++ public ArrayTestRunner(Builder builder) { ++ if (builder.srcArrays.size() != builder.arrDataTypes.size()) { ++ throw new IllegalArgumentException( ++ "Number of source arrays and number of data types do not match."); + } +- } +- checkResults(tensorBuffer); +- } +- } +- +- private void checkResults(TensorBuffer tensorBuffer) { +- ExpectedResults er; +- switch (tensorBuffer.getDataType()) { +- case UINT8: +- er = expectedResForByteBuf; +- break; +- case FLOAT32: +- er = expectedResForFloatBuf; +- break; +- default: +- throw new AssertionError( +- "Cannot test TensorBuffer in the DataType of " + tensorBuffer.getDataType()); +- } +- +- // Checks getIntArray() and getFloatArray(). +- int[] resIntArr = tensorBuffer.getIntArray(); +- assertThat(resIntArr).isEqualTo(er.intArr); +- float[] resFloatArr = tensorBuffer.getFloatArray(); +- assertThat(resFloatArr).isEqualTo(er.floatArr); +- assertThat(tensorBuffer.getShape()).isEqualTo(er.shape); +- +- // Checks getIntValue(int index) and getFloatValue(int index). +- int flatSize = tensorBuffer.getFlatSize(); +- float[] resFloatValues = new float[flatSize]; +- int[] resIntValues = new int[flatSize]; +- for (int i = 0; i < flatSize; i++) { +- resFloatValues[i] = tensorBuffer.getFloatValue(i); +- resIntValues[i] = tensorBuffer.getIntValue(i); +- } +- assertThat(resFloatValues).isEqualTo(er.floatArr); +- assertThat(resIntValues).isEqualTo(er.intArr); +- } +- +- /** Gets the data type of an 1D array. */ +- private static DataType dataTypeOfArray(Object arr) { +- if (arr != null) { +- Class<?> c = arr.getClass(); +- if (c.isArray()) { +- c = c.getComponentType(); +- if (float.class.equals(c)) { +- return DataType.FLOAT32; +- } else if (int.class.equals(c)) { +- return DataType.INT32; +- } else if (byte.class.equals(c)) { +- return DataType.UINT8; +- } else if (long.class.equals(c)) { +- return DataType.INT64; +- } else if (String.class.equals(c)) { +- return DataType.STRING; ++ ++ this.srcArrays = builder.srcArrays; ++ this.arrDataTypes = builder.arrDataTypes; ++ this.arrShapes = builder.arrShapes; ++ this.tensorBufferShape = builder.tensorBufferShape; ++ this.expectedResForFloatBuf = builder.expectedResForFloatBuf; ++ this.expectedResForByteBuf = builder.expectedResForByteBuf; ++ } ++ ++ static class ExpectedResults { ++ public float[] floatArr; ++ public int[] intArr; ++ public int[] shape; ++ } ++ ++ public static class Builder { ++ private final ArrayList<Object> srcArrays = new ArrayList<>(); ++ private final ArrayList<DataType> arrDataTypes = new ArrayList<>(); ++ private final ArrayList<int[]> arrShapes = new ArrayList<>(); ++ private int[] tensorBufferShape; ++ private final ExpectedResults expectedResForFloatBuf = new ExpectedResults(); ++ private final ExpectedResults expectedResForByteBuf = new ExpectedResults(); ++ ++ public static Builder newInstance() { ++ return new Builder(); ++ } ++ ++ private Builder() {} ++ ++ /** Loads a test array into the test runner. */ ++ public Builder addSrcArray(Object src, int[] shape) { ++ // src should be a primitive 1D array. ++ DataType dataType = dataTypeOfArray(src); ++ switch (dataType) { ++ case INT32: ++ case FLOAT32: ++ srcArrays.add(src); ++ arrDataTypes.add(dataType); ++ arrShapes.add(shape); ++ return this; ++ default: ++ throw new AssertionError( ++ "Cannot resolve srouce arrays in the DataType of " + dataType); ++ } ++ } ++ ++ public Builder setTensorBufferShape(int[] tensorBufferShape) { ++ this.tensorBufferShape = tensorBufferShape; ++ return this; + } +- } ++ ++ public Builder setExpectedResults( ++ DataType bufferType, float[] expectedFloatArr, int[] expectedIntArr) { ++ ExpectedResults er; ++ switch (bufferType) { ++ case UINT8: ++ er = expectedResForByteBuf; ++ break; ++ case FLOAT32: ++ er = expectedResForFloatBuf; ++ break; ++ default: ++ throw new AssertionError( ++ "Cannot test TensorBuffer in the DataType of " + bufferType); ++ } ++ ++ er.floatArr = expectedFloatArr; ++ er.intArr = expectedIntArr; ++ return this; ++ } ++ ++ public ArrayTestRunner build() { ++ int[] expectedShape; ++ if (arrShapes.isEmpty()) { ++ // If no array will be loaded, the array is an empty array. ++ expectedShape = new int[] {0}; ++ } else { ++ expectedShape = arrShapes.get(arrShapes.size() - 1); ++ } ++ expectedResForByteBuf.shape = expectedShape; ++ expectedResForFloatBuf.shape = expectedShape; ++ return new ArrayTestRunner(this); ++ } ++ } ++ ++ public static DataType[] getBufferTypeList() { ++ return BUFFER_TYPE_LIST; ++ } ++ ++ /** ++ * Runs tests in the following steps: 1. Create a TensorBuffer. If tensorBufferShape is null, ++ * create a dynamic buffer. Otherwise, create a fixed-size buffer accordingly. 2. Load arrays in ++ * srcArrays one by one into the TensotBuffer. 3. Get arrays for each supported primitive types ++ * in TensorBuffer, such as int array and float array for now. Check if the results are ++ * correct. 4. Repeat Step 1 to 3 for all buffer types in BUFFER_TYPE_LIST. ++ */ ++ public void run() { ++ for (DataType bufferDataType : BUFFER_TYPE_LIST) { ++ TensorBuffer tensorBuffer; ++ if (tensorBufferShape == null) { ++ tensorBuffer = TensorBuffer.createDynamic(bufferDataType); ++ } else { ++ tensorBuffer = TensorBuffer.createFixedSize(tensorBufferShape, bufferDataType); ++ } ++ for (int i = 0; i < srcArrays.size(); i++) { ++ switch (arrDataTypes.get(i)) { ++ case INT32: ++ int[] arrInt = (int[]) srcArrays.get(i); ++ tensorBuffer.loadArray(arrInt, arrShapes.get(i)); ++ break; ++ case FLOAT32: ++ float[] arrFloat = (float[]) srcArrays.get(i); ++ tensorBuffer.loadArray(arrFloat, arrShapes.get(i)); ++ break; ++ default: ++ break; ++ } ++ } ++ checkResults(tensorBuffer); ++ } ++ } ++ ++ private void checkResults(TensorBuffer tensorBuffer) { ++ ExpectedResults er; ++ switch (tensorBuffer.getDataType()) { ++ case UINT8: ++ er = expectedResForByteBuf; ++ break; ++ case FLOAT32: ++ er = expectedResForFloatBuf; ++ break; ++ default: ++ throw new AssertionError("Cannot test TensorBuffer in the DataType of " ++ + tensorBuffer.getDataType()); ++ } ++ ++ // Checks getIntArray() and getFloatArray(). ++ int[] resIntArr = tensorBuffer.getIntArray(); ++ assertThat(resIntArr).isEqualTo(er.intArr); ++ float[] resFloatArr = tensorBuffer.getFloatArray(); ++ assertThat(resFloatArr).isEqualTo(er.floatArr); ++ assertThat(tensorBuffer.getShape()).isEqualTo(er.shape); ++ ++ // Checks getIntValue(int index) and getFloatValue(int index). ++ int flatSize = tensorBuffer.getFlatSize(); ++ float[] resFloatValues = new float[flatSize]; ++ int[] resIntValues = new int[flatSize]; ++ for (int i = 0; i < flatSize; i++) { ++ resFloatValues[i] = tensorBuffer.getFloatValue(i); ++ resIntValues[i] = tensorBuffer.getIntValue(i); ++ } ++ assertThat(resFloatValues).isEqualTo(er.floatArr); ++ assertThat(resIntValues).isEqualTo(er.intArr); ++ } ++ ++ /** Gets the data type of an 1D array. */ ++ private static DataType dataTypeOfArray(Object arr) { ++ if (arr != null) { ++ Class<?> c = arr.getClass(); ++ if (c.isArray()) { ++ c = c.getComponentType(); ++ if (float.class.equals(c)) { ++ return DataType.FLOAT32; ++ } else if (int.class.equals(c)) { ++ return DataType.INT32; ++ } else if (byte.class.equals(c)) { ++ return DataType.UINT8; ++ } else if (long.class.equals(c)) { ++ return DataType.INT64; ++ } else if (String.class.equals(c)) { ++ return DataType.STRING; ++ } ++ } ++ } ++ throw new IllegalArgumentException( ++ "Requires a 1D array. Cannot resolve data type of " + arr.getClass().getName()); + } +- throw new IllegalArgumentException( +- "Requires a 1D array. Cannot resolve data type of " + arr.getClass().getName()); +- } + } + + /** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}. */ + @RunWith(RobolectricTestRunner.class) + public final class TensorBufferTest { +- // FLOAT_ARRAY1 and INT_ARRAY1 correspond to each other. +- private static final int[] ARRAY1_SHAPE = new int[] {2, 3}; +- private static final float[] FLOAT_ARRAY1 = new float[] {500.1f, 4.2f, 3.3f, 2.4f, 1.5f, 6.1f}; +- private static final float[] FLOAT_ARRAY1_ROUNDED = +- new float[] {500.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f}; +- // FLOAT_ARRAY1_CAPPED and INT_ARRAY1_CAPPED correspond to the expected values when converted into +- // uint8. +- private static final float[] FLOAT_ARRAY1_CAPPED = +- new float[] {255.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f}; +- private static final int[] INT_ARRAY1 = new int[] {500, 4, 3, 2, 1, 6}; +- private static final int[] INT_ARRAY1_CAPPED = new int[] {255, 4, 3, 2, 1, 6}; +- // FLOAT_ARRAY2 and INT_ARRAY2 correspond to each other. +- private static final int[] ARRAY2_SHAPE = new int[] {2, 1}; +- private static final float[] FLOAT_ARRAY2 = new float[] {6.7f, 7.6f}; +- private static final float[] FLOAT_ARRAY2_ROUNDED = new float[] {6.0f, 7.0f}; +- private static final int[] INT_ARRAY2 = new int[] {6, 7}; +- // FLOAT_ARRAY2 and FLOAT_ARRAY3 have the same size. +- private static final int[] ARRAY3_SHAPE = new int[] {2, 1}; +- private static final float[] FLOAT_ARRAY3 = new float[] {8.2f, 9.9f}; +- private static final float[] FLOAT_ARRAY3_ROUNDED = new float[] {8.0f, 9.0f}; +- // INT_ARRAY2 and INT_ARRAY3 have the same size. +- private static final int[] INT_ARRAY3 = new int[] {8, 9}; +- private static final int[] EMPTY_ARRAY_SHAPE = new int[] {0}; +- private static final int[] EMPTY_INT_ARRAY = new int[0]; +- private static final float[] EMPTY_FLOAT_ARRAY = new float[0]; +- // Single element array which represents a scalar. +- private static final int[] SCALAR_ARRAY_SHAPE = new int[] {}; +- private static final float[] FLOAT_SCALAR_ARRAY = new float[] {800.2f}; +- private static final float[] FLOAT_SCALAR_ARRAY_ROUNDED = new float[] {800.0f}; +- private static final float[] FLOAT_SCALAR_ARRAY_CAPPED = new float[] {255.0f}; +- private static final int[] INT_SCALAR_ARRAY = new int[] {800}; +- private static final int[] INT_SCALAR_ARRAY_CAPPED = new int[] {255}; +- // Several different ByteBuffer. +- private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocateDirect(0); +- private static final ByteBuffer FLOAT_BYTE_BUFFER1 = ByteBuffer.allocateDirect(24); +- +- static { +- FLOAT_BYTE_BUFFER1.rewind(); +- +- FloatBuffer floatBuffer = FLOAT_BYTE_BUFFER1.asFloatBuffer(); +- floatBuffer.put(FLOAT_ARRAY1); +- } +- +- private static final ByteBuffer INT_BYTE_BUFFER2 = ByteBuffer.allocateDirect(2); +- +- static { +- INT_BYTE_BUFFER2.rewind(); +- +- for (int a : INT_ARRAY2) { +- INT_BYTE_BUFFER2.put((byte) a); +- } +- } +- +- @Test +- public void testCreateFixedSizeTensorBufferFloat() { +- int[] shape = new int[] {1, 2, 3}; +- TensorBuffer tensorBufferFloat = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); +- assertThat(tensorBufferFloat).isNotNull(); +- assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6); +- } +- +- @Test +- public void testCreateFixedSizeTensorBufferUint8() { +- int[] shape = new int[] {1, 2, 3}; +- TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8); +- assertThat(tensorBufferUint8).isNotNull(); +- assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6); +- } +- +- @Test +- public void testCreateDynamicTensorBufferFloat() { +- TensorBuffer tensorBufferFloat = TensorBuffer.createDynamic(DataType.FLOAT32); +- assertThat(tensorBufferFloat).isNotNull(); +- } +- +- @Test +- public void testCreateDynamicTensorBufferUint8() { +- TensorBuffer tensorBufferUint8 = TensorBuffer.createDynamic(DataType.UINT8); +- assertThat(tensorBufferUint8).isNotNull(); +- } +- +- @Test +- public void testCreateTensorBufferFromFixedSize() { +- int[] shape = new int[] {1, 2, 3}; +- TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8); +- TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32); +- assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3}); +- } +- +- @Test +- public void testCreateTensorBufferFromDynamicSize() { +- int[] shape = new int[] {1, 2, 3}; +- TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8); +- src.resize(shape); +- TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32); +- assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3}); +- } +- +- @Test +- public void testCreateTensorBufferUInt8FromUInt8() { +- int[] shape = new int[] {INT_ARRAY1.length}; +- TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8); +- src.loadArray(INT_ARRAY1); +- TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8); +- int[] data = dst.getIntArray(); +- assertThat(data).isEqualTo(INT_ARRAY1_CAPPED); +- } +- +- @Test +- public void testCreateTensorBufferUInt8FromFloat32() { +- TensorBuffer src = TensorBuffer.createDynamic(DataType.FLOAT32); +- src.loadArray(FLOAT_ARRAY1, ARRAY1_SHAPE); +- TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8); +- int[] data = dst.getIntArray(); +- assertThat(data).isEqualTo(INT_ARRAY1_CAPPED); +- } +- +- @Test +- public void testCreateTensorBufferFloat32FromUInt8() { +- TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8); +- src.loadArray(INT_ARRAY1, ARRAY1_SHAPE); +- TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32); +- float[] data = dst.getFloatArray(); +- assertThat(data).isEqualTo(FLOAT_ARRAY1_CAPPED); +- } +- +- @Test +- public void testCreateTensorBufferFloat32FromFloat32() { +- int[] shape = new int[] {FLOAT_ARRAY1.length}; +- TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); +- src.loadArray(FLOAT_ARRAY1); +- TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32); +- float[] data = dst.getFloatArray(); +- assertThat(data).isEqualTo(FLOAT_ARRAY1); +- } +- +- @Test +- public void testGetBuffer() throws IOException { +- int[] shape = new int[] {1, 2, 3}; +- TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8); +- assertThat(tensorBufferUint8.getBuffer()).isNotNull(); +- } +- +- @Test +- public void testLoadAndGetIntArrayWithFixedSizeForScalarArray() throws IOException { +- ArrayTestRunner.Builder.newInstance() +- .addSrcArray(INT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE) +- .setTensorBufferShape(SCALAR_ARRAY_SHAPE) +- .setExpectedResults( +- /*bufferType = */ DataType.FLOAT32, +- /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY_ROUNDED, +- /*expectedIntArr=*/ INT_SCALAR_ARRAY) +- .setExpectedResults( +- /*bufferType = */ DataType.UINT8, +- /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY_CAPPED, +- /*expectedIntArr=*/ INT_SCALAR_ARRAY_CAPPED) +- .build() +- .run(); +- } +- +- @Test +- public void testLoadAndGetFloatArrayWithFixedSizeForScalarArray() throws IOException { +- ArrayTestRunner.Builder.newInstance() +- .addSrcArray(FLOAT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE) +- .setTensorBufferShape(SCALAR_ARRAY_SHAPE) +- .setExpectedResults( +- /*bufferType = */ DataType.FLOAT32, +- /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY, +- /*expectedIntArr=*/ INT_SCALAR_ARRAY) +- .setExpectedResults( +- /*bufferType = */ DataType.UINT8, +- /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY_CAPPED, +- /*expectedIntArr=*/ INT_SCALAR_ARRAY_CAPPED) +- .build() +- .run(); +- } +- +- @Test +- public void testLoadAndGetIntArrayWithFixedSize() { +- ArrayTestRunner.Builder.newInstance() +- .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE) +- .setTensorBufferShape(ARRAY1_SHAPE) +- .setExpectedResults( +- /*bufferType = */ DataType.FLOAT32, +- /*expectedFloatArr=*/ FLOAT_ARRAY1_ROUNDED, +- /*expectedIntArr=*/ INT_ARRAY1) +- .setExpectedResults( +- /*bufferType = */ DataType.UINT8, +- /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED, +- /*expectedIntArr=*/ INT_ARRAY1_CAPPED) +- .build() +- .run(); +- } +- +- @Test +- public void testLoadAndGetFloatArrayWithFixedSize() { +- ArrayTestRunner.Builder.newInstance() +- .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE) +- .setTensorBufferShape(ARRAY1_SHAPE) +- .setExpectedResults( +- /*bufferType = */ DataType.FLOAT32, +- /*expectedFloatArr=*/ FLOAT_ARRAY1, +- /*expectedIntArr=*/ INT_ARRAY1) +- .setExpectedResults( +- /*bufferType = */ DataType.UINT8, +- /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED, +- /*expectedIntArr=*/ INT_ARRAY1_CAPPED) +- .build() +- .run(); +- } +- +- @Test +- public void testRepeatedLoadAndGetIntArrayWithSameFixedSize() { +- ArrayTestRunner.Builder.newInstance() +- .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE) +- .addSrcArray(INT_ARRAY3, ARRAY3_SHAPE) +- .setTensorBufferShape(ARRAY2_SHAPE) +- .setExpectedResults( +- /*bufferType = */ DataType.FLOAT32, +- /*expectedFloatArr=*/ FLOAT_ARRAY3_ROUNDED, +- /*expectedIntArr=*/ INT_ARRAY3) +- .setExpectedResults( +- /*bufferType = */ DataType.UINT8, +- /*expectedFloatArr=*/ FLOAT_ARRAY3_ROUNDED, +- /*expectedIntArr=*/ INT_ARRAY3) +- .build() +- .run(); +- } +- +- @Test +- public void testRepeatedLoadAndGetFloatArrayWithSameFixedSize() { +- ArrayTestRunner.Builder.newInstance() +- .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE) +- .addSrcArray(FLOAT_ARRAY3, ARRAY3_SHAPE) +- .setTensorBufferShape(ARRAY2_SHAPE) +- .setExpectedResults( +- /*bufferType = */ DataType.FLOAT32, +- /*expectedFloatArr=*/ FLOAT_ARRAY3, +- /*expectedIntArr=*/ INT_ARRAY3) +- .setExpectedResults( +- /*bufferType = */ DataType.UINT8, +- /*expectedFloatArr=*/ FLOAT_ARRAY3_ROUNDED, +- /*expectedIntArr=*/ INT_ARRAY3) +- .build() +- .run(); +- } +- +- @Test +- public void testRepeatedLoadIntArrayWithDifferentFixedSize() { +- int[] srcArr1 = INT_ARRAY1; +- int[] srcArr2 = INT_ARRAY2; +- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { +- TensorBuffer tensorBuffer = +- TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType); +- tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length}); +- // Load srcArr2 which had different size as srcArr1. +- Assert.assertThrows( +- IllegalArgumentException.class, +- () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length})); +- } +- } +- +- @Test +- public void testRepeatedLoadFloatArrayWithDifferentFixedSize() { +- float[] srcArr1 = FLOAT_ARRAY1; +- float[] srcArr2 = FLOAT_ARRAY2; +- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { +- TensorBuffer tensorBuffer = +- TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType); +- tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length}); +- // Load srcArr2 which had different size as srcArr1. +- Assert.assertThrows( +- IllegalArgumentException.class, +- () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length})); +- } +- } +- +- @Test +- public void testLoadAndGetIntArrayWithDynamicSize() { +- ArrayTestRunner.Builder.newInstance() +- .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE) +- .setExpectedResults( +- /*bufferType = */ DataType.FLOAT32, +- /*expectedFloatArr=*/ FLOAT_ARRAY1_ROUNDED, +- /*expectedIntArr=*/ INT_ARRAY1) +- .setExpectedResults( +- /*bufferType = */ DataType.UINT8, +- /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED, +- /*expectedIntArr=*/ INT_ARRAY1_CAPPED) +- .build() +- .run(); +- } +- +- @Test +- public void testLoadAndGetFloatArrayWithDynamicSize() { +- ArrayTestRunner.Builder.newInstance() +- .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE) +- .setExpectedResults( +- /*bufferType = */ DataType.FLOAT32, +- /*expectedFloatArr=*/ FLOAT_ARRAY1, +- /*expectedIntArr=*/ INT_ARRAY1) +- .setExpectedResults( +- /*bufferType = */ DataType.UINT8, +- /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED, +- /*expectedIntArr=*/ INT_ARRAY1_CAPPED) +- .build() +- .run(); +- } +- +- @Test +- public void testRepeatedLoadAndGetIntArrayWithDifferentDynamicSize() { +- ArrayTestRunner.Builder.newInstance() +- .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE) +- .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE) +- .setExpectedResults( +- /*bufferType = */ DataType.FLOAT32, +- /*expectedFloatArr=*/ FLOAT_ARRAY2_ROUNDED, +- /*expectedIntArr=*/ INT_ARRAY2) +- .setExpectedResults( +- /*bufferType = */ DataType.UINT8, +- /*expectedFloatArr=*/ FLOAT_ARRAY2_ROUNDED, +- /*expectedIntArr=*/ INT_ARRAY2) +- .build() +- .run(); +- } +- +- @Test +- public void testRepeatedLoadAndGetFloatArrayWithDifferentDynamicSize() { +- ArrayTestRunner.Builder.newInstance() +- .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE) +- .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE) +- .setExpectedResults( +- /*bufferType = */ DataType.FLOAT32, +- /*expectedFloatArr=*/ FLOAT_ARRAY2, +- /*expectedIntArr=*/ INT_ARRAY2) +- .setExpectedResults( +- /*bufferType = */ DataType.UINT8, +- /*expectedFloatArr=*/ FLOAT_ARRAY2_ROUNDED, +- /*expectedIntArr=*/ INT_ARRAY2) +- .build() +- .run(); +- } +- +- @Test +- public void testGetForEmptyArrayWithFixedSizeBuffer() { +- ArrayTestRunner.Builder.newInstance() +- .setTensorBufferShape(EMPTY_ARRAY_SHAPE) +- .setExpectedResults( +- /*bufferType = */ DataType.FLOAT32, +- /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY, +- /*expectedIntArr=*/ EMPTY_INT_ARRAY) +- .setExpectedResults( +- /*bufferType = */ DataType.UINT8, +- /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY, +- /*expectedIntArr=*/ EMPTY_INT_ARRAY) +- .build() +- .run(); +- } +- +- @Test +- public void testGetForEmptyArrayWithDynamicBuffer() { +- ArrayTestRunner.Builder.newInstance() +- .setExpectedResults( +- /*bufferType = */ DataType.FLOAT32, +- /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY, +- /*expectedIntArr=*/ EMPTY_INT_ARRAY) +- .setExpectedResults( +- /*bufferType = */ DataType.UINT8, +- /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY, +- /*expectedIntArr=*/ EMPTY_INT_ARRAY) +- .build() +- .run(); +- } +- +- @Test +- public void testRepeatedLoadAndGetForEmptyArray() { +- ArrayTestRunner.Builder.newInstance() +- .addSrcArray(EMPTY_INT_ARRAY, EMPTY_ARRAY_SHAPE) +- .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE) +- .addSrcArray(EMPTY_FLOAT_ARRAY, EMPTY_ARRAY_SHAPE) +- .setExpectedResults( +- /*bufferType = */ DataType.FLOAT32, +- /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY, +- /*expectedIntArr=*/ EMPTY_INT_ARRAY) +- .setExpectedResults( +- /*bufferType = */ DataType.UINT8, +- /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY, +- /*expectedIntArr=*/ EMPTY_INT_ARRAY) +- .build() +- .run(); +- } +- +- @Test +- public void testLoadNullIntArrays() { +- int[] nullArray = null; +- int[] shape = new int[] {}; +- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); +- Assert.assertThrows( +- NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape)); +- } +- } +- +- @Test +- public void testLoadNullFloatArrays() { +- float[] nullArray = null; +- int[] shape = new int[] {}; +- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); +- Assert.assertThrows( +- NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape)); +- } +- } +- +- @Test +- public void testLoadFloatArraysWithNullShape() { +- float[] arr = new float[] {1.0f}; +- int[] nullShape = null; +- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); +- Assert.assertThrows(NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape)); +- } +- } +- +- @Test +- public void testLoadIntArraysWithNullShape() { +- int[] arr = new int[] {1}; +- int[] nullShape = null; +- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); +- Assert.assertThrows(NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape)); +- } +- } +- +- @Test +- public void testLoadIntArraysWithoutShapeAndArrayDoesNotMatchShape() { +- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { +- TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType); +- Assert.assertThrows( +- IllegalArgumentException.class, () -> fixedTensorBuffer.loadArray(INT_ARRAY2)); +- } +- } +- +- @Test +- public void testLoadFloatArraysWithoutShapeAndArrayDoesNotMatchShape() { +- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { +- TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType); +- Assert.assertThrows( +- IllegalArgumentException.class, () -> fixedTensorBuffer.loadArray(FLOAT_ARRAY2)); +- } +- } +- +- @Test +- public void testLoadByteBufferForNullBuffer() { +- ByteBuffer byteBuffer = null; +- int[] shape = new int[] {}; +- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); +- Assert.assertThrows( +- NullPointerException.class, () -> tensorBuffer.loadBuffer(byteBuffer, shape)); +- } +- } +- +- @Test +- public void testLoadByteBufferForEmptyBuffer() { +- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); +- tensorBuffer.loadBuffer(EMPTY_BYTE_BUFFER, EMPTY_ARRAY_SHAPE); +- assertThat(tensorBuffer.getFlatSize()).isEqualTo(0); +- } +- } +- +- @Test +- public void testLoadByteBufferWithDifferentFixedSize() { +- // Create a fixed-size TensorBuffer with size 2, and load a ByteBuffer with size 5. +- int[] tensorBufferShape = new int[] {2}; +- TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(tensorBufferShape, DataType.FLOAT32); +- Assert.assertThrows( +- IllegalArgumentException.class, +- () -> tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE)); +- } +- +- @Test +- public void testLoadByteBufferWithMisMatchDataType() { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); +- int[] wrongShape = new int[] {1}; +- // Size of INT_BYTE_BUFFER is 8 bytes. It does not match the specified shape. +- Assert.assertThrows( +- IllegalArgumentException.class, +- () -> tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, wrongShape)); +- } +- +- @Test +- public void testLoadByteBufferForTensorBufferFloat() { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); +- tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE); +- assertThat(tensorBuffer.getFloatArray()).isEqualTo(FLOAT_ARRAY1); +- assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY1_SHAPE); +- } +- +- @Test +- public void testLoadByteBufferForTensorBufferUint8() { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); +- tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, ARRAY2_SHAPE); +- assertThat(tensorBuffer.getIntArray()).isEqualTo(INT_ARRAY2); +- assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY2_SHAPE); +- } +- +- @Test +- public void testGetFloatValueWithInvalidIndex() { +- float[] arrayWithSixElements = FLOAT_ARRAY1; +- int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE; +- int[] invalidIndexes = {-1, 7}; +- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); +- tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements); +- for (int invalidIndex : invalidIndexes) { +- Assert.assertThrows( +- IndexOutOfBoundsException.class, () -> tensorBuffer.getFloatValue(invalidIndex)); +- } +- } +- } +- +- @Test +- public void testGetFloatValueFromScalarWithInvalidIndex() { +- int[] shape = new int[] {}; +- float[] arr = new float[] {10.0f}; +- int[] invalidIndexes = +- new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize. +- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); +- tensorBuffer.loadArray(arr, shape); +- for (int invalidIndex : invalidIndexes) { +- Assert.assertThrows( +- IndexOutOfBoundsException.class, () -> tensorBuffer.getFloatValue(invalidIndex)); +- } +- } +- } +- +- @Test +- public void testGetIntValueWithInvalidIndex() { +- float[] arrayWithSixElements = FLOAT_ARRAY1; +- int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE; +- int[] invalidIndexes = {-1, 7}; +- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); +- tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements); +- for (int invalidIndex : invalidIndexes) { +- Assert.assertThrows( +- IndexOutOfBoundsException.class, () -> tensorBuffer.getIntValue(invalidIndex)); +- } +- } +- } +- +- @Test +- public void testGetIntValueFromScalarWithInvalidIndex() { +- int[] shape = new int[] {}; +- float[] arr = new float[] {10.0f}; +- int[] invalidIndexes = +- new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize. +- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); +- tensorBuffer.loadArray(arr, shape); +- for (int invalidIndex : invalidIndexes) { +- Assert.assertThrows( +- IndexOutOfBoundsException.class, () -> tensorBuffer.getIntValue(invalidIndex)); +- } +- } +- } +- +- @Test +- public void testLoadByteBufferSliceForTensorBufferFloat() { +- TensorBuffer original = TensorBuffer.createDynamic(DataType.FLOAT32); +- original.loadArray(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, new int[] {6}); +- ByteBuffer buffer = original.getBuffer(); +- // Slice original buffer to 3 sub-buffer, each of which has 2 element +- int numBuffers = 3; +- int numElements = 2; +- int subArrayLength = numElements * original.getTypeSize(); +- TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType()); +- for (int i = 0; i < numBuffers; i++) { +- buffer.position(i * subArrayLength); +- ByteBuffer subBuffer = buffer.slice(); +- // ByteBuffer.slice doesn't keep order. +- subBuffer.order(buffer.order()).limit(subArrayLength); +- tensorSlice.loadBuffer(subBuffer, new int[] {numElements}); +- float[] arraySlice = tensorSlice.getFloatArray(); +- assertThat(arraySlice.length).isEqualTo(numElements); +- assertThat(arraySlice[0]).isEqualTo(i * numElements + 1); +- assertThat(arraySlice[1]).isEqualTo(i * numElements + 2); +- } +- } +- +- @Test +- public void testLoadByteBufferSliceForTensorBufferUInt8() { +- TensorBuffer original = TensorBuffer.createDynamic(DataType.UINT8); +- original.loadArray(new int[] {1, 2, 3, 4, 5, 6}, new int[] {6}); +- ByteBuffer buffer = original.getBuffer(); +- // Slice original buffer to 3 sub-buffer, each of which has 2 element +- int numBuffers = 3; +- int numElements = 2; +- int subArrayLength = numElements * original.getTypeSize(); +- TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType()); +- for (int i = 0; i < numBuffers; i++) { +- buffer.position(i * subArrayLength); +- ByteBuffer subBuffer = buffer.slice(); +- // ByteBuffer.slice doesn't keep order. +- subBuffer.order(buffer.order()).limit(subArrayLength); +- tensorSlice.loadBuffer(subBuffer, new int[] {numElements}); +- int[] arraySlice = tensorSlice.getIntArray(); +- assertThat(arraySlice.length).isEqualTo(numElements); +- assertThat(arraySlice[0]).isEqualTo(i * numElements + 1); +- assertThat(arraySlice[1]).isEqualTo(i * numElements + 2); +- } +- } +- +- @Test +- public void getShapeFailsAfterByteBufferChanged() { +- TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32); +- ByteBuffer byteBuffer = tensorBuffer.getBuffer(); +- byteBuffer.limit(5); +- +- IllegalStateException exception = +- assertThrows(IllegalStateException.class, tensorBuffer::getShape); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The" ++ // FLOAT_ARRAY1 and INT_ARRAY1 correspond to each other. ++ private static final int[] ARRAY1_SHAPE = new int[] {2, 3}; ++ private static final float[] FLOAT_ARRAY1 = new float[] {500.1f, 4.2f, 3.3f, 2.4f, 1.5f, 6.1f}; ++ private static final float[] FLOAT_ARRAY1_ROUNDED = ++ new float[] {500.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f}; ++ // FLOAT_ARRAY1_CAPPED and INT_ARRAY1_CAPPED correspond to the expected values when converted ++ // into uint8. ++ private static final float[] FLOAT_ARRAY1_CAPPED = ++ new float[] {255.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f}; ++ private static final int[] INT_ARRAY1 = new int[] {500, 4, 3, 2, 1, 6}; ++ private static final int[] INT_ARRAY1_CAPPED = new int[] {255, 4, 3, 2, 1, 6}; ++ // FLOAT_ARRAY2 and INT_ARRAY2 correspond to each other. ++ private static final int[] ARRAY2_SHAPE = new int[] {2, 1}; ++ private static final float[] FLOAT_ARRAY2 = new float[] {6.7f, 7.6f}; ++ private static final float[] FLOAT_ARRAY2_ROUNDED = new float[] {6.0f, 7.0f}; ++ private static final int[] INT_ARRAY2 = new int[] {6, 7}; ++ // FLOAT_ARRAY2 and FLOAT_ARRAY3 have the same size. ++ private static final int[] ARRAY3_SHAPE = new int[] {2, 1}; ++ private static final float[] FLOAT_ARRAY3 = new float[] {8.2f, 9.9f}; ++ private static final float[] FLOAT_ARRAY3_ROUNDED = new float[] {8.0f, 9.0f}; ++ // INT_ARRAY2 and INT_ARRAY3 have the same size. ++ private static final int[] INT_ARRAY3 = new int[] {8, 9}; ++ private static final int[] EMPTY_ARRAY_SHAPE = new int[] {0}; ++ private static final int[] EMPTY_INT_ARRAY = new int[0]; ++ private static final float[] EMPTY_FLOAT_ARRAY = new float[0]; ++ // Single element array which represents a scalar. ++ private static final int[] SCALAR_ARRAY_SHAPE = new int[] {}; ++ private static final float[] FLOAT_SCALAR_ARRAY = new float[] {800.2f}; ++ private static final float[] FLOAT_SCALAR_ARRAY_ROUNDED = new float[] {800.0f}; ++ private static final float[] FLOAT_SCALAR_ARRAY_CAPPED = new float[] {255.0f}; ++ private static final int[] INT_SCALAR_ARRAY = new int[] {800}; ++ private static final int[] INT_SCALAR_ARRAY_CAPPED = new int[] {255}; ++ // Several different ByteBuffer. ++ private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocateDirect(0); ++ private static final ByteBuffer FLOAT_BYTE_BUFFER1 = ByteBuffer.allocateDirect(24); ++ ++ static { ++ FLOAT_BYTE_BUFFER1.rewind(); ++ ++ FloatBuffer floatBuffer = FLOAT_BYTE_BUFFER1.asFloatBuffer(); ++ floatBuffer.put(FLOAT_ARRAY1); ++ } ++ ++ private static final ByteBuffer INT_BYTE_BUFFER2 = ByteBuffer.allocateDirect(2); ++ ++ static { ++ INT_BYTE_BUFFER2.rewind(); ++ ++ for (int a : INT_ARRAY2) { ++ INT_BYTE_BUFFER2.put((byte) a); ++ } ++ } ++ ++ @Test ++ public void testCreateFixedSizeTensorBufferFloat() { ++ int[] shape = new int[] {1, 2, 3}; ++ TensorBuffer tensorBufferFloat = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); ++ assertThat(tensorBufferFloat).isNotNull(); ++ assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6); ++ } ++ ++ @Test ++ public void testCreateFixedSizeTensorBufferUint8() { ++ int[] shape = new int[] {1, 2, 3}; ++ TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8); ++ assertThat(tensorBufferUint8).isNotNull(); ++ assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6); ++ } ++ ++ @Test ++ public void testCreateDynamicTensorBufferFloat() { ++ TensorBuffer tensorBufferFloat = TensorBuffer.createDynamic(DataType.FLOAT32); ++ assertThat(tensorBufferFloat).isNotNull(); ++ } ++ ++ @Test ++ public void testCreateDynamicTensorBufferUint8() { ++ TensorBuffer tensorBufferUint8 = TensorBuffer.createDynamic(DataType.UINT8); ++ assertThat(tensorBufferUint8).isNotNull(); ++ } ++ ++ @Test ++ public void testCreateTensorBufferFromFixedSize() { ++ int[] shape = new int[] {1, 2, 3}; ++ TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8); ++ TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32); ++ assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3}); ++ } ++ ++ @Test ++ public void testCreateTensorBufferFromDynamicSize() { ++ int[] shape = new int[] {1, 2, 3}; ++ TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8); ++ src.resize(shape); ++ TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32); ++ assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3}); ++ } ++ ++ @Test ++ public void testCreateTensorBufferUInt8FromUInt8() { ++ int[] shape = new int[] {INT_ARRAY1.length}; ++ TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8); ++ src.loadArray(INT_ARRAY1); ++ TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8); ++ int[] data = dst.getIntArray(); ++ assertThat(data).isEqualTo(INT_ARRAY1_CAPPED); ++ } ++ ++ @Test ++ public void testCreateTensorBufferUInt8FromFloat32() { ++ TensorBuffer src = TensorBuffer.createDynamic(DataType.FLOAT32); ++ src.loadArray(FLOAT_ARRAY1, ARRAY1_SHAPE); ++ TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8); ++ int[] data = dst.getIntArray(); ++ assertThat(data).isEqualTo(INT_ARRAY1_CAPPED); ++ } ++ ++ @Test ++ public void testCreateTensorBufferFloat32FromUInt8() { ++ TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8); ++ src.loadArray(INT_ARRAY1, ARRAY1_SHAPE); ++ TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32); ++ float[] data = dst.getFloatArray(); ++ assertThat(data).isEqualTo(FLOAT_ARRAY1_CAPPED); ++ } ++ ++ @Test ++ public void testCreateTensorBufferFloat32FromFloat32() { ++ int[] shape = new int[] {FLOAT_ARRAY1.length}; ++ TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); ++ src.loadArray(FLOAT_ARRAY1); ++ TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32); ++ float[] data = dst.getFloatArray(); ++ assertThat(data).isEqualTo(FLOAT_ARRAY1); ++ } ++ ++ @Test ++ public void testGetBuffer() throws IOException { ++ int[] shape = new int[] {1, 2, 3}; ++ TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8); ++ assertThat(tensorBufferUint8.getBuffer()).isNotNull(); ++ } ++ ++ @Test ++ public void testLoadAndGetIntArrayWithFixedSizeForScalarArray() throws IOException { ++ ArrayTestRunner.Builder.newInstance() ++ .addSrcArray(INT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE) ++ .setTensorBufferShape(SCALAR_ARRAY_SHAPE) ++ .setExpectedResults( ++ /*bufferType = */ DataType.FLOAT32, ++ /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_ROUNDED, ++ /*expectedIntArr=*/INT_SCALAR_ARRAY) ++ .setExpectedResults( ++ /*bufferType = */ DataType.UINT8, ++ /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_CAPPED, ++ /*expectedIntArr=*/INT_SCALAR_ARRAY_CAPPED) ++ .build() ++ .run(); ++ } ++ ++ @Test ++ public void testLoadAndGetFloatArrayWithFixedSizeForScalarArray() throws IOException { ++ ArrayTestRunner.Builder.newInstance() ++ .addSrcArray(FLOAT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE) ++ .setTensorBufferShape(SCALAR_ARRAY_SHAPE) ++ .setExpectedResults( ++ /*bufferType = */ DataType.FLOAT32, ++ /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY, ++ /*expectedIntArr=*/INT_SCALAR_ARRAY) ++ .setExpectedResults( ++ /*bufferType = */ DataType.UINT8, ++ /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_CAPPED, ++ /*expectedIntArr=*/INT_SCALAR_ARRAY_CAPPED) ++ .build() ++ .run(); ++ } ++ ++ @Test ++ public void testLoadAndGetIntArrayWithFixedSize() { ++ ArrayTestRunner.Builder.newInstance() ++ .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE) ++ .setTensorBufferShape(ARRAY1_SHAPE) ++ .setExpectedResults( ++ /*bufferType = */ DataType.FLOAT32, ++ /*expectedFloatArr=*/FLOAT_ARRAY1_ROUNDED, ++ /*expectedIntArr=*/INT_ARRAY1) ++ .setExpectedResults( ++ /*bufferType = */ DataType.UINT8, ++ /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED, ++ /*expectedIntArr=*/INT_ARRAY1_CAPPED) ++ .build() ++ .run(); ++ } ++ ++ @Test ++ public void testLoadAndGetFloatArrayWithFixedSize() { ++ ArrayTestRunner.Builder.newInstance() ++ .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE) ++ .setTensorBufferShape(ARRAY1_SHAPE) ++ .setExpectedResults( ++ /*bufferType = */ DataType.FLOAT32, ++ /*expectedFloatArr=*/FLOAT_ARRAY1, ++ /*expectedIntArr=*/INT_ARRAY1) ++ .setExpectedResults( ++ /*bufferType = */ DataType.UINT8, ++ /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED, ++ /*expectedIntArr=*/INT_ARRAY1_CAPPED) ++ .build() ++ .run(); ++ } ++ ++ @Test ++ public void testRepeatedLoadAndGetIntArrayWithSameFixedSize() { ++ ArrayTestRunner.Builder.newInstance() ++ .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE) ++ .addSrcArray(INT_ARRAY3, ARRAY3_SHAPE) ++ .setTensorBufferShape(ARRAY2_SHAPE) ++ .setExpectedResults( ++ /*bufferType = */ DataType.FLOAT32, ++ /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED, ++ /*expectedIntArr=*/INT_ARRAY3) ++ .setExpectedResults( ++ /*bufferType = */ DataType.UINT8, ++ /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED, ++ /*expectedIntArr=*/INT_ARRAY3) ++ .build() ++ .run(); ++ } ++ ++ @Test ++ public void testRepeatedLoadAndGetFloatArrayWithSameFixedSize() { ++ ArrayTestRunner.Builder.newInstance() ++ .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE) ++ .addSrcArray(FLOAT_ARRAY3, ARRAY3_SHAPE) ++ .setTensorBufferShape(ARRAY2_SHAPE) ++ .setExpectedResults( ++ /*bufferType = */ DataType.FLOAT32, ++ /*expectedFloatArr=*/FLOAT_ARRAY3, ++ /*expectedIntArr=*/INT_ARRAY3) ++ .setExpectedResults( ++ /*bufferType = */ DataType.UINT8, ++ /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED, ++ /*expectedIntArr=*/INT_ARRAY3) ++ .build() ++ .run(); ++ } ++ ++ @Test ++ public void testRepeatedLoadIntArrayWithDifferentFixedSize() { ++ int[] srcArr1 = INT_ARRAY1; ++ int[] srcArr2 = INT_ARRAY2; ++ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { ++ TensorBuffer tensorBuffer = ++ TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType); ++ tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length}); ++ // Load srcArr2 which had different size as srcArr1. ++ Assert.assertThrows(IllegalArgumentException.class, ++ () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length})); ++ } ++ } ++ ++ @Test ++ public void testRepeatedLoadFloatArrayWithDifferentFixedSize() { ++ float[] srcArr1 = FLOAT_ARRAY1; ++ float[] srcArr2 = FLOAT_ARRAY2; ++ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { ++ TensorBuffer tensorBuffer = ++ TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType); ++ tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length}); ++ // Load srcArr2 which had different size as srcArr1. ++ Assert.assertThrows(IllegalArgumentException.class, ++ () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length})); ++ } ++ } ++ ++ @Test ++ public void testLoadAndGetIntArrayWithDynamicSize() { ++ ArrayTestRunner.Builder.newInstance() ++ .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE) ++ .setExpectedResults( ++ /*bufferType = */ DataType.FLOAT32, ++ /*expectedFloatArr=*/FLOAT_ARRAY1_ROUNDED, ++ /*expectedIntArr=*/INT_ARRAY1) ++ .setExpectedResults( ++ /*bufferType = */ DataType.UINT8, ++ /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED, ++ /*expectedIntArr=*/INT_ARRAY1_CAPPED) ++ .build() ++ .run(); ++ } ++ ++ @Test ++ public void testLoadAndGetFloatArrayWithDynamicSize() { ++ ArrayTestRunner.Builder.newInstance() ++ .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE) ++ .setExpectedResults( ++ /*bufferType = */ DataType.FLOAT32, ++ /*expectedFloatArr=*/FLOAT_ARRAY1, ++ /*expectedIntArr=*/INT_ARRAY1) ++ .setExpectedResults( ++ /*bufferType = */ DataType.UINT8, ++ /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED, ++ /*expectedIntArr=*/INT_ARRAY1_CAPPED) ++ .build() ++ .run(); ++ } ++ ++ @Test ++ public void testRepeatedLoadAndGetIntArrayWithDifferentDynamicSize() { ++ ArrayTestRunner.Builder.newInstance() ++ .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE) ++ .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE) ++ .setExpectedResults( ++ /*bufferType = */ DataType.FLOAT32, ++ /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED, ++ /*expectedIntArr=*/INT_ARRAY2) ++ .setExpectedResults( ++ /*bufferType = */ DataType.UINT8, ++ /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED, ++ /*expectedIntArr=*/INT_ARRAY2) ++ .build() ++ .run(); ++ } ++ ++ @Test ++ public void testRepeatedLoadAndGetFloatArrayWithDifferentDynamicSize() { ++ ArrayTestRunner.Builder.newInstance() ++ .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE) ++ .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE) ++ .setExpectedResults( ++ /*bufferType = */ DataType.FLOAT32, ++ /*expectedFloatArr=*/FLOAT_ARRAY2, ++ /*expectedIntArr=*/INT_ARRAY2) ++ .setExpectedResults( ++ /*bufferType = */ DataType.UINT8, ++ /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED, ++ /*expectedIntArr=*/INT_ARRAY2) ++ .build() ++ .run(); ++ } ++ ++ @Test ++ public void testGetForEmptyArrayWithFixedSizeBuffer() { ++ ArrayTestRunner.Builder.newInstance() ++ .setTensorBufferShape(EMPTY_ARRAY_SHAPE) ++ .setExpectedResults( ++ /*bufferType = */ DataType.FLOAT32, ++ /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY, ++ /*expectedIntArr=*/EMPTY_INT_ARRAY) ++ .setExpectedResults( ++ /*bufferType = */ DataType.UINT8, ++ /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY, ++ /*expectedIntArr=*/EMPTY_INT_ARRAY) ++ .build() ++ .run(); ++ } ++ ++ @Test ++ public void testGetForEmptyArrayWithDynamicBuffer() { ++ ArrayTestRunner.Builder.newInstance() ++ .setExpectedResults( ++ /*bufferType = */ DataType.FLOAT32, ++ /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY, ++ /*expectedIntArr=*/EMPTY_INT_ARRAY) ++ .setExpectedResults( ++ /*bufferType = */ DataType.UINT8, ++ /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY, ++ /*expectedIntArr=*/EMPTY_INT_ARRAY) ++ .build() ++ .run(); ++ } ++ ++ @Test ++ public void testRepeatedLoadAndGetForEmptyArray() { ++ ArrayTestRunner.Builder.newInstance() ++ .addSrcArray(EMPTY_INT_ARRAY, EMPTY_ARRAY_SHAPE) ++ .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE) ++ .addSrcArray(EMPTY_FLOAT_ARRAY, EMPTY_ARRAY_SHAPE) ++ .setExpectedResults( ++ /*bufferType = */ DataType.FLOAT32, ++ /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY, ++ /*expectedIntArr=*/EMPTY_INT_ARRAY) ++ .setExpectedResults( ++ /*bufferType = */ DataType.UINT8, ++ /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY, ++ /*expectedIntArr=*/EMPTY_INT_ARRAY) ++ .build() ++ .run(); ++ } ++ ++ @Test ++ public void testLoadNullIntArrays() { ++ int[] nullArray = null; ++ int[] shape = new int[] {}; ++ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); ++ Assert.assertThrows( ++ NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape)); ++ } ++ } ++ ++ @Test ++ public void testLoadNullFloatArrays() { ++ float[] nullArray = null; ++ int[] shape = new int[] {}; ++ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); ++ Assert.assertThrows( ++ NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape)); ++ } ++ } ++ ++ @Test ++ public void testLoadFloatArraysWithNullShape() { ++ float[] arr = new float[] {1.0f}; ++ int[] nullShape = null; ++ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); ++ Assert.assertThrows( ++ NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape)); ++ } ++ } ++ ++ @Test ++ public void testLoadIntArraysWithNullShape() { ++ int[] arr = new int[] {1}; ++ int[] nullShape = null; ++ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); ++ Assert.assertThrows( ++ NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape)); ++ } ++ } ++ ++ @Test ++ public void testLoadIntArraysWithoutShapeAndArrayDoesNotMatchShape() { ++ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { ++ TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType); ++ Assert.assertThrows( ++ IllegalArgumentException.class, () -> fixedTensorBuffer.loadArray(INT_ARRAY2)); ++ } ++ } ++ ++ @Test ++ public void testLoadFloatArraysWithoutShapeAndArrayDoesNotMatchShape() { ++ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { ++ TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType); ++ Assert.assertThrows(IllegalArgumentException.class, ++ () -> fixedTensorBuffer.loadArray(FLOAT_ARRAY2)); ++ } ++ } ++ ++ @Test ++ public void testLoadByteBufferForNullBuffer() { ++ ByteBuffer byteBuffer = null; ++ int[] shape = new int[] {}; ++ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); ++ Assert.assertThrows( ++ NullPointerException.class, () -> tensorBuffer.loadBuffer(byteBuffer, shape)); ++ } ++ } ++ ++ @Test ++ public void testLoadByteBufferForEmptyBuffer() { ++ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); ++ tensorBuffer.loadBuffer(EMPTY_BYTE_BUFFER, EMPTY_ARRAY_SHAPE); ++ assertThat(tensorBuffer.getFlatSize()).isEqualTo(0); ++ } ++ } ++ ++ @Test ++ public void testLoadByteBufferWithDifferentFixedSize() { ++ // Create a fixed-size TensorBuffer with size 2, and load a ByteBuffer with size 5. ++ int[] tensorBufferShape = new int[] {2}; ++ TensorBuffer tensorBuffer = ++ TensorBuffer.createFixedSize(tensorBufferShape, DataType.FLOAT32); ++ Assert.assertThrows(IllegalArgumentException.class, ++ () -> tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE)); ++ } ++ ++ @Test ++ public void testLoadByteBufferWithMisMatchDataType() { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); ++ int[] wrongShape = new int[] {1}; ++ // Size of INT_BYTE_BUFFER is 8 bytes. It does not match the specified shape. ++ Assert.assertThrows(IllegalArgumentException.class, ++ () -> tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, wrongShape)); ++ } ++ ++ @Test ++ public void testLoadByteBufferForTensorBufferFloat() { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); ++ tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE); ++ assertThat(tensorBuffer.getFloatArray()).isEqualTo(FLOAT_ARRAY1); ++ assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY1_SHAPE); ++ } ++ ++ @Test ++ public void testLoadByteBufferForTensorBufferUint8() { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); ++ tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, ARRAY2_SHAPE); ++ assertThat(tensorBuffer.getIntArray()).isEqualTo(INT_ARRAY2); ++ assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY2_SHAPE); ++ } ++ ++ @Test ++ public void testGetFloatValueWithInvalidIndex() { ++ float[] arrayWithSixElements = FLOAT_ARRAY1; ++ int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE; ++ int[] invalidIndexes = {-1, 7}; ++ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); ++ tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements); ++ for (int invalidIndex : invalidIndexes) { ++ Assert.assertThrows(IndexOutOfBoundsException.class, ++ () -> tensorBuffer.getFloatValue(invalidIndex)); ++ } ++ } ++ } ++ ++ @Test ++ public void testGetFloatValueFromScalarWithInvalidIndex() { ++ int[] shape = new int[] {}; ++ float[] arr = new float[] {10.0f}; ++ int[] invalidIndexes = ++ new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize. ++ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); ++ tensorBuffer.loadArray(arr, shape); ++ for (int invalidIndex : invalidIndexes) { ++ Assert.assertThrows(IndexOutOfBoundsException.class, ++ () -> tensorBuffer.getFloatValue(invalidIndex)); ++ } ++ } ++ } ++ ++ @Test ++ public void testGetIntValueWithInvalidIndex() { ++ float[] arrayWithSixElements = FLOAT_ARRAY1; ++ int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE; ++ int[] invalidIndexes = {-1, 7}; ++ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); ++ tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements); ++ for (int invalidIndex : invalidIndexes) { ++ Assert.assertThrows(IndexOutOfBoundsException.class, ++ () -> tensorBuffer.getIntValue(invalidIndex)); ++ } ++ } ++ } ++ ++ @Test ++ public void testGetIntValueFromScalarWithInvalidIndex() { ++ int[] shape = new int[] {}; ++ float[] arr = new float[] {10.0f}; ++ int[] invalidIndexes = ++ new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize. ++ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); ++ tensorBuffer.loadArray(arr, shape); ++ for (int invalidIndex : invalidIndexes) { ++ Assert.assertThrows(IndexOutOfBoundsException.class, ++ () -> tensorBuffer.getIntValue(invalidIndex)); ++ } ++ } ++ } ++ ++ @Test ++ public void testLoadByteBufferSliceForTensorBufferFloat() { ++ TensorBuffer original = TensorBuffer.createDynamic(DataType.FLOAT32); ++ original.loadArray(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, new int[] {6}); ++ ByteBuffer buffer = original.getBuffer(); ++ // Slice original buffer to 3 sub-buffer, each of which has 2 element ++ int numBuffers = 3; ++ int numElements = 2; ++ int subArrayLength = numElements * original.getTypeSize(); ++ TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType()); ++ for (int i = 0; i < numBuffers; i++) { ++ buffer.position(i * subArrayLength); ++ ByteBuffer subBuffer = buffer.slice(); ++ // ByteBuffer.slice doesn't keep order. ++ subBuffer.order(buffer.order()).limit(subArrayLength); ++ tensorSlice.loadBuffer(subBuffer, new int[] {numElements}); ++ float[] arraySlice = tensorSlice.getFloatArray(); ++ assertThat(arraySlice.length).isEqualTo(numElements); ++ assertThat(arraySlice[0]).isEqualTo(i * numElements + 1); ++ assertThat(arraySlice[1]).isEqualTo(i * numElements + 2); ++ } ++ } ++ ++ @Test ++ public void testLoadByteBufferSliceForTensorBufferUInt8() { ++ TensorBuffer original = TensorBuffer.createDynamic(DataType.UINT8); ++ original.loadArray(new int[] {1, 2, 3, 4, 5, 6}, new int[] {6}); ++ ByteBuffer buffer = original.getBuffer(); ++ // Slice original buffer to 3 sub-buffer, each of which has 2 element ++ int numBuffers = 3; ++ int numElements = 2; ++ int subArrayLength = numElements * original.getTypeSize(); ++ TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType()); ++ for (int i = 0; i < numBuffers; i++) { ++ buffer.position(i * subArrayLength); ++ ByteBuffer subBuffer = buffer.slice(); ++ // ByteBuffer.slice doesn't keep order. ++ subBuffer.order(buffer.order()).limit(subArrayLength); ++ tensorSlice.loadBuffer(subBuffer, new int[] {numElements}); ++ int[] arraySlice = tensorSlice.getIntArray(); ++ assertThat(arraySlice.length).isEqualTo(numElements); ++ assertThat(arraySlice[0]).isEqualTo(i * numElements + 1); ++ assertThat(arraySlice[1]).isEqualTo(i * numElements + 2); ++ } ++ } ++ ++ @Test ++ public void getShapeFailsAfterByteBufferChanged() { ++ TensorBuffer tensorBuffer = ++ TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32); ++ ByteBuffer byteBuffer = tensorBuffer.getBuffer(); ++ byteBuffer.limit(5); ++ ++ IllegalStateException exception = ++ assertThrows(IllegalStateException.class, tensorBuffer::getShape); ++ assertThat(exception).hasMessageThat().contains( ++ "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The" + + " ByteBuffer may have been changed."); +- } +- +- @Test +- public void getFlatSizeFailsAfterByteBufferChanged() { +- TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32); +- ByteBuffer byteBuffer = tensorBuffer.getBuffer(); +- byteBuffer.limit(5); +- +- IllegalStateException exception = +- assertThrows(IllegalStateException.class, tensorBuffer::getFlatSize); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The" ++ } ++ ++ @Test ++ public void getFlatSizeFailsAfterByteBufferChanged() { ++ TensorBuffer tensorBuffer = ++ TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32); ++ ByteBuffer byteBuffer = tensorBuffer.getBuffer(); ++ byteBuffer.limit(5); ++ ++ IllegalStateException exception = ++ assertThrows(IllegalStateException.class, tensorBuffer::getFlatSize); ++ assertThat(exception).hasMessageThat().contains( ++ "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The" + + " ByteBuffer may have been changed."); +- } +- +- @Test +- public void loadReadOnlyBuffersCopiesOnWrite() { +- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); +- ByteBuffer originalByteBuffer = ByteBuffer.allocateDirect(1); +- originalByteBuffer.put(new byte[]{99}); +- originalByteBuffer.rewind(); +- ByteBuffer readOnlyByteBuffer = originalByteBuffer.asReadOnlyBuffer(); +- +- tensorBuffer.loadBuffer(readOnlyByteBuffer, new int[]{1}); +- assertThat(tensorBuffer.getBuffer()).isSameInstanceAs(readOnlyByteBuffer); +- +- tensorBuffer.loadArray(new int[]{42}); +- assertThat(tensorBuffer.getBuffer()).isNotSameInstanceAs(readOnlyByteBuffer); +- assertThat(tensorBuffer.getBuffer().get(0)).isEqualTo(42); // updated +- assertThat(originalByteBuffer.get(0)).isEqualTo(99); // original one not changed +- } ++ } ++ ++ @Test ++ public void loadReadOnlyBuffersCopiesOnWrite() { ++ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); ++ ByteBuffer originalByteBuffer = ByteBuffer.allocateDirect(1); ++ originalByteBuffer.put(new byte[] {99}); ++ originalByteBuffer.rewind(); ++ ByteBuffer readOnlyByteBuffer = originalByteBuffer.asReadOnlyBuffer(); ++ ++ tensorBuffer.loadBuffer(readOnlyByteBuffer, new int[] {1}); ++ assertThat(tensorBuffer.getBuffer()).isSameInstanceAs(readOnlyByteBuffer); ++ ++ tensorBuffer.loadArray(new int[] {42}); ++ assertThat(tensorBuffer.getBuffer()).isNotSameInstanceAs(readOnlyByteBuffer); ++ assertThat(tensorBuffer.getBuffer().get(0)).isEqualTo(42); // updated ++ assertThat(originalByteBuffer.get(0)).isEqualTo(99); // original one not changed ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java +index e843133275d61..1921f4e467d01 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java +@@ -26,51 +26,51 @@ import org.tensorflow.lite.DataType; + /** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBufferUint8}. */ + @RunWith(RobolectricTestRunner.class) + public final class TensorBufferUint8Test { +- @Test +- public void testCreateDynamic() { +- TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(); +- assertThat(tensorBufferUint8).isNotNull(); +- } ++ @Test ++ public void testCreateDynamic() { ++ TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(); ++ assertThat(tensorBufferUint8).isNotNull(); ++ } + +- @Test +- public void testCreateFixedSize() { +- int[] shape = new int[] {1, 2, 3}; +- TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape); +- assertThat(tensorBufferUint8).isNotNull(); +- assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6); +- } ++ @Test ++ public void testCreateFixedSize() { ++ int[] shape = new int[] {1, 2, 3}; ++ TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape); ++ assertThat(tensorBufferUint8).isNotNull(); ++ assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6); ++ } + +- @Test +- public void testCreateFixedSizeWithScalarShape() { +- int[] shape = new int[] {}; +- TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape); +- assertThat(tensorBufferUint8).isNotNull(); +- assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(1); +- } ++ @Test ++ public void testCreateFixedSizeWithScalarShape() { ++ int[] shape = new int[] {}; ++ TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape); ++ assertThat(tensorBufferUint8).isNotNull(); ++ assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(1); ++ } + +- @Test +- public void testCreateWithNullShape() { +- int[] shape = null; +- Assert.assertThrows(NullPointerException.class, () -> new TensorBufferUint8(shape)); +- } ++ @Test ++ public void testCreateWithNullShape() { ++ int[] shape = null; ++ Assert.assertThrows(NullPointerException.class, () -> new TensorBufferUint8(shape)); ++ } + +- @Test +- public void testCreateWithInvalidShape() { +- int[] shape = new int[] {1, -1, 2}; +- Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferUint8(shape)); +- } ++ @Test ++ public void testCreateWithInvalidShape() { ++ int[] shape = new int[] {1, -1, 2}; ++ Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferUint8(shape)); ++ } + +- @Test +- public void testCreateUsingShapeWithZero() { +- int[] shape = new int[] {1, 0, 2}; +- TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape); +- assertThat(tensorBufferUint8).isNotNull(); +- assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(0); +- } ++ @Test ++ public void testCreateUsingShapeWithZero() { ++ int[] shape = new int[] {1, 0, 2}; ++ TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape); ++ assertThat(tensorBufferUint8).isNotNull(); ++ assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(0); ++ } + +- @Test +- public void testGetDataType() { +- TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(); +- assertThat(tensorBufferUint8.getDataType()).isEqualTo(DataType.UINT8); +- } ++ @Test ++ public void testGetDataType() { ++ TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(); ++ assertThat(tensorBufferUint8.getDataType()).isEqualTo(DataType.UINT8); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc +index d62da546a484b..c3c21fa43ab49 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc +@@ -134,7 +134,8 @@ jobject ConvertToClassificationResults(JNIEnv* env, + } + + // Creates an AudioClassifierOptions proto based on the Java class. +-AudioClassifierOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options, ++AudioClassifierOptions ConvertToProtoOptions(JNIEnv* env, ++ jobject java_options, + jlong base_options_handle) { + AudioClassifierOptions proto_options; + +@@ -214,7 +215,9 @@ jlong CreateAudioClassifierFromOptions(JNIEnv* env, + + extern "C" JNIEXPORT void JNICALL + Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_deinitJni( +- JNIEnv* env, jobject thiz, jlong native_handle) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong native_handle) { + delete reinterpret_cast<AudioClassifier*>(native_handle); + } + +@@ -223,9 +226,13 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_deinitJni( + // values will be ignored. + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithModelFdAndOptions( +- JNIEnv* env, jclass thiz, jint file_descriptor, +- jlong file_descriptor_length, jlong file_descriptor_offset, +- jobject java_options, jlong base_options_handle) { ++ JNIEnv* env, ++ jclass thiz, ++ jint file_descriptor, ++ jlong file_descriptor_length, ++ jlong file_descriptor_offset, ++ jobject java_options, ++ jlong base_options_handle) { + AudioClassifierOptions proto_options = + ConvertToProtoOptions(env, java_options, base_options_handle); + auto file_descriptor_meta = proto_options.mutable_base_options() +@@ -243,7 +250,10 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithModelF + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithByteBuffer( +- JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options, ++ JNIEnv* env, ++ jclass thiz, ++ jobject model_buffer, ++ jobject java_options, + jlong base_options_handle) { + AudioClassifierOptions proto_options = + ConvertToProtoOptions(env, java_options, base_options_handle); +@@ -262,7 +272,9 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithByteBu + // caching it in JAVA layer. + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredSampleRateNative( +- JNIEnv* env, jclass thiz, jlong native_handle) { ++ JNIEnv* env, ++ jclass thiz, ++ jlong native_handle) { + auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle); + StatusOr<AudioBuffer::AudioFormat> format_or = + classifier->GetRequiredAudioFormat(); +@@ -279,7 +291,9 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredSample + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredChannelsNative( +- JNIEnv* env, jclass thiz, jlong native_handle) { ++ JNIEnv* env, ++ jclass thiz, ++ jlong native_handle) { + auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle); + StatusOr<AudioBuffer::AudioFormat> format_or = + classifier->GetRequiredAudioFormat(); +@@ -296,15 +310,21 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredChanne + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredInputBufferSizeNative( +- JNIEnv* env, jclass thiz, jlong native_handle) { ++ JNIEnv* env, ++ jclass thiz, ++ jlong native_handle) { + auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle); + return classifier->GetRequiredInputBufferSize(); + } + + extern "C" JNIEXPORT jobject JNICALL + Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_classifyNative( +- JNIEnv* env, jclass thiz, jlong native_handle, jbyteArray java_array, +- jint channels, jint sample_rate) { ++ JNIEnv* env, ++ jclass thiz, ++ jlong native_handle, ++ jbyteArray java_array, ++ jint channels, ++ jint sample_rate) { + // Get the primitive native array. Depending on the JAVA runtime, the returned + // array might be a copy of the JAVA array (or not). + jbyte* native_array = env->GetByteArrayElements(java_array, nullptr); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc +index 2fd1d7ca9a593..75f93d6f2e458 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc +@@ -30,7 +30,10 @@ using ::tflite::task::core::BaseOptions; + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_core_TaskJniUtils_createProtoBaseOptions( +- JNIEnv* env, jclass thiz, jint delegate, jint num_threads) { ++ JNIEnv* env, ++ jclass thiz, ++ jint delegate, ++ jint num_threads) { + StatusOr<Delegate> delegate_proto_or = ConvertToProtoDelegate(delegate); + if (!delegate_proto_or.ok()) { + ThrowException(env, kIllegalStateException, +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc +index 6657ef4ca2d95..2daacdf893903 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc +@@ -32,7 +32,9 @@ using ::tflite::task::text::BertNLClassifierOptions; + using ::tflite::task::text::nlclassifier::RunClassifier; + + BertNLClassifierOptions ConvertJavaBertNLClassifierOptions( +- JNIEnv* env, jobject java_options, jlong base_options_handle) { ++ JNIEnv* env, ++ jobject java_options, ++ jlong base_options_handle) { + BertNLClassifierOptions proto_options; + + if (base_options_handle != kInvalidPointer) { +@@ -47,13 +49,18 @@ BertNLClassifierOptions ConvertJavaBertNLClassifierOptions( + + extern "C" JNIEXPORT void JNICALL + Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni( +- JNIEnv* env, jobject thiz, jlong native_handle) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong native_handle) { + delete reinterpret_cast<BertNLClassifier*>(native_handle); + } + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByteBuffer( +- JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options, ++ JNIEnv* env, ++ jclass thiz, ++ jobject model_buffer, ++ jobject java_options, + jlong base_options_handle) { + BertNLClassifierOptions proto_options = ConvertJavaBertNLClassifierOptions( + env, java_options, base_options_handle); +@@ -76,7 +83,10 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByte + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFileDescriptor( +- JNIEnv* env, jclass thiz, jint fd, jobject java_options, ++ JNIEnv* env, ++ jclass thiz, ++ jint fd, ++ jobject java_options, + jlong base_options_handle) { + BertNLClassifierOptions proto_options = ConvertJavaBertNLClassifierOptions( + env, java_options, base_options_handle); +@@ -100,6 +110,9 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFile + + extern "C" JNIEXPORT jobject JNICALL + Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_classifyNative( +- JNIEnv* env, jclass clazz, jlong native_handle, jstring text) { ++ JNIEnv* env, ++ jclass clazz, ++ jlong native_handle, ++ jstring text) { + return RunClassifier(env, native_handle, text); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc +index f6d34a5f74e2b..4c71a80ea1528 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc +@@ -94,14 +94,19 @@ NLClassifierOptions ConvertToProtoOptions(JNIEnv* env, + + extern "C" JNIEXPORT void JNICALL + Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_deinitJni( +- JNIEnv* env, jobject thiz, jlong native_handle) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong native_handle) { + delete reinterpret_cast<NLClassifier*>(native_handle); + } + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithByteBuffer( +- JNIEnv* env, jclass thiz, jobject nl_classifier_options, +- jobject model_buffer, jlong base_options_handle) { ++ JNIEnv* env, ++ jclass thiz, ++ jobject nl_classifier_options, ++ jobject model_buffer, ++ jlong base_options_handle) { + auto model = GetMappedFileBuffer(env, model_buffer); + tflite::support::StatusOr<std::unique_ptr<NLClassifier>> classifier_or; + +@@ -125,7 +130,10 @@ Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithByteBuff + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithFileDescriptor( +- JNIEnv* env, jclass thiz, jobject nl_classifier_options, jint fd, ++ JNIEnv* env, ++ jclass thiz, ++ jobject nl_classifier_options, ++ jint fd, + jlong base_options_handle) { + tflite::support::StatusOr<std::unique_ptr<NLClassifier>> classifier_or; + +@@ -151,6 +159,9 @@ Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithFileDesc + + extern "C" JNIEXPORT jobject JNICALL + Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_classifyNative( +- JNIEnv* env, jclass thiz, jlong native_handle, jstring text) { ++ JNIEnv* env, ++ jclass thiz, ++ jlong native_handle, ++ jstring text) { + return RunClassifier(env, native_handle, text); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc +index c392c9a5a972f..401e6fbda3d9b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc +@@ -52,14 +52,19 @@ BertQuestionAnswererOptions ConvertToProtoOptions(jlong base_options_handle) { + + extern "C" JNIEXPORT void JNICALL + Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_deinitJni( +- JNIEnv* env, jobject thiz, jlong native_handle) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong native_handle) { + delete reinterpret_cast<QuestionAnswerer*>(native_handle); + } + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescriptor( +- JNIEnv* env, jclass thiz, jint file_descriptor, +- jlong file_descriptor_length, jlong file_descriptor_offset, ++ JNIEnv* env, ++ jclass thiz, ++ jint file_descriptor, ++ jlong file_descriptor_length, ++ jlong file_descriptor_offset, + jlong base_options_handle) { + BertQuestionAnswererOptions proto_options = + ConvertToProtoOptions(base_options_handle); +@@ -89,7 +94,9 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescri + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers( +- JNIEnv* env, jclass thiz, jobjectArray model_buffers) { ++ JNIEnv* env, ++ jclass thiz, ++ jobjectArray model_buffers) { + absl::string_view model = + GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0)); + absl::string_view vocab = +@@ -111,7 +118,9 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBu + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByteBuffers( +- JNIEnv* env, jclass thiz, jobjectArray model_buffers) { ++ JNIEnv* env, ++ jclass thiz, ++ jobjectArray model_buffers) { + absl::string_view model = + GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0)); + absl::string_view sp_model = +@@ -133,7 +142,10 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByte + + extern "C" JNIEXPORT jobject JNICALL + Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative( +- JNIEnv* env, jclass thiz, jlong native_handle, jstring context, ++ JNIEnv* env, ++ jclass thiz, ++ jlong native_handle, ++ jstring context, + jstring question) { + auto* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc +index 18e2ee1a7d4ab..2a713cf8b63cf 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc +@@ -54,7 +54,8 @@ using ::tflite::task::vision::ImageClassifier; + using ::tflite::task::vision::ImageClassifierOptions; + + // Creates an ImageClassifierOptions proto based on the Java class. +-ImageClassifierOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options, ++ImageClassifierOptions ConvertToProtoOptions(JNIEnv* env, ++ jobject java_options, + jlong base_options_handle) { + ImageClassifierOptions proto_options; + +@@ -175,7 +176,9 @@ jlong CreateImageClassifierFromOptions(JNIEnv* env, + + extern "C" JNIEXPORT void JNICALL + Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_deinitJni( +- JNIEnv* env, jobject thiz, jlong native_handle) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong native_handle) { + delete reinterpret_cast<ImageClassifier*>(native_handle); + } + +@@ -184,9 +187,13 @@ Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_deinitJni( + // values will be ignored. + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithModelFdAndOptions( +- JNIEnv* env, jclass thiz, jint file_descriptor, +- jlong file_descriptor_length, jlong file_descriptor_offset, +- jobject java_options, jlong base_options_handle) { ++ JNIEnv* env, ++ jclass thiz, ++ jint file_descriptor, ++ jlong file_descriptor_length, ++ jlong file_descriptor_offset, ++ jobject java_options, ++ jlong base_options_handle) { + ImageClassifierOptions proto_options = + ConvertToProtoOptions(env, java_options, base_options_handle); + auto file_descriptor_meta = proto_options.mutable_base_options() +@@ -204,7 +211,10 @@ Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithModel + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithByteBuffer( +- JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options, ++ JNIEnv* env, ++ jclass thiz, ++ jobject model_buffer, ++ jobject java_options, + jlong base_options_handle) { + ImageClassifierOptions proto_options = + ConvertToProtoOptions(env, java_options, base_options_handle); +@@ -220,7 +230,10 @@ Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithByteB + + extern "C" JNIEXPORT jobject JNICALL + Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_classifyNative( +- JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle, ++ JNIEnv* env, ++ jclass thiz, ++ jlong native_handle, ++ jlong frame_buffer_handle, + jintArray jroi) { + auto* classifier = reinterpret_cast<ImageClassifier*>(native_handle); + // frame_buffer will be deleted after inference is done in +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc +index 84bff227f2543..2cda1b500aeb5 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc +@@ -31,8 +31,13 @@ using ::tflite::task::vision::FrameBuffer; + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromByteBuffer( +- JNIEnv* env, jclass thiz, jobject jimage_byte_buffer, jint width, +- jint height, jint jorientation, jint jcolor_space_type) { ++ JNIEnv* env, ++ jclass thiz, ++ jobject jimage_byte_buffer, ++ jint width, ++ jint height, ++ jint jorientation, ++ jint jcolor_space_type) { + auto frame_buffer_or = CreateFrameBufferFromByteBuffer( + env, jimage_byte_buffer, width, height, jorientation, jcolor_space_type); + if (frame_buffer_or.ok()) { +@@ -49,8 +54,14 @@ Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFro + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromBytes( +- JNIEnv* env, jclass thiz, jbyteArray jimage_bytes, jint width, jint height, +- jint jorientation, jint jcolor_space_type, jlongArray jbyte_array_handle) { ++ JNIEnv* env, ++ jclass thiz, ++ jbyteArray jimage_bytes, ++ jint width, ++ jint height, ++ jint jorientation, ++ jint jcolor_space_type, ++ jlongArray jbyte_array_handle) { + auto frame_buffer_or = + CreateFrameBufferFromBytes(env, jimage_bytes, width, height, jorientation, + jcolor_space_type, jbyte_array_handle); +@@ -68,9 +79,17 @@ Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFro + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromPlanes( +- JNIEnv* env, jclass thiz, jobject jy_plane, jobject ju_plane, +- jobject jv_plane, jint width, jint height, jint row_stride_y, +- jint row_stride_uv, jint pixel_stride_uv, jint orientation) { ++ JNIEnv* env, ++ jclass thiz, ++ jobject jy_plane, ++ jobject ju_plane, ++ jobject jv_plane, ++ jint width, ++ jint height, ++ jint row_stride_y, ++ jint row_stride_uv, ++ jint pixel_stride_uv, ++ jint orientation) { + auto frame_buffer_or = CreateFrameBufferFromYuvPlanes( + env, jy_plane, ju_plane, jv_plane, width, height, row_stride_y, + row_stride_uv, pixel_stride_uv, orientation); +@@ -88,8 +107,11 @@ Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFro + + extern "C" JNIEXPORT void JNICALL + Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_deleteFrameBuffer( +- JNIEnv* env, jobject thiz, jlong frame_buffer_handle, +- jlong byte_array_handle, jbyteArray jbyte_array) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong frame_buffer_handle, ++ jlong byte_array_handle, ++ jbyteArray jbyte_array) { + delete reinterpret_cast<FrameBuffer*>(frame_buffer_handle); + jbyte* bytes_ptr = reinterpret_cast<jbyte*>(byte_array_handle); + if (bytes_ptr != NULL) { +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc +index ddb0b72a25b65..f720795263791 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc +@@ -54,7 +54,8 @@ using ::tflite::task::vision::ObjectDetector; + using ::tflite::task::vision::ObjectDetectorOptions; + + // Creates an ObjectDetectorOptions proto based on the Java class. +-ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options, ++ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env, ++ jobject java_options, + jlong base_options_handle) { + ObjectDetectorOptions proto_options; + +@@ -183,7 +184,9 @@ jlong CreateObjectDetectorFromOptions(JNIEnv* env, + + extern "C" JNIEXPORT void JNICALL + Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_deinitJni( +- JNIEnv* env, jobject thiz, jlong native_handle) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong native_handle) { + delete reinterpret_cast<ObjectDetector*>(native_handle); + } + +@@ -192,9 +195,13 @@ Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_deinitJni( + // values will be ignored. + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithModelFdAndOptions( +- JNIEnv* env, jclass thiz, jint file_descriptor, +- jlong file_descriptor_length, jlong file_descriptor_offset, +- jobject java_options, jlong base_options_handle) { ++ JNIEnv* env, ++ jclass thiz, ++ jint file_descriptor, ++ jlong file_descriptor_length, ++ jlong file_descriptor_offset, ++ jobject java_options, ++ jlong base_options_handle) { + ObjectDetectorOptions proto_options = + ConvertToProtoOptions(env, java_options, base_options_handle); + auto file_descriptor_meta = proto_options.mutable_base_options() +@@ -212,7 +219,10 @@ Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithModelFdA + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithByteBuffer( +- JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options, ++ JNIEnv* env, ++ jclass thiz, ++ jobject model_buffer, ++ jobject java_options, + jlong base_options_handle) { + ObjectDetectorOptions proto_options = + ConvertToProtoOptions(env, java_options, base_options_handle); +@@ -224,7 +234,10 @@ Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithByteBuff + + extern "C" JNIEXPORT jobject JNICALL + Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_detectNative( +- JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle) { ++ JNIEnv* env, ++ jclass thiz, ++ jlong native_handle, ++ jlong frame_buffer_handle) { + auto* detector = reinterpret_cast<ObjectDetector*>(native_handle); + // frame_buffer will be deleted after inference is done in + // base_vision_api_jni.cc. +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc +index 1b08e56ed509b..e0c94e2ec72c6 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc +@@ -135,8 +135,12 @@ StatusOr<FrameBuffer::Format> GetYUVImageFormat(const uint8* u_buffer, + } + + StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromByteBuffer( +- JNIEnv* env, jobject jimage_byte_buffer, jint width, jint height, +- jint jorientation, jint jcolor_space_type) { ++ JNIEnv* env, ++ jobject jimage_byte_buffer, ++ jint width, ++ jint height, ++ jint jorientation, ++ jint jcolor_space_type) { + absl::string_view image = GetMappedFileBuffer(env, jimage_byte_buffer); + return CreateFromRawBuffer( + reinterpret_cast<const uint8*>(image.data()), +@@ -146,8 +150,13 @@ StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromByteBuffer( + } + + StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromBytes( +- JNIEnv* env, jbyteArray jimage_bytes, jint width, jint height, +- jint jorientation, jint jcolor_space_type, jlongArray jbyte_array_handle) { ++ JNIEnv* env, ++ jbyteArray jimage_bytes, ++ jint width, ++ jint height, ++ jint jorientation, ++ jint jcolor_space_type, ++ jlongArray jbyte_array_handle) { + jbyte* jimage_ptr = env->GetByteArrayElements(jimage_bytes, NULL); + // Free jimage_ptr together with frame_buffer after inference is finished. + jlong jimage_ptr_handle = reinterpret_cast<jlong>(jimage_ptr); +@@ -168,9 +177,16 @@ StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromBytes( + } + + StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromYuvPlanes( +- JNIEnv* env, jobject jy_plane, jobject ju_plane, jobject jv_plane, +- jint width, jint height, jint row_stride_y, jint row_stride_uv, +- jint pixel_stride_uv, jint jorientation) { ++ JNIEnv* env, ++ jobject jy_plane, ++ jobject ju_plane, ++ jobject jv_plane, ++ jint width, ++ jint height, ++ jint row_stride_y, ++ jint row_stride_uv, ++ jint pixel_stride_uv, ++ jint jorientation) { + const uint8* y_plane = + reinterpret_cast<const uint8*>(GetMappedFileBuffer(env, jy_plane).data()); + const uint8* u_plane = +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h +index dbe32f8a3f2a5..4d7ec17a1c042 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h +@@ -34,23 +34,35 @@ FrameBuffer::Orientation ConvertToFrameBufferOrientation(JNIEnv* env, + + // Creates FrameBuffer from a direct ByteBuffer. + ::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> +-CreateFrameBufferFromByteBuffer(JNIEnv* env, jobject jimage_byte_buffer, +- jint width, jint height, jint jorientation, ++CreateFrameBufferFromByteBuffer(JNIEnv* env, ++ jobject jimage_byte_buffer, ++ jint width, ++ jint height, ++ jint jorientation, + jint jcolor_space_type); + + // Creates FrameBuffer from a byte array. + ::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> +-CreateFrameBufferFromBytes(JNIEnv* env, jbyteArray jimage_bytes, jint width, +- jint height, jint jorientation, ++CreateFrameBufferFromBytes(JNIEnv* env, ++ jbyteArray jimage_bytes, ++ jint width, ++ jint height, ++ jint jorientation, + jint jcolor_space_type, + jlongArray jbyte_array_handle); + + // Creates FrameBuffer from YUV planes. + ::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> +-CreateFrameBufferFromYuvPlanes(JNIEnv* env, jobject jy_plane, jobject ju_plane, +- jobject jv_plane, jint width, jint height, +- jint row_stride_y, jint row_stride_uv, +- jint pixel_stride_uv, jint jorientation); ++CreateFrameBufferFromYuvPlanes(JNIEnv* env, ++ jobject jy_plane, ++ jobject ju_plane, ++ jobject jv_plane, ++ jint width, ++ jint height, ++ jint row_stride_y, ++ jint row_stride_uv, ++ jint pixel_stride_uv, ++ jint jorientation); + } // namespace vision + } // namespace task + } // namespace tflite +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc +index 40fa4472d37e1..8d8c8eec34295 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc +@@ -194,7 +194,9 @@ jlong CreateImageSegmenterFromOptions(JNIEnv* env, + + extern "C" JNIEXPORT void JNICALL + Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni( +- JNIEnv* env, jobject thiz, jlong native_handle) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong native_handle) { + delete reinterpret_cast<ImageSegmenter*>(native_handle); + } + +@@ -203,9 +205,14 @@ Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni( + // values will be ignored. + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFdAndOptions( +- JNIEnv* env, jclass thiz, jint file_descriptor, +- jlong file_descriptor_length, jlong file_descriptor_offset, +- jstring display_names_locale, jint output_type, jlong base_options_handle) { ++ JNIEnv* env, ++ jclass thiz, ++ jint file_descriptor, ++ jlong file_descriptor_length, ++ jlong file_descriptor_offset, ++ jstring display_names_locale, ++ jint output_type, ++ jlong base_options_handle) { + ImageSegmenterOptions proto_options = ConvertToProtoOptions( + env, display_names_locale, output_type, base_options_handle); + auto file_descriptor_meta = proto_options.mutable_base_options() +@@ -223,8 +230,12 @@ Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFd + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuffer( +- JNIEnv* env, jclass thiz, jobject model_buffer, +- jstring display_names_locale, jint output_type, jlong base_options_handle) { ++ JNIEnv* env, ++ jclass thiz, ++ jobject model_buffer, ++ jstring display_names_locale, ++ jint output_type, ++ jlong base_options_handle) { + ImageSegmenterOptions proto_options = ConvertToProtoOptions( + env, display_names_locale, output_type, base_options_handle); + proto_options.mutable_base_options()->mutable_model_file()->set_file_content( +@@ -235,8 +246,13 @@ Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuf + + extern "C" JNIEXPORT void JNICALL + Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_segmentNative( +- JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle, +- jobject jmask_buffers, jintArray jmask_shape, jobject jcolored_labels) { ++ JNIEnv* env, ++ jclass thiz, ++ jlong native_handle, ++ jlong frame_buffer_handle, ++ jobject jmask_buffers, ++ jintArray jmask_shape, ++ jobject jcolored_labels) { + auto* segmenter = reinterpret_cast<ImageSegmenter*>(native_handle); + // frame_buffer will be deleted after inference is done in + // base_vision_api_jni.cc. +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc +index 7a9843d61d63c..3aae0aa0ec5c7 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc +@@ -17,11 +17,11 @@ limitations under the License. + + #include <functional> + +-#include "absl/memory/memory.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "flatbuffers/flatbuffers.h" // from @flatbuffers +-#include "lib/zip.h" // from @org_libzip ++#include "lib/zip.h" // from @org_libzip + #include "tensorflow/lite/schema/schema_generated.h" + #include "tensorflow_lite_support/cc/common.h" + #include "tensorflow_lite_support/cc/port/status_macros.h" +@@ -48,7 +48,8 @@ class SimpleCleanUp { + : callback_(std::move(callback)) {} + + ~SimpleCleanUp() { +- if (callback_ != nullptr) callback_(); ++ if (callback_ != nullptr) ++ callback_(); + } + + // Use `std::move(simple_cleanup).Cancel()` to prevent the callback from ever +@@ -63,7 +64,8 @@ class SimpleCleanUp { + // Util to get item from src_vector specified by index. + template <typename T> + const T* GetItemFromVector( +- const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector, int index) { ++ const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector, ++ int index) { + if (src_vector == nullptr || index < 0 || index >= src_vector->size()) { + return nullptr; + } +@@ -111,7 +113,8 @@ ModelMetadataExtractor::FindFirstProcessUnit( + /* static */ + std::string ModelMetadataExtractor::FindFirstAssociatedFileName( + const tflite::TensorMetadata& tensor_metadata, +- tflite::AssociatedFileType type, absl::string_view locale) { ++ tflite::AssociatedFileType type, ++ absl::string_view locale) { + if (tensor_metadata.associated_files() == nullptr) { + return std::string(); + } +@@ -128,7 +131,8 @@ std::string ModelMetadataExtractor::FindFirstAssociatedFileName( + } + + absl::Status ModelMetadataExtractor::InitFromModelBuffer( +- const char* buffer_data, size_t buffer_size) { ++ const char* buffer_data, ++ size_t buffer_size) { + // Rely on the simplest, base flatbuffers verifier. Here is not the place to + // e.g. use an OpResolver: we just want to make sure the buffer is valid to + // access the metadata. +@@ -187,7 +191,8 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer( + } + + absl::Status ModelMetadataExtractor::ExtractAssociatedFiles( +- const char* buffer_data, size_t buffer_size) { ++ const char* buffer_data, ++ size_t buffer_size) { + // Setup libzip error reporting. + zip_error_t error; + zip_error_init(&error); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h +index bff8cdf5ef43e..dc9a992aee2be 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h +@@ -16,8 +16,8 @@ limitations under the License. + #define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_ + + #include "absl/container/flat_hash_map.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/string_view.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/string_view.h" // from @com_google_absl + #include "tensorflow/lite/schema/schema_generated.h" + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h +index a18e19bdb7973..9037f5853744b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h +@@ -17,8 +17,8 @@ limitations under the License. + #define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_POPULATOR_H_ + + #include "absl/container/flat_hash_map.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "flatbuffers/flatbuffers.h" // from @flatbuffers ++#include "absl/status/status.h" // from @com_google_absl ++#include "flatbuffers/flatbuffers.h" // from @flatbuffers + #include "tensorflow/lite/schema/schema_generated.h" + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" +@@ -77,7 +77,8 @@ class ModelMetadataPopulator { + // Zips and appends associated files to the provided model buffer. Called + // internally by `Populate()`. + tflite::support::StatusOr<std::string> AppendAssociatedFiles( +- const char* model_buffer_data, size_t model_buffer_size); ++ const char* model_buffer_data, ++ size_t model_buffer_size); + + // The unpacked model FlatBuffer. + tflite::ModelT model_t_; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc +index fb3e01e00b76d..ed75b656e70a2 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc +@@ -22,7 +22,7 @@ limitations under the License. + #include <string> + #include <vector> + +-#include "absl/strings/str_join.h" // from @com_google_absl ++#include "absl/strings/str_join.h" // from @com_google_absl + #include "absl/strings/str_split.h" // from @com_google_absl + #include "flatbuffers/flatbuffers.h" // from @flatbuffers + #include "tensorflow/lite/c/common.h" +@@ -134,7 +134,8 @@ template <typename T> + void UpdateMinimumVersionForArray( + const flatbuffers::Vector<flatbuffers::Offset<T>>* array, + Version* min_version) { +- if (array == nullptr) return; ++ if (array == nullptr) ++ return; + + for (int i = 0; i < array->size(); ++i) { + UpdateMinimumVersionForTable<T>(array->Get(i), min_version); +@@ -143,8 +144,10 @@ void UpdateMinimumVersionForArray( + + template <> + void UpdateMinimumVersionForTable<tflite::AssociatedFile>( +- const tflite::AssociatedFile* table, Version* min_version) { +- if (table == nullptr) return; ++ const tflite::AssociatedFile* table, ++ Version* min_version) { ++ if (table == nullptr) ++ return; + + if (table->type() == AssociatedFileType_VOCABULARY) { + UpdateMinimumVersion( +@@ -155,8 +158,10 @@ void UpdateMinimumVersionForTable<tflite::AssociatedFile>( + + template <> + void UpdateMinimumVersionForTable<tflite::ProcessUnit>( +- const tflite::ProcessUnit* table, Version* min_version) { +- if (table == nullptr) return; ++ const tflite::ProcessUnit* table, ++ Version* min_version) { ++ if (table == nullptr) ++ return; + + tflite::ProcessUnitOptions process_unit_type = table->options_type(); + if (process_unit_type == ProcessUnitOptions_BertTokenizerOptions) { +@@ -182,7 +187,8 @@ void UpdateMinimumVersionForTable<tflite::ProcessUnit>( + template <> + void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table, + Version* min_version) { +- if (table == nullptr) return; ++ if (table == nullptr) ++ return; + + // Checks the ContenProperties field. + if (table->content_properties_type() == ContentProperties_AudioProperties) { +@@ -194,8 +200,10 @@ void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table, + + template <> + void UpdateMinimumVersionForTable<tflite::TensorMetadata>( +- const tflite::TensorMetadata* table, Version* min_version) { +- if (table == nullptr) return; ++ const tflite::TensorMetadata* table, ++ Version* min_version) { ++ if (table == nullptr) ++ return; + + // Checks the associated_files field. + UpdateMinimumVersionForArray<tflite::AssociatedFile>( +@@ -211,8 +219,10 @@ void UpdateMinimumVersionForTable<tflite::TensorMetadata>( + + template <> + void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>( +- const tflite::SubGraphMetadata* table, Version* min_version) { +- if (table == nullptr) return; ++ const tflite::SubGraphMetadata* table, ++ Version* min_version) { ++ if (table == nullptr) ++ return; + + // Checks in the input/output metadata arrays. + UpdateMinimumVersionForArray<tflite::TensorMetadata>( +@@ -259,7 +269,8 @@ void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>( + + template <> + void UpdateMinimumVersionForTable<tflite::ModelMetadata>( +- const tflite::ModelMetadata* table, Version* min_version) { ++ const tflite::ModelMetadata* table, ++ Version* min_version) { + if (table == nullptr) { + // Should never happen, because VerifyModelMetadataBuffer has verified it. + TFLITE_LOG(FATAL) << "The ModelMetadata object is null."; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc +index 6185722504f69..8e00452bea983 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc +@@ -14,7 +14,7 @@ limitations under the License. + ==============================================================================*/ + + #include "flatbuffers/flatbuffers.h" // from @flatbuffers +-#include "flatbuffers/idl.h" // from @flatbuffers ++#include "flatbuffers/idl.h" // from @flatbuffers + #include "pybind11/pybind11.h" + #include "pybind11/pytypes.h" + #include "pybind11/stl.h" +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java +index 6c3d23270f3f0..15bcb45c1a4b1 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java +@@ -33,84 +33,84 @@ import java.nio.ByteBuffer; + * synchronized as well. + */ + final class BoundedInputStream extends InputStream { +- private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1); +- private final long end; // The valid data for the stream is between [start, end). +- private long position; +- private final SeekableByteChannelCompat channel; +- +- /** +- * Creates a {@link BoundedInputStream} with a {@link SeekableByteChannelCompat}. +- * +- * @param channel the {@link SeekableByteChannelCompat} that backs up this {@link +- * BoundedInputStream} +- * @param start the starting position of this {@link BoundedInputStream} in the given {@link +- * SeekableByteChannelCompat} +- * @param remaining the length of this {@link BoundedInputStream} +- * @throws IllegalArgumentException if {@code start} or {@code remaining} is negative +- */ +- BoundedInputStream(SeekableByteChannelCompat channel, long start, long remaining) { +- checkArgument( +- remaining >= 0 && start >= 0, +- String.format("Invalid length of stream at offset=%d, length=%d", start, remaining)); +- +- end = start + remaining; +- this.channel = channel; +- position = start; +- } +- +- @Override +- public int available() throws IOException { +- return (int) (Math.min(end, channel.size()) - position); +- } +- +- @Override +- public int read() throws IOException { +- if (position >= end) { +- return -1; ++ private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1); ++ private final long end; // The valid data for the stream is between [start, end). ++ private long position; ++ private final SeekableByteChannelCompat channel; ++ ++ /** ++ * Creates a {@link BoundedInputStream} with a {@link SeekableByteChannelCompat}. ++ * ++ * @param channel the {@link SeekableByteChannelCompat} that backs up this {@link ++ * BoundedInputStream} ++ * @param start the starting position of this {@link BoundedInputStream} in the given {@link ++ * SeekableByteChannelCompat} ++ * @param remaining the length of this {@link BoundedInputStream} ++ * @throws IllegalArgumentException if {@code start} or {@code remaining} is negative ++ */ ++ BoundedInputStream(SeekableByteChannelCompat channel, long start, long remaining) { ++ checkArgument(remaining >= 0 && start >= 0, ++ String.format( ++ "Invalid length of stream at offset=%d, length=%d", start, remaining)); ++ ++ end = start + remaining; ++ this.channel = channel; ++ position = start; + } + +- singleByteBuffer.rewind(); +- int count = read(position, singleByteBuffer); +- if (count < 0) { +- return count; ++ @Override ++ public int available() throws IOException { ++ return (int) (Math.min(end, channel.size()) - position); + } + +- position++; +- return singleByteBuffer.get() & 0xff; +- } ++ @Override ++ public int read() throws IOException { ++ if (position >= end) { ++ return -1; ++ } + +- @Override +- public int read(byte[] b, int off, int len) throws IOException { +- checkNotNull(b); +- checkElementIndex(off, b.length, "The start offset"); +- checkElementIndex(len, b.length - off + 1, "The maximumn number of bytes to read"); ++ singleByteBuffer.rewind(); ++ int count = read(position, singleByteBuffer); ++ if (count < 0) { ++ return count; ++ } + +- if (len == 0) { +- return 0; ++ position++; ++ return singleByteBuffer.get() & 0xff; + } + +- if (len > end - position) { +- if (position >= end) { +- return -1; +- } +- len = (int) (end - position); ++ @Override ++ public int read(byte[] b, int off, int len) throws IOException { ++ checkNotNull(b); ++ checkElementIndex(off, b.length, "The start offset"); ++ checkElementIndex(len, b.length - off + 1, "The maximumn number of bytes to read"); ++ ++ if (len == 0) { ++ return 0; ++ } ++ ++ if (len > end - position) { ++ if (position >= end) { ++ return -1; ++ } ++ len = (int) (end - position); ++ } ++ ++ ByteBuffer buf = ByteBuffer.wrap(b, off, len); ++ int count = read(position, buf); ++ if (count > 0) { ++ position += count; ++ } ++ return count; + } + +- ByteBuffer buf = ByteBuffer.wrap(b, off, len); +- int count = read(position, buf); +- if (count > 0) { +- position += count; ++ private int read(long position, ByteBuffer buf) throws IOException { ++ int count; ++ synchronized (channel) { ++ channel.position(position); ++ count = channel.read(buf); ++ } ++ buf.flip(); ++ return count; + } +- return count; +- } +- +- private int read(long position, ByteBuffer buf) throws IOException { +- int count; +- synchronized (channel) { +- channel.position(position); +- count = channel.read(buf); +- } +- buf.flip(); +- return count; +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java +index e5d54a415edc4..354119b02822e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java +@@ -15,116 +15,114 @@ limitations under the License. + + package org.tensorflow.lite.support.metadata; + +-import static java.lang.Math.min; + import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument; + import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull; + ++import static java.lang.Math.min; ++ + import java.nio.ByteBuffer; + import java.nio.channels.NonWritableChannelException; + + /** Implements the {@link SeekableByteChannelCompat} on top of {@link ByteBuffer}. */ + final class ByteBufferChannel implements SeekableByteChannelCompat { ++ /** The ByteBuffer that holds the data. */ ++ private final ByteBuffer buffer; ++ ++ /** ++ * Creates a {@link ByteBufferChannel} that wraps a {@link ByteBuffer}. ++ * ++ * @param buffer the {@link ByteBuffer} that backs this {@link ByteBufferChannel} ++ * @throws NullPointerException if {@code buffer} is null ++ */ ++ public ByteBufferChannel(ByteBuffer buffer) { ++ checkNotNull(buffer, "The ByteBuffer cannot be null."); ++ this.buffer = buffer; ++ } ++ ++ @Override ++ public void close() {} + +- /** The ByteBuffer that holds the data. */ +- private final ByteBuffer buffer; +- +- /** +- * Creates a {@link ByteBufferChannel} that wraps a {@link ByteBuffer}. +- * +- * @param buffer the {@link ByteBuffer} that backs this {@link ByteBufferChannel} +- * @throws NullPointerException if {@code buffer} is null +- */ +- public ByteBufferChannel(ByteBuffer buffer) { +- checkNotNull(buffer, "The ByteBuffer cannot be null."); +- this.buffer = buffer; +- } +- +- @Override +- public void close() {} +- +- @Override +- public boolean isOpen() { +- return true; +- } +- +- @Override +- public long position() { +- return buffer.position(); +- } +- +- /** +- * Sets this channel's position. +- * +- * @param newPosition the new position, a non-negative integer counting the number of bytes from +- * the beginning of the entity +- * @return this channel +- * @throws IllegalArgumentException if the new position is negative, or greater than the size of +- * the underlying {@link ByteBuffer}, or greater than Integer.MAX_VALUE +- */ +- @Override +- public synchronized ByteBufferChannel position(long newPosition) { +- checkArgument( +- (newPosition >= 0 && newPosition <= Integer.MAX_VALUE), +- "The new position should be non-negative and be less than Integer.MAX_VALUE."); +- buffer.position((int) newPosition); +- return this; +- } +- +- /** +- * {@inheritDoc} +- * +- * <p>Bytes are read starting at this channel's current position, and then the position is updated +- * with the number of bytes actually read. Otherwise this method behaves exactly as specified in +- * the {@link ReadableByteChannel} interface. +- */ +- @Override +- public synchronized int read(ByteBuffer dst) { +- if (buffer.remaining() == 0) { +- return -1; ++ @Override ++ public boolean isOpen() { ++ return true; + } + +- int count = min(dst.remaining(), buffer.remaining()); +- if (count > 0) { +- ByteBuffer tempBuffer = buffer.slice(); +- tempBuffer.order(buffer.order()).limit(count); +- dst.put(tempBuffer); +- buffer.position(buffer.position() + count); ++ @Override ++ public long position() { ++ return buffer.position(); + } +- return count; +- } +- +- @Override +- public long size() { +- return buffer.limit(); +- } +- +- @Override +- public synchronized ByteBufferChannel truncate(long size) { +- checkArgument( +- (size >= 0 && size <= Integer.MAX_VALUE), +- "The new size should be non-negative and be less than Integer.MAX_VALUE."); +- +- if (size < buffer.limit()) { +- buffer.limit((int) size); +- if (buffer.position() > size) { +- buffer.position((int) size); +- } ++ ++ /** ++ * Sets this channel's position. ++ * ++ * @param newPosition the new position, a non-negative integer counting the number of bytes from ++ * the beginning of the entity ++ * @return this channel ++ * @throws IllegalArgumentException if the new position is negative, or greater than the size of ++ * the underlying {@link ByteBuffer}, or greater than Integer.MAX_VALUE ++ */ ++ @Override ++ public synchronized ByteBufferChannel position(long newPosition) { ++ checkArgument((newPosition >= 0 && newPosition <= Integer.MAX_VALUE), ++ "The new position should be non-negative and be less than Integer.MAX_VALUE."); ++ buffer.position((int) newPosition); ++ return this; ++ } ++ ++ /** ++ * {@inheritDoc} ++ * ++ * <p>Bytes are read starting at this channel's current position, and then the position is ++ * updated with the number of bytes actually read. Otherwise this method behaves exactly as ++ * specified in the {@link ReadableByteChannel} interface. ++ */ ++ @Override ++ public synchronized int read(ByteBuffer dst) { ++ if (buffer.remaining() == 0) { ++ return -1; ++ } ++ ++ int count = min(dst.remaining(), buffer.remaining()); ++ if (count > 0) { ++ ByteBuffer tempBuffer = buffer.slice(); ++ tempBuffer.order(buffer.order()).limit(count); ++ dst.put(tempBuffer); ++ buffer.position(buffer.position() + count); ++ } ++ return count; ++ } ++ ++ @Override ++ public long size() { ++ return buffer.limit(); + } +- return this; +- } + +- @Override +- public synchronized int write(ByteBuffer src) { +- if (buffer.isReadOnly()) { +- throw new NonWritableChannelException(); ++ @Override ++ public synchronized ByteBufferChannel truncate(long size) { ++ checkArgument((size >= 0 && size <= Integer.MAX_VALUE), ++ "The new size should be non-negative and be less than Integer.MAX_VALUE."); ++ ++ if (size < buffer.limit()) { ++ buffer.limit((int) size); ++ if (buffer.position() > size) { ++ buffer.position((int) size); ++ } ++ } ++ return this; + } + +- int count = min(src.remaining(), buffer.remaining()); +- if (count > 0) { +- ByteBuffer tempBuffer = src.slice(); +- tempBuffer.order(buffer.order()).limit(count); +- buffer.put(tempBuffer); ++ @Override ++ public synchronized int write(ByteBuffer src) { ++ if (buffer.isReadOnly()) { ++ throw new NonWritableChannelException(); ++ } ++ ++ int count = min(src.remaining(), buffer.remaining()); ++ if (count > 0) { ++ ByteBuffer tempBuffer = src.slice(); ++ tempBuffer.order(buffer.order()).limit(count); ++ buffer.put(tempBuffer); ++ } ++ return count; + } +- return count; +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java +index 183d416481156..3fb3c48118748 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java +@@ -17,15 +17,16 @@ package org.tensorflow.lite.support.metadata; + + import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument; + ++import org.checkerframework.checker.nullness.qual.Nullable; ++import org.tensorflow.lite.schema.Tensor; ++import org.tensorflow.lite.support.metadata.schema.ModelMetadata; ++import org.tensorflow.lite.support.metadata.schema.TensorMetadata; ++ + import java.io.IOException; + import java.io.InputStream; + import java.nio.ByteBuffer; + import java.util.Set; + import java.util.zip.ZipException; +-import org.checkerframework.checker.nullness.qual.Nullable; +-import org.tensorflow.lite.schema.Tensor; +-import org.tensorflow.lite.support.metadata.schema.ModelMetadata; +-import org.tensorflow.lite.support.metadata.schema.TensorMetadata; + + /** + * Loads metadata from TFLite Model FlatBuffer. +@@ -53,328 +54,329 @@ import org.tensorflow.lite.support.metadata.schema.TensorMetadata; + * MetadataExtractor} omits subgraph index as an input in its methods. + */ + public class MetadataExtractor { ++ /** The helper class to load metadata from TFLite model FlatBuffer. */ ++ private final ModelInfo modelInfo; ++ ++ /** The helper class to load metadata from TFLite metadata FlatBuffer. */ ++ @Nullable ++ private final ModelMetadataInfo metadataInfo; ++ ++ /** The handler to load associated files through zip. */ ++ @Nullable ++ private final ZipFile zipFile; ++ ++ /** ++ * Creates a {@link MetadataExtractor} with TFLite model FlatBuffer. ++ * ++ * @param buffer the TFLite model FlatBuffer ++ * @throws IllegalArgumentException if the number of input or output tensors in the model does ++ * not ++ * match that in the metadata ++ * @throws IOException if an error occurs while reading the model as a Zip file ++ */ ++ public MetadataExtractor(ByteBuffer buffer) throws IOException { ++ modelInfo = new ModelInfo(buffer); ++ ByteBuffer metadataBuffer = modelInfo.getMetadataBuffer(); ++ if (metadataBuffer != null) { ++ metadataInfo = new ModelMetadataInfo(metadataBuffer); ++ ++ // Prints warning message if the minimum parser version is not satisfied. ++ if (!isMinimumParserVersionSatisfied()) { ++ System.err.printf( ++ "<Warning> Some fields in the metadata belong to a future schema. The minimum parser" ++ + " version required is %s, but the version of the current metadata parser is %s", ++ metadataInfo.getMininumParserVersion(), MetadataParser.VERSION); ++ } ++ ++ checkArgument(modelInfo.getInputTensorCount() == metadataInfo.getInputTensorCount(), ++ String.format( ++ "The number of input tensors in the model is %d. The number of input tensors that" ++ + " recorded in the metadata is %d. These two values does not match.", ++ modelInfo.getInputTensorCount(), metadataInfo.getInputTensorCount())); ++ checkArgument(modelInfo.getOutputTensorCount() == metadataInfo.getOutputTensorCount(), ++ String.format( ++ "The number of output tensors in the model is %d. The number of output tensors that" ++ + " recorded in the metadata is %d. These two values does not match.", ++ modelInfo.getOutputTensorCount(), metadataInfo.getOutputTensorCount())); ++ } else { ++ // It is allowed to pass in a model FlatBuffer without TFLite metadata. However, ++ // invoking methods that read from TFLite metadata will cause runtime errors. ++ metadataInfo = null; ++ } ++ ++ zipFile = createZipFile(buffer); ++ } + +- /** The helper class to load metadata from TFLite model FlatBuffer. */ +- private final ModelInfo modelInfo; +- +- /** The helper class to load metadata from TFLite metadata FlatBuffer. */ +- @Nullable private final ModelMetadataInfo metadataInfo; +- +- /** The handler to load associated files through zip. */ +- @Nullable private final ZipFile zipFile; +- +- /** +- * Creates a {@link MetadataExtractor} with TFLite model FlatBuffer. +- * +- * @param buffer the TFLite model FlatBuffer +- * @throws IllegalArgumentException if the number of input or output tensors in the model does not +- * match that in the metadata +- * @throws IOException if an error occurs while reading the model as a Zip file +- */ +- public MetadataExtractor(ByteBuffer buffer) throws IOException { +- modelInfo = new ModelInfo(buffer); +- ByteBuffer metadataBuffer = modelInfo.getMetadataBuffer(); +- if (metadataBuffer != null) { +- metadataInfo = new ModelMetadataInfo(metadataBuffer); +- +- // Prints warning message if the minimum parser version is not satisfied. +- if (!isMinimumParserVersionSatisfied()) { +- System.err.printf( +- "<Warning> Some fields in the metadata belong to a future schema. The minimum parser" +- + " version required is %s, but the version of the current metadata parser is %s", +- metadataInfo.getMininumParserVersion(), MetadataParser.VERSION); +- } +- +- checkArgument( +- modelInfo.getInputTensorCount() == metadataInfo.getInputTensorCount(), +- String.format( +- "The number of input tensors in the model is %d. The number of input tensors that" +- + " recorded in the metadata is %d. These two values does not match.", +- modelInfo.getInputTensorCount(), metadataInfo.getInputTensorCount())); +- checkArgument( +- modelInfo.getOutputTensorCount() == metadataInfo.getOutputTensorCount(), +- String.format( +- "The number of output tensors in the model is %d. The number of output tensors that" +- + " recorded in the metadata is %d. These two values does not match.", +- modelInfo.getOutputTensorCount(), metadataInfo.getOutputTensorCount())); +- } else { +- // It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking +- // methods that read from TFLite metadata will cause runtime errors. +- metadataInfo = null; ++ /** ++ * Quantization parameters that corresponds to the table, {@code QuantizationParameters}, in the ++ * <a ++ * href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite ++ * Model schema file.</a> ++ * ++ * <p>Since per-channel quantization does not apply to input and output tensors, {@code scale} ++ * and ++ * {@code zero_point} are both single values instead of arrays. ++ * ++ * <p>For tensor that are not quantized, the values of scale and zero_point are both 0. ++ * ++ * <p>Given a quantized value q, the corresponding float value f should be: <br> ++ * f = scale * (q - zero_point) <br> ++ */ ++ public static class QuantizationParams { ++ /** The scale value used in quantization. */ ++ private final float scale; ++ /** The zero point value used in quantization. */ ++ private final int zeroPoint; ++ ++ /** ++ * Creates a {@link QuantizationParams} with {@code scale} and {@code zero_point}. ++ * ++ * @param scale The scale value used in quantization. ++ * @param zeroPoint The zero point value used in quantization. ++ */ ++ public QuantizationParams(final float scale, final int zeroPoint) { ++ this.scale = scale; ++ this.zeroPoint = zeroPoint; ++ } ++ ++ /** Returns the scale value. */ ++ public float getScale() { ++ return scale; ++ } ++ ++ /** Returns the zero point value. */ ++ public int getZeroPoint() { ++ return zeroPoint; ++ } + } + +- zipFile = createZipFile(buffer); +- } +- +- /** +- * Quantization parameters that corresponds to the table, {@code QuantizationParameters}, in the +- * <a +- * href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite +- * Model schema file.</a> +- * +- * <p>Since per-channel quantization does not apply to input and output tensors, {@code scale} and +- * {@code zero_point} are both single values instead of arrays. +- * +- * <p>For tensor that are not quantized, the values of scale and zero_point are both 0. +- * +- * <p>Given a quantized value q, the corresponding float value f should be: <br> +- * f = scale * (q - zero_point) <br> +- */ +- public static class QuantizationParams { +- /** The scale value used in quantization. */ +- private final float scale; +- /** The zero point value used in quantization. */ +- private final int zeroPoint; ++ /** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */ ++ public boolean hasMetadata() { ++ return metadataInfo != null; ++ } + + /** +- * Creates a {@link QuantizationParams} with {@code scale} and {@code zero_point}. ++ * Gets the packed associated file with the specified {@code fileName}. + * +- * @param scale The scale value used in quantization. +- * @param zeroPoint The zero point value used in quantization. ++ * @param fileName the name of the associated file ++ * @return the raw input stream containing specified file ++ * @throws IllegalStateException if the model is not a zip file ++ * @throws IllegalArgumentException if the specified file does not exist in the model + */ +- public QuantizationParams(final float scale, final int zeroPoint) { +- this.scale = scale; +- this.zeroPoint = zeroPoint; ++ public InputStream getAssociatedFile(String fileName) { ++ assertZipFile(); ++ return zipFile.getRawInputStream(fileName); + } + +- /** Returns the scale value. */ +- public float getScale() { +- return scale; ++ /** ++ * Gets the file names of the associated files. ++ * ++ * @return the file names of the associated files ++ * @throws IllegalStateException if the model is not a zip file ++ */ ++ public Set<String> getAssociatedFileNames() { ++ assertZipFile(); ++ return zipFile.getFileNames(); + } + +- /** Returns the zero point value. */ +- public int getZeroPoint() { +- return zeroPoint; ++ /** Gets the count of input tensors in the model. */ ++ public int getInputTensorCount() { ++ return modelInfo.getInputTensorCount(); + } +- } +- +- /** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */ +- public boolean hasMetadata() { +- return metadataInfo != null; +- } +- +- /** +- * Gets the packed associated file with the specified {@code fileName}. +- * +- * @param fileName the name of the associated file +- * @return the raw input stream containing specified file +- * @throws IllegalStateException if the model is not a zip file +- * @throws IllegalArgumentException if the specified file does not exist in the model +- */ +- public InputStream getAssociatedFile(String fileName) { +- assertZipFile(); +- return zipFile.getRawInputStream(fileName); +- } +- +- /** +- * Gets the file names of the associated files. +- * +- * @return the file names of the associated files +- * @throws IllegalStateException if the model is not a zip file +- */ +- public Set<String> getAssociatedFileNames() { +- assertZipFile(); +- return zipFile.getFileNames(); +- } +- +- /** Gets the count of input tensors in the model. */ +- public int getInputTensorCount() { +- return modelInfo.getInputTensorCount(); +- } +- +- /** +- * Gets the metadata for the input tensor specified by {@code inputIndex}. +- * +- * @param inputIndex the index of the desired input tensor +- * @throws IllegalStateException if this model does not contain model metadata +- */ +- @Nullable +- public TensorMetadata getInputTensorMetadata(int inputIndex) { +- assertMetadataInfo(); +- return metadataInfo.getInputTensorMetadata(inputIndex); +- } +- +- /** +- * Gets the quantization parameters for the input tensor specified by {@code inputIndex}. +- * +- * @param inputIndex the index of the desired input tensor +- */ +- public QuantizationParams getInputTensorQuantizationParams(int inputIndex) { +- Tensor tensor = modelInfo.getInputTensor(inputIndex); +- return modelInfo.getQuantizationParams(tensor); +- } +- +- /** +- * Gets the shape of the input tensor with {@code inputIndex}. +- * +- * @param inputIndex the index of the desired input tensor +- */ +- public int[] getInputTensorShape(int inputIndex) { +- return modelInfo.getInputTensorShape(inputIndex); +- } +- +- /** +- * Gets the {@link TensorType} of the input tensor with {@code inputIndex}. +- * +- * @param inputIndex the index of the desired input tensor +- */ +- public byte getInputTensorType(int inputIndex) { +- return modelInfo.getInputTensorType(inputIndex); +- } +- +- /** +- * Gets the root handler for the model metadata. +- * +- * @throws IllegalStateException if this model does not contain model metadata +- */ +- public ModelMetadata getModelMetadata() { +- assertMetadataInfo(); +- return metadataInfo.getModelMetadata(); +- } +- +- /** Gets the count of output tensors in the model. */ +- public int getOutputTensorCount() { +- return modelInfo.getOutputTensorCount(); +- } +- +- /** +- * Gets the metadata for the output tensor specified by {@code outputIndex}. +- * +- * @param outputIndex the index of the desired output tensor +- * @throws IllegalStateException if this model does not contain model metadata +- */ +- @Nullable +- public TensorMetadata getOutputTensorMetadata(int outputIndex) { +- assertMetadataInfo(); +- return metadataInfo.getOutputTensorMetadata(outputIndex); +- } +- +- /** +- * Gets the quantization parameters for the output tensor specified by {@code outputIndex}. +- * +- * @param outputIndex the index of the desired output tensor +- */ +- public QuantizationParams getOutputTensorQuantizationParams(int outputIndex) { +- Tensor tensor = modelInfo.getOutputTensor(outputIndex); +- return modelInfo.getQuantizationParams(tensor); +- } +- +- /** +- * Gets the shape of the output tensor with {@code outputIndex}. +- * +- * @param outputIndex the index of the desired output tensor +- */ +- public int[] getOutputTensorShape(int outputIndex) { +- return modelInfo.getOutputTensorShape(outputIndex); +- } +- +- /** +- * Gets the {@link TensorType} of the output tensor with {@code outputIndex}. +- * +- * @param outputIndex the index of the desired output tensor +- */ +- public byte getOutputTensorType(int outputIndex) { +- return modelInfo.getOutputTensorType(outputIndex); +- } +- +- /** +- * Returns {@code true} if the minimum parser version required by the given metadata flatbuffer +- * precedes or equals to the version of the metadata parser that this MetadataExtractor library is +- * relying on. All fields in the metadata can be parsed correctly with this metadata extractor +- * library in this case. Otherwise, it returns {@code false}. +- * +- * <p>For example, assume the underlying metadata parser version is {@code 1.14.1}, +- * +- * <ul> +- * <li>it returns {@code true}, if the required minimum parser version is the same or older, +- * such as {@code 1.14.1} or {@code 1.14.0}. Null version precedes all numeric versions, +- * because some metadata flatbuffers are generated before the first versioned release; <br> +- * <li>it returns {@code false}, if the required minimum parser version is newer, such as {@code +- * 1.14.2}. +- * </ul> +- */ +- public final boolean isMinimumParserVersionSatisfied() { +- String minVersion = metadataInfo.getMininumParserVersion(); +- if (minVersion == null) { +- return true; ++ ++ /** ++ * Gets the metadata for the input tensor specified by {@code inputIndex}. ++ * ++ * @param inputIndex the index of the desired input tensor ++ * @throws IllegalStateException if this model does not contain model metadata ++ */ ++ @Nullable ++ public TensorMetadata getInputTensorMetadata(int inputIndex) { ++ assertMetadataInfo(); ++ return metadataInfo.getInputTensorMetadata(inputIndex); + } +- return compareVersions(minVersion, MetadataParser.VERSION) <= 0; +- } +- +- /** +- * Asserts if {@link #metadataInfo} is not initialized. Some models may not have metadata and this +- * is allowed. However, invoking methods that reads the metadata is not allowed. +- * +- * @throws IllegalStateException if this model does not contain model metadata +- */ +- private void assertMetadataInfo() { +- if (metadataInfo == null) { +- throw new IllegalStateException("This model does not contain model metadata."); ++ ++ /** ++ * Gets the quantization parameters for the input tensor specified by {@code inputIndex}. ++ * ++ * @param inputIndex the index of the desired input tensor ++ */ ++ public QuantizationParams getInputTensorQuantizationParams(int inputIndex) { ++ Tensor tensor = modelInfo.getInputTensor(inputIndex); ++ return modelInfo.getQuantizationParams(tensor); + } +- } +- +- /** +- * Asserts if {@link #zipFile} is not initialized. Some models may not have associated files, thus +- * are not Zip files. This is allowed. However, invoking methods that reads those associated files +- * is not allowed. +- * +- * @throws IllegalStateException if this model is not a Zip file +- */ +- private void assertZipFile() { +- if (zipFile == null) { +- throw new IllegalStateException( +- "This model does not contain associated files, and is not a Zip file."); ++ ++ /** ++ * Gets the shape of the input tensor with {@code inputIndex}. ++ * ++ * @param inputIndex the index of the desired input tensor ++ */ ++ public int[] getInputTensorShape(int inputIndex) { ++ return modelInfo.getInputTensorShape(inputIndex); + } +- } +- +- /** +- * Creates a Zip file handler to read the associated files. If the model is not a zip file, i.e. +- * it does not have associated files, return a null handler. +- * +- * @param buffer the TFLite model FlatBuffer +- * @throws IOException if an error occurs while reading the model as a Zip file +- */ +- @Nullable +- private static ZipFile createZipFile(ByteBuffer buffer) throws IOException { +- try { +- // Creates the handler to hold the associated files through the Zip. +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(buffer); +- return ZipFile.createFrom(byteBufferChannel); +- } catch (ZipException e) { +- // Some models may not have associate files. Therefore, Those models are not zip files. +- // However, invoking methods that read associated files later will lead into errors. +- return null; ++ ++ /** ++ * Gets the {@link TensorType} of the input tensor with {@code inputIndex}. ++ * ++ * @param inputIndex the index of the desired input tensor ++ */ ++ public byte getInputTensorType(int inputIndex) { ++ return modelInfo.getInputTensorType(inputIndex); + } +- } +- +- /** +- * Compares two semantic version numbers. +- * +- * <p>Examples of comparing two versions: <br> +- * {@code 1.9} precedes {@code 1.14}; <br> +- * {@code 1.14} precedes {@code 1.14.1}; <br> +- * {@code 1.14} and {@code 1.14.0} are euqal; +- * +- * @return the value {@code 0} if the two versions are equal; a value less than {@code 0} if +- * {@code version1} precedes {@code version2}; a value greater than {@code 0} if {@code +- * version2} precedes {@code version1}. +- */ +- private static int compareVersions(String version1, String version2) { +- // Using String.split instead of the recommanded Guava Splitter because we've been avoiding +- // depending on other third party libraries in this project. +- String[] levels1 = version1.split("\\.", 0); +- String[] levels2 = version2.split("\\.", 0); +- +- int length = Math.max(levels1.length, levels2.length); +- for (int i = 0; i < length; i++) { +- Integer v1 = i < levels1.length ? Integer.parseInt(levels1[i]) : 0; +- Integer v2 = i < levels2.length ? Integer.parseInt(levels2[i]) : 0; +- int compare = v1.compareTo(v2); +- if (compare != 0) { +- return compare; +- } ++ ++ /** ++ * Gets the root handler for the model metadata. ++ * ++ * @throws IllegalStateException if this model does not contain model metadata ++ */ ++ public ModelMetadata getModelMetadata() { ++ assertMetadataInfo(); ++ return metadataInfo.getModelMetadata(); ++ } ++ ++ /** Gets the count of output tensors in the model. */ ++ public int getOutputTensorCount() { ++ return modelInfo.getOutputTensorCount(); + } + +- return 0; +- } ++ /** ++ * Gets the metadata for the output tensor specified by {@code outputIndex}. ++ * ++ * @param outputIndex the index of the desired output tensor ++ * @throws IllegalStateException if this model does not contain model metadata ++ */ ++ @Nullable ++ public TensorMetadata getOutputTensorMetadata(int outputIndex) { ++ assertMetadataInfo(); ++ return metadataInfo.getOutputTensorMetadata(outputIndex); ++ } ++ ++ /** ++ * Gets the quantization parameters for the output tensor specified by {@code outputIndex}. ++ * ++ * @param outputIndex the index of the desired output tensor ++ */ ++ public QuantizationParams getOutputTensorQuantizationParams(int outputIndex) { ++ Tensor tensor = modelInfo.getOutputTensor(outputIndex); ++ return modelInfo.getQuantizationParams(tensor); ++ } ++ ++ /** ++ * Gets the shape of the output tensor with {@code outputIndex}. ++ * ++ * @param outputIndex the index of the desired output tensor ++ */ ++ public int[] getOutputTensorShape(int outputIndex) { ++ return modelInfo.getOutputTensorShape(outputIndex); ++ } ++ ++ /** ++ * Gets the {@link TensorType} of the output tensor with {@code outputIndex}. ++ * ++ * @param outputIndex the index of the desired output tensor ++ */ ++ public byte getOutputTensorType(int outputIndex) { ++ return modelInfo.getOutputTensorType(outputIndex); ++ } ++ ++ /** ++ * Returns {@code true} if the minimum parser version required by the given metadata flatbuffer ++ * precedes or equals to the version of the metadata parser that this MetadataExtractor library ++ * is relying on. All fields in the metadata can be parsed correctly with this metadata ++ * extractor library in this case. Otherwise, it returns {@code false}. ++ * ++ * <p>For example, assume the underlying metadata parser version is {@code 1.14.1}, ++ * ++ * <ul> ++ * <li>it returns {@code true}, if the required minimum parser version is the same or older, ++ * such as {@code 1.14.1} or {@code 1.14.0}. Null version precedes all numeric versions, ++ * because some metadata flatbuffers are generated before the first versioned release; ++ * <br> <li>it returns {@code false}, if the required minimum parser version is newer, such as ++ * {@code 1.14.2}. ++ * </ul> ++ */ ++ public final boolean isMinimumParserVersionSatisfied() { ++ String minVersion = metadataInfo.getMininumParserVersion(); ++ if (minVersion == null) { ++ return true; ++ } ++ return compareVersions(minVersion, MetadataParser.VERSION) <= 0; ++ } ++ ++ /** ++ * Asserts if {@link #metadataInfo} is not initialized. Some models may not have metadata and ++ * this is allowed. However, invoking methods that reads the metadata is not allowed. ++ * ++ * @throws IllegalStateException if this model does not contain model metadata ++ */ ++ private void assertMetadataInfo() { ++ if (metadataInfo == null) { ++ throw new IllegalStateException("This model does not contain model metadata."); ++ } ++ } ++ ++ /** ++ * Asserts if {@link #zipFile} is not initialized. Some models may not have associated files, ++ * thus are not Zip files. This is allowed. However, invoking methods that reads those ++ * associated files is not allowed. ++ * ++ * @throws IllegalStateException if this model is not a Zip file ++ */ ++ private void assertZipFile() { ++ if (zipFile == null) { ++ throw new IllegalStateException( ++ "This model does not contain associated files, and is not a Zip file."); ++ } ++ } ++ ++ /** ++ * Creates a Zip file handler to read the associated files. If the model is not a zip file, i.e. ++ * it does not have associated files, return a null handler. ++ * ++ * @param buffer the TFLite model FlatBuffer ++ * @throws IOException if an error occurs while reading the model as a Zip file ++ */ ++ @Nullable ++ private static ZipFile createZipFile(ByteBuffer buffer) throws IOException { ++ try { ++ // Creates the handler to hold the associated files through the Zip. ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(buffer); ++ return ZipFile.createFrom(byteBufferChannel); ++ } catch (ZipException e) { ++ // Some models may not have associate files. Therefore, Those models are not zip files. ++ // However, invoking methods that read associated files later will lead into errors. ++ return null; ++ } ++ } ++ ++ /** ++ * Compares two semantic version numbers. ++ * ++ * <p>Examples of comparing two versions: <br> ++ * {@code 1.9} precedes {@code 1.14}; <br> ++ * {@code 1.14} precedes {@code 1.14.1}; <br> ++ * {@code 1.14} and {@code 1.14.0} are euqal; ++ * ++ * @return the value {@code 0} if the two versions are equal; a value less than {@code 0} if ++ * {@code version1} precedes {@code version2}; a value greater than {@code 0} if {@code ++ * version2} precedes {@code version1}. ++ */ ++ private static int compareVersions(String version1, String version2) { ++ // Using String.split instead of the recommanded Guava Splitter because we've been avoiding ++ // depending on other third party libraries in this project. ++ String[] levels1 = version1.split("\\.", 0); ++ String[] levels2 = version2.split("\\.", 0); ++ ++ int length = Math.max(levels1.length, levels2.length); ++ for (int i = 0; i < length; i++) { ++ Integer v1 = i < levels1.length ? Integer.parseInt(levels1[i]) : 0; ++ Integer v2 = i < levels2.length ? Integer.parseInt(levels2[i]) : 0; ++ int compare = v1.compareTo(v2); ++ if (compare != 0) { ++ return compare; ++ } ++ } ++ ++ return 0; ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java +index b6dd4a6216f11..20f556692f8f0 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java +@@ -17,11 +17,11 @@ package org.tensorflow.lite.support.metadata; + + /** Information about the metadata parser that this metadata extractor library is depending on. */ + public final class MetadataParser { +- /** +- * The version of the metadata parser that this metadata extractor library is depending on. The +- * value should match the value of "Schema Semantic version" in metadata_schema.fbs. +- */ +- public static final String VERSION = "1.3.0"; ++ /** ++ * The version of the metadata parser that this metadata extractor library is depending on. The ++ * value should match the value of "Schema Semantic version" in metadata_schema.fbs. ++ */ ++ public static final String VERSION = "1.3.0"; + +- private MetadataParser() {} ++ private MetadataParser() {} + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java +index 309a3dbe77470..863ab83e306fb 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java +@@ -18,10 +18,6 @@ package org.tensorflow.lite.support.metadata; + import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument; + import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull; + +-import java.nio.ByteBuffer; +-import java.util.ArrayList; +-import java.util.Collections; +-import java.util.List; + import org.checkerframework.checker.nullness.qual.Nullable; + import org.tensorflow.lite.schema.Buffer; + import org.tensorflow.lite.schema.Metadata; +@@ -32,235 +28,237 @@ import org.tensorflow.lite.schema.Tensor; + import org.tensorflow.lite.schema.TensorType; + import org.tensorflow.lite.support.metadata.MetadataExtractor.QuantizationParams; + ++import java.nio.ByteBuffer; ++import java.util.ArrayList; ++import java.util.Collections; ++import java.util.List; ++ + /** Extracts model information out of TFLite model FLatBuffer. */ + final class ModelInfo { +- /** The model that is loaded from TFLite model FlatBuffer. */ +- private final Model model; +- +- /** A list of input tensors. */ +- private final List</* @Nullable */ Tensor> inputTensors; +- +- /** A list of output tensors. */ +- private final List</* @Nullable */ Tensor> outputTensors; +- +- /** Identifier of the TFLite model metadata in the Metadata array. */ +- static final String METADATA_FIELD_NAME = "TFLITE_METADATA"; +- +- /** +- * Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}. +- * +- * <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports +- * single subgraph so far. See the <a +- * href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction +- * of how to specify subgraph during convertion for more information.</a> Therefore, all methods +- * in {@link ModelInfo} retrieves metadata of the first subgrpah as default. +- * +- * @param buffer the TFLite model FlatBuffer +- * @throws NullPointerException if {@code buffer} is null +- * @throws IllegalArgumentException if the model does not contain any subgraph, or the model does +- * not contain the expected identifier +- */ +- ModelInfo(ByteBuffer buffer) { +- assertTFLiteModel(buffer); +- +- model = Model.getRootAsModel(buffer); +- checkArgument(model.subgraphsLength() > 0, "The model does not contain any subgraph."); +- +- inputTensors = getInputTensors(model); +- outputTensors = getOutputTensors(model); +- } +- +- /** +- * Gets the input tensor with {@code inputIndex}. +- * +- * @param inputIndex The index of the desired input tensor. +- * @throws IllegalArgumentException if the inputIndex specified is invalid. +- */ +- @Nullable +- Tensor getInputTensor(int inputIndex) { +- checkArgument( +- inputIndex >= 0 && inputIndex < inputTensors.size(), +- "The inputIndex specified is invalid."); +- return inputTensors.get(inputIndex); +- } +- +- int getInputTensorCount() { +- return inputTensors.size(); +- } +- +- /** +- * Gets shape of the input tensor with {@code inputIndex}. +- * +- * @param inputIndex The index of the desired intput tensor. +- */ +- int[] getInputTensorShape(int inputIndex) { +- Tensor tensor = getInputTensor(inputIndex); +- return getShape(tensor); +- } +- +- /** +- * Gets the {@link TensorType} in byte of the input tensor with {@code inputIndex}. +- * +- * @param inputIndex The index of the desired intput tensor. +- */ +- byte getInputTensorType(int inputIndex) { +- return getInputTensor(inputIndex).type(); +- } +- +- /** Gets the metadata FlatBuffer from the model FlatBuffer. */ +- @Nullable +- ByteBuffer getMetadataBuffer() { +- // Some models may not have metadata, and this is allowed. +- if (model.metadataLength() == 0) { +- return null; ++ /** The model that is loaded from TFLite model FlatBuffer. */ ++ private final Model model; ++ ++ /** A list of input tensors. */ ++ private final List</* @Nullable */ Tensor> inputTensors; ++ ++ /** A list of output tensors. */ ++ private final List</* @Nullable */ Tensor> outputTensors; ++ ++ /** Identifier of the TFLite model metadata in the Metadata array. */ ++ static final String METADATA_FIELD_NAME = "TFLITE_METADATA"; ++ ++ /** ++ * Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}. ++ * ++ * <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only ++ * supports single subgraph so far. See the <a ++ * href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction ++ * of how to specify subgraph during convertion for more information.</a> Therefore, all methods ++ * in {@link ModelInfo} retrieves metadata of the first subgrpah as default. ++ * ++ * @param buffer the TFLite model FlatBuffer ++ * @throws NullPointerException if {@code buffer} is null ++ * @throws IllegalArgumentException if the model does not contain any subgraph, or the model ++ * does ++ * not contain the expected identifier ++ */ ++ ModelInfo(ByteBuffer buffer) { ++ assertTFLiteModel(buffer); ++ ++ model = Model.getRootAsModel(buffer); ++ checkArgument(model.subgraphsLength() > 0, "The model does not contain any subgraph."); ++ ++ inputTensors = getInputTensors(model); ++ outputTensors = getOutputTensors(model); ++ } ++ ++ /** ++ * Gets the input tensor with {@code inputIndex}. ++ * ++ * @param inputIndex The index of the desired input tensor. ++ * @throws IllegalArgumentException if the inputIndex specified is invalid. ++ */ ++ @Nullable ++ Tensor getInputTensor(int inputIndex) { ++ checkArgument(inputIndex >= 0 && inputIndex < inputTensors.size(), ++ "The inputIndex specified is invalid."); ++ return inputTensors.get(inputIndex); ++ } ++ ++ int getInputTensorCount() { ++ return inputTensors.size(); ++ } ++ ++ /** ++ * Gets shape of the input tensor with {@code inputIndex}. ++ * ++ * @param inputIndex The index of the desired intput tensor. ++ */ ++ int[] getInputTensorShape(int inputIndex) { ++ Tensor tensor = getInputTensor(inputIndex); ++ return getShape(tensor); + } + +- for (int i = 0; i < model.metadataLength(); i++) { +- Metadata meta = model.metadata(i); +- if (METADATA_FIELD_NAME.equals(meta.name())) { +- long bufferIndex = meta.buffer(); +- Buffer metadataBuf = model.buffers((int) bufferIndex); +- return metadataBuf.dataAsByteBuffer(); +- } ++ /** ++ * Gets the {@link TensorType} in byte of the input tensor with {@code inputIndex}. ++ * ++ * @param inputIndex The index of the desired intput tensor. ++ */ ++ byte getInputTensorType(int inputIndex) { ++ return getInputTensor(inputIndex).type(); + } +- return null; +- } +- +- /** +- * Gets the output tensor with {@code outputIndex}. +- * +- * @param outputIndex The index of the desired outtput tensor. +- * @throws IllegalArgumentException if the outputIndex specified is invalid. +- */ +- @Nullable +- Tensor getOutputTensor(int outputIndex) { +- checkArgument( +- outputIndex >= 0 && outputIndex < outputTensors.size(), +- "The outputIndex specified is invalid."); +- return outputTensors.get(outputIndex); +- } +- +- int getOutputTensorCount() { +- return outputTensors.size(); +- } +- +- /** +- * Gets shape of the output tensor with {@code outputIndex}. +- * +- * @param outputIndex The index of the desired outtput tensor. +- */ +- int[] getOutputTensorShape(int outputIndex) { +- Tensor tensor = getOutputTensor(outputIndex); +- return getShape(tensor); +- } +- +- /** +- * Gets the {@link TensorType} in byte of the output tensor {@code outputIndex}. +- * +- * @param outputIndex The index of the desired outtput tensor. +- */ +- byte getOutputTensorType(int outputIndex) { +- return getOutputTensor(outputIndex).type(); +- } +- +- /** +- * Gets the quantization parameters of a tensor. +- * +- * <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not +- * quantized, the values of scale and zero_point are both 0. +- * +- * @param tensor The tensor whoes quantization parameters is desired. +- * @throws NullPointerException if the tensor is null. +- * @throws IllegalArgumentException if {@code scale} and {@code zeroPoint} of the tensor's {@link +- * QuantizationParameters} are not single values. +- */ +- QuantizationParams getQuantizationParams(Tensor tensor) { +- checkNotNull(tensor, "Tensor cannot be null."); +- +- float scale; +- int zeroPoint; +- QuantizationParameters quantization = tensor.quantization(); +- +- // Tensors that are not quantized do not have quantization parameters, which can be null when +- // being extracted from the flatbuffer. +- if (quantization == null) { +- scale = 0.0f; +- zeroPoint = 0; +- return new QuantizationParams(scale, zeroPoint); ++ ++ /** Gets the metadata FlatBuffer from the model FlatBuffer. */ ++ @Nullable ++ ByteBuffer getMetadataBuffer() { ++ // Some models may not have metadata, and this is allowed. ++ if (model.metadataLength() == 0) { ++ return null; ++ } ++ ++ for (int i = 0; i < model.metadataLength(); i++) { ++ Metadata meta = model.metadata(i); ++ if (METADATA_FIELD_NAME.equals(meta.name())) { ++ long bufferIndex = meta.buffer(); ++ Buffer metadataBuf = model.buffers((int) bufferIndex); ++ return metadataBuf.dataAsByteBuffer(); ++ } ++ } ++ return null; ++ } ++ ++ /** ++ * Gets the output tensor with {@code outputIndex}. ++ * ++ * @param outputIndex The index of the desired outtput tensor. ++ * @throws IllegalArgumentException if the outputIndex specified is invalid. ++ */ ++ @Nullable ++ Tensor getOutputTensor(int outputIndex) { ++ checkArgument(outputIndex >= 0 && outputIndex < outputTensors.size(), ++ "The outputIndex specified is invalid."); ++ return outputTensors.get(outputIndex); ++ } ++ ++ int getOutputTensorCount() { ++ return outputTensors.size(); ++ } ++ ++ /** ++ * Gets shape of the output tensor with {@code outputIndex}. ++ * ++ * @param outputIndex The index of the desired outtput tensor. ++ */ ++ int[] getOutputTensorShape(int outputIndex) { ++ Tensor tensor = getOutputTensor(outputIndex); ++ return getShape(tensor); + } + +- // Tensors that are not quantized do not have quantization parameters. +- // quantization.scaleLength() and quantization.zeroPointLength() may both return 0. +- checkArgument( +- quantization.scaleLength() <= 1, +- "Input and output tensors do not support per-channel quantization."); +- checkArgument( +- quantization.zeroPointLength() <= 1, +- "Input and output tensors do not support per-channel quantization."); +- +- // For tensors that are not quantized, quantization.scale(0) and quantization.zeroPoint(0) will +- // both be the default value in flatbuffer, 0. This behavior is consistent with the TFlite C++ +- // runtime. +- scale = quantization.scale(0); +- // zeroPoint is a long value in the schema, but an integer in the C++ runtime. Here we keep it +- // consistent with the C++ runtime. +- zeroPoint = (int) quantization.zeroPoint(0); +- +- return new QuantizationParams(scale, zeroPoint); +- } +- +- /** +- * Verifies if the buffer is a valid TFLite model. +- * +- * @param buffer the TFLite model flatbuffer +- * @throws NullPointerException if {@code buffer} is null. +- * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier +- */ +- private static void assertTFLiteModel(ByteBuffer buffer) { +- checkNotNull(buffer, "Model flatbuffer cannot be null."); +- checkArgument( +- Model.ModelBufferHasIdentifier(buffer), +- "The identifier of the model is invalid. The buffer may not be a valid TFLite model" +- + " flatbuffer."); +- } +- +- /** +- * Gets the shape of a tensor. +- * +- * @param tensor The tensor whoes shape is desired. +- * @throws NullPointerException if the tensor is null. +- */ +- private static int[] getShape(Tensor tensor) { +- checkNotNull(tensor, "Tensor cannot be null."); +- int shapeDim = tensor.shapeLength(); +- int[] tensorShape = new int[shapeDim]; +- for (int i = 0; i < shapeDim; i++) { +- tensorShape[i] = tensor.shape(i); ++ /** ++ * Gets the {@link TensorType} in byte of the output tensor {@code outputIndex}. ++ * ++ * @param outputIndex The index of the desired outtput tensor. ++ */ ++ byte getOutputTensorType(int outputIndex) { ++ return getOutputTensor(outputIndex).type(); + } +- return tensorShape; +- } +- +- /** Gets input tensors from a model. */ +- private static List<Tensor> getInputTensors(Model model) { +- // TFLite only support one subgraph currently. +- SubGraph subgraph = model.subgraphs(0); +- int tensorNum = subgraph.inputsLength(); +- ArrayList<Tensor> inputTensors = new ArrayList<>(tensorNum); +- for (int i = 0; i < tensorNum; i++) { +- inputTensors.add(subgraph.tensors(subgraph.inputs(i))); ++ ++ /** ++ * Gets the quantization parameters of a tensor. ++ * ++ * <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not ++ * quantized, the values of scale and zero_point are both 0. ++ * ++ * @param tensor The tensor whoes quantization parameters is desired. ++ * @throws NullPointerException if the tensor is null. ++ * @throws IllegalArgumentException if {@code scale} and {@code zeroPoint} of the tensor's ++ * {@link ++ * QuantizationParameters} are not single values. ++ */ ++ QuantizationParams getQuantizationParams(Tensor tensor) { ++ checkNotNull(tensor, "Tensor cannot be null."); ++ ++ float scale; ++ int zeroPoint; ++ QuantizationParameters quantization = tensor.quantization(); ++ ++ // Tensors that are not quantized do not have quantization parameters, which can be null ++ // when being extracted from the flatbuffer. ++ if (quantization == null) { ++ scale = 0.0f; ++ zeroPoint = 0; ++ return new QuantizationParams(scale, zeroPoint); ++ } ++ ++ // Tensors that are not quantized do not have quantization parameters. ++ // quantization.scaleLength() and quantization.zeroPointLength() may both return 0. ++ checkArgument(quantization.scaleLength() <= 1, ++ "Input and output tensors do not support per-channel quantization."); ++ checkArgument(quantization.zeroPointLength() <= 1, ++ "Input and output tensors do not support per-channel quantization."); ++ ++ // For tensors that are not quantized, quantization.scale(0) and quantization.zeroPoint(0) ++ // will both be the default value in flatbuffer, 0. This behavior is consistent with the ++ // TFlite C++ runtime. ++ scale = quantization.scale(0); ++ // zeroPoint is a long value in the schema, but an integer in the C++ runtime. Here we keep ++ // it consistent with the C++ runtime. ++ zeroPoint = (int) quantization.zeroPoint(0); ++ ++ return new QuantizationParams(scale, zeroPoint); + } +- return Collections.unmodifiableList(inputTensors); +- } +- +- /** Gets output tensors from a model. */ +- private static List<Tensor> getOutputTensors(Model model) { +- // TFLite only support one subgraph currently. +- SubGraph subgraph = model.subgraphs(0); +- int tensorNum = subgraph.outputsLength(); +- ArrayList<Tensor> outputTensors = new ArrayList<>(tensorNum); +- for (int i = 0; i < tensorNum; i++) { +- outputTensors.add(subgraph.tensors(subgraph.outputs(i))); ++ ++ /** ++ * Verifies if the buffer is a valid TFLite model. ++ * ++ * @param buffer the TFLite model flatbuffer ++ * @throws NullPointerException if {@code buffer} is null. ++ * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier ++ */ ++ private static void assertTFLiteModel(ByteBuffer buffer) { ++ checkNotNull(buffer, "Model flatbuffer cannot be null."); ++ checkArgument(Model.ModelBufferHasIdentifier(buffer), ++ "The identifier of the model is invalid. The buffer may not be a valid TFLite model" ++ + " flatbuffer."); ++ } ++ ++ /** ++ * Gets the shape of a tensor. ++ * ++ * @param tensor The tensor whoes shape is desired. ++ * @throws NullPointerException if the tensor is null. ++ */ ++ private static int[] getShape(Tensor tensor) { ++ checkNotNull(tensor, "Tensor cannot be null."); ++ int shapeDim = tensor.shapeLength(); ++ int[] tensorShape = new int[shapeDim]; ++ for (int i = 0; i < shapeDim; i++) { ++ tensorShape[i] = tensor.shape(i); ++ } ++ return tensorShape; ++ } ++ ++ /** Gets input tensors from a model. */ ++ private static List<Tensor> getInputTensors(Model model) { ++ // TFLite only support one subgraph currently. ++ SubGraph subgraph = model.subgraphs(0); ++ int tensorNum = subgraph.inputsLength(); ++ ArrayList<Tensor> inputTensors = new ArrayList<>(tensorNum); ++ for (int i = 0; i < tensorNum; i++) { ++ inputTensors.add(subgraph.tensors(subgraph.inputs(i))); ++ } ++ return Collections.unmodifiableList(inputTensors); ++ } ++ ++ /** Gets output tensors from a model. */ ++ private static List<Tensor> getOutputTensors(Model model) { ++ // TFLite only support one subgraph currently. ++ SubGraph subgraph = model.subgraphs(0); ++ int tensorNum = subgraph.outputsLength(); ++ ArrayList<Tensor> outputTensors = new ArrayList<>(tensorNum); ++ for (int i = 0; i < tensorNum; i++) { ++ outputTensors.add(subgraph.tensors(subgraph.outputs(i))); ++ } ++ return Collections.unmodifiableList(outputTensors); + } +- return Collections.unmodifiableList(outputTensors); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java +index 751ed500dc2fc..7ee01df094283 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java +@@ -18,136 +18,133 @@ package org.tensorflow.lite.support.metadata; + import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument; + import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull; + +-import java.nio.ByteBuffer; +-import java.util.ArrayList; +-import java.util.Collections; +-import java.util.List; + import org.checkerframework.checker.nullness.qual.Nullable; + import org.tensorflow.lite.support.metadata.schema.ModelMetadata; + import org.tensorflow.lite.support.metadata.schema.SubGraphMetadata; + import org.tensorflow.lite.support.metadata.schema.TensorMetadata; + ++import java.nio.ByteBuffer; ++import java.util.ArrayList; ++import java.util.Collections; ++import java.util.List; ++ + /** Extracts model metadata information out of TFLite metadata FlatBuffer. */ + final class ModelMetadataInfo { +- /** The root handler for the model metadata. */ +- private final ModelMetadata modelMetadata; +- +- /** Metadata array of input tensors. */ +- private final List</* @Nullable */ TensorMetadata> inputsMetadata; +- +- /** Metadata array of output tensors. */ +- private final List</* @Nullable */ TensorMetadata> outputsMetadata; +- +- /** The minimum parser version required to fully understand the metadata flatbuffer. */ +- private final String /* @Nullable */ minVersion; +- +- /** +- * Creates a {@link ModelMetadataInfo} with the metadata FlatBuffer, {@code buffer}. +- * +- * @param buffer the TFLite metadata FlatBuffer +- * @throws NullPointerException if {@code buffer} is null +- * @throws IllegalArgumentException if {@code buffer} does not contain any subgraph metadata, or +- * it does not contain the expected identifier +- */ +- ModelMetadataInfo(ByteBuffer buffer) { +- assertTFLiteMetadata(buffer); +- +- modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer); +- checkArgument( +- modelMetadata.subgraphMetadataLength() > 0, +- "The metadata flatbuffer does not contain any subgraph metadata."); +- +- inputsMetadata = getInputsMetadata(modelMetadata); +- outputsMetadata = getOutputsMetadata(modelMetadata); +- minVersion = modelMetadata.minParserVersion(); +- } +- +- /** Gets the count of input tensors with metadata in the metadata FlatBuffer. */ +- int getInputTensorCount() { +- return inputsMetadata.size(); +- } +- +- /** +- * Gets the metadata for the input tensor specified by {@code inputIndex}. +- * +- * @param inputIndex The index of the desired intput tensor. +- * @throws IllegalArgumentException if the inputIndex specified is invalid. +- */ +- @Nullable +- TensorMetadata getInputTensorMetadata(int inputIndex) { +- checkArgument( +- inputIndex >= 0 && inputIndex < inputsMetadata.size(), +- "The inputIndex specified is invalid."); +- return inputsMetadata.get(inputIndex); +- } +- +- /** +- * Gets the minimum parser version of the metadata. It can be {@code null} if the version is not +- * populated. +- */ +- @Nullable +- String getMininumParserVersion() { +- return minVersion; +- } +- +- /** Gets the root handler for the model metadata. */ +- ModelMetadata getModelMetadata() { +- return modelMetadata; +- } +- +- /** Gets the count of output tensors with metadata in the metadata FlatBuffer. */ +- int getOutputTensorCount() { +- return outputsMetadata.size(); +- } +- +- /** +- * Gets the metadata for the output tensor specified by {@code outputIndex}. +- * +- * @param outputIndex The index of the desired output tensor. +- * @throws IllegalArgumentException if the outputIndex specified is invalid. +- */ +- @Nullable +- TensorMetadata getOutputTensorMetadata(int outputIndex) { +- checkArgument( +- outputIndex >= 0 && outputIndex < outputsMetadata.size(), +- "The outputIndex specified is invalid."); +- return outputsMetadata.get(outputIndex); +- } +- +- /** +- * Verifies if the buffer is a valid TFLite metadata flatbuffer. +- * +- * @param buffer the TFLite metadata flatbuffer +- * @throws NullPointerException if {@code buffer} is null. +- * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier +- */ +- private static void assertTFLiteMetadata(ByteBuffer buffer) { +- checkNotNull(buffer, "Metadata flatbuffer cannot be null."); +- checkArgument( +- ModelMetadata.ModelMetadataBufferHasIdentifier(buffer), +- "The identifier of the metadata is invalid. The buffer may not be a valid TFLite metadata" +- + " flatbuffer."); +- } +- +- /** Gets metadata for all input tensors. */ +- private static List<TensorMetadata> getInputsMetadata(ModelMetadata modelMetadata) { +- SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0); +- int tensorNum = subgraphMetadata.inputTensorMetadataLength(); +- ArrayList<TensorMetadata> inputsMetadata = new ArrayList<>(tensorNum); +- for (int i = 0; i < tensorNum; i++) { +- inputsMetadata.add(subgraphMetadata.inputTensorMetadata(i)); ++ /** The root handler for the model metadata. */ ++ private final ModelMetadata modelMetadata; ++ ++ /** Metadata array of input tensors. */ ++ private final List</* @Nullable */ TensorMetadata> inputsMetadata; ++ ++ /** Metadata array of output tensors. */ ++ private final List</* @Nullable */ TensorMetadata> outputsMetadata; ++ ++ /** The minimum parser version required to fully understand the metadata flatbuffer. */ ++ private final String /* @Nullable */ minVersion; ++ ++ /** ++ * Creates a {@link ModelMetadataInfo} with the metadata FlatBuffer, {@code buffer}. ++ * ++ * @param buffer the TFLite metadata FlatBuffer ++ * @throws NullPointerException if {@code buffer} is null ++ * @throws IllegalArgumentException if {@code buffer} does not contain any subgraph metadata, or ++ * it does not contain the expected identifier ++ */ ++ ModelMetadataInfo(ByteBuffer buffer) { ++ assertTFLiteMetadata(buffer); ++ ++ modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer); ++ checkArgument(modelMetadata.subgraphMetadataLength() > 0, ++ "The metadata flatbuffer does not contain any subgraph metadata."); ++ ++ inputsMetadata = getInputsMetadata(modelMetadata); ++ outputsMetadata = getOutputsMetadata(modelMetadata); ++ minVersion = modelMetadata.minParserVersion(); ++ } ++ ++ /** Gets the count of input tensors with metadata in the metadata FlatBuffer. */ ++ int getInputTensorCount() { ++ return inputsMetadata.size(); ++ } ++ ++ /** ++ * Gets the metadata for the input tensor specified by {@code inputIndex}. ++ * ++ * @param inputIndex The index of the desired intput tensor. ++ * @throws IllegalArgumentException if the inputIndex specified is invalid. ++ */ ++ @Nullable ++ TensorMetadata getInputTensorMetadata(int inputIndex) { ++ checkArgument(inputIndex >= 0 && inputIndex < inputsMetadata.size(), ++ "The inputIndex specified is invalid."); ++ return inputsMetadata.get(inputIndex); + } +- return Collections.unmodifiableList(inputsMetadata); +- } +- +- /** Gets metadata for all output tensors. */ +- private static List<TensorMetadata> getOutputsMetadata(ModelMetadata modelMetadata) { +- SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0); +- int tensorNum = subgraphMetadata.outputTensorMetadataLength(); +- ArrayList<TensorMetadata> outputsMetadata = new ArrayList<>(tensorNum); +- for (int i = 0; i < tensorNum; i++) { +- outputsMetadata.add(subgraphMetadata.outputTensorMetadata(i)); ++ ++ /** ++ * Gets the minimum parser version of the metadata. It can be {@code null} if the version is not ++ * populated. ++ */ ++ @Nullable ++ String getMininumParserVersion() { ++ return minVersion; ++ } ++ ++ /** Gets the root handler for the model metadata. */ ++ ModelMetadata getModelMetadata() { ++ return modelMetadata; ++ } ++ ++ /** Gets the count of output tensors with metadata in the metadata FlatBuffer. */ ++ int getOutputTensorCount() { ++ return outputsMetadata.size(); ++ } ++ ++ /** ++ * Gets the metadata for the output tensor specified by {@code outputIndex}. ++ * ++ * @param outputIndex The index of the desired output tensor. ++ * @throws IllegalArgumentException if the outputIndex specified is invalid. ++ */ ++ @Nullable ++ TensorMetadata getOutputTensorMetadata(int outputIndex) { ++ checkArgument(outputIndex >= 0 && outputIndex < outputsMetadata.size(), ++ "The outputIndex specified is invalid."); ++ return outputsMetadata.get(outputIndex); ++ } ++ ++ /** ++ * Verifies if the buffer is a valid TFLite metadata flatbuffer. ++ * ++ * @param buffer the TFLite metadata flatbuffer ++ * @throws NullPointerException if {@code buffer} is null. ++ * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier ++ */ ++ private static void assertTFLiteMetadata(ByteBuffer buffer) { ++ checkNotNull(buffer, "Metadata flatbuffer cannot be null."); ++ checkArgument(ModelMetadata.ModelMetadataBufferHasIdentifier(buffer), ++ "The identifier of the metadata is invalid. The buffer may not be a valid TFLite metadata" ++ + " flatbuffer."); ++ } ++ ++ /** Gets metadata for all input tensors. */ ++ private static List<TensorMetadata> getInputsMetadata(ModelMetadata modelMetadata) { ++ SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0); ++ int tensorNum = subgraphMetadata.inputTensorMetadataLength(); ++ ArrayList<TensorMetadata> inputsMetadata = new ArrayList<>(tensorNum); ++ for (int i = 0; i < tensorNum; i++) { ++ inputsMetadata.add(subgraphMetadata.inputTensorMetadata(i)); ++ } ++ return Collections.unmodifiableList(inputsMetadata); ++ } ++ ++ /** Gets metadata for all output tensors. */ ++ private static List<TensorMetadata> getOutputsMetadata(ModelMetadata modelMetadata) { ++ SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0); ++ int tensorNum = subgraphMetadata.outputTensorMetadataLength(); ++ ArrayList<TensorMetadata> outputsMetadata = new ArrayList<>(tensorNum); ++ for (int i = 0; i < tensorNum; i++) { ++ outputsMetadata.add(subgraphMetadata.outputTensorMetadata(i)); ++ } ++ return Collections.unmodifiableList(outputsMetadata); + } +- return Collections.unmodifiableList(outputsMetadata); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java +index c2f20fbaacd76..ca3eed3490644 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java +@@ -19,166 +19,170 @@ import org.checkerframework.checker.nullness.qual.Nullable; + + /** Static error checking util methods. */ + final class Preconditions { +- /** +- * Ensures that an object reference passed as a parameter to the calling method is not null. +- * +- * @param reference an object reference +- * @return the non-null reference that was validated +- * @throws NullPointerException if {@code reference} is null +- */ +- public static <T extends Object> T checkNotNull(T reference) { +- if (reference == null) { +- throw new NullPointerException("The object reference is null."); ++ /** ++ * Ensures that an object reference passed as a parameter to the calling method is not null. ++ * ++ * @param reference an object reference ++ * @return the non-null reference that was validated ++ * @throws NullPointerException if {@code reference} is null ++ */ ++ public static <T extends Object> T checkNotNull(T reference) { ++ if (reference == null) { ++ throw new NullPointerException("The object reference is null."); ++ } ++ return reference; + } +- return reference; +- } +- +- /** +- * Ensures that an object reference passed as a parameter to the calling method is not null. +- * +- * @param reference an object reference +- * @param errorMessage the exception message to use if the check fails; will be converted to a +- * string using {@link String#valueOf(Object)} +- * @return the non-null reference that was validated +- * @throws NullPointerException if {@code reference} is null +- */ +- public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) { +- if (reference == null) { +- throw new NullPointerException(String.valueOf(errorMessage)); ++ ++ /** ++ * Ensures that an object reference passed as a parameter to the calling method is not null. ++ * ++ * @param reference an object reference ++ * @param errorMessage the exception message to use if the check fails; will be converted to a ++ * string using {@link String#valueOf(Object)} ++ * @return the non-null reference that was validated ++ * @throws NullPointerException if {@code reference} is null ++ */ ++ public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) { ++ if (reference == null) { ++ throw new NullPointerException(String.valueOf(errorMessage)); ++ } ++ return reference; ++ } ++ ++ /** ++ * Ensures that the given String is not empty and not null. ++ * ++ * @param string the String to test ++ * @return the non-null non-empty String that was validated ++ * @throws IllegalArgumentException if {@code string} is null or empty ++ */ ++ public static String checkNotEmpty(String string) { ++ if (string == null || string.length() == 0) { ++ throw new IllegalArgumentException("Given String is empty or null."); ++ } ++ return string; + } +- return reference; +- } +- +- /** +- * Ensures that the given String is not empty and not null. +- * +- * @param string the String to test +- * @return the non-null non-empty String that was validated +- * @throws IllegalArgumentException if {@code string} is null or empty +- */ +- public static String checkNotEmpty(String string) { +- if (string == null || string.length() == 0) { +- throw new IllegalArgumentException("Given String is empty or null."); ++ ++ /** ++ * Ensures that the given String is not empty and not null. ++ * ++ * @param string the String to test ++ * @param errorMessage the exception message to use if the check fails; will be converted to a ++ * string using {@link String#valueOf(Object)} ++ * @return the non-null non-empty String that was validated ++ * @throws IllegalArgumentException if {@code string} is null or empty ++ */ ++ public static String checkNotEmpty(String string, Object errorMessage) { ++ if (string == null || string.length() == 0) { ++ throw new IllegalArgumentException(String.valueOf(errorMessage)); ++ } ++ return string; + } +- return string; +- } +- +- /** +- * Ensures that the given String is not empty and not null. +- * +- * @param string the String to test +- * @param errorMessage the exception message to use if the check fails; will be converted to a +- * string using {@link String#valueOf(Object)} +- * @return the non-null non-empty String that was validated +- * @throws IllegalArgumentException if {@code string} is null or empty +- */ +- public static String checkNotEmpty(String string, Object errorMessage) { +- if (string == null || string.length() == 0) { +- throw new IllegalArgumentException(String.valueOf(errorMessage)); ++ ++ /** ++ * Ensures the truth of an expression involving one or more parameters to the calling method. ++ * ++ * @param expression a boolean expression. ++ * @throws IllegalArgumentException if {@code expression} is false. ++ */ ++ public static void checkArgument(boolean expression) { ++ if (!expression) { ++ throw new IllegalArgumentException(); ++ } + } +- return string; +- } +- +- /** +- * Ensures the truth of an expression involving one or more parameters to the calling method. +- * +- * @param expression a boolean expression. +- * @throws IllegalArgumentException if {@code expression} is false. +- */ +- public static void checkArgument(boolean expression) { +- if (!expression) { +- throw new IllegalArgumentException(); ++ ++ /** ++ * Ensures the truth of an expression involving one or more parameters to the calling method. ++ * ++ * @param expression a boolean expression. ++ * @param errorMessage the exception message to use if the check fails; will be converted to a ++ * string using {@link String#valueOf(Object)}. ++ * @throws IllegalArgumentException if {@code expression} is false. ++ */ ++ public static void checkArgument(boolean expression, @Nullable Object errorMessage) { ++ if (!expression) { ++ throw new IllegalArgumentException(String.valueOf(errorMessage)); ++ } + } +- } +- +- /** +- * Ensures the truth of an expression involving one or more parameters to the calling method. +- * +- * @param expression a boolean expression. +- * @param errorMessage the exception message to use if the check fails; will be converted to a +- * string using {@link String#valueOf(Object)}. +- * @throws IllegalArgumentException if {@code expression} is false. +- */ +- public static void checkArgument(boolean expression, @Nullable Object errorMessage) { +- if (!expression) { +- throw new IllegalArgumentException(String.valueOf(errorMessage)); ++ ++ /** ++ * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of ++ * size ++ * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive. ++ * ++ * @param index a user-supplied index identifying an element of an array, list or string ++ * @param size the size of that array, list or string ++ * @return the value of {@code index} ++ * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code ++ * size} ++ * @throws IllegalArgumentException if {@code size} is negative ++ */ ++ public static int checkElementIndex(int index, int size) { ++ return checkElementIndex(index, size, "index"); + } +- } +- +- /** +- * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size +- * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive. +- * +- * @param index a user-supplied index identifying an element of an array, list or string +- * @param size the size of that array, list or string +- * @return the value of {@code index} +- * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size} +- * @throws IllegalArgumentException if {@code size} is negative +- */ +- public static int checkElementIndex(int index, int size) { +- return checkElementIndex(index, size, "index"); +- } +- +- /** +- * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size +- * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive. +- * +- * @param index a user-supplied index identifying an element of an array, list or string +- * @param size the size of that array, list or string +- * @param desc the text to use to describe this index in an error message +- * @return the value of {@code index} +- * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size} +- * @throws IllegalArgumentException if {@code size} is negative +- */ +- public static int checkElementIndex(int index, int size, @Nullable String desc) { +- // Carefully optimized for execution by hotspot (explanatory comment above) +- if (index < 0 || index >= size) { +- throw new IndexOutOfBoundsException(badElementIndex(index, size, desc)); ++ ++ /** ++ * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of ++ * size ++ * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive. ++ * ++ * @param index a user-supplied index identifying an element of an array, list or string ++ * @param size the size of that array, list or string ++ * @param desc the text to use to describe this index in an error message ++ * @return the value of {@code index} ++ * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code ++ * size} ++ * @throws IllegalArgumentException if {@code size} is negative ++ */ ++ public static int checkElementIndex(int index, int size, @Nullable String desc) { ++ // Carefully optimized for execution by hotspot (explanatory comment above) ++ if (index < 0 || index >= size) { ++ throw new IndexOutOfBoundsException(badElementIndex(index, size, desc)); ++ } ++ return index; + } +- return index; +- } +- +- /** +- * Ensures the truth of an expression involving the state of the calling instance, but not +- * involving any parameters to the calling method. +- * +- * @param expression a boolean expression +- * @throws IllegalStateException if {@code expression} is false +- * @see Verify#verify Verify.verify() +- */ +- public static void checkState(boolean expression) { +- if (!expression) { +- throw new IllegalStateException(); ++ ++ /** ++ * Ensures the truth of an expression involving the state of the calling instance, but not ++ * involving any parameters to the calling method. ++ * ++ * @param expression a boolean expression ++ * @throws IllegalStateException if {@code expression} is false ++ * @see Verify#verify Verify.verify() ++ */ ++ public static void checkState(boolean expression) { ++ if (!expression) { ++ throw new IllegalStateException(); ++ } + } +- } +- +- /** +- * Ensures the truth of an expression involving the state of the calling instance, but not +- * involving any parameters to the calling method. +- * +- * @param expression a boolean expression +- * @param errorMessage the exception message to use if the check fails; will be converted to a +- * string using {@link String#valueOf(Object)} +- * @throws IllegalStateException if {@code expression} is false +- * @see Verify#verify Verify.verify() +- */ +- public static void checkState(boolean expression, @Nullable Object errorMessage) { +- if (!expression) { +- throw new IllegalStateException(String.valueOf(errorMessage)); ++ ++ /** ++ * Ensures the truth of an expression involving the state of the calling instance, but not ++ * involving any parameters to the calling method. ++ * ++ * @param expression a boolean expression ++ * @param errorMessage the exception message to use if the check fails; will be converted to a ++ * string using {@link String#valueOf(Object)} ++ * @throws IllegalStateException if {@code expression} is false ++ * @see Verify#verify Verify.verify() ++ */ ++ public static void checkState(boolean expression, @Nullable Object errorMessage) { ++ if (!expression) { ++ throw new IllegalStateException(String.valueOf(errorMessage)); ++ } + } +- } +- +- private static String badElementIndex(int index, int size, @Nullable String desc) { +- if (index < 0) { +- return String.format("%s (%s) must not be negative", desc, index); +- } else if (size < 0) { +- throw new IllegalArgumentException("negative size: " + size); +- } else { // index >= size +- return String.format("%s (%s) must be less than size (%s)", desc, index, size); ++ ++ private static String badElementIndex(int index, int size, @Nullable String desc) { ++ if (index < 0) { ++ return String.format("%s (%s) must not be negative", desc, index); ++ } else if (size < 0) { ++ throw new IllegalArgumentException("negative size: " + size); ++ } else { // index >= size ++ return String.format("%s (%s) must be less than size (%s)", desc, index, size); ++ } + } +- } + +- private Preconditions() { +- throw new AssertionError("Preconditions is Uninstantiable."); +- } ++ private Preconditions() { ++ throw new AssertionError("Preconditions is Uninstantiable."); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java +index c655786755baa..1408a3a73d86b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java +@@ -29,79 +29,79 @@ import java.nio.channels.Channel; + * the MetadtaExtractor library consistent with the common used Java libraries. + */ + interface SeekableByteChannelCompat extends Channel { +- /** +- * Reads a sequence of bytes from this channel into the given buffer. +- * +- * @param dst The buffer into which bytes are to be transferred +- * @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached +- * end-of-stream +- * @throws NonReadableChannelException If this channel was not opened for reading +- * @throws ClosedChannelException If this channel is closed +- * @throws AsynchronousCloseException If another thread closes this channel while the read +- * operation is in progress +- * @throws ClosedByInterruptException If another thread interrupts the current thread while the +- * read operation is in progress, thereby closing the channel and setting the current thread's +- * interrupt status +- * @throws IOException If some other I/O error occurs +- */ +- int read(ByteBuffer dst) throws IOException; ++ /** ++ * Reads a sequence of bytes from this channel into the given buffer. ++ * ++ * @param dst The buffer into which bytes are to be transferred ++ * @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached ++ * end-of-stream ++ * @throws NonReadableChannelException If this channel was not opened for reading ++ * @throws ClosedChannelException If this channel is closed ++ * @throws AsynchronousCloseException If another thread closes this channel while the read ++ * operation is in progress ++ * @throws ClosedByInterruptException If another thread interrupts the current thread while the ++ * read operation is in progress, thereby closing the channel and setting the current ++ * thread's interrupt status ++ * @throws IOException If some other I/O error occurs ++ */ ++ int read(ByteBuffer dst) throws IOException; + +- /** +- * Writes a sequence of bytes to this channel from the given buffer. +- * +- * @param src The buffer from which bytes are to be retrieved +- * @return The number of bytes written, possibly zero +- * @throws NonWritableChannelException If this channel was not opened for writing +- * @throws ClosedChannelException If this channel is closed +- * @throws AsynchronousCloseException If another thread closes this channel while the write +- * operation is in progress +- * @throws ClosedByInterruptException If another thread interrupts the current thread while the +- * write operation is in progress, thereby closing the channel and setting the current +- * thread's interrupt status +- * @throws IOException If some other I/O error occurs +- */ +- int write(ByteBuffer src) throws IOException; ++ /** ++ * Writes a sequence of bytes to this channel from the given buffer. ++ * ++ * @param src The buffer from which bytes are to be retrieved ++ * @return The number of bytes written, possibly zero ++ * @throws NonWritableChannelException If this channel was not opened for writing ++ * @throws ClosedChannelException If this channel is closed ++ * @throws AsynchronousCloseException If another thread closes this channel while the write ++ * operation is in progress ++ * @throws ClosedByInterruptException If another thread interrupts the current thread while the ++ * write operation is in progress, thereby closing the channel and setting the current ++ * thread's interrupt status ++ * @throws IOException If some other I/O error occurs ++ */ ++ int write(ByteBuffer src) throws IOException; + +- /** +- * Returns this channel's position. +- * +- * @return This channel's position, a non-negative integer counting the number of bytes from the +- * beginning of the entity to the current position +- * @throws ClosedChannelException If this channel is closed +- * @throws IOException If some other I/O error occurs +- */ +- long position() throws IOException; ++ /** ++ * Returns this channel's position. ++ * ++ * @return This channel's position, a non-negative integer counting the number of bytes from the ++ * beginning of the entity to the current position ++ * @throws ClosedChannelException If this channel is closed ++ * @throws IOException If some other I/O error occurs ++ */ ++ long position() throws IOException; + +- /** +- * Sets this channel's position. +- * +- * @param newPosition The new position, a non-negative integer counting the number of bytes from +- * the beginning of the entity +- * @return This channel +- * @throws ClosedChannelException If this channel is closed +- * @throws IllegalArgumentException If the new position is negative +- * @throws IOException If some other I/O error occurs +- */ +- SeekableByteChannelCompat position(long newPosition) throws IOException; ++ /** ++ * Sets this channel's position. ++ * ++ * @param newPosition The new position, a non-negative integer counting the number of bytes from ++ * the beginning of the entity ++ * @return This channel ++ * @throws ClosedChannelException If this channel is closed ++ * @throws IllegalArgumentException If the new position is negative ++ * @throws IOException If some other I/O error occurs ++ */ ++ SeekableByteChannelCompat position(long newPosition) throws IOException; + +- /** +- * Returns the current size of entity to which this channel is connected. +- * +- * @return The current size, measured in bytes +- * @throws ClosedChannelException If this channel is closed +- * @throws IOException If some other I/O error occurs +- */ +- long size() throws IOException; ++ /** ++ * Returns the current size of entity to which this channel is connected. ++ * ++ * @return The current size, measured in bytes ++ * @throws ClosedChannelException If this channel is closed ++ * @throws IOException If some other I/O error occurs ++ */ ++ long size() throws IOException; + +- /** +- * Truncates the entity, to which this channel is connected, to the given size. +- * +- * @param size The new size, a non-negative byte count +- * @return This channel +- * @throws NonWritableChannelException If this channel was not opened for writing +- * @throws ClosedChannelException If this channel is closed +- * @throws IllegalArgumentException If the new size is negative +- * @throws IOException If some other I/O error occurs +- */ +- SeekableByteChannelCompat truncate(long size) throws IOException; ++ /** ++ * Truncates the entity, to which this channel is connected, to the given size. ++ * ++ * @param size The new size, a non-negative byte count ++ * @return This channel ++ * @throws NonWritableChannelException If this channel was not opened for writing ++ * @throws ClosedChannelException If this channel is closed ++ * @throws IllegalArgumentException If the new size is negative ++ * @throws IOException If some other I/O error occurs ++ */ ++ SeekableByteChannelCompat truncate(long size) throws IOException; + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java +index 6b43e724fd814..c8a3fb806d920 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java +@@ -45,393 +45,389 @@ import java.util.zip.ZipException; + * size limit for Zip64, which is 4GB. + */ + final class ZipFile implements Closeable { +- /** Maps String to list of ZipEntrys, name -> actual entries. */ +- private final Map<String, List<ZipEntry>> nameMap; +- +- /** The actual data source. */ +- private final ByteBufferChannel archive; +- +- /** +- * Opens the given {@link ByteBufferChannel} for reading, assuming "UTF8" for file names. {@link +- * ZipFile} does not synchronized over the buffer that is passed into it. +- * +- * @param channel the archive +- * @throws IOException if an error occurs while creating this {@link ZipFile} +- * @throws ZipException if the channel is not a zip archive +- * @throws NullPointerException if the archive is null +- */ +- public static ZipFile createFrom(ByteBufferChannel channel) throws IOException { +- checkNotNull(channel); +- ZipParser zipParser = new ZipParser(channel); +- Map<String, List<ZipEntry>> nameMap = zipParser.parseEntries(); +- return new ZipFile(channel, nameMap); +- } +- +- @Override +- public void close() { +- archive.close(); +- } +- +- /** +- * Exposes the raw stream of the archive entry. +- * +- * <p>Since the associated files will not be compressed when being packed to the zip file, the raw +- * stream represents the non-compressed files. +- * +- * <p><b>WARNING:</b> The returned {@link InputStream}, is <b>not</b> thread-safe. If multiple +- * threads concurrently reading from the returned {@link InputStream}, it must be synchronized +- * externally. +- * +- * @param name name of the entry to get the stream for +- * @return the raw input stream containing data +- * @throws IllegalArgumentException if the specified file does not exist in the zip file +- */ +- public InputStream getRawInputStream(String name) { +- checkArgument( +- nameMap.containsKey(name), +- String.format("The file, %s, does not exist in the zip file.", name)); +- +- List<ZipEntry> entriesWithTheSameName = nameMap.get(name); +- ZipEntry entry = entriesWithTheSameName.get(0); +- long start = entry.getDataOffset(); +- long remaining = entry.getSize(); +- return new BoundedInputStream(archive, start, remaining); +- } +- +- /** +- * Exposes the file names of the included files. +- * +- * @return the file names of the included files +- */ +- public Set<String> getFileNames() { +- return nameMap.keySet(); +- } +- +- private ZipFile(ByteBufferChannel channel, Map<String, List<ZipEntry>> nameMap) { +- archive = channel; +- this.nameMap = nameMap; +- } +- +- /* Parses a Zip archive and gets the information for each {@link ZipEntry}. */ +- private static class ZipParser { +- private final ByteBufferChannel archive; +- +- // Cached buffers that will only be used locally in the class to reduce garbage collection. +- private final ByteBuffer longBuffer = +- ByteBuffer.allocate(ZipConstants.LONG_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN); +- private final ByteBuffer intBuffer = +- ByteBuffer.allocate(ZipConstants.INT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN); +- private final ByteBuffer shortBuffer = +- ByteBuffer.allocate(ZipConstants.SHORT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN); ++ /** Maps String to list of ZipEntrys, name -> actual entries. */ ++ private final Map<String, List<ZipEntry>> nameMap; + +- private ZipParser(ByteBufferChannel archive) { +- this.archive = archive; +- } +- +- /** +- * Parses the underlying {@code archive} and returns the information as a list of {@link +- * ZipEntry}. +- */ +- private Map<String, List<ZipEntry>> parseEntries() throws IOException { +- List<ZipEntry> entries = parseCentralDirectory(); +- return parseLocalFileHeaderData(entries); +- } +- +- /** +- * Checks if the current position contains a central file header signature, {@link +- * ZipConstants#CENSIG}. +- */ +- private boolean foundCentralFileheaderSignature() { +- long signature = (long) getInt(); +- return signature == ZipConstants.CENSIG; +- } +- +- /** +- * Gets the value as a Java int from two bytes starting at the current position of the archive. +- */ +- private int getShort() { +- shortBuffer.rewind(); +- archive.read(shortBuffer); +- shortBuffer.flip(); +- return (int) shortBuffer.getShort(); +- } ++ /** The actual data source. */ ++ private final ByteBufferChannel archive; + + /** +- * Gets the value as a Java long from four bytes starting at the current position of the +- * archive. ++ * Opens the given {@link ByteBufferChannel} for reading, assuming "UTF8" for file names. {@link ++ * ZipFile} does not synchronized over the buffer that is passed into it. ++ * ++ * @param channel the archive ++ * @throws IOException if an error occurs while creating this {@link ZipFile} ++ * @throws ZipException if the channel is not a zip archive ++ * @throws NullPointerException if the archive is null + */ +- private int getInt() { +- intBuffer.rewind(); +- archive.read(intBuffer); +- intBuffer.flip(); +- return intBuffer.getInt(); ++ public static ZipFile createFrom(ByteBufferChannel channel) throws IOException { ++ checkNotNull(channel); ++ ZipParser zipParser = new ZipParser(channel); ++ Map<String, List<ZipEntry>> nameMap = zipParser.parseEntries(); ++ return new ZipFile(channel, nameMap); + } + +- /** +- * Gets the value as a Java long from four bytes starting at the current position of the +- * archive. +- */ +- private long getLong() { +- longBuffer.rewind(); +- archive.read(longBuffer); +- longBuffer.flip(); +- return longBuffer.getLong(); ++ @Override ++ public void close() { ++ archive.close(); + } + + /** +- * Positions the archive at the start of the central directory. ++ * Exposes the raw stream of the archive entry. ++ * ++ * <p>Since the associated files will not be compressed when being packed to the zip file, the ++ * raw stream represents the non-compressed files. + * +- * <p>First, it searches for the signature of the "end of central directory record", {@link +- * ZipConstants#ENDSIG}. Position the stream at the start of the "end of central directory +- * record". The zip file are created without archive comments, thus {@link ZipConstants#ENDSIG} +- * should appear exactly at {@link ZipConstants#ENDHDR} from the end of the zip file. ++ * <p><b>WARNING:</b> The returned {@link InputStream}, is <b>not</b> thread-safe. If multiple ++ * threads concurrently reading from the returned {@link InputStream}, it must be synchronized ++ * externally. + * +- * <p>Then, parse the "end of central dir record" and position the archive at the start of the +- * central directory. ++ * @param name name of the entry to get the stream for ++ * @return the raw input stream containing data ++ * @throws IllegalArgumentException if the specified file does not exist in the zip file + */ +- private void locateCentralDirectory() throws IOException { +- if (archive.size() < ZipConstants.ENDHDR) { +- throw new ZipException("The archive is not a ZIP archive."); +- } +- +- // Positions the archive at the start of the "end of central directory record". +- long offsetRecord = archive.size() - ZipConstants.ENDHDR; +- archive.position(offsetRecord); +- +- // Checks for the signature, {@link ZipConstants#ENDSIG}. +- long endSig = getLong(); +- if (endSig != ZipConstants.ENDSIG) { +- throw new ZipException("The archive is not a ZIP archive."); +- } +- +- // Positions the archive at the “offset of central directory”. +- skipBytes(ZipConstants.ENDOFF - ZipConstants.ENDSUB); +- // Gets the offset to central directory +- long offsetDirectory = getInt(); +- // Goes to the central directory. +- archive.position(offsetDirectory); ++ public InputStream getRawInputStream(String name) { ++ checkArgument(nameMap.containsKey(name), ++ String.format("The file, %s, does not exist in the zip file.", name)); ++ ++ List<ZipEntry> entriesWithTheSameName = nameMap.get(name); ++ ZipEntry entry = entriesWithTheSameName.get(0); ++ long start = entry.getDataOffset(); ++ long remaining = entry.getSize(); ++ return new BoundedInputStream(archive, start, remaining); + } + + /** +- * Reads the central directory of the given archive and populates the internal tables with +- * {@link ZipEntry} instances. ++ * Exposes the file names of the included files. ++ * ++ * @return the file names of the included files + */ +- private List<ZipEntry> parseCentralDirectory() throws IOException { +- /** List of entries in the order they appear inside the central directory. */ +- List<ZipEntry> entries = new ArrayList<>(); +- locateCentralDirectory(); +- +- while (foundCentralFileheaderSignature()) { +- ZipEntry entry = parseCentralDirectoryEntry(); +- entries.add(entry); +- } +- +- return entries; ++ public Set<String> getFileNames() { ++ return nameMap.keySet(); + } + +- /** +- * Reads an individual entry of the central directory, creats an ZipEntry from it and adds it to +- * the global maps. +- */ +- private ZipEntry parseCentralDirectoryEntry() throws IOException { +- // Positions the archive at the "compressed size" and read the value. +- skipBytes(ZipConstants.CENSIZ - ZipConstants.CENVEM); +- long compressSize = getInt(); +- +- // Positions the archive at the "filename length" and read the value. +- skipBytes(ZipConstants.CENNAM - ZipConstants.CENLEN); +- int fileNameLen = getShort(); +- +- // Reads the extra field length and the comment length. +- int extraLen = getShort(); +- int commentLen = getShort(); +- +- // Positions the archive at the "local file header offset" and read the value. +- skipBytes(ZipConstants.CENOFF - ZipConstants.CENDSK); +- long localHeaderOffset = getInt(); +- +- // Reads the file name. +- byte[] fileNameBuf = new byte[fileNameLen]; +- archive.read(ByteBuffer.wrap(fileNameBuf)); +- String fileName = new String(fileNameBuf, Charset.forName("UTF-8")); ++ private ZipFile(ByteBufferChannel channel, Map<String, List<ZipEntry>> nameMap) { ++ archive = channel; ++ this.nameMap = nameMap; ++ } + +- // Skips the extra field and the comment. +- skipBytes(extraLen + commentLen); ++ /* Parses a Zip archive and gets the information for each {@link ZipEntry}. */ ++ private static class ZipParser { ++ private final ByteBufferChannel archive; + +- ZipEntry entry = new ZipEntry(); +- entry.setSize(compressSize); +- entry.setLocalHeaderOffset(localHeaderOffset); +- entry.setName(fileName); ++ // Cached buffers that will only be used locally in the class to reduce garbage collection. ++ private final ByteBuffer longBuffer = ++ ByteBuffer.allocate(ZipConstants.LONG_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN); ++ private final ByteBuffer intBuffer = ++ ByteBuffer.allocate(ZipConstants.INT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN); ++ private final ByteBuffer shortBuffer = ++ ByteBuffer.allocate(ZipConstants.SHORT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN); + +- return entry; +- } ++ private ZipParser(ByteBufferChannel archive) { ++ this.archive = archive; ++ } + +- /** Walks through all recorded entries and records the offsets for the entry data. */ +- private Map<String, List<ZipEntry>> parseLocalFileHeaderData(List<ZipEntry> entries) { +- /** Maps String to list of ZipEntrys, name -> actual entries. */ +- Map<String, List<ZipEntry>> nameMap = new LinkedHashMap<>(); +- +- for (ZipEntry entry : entries) { +- long offset = entry.getLocalHeaderOffset(); +- archive.position(offset + ZipConstants.LOCNAM); +- +- // Gets the data offset of this entry. +- int fileNameLen = getShort(); +- int extraFieldLen = getShort(); +- long dataOffset = +- offset +- + ZipConstants.LOCEXT +- + ZipConstants.SHORT_BYTE_SIZE +- + fileNameLen +- + extraFieldLen; +- entry.setDataOffset(dataOffset); +- +- // Puts the entry into the nameMap. +- String name = entry.getName(); +- List<ZipEntry> entriesWithTheSameName; +- if (nameMap.containsKey(name)) { +- entriesWithTheSameName = nameMap.get(name); +- } else { +- entriesWithTheSameName = new ArrayList<>(); +- nameMap.put(name, entriesWithTheSameName); ++ /** ++ * Parses the underlying {@code archive} and returns the information as a list of {@link ++ * ZipEntry}. ++ */ ++ private Map<String, List<ZipEntry>> parseEntries() throws IOException { ++ List<ZipEntry> entries = parseCentralDirectory(); ++ return parseLocalFileHeaderData(entries); + } +- entriesWithTheSameName.add(entry); +- } + +- return nameMap; +- } ++ /** ++ * Checks if the current position contains a central file header signature, {@link ++ * ZipConstants#CENSIG}. ++ */ ++ private boolean foundCentralFileheaderSignature() { ++ long signature = (long) getInt(); ++ return signature == ZipConstants.CENSIG; ++ } + +- /** Skips the given number of bytes or throws an EOFException if skipping failed. */ +- private void skipBytes(int count) throws IOException { +- long currentPosition = archive.position(); +- long newPosition = currentPosition + count; +- if (newPosition > archive.size()) { +- throw new EOFException(); +- } +- archive.position(newPosition); +- } +- } ++ /** ++ * Gets the value as a Java int from two bytes starting at the current position of the ++ * archive. ++ */ ++ private int getShort() { ++ shortBuffer.rewind(); ++ archive.read(shortBuffer); ++ shortBuffer.flip(); ++ return (int) shortBuffer.getShort(); ++ } + +- /** Stores the data offset and the size of an entry in the archive. */ +- private static class ZipEntry { ++ /** ++ * Gets the value as a Java long from four bytes starting at the current position of the ++ * archive. ++ */ ++ private int getInt() { ++ intBuffer.rewind(); ++ archive.read(intBuffer); ++ intBuffer.flip(); ++ return intBuffer.getInt(); ++ } + +- private String name; +- private long dataOffset = -1; +- private long size = -1; +- private long localHeaderOffset = -1; ++ /** ++ * Gets the value as a Java long from four bytes starting at the current position of the ++ * archive. ++ */ ++ private long getLong() { ++ longBuffer.rewind(); ++ archive.read(longBuffer); ++ longBuffer.flip(); ++ return longBuffer.getLong(); ++ } + +- public long getSize() { +- return size; +- } ++ /** ++ * Positions the archive at the start of the central directory. ++ * ++ * <p>First, it searches for the signature of the "end of central directory record", {@link ++ * ZipConstants#ENDSIG}. Position the stream at the start of the "end of central directory ++ * record". The zip file are created without archive comments, thus {@link ++ * ZipConstants#ENDSIG} should appear exactly at {@link ZipConstants#ENDHDR} from the end of ++ * the zip file. ++ * ++ * <p>Then, parse the "end of central dir record" and position the archive at the start of ++ * the central directory. ++ */ ++ private void locateCentralDirectory() throws IOException { ++ if (archive.size() < ZipConstants.ENDHDR) { ++ throw new ZipException("The archive is not a ZIP archive."); ++ } ++ ++ // Positions the archive at the start of the "end of central directory record". ++ long offsetRecord = archive.size() - ZipConstants.ENDHDR; ++ archive.position(offsetRecord); ++ ++ // Checks for the signature, {@link ZipConstants#ENDSIG}. ++ long endSig = getLong(); ++ if (endSig != ZipConstants.ENDSIG) { ++ throw new ZipException("The archive is not a ZIP archive."); ++ } ++ ++ // Positions the archive at the “offset of central directory”. ++ skipBytes(ZipConstants.ENDOFF - ZipConstants.ENDSUB); ++ // Gets the offset to central directory ++ long offsetDirectory = getInt(); ++ // Goes to the central directory. ++ archive.position(offsetDirectory); ++ } + +- public long getDataOffset() { +- return dataOffset; +- } ++ /** ++ * Reads the central directory of the given archive and populates the internal tables with ++ * {@link ZipEntry} instances. ++ */ ++ private List<ZipEntry> parseCentralDirectory() throws IOException { ++ /** List of entries in the order they appear inside the central directory. */ ++ List<ZipEntry> entries = new ArrayList<>(); ++ locateCentralDirectory(); ++ ++ while (foundCentralFileheaderSignature()) { ++ ZipEntry entry = parseCentralDirectoryEntry(); ++ entries.add(entry); ++ } ++ ++ return entries; ++ } + +- public String getName() { +- return name; +- } ++ /** ++ * Reads an individual entry of the central directory, creats an ZipEntry from it and adds ++ * it to the global maps. ++ */ ++ private ZipEntry parseCentralDirectoryEntry() throws IOException { ++ // Positions the archive at the "compressed size" and read the value. ++ skipBytes(ZipConstants.CENSIZ - ZipConstants.CENVEM); ++ long compressSize = getInt(); ++ ++ // Positions the archive at the "filename length" and read the value. ++ skipBytes(ZipConstants.CENNAM - ZipConstants.CENLEN); ++ int fileNameLen = getShort(); ++ ++ // Reads the extra field length and the comment length. ++ int extraLen = getShort(); ++ int commentLen = getShort(); ++ ++ // Positions the archive at the "local file header offset" and read the value. ++ skipBytes(ZipConstants.CENOFF - ZipConstants.CENDSK); ++ long localHeaderOffset = getInt(); ++ ++ // Reads the file name. ++ byte[] fileNameBuf = new byte[fileNameLen]; ++ archive.read(ByteBuffer.wrap(fileNameBuf)); ++ String fileName = new String(fileNameBuf, Charset.forName("UTF-8")); ++ ++ // Skips the extra field and the comment. ++ skipBytes(extraLen + commentLen); ++ ++ ZipEntry entry = new ZipEntry(); ++ entry.setSize(compressSize); ++ entry.setLocalHeaderOffset(localHeaderOffset); ++ entry.setName(fileName); ++ ++ return entry; ++ } + +- public long getLocalHeaderOffset() { +- return localHeaderOffset; +- } ++ /** Walks through all recorded entries and records the offsets for the entry data. */ ++ private Map<String, List<ZipEntry>> parseLocalFileHeaderData(List<ZipEntry> entries) { ++ /** Maps String to list of ZipEntrys, name -> actual entries. */ ++ Map<String, List<ZipEntry>> nameMap = new LinkedHashMap<>(); ++ ++ for (ZipEntry entry : entries) { ++ long offset = entry.getLocalHeaderOffset(); ++ archive.position(offset + ZipConstants.LOCNAM); ++ ++ // Gets the data offset of this entry. ++ int fileNameLen = getShort(); ++ int extraFieldLen = getShort(); ++ long dataOffset = offset + ZipConstants.LOCEXT + ZipConstants.SHORT_BYTE_SIZE ++ + fileNameLen + extraFieldLen; ++ entry.setDataOffset(dataOffset); ++ ++ // Puts the entry into the nameMap. ++ String name = entry.getName(); ++ List<ZipEntry> entriesWithTheSameName; ++ if (nameMap.containsKey(name)) { ++ entriesWithTheSameName = nameMap.get(name); ++ } else { ++ entriesWithTheSameName = new ArrayList<>(); ++ nameMap.put(name, entriesWithTheSameName); ++ } ++ entriesWithTheSameName.add(entry); ++ } ++ ++ return nameMap; ++ } + +- public void setSize(long size) { +- this.size = size; ++ /** Skips the given number of bytes or throws an EOFException if skipping failed. */ ++ private void skipBytes(int count) throws IOException { ++ long currentPosition = archive.position(); ++ long newPosition = currentPosition + count; ++ if (newPosition > archive.size()) { ++ throw new EOFException(); ++ } ++ archive.position(newPosition); ++ } + } + +- public void setDataOffset(long dataOffset) { +- this.dataOffset = dataOffset; +- } ++ /** Stores the data offset and the size of an entry in the archive. */ ++ private static class ZipEntry { ++ private String name; ++ private long dataOffset = -1; ++ private long size = -1; ++ private long localHeaderOffset = -1; + +- public void setName(String name) { +- this.name = name; +- } ++ public long getSize() { ++ return size; ++ } + +- public void setLocalHeaderOffset(long localHeaderOffset) { +- this.localHeaderOffset = localHeaderOffset; +- } +- } ++ public long getDataOffset() { ++ return dataOffset; ++ } + +- /** +- * Various constants for this {@link ZipFile}. +- * +- * <p>Referenced from {@link java.util.zip.ZipConstants}. +- */ +- private static class ZipConstants { +- /** length of Java short in bytes. */ +- static final int SHORT_BYTE_SIZE = Short.SIZE / 8; ++ public String getName() { ++ return name; ++ } + +- /** length of Java int in bytes. */ +- static final int INT_BYTE_SIZE = Integer.SIZE / 8; ++ public long getLocalHeaderOffset() { ++ return localHeaderOffset; ++ } + +- /** length of Java long in bytes. */ +- static final int LONG_BYTE_SIZE = Long.SIZE / 8; ++ public void setSize(long size) { ++ this.size = size; ++ } + +- /* +- * Header signatures +- */ +- static final long LOCSIG = 0x04034b50L; // "PK\003\004" +- static final long EXTSIG = 0x08074b50L; // "PK\007\008" +- static final long CENSIG = 0x02014b50L; // "PK\001\002" +- static final long ENDSIG = 0x06054b50L; // "PK\005\006" ++ public void setDataOffset(long dataOffset) { ++ this.dataOffset = dataOffset; ++ } + +- /* +- * Header sizes in bytes (including signatures) +- */ +- static final int LOCHDR = 30; // LOC header size +- static final int EXTHDR = 16; // EXT header size +- static final int CENHDR = 46; // CEN header size +- static final int ENDHDR = 22; // END header size ++ public void setName(String name) { ++ this.name = name; ++ } + +- /* +- * Local file (LOC) header field offsets +- */ +- static final int LOCVER = 4; // version needed to extract +- static final int LOCFLG = 6; // general purpose bit flag +- static final int LOCHOW = 8; // compression method +- static final int LOCTIM = 10; // modification time +- static final int LOCCRC = 14; // uncompressed file crc-32 value +- static final int LOCSIZ = 18; // compressed size +- static final int LOCLEN = 22; // uncompressed size +- static final int LOCNAM = 26; // filename length +- static final int LOCEXT = 28; // extra field length +- +- /* +- * Extra local (EXT) header field offsets +- */ +- static final int EXTCRC = 4; // uncompressed file crc-32 value +- static final int EXTSIZ = 8; // compressed size +- static final int EXTLEN = 12; // uncompressed size ++ public void setLocalHeaderOffset(long localHeaderOffset) { ++ this.localHeaderOffset = localHeaderOffset; ++ } ++ } + +- /* +- * Central directory (CEN) header field offsets +- */ +- static final int CENVEM = 4; // version made by +- static final int CENVER = 6; // version needed to extract +- static final int CENFLG = 8; // encrypt, decrypt flags +- static final int CENHOW = 10; // compression method +- static final int CENTIM = 12; // modification time +- static final int CENCRC = 16; // uncompressed file crc-32 value +- static final int CENSIZ = 20; // compressed size +- static final int CENLEN = 24; // uncompressed size +- static final int CENNAM = 28; // filename length +- static final int CENEXT = 30; // extra field length +- static final int CENCOM = 32; // comment length +- static final int CENDSK = 34; // disk number start +- static final int CENATT = 36; // internal file attributes +- static final int CENATX = 38; // external file attributes +- static final int CENOFF = 42; // LOC header offset +- +- /* +- * End of central directory (END) header field offsets ++ /** ++ * Various constants for this {@link ZipFile}. ++ * ++ * <p>Referenced from {@link java.util.zip.ZipConstants}. + */ +- static final int ENDSUB = 8; // number of entries on this disk +- static final int ENDTOT = 10; // total number of entries +- static final int ENDSIZ = 12; // central directory size in bytes +- static final int ENDOFF = 16; // offset of first CEN header +- static final int ENDCOM = 20; // zip file comment length +- +- private ZipConstants() {} +- } ++ private static class ZipConstants { ++ /** length of Java short in bytes. */ ++ static final int SHORT_BYTE_SIZE = Short.SIZE / 8; ++ ++ /** length of Java int in bytes. */ ++ static final int INT_BYTE_SIZE = Integer.SIZE / 8; ++ ++ /** length of Java long in bytes. */ ++ static final int LONG_BYTE_SIZE = Long.SIZE / 8; ++ ++ /* ++ * Header signatures ++ */ ++ static final long LOCSIG = 0x04034b50L; // "PK\003\004" ++ static final long EXTSIG = 0x08074b50L; // "PK\007\008" ++ static final long CENSIG = 0x02014b50L; // "PK\001\002" ++ static final long ENDSIG = 0x06054b50L; // "PK\005\006" ++ ++ /* ++ * Header sizes in bytes (including signatures) ++ */ ++ static final int LOCHDR = 30; // LOC header size ++ static final int EXTHDR = 16; // EXT header size ++ static final int CENHDR = 46; // CEN header size ++ static final int ENDHDR = 22; // END header size ++ ++ /* ++ * Local file (LOC) header field offsets ++ */ ++ static final int LOCVER = 4; // version needed to extract ++ static final int LOCFLG = 6; // general purpose bit flag ++ static final int LOCHOW = 8; // compression method ++ static final int LOCTIM = 10; // modification time ++ static final int LOCCRC = 14; // uncompressed file crc-32 value ++ static final int LOCSIZ = 18; // compressed size ++ static final int LOCLEN = 22; // uncompressed size ++ static final int LOCNAM = 26; // filename length ++ static final int LOCEXT = 28; // extra field length ++ ++ /* ++ * Extra local (EXT) header field offsets ++ */ ++ static final int EXTCRC = 4; // uncompressed file crc-32 value ++ static final int EXTSIZ = 8; // compressed size ++ static final int EXTLEN = 12; // uncompressed size ++ ++ /* ++ * Central directory (CEN) header field offsets ++ */ ++ static final int CENVEM = 4; // version made by ++ static final int CENVER = 6; // version needed to extract ++ static final int CENFLG = 8; // encrypt, decrypt flags ++ static final int CENHOW = 10; // compression method ++ static final int CENTIM = 12; // modification time ++ static final int CENCRC = 16; // uncompressed file crc-32 value ++ static final int CENSIZ = 20; // compressed size ++ static final int CENLEN = 24; // uncompressed size ++ static final int CENNAM = 28; // filename length ++ static final int CENEXT = 30; // extra field length ++ static final int CENCOM = 32; // comment length ++ static final int CENDSK = 34; // disk number start ++ static final int CENATT = 36; // internal file attributes ++ static final int CENATX = 38; // external file attributes ++ static final int CENOFF = 42; // LOC header offset ++ ++ /* ++ * End of central directory (END) header field offsets ++ */ ++ static final int ENDSUB = 8; // number of entries on this disk ++ static final int ENDTOT = 10; // total number of entries ++ static final int ENDSIZ = 12; // central directory size in bytes ++ static final int ENDOFF = 16; // offset of first CEN header ++ static final int ENDCOM = 20; // zip file comment length ++ ++ private ZipConstants() {} ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java +index 3847bc1d2ce01..e0825a1fe7862 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java +@@ -16,244 +16,223 @@ limitations under the License. + package org.tensorflow.lite.support.metadata; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertArrayEquals; + import static org.junit.Assert.assertThrows; + +-import java.nio.ByteBuffer; + import org.junit.Test; + import org.junit.runner.RunWith; + import org.robolectric.RobolectricTestRunner; + ++import java.nio.ByteBuffer; ++ + /** Tests of {@link BoundedInputStream}. */ + @RunWith(RobolectricTestRunner.class) + public class BoundedInputStreamTest { ++ private static final byte[] testBytes = new byte[] {10, 20, 30, 40, 50}; ++ private static final int[] testInts = new int[] {10, 20, 30, 40, 50}; ++ private static final int TEST_BYTES_LENGTH = testBytes.length; ++ ++ @Test ++ public void boundedInputStream_negtiveStart_throwsException() throws Exception { ++ long start = -1; ++ long remaining = 2; ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> createBoundedInputStream(testBytes, start, remaining)); ++ assertThat(exception).hasMessageThat().isEqualTo(String.format( ++ "Invalid length of stream at offset=%d, length=%d", start, remaining)); ++ } ++ ++ @Test ++ public void boundedInputStream_negtiveRemaining_throwsException() throws Exception { ++ long start = 1; ++ long remaining = -2; ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> createBoundedInputStream(testBytes, start, remaining)); ++ assertThat(exception).hasMessageThat().isEqualTo(String.format( ++ "Invalid length of stream at offset=%d, length=%d", start, remaining)); ++ } ++ ++ @Test ++ public void available_atStart() throws Exception { ++ int start = 3; ++ BoundedInputStream boundedInputStream = ++ createBoundedInputStream(testBytes, start, TEST_BYTES_LENGTH); ++ ++ int available = boundedInputStream.available(); ++ assertThat(available).isEqualTo(TEST_BYTES_LENGTH - start); ++ } ++ ++ @Test ++ public void available_afterRead() throws Exception { ++ BoundedInputStream boundedInputStream = ++ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); ++ // Read a byte out of boundedInputStream. The number of remaining bytes is TEST_BYTES_LENGTH ++ // -1. ++ boundedInputStream.read(); ++ ++ int available = boundedInputStream.available(); ++ assertThat(available).isEqualTo(TEST_BYTES_LENGTH - 1); ++ } ++ ++ @Test ++ public void read_repeatedRead() throws Exception { ++ int[] values = new int[TEST_BYTES_LENGTH]; ++ BoundedInputStream boundedInputStream = ++ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); ++ ++ for (int i = 0; i < TEST_BYTES_LENGTH; i++) { ++ values[i] = boundedInputStream.read(); ++ } ++ ++ assertArrayEquals(testInts, values); ++ } ++ ++ @Test ++ public void read_reachTheEnd() throws Exception { ++ BoundedInputStream boundedInputStream = ++ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); ++ boundedInputStream.skip(TEST_BYTES_LENGTH); ++ int value = boundedInputStream.read(); ++ ++ assertThat(value).isEqualTo(-1); ++ } ++ ++ @Test ++ public void read_channelSizeisSmallerThanTheStreamSpecified() throws Exception { ++ BoundedInputStream boundedInputStream = ++ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH + 1); ++ boundedInputStream.skip(TEST_BYTES_LENGTH); ++ ++ int value = boundedInputStream.read(); ++ ++ assertThat(value).isEqualTo(-1); ++ } + +- private static final byte[] testBytes = new byte[] {10, 20, 30, 40, 50}; +- private static final int[] testInts = new int[] {10, 20, 30, 40, 50}; +- private static final int TEST_BYTES_LENGTH = testBytes.length; +- +- @Test +- public void boundedInputStream_negtiveStart_throwsException() throws Exception { +- long start = -1; +- long remaining = 2; +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, +- () -> createBoundedInputStream(testBytes, start, remaining)); +- assertThat(exception) +- .hasMessageThat() +- .isEqualTo( +- String.format("Invalid length of stream at offset=%d, length=%d", start, remaining)); +- } +- +- @Test +- public void boundedInputStream_negtiveRemaining_throwsException() throws Exception { +- long start = 1; +- long remaining = -2; +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, +- () -> createBoundedInputStream(testBytes, start, remaining)); +- assertThat(exception) +- .hasMessageThat() +- .isEqualTo( +- String.format("Invalid length of stream at offset=%d, length=%d", start, remaining)); +- } +- +- @Test +- public void available_atStart() throws Exception { +- int start = 3; +- BoundedInputStream boundedInputStream = +- createBoundedInputStream(testBytes, start, TEST_BYTES_LENGTH); +- +- int available = boundedInputStream.available(); +- assertThat(available).isEqualTo(TEST_BYTES_LENGTH - start); +- } +- +- @Test +- public void available_afterRead() throws Exception { +- BoundedInputStream boundedInputStream = +- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); +- // Read a byte out of boundedInputStream. The number of remaining bytes is TEST_BYTES_LENGTH -1. +- boundedInputStream.read(); +- +- int available = boundedInputStream.available(); +- assertThat(available).isEqualTo(TEST_BYTES_LENGTH - 1); +- } +- +- @Test +- public void read_repeatedRead() throws Exception { +- int[] values = new int[TEST_BYTES_LENGTH]; +- BoundedInputStream boundedInputStream = +- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); +- +- for (int i = 0; i < TEST_BYTES_LENGTH; i++) { +- values[i] = boundedInputStream.read(); ++ @Test ++ public void readArray_nullArray_throwsException() throws Exception { ++ byte[] array = null; ++ int offset = 0; ++ int length = 1; ++ BoundedInputStream boundedInputStream = ++ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); ++ ++ NullPointerException exception = assertThrows( ++ NullPointerException.class, () -> boundedInputStream.read(array, offset, length)); ++ assertThat(exception).hasMessageThat().isEqualTo("The object reference is null."); + } + +- assertArrayEquals(testInts, values); +- } +- +- @Test +- public void read_reachTheEnd() throws Exception { +- BoundedInputStream boundedInputStream = +- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); +- boundedInputStream.skip(TEST_BYTES_LENGTH); +- int value = boundedInputStream.read(); +- +- assertThat(value).isEqualTo(-1); +- } +- +- @Test +- public void read_channelSizeisSmallerThanTheStreamSpecified() throws Exception { +- BoundedInputStream boundedInputStream = +- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH + 1); +- boundedInputStream.skip(TEST_BYTES_LENGTH); +- +- int value = boundedInputStream.read(); +- +- assertThat(value).isEqualTo(-1); +- } +- +- @Test +- public void readArray_nullArray_throwsException() throws Exception { +- byte[] array = null; +- int offset = 0; +- int length = 1; +- BoundedInputStream boundedInputStream = +- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); +- +- NullPointerException exception = +- assertThrows( +- NullPointerException.class, () -> boundedInputStream.read(array, offset, length)); +- assertThat(exception).hasMessageThat().isEqualTo("The object reference is null."); +- } +- +- @Test +- public void readArray_negativeOffset_throwsException() throws Exception { +- byte[] array = new byte[5]; +- int offset = -1; +- int length = array.length; +- BoundedInputStream boundedInputStream = +- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); +- +- IndexOutOfBoundsException exception = +- assertThrows( +- IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length)); +- assertThat(exception) +- .hasMessageThat() +- .isEqualTo(String.format("The start offset (%s) must not be negative", offset)); +- } +- +- @Test +- public void readArray_OffsetEqualsArrayLength_throwsException() throws Exception { +- byte[] array = new byte[5]; +- int offset = array.length; +- int length = 0; +- BoundedInputStream boundedInputStream = +- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); +- +- IndexOutOfBoundsException exception = +- assertThrows( +- IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length)); +- assertThat(exception) +- .hasMessageThat() +- .isEqualTo( +- String.format( ++ @Test ++ public void readArray_negativeOffset_throwsException() throws Exception { ++ byte[] array = new byte[5]; ++ int offset = -1; ++ int length = array.length; ++ BoundedInputStream boundedInputStream = ++ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); ++ ++ IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class, ++ () -> boundedInputStream.read(array, offset, length)); ++ assertThat(exception).hasMessageThat().isEqualTo( ++ String.format("The start offset (%s) must not be negative", offset)); ++ } ++ ++ @Test ++ public void readArray_OffsetEqualsArrayLength_throwsException() throws Exception { ++ byte[] array = new byte[5]; ++ int offset = array.length; ++ int length = 0; ++ BoundedInputStream boundedInputStream = ++ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); ++ ++ IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class, ++ () -> boundedInputStream.read(array, offset, length)); ++ assertThat(exception).hasMessageThat().isEqualTo(String.format( + "The start offset (%s) must be less than size (%s)", offset, array.length)); +- } +- +- @Test +- public void readArray_negativeLength_throwsException() throws Exception { +- byte[] array = new byte[5]; +- int offset = 0; +- int length = -1; +- BoundedInputStream boundedInputStream = +- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); +- +- IndexOutOfBoundsException exception = +- assertThrows( +- IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length)); +- assertThat(exception) +- .hasMessageThat() +- .isEqualTo( +- String.format( ++ } ++ ++ @Test ++ public void readArray_negativeLength_throwsException() throws Exception { ++ byte[] array = new byte[5]; ++ int offset = 0; ++ int length = -1; ++ BoundedInputStream boundedInputStream = ++ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); ++ ++ IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class, ++ () -> boundedInputStream.read(array, offset, length)); ++ assertThat(exception).hasMessageThat().isEqualTo(String.format( + "The maximumn number of bytes to read (%s) must not be negative", length)); +- } +- +- @Test +- public void readArray_exceedEndOfArray_throwsException() throws Exception { +- byte[] array = new byte[5]; +- int offset = 0; +- int length = array.length + 1; +- BoundedInputStream boundedInputStream = +- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); +- +- IndexOutOfBoundsException exception = +- assertThrows( +- IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length)); +- assertThat(exception) +- .hasMessageThat() +- .isEqualTo( +- String.format( +- "The maximumn number of bytes to read (%s) must be less than size (%s)", +- length, array.length - offset + 1)); +- } +- +- @Test +- public void readArray_zeroLength() throws Exception { +- byte[] array = new byte[5]; +- int offset = 0; +- int length = 0; +- BoundedInputStream boundedInputStream = +- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); +- +- int value = boundedInputStream.read(array, offset, length); +- assertThat(value).isEqualTo(0); +- } +- +- @Test +- public void readArray_exceedEndOfStream() throws Exception { +- byte[] array = new byte[5]; +- int offset = 0; +- int length = 1; +- BoundedInputStream boundedInputStream = +- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); +- +- // Move the position of the stream to the end. +- boundedInputStream.skip(TEST_BYTES_LENGTH); +- +- int value = boundedInputStream.read(array, offset, length); +- +- assertThat(value).isEqualTo(-1); +- } +- +- @Test +- public void readArray_lengthGreaterThanStreamRemaining() throws Exception { +- byte[] array = new byte[5]; +- int offset = 1; +- int length = array.length - 1; // 4 +- BoundedInputStream boundedInputStream = +- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); +- +- // Moves the position of the stream to end-2. +- boundedInputStream.skip(TEST_BYTES_LENGTH - 2); +- +- // Reads the last two bytes of the stream to the array, and put the data at offset 1. +- int value = boundedInputStream.read(array, offset, length); +- +- byte[] expectedArray = new byte[] {0, 40, 50, 0, 0}; +- assertArrayEquals(expectedArray, array); +- assertThat(value).isEqualTo(2); +- +- // Reachs the end of the stream, thus cannot read anymore. +- assertThat(boundedInputStream.read()).isEqualTo(-1); +- } +- +- private static BoundedInputStream createBoundedInputStream( +- final byte[] testBytes, long start, long remaining) { +- ByteBuffer buffer = ByteBuffer.wrap(testBytes); +- SeekableByteChannelCompat channel = new ByteBufferChannel(buffer); +- return new BoundedInputStream(channel, start, remaining); +- } ++ } ++ ++ @Test ++ public void readArray_exceedEndOfArray_throwsException() throws Exception { ++ byte[] array = new byte[5]; ++ int offset = 0; ++ int length = array.length + 1; ++ BoundedInputStream boundedInputStream = ++ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); ++ ++ IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class, ++ () -> boundedInputStream.read(array, offset, length)); ++ assertThat(exception).hasMessageThat().isEqualTo(String.format( ++ "The maximumn number of bytes to read (%s) must be less than size (%s)", length, ++ array.length - offset + 1)); ++ } ++ ++ @Test ++ public void readArray_zeroLength() throws Exception { ++ byte[] array = new byte[5]; ++ int offset = 0; ++ int length = 0; ++ BoundedInputStream boundedInputStream = ++ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); ++ ++ int value = boundedInputStream.read(array, offset, length); ++ assertThat(value).isEqualTo(0); ++ } ++ ++ @Test ++ public void readArray_exceedEndOfStream() throws Exception { ++ byte[] array = new byte[5]; ++ int offset = 0; ++ int length = 1; ++ BoundedInputStream boundedInputStream = ++ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); ++ ++ // Move the position of the stream to the end. ++ boundedInputStream.skip(TEST_BYTES_LENGTH); ++ ++ int value = boundedInputStream.read(array, offset, length); ++ ++ assertThat(value).isEqualTo(-1); ++ } ++ ++ @Test ++ public void readArray_lengthGreaterThanStreamRemaining() throws Exception { ++ byte[] array = new byte[5]; ++ int offset = 1; ++ int length = array.length - 1; // 4 ++ BoundedInputStream boundedInputStream = ++ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH); ++ ++ // Moves the position of the stream to end-2. ++ boundedInputStream.skip(TEST_BYTES_LENGTH - 2); ++ ++ // Reads the last two bytes of the stream to the array, and put the data at offset 1. ++ int value = boundedInputStream.read(array, offset, length); ++ ++ byte[] expectedArray = new byte[] {0, 40, 50, 0, 0}; ++ assertArrayEquals(expectedArray, array); ++ assertThat(value).isEqualTo(2); ++ ++ // Reachs the end of the stream, thus cannot read anymore. ++ assertThat(boundedInputStream.read()).isEqualTo(-1); ++ } ++ ++ private static BoundedInputStream createBoundedInputStream( ++ final byte[] testBytes, long start, long remaining) { ++ ByteBuffer buffer = ByteBuffer.wrap(testBytes); ++ SeekableByteChannelCompat channel = new ByteBufferChannel(buffer); ++ return new BoundedInputStream(channel, start, remaining); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java +index abda43058aa90..ce625de8034b7 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java +@@ -16,254 +16,252 @@ limitations under the License. + package org.tensorflow.lite.support.metadata; + + import static com.google.common.truth.Truth.assertThat; +-import static java.nio.charset.StandardCharsets.UTF_8; ++ + import static org.junit.Assert.assertThrows; + +-import java.nio.ByteBuffer; ++import static java.nio.charset.StandardCharsets.UTF_8; ++ + import org.junit.Test; + import org.junit.runner.RunWith; + import org.robolectric.RobolectricTestRunner; + ++import java.nio.ByteBuffer; ++ + /** Tests of {@link ByteBufferChannel}. */ + @RunWith(RobolectricTestRunner.class) + public final class ByteBufferChannelTest { +- private static final String VALID_STRING = "1234567890"; +- private final ByteBuffer validByteBuffer = ByteBuffer.wrap(VALID_STRING.getBytes(UTF_8)); +- private final int validByteBufferLength = validByteBuffer.limit(); +- +- @Test +- public void byteBufferChannel_validByteBuffer() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- assertThat(byteBufferChannel).isNotNull(); +- } +- +- @Test +- public void byteBufferChannel_nullByteBuffer_throwsException() { +- NullPointerException exception = +- assertThrows(NullPointerException.class, () -> new ByteBufferChannel(/*buffer=*/ null)); +- assertThat(exception).hasMessageThat().isEqualTo("The ByteBuffer cannot be null."); +- } +- +- @Test +- public void isOpen_openedByteBufferChannel() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- assertThat(byteBufferChannel.isOpen()).isTrue(); +- } +- +- @Test +- public void position_newByteBufferChannelWithInitialPosition0() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- long position = byteBufferChannel.position(); +- +- long expectedPosition = 0; +- assertThat(position).isEqualTo(expectedPosition); +- } +- +- @Test +- public void position_validNewPosition() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- long validNewPosition = 6; +- +- byteBufferChannel.position(validNewPosition); +- long position = byteBufferChannel.position(); +- +- assertThat(position).isEqualTo(validNewPosition); +- } +- +- @Test +- public void position_negtiveNewPosition_throwsException() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- long invalidNewPosition = -1; +- +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> byteBufferChannel.position(invalidNewPosition)); +- assertThat(exception) +- .hasMessageThat() +- .isEqualTo("The new position should be non-negative and be less than Integer.MAX_VALUE."); +- } +- +- @Test +- public void position_newPositionGreaterThanMaxIntegerValue_throwsException() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- long invalidNewPosition = Integer.MAX_VALUE + 1; +- +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> byteBufferChannel.position(invalidNewPosition)); +- assertThat(exception) +- .hasMessageThat() +- .isEqualTo("The new position should be non-negative and be less than Integer.MAX_VALUE."); +- } +- +- @Test +- public void position_newPositionGreaterThanByfferLength_throwsException() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- long invalidNewPosition = (long) validByteBufferLength + 1; +- +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> byteBufferChannel.position(invalidNewPosition)); +- assertThat(exception).hasMessageThat().isEqualTo("newPosition > limit: (11 > 10)"); +- } +- +- @Test +- public void read_fromPosition0() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- long validNewPosition = 0; +- +- byteBufferChannel.position(validNewPosition); +- ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength); +- int numBytes = byteBufferChannel.read(dstBuffer); +- +- assertThat(numBytes).isEqualTo(validByteBufferLength); +- assertThat(dstBuffer).isEqualTo(validByteBuffer); +- } +- +- @Test +- public void read_fromPosition5() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- long validNewPosition = 5; +- +- byteBufferChannel.position(validNewPosition); +- ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength); +- int numBytes = byteBufferChannel.read(dstBuffer); +- +- assertThat(numBytes).isEqualTo(validByteBufferLength - (int) validNewPosition); +- String dstString = convertByteByfferToString(dstBuffer, numBytes); +- String expectedString = "67890"; +- assertThat(dstString).isEqualTo(expectedString); +- } +- +- @Test +- public void read_fromPositionValidByteBufferLength() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- long validNewPosition = validByteBufferLength; +- +- byteBufferChannel.position(validNewPosition); +- ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength); +- int numBytes = byteBufferChannel.read(dstBuffer); +- +- assertThat(numBytes).isEqualTo(-1); +- } +- +- @Test +- public void read_dstBufferRemaining0() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- long validNewPosition = 0; +- +- byteBufferChannel.position(validNewPosition); +- ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength); +- dstBuffer.position(validByteBufferLength); +- int numBytes = byteBufferChannel.read(dstBuffer); +- +- assertThat(numBytes).isEqualTo(0); +- String dstString = convertByteByfferToString(dstBuffer, numBytes); +- String expectedString = ""; +- assertThat(dstString).isEqualTo(expectedString); +- } +- +- @Test +- public void read_dstBufferIsSmallerThanTheBufferChannel() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- int dstBufferLength = 3; +- +- ByteBuffer dstBuffer = ByteBuffer.allocate(dstBufferLength); +- int numBytes = byteBufferChannel.read(dstBuffer); +- +- assertThat(numBytes).isEqualTo(dstBufferLength); +- assertThat(validByteBuffer.position()).isEqualTo(dstBufferLength); +- +- String dstString = convertByteByfferToString(dstBuffer, dstBufferLength); +- String expectedString = "123"; +- assertThat(dstString).isEqualTo(expectedString); +- } +- +- @Test +- public void size_validBuffer() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- assertThat(byteBufferChannel.size()).isEqualTo(validByteBufferLength); +- } +- +- @Test +- public void truncate_validSizeAndPosition0() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- long truncateSize = 3; +- +- byteBufferChannel.truncate(truncateSize); +- +- assertThat(byteBufferChannel.size()).isEqualTo(truncateSize); +- assertThat(byteBufferChannel.position()).isEqualTo(0); +- } +- +- @Test +- public void truncate_validSizeAndPosition5() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- long validNewPosition = 5; +- +- byteBufferChannel.position(validNewPosition); +- long truncateSize = 3; +- byteBufferChannel.truncate(truncateSize); +- +- assertThat(byteBufferChannel.position()).isEqualTo(truncateSize); +- } +- +- @Test +- public void truncate_sizeNotSmallerThanBufferSize() { +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- long truncateSize = (long) validByteBufferLength; +- +- byteBufferChannel.truncate(truncateSize); +- +- assertThat(byteBufferChannel.position()).isEqualTo(0); +- } +- +- @Test +- public void write_srcBufferSmallerThanBufferChannel() { +- String srcString = "5555"; +- long newPosition = 3; +- String expectedString = "1235555890"; +- ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8)); +- +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- byteBufferChannel.position(newPosition); +- byteBufferChannel.write(srcBuffer); +- +- assertThat(byteBufferChannel.position()).isEqualTo(newPosition + srcString.length()); +- ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength); +- byteBufferChannel.position(0); +- byteBufferChannel.read(dstBuffer); +- ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8)); +- dstBuffer.rewind(); +- expectedBuffer.rewind(); +- assertThat(dstBuffer).isEqualTo(expectedBuffer); +- } +- +- @Test +- public void write_srcBufferGreaterThanBufferChannel() { +- String srcString = "5555"; +- long newPosition = 8; +- String expectedString = "1234567855"; +- ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8)); +- +- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); +- byteBufferChannel.position(newPosition); +- byteBufferChannel.write(srcBuffer); +- assertThat(byteBufferChannel.position()).isEqualTo(validByteBufferLength); +- +- ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength); +- byteBufferChannel.position(0); +- byteBufferChannel.read(dstBuffer); +- ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8)); +- dstBuffer.rewind(); +- expectedBuffer.rewind(); +- assertThat(dstBuffer).isEqualTo(expectedBuffer); +- } +- +- private static String convertByteByfferToString(ByteBuffer buffer, int arrLength) { +- byte[] bytes = new byte[arrLength]; +- buffer.rewind(); +- buffer.get(bytes); +- return new String(bytes, UTF_8); +- } ++ private static final String VALID_STRING = "1234567890"; ++ private final ByteBuffer validByteBuffer = ByteBuffer.wrap(VALID_STRING.getBytes(UTF_8)); ++ private final int validByteBufferLength = validByteBuffer.limit(); ++ ++ @Test ++ public void byteBufferChannel_validByteBuffer() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ assertThat(byteBufferChannel).isNotNull(); ++ } ++ ++ @Test ++ public void byteBufferChannel_nullByteBuffer_throwsException() { ++ NullPointerException exception = assertThrows( ++ NullPointerException.class, () -> new ByteBufferChannel(/*buffer=*/null)); ++ assertThat(exception).hasMessageThat().isEqualTo("The ByteBuffer cannot be null."); ++ } ++ ++ @Test ++ public void isOpen_openedByteBufferChannel() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ assertThat(byteBufferChannel.isOpen()).isTrue(); ++ } ++ ++ @Test ++ public void position_newByteBufferChannelWithInitialPosition0() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ long position = byteBufferChannel.position(); ++ ++ long expectedPosition = 0; ++ assertThat(position).isEqualTo(expectedPosition); ++ } ++ ++ @Test ++ public void position_validNewPosition() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ long validNewPosition = 6; ++ ++ byteBufferChannel.position(validNewPosition); ++ long position = byteBufferChannel.position(); ++ ++ assertThat(position).isEqualTo(validNewPosition); ++ } ++ ++ @Test ++ public void position_negtiveNewPosition_throwsException() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ long invalidNewPosition = -1; ++ ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> byteBufferChannel.position(invalidNewPosition)); ++ assertThat(exception).hasMessageThat().isEqualTo( ++ "The new position should be non-negative and be less than Integer.MAX_VALUE."); ++ } ++ ++ @Test ++ public void position_newPositionGreaterThanMaxIntegerValue_throwsException() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ long invalidNewPosition = Integer.MAX_VALUE + 1; ++ ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> byteBufferChannel.position(invalidNewPosition)); ++ assertThat(exception).hasMessageThat().isEqualTo( ++ "The new position should be non-negative and be less than Integer.MAX_VALUE."); ++ } ++ ++ @Test ++ public void position_newPositionGreaterThanByfferLength_throwsException() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ long invalidNewPosition = (long) validByteBufferLength + 1; ++ ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> byteBufferChannel.position(invalidNewPosition)); ++ assertThat(exception).hasMessageThat().isEqualTo("newPosition > limit: (11 > 10)"); ++ } ++ ++ @Test ++ public void read_fromPosition0() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ long validNewPosition = 0; ++ ++ byteBufferChannel.position(validNewPosition); ++ ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength); ++ int numBytes = byteBufferChannel.read(dstBuffer); ++ ++ assertThat(numBytes).isEqualTo(validByteBufferLength); ++ assertThat(dstBuffer).isEqualTo(validByteBuffer); ++ } ++ ++ @Test ++ public void read_fromPosition5() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ long validNewPosition = 5; ++ ++ byteBufferChannel.position(validNewPosition); ++ ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength); ++ int numBytes = byteBufferChannel.read(dstBuffer); ++ ++ assertThat(numBytes).isEqualTo(validByteBufferLength - (int) validNewPosition); ++ String dstString = convertByteByfferToString(dstBuffer, numBytes); ++ String expectedString = "67890"; ++ assertThat(dstString).isEqualTo(expectedString); ++ } ++ ++ @Test ++ public void read_fromPositionValidByteBufferLength() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ long validNewPosition = validByteBufferLength; ++ ++ byteBufferChannel.position(validNewPosition); ++ ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength); ++ int numBytes = byteBufferChannel.read(dstBuffer); ++ ++ assertThat(numBytes).isEqualTo(-1); ++ } ++ ++ @Test ++ public void read_dstBufferRemaining0() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ long validNewPosition = 0; ++ ++ byteBufferChannel.position(validNewPosition); ++ ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength); ++ dstBuffer.position(validByteBufferLength); ++ int numBytes = byteBufferChannel.read(dstBuffer); ++ ++ assertThat(numBytes).isEqualTo(0); ++ String dstString = convertByteByfferToString(dstBuffer, numBytes); ++ String expectedString = ""; ++ assertThat(dstString).isEqualTo(expectedString); ++ } ++ ++ @Test ++ public void read_dstBufferIsSmallerThanTheBufferChannel() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ int dstBufferLength = 3; ++ ++ ByteBuffer dstBuffer = ByteBuffer.allocate(dstBufferLength); ++ int numBytes = byteBufferChannel.read(dstBuffer); ++ ++ assertThat(numBytes).isEqualTo(dstBufferLength); ++ assertThat(validByteBuffer.position()).isEqualTo(dstBufferLength); ++ ++ String dstString = convertByteByfferToString(dstBuffer, dstBufferLength); ++ String expectedString = "123"; ++ assertThat(dstString).isEqualTo(expectedString); ++ } ++ ++ @Test ++ public void size_validBuffer() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ assertThat(byteBufferChannel.size()).isEqualTo(validByteBufferLength); ++ } ++ ++ @Test ++ public void truncate_validSizeAndPosition0() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ long truncateSize = 3; ++ ++ byteBufferChannel.truncate(truncateSize); ++ ++ assertThat(byteBufferChannel.size()).isEqualTo(truncateSize); ++ assertThat(byteBufferChannel.position()).isEqualTo(0); ++ } ++ ++ @Test ++ public void truncate_validSizeAndPosition5() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ long validNewPosition = 5; ++ ++ byteBufferChannel.position(validNewPosition); ++ long truncateSize = 3; ++ byteBufferChannel.truncate(truncateSize); ++ ++ assertThat(byteBufferChannel.position()).isEqualTo(truncateSize); ++ } ++ ++ @Test ++ public void truncate_sizeNotSmallerThanBufferSize() { ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ long truncateSize = (long) validByteBufferLength; ++ ++ byteBufferChannel.truncate(truncateSize); ++ ++ assertThat(byteBufferChannel.position()).isEqualTo(0); ++ } ++ ++ @Test ++ public void write_srcBufferSmallerThanBufferChannel() { ++ String srcString = "5555"; ++ long newPosition = 3; ++ String expectedString = "1235555890"; ++ ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8)); ++ ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ byteBufferChannel.position(newPosition); ++ byteBufferChannel.write(srcBuffer); ++ ++ assertThat(byteBufferChannel.position()).isEqualTo(newPosition + srcString.length()); ++ ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength); ++ byteBufferChannel.position(0); ++ byteBufferChannel.read(dstBuffer); ++ ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8)); ++ dstBuffer.rewind(); ++ expectedBuffer.rewind(); ++ assertThat(dstBuffer).isEqualTo(expectedBuffer); ++ } ++ ++ @Test ++ public void write_srcBufferGreaterThanBufferChannel() { ++ String srcString = "5555"; ++ long newPosition = 8; ++ String expectedString = "1234567855"; ++ ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8)); ++ ++ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer); ++ byteBufferChannel.position(newPosition); ++ byteBufferChannel.write(srcBuffer); ++ assertThat(byteBufferChannel.position()).isEqualTo(validByteBufferLength); ++ ++ ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength); ++ byteBufferChannel.position(0); ++ byteBufferChannel.read(dstBuffer); ++ ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8)); ++ dstBuffer.rewind(); ++ expectedBuffer.rewind(); ++ assertThat(dstBuffer).isEqualTo(expectedBuffer); ++ } ++ ++ private static String convertByteByfferToString(ByteBuffer buffer, int arrLength) { ++ byte[] bytes = new byte[arrLength]; ++ buffer.rewind(); ++ buffer.get(bytes); ++ return new String(bytes, UTF_8); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java +index 67fc50d9f57b1..9f1173a1ea19b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java +@@ -16,24 +16,20 @@ limitations under the License. + package org.tensorflow.lite.support.metadata; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertArrayEquals; + import static org.junit.Assert.assertThrows; + + import android.content.Context; + import android.content.res.AssetFileDescriptor; ++ + import androidx.test.core.app.ApplicationProvider; ++ + import com.google.flatbuffers.FlatBufferBuilder; +-import java.io.FileInputStream; +-import java.io.InputStream; +-import java.nio.ByteBuffer; +-import java.nio.channels.FileChannel; +-import java.util.Arrays; +-import java.util.Collection; +-import java.util.HashSet; +-import java.util.Random; +-import java.util.Set; ++ + import org.apache.commons.io.IOUtils; + import org.checkerframework.checker.nullness.qual.Nullable; ++import org.junit.Ignore; + import org.junit.Test; + import org.junit.runner.RunWith; + import org.junit.runners.Suite; +@@ -56,931 +52,903 @@ import org.tensorflow.lite.support.metadata.schema.ModelMetadata; + import org.tensorflow.lite.support.metadata.schema.SubGraphMetadata; + import org.tensorflow.lite.support.metadata.schema.TensorMetadata; + +-import org.junit.Ignore; ++import java.io.FileInputStream; ++import java.io.InputStream; ++import java.nio.ByteBuffer; ++import java.nio.channels.FileChannel; ++import java.util.Arrays; ++import java.util.Collection; ++import java.util.HashSet; ++import java.util.Random; ++import java.util.Set; + + /** Tests of {@link MetadataExtractor}. */ + @RunWith(Suite.class) + @SuiteClasses({MetadataExtractorTest.General.class, MetadataExtractorTest.InputTensorType.class}) + public class MetadataExtractorTest { +- private static final int[] validShape = new int[] {4, 10, 10, 3}; +- private static final byte DATA_TYPE = TensorType.UINT8; +- private static final byte CONTENT_PROPERTIES_TYPE = ContentProperties.ImageProperties; +- private static final float VALID_SCALE = 3.3f; +- private static final long VALID_ZERO_POINT = 2; +- private static final float DEFAULT_SCALE = 0.0f; +- private static final long DEFAULT_ZERO_POINT = 0; +- private static final String MODEL_NAME = "model.tflite"; +- // Scale and zero point should both be a single value, not an array. +- private static final float[] invalidScale = new float[] {0.0f, 1.2f}; +- private static final long[] invalidZeroPoint = new long[] {1, 2}; +- private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite"; +- // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file. +- private static final String VALID_LABEL_FILE_NAME = "labels.txt"; +- // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite. +- private static final String INVALID_LABEL_FILE_NAME = "invalid.txt"; +- private static final int EMPTY_FLATBUFFER_VECTOR = -1; +- private static final String TFLITE_MODEL_IDENTIFIER = "TFL3"; +- private static final String TFLITE_METADATA_IDENTIFIER = "M001"; +- +- /** General tests of MetadataExtractor. */ +- @RunWith(RobolectricTestRunner.class) +- public static final class General extends MetadataExtractorTest { +- +- @Test +- public void hasMetadata_modelWithMetadata() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- assertThat(metadataExtractor.hasMetadata()).isTrue(); +- } +- +- @Test +- public void hasMetadata_modelWithoutMetadata() throws Exception { +- // Creates a model flatbuffer without metadata. +- ByteBuffer modelWithoutMetadata = createModelByteBuffer(/*metadataBuffer=*/ null, DATA_TYPE); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata); +- assertThat(metadataExtractor.hasMetadata()).isFalse(); +- } +- +- @Ignore +- @Test +- public void getAssociatedFile_validAssociateFile() throws Exception { +- ByteBuffer mobileNetBuffer = loadMobileNetBuffer(); +- MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer); +- InputStream associateFileStream = +- mobileNetMetadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME); +- +- // Reads the golden file from context. +- Context context = ApplicationProvider.getApplicationContext(); +- InputStream goldenAssociateFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME); +- assertThat(IOUtils.contentEquals(goldenAssociateFileStream, associateFileStream)).isTrue(); +- } +- +- @Ignore +- @Test +- public void getAssociatedFile_invalidAssociateFile() throws Exception { +- ByteBuffer mobileNetBuffer = loadMobileNetBuffer(); +- MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer); +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, +- () -> mobileNetMetadataExtractor.getAssociatedFile(INVALID_LABEL_FILE_NAME)); +- assertThat(exception) +- .hasMessageThat() +- .isEqualTo( +- String.format( +- "The file, %s, does not exist in the zip file.", INVALID_LABEL_FILE_NAME)); +- } +- +- @Ignore +- @Test +- public void getAssociatedFile_nullFileName() throws Exception { +- ByteBuffer mobileNetBuffer = loadMobileNetBuffer(); +- MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer); +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, +- () -> mobileNetMetadataExtractor.getAssociatedFile(/*fileName=*/ null)); +- assertThat(exception) +- .hasMessageThat() +- .contains("The file, null, does not exist in the zip file."); +- } +- +- @Test +- public void getAssociatedFile_nonZipModel_throwsException() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- IllegalStateException exception = +- assertThrows( +- IllegalStateException.class, +- () -> metadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME)); +- assertThat(exception) +- .hasMessageThat() +- .contains("This model does not contain associated files, and is not a Zip file."); +- } +- +- @Test +- public void getAssociatedFileNames_nonZipModel_throwsException() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- IllegalStateException exception = +- assertThrows(IllegalStateException.class, metadataExtractor::getAssociatedFileNames); +- assertThat(exception) +- .hasMessageThat() +- .contains("This model does not contain associated files, and is not a Zip file."); +- } +- +- @Ignore +- @Test +- public void getAssociatedFileNames_validFileNames() throws Exception { +- ByteBuffer mobileNetBuffer = loadMobileNetBuffer(); +- MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer); +- Set<String> expectedSet = new HashSet<>(); +- expectedSet.add(VALID_LABEL_FILE_NAME); +- assertThat(mobileNetMetadataExtractor.getAssociatedFileNames()).isEqualTo(expectedSet); +- } +- +- @Test +- public void metadataExtractor_loadNullBuffer_throwsException() { +- ByteBuffer nullBuffer = null; +- NullPointerException exception = +- assertThrows(NullPointerException.class, () -> new MetadataExtractor(nullBuffer)); +- assertThat(exception).hasMessageThat().contains("Model flatbuffer cannot be null."); +- } +- +- @Test +- public void metadataExtractor_loadRandomBuffer_throwsException() { +- ByteBuffer randomBuffer = createRandomByteBuffer(); +- IllegalArgumentException exception = +- assertThrows(IllegalArgumentException.class, () -> new MetadataExtractor(randomBuffer)); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "The identifier of the model is invalid. The buffer may not be a valid TFLite model" +- + " flatbuffer."); +- } +- +- @Test +- public void metadataExtractor_loadModelWithInvalidIdentifier_throwsException() { +- // Creates a model with an invalid identifier. +- String invalidIdentifier = "INVI"; +- FlatBufferBuilder builder = new FlatBufferBuilder(); +- Model.startModel(builder); +- int model = Model.endModel(builder); +- builder.finish(model, invalidIdentifier); +- ByteBuffer modelBuffer = builder.dataBuffer(); +- +- IllegalArgumentException exception = +- assertThrows(IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer)); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "The identifier of the model is invalid. The buffer may not be a valid TFLite model" +- + " flatbuffer."); +- } +- +- @Test +- public void metadataExtractor_loadMetadataWithInvalidIdentifier_throwsException() { +- // Creates a model with metadata which contains an invalid identifier. +- String invalidIdentifier = "INVI"; +- ByteBuffer metadata = createMetadataByteBuffer(invalidIdentifier, null); +- ByteBuffer modelBuffer = createModelByteBuffer(metadata, DATA_TYPE); +- +- IllegalArgumentException exception = +- assertThrows(IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer)); +- assertThat(exception) +- .hasMessageThat() +- .contains( +- "The identifier of the metadata is invalid. The buffer may not be a valid TFLite" +- + " metadata flatbuffer."); +- } +- +- @Test +- public void getInputTensorCount_validModelFile() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- int count = metadataExtractor.getInputTensorCount(); +- assertThat(count).isEqualTo(3); +- } +- +- @Test +- public void getOutputTensorCount_validModelFile() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- int count = metadataExtractor.getOutputTensorCount(); +- assertThat(count).isEqualTo(3); +- } +- +- @Test +- public void getInputTensorShape_validTensorShape() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- int[] shape = metadataExtractor.getInputTensorShape(0); +- assertArrayEquals(validShape, shape); +- } +- +- @Test +- public void getInputTensorShape_emptyTensor() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- int[] shape = metadataExtractor.getInputTensorShape(1); +- assertThat(shape).isEmpty(); +- } +- +- @Test +- public void getInputTensorType_emptyTensor() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- byte type = metadataExtractor.getInputTensorType(1); +- assertThat(type).isEqualTo(TensorType.FLOAT32); +- } +- +- @Test +- public void getOutputTensorShape_validTensor() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- int[] shape = metadataExtractor.getOutputTensorShape(0); +- assertArrayEquals(validShape, shape); +- } +- +- @Test +- public void getOutputTensorShape_emptyTensor() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- int[] shape = metadataExtractor.getOutputTensorShape(1); +- assertThat(shape).isEmpty(); +- } +- +- @Test +- public void getOutputTensorType_emptyTensor() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- byte type = metadataExtractor.getOutputTensorType(1); +- assertThat(type).isEqualTo(TensorType.FLOAT32); +- } +- +- @Test +- public void getInputTensorShape_indexGreaterThanTensorNumber_throwsException() +- throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> metadataExtractor.getInputTensorShape(3)); +- assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid."); +- } +- +- @Test +- public void getInputTensorShape_negtiveIndex_throwsException() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> metadataExtractor.getInputTensorShape(-1)); +- assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid."); +- } +- +- @Test +- public void getOutputTensorShape_indexGreaterThanTensorNumber_throwsException() +- throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorShape(3)); +- assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid."); +- } +- +- @Test +- public void getOutputTensorShape_negtiveIndex_throwsException() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorShape(-1)); +- assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid."); +- } +- +- @Test +- public void getModelMetadata_modelWithMetadata() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- ModelMetadata modelMetadata = metadataExtractor.getModelMetadata(); +- assertThat(modelMetadata.name()).isEqualTo(MODEL_NAME); +- } +- +- @Test +- public void getModelMetadata_modelWithoutMetadata_throwsException() throws Exception { +- // Creates a model flatbuffer without metadata. +- ByteBuffer modelWithoutMetadata = createModelByteBuffer(/*metadataBuffer=*/ null, DATA_TYPE); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata); +- +- IllegalStateException exception = +- assertThrows(IllegalStateException.class, () -> metadataExtractor.getModelMetadata()); +- assertThat(exception) +- .hasMessageThat() +- .contains("This model does not contain model metadata."); +- } +- +- @Test +- public void metadataExtractor_modelWithEmptySubgraphMetadata_throwsException() { +- // Creates a metadata FlatBuffer without empty subgraph metadata. +- FlatBufferBuilder builder = new FlatBufferBuilder(); +- SubGraphMetadata.startSubGraphMetadata(builder); +- int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder); +- int subgraphsMetadata = +- ModelMetadata.createSubgraphMetadataVector(builder, new int[] {subgraph1Metadata}); +- +- ModelMetadata.startModelMetadata(builder); +- ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata); +- int modelMetadata = ModelMetadata.endModelMetadata(builder); +- builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER); +- ByteBuffer emptyMetadata = builder.dataBuffer(); +- ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE); +- +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> new MetadataExtractor(modelWithEmptyMetadata)); +- assertThat(exception) +- .hasMessageThat() +- .isEqualTo( +- "The number of input tensors in the model is 3. The number of input tensors that" +- + " recorded in the metadata is 0. These two values does not match."); +- } +- +- @Test +- public void metadataExtractor_modelWithEmptyMetadata_throwsException() { +- // Creates a empty metadata FlatBuffer. +- FlatBufferBuilder builder = new FlatBufferBuilder(); +- ModelMetadata.startModelMetadata(builder); +- int modelMetadata = ModelMetadata.endModelMetadata(builder); +- builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER); +- +- ByteBuffer emptyMetadata = builder.dataBuffer(); +- ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE); +- +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> new MetadataExtractor(modelWithEmptyMetadata)); +- assertThat(exception) +- .hasMessageThat() +- .contains("The metadata flatbuffer does not contain any subgraph metadata."); +- } +- +- @Test +- public void metadataExtractor_modelWithNoMetadata_throwsException() throws Exception { +- // Creates a model flatbuffer without metadata. +- ByteBuffer modelWithoutMetadata = createModelByteBuffer(/*metadataBuffer=*/ null, DATA_TYPE); +- +- // It is allowed to create a model without metadata, but invoking methods that reads metadata +- // is not allowed. +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata); +- +- IllegalStateException exception = +- assertThrows( +- IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0)); +- assertThat(exception) +- .hasMessageThat() +- .contains("This model does not contain model metadata."); +- } +- +- @Test +- public void metadataExtractor_modelWithIrrelevantMetadata_throwsException() throws Exception { +- // Creates a model with irrelevant metadata. +- FlatBufferBuilder builder = new FlatBufferBuilder(); +- SubGraph.startSubGraph(builder); +- int subgraph = SubGraph.endSubGraph(builder); +- +- int metadataName = builder.createString("Irrelevant metadata"); +- Metadata.startMetadata(builder); +- Metadata.addName(builder, metadataName); +- int metadata = Metadata.endMetadata(builder); +- int metadataArray = Model.createMetadataVector(builder, new int[] {metadata}); +- +- // Creates Model. +- int[] subgraphs = new int[1]; +- subgraphs[0] = subgraph; +- int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs); +- Model.startModel(builder); +- Model.addSubgraphs(builder, modelSubgraphs); +- Model.addMetadata(builder, metadataArray); +- int model = Model.endModel(builder); +- builder.finish(model, TFLITE_MODEL_IDENTIFIER); +- ByteBuffer modelBuffer = builder.dataBuffer(); +- +- // It is allowed to create a model without metadata, but invoking methods that reads metadata +- // is not allowed. +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelBuffer); +- +- IllegalStateException exception = +- assertThrows( +- IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0)); +- assertThat(exception) +- .hasMessageThat() +- .contains("This model does not contain model metadata."); +- } +- +- @Test +- public void getInputTensorMetadata_validTensor() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(0); +- assertThat(inputMetadata.content().contentPropertiesType()) +- .isEqualTo(CONTENT_PROPERTIES_TYPE); +- } +- +- @Test +- public void getInputTensorMetadata_emptyTensor() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(1); +- assertThat(inputMetadata.content()).isNull(); +- } +- +- @Test +- public void getInputTensorMetadata_invalidTensor() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(2); +- assertThat(inputMetadata.content().contentPropertiesType()) +- .isEqualTo(CONTENT_PROPERTIES_TYPE); +- } +- +- @Test +- public void getOutputTensorMetadata_validTensor() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(0); +- assertThat(outputMetadata.content().contentPropertiesType()) +- .isEqualTo(CONTENT_PROPERTIES_TYPE); +- } +- +- @Test +- public void getOutputTensorMetadata_emptyTensor() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(1); +- assertThat(outputMetadata.content()).isNull(); +- } +- +- @Test +- public void getOutputTensorMetadata_invalidTensor() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(2); +- assertThat(outputMetadata.content().contentPropertiesType()) +- .isEqualTo(CONTENT_PROPERTIES_TYPE); +- } +- +- @Test +- public void getInputTensorMetadata_indexGreaterThanTensorNumber_throwsException() +- throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> metadataExtractor.getInputTensorMetadata(3)); +- assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid."); +- } +- +- @Test +- public void getInputTensorMetadata_negtiveIndex_throwsException() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> metadataExtractor.getInputTensorMetadata(-1)); +- assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid."); +- } +- +- @Test +- public void getOutputTensorMetadata_indexGreaterThanTensorNumber_throwsException() +- throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorMetadata(3)); +- assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid."); +- } +- +- @Test +- public void getOutputTensorMetadata_negtiveIndex_throwsException() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorMetadata(-1)); +- assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid."); +- } +- +- @Test +- public void getInputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- QuantizationParams quantizationParams = metadataExtractor.getInputTensorQuantizationParams(0); +- assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE); +- assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT); +- } +- +- @Test +- public void getInputTensorQuantizationParams_emptyTensor() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- QuantizationParams quantizationParams = metadataExtractor.getInputTensorQuantizationParams(1); +- // Scale and zero point are expected to be 1.0f and 0, respectively as default. +- assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE); +- assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT); +- } +- +- @Test +- public void getInputTensorQuantizationParams_invalidScale() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, +- () -> metadataExtractor.getInputTensorQuantizationParams(2)); +- assertThat(exception) +- .hasMessageThat() +- .contains("Input and output tensors do not support per-channel quantization."); +- } +- +- @Test +- public void getOutputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- QuantizationParams quantizationParams = +- metadataExtractor.getOutputTensorQuantizationParams(0); +- assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE); +- assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT); +- } +- +- @Test +- public void getOutputTensorQuantizationParams_emptyTensor() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- QuantizationParams quantizationParams = +- metadataExtractor.getOutputTensorQuantizationParams(1); +- // Scale and zero point are expected to be 1.0f and 0, respectively as default. +- assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE); +- assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT); +- } +- +- @Test +- public void getOutputTensorQuantizationParams_invalidScale() throws Exception { +- // Creates a model flatbuffer with metadata. +- ByteBuffer modelWithMetadata = createModelByteBuffer(); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, +- () -> metadataExtractor.getOutputTensorQuantizationParams(2)); +- assertThat(exception) +- .hasMessageThat() +- .contains("Input and output tensors do not support per-channel quantization."); +- } +- +- @Test +- public void isMinimumParserVersionSatisfied_olderVersion() throws Exception { +- // A version older than the current one. The version starts from 1.0.0, thus 0.10.0 will +- // precede any furture versions. +- String minVersion = "0.10"; +- // Creates a metadata using the above version. +- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion); +- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- +- assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue(); +- } +- +- @Test +- public void isMinimumParserVersionSatisfied_sameVersionSamelength() throws Exception { +- // A version the same as the current one. +- String minVersion = MetadataParser.VERSION; +- // Creates a metadata using the above version. +- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion); +- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- +- assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue(); +- } +- +- @Test +- public void isMinimumParserVersionSatisfied_sameVersionLongerlength() throws Exception { +- // A version the same as the current one, but with longer length. +- String minVersion = MetadataParser.VERSION + ".0"; +- // Creates a metadata using the above version. +- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion); +- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- +- assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue(); +- } +- +- @Test +- public void isMinimumParserVersionSatisfied_emptyVersion() throws Exception { +- // An empty version, which can be generated before the first versioned release. +- String minVersion = null; +- // Creates a metadata using the above version. +- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion); +- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- +- assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue(); +- } +- +- @Test +- public void isMinimumParserVersionSatisfied_newerVersion() throws Exception { +- // Creates a version newer than the current one by appending "1" to the end of the current +- // version for testing purposes. For example, 1.0.0 becomes 1.0.01. +- String minVersion = MetadataParser.VERSION + "1"; +- // Creates a metadata using the above version. +- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion); +- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- +- assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse(); +- } +- +- @Test +- public void isMinimumParserVersionSatisfied_newerVersionLongerLength() throws Exception { +- // Creates a version newer than the current one by appending ".1" to the end of the current +- // version for testing purposes. For example, 1.0.0 becomes 1.0.0.1. +- String minVersion = MetadataParser.VERSION + ".1"; +- // Creates a metadata using the above version. +- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion); +- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE); +- +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- +- assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse(); +- } +- } +- +- /** Parameterized tests for the input tensor data type. */ +- @RunWith(ParameterizedRobolectricTestRunner.class) +- public static final class InputTensorType extends MetadataExtractorTest { +- /** The tensor type that used to create the model buffer. */ +- @Parameter(0) +- public byte tensorType; +- +- /** A list of TensorType that is used in the test. */ +- @Parameters +- public static Collection<Object[]> data() { +- return Arrays.asList( +- new Object[][] { +- {TensorType.FLOAT32}, {TensorType.INT32}, +- {TensorType.UINT8}, {TensorType.INT64}, +- {TensorType.STRING} +- }); +- } +- +- @Test +- public void getInputTensorType_validTensor() throws Exception { +- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null); +- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType); +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- byte type = metadataExtractor.getInputTensorType(0); +- assertThat(type).isEqualTo(tensorType); +- } +- +- @Test +- public void getOutputTensorType_validTensor() throws Exception { +- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null); +- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType); +- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); +- byte type = metadataExtractor.getOutputTensorType(0); +- assertThat(type).isEqualTo(tensorType); +- } +- } +- +- /** +- * Creates an example metadata flatbuffer, which contains one subgraph with three inputs and three +- * outputs. +- */ +- private static ByteBuffer createMetadataByteBuffer( +- String identifier, @Nullable String minVersionStr) { +- FlatBufferBuilder builder = new FlatBufferBuilder(); +- +- Content.startContent(builder); +- Content.addContentPropertiesType(builder, CONTENT_PROPERTIES_TYPE); +- int content = Content.endContent(builder); +- +- TensorMetadata.startTensorMetadata(builder); +- TensorMetadata.addContent(builder, content); +- int metadataForValidTensor = TensorMetadata.endTensorMetadata(builder); +- +- TensorMetadata.startTensorMetadata(builder); +- int metadataForEmptyTensor = TensorMetadata.endTensorMetadata(builder); +- +- TensorMetadata.startTensorMetadata(builder); +- TensorMetadata.addContent(builder, content); +- int metadataForInvalidTensor = TensorMetadata.endTensorMetadata(builder); +- +- int[] tensorMetadataArray = +- new int[] {metadataForValidTensor, metadataForEmptyTensor, metadataForInvalidTensor}; +- int inputTensorMetadata = +- SubGraphMetadata.createInputTensorMetadataVector(builder, tensorMetadataArray); +- int outputTensorMetadata = +- SubGraphMetadata.createOutputTensorMetadataVector(builder, tensorMetadataArray); +- +- SubGraphMetadata.startSubGraphMetadata(builder); +- SubGraphMetadata.addInputTensorMetadata(builder, inputTensorMetadata); +- SubGraphMetadata.addOutputTensorMetadata(builder, outputTensorMetadata); +- int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder); +- +- int[] subgraphMetadataArray = new int[] {subgraph1Metadata}; +- int subgraphsMetadata = +- ModelMetadata.createSubgraphMetadataVector(builder, subgraphMetadataArray); +- +- int modelName = builder.createString(MODEL_NAME); +- if (minVersionStr != null) { +- int minVersion = builder.createString(minVersionStr); +- ModelMetadata.startModelMetadata(builder); +- ModelMetadata.addMinParserVersion(builder, minVersion); +- } else { +- // If minVersionStr is null, skip generating the field in the metadata. +- ModelMetadata.startModelMetadata(builder); +- } +- ModelMetadata.addName(builder, modelName); +- ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata); +- int modelMetadata = ModelMetadata.endModelMetadata(builder); +- +- builder.finish(modelMetadata, identifier); +- return builder.dataBuffer(); +- } +- +- private static int createQuantizationParameters( +- FlatBufferBuilder builder, float[] scale, long[] zeroPoint) { +- int inputScale = QuantizationParameters.createScaleVector(builder, scale); +- int inputZeroPoint = QuantizationParameters.createZeroPointVector(builder, zeroPoint); +- QuantizationParameters.startQuantizationParameters(builder); +- QuantizationParameters.addScale(builder, inputScale); +- QuantizationParameters.addZeroPoint(builder, inputZeroPoint); +- return QuantizationParameters.endQuantizationParameters(builder); +- } +- +- private static int createTensor( +- FlatBufferBuilder builder, int[] inputShape, byte inputType, int inputQuantization) { +- int inputShapeVector1 = Tensor.createShapeVector(builder, inputShape); +- Tensor.startTensor(builder); +- Tensor.addShape(builder, inputShapeVector1); +- Tensor.addType(builder, inputType); +- Tensor.addQuantization(builder, inputQuantization); +- return Tensor.endTensor(builder); +- } +- +- /** +- * Creates an example model flatbuffer, which contains one subgraph with three inputs and three +- * output. +- */ +- private static ByteBuffer createModelByteBuffer(ByteBuffer metadataBuffer, byte dataType) { +- FlatBufferBuilder builder = new FlatBufferBuilder(); +- +- // Creates a valid set of quantization parameters. +- int validQuantization = +- createQuantizationParameters( +- builder, new float[] {VALID_SCALE}, new long[] {VALID_ZERO_POINT}); +- +- // Creates an invalid set of quantization parameters. +- int inValidQuantization = createQuantizationParameters(builder, invalidScale, invalidZeroPoint); +- +- // Creates an input Tensor with valid quantization parameters. +- int validTensor = createTensor(builder, validShape, dataType, validQuantization); +- +- // Creates an empty input Tensor. +- Tensor.startTensor(builder); +- int emptyTensor = Tensor.endTensor(builder); +- +- // Creates an input Tensor with invalid quantization parameters. +- int invalidTensor = createTensor(builder, validShape, dataType, inValidQuantization); +- +- // Creates the SubGraph. +- int[] tensors = new int[6]; +- tensors[0] = validTensor; +- tensors[1] = emptyTensor; +- tensors[2] = invalidTensor; +- tensors[3] = validTensor; +- tensors[4] = emptyTensor; +- tensors[5] = invalidTensor; +- int subgraphTensors = SubGraph.createTensorsVector(builder, tensors); +- +- int subgraphInputs = SubGraph.createInputsVector(builder, new int[] {0, 1, 2}); +- int subgraphOutputs = SubGraph.createOutputsVector(builder, new int[] {3, 4, 5}); +- +- SubGraph.startSubGraph(builder); +- SubGraph.addTensors(builder, subgraphTensors); +- SubGraph.addInputs(builder, subgraphInputs); +- SubGraph.addOutputs(builder, subgraphOutputs); +- int subgraph = SubGraph.endSubGraph(builder); +- +- // Creates the Model. +- int[] subgraphs = new int[1]; +- subgraphs[0] = subgraph; +- int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs); +- +- // Inserts metadataBuffer into the model if it's not null. +- int modelBuffers = EMPTY_FLATBUFFER_VECTOR; +- int metadataArray = EMPTY_FLATBUFFER_VECTOR; +- if (metadataBuffer != null) { +- int data = Buffer.createDataVector(builder, metadataBuffer); +- Buffer.startBuffer(builder); +- Buffer.addData(builder, data); +- int buffer = Buffer.endBuffer(builder); +- modelBuffers = Model.createBuffersVector(builder, new int[] {buffer}); +- +- int metadataName = builder.createString(ModelInfo.METADATA_FIELD_NAME); +- Metadata.startMetadata(builder); +- Metadata.addName(builder, metadataName); +- Metadata.addBuffer(builder, 0); +- int metadata = Metadata.endMetadata(builder); +- metadataArray = Model.createMetadataVector(builder, new int[] {metadata}); +- } +- +- Model.startModel(builder); +- Model.addSubgraphs(builder, modelSubgraphs); +- if (modelBuffers != EMPTY_FLATBUFFER_VECTOR && metadataArray != EMPTY_FLATBUFFER_VECTOR) { +- Model.addBuffers(builder, modelBuffers); +- Model.addMetadata(builder, metadataArray); ++ private static final int[] validShape = new int[] {4, 10, 10, 3}; ++ private static final byte DATA_TYPE = TensorType.UINT8; ++ private static final byte CONTENT_PROPERTIES_TYPE = ContentProperties.ImageProperties; ++ private static final float VALID_SCALE = 3.3f; ++ private static final long VALID_ZERO_POINT = 2; ++ private static final float DEFAULT_SCALE = 0.0f; ++ private static final long DEFAULT_ZERO_POINT = 0; ++ private static final String MODEL_NAME = "model.tflite"; ++ // Scale and zero point should both be a single value, not an array. ++ private static final float[] invalidScale = new float[] {0.0f, 1.2f}; ++ private static final long[] invalidZeroPoint = new long[] {1, 2}; ++ private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite"; ++ // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file. ++ private static final String VALID_LABEL_FILE_NAME = "labels.txt"; ++ // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite. ++ private static final String INVALID_LABEL_FILE_NAME = "invalid.txt"; ++ private static final int EMPTY_FLATBUFFER_VECTOR = -1; ++ private static final String TFLITE_MODEL_IDENTIFIER = "TFL3"; ++ private static final String TFLITE_METADATA_IDENTIFIER = "M001"; ++ ++ /** General tests of MetadataExtractor. */ ++ @RunWith(RobolectricTestRunner.class) ++ public static final class General extends MetadataExtractorTest { ++ @Test ++ public void hasMetadata_modelWithMetadata() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ assertThat(metadataExtractor.hasMetadata()).isTrue(); ++ } ++ ++ @Test ++ public void hasMetadata_modelWithoutMetadata() throws Exception { ++ // Creates a model flatbuffer without metadata. ++ ByteBuffer modelWithoutMetadata = ++ createModelByteBuffer(/*metadataBuffer=*/null, DATA_TYPE); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata); ++ assertThat(metadataExtractor.hasMetadata()).isFalse(); ++ } ++ ++ @Ignore ++ @Test ++ public void getAssociatedFile_validAssociateFile() throws Exception { ++ ByteBuffer mobileNetBuffer = loadMobileNetBuffer(); ++ MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer); ++ InputStream associateFileStream = ++ mobileNetMetadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME); ++ ++ // Reads the golden file from context. ++ Context context = ApplicationProvider.getApplicationContext(); ++ InputStream goldenAssociateFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME); ++ assertThat(IOUtils.contentEquals(goldenAssociateFileStream, associateFileStream)) ++ .isTrue(); ++ } ++ ++ @Ignore ++ @Test ++ public void getAssociatedFile_invalidAssociateFile() throws Exception { ++ ByteBuffer mobileNetBuffer = loadMobileNetBuffer(); ++ MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer); ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> mobileNetMetadataExtractor.getAssociatedFile(INVALID_LABEL_FILE_NAME)); ++ assertThat(exception).hasMessageThat().isEqualTo(String.format( ++ "The file, %s, does not exist in the zip file.", INVALID_LABEL_FILE_NAME)); ++ } ++ ++ @Ignore ++ @Test ++ public void getAssociatedFile_nullFileName() throws Exception { ++ ByteBuffer mobileNetBuffer = loadMobileNetBuffer(); ++ MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer); ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> mobileNetMetadataExtractor.getAssociatedFile(/*fileName=*/null)); ++ assertThat(exception).hasMessageThat().contains( ++ "The file, null, does not exist in the zip file."); ++ } ++ ++ @Test ++ public void getAssociatedFile_nonZipModel_throwsException() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ IllegalStateException exception = assertThrows(IllegalStateException.class, ++ () -> metadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME)); ++ assertThat(exception).hasMessageThat().contains( ++ "This model does not contain associated files, and is not a Zip file."); ++ } ++ ++ @Test ++ public void getAssociatedFileNames_nonZipModel_throwsException() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ IllegalStateException exception = assertThrows( ++ IllegalStateException.class, metadataExtractor::getAssociatedFileNames); ++ assertThat(exception).hasMessageThat().contains( ++ "This model does not contain associated files, and is not a Zip file."); ++ } ++ ++ @Ignore ++ @Test ++ public void getAssociatedFileNames_validFileNames() throws Exception { ++ ByteBuffer mobileNetBuffer = loadMobileNetBuffer(); ++ MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer); ++ Set<String> expectedSet = new HashSet<>(); ++ expectedSet.add(VALID_LABEL_FILE_NAME); ++ assertThat(mobileNetMetadataExtractor.getAssociatedFileNames()).isEqualTo(expectedSet); ++ } ++ ++ @Test ++ public void metadataExtractor_loadNullBuffer_throwsException() { ++ ByteBuffer nullBuffer = null; ++ NullPointerException exception = assertThrows( ++ NullPointerException.class, () -> new MetadataExtractor(nullBuffer)); ++ assertThat(exception).hasMessageThat().contains("Model flatbuffer cannot be null."); ++ } ++ ++ @Test ++ public void metadataExtractor_loadRandomBuffer_throwsException() { ++ ByteBuffer randomBuffer = createRandomByteBuffer(); ++ IllegalArgumentException exception = assertThrows( ++ IllegalArgumentException.class, () -> new MetadataExtractor(randomBuffer)); ++ assertThat(exception).hasMessageThat().contains( ++ "The identifier of the model is invalid. The buffer may not be a valid TFLite model" ++ + " flatbuffer."); ++ } ++ ++ @Test ++ public void metadataExtractor_loadModelWithInvalidIdentifier_throwsException() { ++ // Creates a model with an invalid identifier. ++ String invalidIdentifier = "INVI"; ++ FlatBufferBuilder builder = new FlatBufferBuilder(); ++ Model.startModel(builder); ++ int model = Model.endModel(builder); ++ builder.finish(model, invalidIdentifier); ++ ByteBuffer modelBuffer = builder.dataBuffer(); ++ ++ IllegalArgumentException exception = assertThrows( ++ IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer)); ++ assertThat(exception).hasMessageThat().contains( ++ "The identifier of the model is invalid. The buffer may not be a valid TFLite model" ++ + " flatbuffer."); ++ } ++ ++ @Test ++ public void metadataExtractor_loadMetadataWithInvalidIdentifier_throwsException() { ++ // Creates a model with metadata which contains an invalid identifier. ++ String invalidIdentifier = "INVI"; ++ ByteBuffer metadata = createMetadataByteBuffer(invalidIdentifier, null); ++ ByteBuffer modelBuffer = createModelByteBuffer(metadata, DATA_TYPE); ++ ++ IllegalArgumentException exception = assertThrows( ++ IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer)); ++ assertThat(exception).hasMessageThat().contains( ++ "The identifier of the metadata is invalid. The buffer may not be a valid TFLite" ++ + " metadata flatbuffer."); ++ } ++ ++ @Test ++ public void getInputTensorCount_validModelFile() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ int count = metadataExtractor.getInputTensorCount(); ++ assertThat(count).isEqualTo(3); ++ } ++ ++ @Test ++ public void getOutputTensorCount_validModelFile() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ int count = metadataExtractor.getOutputTensorCount(); ++ assertThat(count).isEqualTo(3); ++ } ++ ++ @Test ++ public void getInputTensorShape_validTensorShape() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ int[] shape = metadataExtractor.getInputTensorShape(0); ++ assertArrayEquals(validShape, shape); ++ } ++ ++ @Test ++ public void getInputTensorShape_emptyTensor() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ int[] shape = metadataExtractor.getInputTensorShape(1); ++ assertThat(shape).isEmpty(); ++ } ++ ++ @Test ++ public void getInputTensorType_emptyTensor() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ byte type = metadataExtractor.getInputTensorType(1); ++ assertThat(type).isEqualTo(TensorType.FLOAT32); ++ } ++ ++ @Test ++ public void getOutputTensorShape_validTensor() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ int[] shape = metadataExtractor.getOutputTensorShape(0); ++ assertArrayEquals(validShape, shape); ++ } ++ ++ @Test ++ public void getOutputTensorShape_emptyTensor() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ int[] shape = metadataExtractor.getOutputTensorShape(1); ++ assertThat(shape).isEmpty(); ++ } ++ ++ @Test ++ public void getOutputTensorType_emptyTensor() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ byte type = metadataExtractor.getOutputTensorType(1); ++ assertThat(type).isEqualTo(TensorType.FLOAT32); ++ } ++ ++ @Test ++ public void getInputTensorShape_indexGreaterThanTensorNumber_throwsException() ++ throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ IllegalArgumentException exception = assertThrows( ++ IllegalArgumentException.class, () -> metadataExtractor.getInputTensorShape(3)); ++ assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid."); ++ } ++ ++ @Test ++ public void getInputTensorShape_negtiveIndex_throwsException() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> metadataExtractor.getInputTensorShape(-1)); ++ assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid."); ++ } ++ ++ @Test ++ public void getOutputTensorShape_indexGreaterThanTensorNumber_throwsException() ++ throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> metadataExtractor.getOutputTensorShape(3)); ++ assertThat(exception).hasMessageThat().contains( ++ "The outputIndex specified is invalid."); ++ } ++ ++ @Test ++ public void getOutputTensorShape_negtiveIndex_throwsException() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> metadataExtractor.getOutputTensorShape(-1)); ++ assertThat(exception).hasMessageThat().contains( ++ "The outputIndex specified is invalid."); ++ } ++ ++ @Test ++ public void getModelMetadata_modelWithMetadata() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ ModelMetadata modelMetadata = metadataExtractor.getModelMetadata(); ++ assertThat(modelMetadata.name()).isEqualTo(MODEL_NAME); ++ } ++ ++ @Test ++ public void getModelMetadata_modelWithoutMetadata_throwsException() throws Exception { ++ // Creates a model flatbuffer without metadata. ++ ByteBuffer modelWithoutMetadata = ++ createModelByteBuffer(/*metadataBuffer=*/null, DATA_TYPE); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata); ++ ++ IllegalStateException exception = assertThrows( ++ IllegalStateException.class, () -> metadataExtractor.getModelMetadata()); ++ assertThat(exception).hasMessageThat().contains( ++ "This model does not contain model metadata."); ++ } ++ ++ @Test ++ public void metadataExtractor_modelWithEmptySubgraphMetadata_throwsException() { ++ // Creates a metadata FlatBuffer without empty subgraph metadata. ++ FlatBufferBuilder builder = new FlatBufferBuilder(); ++ SubGraphMetadata.startSubGraphMetadata(builder); ++ int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder); ++ int subgraphsMetadata = ModelMetadata.createSubgraphMetadataVector( ++ builder, new int[] {subgraph1Metadata}); ++ ++ ModelMetadata.startModelMetadata(builder); ++ ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata); ++ int modelMetadata = ModelMetadata.endModelMetadata(builder); ++ builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER); ++ ByteBuffer emptyMetadata = builder.dataBuffer(); ++ ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE); ++ ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> new MetadataExtractor(modelWithEmptyMetadata)); ++ assertThat(exception).hasMessageThat().isEqualTo( ++ "The number of input tensors in the model is 3. The number of input tensors that" ++ + " recorded in the metadata is 0. These two values does not match."); ++ } ++ ++ @Test ++ public void metadataExtractor_modelWithEmptyMetadata_throwsException() { ++ // Creates a empty metadata FlatBuffer. ++ FlatBufferBuilder builder = new FlatBufferBuilder(); ++ ModelMetadata.startModelMetadata(builder); ++ int modelMetadata = ModelMetadata.endModelMetadata(builder); ++ builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER); ++ ++ ByteBuffer emptyMetadata = builder.dataBuffer(); ++ ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE); ++ ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> new MetadataExtractor(modelWithEmptyMetadata)); ++ assertThat(exception).hasMessageThat().contains( ++ "The metadata flatbuffer does not contain any subgraph metadata."); ++ } ++ ++ @Test ++ public void metadataExtractor_modelWithNoMetadata_throwsException() throws Exception { ++ // Creates a model flatbuffer without metadata. ++ ByteBuffer modelWithoutMetadata = ++ createModelByteBuffer(/*metadataBuffer=*/null, DATA_TYPE); ++ ++ // It is allowed to create a model without metadata, but invoking methods that reads ++ // metadata is not allowed. ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata); ++ ++ IllegalStateException exception = assertThrows( ++ IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0)); ++ assertThat(exception).hasMessageThat().contains( ++ "This model does not contain model metadata."); ++ } ++ ++ @Test ++ public void metadataExtractor_modelWithIrrelevantMetadata_throwsException() ++ throws Exception { ++ // Creates a model with irrelevant metadata. ++ FlatBufferBuilder builder = new FlatBufferBuilder(); ++ SubGraph.startSubGraph(builder); ++ int subgraph = SubGraph.endSubGraph(builder); ++ ++ int metadataName = builder.createString("Irrelevant metadata"); ++ Metadata.startMetadata(builder); ++ Metadata.addName(builder, metadataName); ++ int metadata = Metadata.endMetadata(builder); ++ int metadataArray = Model.createMetadataVector(builder, new int[] {metadata}); ++ ++ // Creates Model. ++ int[] subgraphs = new int[1]; ++ subgraphs[0] = subgraph; ++ int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs); ++ Model.startModel(builder); ++ Model.addSubgraphs(builder, modelSubgraphs); ++ Model.addMetadata(builder, metadataArray); ++ int model = Model.endModel(builder); ++ builder.finish(model, TFLITE_MODEL_IDENTIFIER); ++ ByteBuffer modelBuffer = builder.dataBuffer(); ++ ++ // It is allowed to create a model without metadata, but invoking methods that reads ++ // metadata is not allowed. ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelBuffer); ++ ++ IllegalStateException exception = assertThrows( ++ IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0)); ++ assertThat(exception).hasMessageThat().contains( ++ "This model does not contain model metadata."); ++ } ++ ++ @Test ++ public void getInputTensorMetadata_validTensor() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(0); ++ assertThat(inputMetadata.content().contentPropertiesType()) ++ .isEqualTo(CONTENT_PROPERTIES_TYPE); ++ } ++ ++ @Test ++ public void getInputTensorMetadata_emptyTensor() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(1); ++ assertThat(inputMetadata.content()).isNull(); ++ } ++ ++ @Test ++ public void getInputTensorMetadata_invalidTensor() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(2); ++ assertThat(inputMetadata.content().contentPropertiesType()) ++ .isEqualTo(CONTENT_PROPERTIES_TYPE); ++ } ++ ++ @Test ++ public void getOutputTensorMetadata_validTensor() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(0); ++ assertThat(outputMetadata.content().contentPropertiesType()) ++ .isEqualTo(CONTENT_PROPERTIES_TYPE); ++ } ++ ++ @Test ++ public void getOutputTensorMetadata_emptyTensor() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(1); ++ assertThat(outputMetadata.content()).isNull(); ++ } ++ ++ @Test ++ public void getOutputTensorMetadata_invalidTensor() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(2); ++ assertThat(outputMetadata.content().contentPropertiesType()) ++ .isEqualTo(CONTENT_PROPERTIES_TYPE); ++ } ++ ++ @Test ++ public void getInputTensorMetadata_indexGreaterThanTensorNumber_throwsException() ++ throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> metadataExtractor.getInputTensorMetadata(3)); ++ assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid."); ++ } ++ ++ @Test ++ public void getInputTensorMetadata_negtiveIndex_throwsException() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> metadataExtractor.getInputTensorMetadata(-1)); ++ assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid."); ++ } ++ ++ @Test ++ public void getOutputTensorMetadata_indexGreaterThanTensorNumber_throwsException() ++ throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> metadataExtractor.getOutputTensorMetadata(3)); ++ assertThat(exception).hasMessageThat().contains( ++ "The outputIndex specified is invalid."); ++ } ++ ++ @Test ++ public void getOutputTensorMetadata_negtiveIndex_throwsException() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> metadataExtractor.getOutputTensorMetadata(-1)); ++ assertThat(exception).hasMessageThat().contains( ++ "The outputIndex specified is invalid."); ++ } ++ ++ @Test ++ public void getInputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ QuantizationParams quantizationParams = ++ metadataExtractor.getInputTensorQuantizationParams(0); ++ assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE); ++ assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT); ++ } ++ ++ @Test ++ public void getInputTensorQuantizationParams_emptyTensor() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ QuantizationParams quantizationParams = ++ metadataExtractor.getInputTensorQuantizationParams(1); ++ // Scale and zero point are expected to be 1.0f and 0, respectively as default. ++ assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE); ++ assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT); ++ } ++ ++ @Test ++ public void getInputTensorQuantizationParams_invalidScale() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> metadataExtractor.getInputTensorQuantizationParams(2)); ++ assertThat(exception).hasMessageThat().contains( ++ "Input and output tensors do not support per-channel quantization."); ++ } ++ ++ @Test ++ public void getOutputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ QuantizationParams quantizationParams = ++ metadataExtractor.getOutputTensorQuantizationParams(0); ++ assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE); ++ assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT); ++ } ++ ++ @Test ++ public void getOutputTensorQuantizationParams_emptyTensor() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ QuantizationParams quantizationParams = ++ metadataExtractor.getOutputTensorQuantizationParams(1); ++ // Scale and zero point are expected to be 1.0f and 0, respectively as default. ++ assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE); ++ assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT); ++ } ++ ++ @Test ++ public void getOutputTensorQuantizationParams_invalidScale() throws Exception { ++ // Creates a model flatbuffer with metadata. ++ ByteBuffer modelWithMetadata = createModelByteBuffer(); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> metadataExtractor.getOutputTensorQuantizationParams(2)); ++ assertThat(exception).hasMessageThat().contains( ++ "Input and output tensors do not support per-channel quantization."); ++ } ++ ++ @Test ++ public void isMinimumParserVersionSatisfied_olderVersion() throws Exception { ++ // A version older than the current one. The version starts from 1.0.0, thus 0.10.0 will ++ // precede any furture versions. ++ String minVersion = "0.10"; ++ // Creates a metadata using the above version. ++ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion); ++ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ ++ assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue(); ++ } ++ ++ @Test ++ public void isMinimumParserVersionSatisfied_sameVersionSamelength() throws Exception { ++ // A version the same as the current one. ++ String minVersion = MetadataParser.VERSION; ++ // Creates a metadata using the above version. ++ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion); ++ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ ++ assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue(); ++ } ++ ++ @Test ++ public void isMinimumParserVersionSatisfied_sameVersionLongerlength() throws Exception { ++ // A version the same as the current one, but with longer length. ++ String minVersion = MetadataParser.VERSION + ".0"; ++ // Creates a metadata using the above version. ++ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion); ++ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ ++ assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue(); ++ } ++ ++ @Test ++ public void isMinimumParserVersionSatisfied_emptyVersion() throws Exception { ++ // An empty version, which can be generated before the first versioned release. ++ String minVersion = null; ++ // Creates a metadata using the above version. ++ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion); ++ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ ++ assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue(); ++ } ++ ++ @Test ++ public void isMinimumParserVersionSatisfied_newerVersion() throws Exception { ++ // Creates a version newer than the current one by appending "1" to the end of the ++ // current version for testing purposes. For example, 1.0.0 becomes 1.0.01. ++ String minVersion = MetadataParser.VERSION + "1"; ++ // Creates a metadata using the above version. ++ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion); ++ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ ++ assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse(); ++ } ++ ++ @Test ++ public void isMinimumParserVersionSatisfied_newerVersionLongerLength() throws Exception { ++ // Creates a version newer than the current one by appending ".1" to the end of the ++ // current version for testing purposes. For example, 1.0.0 becomes 1.0.0.1. ++ String minVersion = MetadataParser.VERSION + ".1"; ++ // Creates a metadata using the above version. ++ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion); ++ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE); ++ ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ ++ assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse(); ++ } ++ } ++ ++ /** Parameterized tests for the input tensor data type. */ ++ @RunWith(ParameterizedRobolectricTestRunner.class) ++ public static final class InputTensorType extends MetadataExtractorTest { ++ /** The tensor type that used to create the model buffer. */ ++ @Parameter(0) ++ public byte tensorType; ++ ++ /** A list of TensorType that is used in the test. */ ++ @Parameters ++ public static Collection<Object[]> data() { ++ return Arrays.asList(new Object[][] {{TensorType.FLOAT32}, {TensorType.INT32}, ++ {TensorType.UINT8}, {TensorType.INT64}, {TensorType.STRING}}); ++ } ++ ++ @Test ++ public void getInputTensorType_validTensor() throws Exception { ++ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null); ++ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType); ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ byte type = metadataExtractor.getInputTensorType(0); ++ assertThat(type).isEqualTo(tensorType); ++ } ++ ++ @Test ++ public void getOutputTensorType_validTensor() throws Exception { ++ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null); ++ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType); ++ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata); ++ byte type = metadataExtractor.getOutputTensorType(0); ++ assertThat(type).isEqualTo(tensorType); ++ } ++ } ++ ++ /** ++ * Creates an example metadata flatbuffer, which contains one subgraph with three inputs and ++ * three outputs. ++ */ ++ private static ByteBuffer createMetadataByteBuffer( ++ String identifier, @Nullable String minVersionStr) { ++ FlatBufferBuilder builder = new FlatBufferBuilder(); ++ ++ Content.startContent(builder); ++ Content.addContentPropertiesType(builder, CONTENT_PROPERTIES_TYPE); ++ int content = Content.endContent(builder); ++ ++ TensorMetadata.startTensorMetadata(builder); ++ TensorMetadata.addContent(builder, content); ++ int metadataForValidTensor = TensorMetadata.endTensorMetadata(builder); ++ ++ TensorMetadata.startTensorMetadata(builder); ++ int metadataForEmptyTensor = TensorMetadata.endTensorMetadata(builder); ++ ++ TensorMetadata.startTensorMetadata(builder); ++ TensorMetadata.addContent(builder, content); ++ int metadataForInvalidTensor = TensorMetadata.endTensorMetadata(builder); ++ ++ int[] tensorMetadataArray = new int[] { ++ metadataForValidTensor, metadataForEmptyTensor, metadataForInvalidTensor}; ++ int inputTensorMetadata = ++ SubGraphMetadata.createInputTensorMetadataVector(builder, tensorMetadataArray); ++ int outputTensorMetadata = ++ SubGraphMetadata.createOutputTensorMetadataVector(builder, tensorMetadataArray); ++ ++ SubGraphMetadata.startSubGraphMetadata(builder); ++ SubGraphMetadata.addInputTensorMetadata(builder, inputTensorMetadata); ++ SubGraphMetadata.addOutputTensorMetadata(builder, outputTensorMetadata); ++ int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder); ++ ++ int[] subgraphMetadataArray = new int[] {subgraph1Metadata}; ++ int subgraphsMetadata = ++ ModelMetadata.createSubgraphMetadataVector(builder, subgraphMetadataArray); ++ ++ int modelName = builder.createString(MODEL_NAME); ++ if (minVersionStr != null) { ++ int minVersion = builder.createString(minVersionStr); ++ ModelMetadata.startModelMetadata(builder); ++ ModelMetadata.addMinParserVersion(builder, minVersion); ++ } else { ++ // If minVersionStr is null, skip generating the field in the metadata. ++ ModelMetadata.startModelMetadata(builder); ++ } ++ ModelMetadata.addName(builder, modelName); ++ ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata); ++ int modelMetadata = ModelMetadata.endModelMetadata(builder); ++ ++ builder.finish(modelMetadata, identifier); ++ return builder.dataBuffer(); ++ } ++ ++ private static int createQuantizationParameters( ++ FlatBufferBuilder builder, float[] scale, long[] zeroPoint) { ++ int inputScale = QuantizationParameters.createScaleVector(builder, scale); ++ int inputZeroPoint = QuantizationParameters.createZeroPointVector(builder, zeroPoint); ++ QuantizationParameters.startQuantizationParameters(builder); ++ QuantizationParameters.addScale(builder, inputScale); ++ QuantizationParameters.addZeroPoint(builder, inputZeroPoint); ++ return QuantizationParameters.endQuantizationParameters(builder); ++ } ++ ++ private static int createTensor( ++ FlatBufferBuilder builder, int[] inputShape, byte inputType, int inputQuantization) { ++ int inputShapeVector1 = Tensor.createShapeVector(builder, inputShape); ++ Tensor.startTensor(builder); ++ Tensor.addShape(builder, inputShapeVector1); ++ Tensor.addType(builder, inputType); ++ Tensor.addQuantization(builder, inputQuantization); ++ return Tensor.endTensor(builder); ++ } ++ ++ /** ++ * Creates an example model flatbuffer, which contains one subgraph with three inputs and three ++ * output. ++ */ ++ private static ByteBuffer createModelByteBuffer(ByteBuffer metadataBuffer, byte dataType) { ++ FlatBufferBuilder builder = new FlatBufferBuilder(); ++ ++ // Creates a valid set of quantization parameters. ++ int validQuantization = createQuantizationParameters( ++ builder, new float[] {VALID_SCALE}, new long[] {VALID_ZERO_POINT}); ++ ++ // Creates an invalid set of quantization parameters. ++ int inValidQuantization = ++ createQuantizationParameters(builder, invalidScale, invalidZeroPoint); ++ ++ // Creates an input Tensor with valid quantization parameters. ++ int validTensor = createTensor(builder, validShape, dataType, validQuantization); ++ ++ // Creates an empty input Tensor. ++ Tensor.startTensor(builder); ++ int emptyTensor = Tensor.endTensor(builder); ++ ++ // Creates an input Tensor with invalid quantization parameters. ++ int invalidTensor = createTensor(builder, validShape, dataType, inValidQuantization); ++ ++ // Creates the SubGraph. ++ int[] tensors = new int[6]; ++ tensors[0] = validTensor; ++ tensors[1] = emptyTensor; ++ tensors[2] = invalidTensor; ++ tensors[3] = validTensor; ++ tensors[4] = emptyTensor; ++ tensors[5] = invalidTensor; ++ int subgraphTensors = SubGraph.createTensorsVector(builder, tensors); ++ ++ int subgraphInputs = SubGraph.createInputsVector(builder, new int[] {0, 1, 2}); ++ int subgraphOutputs = SubGraph.createOutputsVector(builder, new int[] {3, 4, 5}); ++ ++ SubGraph.startSubGraph(builder); ++ SubGraph.addTensors(builder, subgraphTensors); ++ SubGraph.addInputs(builder, subgraphInputs); ++ SubGraph.addOutputs(builder, subgraphOutputs); ++ int subgraph = SubGraph.endSubGraph(builder); ++ ++ // Creates the Model. ++ int[] subgraphs = new int[1]; ++ subgraphs[0] = subgraph; ++ int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs); ++ ++ // Inserts metadataBuffer into the model if it's not null. ++ int modelBuffers = EMPTY_FLATBUFFER_VECTOR; ++ int metadataArray = EMPTY_FLATBUFFER_VECTOR; ++ if (metadataBuffer != null) { ++ int data = Buffer.createDataVector(builder, metadataBuffer); ++ Buffer.startBuffer(builder); ++ Buffer.addData(builder, data); ++ int buffer = Buffer.endBuffer(builder); ++ modelBuffers = Model.createBuffersVector(builder, new int[] {buffer}); ++ ++ int metadataName = builder.createString(ModelInfo.METADATA_FIELD_NAME); ++ Metadata.startMetadata(builder); ++ Metadata.addName(builder, metadataName); ++ Metadata.addBuffer(builder, 0); ++ int metadata = Metadata.endMetadata(builder); ++ metadataArray = Model.createMetadataVector(builder, new int[] {metadata}); ++ } ++ ++ Model.startModel(builder); ++ Model.addSubgraphs(builder, modelSubgraphs); ++ if (modelBuffers != EMPTY_FLATBUFFER_VECTOR && metadataArray != EMPTY_FLATBUFFER_VECTOR) { ++ Model.addBuffers(builder, modelBuffers); ++ Model.addMetadata(builder, metadataArray); ++ } ++ int model = Model.endModel(builder); ++ builder.finish(model, TFLITE_MODEL_IDENTIFIER); ++ ++ return builder.dataBuffer(); ++ } ++ ++ /** Creates an example model flatbuffer with the default metadata and data type. */ ++ private static ByteBuffer createModelByteBuffer() { ++ ByteBuffer metadata = ++ createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, /*minVersionStr=*/null); ++ return createModelByteBuffer(metadata, DATA_TYPE); ++ } ++ ++ private static ByteBuffer loadMobileNetBuffer() throws Exception { ++ Context context = ApplicationProvider.getApplicationContext(); ++ // Loads a MobileNet model flatbuffer with metadata. The MobileNet model is a zip file that ++ // contains a label file as the associated file. ++ AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH); ++ FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); ++ FileChannel fileChannel = inputStream.getChannel(); ++ long startOffset = fileDescriptor.getStartOffset(); ++ long declaredLength = fileDescriptor.getDeclaredLength(); ++ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); ++ } ++ ++ private static ByteBuffer createRandomByteBuffer() { ++ byte[] buffer = new byte[20]; ++ new Random().nextBytes(buffer); ++ return ByteBuffer.wrap(buffer); + } +- int model = Model.endModel(builder); +- builder.finish(model, TFLITE_MODEL_IDENTIFIER); +- +- return builder.dataBuffer(); +- } +- +- /** Creates an example model flatbuffer with the default metadata and data type. */ +- private static ByteBuffer createModelByteBuffer() { +- ByteBuffer metadata = +- createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, /*minVersionStr=*/ null); +- return createModelByteBuffer(metadata, DATA_TYPE); +- } +- +- private static ByteBuffer loadMobileNetBuffer() throws Exception { +- Context context = ApplicationProvider.getApplicationContext(); +- // Loads a MobileNet model flatbuffer with metadata. The MobileNet model is a zip file that +- // contains a label file as the associated file. +- AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH); +- FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); +- FileChannel fileChannel = inputStream.getChannel(); +- long startOffset = fileDescriptor.getStartOffset(); +- long declaredLength = fileDescriptor.getDeclaredLength(); +- return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); +- } +- +- private static ByteBuffer createRandomByteBuffer() { +- byte[] buffer = new byte[20]; +- new Random().nextBytes(buffer); +- return ByteBuffer.wrap(buffer); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java +index a47566fec06e9..eede6750ea479 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java +@@ -17,20 +17,20 @@ package org.tensorflow.lite.support.metadata; + + import static com.google.common.truth.Truth.assertThat; + +-import java.util.regex.Pattern; + import org.junit.Test; + import org.junit.runner.RunWith; + import org.junit.runners.JUnit4; + ++import java.util.regex.Pattern; ++ + /** Tests of {@link MetadataParser}. */ + @RunWith(JUnit4.class) + public final class MetadataParserTest { +- +- @Test +- public void version_wellFormedAsSemanticVersion() throws Exception { +- // Validates that the version is well-formed (x.y.z). +- String pattern = "[0-9]+\\.[0-9]+\\.[0-9]+"; +- Pattern r = Pattern.compile(pattern); +- assertThat(MetadataParser.VERSION).matches(r); +- } ++ @Test ++ public void version_wellFormedAsSemanticVersion() throws Exception { ++ // Validates that the version is well-formed (x.y.z). ++ String pattern = "[0-9]+\\.[0-9]+\\.[0-9]+"; ++ Pattern r = Pattern.compile(pattern); ++ assertThat(MetadataParser.VERSION).matches(r); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java +index 61231e902e03e..80d2ddc6fd34e 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java +@@ -16,11 +16,20 @@ limitations under the License. + package org.tensorflow.lite.support.metadata; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + + import android.content.Context; + import android.content.res.AssetFileDescriptor; ++ + import androidx.test.core.app.ApplicationProvider; ++ ++import org.apache.commons.io.IOUtils; ++import org.junit.Ignore; ++import org.junit.Test; ++import org.junit.runner.RunWith; ++import org.robolectric.RobolectricTestRunner; ++ + import java.io.FileInputStream; + import java.io.InputStream; + import java.nio.ByteBuffer; +@@ -28,113 +37,102 @@ import java.nio.channels.FileChannel; + import java.util.HashSet; + import java.util.Set; + import java.util.zip.ZipException; +-import org.apache.commons.io.IOUtils; +-import org.junit.Test; +-import org.junit.runner.RunWith; +-import org.robolectric.RobolectricTestRunner; +- +-import org.junit.Ignore; + + /** Tests of {@link ZipFile}. */ + @RunWith(RobolectricTestRunner.class) + public final class ZipFileTest { +- +- // The TFLite model file is a zip file. +- private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite"; +- // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file. +- private static final String VALID_LABEL_FILE_NAME = "labels.txt"; +- // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite. +- private static final String INVALID_LABEL_FILE_NAME = "invalid.txt"; +- private final Context context = ApplicationProvider.getApplicationContext(); +- +- @Test +- public void zipFile_nullChannel_throwsException() throws Exception { +- NullPointerException exception = +- assertThrows(NullPointerException.class, () -> ZipFile.createFrom(null)); +- assertThat(exception).hasMessageThat().isEqualTo("The object reference is null."); +- } +- +- @Test +- public void zipFile_invalidFileWithExtremeSmallSize_throwsException() throws Exception { +- // The size limit for a zip file is the End head size, ZipConstant.ENDHDR, which is 22. +- ByteBuffer modelBuffer = ByteBuffer.allocate(21); +- ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer); +- +- ZipException exception = +- assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel)); +- assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive."); +- } +- +- @Test +- public void zipFile_invalidFileWithNoSignature_throwsException() throws Exception { +- // An invalid zip file that meets the size requirement but does not contain the zip signature. +- ByteBuffer modelBuffer = ByteBuffer.allocate(22); +- ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer); +- +- ZipException exception = +- assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel)); +- assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive."); +- } +- +- @Ignore +- @Test +- public void getFileNames_correctFileName() throws Exception { +- ByteBufferChannel modelChannel = loadModel(MODEL_PATH); +- ZipFile zipFile = ZipFile.createFrom(modelChannel); +- Set<String> expectedSet = new HashSet<>(); +- expectedSet.add(VALID_LABEL_FILE_NAME); +- assertThat(zipFile.getFileNames()).isEqualTo(expectedSet); +- } +- +- @Ignore +- @Test +- public void getRawInputStream_existentFile() throws Exception { +- ByteBufferChannel modelChannel = loadModel(MODEL_PATH); +- ZipFile zipFile = ZipFile.createFrom(modelChannel); +- InputStream fileStream = zipFile.getRawInputStream(VALID_LABEL_FILE_NAME); +- +- // Reads the golden file from context. +- InputStream goldenFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME); +- assertThat(IOUtils.contentEquals(goldenFileStream, fileStream)).isTrue(); +- } +- +- @Ignore +- @Test +- public void getRawInputStream_nonExistentFile() throws Exception { +- ByteBufferChannel modelChannel = loadModel(MODEL_PATH); +- ZipFile zipFile = ZipFile.createFrom(modelChannel); +- +- IllegalArgumentException exception = +- assertThrows( +- IllegalArgumentException.class, +- () -> zipFile.getRawInputStream(INVALID_LABEL_FILE_NAME)); +- assertThat(exception) +- .hasMessageThat() +- .isEqualTo( +- String.format( ++ // The TFLite model file is a zip file. ++ private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite"; ++ // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file. ++ private static final String VALID_LABEL_FILE_NAME = "labels.txt"; ++ // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite. ++ private static final String INVALID_LABEL_FILE_NAME = "invalid.txt"; ++ private final Context context = ApplicationProvider.getApplicationContext(); ++ ++ @Test ++ public void zipFile_nullChannel_throwsException() throws Exception { ++ NullPointerException exception = ++ assertThrows(NullPointerException.class, () -> ZipFile.createFrom(null)); ++ assertThat(exception).hasMessageThat().isEqualTo("The object reference is null."); ++ } ++ ++ @Test ++ public void zipFile_invalidFileWithExtremeSmallSize_throwsException() throws Exception { ++ // The size limit for a zip file is the End head size, ZipConstant.ENDHDR, which is 22. ++ ByteBuffer modelBuffer = ByteBuffer.allocate(21); ++ ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer); ++ ++ ZipException exception = ++ assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel)); ++ assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive."); ++ } ++ ++ @Test ++ public void zipFile_invalidFileWithNoSignature_throwsException() throws Exception { ++ // An invalid zip file that meets the size requirement but does not contain the zip ++ // signature. ++ ByteBuffer modelBuffer = ByteBuffer.allocate(22); ++ ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer); ++ ++ ZipException exception = ++ assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel)); ++ assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive."); ++ } ++ ++ @Ignore ++ @Test ++ public void getFileNames_correctFileName() throws Exception { ++ ByteBufferChannel modelChannel = loadModel(MODEL_PATH); ++ ZipFile zipFile = ZipFile.createFrom(modelChannel); ++ Set<String> expectedSet = new HashSet<>(); ++ expectedSet.add(VALID_LABEL_FILE_NAME); ++ assertThat(zipFile.getFileNames()).isEqualTo(expectedSet); ++ } ++ ++ @Ignore ++ @Test ++ public void getRawInputStream_existentFile() throws Exception { ++ ByteBufferChannel modelChannel = loadModel(MODEL_PATH); ++ ZipFile zipFile = ZipFile.createFrom(modelChannel); ++ InputStream fileStream = zipFile.getRawInputStream(VALID_LABEL_FILE_NAME); ++ ++ // Reads the golden file from context. ++ InputStream goldenFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME); ++ assertThat(IOUtils.contentEquals(goldenFileStream, fileStream)).isTrue(); ++ } ++ ++ @Ignore ++ @Test ++ public void getRawInputStream_nonExistentFile() throws Exception { ++ ByteBufferChannel modelChannel = loadModel(MODEL_PATH); ++ ZipFile zipFile = ZipFile.createFrom(modelChannel); ++ ++ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, ++ () -> zipFile.getRawInputStream(INVALID_LABEL_FILE_NAME)); ++ assertThat(exception).hasMessageThat().isEqualTo(String.format( + "The file, %s, does not exist in the zip file.", INVALID_LABEL_FILE_NAME)); +- } +- +- @Ignore +- @Test +- public void close_validStatus() throws Exception { +- ByteBufferChannel modelChannel = loadModel(MODEL_PATH); +- ZipFile zipFile = ZipFile.createFrom(modelChannel); +- // Should do nothing (including not throwing an exception). +- zipFile.close(); +- } +- +- private static ByteBufferChannel loadModel(String modelPath) throws Exception { +- // Creates a ZipFile with a TFLite model flatbuffer with metadata. The MobileNet +- // model is a zip file that contains a label file as the associated file. +- Context context = ApplicationProvider.getApplicationContext(); +- AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelPath); +- FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); +- FileChannel fileChannel = inputStream.getChannel(); +- long startOffset = fileDescriptor.getStartOffset(); +- long declaredLength = fileDescriptor.getDeclaredLength(); +- ByteBuffer modelBuffer = +- fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); +- return new ByteBufferChannel(modelBuffer); +- } ++ } ++ ++ @Ignore ++ @Test ++ public void close_validStatus() throws Exception { ++ ByteBufferChannel modelChannel = loadModel(MODEL_PATH); ++ ZipFile zipFile = ZipFile.createFrom(modelChannel); ++ // Should do nothing (including not throwing an exception). ++ zipFile.close(); ++ } ++ ++ private static ByteBufferChannel loadModel(String modelPath) throws Exception { ++ // Creates a ZipFile with a TFLite model flatbuffer with metadata. The MobileNet ++ // model is a zip file that contains a label file as the associated file. ++ Context context = ApplicationProvider.getApplicationContext(); ++ AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelPath); ++ FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); ++ FileChannel fileChannel = inputStream.getChannel(); ++ long startOffset = fileDescriptor.getStartOffset(); ++ long declaredLength = fileDescriptor.getDeclaredLength(); ++ ByteBuffer modelBuffer = ++ fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); ++ return new ByteBufferChannel(modelBuffer); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h +index 110186bb63a1b..0c494915e7357 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h +@@ -19,7 +19,8 @@ + NS_ASSUME_NONNULL_BEGIN + + /** Types of image sources. */ +-typedef NSInteger GMLImageSourceType NS_TYPED_ENUM NS_SWIFT_NAME(MLImageSourceType); ++typedef NSInteger GMLImageSourceType NS_TYPED_ENUM ++ NS_SWIFT_NAME(MLImageSourceType); + /** Image source is a `UIImage`. */ + static const GMLImageSourceType GMLImageSourceTypeImage = 0; + /** Image source is a `CVPixelBuffer`. */ +@@ -38,8 +39,9 @@ NS_SWIFT_NAME(MLImage) + @property(nonatomic, readonly) CGFloat height; + + /** +- * The display orientation of the image. If `imageSourceType` is `.image`, the default value is +- * `image.imageOrientation`; otherwise the default value is `.up`. ++ * The display orientation of the image. If `imageSourceType` is `.image`, the ++ * default value is `image.imageOrientation`; otherwise the default value is ++ * `.up`. + */ + @property(nonatomic) UIImageOrientation orientation; + +@@ -47,30 +49,34 @@ NS_SWIFT_NAME(MLImage) + @property(nonatomic, readonly) GMLImageSourceType imageSourceType; + + /** The source image. `nil` if `imageSourceType` is not `.image`. */ +-@property(nonatomic, readonly, nullable) UIImage *image; ++@property(nonatomic, readonly, nullable) UIImage* image; + +-/** The source pixel buffer. `nil` if `imageSourceType` is not `.pixelBuffer`. */ ++/** The source pixel buffer. `nil` if `imageSourceType` is not `.pixelBuffer`. ++ */ + @property(nonatomic, readonly, nullable) CVPixelBufferRef pixelBuffer; + +-/** The source sample buffer. `nil` if `imageSourceType` is not `.sampleBuffer`. */ ++/** The source sample buffer. `nil` if `imageSourceType` is not `.sampleBuffer`. ++ */ + @property(nonatomic, readonly, nullable) CMSampleBufferRef sampleBuffer; + + /** + * Initializes an `MLImage` object with the given image. + * +- * @param image The image to use as the source. Its `CGImage` property must not be `NULL`. +- * @return A new `MLImage` instance with the given image as the source. `nil` if the given `image` +- * is `nil` or invalid. ++ * @param image The image to use as the source. Its `CGImage` property must not ++ * be `NULL`. ++ * @return A new `MLImage` instance with the given image as the source. `nil` if ++ * the given `image` is `nil` or invalid. + */ +-- (nullable instancetype)initWithImage:(UIImage *)image NS_DESIGNATED_INITIALIZER; ++- (nullable instancetype)initWithImage:(UIImage*)image ++ NS_DESIGNATED_INITIALIZER; + + /** + * Initializes an `MLImage` object with the given pixel buffer. + * +- * @param pixelBuffer The pixel buffer to use as the source. It will be retained by the new +- * `MLImage` instance for the duration of its lifecycle. +- * @return A new `MLImage` instance with the given pixel buffer as the source. `nil` if the given +- * pixel buffer is `nil` or invalid. ++ * @param pixelBuffer The pixel buffer to use as the source. It will be retained ++ * by the new `MLImage` instance for the duration of its lifecycle. ++ * @return A new `MLImage` instance with the given pixel buffer as the source. ++ * `nil` if the given pixel buffer is `nil` or invalid. + */ + - (nullable instancetype)initWithPixelBuffer:(CVPixelBufferRef)pixelBuffer + NS_DESIGNATED_INITIALIZER; +@@ -78,12 +84,13 @@ NS_SWIFT_NAME(MLImage) + /** + * Initializes an `MLImage` object with the given sample buffer. + * +- * @param sampleBuffer The sample buffer to use as the source. It will be retained by the new +- * `MLImage` instance for the duration of its lifecycle. The sample buffer must be based on a +- * pixel buffer (not compressed data). In practice, it should be the video output of the camera +- * on an iOS device, not other arbitrary types of `CMSampleBuffer`s. +- * @return A new `MLImage` instance with the given sample buffer as the source. `nil` if the given +- * sample buffer is `nil` or invalid. ++ * @param sampleBuffer The sample buffer to use as the source. It will be ++ * retained by the new `MLImage` instance for the duration of its lifecycle. The ++ * sample buffer must be based on a pixel buffer (not compressed data). In ++ * practice, it should be the video output of the camera on an iOS device, not ++ * other arbitrary types of `CMSampleBuffer`s. ++ * @return A new `MLImage` instance with the given sample buffer as the source. ++ * `nil` if the given sample buffer is `nil` or invalid. + */ + - (nullable instancetype)initWithSampleBuffer:(CMSampleBufferRef)sampleBuffer + NS_DESIGNATED_INITIALIZER; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/sources/GMLImage.m b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/sources/GMLImage.m +index 094d4e01377d2..38ca74268acc1 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/sources/GMLImage.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/sources/GMLImage.m +@@ -20,7 +20,7 @@ NS_ASSUME_NONNULL_BEGIN + + #pragma mark - Public + +-- (nullable instancetype)initWithImage:(UIImage *)image { ++- (nullable instancetype)initWithImage:(UIImage*)image { + if (image.CGImage == NULL) { + return nil; + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/tests/GMLImageTests.m b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/tests/GMLImageTests.m +index 59205747e416a..8abee1ab2f171 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/tests/GMLImageTests.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/tests/GMLImageTests.m +@@ -22,8 +22,8 @@ + + NS_ASSUME_NONNULL_BEGIN + +-static NSString *const kTestImageName = @"grace_hopper"; +-static NSString *const kTestImageType = @"jpg"; ++static NSString* const kTestImageName = @"grace_hopper"; ++static NSString* const kTestImageType = @"jpg"; + static CGFloat kTestImageWidthInPixels = 517.0f; + static CGFloat kTestImageHeightInPixels = 606.0f; + +@@ -31,7 +31,7 @@ static CGFloat kTestImageHeightInPixels = 606.0f; + @interface GMLImageTests : XCTestCase + + /** Test image. */ +-@property(nonatomic, nullable) UIImage *image; ++@property(nonatomic, nullable) UIImage* image; + + @end + +@@ -41,8 +41,9 @@ static CGFloat kTestImageHeightInPixels = 606.0f; + + - (void)setUp { + [super setUp]; +- NSString *imageName = [[NSBundle bundleForClass:[self class]] pathForResource:kTestImageName +- ofType:kTestImageType]; ++ NSString* imageName = ++ [[NSBundle bundleForClass:[self class]] pathForResource:kTestImageName ++ ofType:kTestImageType]; + self.image = [[UIImage alloc] initWithContentsOfFile:imageName]; + } + +@@ -52,53 +53,59 @@ static CGFloat kTestImageHeightInPixels = 606.0f; + } + + - (void)testInitWithImage { +- GMLImage *mlImage = [[GMLImage alloc] initWithImage:self.image]; ++ GMLImage* mlImage = [[GMLImage alloc] initWithImage:self.image]; + XCTAssertNotNil(mlImage); + XCTAssertEqual(mlImage.imageSourceType, GMLImageSourceTypeImage); + XCTAssertEqual(mlImage.orientation, self.image.imageOrientation); + mlImage.orientation = UIImageOrientationDown; + XCTAssertEqual(mlImage.orientation, UIImageOrientationDown); +- XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels, FLT_EPSILON); +- XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels, FLT_EPSILON); ++ XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels, ++ FLT_EPSILON); ++ XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels, ++ FLT_EPSILON); + } + + - (void)testInitWithImage_nilImage { +- GMLImage *mlImage = [[GMLImage alloc] initWithImage:nil]; ++ GMLImage* mlImage = [[GMLImage alloc] initWithImage:nil]; + XCTAssertNil(mlImage); + } + + - (void)testInitWithSampleBuffer { + CMSampleBufferRef sampleBuffer = [self sampleBuffer]; +- GMLImage *mlImage = [[GMLImage alloc] initWithSampleBuffer:sampleBuffer]; ++ GMLImage* mlImage = [[GMLImage alloc] initWithSampleBuffer:sampleBuffer]; + XCTAssertNotNil(mlImage); + XCTAssertEqual(mlImage.imageSourceType, GMLImageSourceTypeSampleBuffer); + XCTAssertEqual(mlImage.orientation, UIImageOrientationUp); + mlImage.orientation = UIImageOrientationDown; + XCTAssertEqual(mlImage.orientation, UIImageOrientationDown); +- XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels, FLT_EPSILON); +- XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels, FLT_EPSILON); ++ XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels, ++ FLT_EPSILON); ++ XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels, ++ FLT_EPSILON); + } + + - (void)testInitWithSampleBuffer_nilImage { +- GMLImage *mlImage = [[GMLImage alloc] initWithSampleBuffer:nil]; ++ GMLImage* mlImage = [[GMLImage alloc] initWithSampleBuffer:nil]; + XCTAssertNil(mlImage); + } + + - (void)testInitWithPixelBuffer { + CMSampleBufferRef sampleBuffer = [self sampleBuffer]; + CVPixelBufferRef pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer); +- GMLImage *mlImage = [[GMLImage alloc] initWithPixelBuffer:pixelBuffer]; ++ GMLImage* mlImage = [[GMLImage alloc] initWithPixelBuffer:pixelBuffer]; + XCTAssertNotNil(mlImage); + XCTAssertEqual(mlImage.imageSourceType, GMLImageSourceTypePixelBuffer); + XCTAssertEqual(mlImage.orientation, UIImageOrientationUp); + mlImage.orientation = UIImageOrientationDown; + XCTAssertEqual(mlImage.orientation, UIImageOrientationDown); +- XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels, FLT_EPSILON); +- XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels, FLT_EPSILON); ++ XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels, ++ FLT_EPSILON); ++ XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels, ++ FLT_EPSILON); + } + + - (void)testInitWithPixelBuffer_nilImage { +- GMLImage *mlImage = [[GMLImage alloc] initWithPixelBuffer:nil]; ++ GMLImage* mlImage = [[GMLImage alloc] initWithPixelBuffer:nil]; + XCTAssertNil(mlImage); + } + +@@ -117,17 +124,18 @@ static CGFloat kTestImageHeightInPixels = 606.0f; + size_t bpr = CGImageGetBytesPerRow(CGImage); + + CGDataProviderRef provider = CGImageGetDataProvider(CGImage); +- NSData *imageRGBAData = (id)CFBridgingRelease(CGDataProviderCopyData(provider)); ++ NSData* imageRGBAData = ++ (id)CFBridgingRelease(CGDataProviderCopyData(provider)); + const uint8_t order[4] = {2, 1, 0, 3}; + +- NSData *imageBGRAData = nil; +- unsigned char *bgraPixel = (unsigned char *)malloc([imageRGBAData length]); ++ NSData* imageBGRAData = nil; ++ unsigned char* bgraPixel = (unsigned char*)malloc([imageRGBAData length]); + if (bgraPixel) { + vImage_Buffer src; + src.height = height; + src.width = width; + src.rowBytes = bpr; +- src.data = (void *)[imageRGBAData bytes]; ++ src.data = (void*)[imageRGBAData bytes]; + + vImage_Buffer dest; + dest.height = height; +@@ -136,11 +144,13 @@ static CGFloat kTestImageHeightInPixels = 606.0f; + dest.data = bgraPixel; + + // Specify ordering changes in map. +- vImage_Error error = vImagePermuteChannels_ARGB8888(&src, &dest, order, kvImageNoFlags); ++ vImage_Error error = ++ vImagePermuteChannels_ARGB8888(&src, &dest, order, kvImageNoFlags); + + // Package the result. + if (error == kvImageNoError) { +- imageBGRAData = [NSData dataWithBytes:bgraPixel length:[imageRGBAData length]]; ++ imageBGRAData = [NSData dataWithBytes:bgraPixel ++ length:[imageRGBAData length]]; + } + + // Memory cleanup. +@@ -152,14 +162,15 @@ static CGFloat kTestImageHeightInPixels = 606.0f; + } + + // Write data to `CMSampleBuffer`. +- NSDictionary *options = @{ +- (__bridge NSString *)kCVPixelBufferCGImageCompatibilityKey : @(YES), +- (__bridge NSString *)kCVPixelBufferCGBitmapContextCompatibilityKey : @(YES) ++ NSDictionary* options = @{ ++ (__bridge NSString*)kCVPixelBufferCGImageCompatibilityKey : @(YES), ++ (__bridge NSString*)kCVPixelBufferCGBitmapContextCompatibilityKey : @(YES) + }; + CVPixelBufferRef pixelBuffer; + CVReturn status = CVPixelBufferCreateWithBytes( +- kCFAllocatorDefault, width, height, kCVPixelFormatType_32BGRA, (void *)[imageBGRAData bytes], +- bpr, NULL, nil, (__bridge CFDictionaryRef)options, &pixelBuffer); ++ kCFAllocatorDefault, width, height, kCVPixelFormatType_32BGRA, ++ (void*)[imageBGRAData bytes], bpr, NULL, nil, ++ (__bridge CFDictionaryRef)options, &pixelBuffer); + + if (status != kCVReturnSuccess) { + XCTFail(@"Failed to create pixel buffer."); +@@ -167,10 +178,12 @@ static CGFloat kTestImageHeightInPixels = 606.0f; + + CVPixelBufferLockBaseAddress(pixelBuffer, 0); + CMVideoFormatDescriptionRef videoInfo = NULL; +- CMVideoFormatDescriptionCreateForImageBuffer(kCFAllocatorDefault, pixelBuffer, &videoInfo); ++ CMVideoFormatDescriptionCreateForImageBuffer(kCFAllocatorDefault, pixelBuffer, ++ &videoInfo); + + CMSampleBufferRef buffer; +- CMSampleBufferCreateForImageBuffer(kCFAllocatorDefault, pixelBuffer, true, NULL, NULL, videoInfo, ++ CMSampleBufferCreateForImageBuffer(kCFAllocatorDefault, pixelBuffer, true, ++ NULL, NULL, videoInfo, + &kCMTimingInfoInvalid, &buffer); + + CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java +index a32fc24749e0c..59116a72a0533 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java +@@ -24,28 +24,27 @@ import android.graphics.Bitmap; + * {@link IllegalArgumentException} will be thrown. + */ + public final class BitmapExtractor { +- +- /** +- * Extracts a {@link android.graphics.Bitmap} from an {@link MlImage}. +- * +- * <p>Notice: Properties of the {@code image} like rotation will not take effects. +- * +- * @param image the image to extract {@link android.graphics.Bitmap} from. +- * @return the {@link android.graphics.Bitmap} stored in {@link MlImage} +- * @throws IllegalArgumentException when the extraction requires unsupported format or data type +- * conversions. +- */ +- public static Bitmap extract(MlImage image) { +- ImageContainer imageContainer = image.getContainer(MlImage.STORAGE_TYPE_BITMAP); +- if (imageContainer != null) { +- return ((BitmapImageContainer) imageContainer).getBitmap(); +- } else { +- // TODO(b/180504869): Support ByteBuffer -> Bitmap conversion. +- throw new IllegalArgumentException( +- "Extracting Bitmap from an MlImage created by objects other than Bitmap is not" +- + " supported"); ++ /** ++ * Extracts a {@link android.graphics.Bitmap} from an {@link MlImage}. ++ * ++ * <p>Notice: Properties of the {@code image} like rotation will not take effects. ++ * ++ * @param image the image to extract {@link android.graphics.Bitmap} from. ++ * @return the {@link android.graphics.Bitmap} stored in {@link MlImage} ++ * @throws IllegalArgumentException when the extraction requires unsupported format or data type ++ * conversions. ++ */ ++ public static Bitmap extract(MlImage image) { ++ ImageContainer imageContainer = image.getContainer(MlImage.STORAGE_TYPE_BITMAP); ++ if (imageContainer != null) { ++ return ((BitmapImageContainer) imageContainer).getBitmap(); ++ } else { ++ // TODO(b/180504869): Support ByteBuffer -> Bitmap conversion. ++ throw new IllegalArgumentException( ++ "Extracting Bitmap from an MlImage created by objects other than Bitmap is not" ++ + " supported"); ++ } + } +- } + +- private BitmapExtractor() {} ++ private BitmapExtractor() {} + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java +index 77e63f0351449..b1b02f8e369ec 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java +@@ -16,44 +16,44 @@ limitations under the License. + package com.google.android.odml.image; + + import android.graphics.Bitmap; ++ + import com.google.android.odml.image.MlImage.ImageFormat; + + class BitmapImageContainer implements ImageContainer { ++ private final Bitmap bitmap; ++ private final ImageProperties properties; ++ ++ public BitmapImageContainer(Bitmap bitmap) { ++ this.bitmap = bitmap; ++ this.properties = ImageProperties.builder() ++ .setImageFormat(convertFormatCode(bitmap.getConfig())) ++ .setStorageType(MlImage.STORAGE_TYPE_BITMAP) ++ .build(); ++ } ++ ++ public Bitmap getBitmap() { ++ return bitmap; ++ } ++ ++ @Override ++ public ImageProperties getImageProperties() { ++ return properties; ++ } ++ ++ @Override ++ public void close() { ++ bitmap.recycle(); ++ } + +- private final Bitmap bitmap; +- private final ImageProperties properties; +- +- public BitmapImageContainer(Bitmap bitmap) { +- this.bitmap = bitmap; +- this.properties = ImageProperties.builder() +- .setImageFormat(convertFormatCode(bitmap.getConfig())) +- .setStorageType(MlImage.STORAGE_TYPE_BITMAP) +- .build(); +- } +- +- public Bitmap getBitmap() { +- return bitmap; +- } +- +- @Override +- public ImageProperties getImageProperties() { +- return properties; +- } +- +- @Override +- public void close() { +- bitmap.recycle(); +- } +- +- @ImageFormat +- static int convertFormatCode(Bitmap.Config config) { +- switch (config) { +- case ALPHA_8: +- return MlImage.IMAGE_FORMAT_ALPHA; +- case ARGB_8888: +- return MlImage.IMAGE_FORMAT_RGBA; +- default: +- return MlImage.IMAGE_FORMAT_UNKNOWN; ++ @ImageFormat ++ static int convertFormatCode(Bitmap.Config config) { ++ switch (config) { ++ case ALPHA_8: ++ return MlImage.IMAGE_FORMAT_ALPHA; ++ case ARGB_8888: ++ return MlImage.IMAGE_FORMAT_RGBA; ++ default: ++ return MlImage.IMAGE_FORMAT_UNKNOWN; ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java +index fe9c35a8a6ede..6c4552bfdac3a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java +@@ -20,6 +20,7 @@ import android.graphics.Bitmap; + import android.graphics.Rect; + import android.net.Uri; + import android.provider.MediaStore; ++ + import java.io.IOException; + + /** +@@ -32,82 +33,76 @@ import java.io.IOException; + * <p>Use {@link BitmapExtractor} to get {@link android.graphics.Bitmap} you passed in. + */ + public class BitmapMlImageBuilder { ++ // Mandatory fields. ++ private final Bitmap bitmap; + +- // Mandatory fields. +- private final Bitmap bitmap; +- +- // Optional fields. +- private int rotation; +- private Rect roi; +- private long timestamp; ++ // Optional fields. ++ private int rotation; ++ private Rect roi; ++ private long timestamp; + +- /** +- * Creates the builder with a mandatory {@link android.graphics.Bitmap}. +- * +- * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values +- * will be set with default: +- * +- * <ul> +- * <li>rotation: 0 +- * </ul> +- * +- * @param bitmap image data object. +- */ +- public BitmapMlImageBuilder(Bitmap bitmap) { +- this.bitmap = bitmap; +- rotation = 0; +- roi = new Rect(0, 0, bitmap.getWidth(), bitmap.getHeight()); +- timestamp = 0; +- } ++ /** ++ * Creates the builder with a mandatory {@link android.graphics.Bitmap}. ++ * ++ * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the ++ * values will be set with default: ++ * ++ * <ul> ++ * <li>rotation: 0 ++ * </ul> ++ * ++ * @param bitmap image data object. ++ */ ++ public BitmapMlImageBuilder(Bitmap bitmap) { ++ this.bitmap = bitmap; ++ rotation = 0; ++ roi = new Rect(0, 0, bitmap.getWidth(), bitmap.getHeight()); ++ timestamp = 0; ++ } + +- /** +- * Creates the builder to build {@link MlImage} from a file. +- * +- * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values +- * will be set with default: +- * +- * <ul> +- * <li>rotation: 0 +- * </ul> +- * +- * @param context the application context. +- * @param uri the path to the resource file. +- */ +- public BitmapMlImageBuilder(Context context, Uri uri) throws IOException { +- this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri)); +- } ++ /** ++ * Creates the builder to build {@link MlImage} from a file. ++ * ++ * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the ++ * values will be set with default: ++ * ++ * <ul> ++ * <li>rotation: 0 ++ * </ul> ++ * ++ * @param context the application context. ++ * @param uri the path to the resource file. ++ */ ++ public BitmapMlImageBuilder(Context context, Uri uri) throws IOException { ++ this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri)); ++ } + +- /** +- * Sets value for {@link MlImage#getRotation()}. +- * +- * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270. +- */ +- public BitmapMlImageBuilder setRotation(int rotation) { +- MlImage.validateRotation(rotation); +- this.rotation = rotation; +- return this; +- } ++ /** ++ * Sets value for {@link MlImage#getRotation()}. ++ * ++ * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270. ++ */ ++ public BitmapMlImageBuilder setRotation(int rotation) { ++ MlImage.validateRotation(rotation); ++ this.rotation = rotation; ++ return this; ++ } + +- /** Sets value for {@link MlImage#getRoi()}. */ +- BitmapMlImageBuilder setRoi(Rect roi) { +- this.roi = roi; +- return this; +- } ++ /** Sets value for {@link MlImage#getRoi()}. */ ++ BitmapMlImageBuilder setRoi(Rect roi) { ++ this.roi = roi; ++ return this; ++ } + +- /** Sets value for {@link MlImage#getTimestamp()}. */ +- BitmapMlImageBuilder setTimestamp(long timestamp) { +- this.timestamp = timestamp; +- return this; +- } ++ /** Sets value for {@link MlImage#getTimestamp()}. */ ++ BitmapMlImageBuilder setTimestamp(long timestamp) { ++ this.timestamp = timestamp; ++ return this; ++ } + +- /** Builds an {@link MlImage} instance. */ +- public MlImage build() { +- return new MlImage( +- new BitmapImageContainer(bitmap), +- rotation, +- roi, +- timestamp, +- bitmap.getWidth(), +- bitmap.getHeight()); +- } ++ /** Builds an {@link MlImage} instance. */ ++ public MlImage build() { ++ return new MlImage(new BitmapImageContainer(bitmap), rotation, roi, timestamp, ++ bitmap.getWidth(), bitmap.getHeight()); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java +index 7b86be6d1b533..d5861c8ca94ac 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java +@@ -19,8 +19,10 @@ import android.graphics.Bitmap; + import android.graphics.Bitmap.Config; + import android.os.Build.VERSION; + import android.os.Build.VERSION_CODES; ++ + import com.google.android.odml.image.MlImage.ImageFormat; + import com.google.auto.value.AutoValue; ++ + import java.nio.ByteBuffer; + import java.nio.ByteOrder; + import java.util.Locale; +@@ -32,229 +34,234 @@ import java.util.Locale; + * otherwise {@link IllegalArgumentException} will be thrown. + */ + public class ByteBufferExtractor { +- +- /** +- * Extracts a {@link ByteBuffer} from an {@link MlImage}. +- * +- * <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link +- * ImageProperties} whose storage type is {@code MlImage.STORAGE_TYPE_BYTEBUFFER}. +- * +- * @see MlImage#getContainedImageProperties() +- * @return A read-only {@link ByteBuffer}. +- * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage. +- */ +- public static ByteBuffer extract(MlImage image) { +- ImageContainer container = image.getContainer(); +- switch (container.getImageProperties().getStorageType()) { +- case MlImage.STORAGE_TYPE_BYTEBUFFER: +- ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; +- return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); +- default: +- throw new IllegalArgumentException( +- "Extract ByteBuffer from an MlImage created by objects other than Bytebuffer is not" +- + " supported"); +- } +- } +- +- /** +- * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link MlImage}. +- * +- * <p>Notice: Properties of the {@code image} like rotation will not take effects. +- * +- * <p>Format conversion spec: +- * +- * <ul> +- * <li>When extracting RGB images to RGBA format, A channel will always set to 255. +- * <li>When extracting RGBA images to RGB format, A channel will be dropped. +- * </ul> +- * +- * @param image the image to extract buffer from. +- * @param targetFormat the image format of the result bytebuffer. +- * @return the readonly {@link ByteBuffer} stored in {@link MlImage} +- * @throws IllegalArgumentException when the extraction requires unsupported format or data type +- * conversions. +- */ +- static ByteBuffer extract(MlImage image, @ImageFormat int targetFormat) { +- ImageContainer container; +- ImageProperties byteBufferProperties = +- ImageProperties.builder() +- .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER) +- .setImageFormat(targetFormat) +- .build(); +- if ((container = image.getContainer(byteBufferProperties)) != null) { +- ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; +- return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); +- } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) { +- ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; +- @ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat(); +- return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat) +- .asReadOnlyBuffer(); +- } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) { +- BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container; +- ByteBuffer byteBuffer = +- extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat) +- .asReadOnlyBuffer(); +- image.addContainer(new ByteBufferImageContainer(byteBuffer, targetFormat)); +- return byteBuffer; +- } else { +- throw new IllegalArgumentException( +- "Extracting ByteBuffer from an MlImage created by objects other than Bitmap or" +- + " Bytebuffer is not supported"); +- } +- } +- +- /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */ +- @AutoValue +- abstract static class Result { + /** +- * Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MlImage)}. ++ * Extracts a {@link ByteBuffer} from an {@link MlImage}. ++ * ++ * <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link ++ * ImageProperties} whose storage type is {@code MlImage.STORAGE_TYPE_BYTEBUFFER}. ++ * ++ * @see MlImage#getContainedImageProperties() ++ * @return A read-only {@link ByteBuffer}. ++ * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage. + */ +- public abstract ByteBuffer buffer(); ++ public static ByteBuffer extract(MlImage image) { ++ ImageContainer container = image.getContainer(); ++ switch (container.getImageProperties().getStorageType()) { ++ case MlImage.STORAGE_TYPE_BYTEBUFFER: ++ ByteBufferImageContainer byteBufferImageContainer = ++ (ByteBufferImageContainer) container; ++ return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); ++ default: ++ throw new IllegalArgumentException( ++ "Extract ByteBuffer from an MlImage created by objects other than Bytebuffer is not" ++ + " supported"); ++ } ++ } + + /** +- * Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(MlImage)}. ++ * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link MlImage}. ++ * ++ * <p>Notice: Properties of the {@code image} like rotation will not take effects. ++ * ++ * <p>Format conversion spec: ++ * ++ * <ul> ++ * <li>When extracting RGB images to RGBA format, A channel will always set to 255. ++ * <li>When extracting RGBA images to RGB format, A channel will be dropped. ++ * </ul> ++ * ++ * @param image the image to extract buffer from. ++ * @param targetFormat the image format of the result bytebuffer. ++ * @return the readonly {@link ByteBuffer} stored in {@link MlImage} ++ * @throws IllegalArgumentException when the extraction requires unsupported format or data type ++ * conversions. + */ +- @ImageFormat +- public abstract int format(); +- +- static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) { +- return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat); ++ static ByteBuffer extract(MlImage image, @ImageFormat int targetFormat) { ++ ImageContainer container; ++ ImageProperties byteBufferProperties = ++ ImageProperties.builder() ++ .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER) ++ .setImageFormat(targetFormat) ++ .build(); ++ if ((container = image.getContainer(byteBufferProperties)) != null) { ++ ByteBufferImageContainer byteBufferImageContainer = ++ (ByteBufferImageContainer) container; ++ return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); ++ } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) { ++ ByteBufferImageContainer byteBufferImageContainer = ++ (ByteBufferImageContainer) container; ++ @ImageFormat ++ int sourceFormat = byteBufferImageContainer.getImageFormat(); ++ return convertByteBuffer( ++ byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat) ++ .asReadOnlyBuffer(); ++ } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) { ++ BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container; ++ ByteBuffer byteBuffer = ++ extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat) ++ .asReadOnlyBuffer(); ++ image.addContainer(new ByteBufferImageContainer(byteBuffer, targetFormat)); ++ return byteBuffer; ++ } else { ++ throw new IllegalArgumentException( ++ "Extracting ByteBuffer from an MlImage created by objects other than Bitmap or" ++ + " Bytebuffer is not supported"); ++ } + } +- } + +- /** +- * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link MlImage}. +- * +- * <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy. +- * +- * <p>Notice: Properties of the {@code image} like rotation will not take effects. +- * +- * @return the readonly {@link ByteBuffer} stored in {@link MlImage} +- * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with +- * given {@code imageFormat} +- */ +- static Result extractInRecommendedFormat(MlImage image) { +- ImageContainer container; +- if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) { +- Bitmap bitmap = ((BitmapImageContainer) container).getBitmap(); +- @ImageFormat int format = adviseImageFormat(bitmap); +- Result result = +- Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format); ++ /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */ ++ @AutoValue ++ abstract static class Result { ++ /** ++ * Gets the {@link ByteBuffer} in the result of {@link ++ * ByteBufferExtractor#extract(MlImage)}. ++ */ ++ public abstract ByteBuffer buffer(); + +- image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format())); +- return result; +- } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) { +- ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; +- return Result.create( +- byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(), +- byteBufferImageContainer.getImageFormat()); +- } else { +- throw new IllegalArgumentException( +- "Extract ByteBuffer from an MlImage created by objects other than Bitmap or Bytebuffer" +- + " is not supported"); ++ /** ++ * Gets the {@link ImageFormat} in the result of {@link ++ * ByteBufferExtractor#extract(MlImage)}. ++ */ ++ @ImageFormat ++ public abstract int format(); ++ ++ static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) { ++ return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat); ++ } + } +- } + +- @ImageFormat +- private static int adviseImageFormat(Bitmap bitmap) { +- if (bitmap.getConfig() == Config.ARGB_8888) { +- return MlImage.IMAGE_FORMAT_RGBA; +- } else { +- throw new IllegalArgumentException( +- String.format( +- "Extracting ByteBuffer from an MlImage created by a Bitmap in config %s is not" +- + " supported", +- bitmap.getConfig())); ++ /** ++ * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link MlImage}. ++ * ++ * <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid ++ * copy. ++ * ++ * <p>Notice: Properties of the {@code image} like rotation will not take effects. ++ * ++ * @return the readonly {@link ByteBuffer} stored in {@link MlImage} ++ * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with ++ * given {@code imageFormat} ++ */ ++ static Result extractInRecommendedFormat(MlImage image) { ++ ImageContainer container; ++ if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) { ++ Bitmap bitmap = ((BitmapImageContainer) container).getBitmap(); ++ @ImageFormat ++ int format = adviseImageFormat(bitmap); ++ Result result = Result.create( ++ extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format); ++ ++ image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format())); ++ return result; ++ } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) { ++ ByteBufferImageContainer byteBufferImageContainer = ++ (ByteBufferImageContainer) container; ++ return Result.create(byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(), ++ byteBufferImageContainer.getImageFormat()); ++ } else { ++ throw new IllegalArgumentException( ++ "Extract ByteBuffer from an MlImage created by objects other than Bitmap or Bytebuffer" ++ + " is not supported"); ++ } + } +- } + +- private static ByteBuffer extractByteBufferFromBitmap( +- Bitmap bitmap, @ImageFormat int imageFormat) { +- if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) { +- throw new IllegalArgumentException( +- "Extracting ByteBuffer from an MlImage created by a premultiplied Bitmap is not" +- + " supported"); ++ @ImageFormat ++ private static int adviseImageFormat(Bitmap bitmap) { ++ if (bitmap.getConfig() == Config.ARGB_8888) { ++ return MlImage.IMAGE_FORMAT_RGBA; ++ } else { ++ throw new IllegalArgumentException(String.format( ++ "Extracting ByteBuffer from an MlImage created by a Bitmap in config %s is not" ++ + " supported", ++ bitmap.getConfig())); ++ } + } +- if (bitmap.getConfig() == Config.ARGB_8888) { +- if (imageFormat == MlImage.IMAGE_FORMAT_RGBA) { +- ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount()); +- bitmap.copyPixelsToBuffer(buffer); +- buffer.rewind(); +- return buffer; +- } else if (imageFormat == MlImage.IMAGE_FORMAT_RGB) { +- // TODO(b/180504869): Try Use RGBA buffer to create RGB buffer which might be faster. +- int w = bitmap.getWidth(); +- int h = bitmap.getHeight(); +- int[] pixels = new int[w * h]; +- bitmap.getPixels(pixels, 0, w, 0, 0, w, h); +- ByteBuffer buffer = ByteBuffer.allocateDirect(w * h * 3); +- buffer.order(ByteOrder.nativeOrder()); +- for (int pixel : pixels) { +- // getPixels returns Color in ARGB rather than copyPixelsToBuffer which returns RGBA +- buffer.put((byte) ((pixel >> 16) & 0xff)); +- buffer.put((byte) ((pixel >> 8) & 0xff)); +- buffer.put((byte) (pixel & 0xff)); ++ ++ private static ByteBuffer extractByteBufferFromBitmap( ++ Bitmap bitmap, @ImageFormat int imageFormat) { ++ if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) { ++ throw new IllegalArgumentException( ++ "Extracting ByteBuffer from an MlImage created by a premultiplied Bitmap is not" ++ + " supported"); + } +- buffer.rewind(); +- return buffer; +- } ++ if (bitmap.getConfig() == Config.ARGB_8888) { ++ if (imageFormat == MlImage.IMAGE_FORMAT_RGBA) { ++ ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount()); ++ bitmap.copyPixelsToBuffer(buffer); ++ buffer.rewind(); ++ return buffer; ++ } else if (imageFormat == MlImage.IMAGE_FORMAT_RGB) { ++ // TODO(b/180504869): Try Use RGBA buffer to create RGB buffer which might be ++ // faster. ++ int w = bitmap.getWidth(); ++ int h = bitmap.getHeight(); ++ int[] pixels = new int[w * h]; ++ bitmap.getPixels(pixels, 0, w, 0, 0, w, h); ++ ByteBuffer buffer = ByteBuffer.allocateDirect(w * h * 3); ++ buffer.order(ByteOrder.nativeOrder()); ++ for (int pixel : pixels) { ++ // getPixels returns Color in ARGB rather than copyPixelsToBuffer which returns ++ // RGBA ++ buffer.put((byte) ((pixel >> 16) & 0xff)); ++ buffer.put((byte) ((pixel >> 8) & 0xff)); ++ buffer.put((byte) (pixel & 0xff)); ++ } ++ buffer.rewind(); ++ return buffer; ++ } ++ } ++ throw new IllegalArgumentException(String.format( ++ "Extracting ByteBuffer from an MlImage created by Bitmap and convert from %s to format" ++ + " %d is not supported", ++ bitmap.getConfig(), imageFormat)); + } +- throw new IllegalArgumentException( +- String.format( +- "Extracting ByteBuffer from an MlImage created by Bitmap and convert from %s to format" +- + " %d is not supported", +- bitmap.getConfig(), imageFormat)); +- } + +- private static ByteBuffer convertByteBuffer( +- ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) { +- if (sourceFormat == MlImage.IMAGE_FORMAT_RGB && targetFormat == MlImage.IMAGE_FORMAT_RGBA) { +- ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4); +- // Extend the buffer when the target is longer than the source. Use two cursors and sweep the +- // array reversely to convert in-place. +- byte[] array = new byte[target.capacity()]; +- source.get(array, 0, source.capacity()); +- source.rewind(); +- int rgbCursor = source.capacity(); +- int rgbaCursor = target.capacity(); +- while (rgbCursor != rgbaCursor) { +- array[--rgbaCursor] = (byte) 0xff; // A +- array[--rgbaCursor] = array[--rgbCursor]; // B +- array[--rgbaCursor] = array[--rgbCursor]; // G +- array[--rgbaCursor] = array[--rgbCursor]; // R +- } +- target.put(array, 0, target.capacity()); +- target.rewind(); +- return target; +- } else if (sourceFormat == MlImage.IMAGE_FORMAT_RGBA +- && targetFormat == MlImage.IMAGE_FORMAT_RGB) { +- ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3); +- // Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the +- // array to convert in-place. +- byte[] array = new byte[source.capacity()]; +- source.get(array, 0, source.capacity()); +- source.rewind(); +- int rgbaCursor = 0; +- int rgbCursor = 0; +- while (rgbaCursor < array.length) { +- array[rgbCursor++] = array[rgbaCursor++]; // R +- array[rgbCursor++] = array[rgbaCursor++]; // G +- array[rgbCursor++] = array[rgbaCursor++]; // B +- rgbaCursor++; +- } +- target.put(array, 0, target.capacity()); +- target.rewind(); +- return target; +- } else { +- throw new IllegalArgumentException( +- String.format( +- Locale.ENGLISH, +- "Convert bytebuffer image format from %d to %d is not supported", +- sourceFormat, +- targetFormat)); ++ private static ByteBuffer convertByteBuffer( ++ ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) { ++ if (sourceFormat == MlImage.IMAGE_FORMAT_RGB && targetFormat == MlImage.IMAGE_FORMAT_RGBA) { ++ ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4); ++ // Extend the buffer when the target is longer than the source. Use two cursors and ++ // sweep the array reversely to convert in-place. ++ byte[] array = new byte[target.capacity()]; ++ source.get(array, 0, source.capacity()); ++ source.rewind(); ++ int rgbCursor = source.capacity(); ++ int rgbaCursor = target.capacity(); ++ while (rgbCursor != rgbaCursor) { ++ array[--rgbaCursor] = (byte) 0xff; // A ++ array[--rgbaCursor] = array[--rgbCursor]; // B ++ array[--rgbaCursor] = array[--rgbCursor]; // G ++ array[--rgbaCursor] = array[--rgbCursor]; // R ++ } ++ target.put(array, 0, target.capacity()); ++ target.rewind(); ++ return target; ++ } else if (sourceFormat == MlImage.IMAGE_FORMAT_RGBA ++ && targetFormat == MlImage.IMAGE_FORMAT_RGB) { ++ ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3); ++ // Shrink the buffer when the target is shorter than the source. Use two cursors and ++ // sweep the array to convert in-place. ++ byte[] array = new byte[source.capacity()]; ++ source.get(array, 0, source.capacity()); ++ source.rewind(); ++ int rgbaCursor = 0; ++ int rgbCursor = 0; ++ while (rgbaCursor < array.length) { ++ array[rgbCursor++] = array[rgbaCursor++]; // R ++ array[rgbCursor++] = array[rgbaCursor++]; // G ++ array[rgbCursor++] = array[rgbaCursor++]; // B ++ rgbaCursor++; ++ } ++ target.put(array, 0, target.capacity()); ++ target.rewind(); ++ return target; ++ } else { ++ throw new IllegalArgumentException(String.format(Locale.ENGLISH, ++ "Convert bytebuffer image format from %d to %d is not supported", sourceFormat, ++ targetFormat)); ++ } + } +- } + +- // ByteBuffer is not able to be instantiated. +- private ByteBufferExtractor() {} ++ // ByteBuffer is not able to be instantiated. ++ private ByteBufferExtractor() {} + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java +index 9fbc3cbb94994..f872db485a8a2 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java +@@ -16,42 +16,40 @@ limitations under the License. + package com.google.android.odml.image; + + import com.google.android.odml.image.MlImage.ImageFormat; ++ + import java.nio.ByteBuffer; + + class ByteBufferImageContainer implements ImageContainer { +- +- private final ByteBuffer buffer; +- private final ImageProperties properties; +- +- public ByteBufferImageContainer( +- ByteBuffer buffer, +- @ImageFormat int imageFormat) { +- this.buffer = buffer; +- this.properties = ImageProperties.builder() +- .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER) +- .setImageFormat(imageFormat) +- .build(); +- } +- +- public ByteBuffer getByteBuffer() { +- return buffer; +- } +- +- @Override +- public ImageProperties getImageProperties() { +- return properties; +- } +- +- /** +- * Returns the image format. +- */ +- @ImageFormat +- public int getImageFormat() { +- return properties.getImageFormat(); +- } +- +- @Override +- public void close() { +- // No op for ByteBuffer. +- } ++ private final ByteBuffer buffer; ++ private final ImageProperties properties; ++ ++ public ByteBufferImageContainer(ByteBuffer buffer, @ImageFormat int imageFormat) { ++ this.buffer = buffer; ++ this.properties = ImageProperties.builder() ++ .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER) ++ .setImageFormat(imageFormat) ++ .build(); ++ } ++ ++ public ByteBuffer getByteBuffer() { ++ return buffer; ++ } ++ ++ @Override ++ public ImageProperties getImageProperties() { ++ return properties; ++ } ++ ++ /** ++ * Returns the image format. ++ */ ++ @ImageFormat ++ public int getImageFormat() { ++ return properties.getImageFormat(); ++ } ++ ++ @Override ++ public void close() { ++ // No op for ByteBuffer. ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java +index 421e2b8f0de31..f4b0b31dd5e3b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java +@@ -16,7 +16,9 @@ limitations under the License. + package com.google.android.odml.image; + + import android.graphics.Rect; ++ + import com.google.android.odml.image.MlImage.ImageFormat; ++ + import java.nio.ByteBuffer; + + /** +@@ -28,79 +30,74 @@ import java.nio.ByteBuffer; + * <p>Use {@link ByteBufferExtractor} to get {@link ByteBuffer} you passed in. + */ + public class ByteBufferMlImageBuilder { ++ // Mandatory fields. ++ private final ByteBuffer buffer; ++ private final int width; ++ private final int height; ++ @ImageFormat ++ private final int imageFormat; + +- // Mandatory fields. +- private final ByteBuffer buffer; +- private final int width; +- private final int height; +- @ImageFormat private final int imageFormat; +- +- // Optional fields. +- private int rotation; +- private Rect roi; +- private long timestamp; ++ // Optional fields. ++ private int rotation; ++ private Rect roi; ++ private long timestamp; + +- /** +- * Creates the builder with mandatory {@link ByteBuffer} and the represented image. +- * +- * <p>We will validate the size of the {@code byteBuffer} with given {@code width}, {@code height} +- * and {@code imageFormat}. +- * +- * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values +- * will be set with default: +- * +- * <ul> +- * <li>rotation: 0 +- * </ul> +- * +- * @param byteBuffer image data object. +- * @param width the width of the represented image. +- * @param height the height of the represented image. +- * @param imageFormat how the data encode the image. +- */ +- public ByteBufferMlImageBuilder( +- ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) { +- this.buffer = byteBuffer; +- this.width = width; +- this.height = height; +- this.imageFormat = imageFormat; +- // TODO(b/180504869): Validate bytebuffer size with width, height and image format +- this.rotation = 0; +- this.roi = new Rect(0, 0, width, height); +- this.timestamp = 0; +- } ++ /** ++ * Creates the builder with mandatory {@link ByteBuffer} and the represented image. ++ * ++ * <p>We will validate the size of the {@code byteBuffer} with given {@code width}, {@code ++ * height} and {@code imageFormat}. ++ * ++ * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the ++ * values will be set with default: ++ * ++ * <ul> ++ * <li>rotation: 0 ++ * </ul> ++ * ++ * @param byteBuffer image data object. ++ * @param width the width of the represented image. ++ * @param height the height of the represented image. ++ * @param imageFormat how the data encode the image. ++ */ ++ public ByteBufferMlImageBuilder( ++ ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) { ++ this.buffer = byteBuffer; ++ this.width = width; ++ this.height = height; ++ this.imageFormat = imageFormat; ++ // TODO(b/180504869): Validate bytebuffer size with width, height and image format ++ this.rotation = 0; ++ this.roi = new Rect(0, 0, width, height); ++ this.timestamp = 0; ++ } + +- /** +- * Sets value for {@link MlImage#getRotation()}. +- * +- * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270. +- */ +- public ByteBufferMlImageBuilder setRotation(int rotation) { +- MlImage.validateRotation(rotation); +- this.rotation = rotation; +- return this; +- } ++ /** ++ * Sets value for {@link MlImage#getRotation()}. ++ * ++ * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270. ++ */ ++ public ByteBufferMlImageBuilder setRotation(int rotation) { ++ MlImage.validateRotation(rotation); ++ this.rotation = rotation; ++ return this; ++ } + +- /** Sets value for {@link MlImage#getRoi()}. */ +- ByteBufferMlImageBuilder setRoi(Rect roi) { +- this.roi = roi; +- return this; +- } ++ /** Sets value for {@link MlImage#getRoi()}. */ ++ ByteBufferMlImageBuilder setRoi(Rect roi) { ++ this.roi = roi; ++ return this; ++ } + +- /** Sets value for {@link MlImage#getTimestamp()}. */ +- ByteBufferMlImageBuilder setTimestamp(long timestamp) { +- this.timestamp = timestamp; +- return this; +- } ++ /** Sets value for {@link MlImage#getTimestamp()}. */ ++ ByteBufferMlImageBuilder setTimestamp(long timestamp) { ++ this.timestamp = timestamp; ++ return this; ++ } + +- /** Builds an {@link MlImage} instance. */ +- public MlImage build() { +- return new MlImage( +- new ByteBufferImageContainer(buffer, imageFormat), +- rotation, +- roi, +- timestamp, +- width, +- height); +- } ++ /** Builds an {@link MlImage} instance. */ ++ public MlImage build() { ++ return new MlImage(new ByteBufferImageContainer(buffer, imageFormat), rotation, roi, ++ timestamp, width, height); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java +index 25ed2312ce580..bfa7c0a292f4f 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java +@@ -20,11 +20,11 @@ import com.google.android.odml.image.annotation.KeepForSdk; + /** Manages internal image data storage. The interface is package-private. */ + @KeepForSdk + interface ImageContainer { +- /** Returns the properties of the contained image. */ +- @KeepForSdk +- ImageProperties getImageProperties(); ++ /** Returns the properties of the contained image. */ ++ @KeepForSdk ++ ImageProperties getImageProperties(); + +- /** Close the image container and releases the image resource inside. */ +- @KeepForSdk +- void close(); ++ /** Close the image container and releases the image resource inside. */ ++ @KeepForSdk ++ void close(); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java +index 717bc5f9935ed..a61e97b81b872 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java +@@ -24,63 +24,61 @@ import com.google.auto.value.extension.memoized.Memoized; + /** Groups a set of properties to describe how an image is stored. */ + @AutoValue + public abstract class ImageProperties { +- +- /** +- * Gets the pixel format of the image. +- * +- * @see MlImage.ImageFormat +- */ +- @ImageFormat +- public abstract int getImageFormat(); +- +- /** +- * Gets the storage type of the image. +- * +- * @see MlImage.StorageType +- */ +- @StorageType +- public abstract int getStorageType(); +- +- @Memoized +- @Override +- public abstract int hashCode(); +- +- /** +- * Creates a builder of {@link ImageProperties}. +- * +- * @see ImageProperties.Builder +- */ +- @KeepForSdk +- static Builder builder() { +- return new AutoValue_ImageProperties.Builder(); +- } +- +- /** Builds a {@link ImageProperties}. */ +- @AutoValue.Builder +- @KeepForSdk +- abstract static class Builder { ++ /** ++ * Gets the pixel format of the image. ++ * ++ * @see MlImage.ImageFormat ++ */ ++ @ImageFormat ++ public abstract int getImageFormat(); + + /** +- * Sets the {@link MlImage.ImageFormat}. ++ * Gets the storage type of the image. + * +- * @see ImageProperties#getImageFormat ++ * @see MlImage.StorageType + */ +- @KeepForSdk +- abstract Builder setImageFormat(@ImageFormat int value); ++ @StorageType ++ public abstract int getStorageType(); ++ ++ @Memoized ++ @Override ++ public abstract int hashCode(); + + /** +- * Sets the {@link MlImage.StorageType}. ++ * Creates a builder of {@link ImageProperties}. + * +- * @see ImageProperties#getStorageType ++ * @see ImageProperties.Builder + */ + @KeepForSdk +- abstract Builder setStorageType(@StorageType int value); ++ static Builder builder() { ++ return new AutoValue_ImageProperties.Builder(); ++ } + +- /** Builds the {@link ImageProperties}. */ ++ /** Builds a {@link ImageProperties}. */ ++ @AutoValue.Builder + @KeepForSdk +- abstract ImageProperties build(); +- } ++ abstract static class Builder { ++ /** ++ * Sets the {@link MlImage.ImageFormat}. ++ * ++ * @see ImageProperties#getImageFormat ++ */ ++ @KeepForSdk ++ abstract Builder setImageFormat(@ImageFormat int value); ++ ++ /** ++ * Sets the {@link MlImage.StorageType}. ++ * ++ * @see ImageProperties#getStorageType ++ */ ++ @KeepForSdk ++ abstract Builder setStorageType(@StorageType int value); ++ ++ /** Builds the {@link ImageProperties}. */ ++ @KeepForSdk ++ abstract ImageProperties build(); ++ } + +- // Hide the constructor. +- ImageProperties() {} ++ // Hide the constructor. ++ ImageProperties() {} + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java +index 9365d0b2a422e..9ed88ee30c62f 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java +@@ -19,55 +19,56 @@ import android.media.Image; + import android.os.Build; + import android.os.Build.VERSION; + import android.os.Build.VERSION_CODES; ++ + import androidx.annotation.RequiresApi; ++ + import com.google.android.odml.image.MlImage.ImageFormat; + + @RequiresApi(VERSION_CODES.KITKAT) + class MediaImageContainer implements ImageContainer { ++ private final Image mediaImage; ++ private final ImageProperties properties; + +- private final Image mediaImage; +- private final ImageProperties properties; +- +- public MediaImageContainer(Image mediaImage) { +- this.mediaImage = mediaImage; +- this.properties = ImageProperties.builder() +- .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE) +- .setImageFormat(convertFormatCode(mediaImage.getFormat())) +- .build(); +- } +- +- public Image getImage() { +- return mediaImage; +- } ++ public MediaImageContainer(Image mediaImage) { ++ this.mediaImage = mediaImage; ++ this.properties = ImageProperties.builder() ++ .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE) ++ .setImageFormat(convertFormatCode(mediaImage.getFormat())) ++ .build(); ++ } + +- @Override +- public ImageProperties getImageProperties() { +- return properties; +- } ++ public Image getImage() { ++ return mediaImage; ++ } + +- @Override +- public void close() { +- mediaImage.close(); +- } ++ @Override ++ public ImageProperties getImageProperties() { ++ return properties; ++ } + +- @ImageFormat +- static int convertFormatCode(int graphicsFormat) { +- // We only cover the format mentioned in +- // https://developer.android.com/reference/android/media/Image#getFormat() +- if (VERSION.SDK_INT >= Build.VERSION_CODES.M) { +- if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) { +- return MlImage.IMAGE_FORMAT_RGBA; +- } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) { +- return MlImage.IMAGE_FORMAT_RGB; +- } ++ @Override ++ public void close() { ++ mediaImage.close(); + } +- switch (graphicsFormat) { +- case android.graphics.ImageFormat.JPEG: +- return MlImage.IMAGE_FORMAT_JPEG; +- case android.graphics.ImageFormat.YUV_420_888: +- return MlImage.IMAGE_FORMAT_YUV_420_888; +- default: +- return MlImage.IMAGE_FORMAT_UNKNOWN; ++ ++ @ImageFormat ++ static int convertFormatCode(int graphicsFormat) { ++ // We only cover the format mentioned in ++ // https://developer.android.com/reference/android/media/Image#getFormat() ++ if (VERSION.SDK_INT >= Build.VERSION_CODES.M) { ++ if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) { ++ return MlImage.IMAGE_FORMAT_RGBA; ++ } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) { ++ return MlImage.IMAGE_FORMAT_RGB; ++ } ++ } ++ switch (graphicsFormat) { ++ case android.graphics.ImageFormat.JPEG: ++ return MlImage.IMAGE_FORMAT_JPEG; ++ case android.graphics.ImageFormat.YUV_420_888: ++ return MlImage.IMAGE_FORMAT_YUV_420_888; ++ default: ++ return MlImage.IMAGE_FORMAT_UNKNOWN; ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java +index 73aadabb38789..59ed98b569fa2 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java +@@ -17,6 +17,7 @@ package com.google.android.odml.image; + + import android.media.Image; + import android.os.Build.VERSION_CODES; ++ + import androidx.annotation.RequiresApi; + + /** +@@ -27,26 +28,25 @@ import androidx.annotation.RequiresApi; + */ + @RequiresApi(VERSION_CODES.KITKAT) + public class MediaImageExtractor { +- +- private MediaImageExtractor() {} +- +- /** +- * Extracts a {@link android.media.Image} from an {@link MlImage}. Currently it only works for +- * {@link MlImage} that built from {@link MediaMlImageBuilder}. +- * +- * <p>Notice: Properties of the {@code image} like rotation will not take effects. +- * +- * @param image the image to extract {@link android.media.Image} from. +- * @return {@link android.media.Image} that stored in {@link MlImage}. +- * @throws IllegalArgumentException if the extraction failed. +- */ +- public static Image extract(MlImage image) { +- ImageContainer container; +- if ((container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) { +- return ((MediaImageContainer) container).getImage(); ++ private MediaImageExtractor() {} ++ ++ /** ++ * Extracts a {@link android.media.Image} from an {@link MlImage}. Currently it only works for ++ * {@link MlImage} that built from {@link MediaMlImageBuilder}. ++ * ++ * <p>Notice: Properties of the {@code image} like rotation will not take effects. ++ * ++ * @param image the image to extract {@link android.media.Image} from. ++ * @return {@link android.media.Image} that stored in {@link MlImage}. ++ * @throws IllegalArgumentException if the extraction failed. ++ */ ++ public static Image extract(MlImage image) { ++ ImageContainer container; ++ if ((container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) { ++ return ((MediaImageContainer) container).getImage(); ++ } ++ throw new IllegalArgumentException( ++ "Extract Media Image from an MlImage created by objects other than Media Image" ++ + " is not supported"); + } +- throw new IllegalArgumentException( +- "Extract Media Image from an MlImage created by objects other than Media Image" +- + " is not supported"); +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java +index e96ab38317bac..80771bdb91890 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java +@@ -18,6 +18,7 @@ package com.google.android.odml.image; + import android.graphics.Rect; + import android.media.Image; + import android.os.Build.VERSION_CODES; ++ + import androidx.annotation.RequiresApi; + + /** +@@ -30,65 +31,59 @@ import androidx.annotation.RequiresApi; + */ + @RequiresApi(VERSION_CODES.KITKAT) + public class MediaMlImageBuilder { ++ // Mandatory fields. ++ private final Image mediaImage; + +- // Mandatory fields. +- private final Image mediaImage; +- +- // Optional fields. +- private int rotation; +- private Rect roi; +- private long timestamp; ++ // Optional fields. ++ private int rotation; ++ private Rect roi; ++ private long timestamp; + +- /** +- * Creates the builder with a mandatory {@link android.media.Image}. +- * +- * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values +- * will be set with default: +- * +- * <ul> +- * <li>rotation: 0 +- * </ul> +- * +- * @param mediaImage image data object. +- */ +- public MediaMlImageBuilder(Image mediaImage) { +- this.mediaImage = mediaImage; +- this.rotation = 0; +- this.roi = new Rect(0, 0, mediaImage.getWidth(), mediaImage.getHeight()); +- this.timestamp = 0; +- } ++ /** ++ * Creates the builder with a mandatory {@link android.media.Image}. ++ * ++ * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the ++ * values will be set with default: ++ * ++ * <ul> ++ * <li>rotation: 0 ++ * </ul> ++ * ++ * @param mediaImage image data object. ++ */ ++ public MediaMlImageBuilder(Image mediaImage) { ++ this.mediaImage = mediaImage; ++ this.rotation = 0; ++ this.roi = new Rect(0, 0, mediaImage.getWidth(), mediaImage.getHeight()); ++ this.timestamp = 0; ++ } + +- /** +- * Sets value for {@link MlImage#getRotation()}. +- * +- * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270. +- */ +- public MediaMlImageBuilder setRotation(int rotation) { +- MlImage.validateRotation(rotation); +- this.rotation = rotation; +- return this; +- } ++ /** ++ * Sets value for {@link MlImage#getRotation()}. ++ * ++ * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270. ++ */ ++ public MediaMlImageBuilder setRotation(int rotation) { ++ MlImage.validateRotation(rotation); ++ this.rotation = rotation; ++ return this; ++ } + +- /** Sets value for {@link MlImage#getRoi()}. */ +- MediaMlImageBuilder setRoi(Rect roi) { +- this.roi = roi; +- return this; +- } ++ /** Sets value for {@link MlImage#getRoi()}. */ ++ MediaMlImageBuilder setRoi(Rect roi) { ++ this.roi = roi; ++ return this; ++ } + +- /** Sets value for {@link MlImage#getTimestamp()}. */ +- MediaMlImageBuilder setTimestamp(long timestamp) { +- this.timestamp = timestamp; +- return this; +- } ++ /** Sets value for {@link MlImage#getTimestamp()}. */ ++ MediaMlImageBuilder setTimestamp(long timestamp) { ++ this.timestamp = timestamp; ++ return this; ++ } + +- /** Builds an {@link MlImage} instance. */ +- public MlImage build() { +- return new MlImage( +- new MediaImageContainer(mediaImage), +- rotation, +- roi, +- timestamp, +- mediaImage.getWidth(), +- mediaImage.getHeight()); +- } ++ /** Builds an {@link MlImage} instance. */ ++ public MlImage build() { ++ return new MlImage(new MediaImageContainer(mediaImage), rotation, roi, timestamp, ++ mediaImage.getWidth(), mediaImage.getHeight()); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java +index 975ff7c0908c7..7e21e6ad428f2 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java +@@ -16,9 +16,12 @@ limitations under the License. + package com.google.android.odml.image; + + import android.graphics.Rect; ++ + import androidx.annotation.IntDef; + import androidx.annotation.Nullable; ++ + import com.google.android.odml.image.annotation.KeepForSdk; ++ + import java.io.Closeable; + import java.lang.annotation.Retention; + import java.lang.annotation.RetentionPolicy; +@@ -62,228 +65,232 @@ import java.util.Map.Entry; + * and multiple storages. + */ + public class MlImage implements Closeable { ++ /** Specifies the image format of an image. */ ++ @IntDef({ ++ IMAGE_FORMAT_UNKNOWN, ++ IMAGE_FORMAT_RGBA, ++ IMAGE_FORMAT_RGB, ++ IMAGE_FORMAT_NV12, ++ IMAGE_FORMAT_NV21, ++ IMAGE_FORMAT_YV12, ++ IMAGE_FORMAT_YV21, ++ IMAGE_FORMAT_YUV_420_888, ++ IMAGE_FORMAT_ALPHA, ++ IMAGE_FORMAT_JPEG, ++ }) ++ @Retention(RetentionPolicy.SOURCE) ++ public @interface ImageFormat {} ++ ++ public static final int IMAGE_FORMAT_UNKNOWN = 0; ++ public static final int IMAGE_FORMAT_RGBA = 1; ++ public static final int IMAGE_FORMAT_RGB = 2; ++ public static final int IMAGE_FORMAT_NV12 = 3; ++ public static final int IMAGE_FORMAT_NV21 = 4; ++ public static final int IMAGE_FORMAT_YV12 = 5; ++ public static final int IMAGE_FORMAT_YV21 = 6; ++ public static final int IMAGE_FORMAT_YUV_420_888 = 7; ++ public static final int IMAGE_FORMAT_ALPHA = 8; ++ public static final int IMAGE_FORMAT_JPEG = 9; ++ ++ /** Specifies the image container type. Would be useful for choosing extractors. */ ++ @IntDef({ ++ STORAGE_TYPE_BITMAP, ++ STORAGE_TYPE_BYTEBUFFER, ++ STORAGE_TYPE_MEDIA_IMAGE, ++ STORAGE_TYPE_IMAGE_PROXY, ++ }) ++ @Retention(RetentionPolicy.SOURCE) ++ public @interface StorageType {} ++ ++ public static final int STORAGE_TYPE_BITMAP = 1; ++ public static final int STORAGE_TYPE_BYTEBUFFER = 2; ++ public static final int STORAGE_TYPE_MEDIA_IMAGE = 3; ++ public static final int STORAGE_TYPE_IMAGE_PROXY = 4; ++ ++ /** ++ * Returns a list of supported image properties for this {@link MlImage}. ++ * ++ * <p>Currently {@link MlImage} only support single storage type so the size of return list will ++ * always be 1. ++ * ++ * @see ImageProperties ++ */ ++ public List<ImageProperties> getContainedImageProperties() { ++ return Collections.singletonList(getContainer().getImageProperties()); ++ } ++ ++ /** Returns the rotation value attached to the image. Rotation value will be 0, 90, 180, 270. */ ++ public int getRotation() { ++ return rotation; ++ } ++ ++ /** Returns the timestamp attached to the image. */ ++ long getTimestamp() { ++ return timestamp; ++ } ++ ++ /** Returns the width of the image. */ ++ public int getWidth() { ++ return width; ++ } ++ ++ /** Returns the height of the image. */ ++ public int getHeight() { ++ return height; ++ } + +- /** Specifies the image format of an image. */ +- @IntDef({ +- IMAGE_FORMAT_UNKNOWN, +- IMAGE_FORMAT_RGBA, +- IMAGE_FORMAT_RGB, +- IMAGE_FORMAT_NV12, +- IMAGE_FORMAT_NV21, +- IMAGE_FORMAT_YV12, +- IMAGE_FORMAT_YV21, +- IMAGE_FORMAT_YUV_420_888, +- IMAGE_FORMAT_ALPHA, +- IMAGE_FORMAT_JPEG, +- }) +- @Retention(RetentionPolicy.SOURCE) +- public @interface ImageFormat {} +- +- public static final int IMAGE_FORMAT_UNKNOWN = 0; +- public static final int IMAGE_FORMAT_RGBA = 1; +- public static final int IMAGE_FORMAT_RGB = 2; +- public static final int IMAGE_FORMAT_NV12 = 3; +- public static final int IMAGE_FORMAT_NV21 = 4; +- public static final int IMAGE_FORMAT_YV12 = 5; +- public static final int IMAGE_FORMAT_YV21 = 6; +- public static final int IMAGE_FORMAT_YUV_420_888 = 7; +- public static final int IMAGE_FORMAT_ALPHA = 8; +- public static final int IMAGE_FORMAT_JPEG = 9; +- +- /** Specifies the image container type. Would be useful for choosing extractors. */ +- @IntDef({ +- STORAGE_TYPE_BITMAP, +- STORAGE_TYPE_BYTEBUFFER, +- STORAGE_TYPE_MEDIA_IMAGE, +- STORAGE_TYPE_IMAGE_PROXY, +- }) +- @Retention(RetentionPolicy.SOURCE) +- public @interface StorageType {} +- +- public static final int STORAGE_TYPE_BITMAP = 1; +- public static final int STORAGE_TYPE_BYTEBUFFER = 2; +- public static final int STORAGE_TYPE_MEDIA_IMAGE = 3; +- public static final int STORAGE_TYPE_IMAGE_PROXY = 4; +- +- /** +- * Returns a list of supported image properties for this {@link MlImage}. +- * +- * <p>Currently {@link MlImage} only support single storage type so the size of return list will +- * always be 1. +- * +- * @see ImageProperties +- */ +- public List<ImageProperties> getContainedImageProperties() { +- return Collections.singletonList(getContainer().getImageProperties()); +- } +- +- /** Returns the rotation value attached to the image. Rotation value will be 0, 90, 180, 270. */ +- public int getRotation() { +- return rotation; +- } +- +- /** Returns the timestamp attached to the image. */ +- long getTimestamp() { +- return timestamp; +- } +- +- /** Returns the width of the image. */ +- public int getWidth() { +- return width; +- } +- +- /** Returns the height of the image. */ +- public int getHeight() { +- return height; +- } +- +- /** Returns the region-of-interest rectangle attached to the image. */ +- Rect getRoi() { +- Rect result = new Rect(); +- result.set(roi); +- return result; +- } +- +- /** Acquires a reference on this {@link MlImage}. This will increase the reference count by 1. */ +- private synchronized void acquire() { +- referenceCount += 1; +- } +- +- /** +- * Removes a reference that was previously acquired or init. +- * +- * <p>When {@link MlImage} is created, it has 1 reference count. +- * +- * <p>When the reference count becomes 0, it will release the resource under the hood. +- */ +- @Override +- // TODO(b/189767728): Create an internal flag to indicate image is closed, or use referenceCount +- public synchronized void close() { +- referenceCount -= 1; +- if (referenceCount == 0) { +- for (ImageContainer imageContainer : containerMap.values()) { +- imageContainer.close(); +- } ++ /** Returns the region-of-interest rectangle attached to the image. */ ++ Rect getRoi() { ++ Rect result = new Rect(); ++ result.set(roi); ++ return result; + } +- } +- +- /** +- * Advanced API access for {@link MlImage}. +- * +- * <p>These APIs are useful for other infrastructures, for example, acquiring extra reference +- * count for {@link MlImage}. However, an App developer should avoid using the following APIs. +- * +- * <p>APIs inside are treated as internal APIs which are subject to change. +- */ +- public static final class Internal { + + /** + * Acquires a reference on this {@link MlImage}. This will increase the reference count by 1. ++ */ ++ private synchronized void acquire() { ++ referenceCount += 1; ++ } ++ ++ /** ++ * Removes a reference that was previously acquired or init. ++ * ++ * <p>When {@link MlImage} is created, it has 1 reference count. + * +- * <p>This method is more useful for image consumer to acquire a reference so image resource +- * will not be closed accidentally. As image creator, normal developer doesn't need to call this +- * method. ++ * <p>When the reference count becomes 0, it will release the resource under the hood. ++ */ ++ @Override ++ // TODO(b/189767728): Create an internal flag to indicate image is closed, or use referenceCount ++ public synchronized void close() { ++ referenceCount -= 1; ++ if (referenceCount == 0) { ++ for (ImageContainer imageContainer : containerMap.values()) { ++ imageContainer.close(); ++ } ++ } ++ } ++ ++ /** ++ * Advanced API access for {@link MlImage}. + * +- * <p>The reference count is 1 when {@link MlImage} is created. Developer can call {@link +- * #close()} to indicate it doesn't need this {@link MlImage} anymore. ++ * <p>These APIs are useful for other infrastructures, for example, acquiring extra reference ++ * count for {@link MlImage}. However, an App developer should avoid using the following APIs. + * +- * @see #close() ++ * <p>APIs inside are treated as internal APIs which are subject to change. + */ +- public void acquire() { +- image.acquire(); ++ public static final class Internal { ++ /** ++ * Acquires a reference on this {@link MlImage}. This will increase the reference count ++ * by 1. ++ * ++ * <p>This method is more useful for image consumer to acquire a reference so image resource ++ * will not be closed accidentally. As image creator, normal developer doesn't need to call ++ * this method. ++ * ++ * <p>The reference count is 1 when {@link MlImage} is created. Developer can call {@link ++ * #close()} to indicate it doesn't need this {@link MlImage} anymore. ++ * ++ * @see #close() ++ */ ++ public void acquire() { ++ image.acquire(); ++ } ++ ++ private final MlImage image; ++ ++ // Only MlImage creates the internal helper. ++ private Internal(MlImage image) { ++ this.image = image; ++ } ++ } ++ ++ /** Gets {@link Internal} object which contains internal APIs. */ ++ public Internal getInternal() { ++ return new Internal(this); + } + +- private final MlImage image; ++ private final Map<ImageProperties, ImageContainer> containerMap; ++ private final int rotation; ++ private final Rect roi; ++ private final long timestamp; ++ private final int width; ++ private final int height; ++ ++ private int referenceCount; ++ ++ /** Constructs an {@link MlImage} with a built container. */ ++ @KeepForSdk ++ MlImage(ImageContainer container, int rotation, Rect roi, long timestamp, int width, ++ int height) { ++ this.containerMap = new HashMap<>(); ++ containerMap.put(container.getImageProperties(), container); ++ this.rotation = rotation; ++ this.roi = new Rect(); ++ this.roi.set(roi); ++ this.timestamp = timestamp; ++ this.width = width; ++ this.height = height; ++ this.referenceCount = 1; ++ } ++ ++ /** ++ * Gets one available container. ++ * ++ * @return the current container. ++ */ ++ @KeepForSdk ++ ImageContainer getContainer() { ++ // According to the design, in the future we will support multiple containers in one image. ++ // Currently just return the original container. ++ // TODO(b/182443927): Cache multiple containers in MlImage. ++ return containerMap.values().iterator().next(); ++ } + +- // Only MlImage creates the internal helper. +- private Internal(MlImage image) { +- this.image = image; ++ /** ++ * Gets container from required {@code storageType}. Returns {@code null} if not existed. ++ * ++ * <p>If there are multiple containers with required {@code storageType}, returns the first one. ++ */ ++ @Nullable ++ @KeepForSdk ++ ImageContainer getContainer(@StorageType int storageType) { ++ for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) { ++ if (entry.getKey().getStorageType() == storageType) { ++ return entry.getValue(); ++ } ++ } ++ return null; + } +- } +- +- /** Gets {@link Internal} object which contains internal APIs. */ +- public Internal getInternal() { +- return new Internal(this); +- } +- +- private final Map<ImageProperties, ImageContainer> containerMap; +- private final int rotation; +- private final Rect roi; +- private final long timestamp; +- private final int width; +- private final int height; +- +- private int referenceCount; +- +- /** Constructs an {@link MlImage} with a built container. */ +- @KeepForSdk +- MlImage(ImageContainer container, int rotation, Rect roi, long timestamp, int width, int height) { +- this.containerMap = new HashMap<>(); +- containerMap.put(container.getImageProperties(), container); +- this.rotation = rotation; +- this.roi = new Rect(); +- this.roi.set(roi); +- this.timestamp = timestamp; +- this.width = width; +- this.height = height; +- this.referenceCount = 1; +- } +- +- /** +- * Gets one available container. +- * +- * @return the current container. +- */ +- @KeepForSdk +- ImageContainer getContainer() { +- // According to the design, in the future we will support multiple containers in one image. +- // Currently just return the original container. +- // TODO(b/182443927): Cache multiple containers in MlImage. +- return containerMap.values().iterator().next(); +- } +- +- /** +- * Gets container from required {@code storageType}. Returns {@code null} if not existed. +- * +- * <p>If there are multiple containers with required {@code storageType}, returns the first one. +- */ +- @Nullable +- @KeepForSdk +- ImageContainer getContainer(@StorageType int storageType) { +- for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) { +- if (entry.getKey().getStorageType() == storageType) { +- return entry.getValue(); +- } ++ ++ /** ++ * Gets container from required {@code imageProperties}. Returns {@code null} if non existed. ++ */ ++ @Nullable ++ @KeepForSdk ++ ImageContainer getContainer(ImageProperties imageProperties) { ++ return containerMap.get(imageProperties); + } +- return null; +- } +- +- /** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */ +- @Nullable +- @KeepForSdk +- ImageContainer getContainer(ImageProperties imageProperties) { +- return containerMap.get(imageProperties); +- } +- +- /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */ +- boolean addContainer(ImageContainer container) { +- ImageProperties imageProperties = container.getImageProperties(); +- if (containerMap.containsKey(imageProperties)) { +- return false; ++ ++ /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */ ++ boolean addContainer(ImageContainer container) { ++ ImageProperties imageProperties = container.getImageProperties(); ++ if (containerMap.containsKey(imageProperties)) { ++ return false; ++ } ++ containerMap.put(imageProperties, container); ++ return true; + } +- containerMap.put(imageProperties, container); +- return true; +- } +- +- /** +- * Validates rotation values for builders. Only supports 0, 90, 180, 270. +- * +- * @throws IllegalArgumentException if the rotation value is invalid. +- */ +- static void validateRotation(int rotation) { +- if (rotation != 0 && rotation != 90 && rotation != 180 && rotation != 270) { +- throw new IllegalArgumentException( +- "Rotation value " + rotation + " is not valid. Use only 0, 90, 180 or 270."); ++ ++ /** ++ * Validates rotation values for builders. Only supports 0, 90, 180, 270. ++ * ++ * @throws IllegalArgumentException if the rotation value is invalid. ++ */ ++ static void validateRotation(int rotation) { ++ if (rotation != 0 && rotation != 90 && rotation != 180 && rotation != 270) { ++ throw new IllegalArgumentException( ++ "Rotation value " + rotation + " is not valid. Use only 0, 90, 180 or 270."); ++ } + } +- } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java +index 44eb1198884fa..8408a0e424a9b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java +@@ -16,39 +16,37 @@ limitations under the License. + package com.google.android.odml.image; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + + import android.graphics.Bitmap; +-import java.nio.ByteBuffer; ++ + import org.junit.Test; + import org.junit.runner.RunWith; + import org.robolectric.RobolectricTestRunner; + ++import java.nio.ByteBuffer; ++ + /** Unit test for {@link BitmapExtractor}. */ + @RunWith(RobolectricTestRunner.class) + public class BitmapExtractorTest { ++ @Test ++ public void extract_fromBitmap_succeeds() { ++ Bitmap bitmap = TestImageCreator.createRgbaBitmap(); ++ MlImage image = new BitmapMlImageBuilder(bitmap).build(); ++ ++ Bitmap result = BitmapExtractor.extract(image); ++ ++ assertThat(result).isSameInstanceAs(bitmap); ++ } ++ ++ @Test ++ public void extract_fromByteBuffer_throwsException() { ++ ByteBuffer buffer = TestImageCreator.createRgbBuffer(); ++ MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(), ++ TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB) ++ .build(); + +- @Test +- public void extract_fromBitmap_succeeds() { +- Bitmap bitmap = TestImageCreator.createRgbaBitmap(); +- MlImage image = new BitmapMlImageBuilder(bitmap).build(); +- +- Bitmap result = BitmapExtractor.extract(image); +- +- assertThat(result).isSameInstanceAs(bitmap); +- } +- +- @Test +- public void extract_fromByteBuffer_throwsException() { +- ByteBuffer buffer = TestImageCreator.createRgbBuffer(); +- MlImage image = +- new ByteBufferMlImageBuilder( +- buffer, +- TestImageCreator.getWidth(), +- TestImageCreator.getHeight(), +- MlImage.IMAGE_FORMAT_RGB) +- .build(); +- +- assertThrows(IllegalArgumentException.class, () -> BitmapExtractor.extract(image)); +- } ++ assertThrows(IllegalArgumentException.class, () -> BitmapExtractor.extract(image)); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java +index f9908210f2970..9a4051cdf8f6a 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java +@@ -16,11 +16,13 @@ limitations under the License. + package com.google.android.odml.image; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + + import android.graphics.Bitmap; + import android.graphics.Bitmap.Config; + import android.graphics.Rect; ++ + import org.junit.Test; + import org.junit.runner.RunWith; + import org.robolectric.RobolectricTestRunner; +@@ -28,63 +30,59 @@ import org.robolectric.RobolectricTestRunner; + /** Tests for {@link BitmapMlImageBuilder} */ + @RunWith(RobolectricTestRunner.class) + public final class BitmapMlImageBuilderTest { +- +- @Test +- public void build_fromBitmap_succeeds() { +- Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888); +- +- MlImage image = new BitmapMlImageBuilder(bitmap).build(); +- ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP); +- +- assertThat(image.getWidth()).isEqualTo(20); +- assertThat(image.getHeight()).isEqualTo(25); +- assertThat(image.getContainedImageProperties()) +- .containsExactly( +- ImageProperties.builder() +- .setImageFormat(MlImage.IMAGE_FORMAT_RGBA) +- .setStorageType(MlImage.STORAGE_TYPE_BITMAP) +- .build()); +- assertThat(((BitmapImageContainer) container).getBitmap().getConfig()) +- .isEqualTo(Config.ARGB_8888); +- } +- +- @Test +- public void build_withOptionalProperties_succeeds() { +- Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888); +- +- MlImage image = +- new BitmapMlImageBuilder(bitmap) +- .setRoi(new Rect(0, 5, 10, 15)) +- .setRotation(90) +- .setTimestamp(12345) +- .build(); +- +- assertThat(image.getTimestamp()).isEqualTo(12345); +- assertThat(image.getRotation()).isEqualTo(90); +- assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15)); +- } +- +- @Test +- public void build_withInvalidRotation_throwsException() { +- Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888); +- BitmapMlImageBuilder builder = new BitmapMlImageBuilder(bitmap); +- +- assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360)); +- } +- +- @Test +- public void release_recyclesBitmap() { +- Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888); +- +- MlImage image = +- new BitmapMlImageBuilder(bitmap) +- .setRoi(new Rect(0, 5, 10, 15)) +- .setRotation(90) +- .setTimestamp(12345) +- .build(); +- assertThat(bitmap.isRecycled()).isFalse(); +- image.close(); +- +- assertThat(bitmap.isRecycled()).isTrue(); +- } ++ @Test ++ public void build_fromBitmap_succeeds() { ++ Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888); ++ ++ MlImage image = new BitmapMlImageBuilder(bitmap).build(); ++ ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP); ++ ++ assertThat(image.getWidth()).isEqualTo(20); ++ assertThat(image.getHeight()).isEqualTo(25); ++ assertThat(image.getContainedImageProperties()) ++ .containsExactly(ImageProperties.builder() ++ .setImageFormat(MlImage.IMAGE_FORMAT_RGBA) ++ .setStorageType(MlImage.STORAGE_TYPE_BITMAP) ++ .build()); ++ assertThat(((BitmapImageContainer) container).getBitmap().getConfig()) ++ .isEqualTo(Config.ARGB_8888); ++ } ++ ++ @Test ++ public void build_withOptionalProperties_succeeds() { ++ Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888); ++ ++ MlImage image = new BitmapMlImageBuilder(bitmap) ++ .setRoi(new Rect(0, 5, 10, 15)) ++ .setRotation(90) ++ .setTimestamp(12345) ++ .build(); ++ ++ assertThat(image.getTimestamp()).isEqualTo(12345); ++ assertThat(image.getRotation()).isEqualTo(90); ++ assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15)); ++ } ++ ++ @Test ++ public void build_withInvalidRotation_throwsException() { ++ Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888); ++ BitmapMlImageBuilder builder = new BitmapMlImageBuilder(bitmap); ++ ++ assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360)); ++ } ++ ++ @Test ++ public void release_recyclesBitmap() { ++ Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888); ++ ++ MlImage image = new BitmapMlImageBuilder(bitmap) ++ .setRoi(new Rect(0, 5, 10, 15)) ++ .setRotation(90) ++ .setTimestamp(12345) ++ .build(); ++ assertThat(bitmap.isRecycled()).isFalse(); ++ image.close(); ++ ++ assertThat(bitmap.isRecycled()).isTrue(); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java +index 2ff49010443a5..e675ba9abd479 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java +@@ -16,15 +16,18 @@ limitations under the License. + package com.google.android.odml.image; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + + import android.graphics.Bitmap; +-import java.nio.Buffer; +-import java.nio.ByteBuffer; ++ + import org.junit.Test; + import org.junit.runner.RunWith; + import org.robolectric.RobolectricTestRunner; + ++import java.nio.Buffer; ++import java.nio.ByteBuffer; ++ + /** + * Tests for {@link ByteBufferExtractor}. + * +@@ -35,145 +38,120 @@ import org.robolectric.RobolectricTestRunner; + */ + @RunWith(RobolectricTestRunner.class) + public final class ByteBufferExtractorTest { +- +- @Test +- public void extract_fromByteBuffer_succeeds() { +- ByteBuffer byteBuffer = TestImageCreator.createRgbBuffer(); +- MlImage image = +- new ByteBufferMlImageBuilder( +- byteBuffer, +- TestImageCreator.getWidth(), +- TestImageCreator.getHeight(), +- MlImage.IMAGE_FORMAT_RGB) +- .build(); +- +- ByteBuffer result = ByteBufferExtractor.extract(image); +- +- assertThat(result).isEquivalentAccordingToCompareTo(byteBuffer); +- assertThat(result.isReadOnly()).isTrue(); +- } +- +- @Test +- public void extract_fromBitmap_throws() { +- Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap(); +- MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build(); +- +- assertThrows(IllegalArgumentException.class, () -> ByteBufferExtractor.extract(image)); +- } +- +- @Test +- public void extract_rgbFromRgbByteBuffer_succeeds() { +- ByteBuffer buffer = TestImageCreator.createRgbBuffer(); +- MlImage image = +- new ByteBufferMlImageBuilder( +- buffer, +- TestImageCreator.getWidth(), +- TestImageCreator.getHeight(), +- MlImage.IMAGE_FORMAT_RGB) +- .build(); +- +- ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB); +- +- assertThat(result.isReadOnly()).isTrue(); +- assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer()); +- } +- +- @Test +- public void extract_rgbFromRgbaByteBuffer_succeeds() { +- ByteBuffer buffer = TestImageCreator.createRgbaBuffer(); +- MlImage image = +- new ByteBufferMlImageBuilder( +- buffer, +- TestImageCreator.getWidth(), +- TestImageCreator.getHeight(), +- MlImage.IMAGE_FORMAT_RGBA) +- .build(); +- +- ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB); +- +- assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer()); +- assertThat(buffer.position()).isEqualTo(0); +- } +- +- @Test +- public void extract_rgbaFromRgbByteBuffer_succeeds() { +- ByteBuffer buffer = TestImageCreator.createRgbBuffer(); +- MlImage image = +- new ByteBufferMlImageBuilder( +- buffer, +- TestImageCreator.getWidth(), +- TestImageCreator.getHeight(), +- MlImage.IMAGE_FORMAT_RGB) +- .build(); +- +- ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGBA); +- +- assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createOpaqueRgbaBuffer()); +- assertThat(buffer.position()).isEqualTo(0); +- } +- +- @Test +- public void extract_rgbFromRgbaBitmap_succeeds() { +- Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap(); +- MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build(); +- +- ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB); +- +- assertThat(result.isReadOnly()).isTrue(); +- assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer()); +- +- // Verifies ByteBuffer is cached inside MlImage. +- ByteBufferImageContainer byteBufferImageContainer = +- (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER); +- assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result); +- assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB); +- +- // Verifies that extracted ByteBuffer is the cached one. +- ByteBuffer result2 = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB); +- assertThat(result2).isEqualTo(result); +- } +- +- @Test +- public void extract_unsupportedFormatFromByteBuffer_throws() { +- ByteBuffer buffer = TestImageCreator.createRgbaBuffer(); +- MlImage image = +- new ByteBufferMlImageBuilder( +- buffer, +- TestImageCreator.getWidth(), +- TestImageCreator.getHeight(), +- MlImage.IMAGE_FORMAT_RGBA) +- .build(); +- +- assertThrows( +- IllegalArgumentException.class, +- () -> ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_YUV_420_888)); +- } +- +- @Test +- public void extractInRecommendedFormat_anyFormatFromRgbByteBuffer_succeeds() { +- ByteBuffer buffer = TestImageCreator.createRgbBuffer(); +- MlImage image = +- new ByteBufferMlImageBuilder( +- buffer, +- TestImageCreator.getWidth(), +- TestImageCreator.getHeight(), +- MlImage.IMAGE_FORMAT_RGB) +- .build(); +- +- ByteBufferExtractor.Result result = ByteBufferExtractor.extractInRecommendedFormat(image); +- +- assertThat(result.buffer().isReadOnly()).isTrue(); +- assertThat(result.format()).isEqualTo(MlImage.IMAGE_FORMAT_RGB); +- +- // Verifies ByteBuffer is cached inside MlImage. +- ByteBufferImageContainer byteBufferImageContainer = +- (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER); +- assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result.buffer()); +- assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB); +- +- // Verifies that extracted ByteBuffer is the cached one. +- ByteBufferExtractor.Result result2 = ByteBufferExtractor.extractInRecommendedFormat(image); +- assertThat(result2.buffer()).isEqualTo(result.buffer()); +- assertThat(result2.format()).isEqualTo(result.format()); +- } ++ @Test ++ public void extract_fromByteBuffer_succeeds() { ++ ByteBuffer byteBuffer = TestImageCreator.createRgbBuffer(); ++ MlImage image = new ByteBufferMlImageBuilder(byteBuffer, TestImageCreator.getWidth(), ++ TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB) ++ .build(); ++ ++ ByteBuffer result = ByteBufferExtractor.extract(image); ++ ++ assertThat(result).isEquivalentAccordingToCompareTo(byteBuffer); ++ assertThat(result.isReadOnly()).isTrue(); ++ } ++ ++ @Test ++ public void extract_fromBitmap_throws() { ++ Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap(); ++ MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build(); ++ ++ assertThrows(IllegalArgumentException.class, () -> ByteBufferExtractor.extract(image)); ++ } ++ ++ @Test ++ public void extract_rgbFromRgbByteBuffer_succeeds() { ++ ByteBuffer buffer = TestImageCreator.createRgbBuffer(); ++ MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(), ++ TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB) ++ .build(); ++ ++ ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB); ++ ++ assertThat(result.isReadOnly()).isTrue(); ++ assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer()); ++ } ++ ++ @Test ++ public void extract_rgbFromRgbaByteBuffer_succeeds() { ++ ByteBuffer buffer = TestImageCreator.createRgbaBuffer(); ++ MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(), ++ TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGBA) ++ .build(); ++ ++ ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB); ++ ++ assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer()); ++ assertThat(buffer.position()).isEqualTo(0); ++ } ++ ++ @Test ++ public void extract_rgbaFromRgbByteBuffer_succeeds() { ++ ByteBuffer buffer = TestImageCreator.createRgbBuffer(); ++ MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(), ++ TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB) ++ .build(); ++ ++ ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGBA); ++ ++ assertThat(result).isEquivalentAccordingToCompareTo( ++ TestImageCreator.createOpaqueRgbaBuffer()); ++ assertThat(buffer.position()).isEqualTo(0); ++ } ++ ++ @Test ++ public void extract_rgbFromRgbaBitmap_succeeds() { ++ Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap(); ++ MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build(); ++ ++ ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB); ++ ++ assertThat(result.isReadOnly()).isTrue(); ++ assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer()); ++ ++ // Verifies ByteBuffer is cached inside MlImage. ++ ByteBufferImageContainer byteBufferImageContainer = ++ (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER); ++ assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result); ++ assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB); ++ ++ // Verifies that extracted ByteBuffer is the cached one. ++ ByteBuffer result2 = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB); ++ assertThat(result2).isEqualTo(result); ++ } ++ ++ @Test ++ public void extract_unsupportedFormatFromByteBuffer_throws() { ++ ByteBuffer buffer = TestImageCreator.createRgbaBuffer(); ++ MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(), ++ TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGBA) ++ .build(); ++ ++ assertThrows(IllegalArgumentException.class, ++ () -> ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_YUV_420_888)); ++ } ++ ++ @Test ++ public void extractInRecommendedFormat_anyFormatFromRgbByteBuffer_succeeds() { ++ ByteBuffer buffer = TestImageCreator.createRgbBuffer(); ++ MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(), ++ TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB) ++ .build(); ++ ++ ByteBufferExtractor.Result result = ByteBufferExtractor.extractInRecommendedFormat(image); ++ ++ assertThat(result.buffer().isReadOnly()).isTrue(); ++ assertThat(result.format()).isEqualTo(MlImage.IMAGE_FORMAT_RGB); ++ ++ // Verifies ByteBuffer is cached inside MlImage. ++ ByteBufferImageContainer byteBufferImageContainer = ++ (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER); ++ assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result.buffer()); ++ assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB); ++ ++ // Verifies that extracted ByteBuffer is the cached one. ++ ByteBufferExtractor.Result result2 = ByteBufferExtractor.extractInRecommendedFormat(image); ++ assertThat(result2.buffer()).isEqualTo(result.buffer()); ++ assertThat(result2.format()).isEqualTo(result.format()); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java +index 45ba77934a61f..374c82b3f4e8d 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java +@@ -16,61 +16,62 @@ limitations under the License. + package com.google.android.odml.image; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + + import android.graphics.Rect; +-import java.nio.ByteBuffer; ++ + import org.junit.Test; + import org.junit.runner.RunWith; + import org.robolectric.RobolectricTestRunner; + ++import java.nio.ByteBuffer; ++ + /** Tests for {@link ByteBufferMlImageBuilder} */ + @RunWith(RobolectricTestRunner.class) + public final class ByteBufferMlImageBuilderTest { ++ @Test ++ public void build_fromByteBuffer_succeeds() { ++ ByteBuffer buffer = ByteBuffer.allocate(500); ++ ++ MlImage image = ++ new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB).build(); ++ ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER); ++ ++ assertThat(image.getWidth()).isEqualTo(20); ++ assertThat(image.getHeight()).isEqualTo(25); ++ assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, 20, 25)); ++ assertThat(image.getRotation()).isEqualTo(0); ++ assertThat(image.getContainedImageProperties()) ++ .containsExactly(ImageProperties.builder() ++ .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER) ++ .setImageFormat(MlImage.IMAGE_FORMAT_RGB) ++ .build()); ++ assertThat(((ByteBufferImageContainer) container).getImageFormat()) ++ .isEqualTo(MlImage.IMAGE_FORMAT_RGB); ++ } ++ ++ @Test ++ public void build_withOptionalProperties_succeeds() { ++ ByteBuffer buffer = ByteBuffer.allocate(500); ++ ++ MlImage image = new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB) ++ .setRoi(new Rect(0, 5, 10, 15)) ++ .setRotation(90) ++ .setTimestamp(12345) ++ .build(); ++ ++ assertThat(image.getTimestamp()).isEqualTo(12345); ++ assertThat(image.getRotation()).isEqualTo(90); ++ assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15)); ++ } ++ ++ @Test ++ public void build_withInvalidRotation_throwsException() { ++ ByteBuffer buffer = ByteBuffer.allocate(500); ++ ByteBufferMlImageBuilder builder = ++ new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB); + +- @Test +- public void build_fromByteBuffer_succeeds() { +- ByteBuffer buffer = ByteBuffer.allocate(500); +- +- MlImage image = new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB).build(); +- ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER); +- +- assertThat(image.getWidth()).isEqualTo(20); +- assertThat(image.getHeight()).isEqualTo(25); +- assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, 20, 25)); +- assertThat(image.getRotation()).isEqualTo(0); +- assertThat(image.getContainedImageProperties()) +- .containsExactly( +- ImageProperties.builder() +- .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER) +- .setImageFormat(MlImage.IMAGE_FORMAT_RGB) +- .build()); +- assertThat(((ByteBufferImageContainer) container).getImageFormat()) +- .isEqualTo(MlImage.IMAGE_FORMAT_RGB); +- } +- +- @Test +- public void build_withOptionalProperties_succeeds() { +- ByteBuffer buffer = ByteBuffer.allocate(500); +- +- MlImage image = +- new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB) +- .setRoi(new Rect(0, 5, 10, 15)) +- .setRotation(90) +- .setTimestamp(12345) +- .build(); +- +- assertThat(image.getTimestamp()).isEqualTo(12345); +- assertThat(image.getRotation()).isEqualTo(90); +- assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15)); +- } +- +- @Test +- public void build_withInvalidRotation_throwsException() { +- ByteBuffer buffer = ByteBuffer.allocate(500); +- ByteBufferMlImageBuilder builder = +- new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB); +- +- assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360)); +- } ++ assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360)); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java +index 67ed4a7f6e2c4..fa832671e4458 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java +@@ -16,6 +16,7 @@ limitations under the License. + package com.google.android.odml.image; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + import static org.mockito.Mockito.when; + +@@ -23,6 +24,7 @@ import android.graphics.Bitmap; + import android.graphics.Bitmap.Config; + import android.graphics.ImageFormat; + import android.media.Image; ++ + import org.junit.Before; + import org.junit.Test; + import org.junit.runner.RunWith; +@@ -33,34 +35,34 @@ import org.robolectric.RobolectricTestRunner; + /** Tests for {@link MediaImageExtractor} */ + @RunWith(RobolectricTestRunner.class) + public final class MediaImageExtractorTest { +- private static final int HEIGHT = 100; +- private static final int WIDTH = 50; ++ private static final int HEIGHT = 100; ++ private static final int WIDTH = 50; + +- @Mock private Image mediaImage; ++ @Mock ++ private Image mediaImage; + +- @Before +- public void setUp() { +- MockitoAnnotations.initMocks(this); ++ @Before ++ public void setUp() { ++ MockitoAnnotations.initMocks(this); + +- when(mediaImage.getHeight()).thenReturn(HEIGHT); +- when(mediaImage.getWidth()).thenReturn(WIDTH); +- when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888); +- } ++ when(mediaImage.getHeight()).thenReturn(HEIGHT); ++ when(mediaImage.getWidth()).thenReturn(WIDTH); ++ when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888); ++ } + +- @Test +- public void extract_fromMediaMlImage_succeeds() { +- MlImage image = new MediaMlImageBuilder(mediaImage).build(); +- Image extractedMediaImage = MediaImageExtractor.extract(image); ++ @Test ++ public void extract_fromMediaMlImage_succeeds() { ++ MlImage image = new MediaMlImageBuilder(mediaImage).build(); ++ Image extractedMediaImage = MediaImageExtractor.extract(image); + +- assertThat(extractedMediaImage).isSameInstanceAs(image); +- } ++ assertThat(extractedMediaImage).isSameInstanceAs(image); ++ } + +- @Test +- public void extract_fromBitmapMlImage_throwsException() { +- MlImage image = +- new BitmapMlImageBuilder( ++ @Test ++ public void extract_fromBitmapMlImage_throwsException() { ++ MlImage image = new BitmapMlImageBuilder( + Bitmap.createBitmap(/* width= */ 20, /* height= */ 25, Config.ARGB_8888)) +- .build(); +- assertThrows(IllegalArgumentException.class, () -> MediaImageExtractor.extract(image)); +- } ++ .build(); ++ assertThrows(IllegalArgumentException.class, () -> MediaImageExtractor.extract(image)); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java +index 4f589874bfaf8..60397feceb067 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java +@@ -16,12 +16,14 @@ limitations under the License. + package com.google.android.odml.image; + + import static com.google.common.truth.Truth.assertThat; ++ + import static org.junit.Assert.assertThrows; + import static org.mockito.Mockito.when; + + import android.graphics.ImageFormat; + import android.graphics.Rect; + import android.media.Image; ++ + import org.junit.Before; + import org.junit.Test; + import org.junit.runner.RunWith; +@@ -32,58 +34,57 @@ import org.robolectric.RobolectricTestRunner; + /** Tests for {@link MediaMlImageBuilder} */ + @RunWith(RobolectricTestRunner.class) + public final class MediaMlImageBuilderTest { +- private static final int HEIGHT = 100; +- private static final int WIDTH = 50; +- +- @Mock private Image mediaImage; +- +- @Before +- public void setUp() { +- MockitoAnnotations.initMocks(this); +- +- when(mediaImage.getHeight()).thenReturn(HEIGHT); +- when(mediaImage.getWidth()).thenReturn(WIDTH); +- when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888); +- } +- +- @Test +- public void build_fromMediaImage_succeeds() { +- MlImage image = new MediaMlImageBuilder(mediaImage).build(); +- ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE); +- +- assertThat(image.getWidth()).isEqualTo(WIDTH); +- assertThat(image.getHeight()).isEqualTo(HEIGHT); +- assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, WIDTH, HEIGHT)); +- assertThat(image.getRotation()).isEqualTo(0); +- assertThat(image.getTimestamp()).isAtLeast(0); +- assertThat(image.getContainedImageProperties()) +- .containsExactly( +- ImageProperties.builder() +- .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE) +- .setImageFormat(MlImage.IMAGE_FORMAT_YUV_420_888) +- .build()); +- assertThat(((MediaImageContainer) container).getImage().getFormat()) +- .isEqualTo(ImageFormat.YUV_420_888); +- } +- +- @Test +- public void build_withOptionalProperties_succeeds() { +- MlImage image = +- new MediaMlImageBuilder(mediaImage) +- .setTimestamp(12345) +- .setRoi(new Rect(0, 5, 10, 15)) +- .setRotation(90) +- .build(); +- +- assertThat(image.getTimestamp()).isEqualTo(12345); +- assertThat(image.getRotation()).isEqualTo(90); +- assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15)); +- } +- +- @Test +- public void build_withInvalidRotation_throwsException() { +- MediaMlImageBuilder builder = new MediaMlImageBuilder(mediaImage); +- +- assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360)); +- } ++ private static final int HEIGHT = 100; ++ private static final int WIDTH = 50; ++ ++ @Mock ++ private Image mediaImage; ++ ++ @Before ++ public void setUp() { ++ MockitoAnnotations.initMocks(this); ++ ++ when(mediaImage.getHeight()).thenReturn(HEIGHT); ++ when(mediaImage.getWidth()).thenReturn(WIDTH); ++ when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888); ++ } ++ ++ @Test ++ public void build_fromMediaImage_succeeds() { ++ MlImage image = new MediaMlImageBuilder(mediaImage).build(); ++ ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE); ++ ++ assertThat(image.getWidth()).isEqualTo(WIDTH); ++ assertThat(image.getHeight()).isEqualTo(HEIGHT); ++ assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, WIDTH, HEIGHT)); ++ assertThat(image.getRotation()).isEqualTo(0); ++ assertThat(image.getTimestamp()).isAtLeast(0); ++ assertThat(image.getContainedImageProperties()) ++ .containsExactly(ImageProperties.builder() ++ .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE) ++ .setImageFormat(MlImage.IMAGE_FORMAT_YUV_420_888) ++ .build()); ++ assertThat(((MediaImageContainer) container).getImage().getFormat()) ++ .isEqualTo(ImageFormat.YUV_420_888); ++ } ++ ++ @Test ++ public void build_withOptionalProperties_succeeds() { ++ MlImage image = new MediaMlImageBuilder(mediaImage) ++ .setTimestamp(12345) ++ .setRoi(new Rect(0, 5, 10, 15)) ++ .setRotation(90) ++ .build(); ++ ++ assertThat(image.getTimestamp()).isEqualTo(12345); ++ assertThat(image.getRotation()).isEqualTo(90); ++ assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15)); ++ } ++ ++ @Test ++ public void build_withInvalidRotation_throwsException() { ++ MediaMlImageBuilder builder = new MediaMlImageBuilder(mediaImage); ++ ++ assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360)); ++ } + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java +index c9e7134bedd93..28f54be2c70a3 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java +@@ -17,6 +17,7 @@ package com.google.android.odml.image; + + import android.graphics.Bitmap; + import android.graphics.Color; ++ + import java.nio.ByteBuffer; + + /** +@@ -35,113 +36,113 @@ import java.nio.ByteBuffer; + * <p>The created {@link Bitmap} is not pre-multiplied. + */ + final class TestImageCreator { ++ private static final int RED = 0x73; ++ private static final int GREEN = 0x85; ++ private static final int BLUE = 0x96; ++ private static final int ALPHA = 0x70; ++ ++ static int getWidth() { ++ return 10; ++ } ++ ++ static int getHeight() { ++ return 2; ++ } ++ ++ /** ++ * Creates an example non-pre-multiplied bitmap which is 100% opaque. ++ * ++ * @see TestImageCreator for details. ++ */ ++ static Bitmap createOpaqueRgbaBitmap() { ++ return createRgbaBitmap(0xff); ++ } ++ ++ /** ++ * Creates an example non-pre-multiplied bitmap which has non-trivial alpha channel. ++ * ++ * @see TestImageCreator for details. ++ */ ++ static Bitmap createRgbaBitmap() { ++ return createRgbaBitmap(ALPHA); ++ } + +- private static final int RED = 0x73; +- private static final int GREEN = 0x85; +- private static final int BLUE = 0x96; +- private static final int ALPHA = 0x70; +- +- static int getWidth() { +- return 10; +- } +- +- static int getHeight() { +- return 2; +- } +- +- /** +- * Creates an example non-pre-multiplied bitmap which is 100% opaque. +- * +- * @see TestImageCreator for details. +- */ +- static Bitmap createOpaqueRgbaBitmap() { +- return createRgbaBitmap(0xff); +- } +- +- /** +- * Creates an example non-pre-multiplied bitmap which has non-trivial alpha channel. +- * +- * @see TestImageCreator for details. +- */ +- static Bitmap createRgbaBitmap() { +- return createRgbaBitmap(ALPHA); +- } +- +- /** +- * Creates an example 10x2 bitmap demonstrated in the class doc. A channel sets to {@code alpha}. +- */ +- static Bitmap createRgbaBitmap(int alpha) { +- int[] colors = new int[20]; +- for (int i = 0; i < 5; i++) { +- colors[i] = Color.argb(alpha, 0, 0, BLUE); +- colors[i + 5] = Color.argb(alpha, 0xff, 0xff, 0xff); +- colors[i + 10] = Color.argb(alpha, 0, GREEN, 0); +- colors[i + 15] = Color.argb(alpha, RED, 0, 0); ++ /** ++ * Creates an example 10x2 bitmap demonstrated in the class doc. A channel sets to {@code ++ * alpha}. ++ */ ++ static Bitmap createRgbaBitmap(int alpha) { ++ int[] colors = new int[20]; ++ for (int i = 0; i < 5; i++) { ++ colors[i] = Color.argb(alpha, 0, 0, BLUE); ++ colors[i + 5] = Color.argb(alpha, 0xff, 0xff, 0xff); ++ colors[i + 10] = Color.argb(alpha, 0, GREEN, 0); ++ colors[i + 15] = Color.argb(alpha, RED, 0, 0); ++ } ++ // We don't use Bitmap#createBitmap(int[] ...) here, because that method creates ++ // pre-multiplied bitmaps. ++ Bitmap bitmap = Bitmap.createBitmap(10, 2, Bitmap.Config.ARGB_8888); ++ bitmap.setPremultiplied(false); ++ bitmap.setPixels(colors, 0, 10, 0, 0, 10, 2); ++ return bitmap; + } +- // We don't use Bitmap#createBitmap(int[] ...) here, because that method creates pre-multiplied +- // bitmaps. +- Bitmap bitmap = Bitmap.createBitmap(10, 2, Bitmap.Config.ARGB_8888); +- bitmap.setPremultiplied(false); +- bitmap.setPixels(colors, 0, 10, 0, 0, 10, 2); +- return bitmap; +- } +- +- /** +- * Creates an example 10*10*3 bytebuffer in R-G-B format. +- * +- * @see TestImageCreator for details. +- */ +- static ByteBuffer createRgbBuffer() { +- return createRgbOrRgbaBuffer(false, 0xff); +- } +- +- /** +- * Creates an example 10*10*4 bytebuffer in R-G-B-A format. +- * +- * @see TestImageCreator for details. +- */ +- static ByteBuffer createRgbaBuffer() { +- return createRgbOrRgbaBuffer(true, ALPHA); +- } +- +- /** +- * Creates an example 10*10*4 bytebuffer in R-G-B-A format, but the A channel is 0xFF. +- * +- * @see TestImageCreator for details. +- */ +- static ByteBuffer createOpaqueRgbaBuffer() { +- return createRgbOrRgbaBuffer(true, 0xff); +- } +- +- /** +- * Creates an example 10x2x4 (or 10x2x3 if no alpha) bytebuffer demonstrated in the class doc. +- * +- * @param withAlpha if true, set A to {@code alpha}, otherwise A channel is ignored. +- * @param alpha alpha channel value. Only effective when {@code withAlpha} is {@code true}. +- */ +- static ByteBuffer createRgbOrRgbaBuffer(boolean withAlpha, int alpha) { +- int capacity = withAlpha ? 80 : 60; +- ByteBuffer buffer = ByteBuffer.allocateDirect(capacity); +- putColorInByteBuffer(buffer, 0, 0, BLUE, withAlpha, alpha, 5); +- putColorInByteBuffer(buffer, 0xff, 0xff, 0xff, withAlpha, alpha, 5); +- putColorInByteBuffer(buffer, 0, GREEN, 0, withAlpha, alpha, 5); +- putColorInByteBuffer(buffer, RED, 0, 0, withAlpha, alpha, 5); +- buffer.rewind(); +- return buffer; +- } +- +- private static void putColorInByteBuffer( +- ByteBuffer buffer, int r, int g, int b, boolean withAlpha, int alpha, int num) { +- for (int i = 0; i < num; i++) { +- buffer.put((byte) r); +- buffer.put((byte) g); +- buffer.put((byte) b); +- if (withAlpha) { +- buffer.put((byte) alpha); +- } ++ ++ /** ++ * Creates an example 10*10*3 bytebuffer in R-G-B format. ++ * ++ * @see TestImageCreator for details. ++ */ ++ static ByteBuffer createRgbBuffer() { ++ return createRgbOrRgbaBuffer(false, 0xff); ++ } ++ ++ /** ++ * Creates an example 10*10*4 bytebuffer in R-G-B-A format. ++ * ++ * @see TestImageCreator for details. ++ */ ++ static ByteBuffer createRgbaBuffer() { ++ return createRgbOrRgbaBuffer(true, ALPHA); ++ } ++ ++ /** ++ * Creates an example 10*10*4 bytebuffer in R-G-B-A format, but the A channel is 0xFF. ++ * ++ * @see TestImageCreator for details. ++ */ ++ static ByteBuffer createOpaqueRgbaBuffer() { ++ return createRgbOrRgbaBuffer(true, 0xff); ++ } ++ ++ /** ++ * Creates an example 10x2x4 (or 10x2x3 if no alpha) bytebuffer demonstrated in the class doc. ++ * ++ * @param withAlpha if true, set A to {@code alpha}, otherwise A channel is ignored. ++ * @param alpha alpha channel value. Only effective when {@code withAlpha} is {@code true}. ++ */ ++ static ByteBuffer createRgbOrRgbaBuffer(boolean withAlpha, int alpha) { ++ int capacity = withAlpha ? 80 : 60; ++ ByteBuffer buffer = ByteBuffer.allocateDirect(capacity); ++ putColorInByteBuffer(buffer, 0, 0, BLUE, withAlpha, alpha, 5); ++ putColorInByteBuffer(buffer, 0xff, 0xff, 0xff, withAlpha, alpha, 5); ++ putColorInByteBuffer(buffer, 0, GREEN, 0, withAlpha, alpha, 5); ++ putColorInByteBuffer(buffer, RED, 0, 0, withAlpha, alpha, 5); ++ buffer.rewind(); ++ return buffer; ++ } ++ ++ private static void putColorInByteBuffer( ++ ByteBuffer buffer, int r, int g, int b, boolean withAlpha, int alpha, int num) { ++ for (int i = 0; i < num; i++) { ++ buffer.put((byte) r); ++ buffer.put((byte) g); ++ buffer.put((byte) b); ++ if (withAlpha) { ++ buffer.put((byte) alpha); ++ } ++ } + } +- } + +- // Should not be instantiated. +- private TestImageCreator() {} ++ // Should not be instantiated. ++ private TestImageCreator() {} + } +diff --git a/third_party/tflite_support/src/third_party/fft2d/fft.h b/third_party/tflite_support/src/third_party/fft2d/fft.h +index 36d838b7f6280..35dbcc766c169 100644 +--- a/third_party/tflite_support/src/third_party/fft2d/fft.h ++++ b/third_party/tflite_support/src/third_party/fft2d/fft.h +@@ -22,12 +22,12 @@ limitations under the License. + extern "C" { + #endif + +-extern void cdft(int, int, double *, int *, double *); +-extern void rdft(int, int, double *, int *, double *); +-extern void ddct(int, int, double *, int *, double *); +-extern void ddst(int, int, double *, int *, double *); +-extern void dfct(int, double *, double *, int *, double *); +-extern void dfst(int, double *, double *, int *, double *); ++extern void cdft(int, int, double*, int*, double*); ++extern void rdft(int, int, double*, int*, double*); ++extern void ddct(int, int, double*, int*, double*); ++extern void ddst(int, int, double*, int*, double*); ++extern void dfct(int, double*, double*, int*, double*); ++extern void dfst(int, double*, double*, int*, double*); + + #ifdef __cplusplus + } +diff --git a/third_party/tflite_support/src/third_party/fft2d/fft2d.h b/third_party/tflite_support/src/third_party/fft2d/fft2d.h +index d587b3b441ce2..d79441827d54c 100644 +--- a/third_party/tflite_support/src/third_party/fft2d/fft2d.h ++++ b/third_party/tflite_support/src/third_party/fft2d/fft2d.h +@@ -22,12 +22,12 @@ limitations under the License. + extern "C" { + #endif + +-extern void cdft2d(int, int, int, double **, double *, int *, double *); +-extern void rdft2d(int, int, int, double **, double *, int *, double *); +-extern void ddct2d(int, int, int, double **, double *, int *, double *); +-extern void ddst2d(int, int, int, double **, double *, int *, double *); +-extern void ddct8x8s(int isgn, double **a); +-extern void ddct16x16s(int isgn, double **a); ++extern void cdft2d(int, int, int, double**, double*, int*, double*); ++extern void rdft2d(int, int, int, double**, double*, int*, double*); ++extern void ddct2d(int, int, int, double**, double*, int*, double*); ++extern void ddst2d(int, int, int, double**, double*, int*, double*); ++extern void ddct8x8s(int isgn, double** a); ++extern void ddct16x16s(int isgn, double** a); + + #ifdef __cplusplus + } +-- +2.34.1.307.g9b7440fafd-goog +
diff --git a/third_party/tflite_support/patches/0012-rm-stdio-static-init.patch b/third_party/tflite_support/patches/0012-rm-stdio-static-init.patch new file mode 100644 index 0000000..4562b6d --- /dev/null +++ b/third_party/tflite_support/patches/0012-rm-stdio-static-init.patch
@@ -0,0 +1,38 @@ +From f3ab1569fa4dfd69f74d1bb9c6d1c2c26e9215ce Mon Sep 17 00:00:00 2001 +From: Robert Ogden <robertogden@chromium.org> +Date: Fri, 7 Jan 2022 09:30:06 -0800 +Subject: [PATCH] rm stdio static init + +--- + .../cc/task/core/tflite_engine.cc | 15 --------------- + 1 file changed, 15 deletions(-) + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc +index 0b34bad4f18f7..41e06389af80b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc +@@ -32,21 +32,6 @@ namespace tflite { + namespace task { + namespace core { + +-#ifdef __ANDROID__ +-// https://github.com/opencv/opencv/issues/14906 +-// "ios_base::Init" object is not a part of Android's "iostream" header (in case +-// of clang toolchain, NDK 20). +-// +-// Ref1: +-// https://en.cppreference.com/w/cpp/io/ios_base/Init +-// The header <iostream> behaves as if it defines (directly or indirectly) +-// an instance of std::ios_base::Init with static storage duration +-// +-// Ref2: +-// https://github.com/gcc-mirror/gcc/blob/gcc-8-branch/libstdc%2B%2B-v3/include/std/iostream#L73-L74 +-static std::ios_base::Init s_iostream_initializer; +-#endif +- + using ::absl::StatusCode; + using ::tflite::proto::ComputeSettings; + using ::tflite::support::CreateStatusWithPayload; +-- +2.34.1.575.g55b058a8bb-goog +
diff --git a/third_party/tflite_support/src/.bazelrc b/third_party/tflite_support/src/.bazelrc new file mode 100644 index 0000000..02ba9da --- /dev/null +++ b/third_party/tflite_support/src/.bazelrc
@@ -0,0 +1,170 @@ +# This file is based on tensorflow's (v2.2.0) .bazelrc found here: +# https://github.com/tensorflow/tensorflow/blob/v2.2.0/.bazelrc + +# Sets the default Apple platform to macOS. +build:macos --apple_platform_type=macos + +# Flag to enable remote config. Required starting from TF 2.2. +common --experimental_repo_remote_exec + +# For workaround https://github.com/bazelbuild/bazel/issues/8772 with Bazel >= 0.29.1 +build --java_toolchain=//third_party/toolchains/java:tf_java_toolchain +build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain + +# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs. +build:android --copt=-w +build:linux --copt=-w +build:macos --copt=-w +build:windows --copt=/w + +# Android workspace configurations. Should be replaced by an interative configure in the future. +build --action_env ANDROID_NDK_HOME +build --action_env ANDROID_NDK_API_LEVEL +build --action_env ANDROID_BUILD_TOOLS_VERSION +build --action_env ANDROID_SDK_API_LEVEL +build --action_env ANDROID_SDK_HOME + +# Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the +# target CPU to build transient dependencies correctly. See +# https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu + +build:android --crosstool_top=//external:android/crosstool +build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain +build:android_arm --config=android +build:android_arm --cpu=armeabi-v7a +build:android_arm --fat_apk_cpu=armeabi-v7a +build:android_arm64 --config=android +build:android_arm64 --cpu=arm64-v8a +build:android_arm64 --fat_apk_cpu=arm64-v8a +build:android_x86 --config=android +build:android_x86 --cpu=x86 +build:android_x86 --fat_apk_cpu=x86 +build:android_x86_64 --config=android +build:android_x86_64 --cpu=x86_64 +build:android_x86_64 --fat_apk_cpu=x86_64 + +# iOS configs for each architecture and the fat binary builds. +build:ios --apple_platform_type=ios +build:ios --apple_bitcode=embedded --copt=-fembed-bitcode +build:ios --copt=-Wno-c++11-narrowing +build:ios_armv7 --config=ios +build:ios_armv7 --cpu=ios_armv7 +build:ios_arm64 --config=ios +build:ios_arm64 --cpu=ios_arm64 +build:ios_x86_64 --config=ios +build:ios_x86_64 --cpu=ios_x86_64 +build:ios_fat --config=ios +build:ios_fat --ios_multi_cpus=armv7,arm64,x86_64 + +# TFLite build configs for generic embedded Linux +build:elinux --crosstool_top=@local_config_embedded_arm//:toolchain +build:elinux --host_crosstool_top=@bazel_tools//tools/cpp:toolchain +build:elinux_aarch64 --config=elinux +build:elinux_aarch64 --cpu=aarch64 +build:elinux_aarch64 --distinct_host_configuration=true +build:elinux_armhf --config=elinux +build:elinux_armhf --cpu=armhf +build:elinux_armhf --distinct_host_configuration=true + +# By default, build TF in C++ 14 mode. +build:android --cxxopt=-std=c++14 +build:android --host_cxxopt=-std=c++14 +build:ios --cxxopt=-std=c++14 +build:ios --host_cxxopt=-std=c++14 +build:linux --cxxopt=-std=c++14 +build:linux --host_cxxopt=-std=c++14 +build:macos --cxxopt=-std=c++14 +build:macos --host_cxxopt=-std=c++14 +build:windows --cxxopt=/std:c++14 +build:windows --host_cxxopt=/std:c++14 + +# Config to use a mostly-static build and disable modular op registration +# support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python). +# By default, TensorFlow will build with a dependence on +# //tensorflow:libtensorflow_framework.so. +build:monolithic --define framework_shared_object=false + +# For projects which use TensorFlow as part of a Bazel build process, putting +# nothing in a bazelrc will default to a monolithic build. The following line +# opts in to modular op registration support by default. +build --define framework_shared_object=true + +# ASAN build +build:asan --strip=never +build:asan --copt -fsanitize=address +build:asan --copt -DADDRESS_SANITIZER +build:asan --copt -O1 +build:asan --copt -g +build:asan --copt -fno-omit-frame-pointer +build:asan --linkopt -fsanitize=address + +# dbg config, as a shorthand for '--config=opt -c dbg' +build:dbg --config=opt -c dbg +# for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360 +build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON +# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498 +build:dbg --copt -DDEBUG_BUILD + +build --define=use_fast_cpp_protos=true +build --define=allow_oversize_protos=true + +# TF uses `standalone`, which is deprecated. +build --spawn_strategy=local +build -c opt + +# Adding "--cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0" creates parity with TF +# compilation options. It also addresses memory use due to +# copy-on-write semantics of std::strings of the older ABI. +build --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0 + +# Make Bazel print out all options from rc files. +build --announce_rc + +# Other build flags. +build --define=grpc_no_ares=true + +# See https://github.com/bazelbuild/bazel/issues/7362 for information on what +# --incompatible_remove_legacy_whole_archive flag does. +# This flag is set to true in Bazel 1.0 and newer versions. We tried to migrate +# Tensorflow to the default, however test coverage wasn't enough to catch the +# errors. +# There is ongoing work on Bazel team's side to provide support for transitive +# shared libraries. As part of migrating to transitive shared libraries, we +# hope to provide a better mechanism for control over symbol exporting, and +# then tackle this issue again. +# +# TODO: Remove this line once TF doesn't depend on Bazel wrapping all library +# archives in -whole_archive -no_whole_archive. +build --noincompatible_remove_legacy_whole_archive + +# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0 +# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC: +# https://github.com/tensorflow/community/pull/179 +build --noincompatible_prohibit_aapt1 + +# Build TF with C++ 17 features. +build:c++17 --cxxopt=-std=c++1z +build:c++17 --cxxopt=-stdlib=libc++ +build:c++1z --config=c++17 + +# Enable using platform specific build settings, except when cross-compiling for +# mobile platforms. +build --enable_platform_specific_config +build:android --noenable_platform_specific_config +build:ios --noenable_platform_specific_config + +# Suppress all warning messages. +build:short_logs --output_filter=DONT_MATCH_ANYTHING +build:verbose_logs --output_filter= +build --config=short_logs + +# Options to build TensorFlow 1.x or 2.x. +build:v1 --define=tf_api_version=1 +build:v2 --define=tf_api_version=2 +build:v1 --action_env=TF2_BEHAVIOR=0 +build:v2 --action_env=TF2_BEHAVIOR=1 +build --config=v2 +test --config=v2 + +# Put user-specific options in .bazelrc.user +try-import %workspace%/.bazelrc.user
diff --git a/third_party/tflite_support/src/.bazelversion b/third_party/tflite_support/src/.bazelversion new file mode 100644 index 0000000..47b6be3f --- /dev/null +++ b/third_party/tflite_support/src/.bazelversion
@@ -0,0 +1 @@ +3.7.2 \ No newline at end of file
diff --git a/third_party/tflite_support/src/README.md b/third_party/tflite_support/src/README.md index 67d4a8f..bab0201 100644 --- a/third_party/tflite_support/src/README.md +++ b/third_party/tflite_support/src/README.md
@@ -5,8 +5,8 @@ C++ (WIP), and Swift (WIP). The TFLite Support project consists of the following major components: -* **TFLite Support Library**: a cross-platform library that helps to - deploy TFLite models onto mobile devices. +* **TFLite Support Library**: a cross-platform library that helps to deploy + TFLite models onto mobile devices. * **TFLite Model Metadata**: (metadata populator and metadata extractor library): includes both human and machine readable information about what a model does and how to use the model. @@ -55,6 +55,11 @@ * `ANDROID_SDK_API_LEVEL` * `ANDROID_BUILD_TOOLS_VERSION` +## How to contribute + +Please issue a pull request and assign @xunkai55 or @lu-wang-g for a code +review. + ## Contact us Let us know what you think about TFLite Support by creating a
diff --git a/third_party/tflite_support/src/WORKSPACE b/third_party/tflite_support/src/WORKSPACE index 33e44e8..c7834d3 100644 --- a/third_party/tflite_support/src/WORKSPACE +++ b/third_party/tflite_support/src/WORKSPACE
@@ -1,9 +1,58 @@ workspace(name = "org_tensorflow_lite_support") load("@bazel_tools//tools/build_defs/repo:java.bzl", "java_import_external") -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file") load("@//third_party/py:python_configure.bzl", "python_configure") +http_file( + name = "mobilebert_float", + sha256 = "883bf5d40f0b0ae435326bb21ed0f4c9004b22c3fd1539383fd16d68623696dd", + urls = ["https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1?lite-format=tflite"], +) + +http_file( + name = "mobilebert_with_metadata", + sha256 = "e79d3c70108bbdee02da657b679349cab46dbb859a05b599c76b53d98e82f272", + urls = ["https://tfhub.dev/tensorflow/lite-model/mobilebert/1/metadata/1?lite-format=tflite"], +) + +http_file( + name = "30k-clean", + sha256 = "fefb02b667a6c5c2fe27602d28e5fb3428f66ab89c7d6f388e7c8d44a02d0336", + urls = ["https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/bert_qa/30k-clean.model"], +) + +http_file( + name = "mobilebert_vocab", + sha256 = "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3", + urls = ["https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/bert_qa/mobilebert_vocab.txt"], +) + + +http_file( + name = "albert", + sha256 = "4a29c7063c518925960229f49dd03e8da5d6682001cf73037815dcd98afd728a", + urls = ["https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1?lite-format=tflite"], +) + +http_file( + name = "albert_with_metadata", + sha256 = "8a8a91856b94b945e4a9f22f0332bbf105c3b6b878bb23abfc97eb89d3e8436a", + urls = ["https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/metadata/1?lite-format=tflite"], +) + +http_file( + name = "bert_nl_classifier", + sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600", + urls = ["https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/bert_nl_classifier/bert_nl_classifier.tflite"], +) + +http_file( + name = "bert_nl_classifier_no_metadata", + sha256 = "9b4554f6e28a72a3f40511964eed1ccf4e74cc074f81543cacca4faf169a173e", + urls = ["https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/bert_nl_classifier/bert_nl_classifier_no_metadata.tflite"], +) + http_archive( name = "io_bazel_rules_closure", sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9", @@ -14,6 +63,14 @@ ], ) +# GoogleTest/GoogleMock framework. Used by most unit-tests. +http_archive( + name = "com_google_googletest", + urls = ["https://github.com/google/googletest/archive/4ec4cd23f486bf70efcc5d2caa40f24368f752e3.zip"], + strip_prefix = "googletest-4ec4cd23f486bf70efcc5d2caa40f24368f752e3", + sha256 = "de682ea824bfffba05b4e33b67431c247397d6175962534305136aa06f92e049", +) + # Apple and Swift rules. # https://github.com/bazelbuild/rules_apple/releases http_archive( @@ -37,15 +94,21 @@ ], ) -# tf-nightly-20200810 +# TF on 2021-11-09. +TENSORFLOW_COMMIT = "6a144e7763914d3f6141a7cdc9cb116cc23425f9" +TENSORFLOW_SHA256 = "cec9a514c09d2b171ad447f3413151b25a6c3d88d048148cced1e85db81f3617" http_archive( name = "org_tensorflow", - sha256 = "fc6d7c57cd9427e695a38ad00fb6ecc3f623bac792dd44ad73a3f85b338b68be", - strip_prefix = "tensorflow-8a4ffe2e1ae722cff5306778df0cfca8b7f503fe", + sha256 = TENSORFLOW_SHA256, + strip_prefix = "tensorflow-" + TENSORFLOW_COMMIT, urls = [ - "https://github.com/tensorflow/tensorflow/archive/8a4ffe2e1ae722cff5306778df0cfca8b7f503fe.tar.gz", + "https://github.com/tensorflow/tensorflow/archive/" + TENSORFLOW_COMMIT + + ".tar.gz", ], - patches = ["@//third_party:tensorflow_lite_ios_build.patch"], + patches = [ + # We need to rename lite/ios/BUILD.apple to lite/ios/BUILD. + "@//third_party:tensorflow_lite_ios_build.patch", + ], patch_args = ["-p1"], ) @@ -60,11 +123,11 @@ third_party_http_archive( name = "pybind11", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.4.3.tar.gz", - "https://github.com/pybind/pybind11/archive/v2.4.3.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.6.0.tar.gz", + "https://github.com/pybind/pybind11/archive/v2.6.0.tar.gz", ], - sha256 = "1eed57bc6863190e35637290f97a20c81cfe4d9090ac0a24f3bbf08f265eb71d", - strip_prefix = "pybind11-2.4.3", + sha256 = "90b705137b69ee3b5fc655eaca66d0dc9862ea1759226f7ccd3098425ae69571", + strip_prefix = "pybind11-2.6.0", build_file = "//third_party:pybind11.BUILD", ) @@ -119,13 +182,14 @@ ], ) -# ABSL cpp library lts_2020_02_25 -# Needed for absl/status +# ABSL cpp library lts_2021_03_24 Patch2 +# See https://github.com/abseil/abseil-cpp/releases for details. +# Needed for absl/status and absl/status:statusor http_archive( name = "com_google_absl", build_file = "//third_party:com_google_absl.BUILD", urls = [ - "https://github.com/abseil/abseil-cpp/archive/20200225.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/20210324.2.tar.gz", ], # Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved. patches = [ @@ -134,8 +198,8 @@ patch_args = [ "-p1", ], - strip_prefix = "abseil-cpp-20200225", - sha256 = "728a813291bdec2aa46eab8356ace9f75ac2ed9dfe2df5ab603c4e6c09f1c353" + strip_prefix = "abseil-cpp-20210324.2", + sha256 = "59b862f50e710277f8ede96f083a5bb8d7c9595376146838b9580be90374ee1f" ) http_archive( @@ -175,12 +239,12 @@ http_archive( name = "libyuv", - urls = ["https://chromium.googlesource.com/libyuv/libyuv/+archive/6d603ec3f57dafddc424ef895e5d903915e94ba6.tar.gz"], - # Adding the constrain of sha256 and strip_prefix will cause failure. - # It seems that the downloaded libyuv was different every time, so that - # the specified sha256 and strip_prefix cannot match. - # sha256 = "ce196c72858456baa8022fa4a0dc18b77d619265dbc0e3d58e25ad15ca402522", - # strip_prefix = "libyuv-6d603ec3f57dafddc424ef895e5d903915e94ba6", + urls = ["https://chromium.googlesource.com/libyuv/libyuv/+archive/39240f7149cffde62e3620344d222c8ab2c21178.tar.gz"], + # Adding the constrain of sha256 and strip_prefix will cause failure as of + # Jan 2021. It seems that the downloaded libyuv was different every time, + # so that the specified sha256 and strip_prefix cannot match. + # sha256 = "01c2e30eb8e83880f9ba382f6bece9c38cd5b07f9cadae46ef1d5a69e07fafaf", + # strip_prefix = "libyuv-39240f7149cffde62e3620344d222c8ab2c21178", build_file = "//third_party:libyuv.BUILD", ) @@ -243,18 +307,18 @@ ) http_archive( - name = "com_google_protobuf", - sha256 = "a79d19dcdf9139fa4b81206e318e33d245c4c9da1ffed21c87288ed4380426f9", - strip_prefix = "protobuf-3.11.4", - urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.11.4.tar.gz"], - patches = [ - "@//third_party:com_google_protobuf_fixes.diff" - ], - patch_args = [ - "-p1", - ], + name = "libedgetpu", + sha256 = "a179016a5874c58db969a5edd3fecf57610604e751b5c4d6d82ad58c383ffd64", + strip_prefix = "libedgetpu-ea1eaddbddece0c9ca1166e868f8fd03f4a3199e", + urls = [ + "https://github.com/google-coral/libedgetpu/archive/ea1eaddbddece0c9ca1166e868f8fd03f4a3199e.tar.gz" + ], ) +# Set up TensorFlow version for Coral. +load("@libedgetpu//:workspace.bzl", "libedgetpu_dependencies") +libedgetpu_dependencies(TENSORFLOW_COMMIT, TENSORFLOW_SHA256) + # AutoValue 1.6+ shades Guava, Auto Common, and JavaPoet. That's OK # because none of these jars become runtime dependencies. java_import_external( @@ -317,12 +381,38 @@ default_visibility = ["@com_google_auto_value//:__pkg__"], ) +http_archive( + name = "robolectric", + urls = ["https://github.com/robolectric/robolectric-bazel/archive/4.4.tar.gz"], + strip_prefix = "robolectric-bazel-4.4", +) +load("@robolectric//bazel:robolectric.bzl", "robolectric_repositories") +robolectric_repositories() + load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() + +RULES_JVM_EXTERNAL_TAG = "3.2" + +http_archive( + name = "rules_jvm_external", + strip_prefix = "rules_jvm_external-%s" % RULES_JVM_EXTERNAL_TAG, + sha256 = "82262ff4223c5fda6fb7ff8bd63db8131b51b413d26eb49e3131037e79e324af", + url = "https://github.com/bazelbuild/rules_jvm_external/archive/%s.zip" % RULES_JVM_EXTERNAL_TAG, +) + +load("@rules_jvm_external//:defs.bzl", "maven_install") + # Set up TF. -load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") -tf_workspace(tf_repo_name="@org_tensorflow") +load("@org_tensorflow//tensorflow:workspace3.bzl", "workspace") +workspace() +load("@org_tensorflow//tensorflow:workspace2.bzl", "workspace") # buildifier: disable=load +workspace() +load("@org_tensorflow//tensorflow:workspace1.bzl", "workspace") # buildifier: disable=load +workspace() +load("@org_tensorflow//tensorflow:workspace0.bzl", "workspace") # buildifier: disable=load +workspace() load("//third_party/tensorflow:tf_configure.bzl", "tf_configure") tf_configure(name = "local_config_tf") @@ -346,35 +436,38 @@ load("@upb//bazel:repository_defs.bzl", "bazel_version_repository") bazel_version_repository(name = "bazel_version") - -# Set up Android. -load("//third_party/android:android_configure.bzl", "android_configure") -android_configure(name="local_config_android") -load("@local_config_android//:android.bzl", "android_workspace") -android_workspace() - python_configure(name = "local_config_python") +ATS_TAG = "androidx-test-1.3.0" +http_archive( + name = "android_test_support", + strip_prefix = "android-test-%s" % ATS_TAG, + urls = ["https://github.com/android/android-test/archive/%s.tar.gz" % ATS_TAG], +) +load("@android_test_support//:repo.bzl", "android_test_repositories") +android_test_repositories() # Maven dependencies. -RULES_JVM_EXTERNAL_TAG = "3.2" - -http_archive( - name = "rules_jvm_external", - strip_prefix = "rules_jvm_external-%s" % RULES_JVM_EXTERNAL_TAG, - sha256 = "82262ff4223c5fda6fb7ff8bd63db8131b51b413d26eb49e3131037e79e324af", - url = "https://github.com/bazelbuild/rules_jvm_external/archive/%s.zip" % RULES_JVM_EXTERNAL_TAG, -) - -load("@rules_jvm_external//:defs.bzl", "maven_install") - maven_install( artifacts = [ "androidx.annotation:annotation:aar:1.1.0", + "androidx.annotation:annotation-experimental:1.1.0", + "androidx.multidex:multidex:jar:2.0.1", + "androidx.test:core:jar:1.3.0", + "androidx.test.ext:junit:jar:1.1.2", + "androidx.test:runner:jar:1.3.0", + "com.google.android.odml:image:aar:1.0.0-beta1", + "com.google.truth:truth:jar:1.1", + "commons-io:commons-io:jar:2.8.0", + # Mockito >= 3.4.6 cannot pass bazel desugar. + "org.mockito:mockito-android:jar:3.0.0", + "org.mockito:mockito-core:jar:3.0.0", + "org.mockito:mockito-inline:jar:3.0.0", + "org.robolectric:robolectric:jar:4.4", + "junit:junit:jar:4.13", ], repositories = [ - "https://jcenter.bintray.com", "https://maven.google.com", "https://dl.google.com/dl/android/maven2", "https://repo1.maven.org/maven2", @@ -382,3 +475,23 @@ fetch_sources = True, version_conflict_policy = "pinned", ) + +http_archive( + name = "tf_toolchains", + sha256 = "d72b2e52baf0592f5b94347b128ef75422fc22f63dfcf2d5fd46bc732cab052b", + strip_prefix = "toolchains-1.3.0", + urls = [ + "http://mirror.tensorflow.org/github.com/tensorflow/toolchains/archive/v1.3.0.tar.gz", + "https://github.com/tensorflow/toolchains/archive/v1.3.0.tar.gz", + ], +) + +load("@tf_toolchains//toolchains/embedded/arm-linux:arm_linux_toolchain_configure.bzl", "arm_linux_toolchain_configure") + +# TFLite crossbuild toolchain for embeddeds Linux +arm_linux_toolchain_configure( + name = "local_config_embedded_arm", + build_file = "@tf_toolchains//toolchains/embedded/arm-linux:BUILD", + aarch64_repo = "../aarch64_linux_toolchain", + armhf_repo = "../armhf_linux_toolchain", +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/BUILD index 258f29bc..0830d48 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/BUILD
@@ -8,17 +8,21 @@ exports_files(["LICENSE"]) -# LINT.IfChange package_group( name = "users", + includes = [ + ":internal", + ], packages = [ - "//tensorflow_lite_support/...", - "//third_party/py/tensorflow_examples/...", - "//third_party/tensorflow_models/...", ], ) -# Remove internal path from tensorflow_lite_support:users in the copybara file. -# LINT.ThenChange(//tensorflow_lite_support/copy.bara.sky) + +package_group( + name = "internal", + packages = [ + "//tensorflow_lite_support/...", + ], +) # Config setting for determining if we are building for Android. config_setting(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/README.md b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/README.md new file mode 100644 index 0000000..6221f02 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/README.md
@@ -0,0 +1,18 @@ +# Acceleration allowlisting + +A complementary directory for the work of +[accelerator allowlisting](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/experimental/acceleration) +in TensorFlow Lite. + +## Coral Edge TPU plugin + +The Coral Edge TPU delegate plugin used in the +[acceleration library](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/port/default/tflite_wrapper.h). +See +[CoralSettings](https://github.com/tensorflow/tensorflow/blob/896fecee319ffeb4af2a3c0b5436f3a55ab058fa/tensorflow/lite/experimental/acceleration/configuration/configuration.proto#L323) +about how to configure the Coral Edge TPU plugin. You can use the acceleration +library together with +[Task Library](https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/cc/task). +Configure your desired accelerator, including the Coral plugin through the +options of each task, i.e. +[image_classifier_options](https://github.com/tensorflow/tflite-support/blob/43f1267b99f1dbc27c7c5b2e1111e1ff6b9121ea/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto#L79).
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/BUILD new file mode 100644 index 0000000..d6da77e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/BUILD
@@ -0,0 +1,136 @@ +load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") + +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "gpu_plugin", + visibility = [ + "//visibility:public", + ], + deps = [ + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:gpu_plugin", + ], + alwayslink = 1, # For registration to always run. +) + +cc_library( + name = "nnapi_plugin", + visibility = [ + "//visibility:public", + ], + deps = [ + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:nnapi_plugin", + ], + alwayslink = 1, # For registration to always run. +) + +cc_library( + name = "hexagon_plugin", + visibility = [ + "//visibility:public", + ], + deps = [ + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:hexagon_plugin", + ], + alwayslink = 1, # For registration to always run. +) + +cc_library( + name = "xnnpack_plugin", + visibility = [ + "//visibility:public", + ], + deps = [ + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:xnnpack_plugin", + ], + alwayslink = 1, # For registration to always run. +) + +# To use the edgetpu_coral_plugin externally, add the following flags to the bazel command: +# --define darwinn_portable=1 +cc_library( + name = "edgetpu_coral_plugin", + srcs = ["edgetpu_coral_plugin.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_glog//:glog", + "@libedgetpu//tflite/public:edgetpu_c", + "@libedgetpu//tflite/public:oss_edgetpu_direct_all", # buildcleaner: keep + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:delegate_registry", + ], + alwayslink = 1, # For registration to always run. +) + +# To test it externally, plugin a Coral device, and run the following command: +# bazel test tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin_test \ +# --define darwinn_portable=1 +cc_test( + name = "edgetpu_coral_plugin_test", + srcs = ["edgetpu_coral_plugin_test.cc"], + data = [ + "//tensorflow_lite_support/acceleration/configuration/testdata:test_files", + ], + tags = [ + "manual", + "notap", # Requires edge TPU device. + ], + deps = [ + ":edgetpu_coral_plugin", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "@com_google_googletest//:gtest_main", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:delegate_registry", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + +# Targets for delegate plugin library Maven release. + +# GPU delegate plugin library. +tflite_jni_binary( + name = "libgpu_delegate_plugin.so", + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + visibility = ["//visibility:private"], + deps = [ + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration/c:gpu_plugin", + ], +) + +cc_library( + name = "gpu_delegate_plugin_native", + srcs = [ + ":libgpu_delegate_plugin.so", + ], + visibility = ["//visibility:private"], +) + +# Android target of Acceleration@Scale GPU plugin. +# Use this target when GPU delegate is selected in the Task Library Java API. +android_library( + name = "gpu_delegate_plugin_android", + visibility = [ + "//visibility:public", + ], + exports = [":gpu_delegate_plugin_native"], +) + +# AAR target of Acceleration@Scale GPU acceleration for OSS release. +aar_with_jni( + name = "gpu-delegate-plugin", + android_library = ":gpu_delegate_plugin_android", +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc new file mode 100644 index 0000000..6a16d12 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc
@@ -0,0 +1,170 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <unordered_map> + +#include <glog/logging.h> +#include "absl/container/node_hash_map.h" // from @com_google_absl +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/strings/match.h" // from @com_google_absl +#include "absl/strings/numbers.h" // from @com_google_absl +#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h" +#include "tflite/public/edgetpu_c.h" + +namespace tflite { +namespace delegates { +namespace { + +constexpr int kDEFAULT_USB_MAX_BULK_IN_QUEUE_LENGTH = 32; +constexpr char kUsb[] = "usb"; +constexpr char kPci[] = "pci"; + +inline std::string ConvertPerformance( + const CoralSettings_::Performance& from_performance) { + switch (from_performance) { + case CoralSettings_::Performance_LOW: + return "Low"; + case CoralSettings_::Performance_MEDIUM: + return "Medium"; + case CoralSettings_::Performance_HIGH: + return "High"; + default: + return "Max"; + } +} + +inline std::string ConvertBool(bool from_bool) { + return from_bool ? "True" : "False"; +} + +bool MatchDevice(const std::string& device, + const std::string& type, + int* index) { + const auto prefix(type + ":"); + if (!absl::StartsWith(device, prefix)) + return false; + if (!absl::SimpleAtoi(device.substr(prefix.size()), index)) + return false; + if (*index < 0) + return false; + return true; +} + +// device_index corresponds to specific device type, e.g. "usb:0" means the +// first USB device or "pci:0" means the first PCIe device. +TfLiteDelegate* CreateEdgeTpuDelegate( + absl::optional<edgetpu_device_type> device_type, + absl::optional<int> device_index, + const absl::node_hash_map<std::string, std::string>& device_options) { + std::vector<edgetpu_option> options(device_options.size()); + size_t i = 0; + for (auto& device_option : device_options) { + options[i++] = {device_option.first.c_str(), device_option.second.c_str()}; + } + + size_t num_devices; + std::unique_ptr<edgetpu_device, decltype(&edgetpu_free_devices)> devices( + edgetpu_list_devices(&num_devices), &edgetpu_free_devices); + + if (!device_index.has_value()) { + return CreateEdgeTpuDelegate(device_type, 0, device_options); + } else { + const int index = device_index.value(); + if (device_type.has_value()) { + int type_index = 0; + for (size_t i = 0; i < num_devices; i++) { + const auto& device = devices.get()[i]; + if (device.type == device_type.value() && type_index++ == index) + return edgetpu_create_delegate(device.type, device.path, + options.data(), options.size()); + } + } else { + if (index < num_devices) + return edgetpu_create_delegate(devices.get()[index].type, + devices.get()[index].path, + options.data(), options.size()); + } + return nullptr; + } +} + +TfLiteDelegate* CreateEdgeTpuDelegate( + const std::string& device, + const absl::node_hash_map<std::string, std::string>& options) { + if (device.empty()) { + return CreateEdgeTpuDelegate(absl::nullopt, absl::nullopt, options); + } else if (device == kUsb) { + return CreateEdgeTpuDelegate(EDGETPU_APEX_USB, absl::nullopt, options); + } else if (device == kPci) { + return CreateEdgeTpuDelegate(EDGETPU_APEX_PCI, absl::nullopt, options); + } else { + int index; + if (MatchDevice(device, "", &index)) { + return CreateEdgeTpuDelegate(absl::nullopt, index, options); + } else if (MatchDevice(device, kUsb, &index)) { + return CreateEdgeTpuDelegate(EDGETPU_APEX_USB, index, options); + } else if (MatchDevice(device, kPci, &index)) { + return CreateEdgeTpuDelegate(EDGETPU_APEX_PCI, index, options); + } else { + LOG(ERROR) << "Cannot match the given device string (" << device + << ") with a Coral device."; + return nullptr; + } + } +} + +class EdgeTpuCoralPlugin : public DelegatePluginInterface { + public: + TfLiteDelegatePtr Create() override { + return TfLiteDelegatePtr(CreateEdgeTpuDelegate(device_, options_), + edgetpu_free_delegate); + } + + int GetDelegateErrno(TfLiteDelegate* from_delegate) override { return 0; } + + static std::unique_ptr<DelegatePluginInterface> New( + const TFLiteSettings& acceleration) { + return absl::make_unique<EdgeTpuCoralPlugin>(acceleration); + } + + explicit EdgeTpuCoralPlugin(const TFLiteSettings& tflite_settings) { + const auto* coral_settings = tflite_settings.coral_settings(); + if (!coral_settings) { + return; + } + + device_ = coral_settings->device()->str(); + options_.insert( + {"Performance", ConvertPerformance(coral_settings->performance())}); + options_.insert( + {"Usb.AlwaysDfu", ConvertBool(coral_settings->usb_always_dfu())}); + options_.insert( + {"Usb.MaxBulkInQueueLength", + std::to_string(coral_settings->usb_max_bulk_in_queue_length() == 0 + ? kDEFAULT_USB_MAX_BULK_IN_QUEUE_LENGTH + : coral_settings->usb_max_bulk_in_queue_length())}); + } + + private: + std::string device_; + absl::node_hash_map<std::string, std::string> options_; +}; +} // namespace + +TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION(EdgeTpuCoralPlugin, + EdgeTpuCoralPlugin::New); +} // namespace delegates +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc new file mode 100644 index 0000000..cc183a6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc
@@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +namespace tflite { +namespace delegates { +namespace { + +constexpr char kEdgeTpuModelFilePath[] = + "tensorflow_lite_support/acceleration/configuration/testdata/" + "mobilenet_v1_1.0_224_quant_edgetpu.tflite"; +constexpr char kRegularModelFilePath[] = + "tensorflow_lite_support/acceleration/configuration/testdata/" + "mobilenet_v1_1.0_224_quant.tflite"; +constexpr char kImagePath[] = + "tensorflow_lite_support/acceleration/configuration/testdata/" + "burger.jpg"; + +using ::tflite::task::vision::DecodeImageFromFile; +using ::tflite::task::vision::ImageData; +using ::tflite::task::vision::ImageDataFree; + +using EdgeTpuCoralPluginTest = testing::TestWithParam<std::string>; + +INSTANTIATE_TEST_SUITE_P(CoralPluginTests, + EdgeTpuCoralPluginTest, + testing::Values(kRegularModelFilePath, + kEdgeTpuModelFilePath)); + +TEST_P(EdgeTpuCoralPluginTest, CreateEdgeTpuCoralPlugin) { + // Create the Coral delegate from the Coral plugin. + flatbuffers::FlatBufferBuilder flatbuffer_builder; + auto settings = flatbuffers::GetTemporaryPointer( + flatbuffer_builder, + CreateTFLiteSettings(flatbuffer_builder, tflite::Delegate_EDGETPU_CORAL)); + auto plugin = ::tflite::delegates::DelegatePluginRegistry::CreateByName( + "EdgeTpuCoralPlugin", *settings); + auto coral_delegate = plugin->Create(); + + // Load the tflite model file. + std::unique_ptr<::tflite::FlatBufferModel> tflite_model = + ::tflite::FlatBufferModel::BuildFromFile(GetParam().c_str()); + ASSERT_NE(tflite_model, nullptr); + + // Create the tflite interpreter. + tflite::ops::builtin::BuiltinOpResolver resolver; + std::unique_ptr<::tflite::Interpreter> interpreter; + ASSERT_EQ(::tflite::InterpreterBuilder(*tflite_model, resolver)(&interpreter), + kTfLiteOk); + ASSERT_NE(interpreter, nullptr); + interpreter->ModifyGraphWithDelegate(coral_delegate.get()); + + // Verifies that interpreter runs correctly. + // To open source the code under tensorflow/lite, the following code needs to + // be stript from the Task library dependency, meaning forking or rewriting + // `LoadImage` and `ImageData`. + // `ASSERT_OK_AND_ASSIGN` is not available externally. + auto rgb_image_or = DecodeImageFromFile(kImagePath); + ASSERT_TRUE(rgb_image_or.ok()); + + ImageData rgb_image = rgb_image_or.value(); + const uint8_t* input_data = rgb_image.pixel_data; + size_t input_data_byte_size = + rgb_image.width * rgb_image.height * rgb_image.channels * sizeof(uint8_t); + + ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk); + uint8_t* input_tensor = interpreter->typed_input_tensor<uint8_t>(0); + memcpy(input_tensor, input_data, input_data_byte_size); + ASSERT_EQ(interpreter->Invoke(), kTfLiteOk); + uint8_t* output_tensor = interpreter->typed_output_tensor<uint8_t>(0); + // `cheeseburger` is the 935th item in the label file of + // "mobilenet_v1_1.0_224_quant_edgetpu.tflite". See labels.txt. + EXPECT_EQ(output_tensor[934], 255); + ImageDataFree(&rgb_image); +} + +} // namespace +} // namespace delegates +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/BUILD new file mode 100644 index 0000000..d3dcbaf --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/BUILD
@@ -0,0 +1,12 @@ +package( + default_visibility = ["//tensorflow_lite_support:internal"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "test_files", + srcs = glob([ + "*.tflite", + "*.jpg", + ]), +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/burger.jpg b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/burger.jpg new file mode 100644 index 0000000..58ec72b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/burger.jpg Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/mobilenet_v1_1.0_224_quant.tflite b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/mobilenet_v1_1.0_224_quant.tflite new file mode 100644 index 0000000..437640b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/mobilenet_v1_1.0_224_quant.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/mobilenet_v1_1.0_224_quant_edgetpu.tflite b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/mobilenet_v1_1.0_224_quant_edgetpu.tflite new file mode 100644 index 0000000..416152a1 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/mobilenet_v1_1.0_224_quant_edgetpu.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/c/BUILD new file mode 100644 index 0000000..ecee826 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/BUILD
@@ -0,0 +1,23 @@ +package( + default_visibility = ["//tensorflow_lite_support:internal"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "common", + srcs = ["common.cc"], + hdrs = ["common.h"], +) + +cc_library( + name = "common_utils", + srcs = ["common_utils.cc"], + hdrs = ["common_utils.h"], + deps = [ + ":common", + "//tensorflow_lite_support/cc:common", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc new file mode 100644 index 0000000..f0974ed2 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc
@@ -0,0 +1,25 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/c/common.h" + +#include <cstdlib> + +void TfLiteSupportErrorDelete(TfLiteSupportError* error) { + // `strdup` obtains memory using `malloc` and the memory needs to be + // released using `free`. + free(error->message); + delete error; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common.h b/third_party/tflite_support/src/tensorflow_lite_support/c/common.h new file mode 100644 index 0000000..3ced6422 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common.h
@@ -0,0 +1,202 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_COMMON_H_ +#define TENSORFLOW_LITE_SUPPORT_C_COMMON_H_ + +// Defines C struct and error codes for describing any error returned from the C +// Task Library. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Error codes for TensorFlow Lite Task Library C APIs. +// +// Holds one to one mapping with `TfLiteSupportStatus` code starting from kError +// = 1. Omits `kOk` since `TfLiteErrorCode` is only to be used in the event of +// an error and does not account for success unlike `TfLiteSupportStatus`. In +// case of success, TensorFlow Lite Task Library C APIs return the appropriate +// return value and a null error. One to one mapping makes it easier to convert +// between `TfLiteSupportStatus` and `TfLiteSupportErrorCode` without long +// switch statements. +// +// Also holds error codes mapping to absl::Status::code() starting from +// kNotFound = 900 in cases where the absl::Status payload can't +// be mapped to a `TfLiteSupportStatus` code. kErrorCodeFirst and kErrorCodeLast +// are also provided for safety checks during conversion between +// `TfLiteSupportStatus` and `TfLiteSupportErrorCode`. In case of modifications +// in error codes, ensure that kErrorCodeFirst and kErrorCodeLast is +// respectively, set to the least and greatest enum value amongst the error +// codes mapping to TfLiteSupportStatus. +enum TfLiteSupportErrorCode { + // Unspecified error. + kError = 1, + // Invalid argument specified. + kInvalidArgumentError = 2, + // Invalid FlatBuffer file or buffer specified. + kInvalidFlatBufferError = 3, + // Model contains a builtin op that isn't supported by the OpResolver or + // delegates. + kUnsupportedBuiltinOpError = 4, + // Model contains a custom op that isn't supported by the OpResolver or + // delegates. + kUnsupportedCustomOpError = 5, + + // File I/O error codes. + + // No such file. + kFileNotFoundError = 100, + // Permission issue. + kFilePermissionDeniedError, + // I/O error when reading file. + kFileReadError, + // I/O error when mmap-ing file. + kFileMmapError, + + // TensorFlow Lite metadata error codes. + + // Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer. + kMetadataInvalidSchemaVersionError = 200, + // No such associated file within metadata, or file has not been packed. + kMetadataAssociatedFileNotFoundError, + // ZIP I/O error when unpacking an associated file. + kMetadataAssociatedFileZipError, + // Inconsistency error between the metadata and actual TF Lite model. + // E.g.: number of labels and output tensor values differ. + kMetadataInconsistencyError, + // Invalid process units specified. + // E.g.: multiple ProcessUnits with the same type for a given tensor. + kMetadataInvalidProcessUnitsError, + // Inconsistency error with the number of labels. + // E.g.: label files for different locales have a different number of labels. + kMetadataNumLabelsMismatchError, + // Score calibration parameters parsing error. + // E.g.: too many parameters provided in the corresponding associated file. + kMetadataMalformedScoreCalibrationError, + // Unexpected number of subgraphs for the current task. + // E.g.: image classification expects a single subgraph. + kMetadataInvalidNumSubgraphsError, + // A given tensor requires NormalizationOptions but none were found. + // E.g.: float input tensor requires normalization to preprocess input images. + kMetadataMissingNormalizationOptionsError, + // Invalid ContentProperties specified. + // E.g. expected ImageProperties, got BoundingBoxProperties. + kMetadataInvalidContentPropertiesError, + // Metadata is mandatory but was not found. + // E.g. current task requires TFLite Model Metadata but none was found. + kMetadataNotFoundError, + // Associated TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS file is mandatory but + // none was found or it was empty. + // E.g. current task requires labels but none were found. + kMetadataMissingLabelsError, + // The ProcessingUnit for tokenizer is not correctly configured. + // E.g BertTokenizer doesn't have a valid vocab file associated. + kMetadataInvalidTokenizerError, + + // Input tensor(s) error codes. + + // Unexpected number of input tensors for the current task. + // E.g. current task expects a single input tensor. + kInvalidNumInputTensorsError = 300, + // Unexpected input tensor dimensions for the current task. + // E.g.: only 4D input tensors supported. + kInvalidInputTensorDimensionsError, + // Unexpected input tensor type for the current task. + // E.g.: current task expects a uint8 pixel image as input. + kInvalidInputTensorTypeError, + // Unexpected input tensor bytes size. + // E.g.: size in bytes does not correspond to the expected number of pixels. + kInvalidInputTensorSizeError, + // No correct input tensor found for the model. + // E.g.: input tensor name is not part of the text model's input tensors. + kInputTensorNotFoundError, + + // Output tensor(s) error codes. + + // Unexpected output tensor dimensions for the current task. + // E.g.: only a batch size of 1 is supported. + kInvalidOutputTensorDimensionsError = 400, + // Unexpected input tensor type for the current task. + // E.g.: multi-head model with different output tensor types. + kInvalidOutputTensorTypeError, + // No correct output tensor found for the model. + // E.g.: output tensor name is not part of the text model's output tensors. + kOutputTensorNotFoundError, + // Unexpected number of output tensors for the current task. + // E.g.: current task expects a single output tensor. + kInvalidNumOutputTensorsError, + + // Image processing error codes. + + // Unspecified image processing failures. + kImageProcessingError = 500, + // Unexpected input or output buffer metadata. + // E.g.: rotate RGBA buffer to Grayscale buffer by 90 degrees. + kImageProcessingInvalidArgumentError, + // Image processing operation failures. + // E.g. libyuv rotation failed for an unknown reason. + kImageProcessingBackendError, + + // Convenience error codes for condition checks during type casting. + // + // Codes mapping to absl status codes should not be considered for these + // ranges. + // They must be used exclsively for checking if error codes fall in valid + // ranges when converting between TfLiteSupportStatus and + // TfLiteSupportErrorCodee. + + // Ensure it holds the least enum value amongst error codes mapping to + // TfLiteSupportStatus. + kErrorCodeFirst = kError, + // Ensure it holds the greatest enum value amongst error codes mapping to + // TfLiteSupportStatus. + kErrorCodeLast = kImageProcessingBackendError, + + // Absl Status Codes Mapping + // + // Codes starting from 900 will be used to map absl::Status created by TfLite + // and are used as is by TfLite Support C++ layer. Such absl status objects + // don't have a TfLiteSupportStatus in the payload that can be mapped to other + // error codes in this struct. You must use the absl::Status::code() and map + // them to the following error codes in such cases. + // For more info on respective absl status codes, please see: + // https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h#L91 + + // kNotFound indicates some requested entity (such as a file or directory) + // was not found. + kNotFoundError = 900, + // kInternal indicates an internal error has occurred + // and some invariants expected by the underlying system have not been + // satisfied. This error code is reserved for serious errors. + kInternalError, +}; + +// A `TfLiteSupportError` encapsulates an error code and a descriptive message +// to return in the event of an error being encountered in any TensorFlow Lite +// Task Library C API. +typedef struct TfLiteSupportError { + // Holds the error code. + enum TfLiteSupportErrorCode code; + // Detailed description of the error. + char* message; +} TfLiteSupportError; + +void TfLiteSupportErrorDelete(TfLiteSupportError* error); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_COMMON_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc new file mode 100644 index 0000000..39afb9c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc
@@ -0,0 +1,111 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/c/common_utils.h" + +#include <string> + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/cord.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/common.h" + +namespace tflite { +namespace support { + +void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code, + const char* message, + TfLiteSupportError** error) { + if (error == nullptr) + return; + + *error = new TfLiteSupportError; + (*error)->code = code; + (*error)->message = strdup(message); +} + +void CreateTfLiteSupportErrorWithStatus(const absl::Status& status, + TfLiteSupportError** error) { + if (status.ok() || error == nullptr) + return; + + // Payload of absl::Status created by the tflite task library stores an + // appropriate value of the enum TfLiteSupportStatus. The integer value + // corresponding to the TfLiteSupportStatus enum stored in the payload is + // extracted here to later map to the appropriate error code to be returned. + // In cases where the enum is not stored in (payload is NULL or the payload + // string cannot be converted to an integer), we set the error code value to + // be 1 (kError of TfLiteErrorCode used in the C library to signify any errors + // not falling into other categories.) Since payload is of type absl::Cord + // that can be type cast into an absl::optional<std::string>, we use the + // std::stoi function to convert it into an integer code if possible. + int generic_error_code = static_cast<int>(kError); + int error_code; + try { + // Try converting payload to integer if payload is not empty. Otherwise + // convert a string signifying generic error code kError to integer. + error_code = std::stoi(static_cast<absl::optional<std::string>>( + status.GetPayload(kTfLiteSupportPayload)) + .value_or(std::to_string(generic_error_code))); + } catch (std::invalid_argument& e) { + // If non empty payload string cannot be converted to an integer. Set error + // code to 1(kError). + error_code = generic_error_code; + } + + // If error_code is outside the range of enum values possible or is kError, we + // try to map the absl::Status::code() to assign appropriate + // TfLiteSupportErrorCode or kError in default cases. Note: The mapping to + // absl::Status::code() is done to generate a more specific error code than + // kError in cases when the payload can't be mapped to TfLiteSupportStatus. + // This can happen when absl::Status returned by TfLite are in turn returned + // without moodification by TfLite Support Methods. + if (error_code > static_cast<int>(kErrorCodeLast) || + error_code <= static_cast<int>(kErrorCodeFirst)) { + switch (status.code()) { + case absl::StatusCode::kInternal: + error_code = kInternalError; + break; + case absl::StatusCode::kInvalidArgument: + error_code = kInvalidArgumentError; + break; + case absl::StatusCode::kNotFound: + error_code = kNotFoundError; + break; + default: + error_code = kError; + break; + } + } + + // Creates the TfLiteSupportError with the appropriate error + // TfLiteSupportErrorCode and message. TfLiteErrorCode has a one to one + // mapping with TfLiteSupportStatus starting from the value 1(kError) and + // hence will be correctly initialized if directly cast from the integer code + // derived from TfLiteSupportStatus stored in payload. TfLiteErrorCode omits + // kOk = 0 of TfLiteSupportStatus. + // + // Stores a string including absl status code and message(if non empty) as the + // error message See + // https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h#L514 + // for explanation. absl::Status::message() can also be used but not always + // guaranteed to be non empty. + CreateTfLiteSupportError( + static_cast<TfLiteSupportErrorCode>(error_code), + status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str(), + error); +} + +} // namespace support +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h new file mode 100644 index 0000000..551f64a5 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h
@@ -0,0 +1,57 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_COMMON_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_C_COMMON_UTILS_H_ + +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow_lite_support/c/common.h" + +// Utils for Conversion of absl::Status to TfLiteError +// ----------------------------------------------------------------- +// Meant to be used with task C apis. + +namespace tflite { +namespace support { + +// Creates a TfLiteSupportError with a TfLiteSupportErrorCode and message. +void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code, + const char* message, + TfLiteSupportError** error); + +// Creates a TfLiteSupportError from absl::Status and passes it back as a +// parameter which is a pointer to the error pointer. +// +// Example Usage With Image Classifier +// +// APIs: TfLiteImageClassifier* TfLiteImageClassifierFromOptions( +// const TfLiteImageClassifierOptions* options, +// TfLiteSupportError **error) { +// // Necessary checks +// tflite::support::StatusOr<std::unique_ptr<ImageClassifier>> classifier_status +// = // Call to create Cpp Image Classifier. +// if (classifier_status.ok()) { +// Code to return classifier +// } else { +// ::tflite::support::CreateTfLiteSupportErrorWithStatus(classifier_status.status(), +// error); +// return nullptr; +// } +//} +void CreateTfLiteSupportErrorWithStatus(const absl::Status& status, + TfLiteSupportError** error); + +} // namespace support +} // namespace tflite +#endif // TENSORFLOW_LITE_SUPPORT_C_COMMON_UTILS_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/c/task/core/BUILD new file mode 100644 index 0000000..659f165 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/core/BUILD
@@ -0,0 +1,9 @@ +package( + default_visibility = ["//tensorflow_lite_support:internal"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "base_options", + hdrs = ["base_options.h"], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/core/base_options.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/core/base_options.h new file mode 100644 index 0000000..1093b63 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/core/base_options.h
@@ -0,0 +1,73 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_CORE_BASE_OPTIONS_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_CORE_BASE_OPTIONS_H_ + +#include <stdint.h> + +// Defines C Structs for Base Options Shared by all tasks. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Holds cpu settings. +typedef struct TfLiteCpuSettings { + // Specifies the number of threads to be used for TFLite + // ops that support multi-threading when running inference with CPU. + // num_threads should be greater than 0 or equal to -1. Setting num_threads to + // -1 has the effect to let TFLite runtime set the value. + int num_threads; +} TfLiteCpuSettings; + +// Holds settings for one possible acceleration configuration. +typedef struct TfLiteComputeSettings { + // Holds cpu settings + TfLiteCpuSettings cpu_settings; +} TfLiteComputeSettings; + +// Represents external files used by the Task APIs (e.g. TF Lite Model File). +// For now you can only specify the path of the file using file_path: +// In future other sources may be supported. +typedef struct TfLiteExternalFile { + // The path to the file to open. + const char* file_path; + // Additional option for byte data when it's supported. +} TfLiteExternalFile; + +// Holds the base options that is used for creation of any type of task. It has +// fields withh important information acceleration configuration, tflite model +// source etc. +// This struct must be zero initialized before setting any options as this +// will result in seg faults. +typedef struct TfLiteBaseOptions { + // The external model file, as a single standalone TFLite file. It could be + // packed with TFLite Model Metadata[1] and associated files if exist. Fail to + // provide the necessary metadata and associated files might result in errors. + // Check the documentation for each task about the specific requirement. + // [1]: https://www.tensorflow.org/lite/convert/metadata + TfLiteExternalFile model_file; + + // Holds settings for one possible acceleration configuration + // including.cpu/gpu settings. Please see documentation of + // TfLiteComputeSettings and its members for more details. + TfLiteComputeSettings compute_settings; +} TfLiteBaseOptions; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_CORE_BASE_OPTIONS_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/BUILD new file mode 100644 index 0000000..2d0cef6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/BUILD
@@ -0,0 +1,42 @@ +package( + default_visibility = ["//tensorflow_lite_support:internal"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "category", + hdrs = ["category.h"], +) + +cc_library( + name = "classification_result", + srcs = [ + "classification_result.cc", + ], + hdrs = ["classification_result.h"], + deps = [ + ":category", + ], +) + +cc_library( + name = "bounding_box", + hdrs = ["bounding_box.h"], +) + +cc_library( + name = "classification_options", + hdrs = ["classification_options.h"], +) + +cc_library( + name = "detection_result", + srcs = [ + "detection_result.cc", + ], + hdrs = ["detection_result.h"], + deps = [ + ":bounding_box", + ":category", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/bounding_box.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/bounding_box.h new file mode 100644 index 0000000..4a28f77 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/bounding_box.h
@@ -0,0 +1,45 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_BOUNDING_BOX_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_BOUNDING_BOX_H_ + +#include <stdint.h> + +// Defines C Struct for Bounding Box Shared by Vision Tasks. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Holds the region of interest used for image classification. +typedef struct TfLiteBoundingBox { + // The X coordinate of the top-left corner, in pixels. + int origin_x; + + // The Y coordinate of the top-left corner, in pixels. + int origin_y; + + // The width of the bounding box, in pixels. + int width; + + // The height of the bounding box, in pixels. + int height; +} TfLiteBoundingBox; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_BOUNDING_BOX_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/category.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/category.h new file mode 100644 index 0000000..5c049ec --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/category.h
@@ -0,0 +1,49 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_CATEGORY_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_CATEGORY_H_ + +// Defines C structure for a Category which encapsulates a single predicted +// class. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// A single predicted class. +typedef struct TfLiteCategory { + // The index of the class in the corresponding label map, usually packed in + // the TFLite Model Metadata [1]. + // + // [1]: https://www.tensorflow.org/lite/convert/metadata + int index; + + // The score for this class e.g. (but not necessarily) a probability in [0,1]. + float score; + + // A human readable name of the class filled from the label map. + char* display_name; + // An ID for the class, not necessarily human-readable (e.g. a Google + // Knowledge Graph ID [1]), filled from the label map. + // + // [1]: https://developers.google.com/knowledge-graph + char* label; +} TfLiteCategory; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_CATEGORY_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_options.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_options.h new file mode 100644 index 0000000..829e9e97 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_options.h
@@ -0,0 +1,66 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_CLASSIFICATION_OPTIONS_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_CLASSIFICATION_OPTIONS_H_ + +#include <stdint.h> + +// Defines C Struct for Classification Options Shared by All Classification +// Tasks. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Holds pointer to array of C strings and length for looping through the array. +typedef struct TfLiteStringArrayOption { + // Length of list. length can be used to loop through list. + int length; + + // Array of C strings. + char** list; +} TfLiteStringArrayOption; + +// Holds settings for any single classification task. +typedef struct TfLiteClassificationOptions { + // Optional denylist of class labels. If non NULL, classifications whose + // class label is in this set will be filtered out. Duplicate or unknown + // class labels are ignored. Mutually exclusive with label_allowlist. + TfLiteStringArrayOption label_denylist; + + // Optional allowlist of class labels. If non-empty, classifications whose + // class label is not in this set will be filtered out. Duplicate or unknown + // class labels are ignored. Mutually exclusive with label_denylist. + TfLiteStringArrayOption label_allowlist; + + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + char* display_names_local; + + // The maximum number of top-scored classification results to return. If < 0, + // all available results will be returned. If 0, an invalid argument error is + // returned. Defaults to -1. + int max_results; + + // Score threshold, overrides the ones provided in the model metadata + // (if any). Results below this value are rejected. + float score_threshold; +} TfLiteClassificationOptions; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_CLASSIFICATION_OPTIONS_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc new file mode 100644 index 0000000..c277df0 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc
@@ -0,0 +1,46 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/c/task/processor/classification_result.h" + +#include <cstdlib> +#include <memory> + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +void TfLiteClassificationResultDelete( + TfLiteClassificationResult* classification_result) { + for (int head = 0; head < classification_result->size; ++head) { + TfLiteClassifications classifications = + classification_result->classifications[head]; + for (int rank = 0; rank < classifications.size; ++rank) { + // `strdup` obtains memory using `malloc` and the memory needs to be + // released using `free`. + free(classifications.categories[rank].display_name); + free(classifications.categories[rank].label); + } + + delete[] classifications.categories; + } + + delete[] classification_result->classifications; + delete classification_result; +} + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.h new file mode 100644 index 0000000..1b73365a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.h
@@ -0,0 +1,63 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_CLASSIFICATION_RESULT_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_CLASSIFICATION_RESULT_H_ + +#include "tensorflow_lite_support/c/task/processor/category.h" + +// Defines C structure for Classification Results and associated helper methods. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// List of predicted classes (aka labels) for a given image classifier head. +typedef struct TfLiteClassifications { + // The index of the image classifier head these classes refer to. This is + // useful for multi-head models. + int head_index; + + // Number of predicted classes which can be used to traverse the array of + // predicted classes. + int size; + + // The array of predicted classes, usually sorted by descending scores (e.g. + // from high to low probability). Since this array is dynamically allocated, + // use size to traverse through the array. + TfLiteCategory* categories; +} TfLiteClassifications; + +// Holds Image Classification results. +// Contains one set of results per image classifier head. +typedef struct TfLiteClassificationResult { + // Number of predicted classes which can be used to traverse the array of + // predicted classes. + int size; + + // Array of image classifier results per image classifier head. This array can + // have any number of results. size holds the size of this array. size should + // be used to traverse this array. + TfLiteClassifications* classifications; +} TfLiteClassificationResult; + +// Frees up the ClassificationResult Structure. +void TfLiteClassificationResultDelete( + TfLiteClassificationResult* classification_result); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_CLASSIFICATION_RESULT_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/detection_result.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/detection_result.h new file mode 100644 index 0000000..2f4c701 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/detection_result.h
@@ -0,0 +1,65 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_DETECTION_RESULT_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_DETECTION_RESULT_H_ + +#include "tensorflow_lite_support/c/task/processor/bounding_box.h" +#include "tensorflow_lite_support/c/task/processor/category.h" + +// Defines C structure for Object Detection Results and associated helper +// methods. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Bounding box and list of predicted classes (aka labels) for a detected +// object. +typedef struct TfLiteDetection { + // The bounding box of the detected object. + TfLiteBoundingBox bounding_box; + + // The array of predicted classes for the object detection represented by an + // instance of TfLiteDetection, usually sorted by descending scores (e.g. from + // high to low probability). Since this array is dynamically allocated, use + // size to traverse through the array. + TfLiteCategory* categories; + + // Number of detectd objects be used to traverse the array of the detected + // objects. + int size; +} TfLiteDetection; + +// Holds Object Detection results. +// Contains one set of results per detected object. +typedef struct TfLiteDetectionResult { + // Number of detectd objects be used to traverse the array of the detected + // objects. + int size; + + // Array of results per detected object. This array can + // have any number of results. size holds the size of this array. size should + // be used to traverse this array. + TfLiteDetection* detections; +} TfLiteDetectionResult; + +// Frees up the DetectionResult Structure. +void TfLiteDetectionResultDelete(TfLiteDetectionResult* detection_result); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_DETECTION_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/BUILD new file mode 100644 index 0000000..fd2e9740 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/BUILD
@@ -0,0 +1,79 @@ +load( + "@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", + "cc_library_with_tflite", +) + +package( + default_visibility = ["//tensorflow_lite_support:internal"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "bert_nl_classifier.h", + "nl_classifier.h", + "nl_classifier_common.h", + "bert_question_answerer.h", +]) + +cc_library_with_tflite( + name = "nl_classifier", + srcs = [ + "nl_classifier.cc", + ], + hdrs = [ + "nl_classifier.h", + "nl_classifier_common.h", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier", + ], + deps = [ + ":nl_classifier_common", + "//tensorflow_lite_support/cc/task/core:category", + "@com_google_absl//absl/strings", + ], +) + +cc_library_with_tflite( + name = "bert_nl_classifier", + srcs = [ + "bert_nl_classifier.cc", + ], + hdrs = [ + "bert_nl_classifier.h", + "nl_classifier_common.h", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/text:bert_nl_classifier", + ], + deps = [ + ":nl_classifier_common", + "//tensorflow_lite_support/cc/task/core:category", + "//tensorflow_lite_support/cc/task/text/proto:bert_nl_classifier_options_proto_inc", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "nl_classifier_common", + srcs = [ + "nl_classifier_common.cc", + ], + hdrs = [ + "nl_classifier_common.h", + ], +) + +cc_library_with_tflite( + name = "bert_question_answerer", + srcs = [ + "bert_question_answerer.cc", + ], + hdrs = [ + "bert_question_answerer.h", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/text:bert_question_answerer", + "//tensorflow_lite_support/cc/task/text:question_answerer", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc new file mode 100644 index 0000000..52907f4f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc
@@ -0,0 +1,93 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/c/task/text/bert_nl_classifier.h" + +#include <memory> + +#include "absl/strings/string_view.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/task/core/category.h" +#include "tensorflow_lite_support/cc/task/text/bert_nl_classifier.h" +#include "tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options_proto_inc.h" + +namespace { +using CategoryCpp = ::tflite::task::core::Category; +using BertNLClassifierCpp = ::tflite::task::text::BertNLClassifier; +using BertNLClassifierOptionsCpp = + ::tflite::task::text::BertNLClassifierOptions; + +const TfLiteBertNLClassifierOptions kBertNLClassifierOptionsDefault = {128}; +} // namespace + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +struct TfLiteBertNLClassifier { + std::unique_ptr<BertNLClassifierCpp> impl; +}; + +TfLiteBertNLClassifier* TfLiteBertNLClassifierCreateFromOptions( + const char* model_path, + const TfLiteBertNLClassifierOptions* options) { + BertNLClassifierOptionsCpp cc_options; + + cc_options.mutable_base_options()->mutable_model_file()->set_file_name( + model_path); + auto classifier_status = BertNLClassifierCpp::CreateFromOptions(cc_options); + + if (classifier_status.ok()) { + return new TfLiteBertNLClassifier{ + .impl = std::unique_ptr<BertNLClassifierCpp>( + dynamic_cast<BertNLClassifierCpp*>( + classifier_status.value().release()))}; + + } else { + return nullptr; + } +} + +TfLiteBertNLClassifier* TfLiteBertNLClassifierCreate(const char* model_path) { + return TfLiteBertNLClassifierCreateFromOptions( + model_path, &kBertNLClassifierOptionsDefault); +} + +Categories* TfLiteBertNLClassifierClassify( + const TfLiteBertNLClassifier* classifier, + const char* text) { + std::vector<CategoryCpp> results = + + classifier->impl->Classify(absl::string_view(text).data()); + size_t size = results.size(); + auto* categories = new Category[size]; + + for (size_t i = 0; i < size; ++i) { + categories[i].text = strdup(results[i].class_name.c_str()); + categories[i].score = results[i].score; + } + + auto* c_categories = new Categories; + c_categories->size = size; + c_categories->categories = categories; + return c_categories; +} + +void TfLiteBertNLClassifierDelete(TfLiteBertNLClassifier* classifier) { + delete classifier; +} + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h new file mode 100644 index 0000000..94138a2 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h
@@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_BERT_NL_CLASSIFIER_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_BERT_NL_CLASSIFIER_H_ + +#include "tensorflow_lite_support/c/task/text/nl_classifier_common.h" +// -------------------------------------------------------------------------- +// C API for BertNLClassifier. +// +// Usage: +// // Create the model and interpreter options. +// TfLiteBertNLClassifier* classifier = +// TfLiteBertNLClassifierCreate("/path/to/model.tflite"); +// +// // Classification. +// Categories* categories = TfLiteBertNLClassifierClassify(classifier, +// question); +// +// // Dispose of the API object. +// TfLiteBertNLClassifierDelete(classifier); + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct TfLiteBertNLClassifier TfLiteBertNLClassifier; + +typedef struct TfLiteBertNLClassifierOptions { + // Max number of tokens to pass to the model. + // + // Deprecated: max_seq_len is now read from the model (i.e. input tensor size) + // automatically. + int max_seq_len; +} TfLiteBertNLClassifierOptions; + +// Creates TfLiteBertNLClassifier from model path and options, returns nullptr +// if the file doesn't exist or is not a well formatted TFLite model path. +TfLiteBertNLClassifier* TfLiteBertNLClassifierCreateFromOptions( + const char* model_path, + const TfLiteBertNLClassifierOptions* options); + +// Creates TfLiteBertNLClassifier from model path and default options, returns +// nullptr if the file doesn't exist or is not a well formatted TFLite model +// path. +TfLiteBertNLClassifier* TfLiteBertNLClassifierCreate(const char* model_path); + +// Invokes the encapsulated TFLite model and classifies the input text. +Categories* TfLiteBertNLClassifierClassify( + const TfLiteBertNLClassifier* classifier, + const char* text); + +void TfLiteBertNLClassifierDelete(TfLiteBertNLClassifier* classifier); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_BERT_NL_CLASSIFIER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc new file mode 100644 index 0000000..1887d52 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc
@@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/c/task/text/bert_question_answerer.h" + +#include <memory> + +#include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h" +#include "tensorflow_lite_support/cc/task/text/question_answerer.h" + +namespace { +using BertQuestionAnswererCpp = ::tflite::task::text::BertQuestionAnswerer; +using QaAnswerCpp = ::tflite::task::text::QaAnswer; +} // namespace + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +struct TfLiteBertQuestionAnswerer { + std::unique_ptr<BertQuestionAnswererCpp> impl; +}; + +TfLiteBertQuestionAnswerer* TfLiteBertQuestionAnswererCreate( + const char* model_path) { + auto bert_qa_status = + BertQuestionAnswererCpp::CreateFromFile(std::string(model_path)); + if (bert_qa_status.ok()) { + return new TfLiteBertQuestionAnswerer{ + .impl = std::unique_ptr<BertQuestionAnswererCpp>( + dynamic_cast<BertQuestionAnswererCpp*>( + bert_qa_status.value().release()))}; + } else { + return nullptr; + } +} + +TfLiteQaAnswers* TfLiteBertQuestionAnswererAnswer( + const TfLiteBertQuestionAnswerer* question_answerer, + const char* context, + const char* question) { + std::vector<QaAnswerCpp> answers = question_answerer->impl->Answer( + absl::string_view(context).data(), absl::string_view(question).data()); + size_t size = answers.size(); + auto* qa_answers = new TfLiteQaAnswer[size]; + + for (size_t i = 0; i < size; ++i) { + qa_answers[i].start = answers[i].pos.start; + qa_answers[i].end = answers[i].pos.end; + qa_answers[i].logit = answers[i].pos.logit; + qa_answers[i].text = strdup(answers[i].text.c_str()); + } + + auto* c_answers = new TfLiteQaAnswers; + c_answers->size = size; + c_answers->answers = qa_answers; + return c_answers; +} + +void TfLiteBertQuestionAnswererDelete( + TfLiteBertQuestionAnswerer* bert_question_answerer) { + delete bert_question_answerer; +} + +void TfLiteQaAnswersDelete(TfLiteQaAnswers* qa_answers) { + for (int i = 0; i < qa_answers->size; i++) { + // `strdup` obtains memory using `malloc` and the memory needs to be + // released using `free`. + free(qa_answers->answers[i].text); + } + delete[] qa_answers->answers; + delete qa_answers; +} + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h new file mode 100644 index 0000000..e9a1190 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h
@@ -0,0 +1,74 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_BERT_QUESTION_ANSWERER_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_BERT_QUESTION_ANSWERER_H_ + +// -------------------------------------------------------------------------- +// C API for BertQuestionAnswerer. +// +// Usage: +// <pre><code> +// // Create the model and interpreter options. +// TfLiteBertQuestionAnswerer* qa_answerer = +// TfLiteBertQuestionAnswererCreate("/path/to/model.tflite"); +// +// // Answer a question. +// TfLiteQaAnswers* answers = TfLiteBertQuestionAnswererAnswer(qa_answerer, +// question); +// +// // Dispose of the API and QaAnswers objects. +// TfLiteBertQuestionAnswererDelete(qa_answerer); +// TfLiteQaAnswersDelete(answers); + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct TfLiteBertQuestionAnswerer TfLiteBertQuestionAnswerer; + +typedef struct TfLiteQaAnswer { + int start; + int end; + float logit; + char* text; +} TfLiteQaAnswer; + +typedef struct TfLiteQaAnswers { + int size; + TfLiteQaAnswer* answers; +} TfLiteQaAnswers; + +// Creates TfLiteBertQuestionAnswerer from model path, returns nullptr if the +// file doesn't exist or is not a well formatted TFLite model path. +TfLiteBertQuestionAnswerer* TfLiteBertQuestionAnswererCreate( + const char* model_path); + +// Invokes the encapsulated TFLite model and answers a question based on +// context. +TfLiteQaAnswers* TfLiteBertQuestionAnswererAnswer( + const TfLiteBertQuestionAnswerer* question_answerer, + const char* context, + const char* question); + +void TfLiteBertQuestionAnswererDelete( + TfLiteBertQuestionAnswerer* bert_question_answerer); + +void TfLiteQaAnswersDelete(TfLiteQaAnswers* qa_answers); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_BERT_QUESTION_ANSWERER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc new file mode 100644 index 0000000..1e6805c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc
@@ -0,0 +1,94 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/c/task/text/nl_classifier.h" + +#include <memory> + +#include "absl/strings/string_view.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/task/core/category.h" +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" + +namespace { +using CategoryCpp = ::tflite::task::core::Category; +using NLClassifierCpp = ::tflite::task::text::nlclassifier::NLClassifier; +using NLClassifierOptionsCpp = + ::tflite::task::text::nlclassifier::NLClassifierOptions; +} // namespace + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +struct TfLiteNLClassifier { + std::unique_ptr<NLClassifierCpp> impl; +}; + +TfLiteNLClassifier* TfLiteNLClassifierCreateFromOptions( + const char* model_path, + const TfLiteNLClassifierOptions* options) { + auto classifier_status = NLClassifierCpp::CreateFromFileAndOptions( + std::string(model_path), + { + .input_tensor_index = options->input_tensor_index, + .output_score_tensor_index = options->output_score_tensor_index, + .output_label_tensor_index = options->output_label_tensor_index, + .input_tensor_name = !options->input_tensor_name + ? "" + : std::string(options->input_tensor_name), + .output_score_tensor_name = + !options->output_score_tensor_name + ? "" + : std::string(options->output_score_tensor_name), + .output_label_tensor_name = + !options->output_label_tensor_name + ? "" + : std::string(options->output_label_tensor_name), + }); + + if (classifier_status.ok()) { + return new TfLiteNLClassifier{ + .impl = std::unique_ptr<NLClassifierCpp>(dynamic_cast<NLClassifierCpp*>( + classifier_status.value().release()))}; + } else { + return nullptr; + } +} + +Categories* TfLiteNLClassifierClassify(const TfLiteNLClassifier* classifier, + const char* text) { + std::vector<CategoryCpp> results = + classifier->impl->Classify(absl::string_view(text).data()); + size_t size = results.size(); + auto* categories = new Category[size]; + + for (size_t i = 0; i < size; ++i) { + categories[i].text = strdup(results[i].class_name.c_str()); + categories[i].score = results[i].score; + } + + auto* c_categories = new Categories; + c_categories->size = size; + c_categories->categories = categories; + return c_categories; +} + +void TfLiteNLClassifierDelete(TfLiteNLClassifier* classifier) { + delete classifier; +} + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h new file mode 100644 index 0000000..389ca5d6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h
@@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_NL_CLASSIFIER_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_NL_CLASSIFIER_H_ + +#include "tensorflow_lite_support/c/task/text/nl_classifier_common.h" +// -------------------------------------------------------------------------- +// C API for NLClassifier. +// +// Usage: +// // Create the model and interpreter options. +// TfLiteNLClassifier* classifier = TfLiteNLClassifierCreate( +// "/path/to/model.tflite"); +// +// // Classification. +// Categories* categories = TfLiteNLClassifierClassify(classifier, question); +// +// // Dispose of the API object. +// TfLiteNLClassifierDelete(classifier); + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct TfLiteNLClassifier TfLiteNLClassifier; + +typedef struct TfLiteNLClassifierOptions { + int input_tensor_index; + int output_score_tensor_index; + int output_label_tensor_index; + const char* input_tensor_name; + const char* output_score_tensor_name; + const char* output_label_tensor_name; +} TfLiteNLClassifierOptions; + +// Creates TfLiteNLClassifier from model path and options, returns nullptr if +// the file doesn't exist or is not a well formatted TFLite model path. +TfLiteNLClassifier* TfLiteNLClassifierCreateFromOptions( + const char* model_path, + const TfLiteNLClassifierOptions* options); + +// Invokes the encapsulated TFLite model and classifies the input text. +Categories* TfLiteNLClassifierClassify(const TfLiteNLClassifier* classifier, + const char* text); + +void TfLiteNLClassifierDelete(TfLiteNLClassifier* classifier); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_NL_CLASSIFIER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier_common.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier_common.cc new file mode 100644 index 0000000..c26ce05 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier_common.cc
@@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/c/task/text/nl_classifier_common.h" + +#include <memory> + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +void NLClassifierCategoriesDelete(Categories* categories) { + for (int i = 0; i < categories->size; i++) { + // `strdup` obtains memory using `malloc` and the memory needs to be + // released using `free`. + free(categories->categories[i].text); + } + delete[] categories->categories; + delete categories; +} + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier_common.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier_common.h new file mode 100644 index 0000000..ed4d1c8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier_common.h
@@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_NL_CLASSIFIER_COMMON_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_NL_CLASSIFIER_COMMON_H_ + +// C API for the NLClassifier results, Catergory. + +// TODO(b/197355311): deprecate this class and use the unified one with image +// and audio. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct Category { + char* text; + double score; +} Category; + +typedef struct Categories { + int size; + Category* categories; +} Categories; + +void NLClassifierCategoriesDelete(Categories* categories); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_NL_CLASSIFIER_COMMON_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/BUILD new file mode 100644 index 0000000..5bd132a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/BUILD
@@ -0,0 +1,50 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library_with_tflite( + name = "image_classifier", + srcs = [ + "image_classifier.cc", + ], + hdrs = [ + "image_classifier.h", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/vision:image_classifier", + ], + deps = [ + "//tensorflow_lite_support/c:common", + "//tensorflow_lite_support/c:common_utils", + "//tensorflow_lite_support/c/task/core:base_options", + "//tensorflow_lite_support/c/task/processor:bounding_box", + "//tensorflow_lite_support/c/task/processor:classification_options", + "//tensorflow_lite_support/c/task/processor:classification_result", + "//tensorflow_lite_support/c/task/vision/core:frame_buffer", + "//tensorflow_lite_support/c/task/vision/utils:frame_buffer_cpp_c_utils", + "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", + ], +) + +cc_library_with_tflite( + name = "object_detector", + hdrs = [ + "object_detector.h", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/vision:object_detector", + ], + deps = [ + "//tensorflow_lite_support/c:common", + "//tensorflow_lite_support/c/task/core:base_options", + "//tensorflow_lite_support/c/task/processor:classification_options", + "//tensorflow_lite_support/c/task/processor:detection_result", + "//tensorflow_lite_support/c/task/vision/core:frame_buffer", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/core/BUILD new file mode 100644 index 0000000..75c97b4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/core/BUILD
@@ -0,0 +1,13 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "frame_buffer", + hdrs = [ + "frame_buffer.h", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/core/frame_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/core/frame_buffer.h new file mode 100644 index 0000000..8cab267 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/core/frame_buffer.h
@@ -0,0 +1,81 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_FRAME_BUFFER_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_FRAME_BUFFER_H_ + +#include <stdint.h> + +// Defines C structs for holding the frame buffer. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Colorspace formats. +enum TfLiteFrameBufferFormat { + kRGBA, + kRGB, + kNV12, + kNV21, + kYV12, + kYV21, + kGRAY, + kUNKNOWN +}; + +// FrameBuffer content orientation follows EXIF specification. The name of +// each enum value defines the position of the 0th row and the 0th column of +// the image content. See http://jpegclub.org/exif_orientation.html for +// details. +enum TfLiteFrameBufferOrientation { + kTopLeft = 1, + kTopRight = 2, + kBottomRight = 3, + kBottomLeft = 4, + kLeftTop = 5, + kRightTop = 6, + kRightBottom = 7, + kLeftBottom = 8 +}; + +// Dimension information for the whole frame. +struct TfLiteFrameBufferDimension { + // The width dimension in pixel unit. + int width; + // The height dimension in pixel unit. + int height; +}; + +// A `FrameBuffer` provides a view into the provided backing buffer (e.g. camera +// frame or still image) with buffer format information. FrameBuffer doesn't +// take ownership of the provided backing buffer. The caller is responsible to +// manage the backing buffer lifecycle for the lifetime of the FrameBuffer. +typedef struct TfLiteFrameBuffer { + // Colorspace format of the frame buffer. + enum TfLiteFrameBufferFormat format; + // Orientation of the frame buffer. + enum TfLiteFrameBufferOrientation orientation; + // Dimension information for the whole frame. + struct TfLiteFrameBufferDimension dimension; + // Holds the backing buffer for the frame buffer. Only single planar images + // are supported as of now. + uint8_t* buffer; +} TfLiteFrameBuffer; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_FRAME_BUFFER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc new file mode 100644 index 0000000..8981e66 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc
@@ -0,0 +1,236 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/c/task/vision/image_classifier.h" + +#include <memory> + +#include "tensorflow_lite_support/c/common_utils.h" +#include "tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.h" +#include "tensorflow_lite_support/cc/task/vision/image_classifier.h" +#include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h" + +namespace { +using ::tflite::support::StatusOr; +using ClassificationResultCpp = ::tflite::task::vision::ClassificationResult; +using ClassificationsCpp = ::tflite::task::vision::Classifications; +using ClassCpp = ::tflite::task::vision::Class; +using BoundingBoxCpp = ::tflite::task::vision::BoundingBox; +using ImageClassifierCpp = ::tflite::task::vision::ImageClassifier; +using ImageClassifierOptionsCpp = + ::tflite::task::vision::ImageClassifierOptions; +using FrameBufferCpp = ::tflite::task::vision::FrameBuffer; +using ::tflite::support::TfLiteSupportStatus; + +StatusOr<ImageClassifierOptionsCpp> CreateImageClassifierCppOptionsFromCOptions( + const TfLiteImageClassifierOptions* c_options) { + if (c_options == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Expected non null options."), + TfLiteSupportStatus::kInvalidArgumentError); + } + + ImageClassifierOptionsCpp cpp_options = {}; + + // More file sources can be added in else ifs + if (c_options->base_options.model_file.file_path) + cpp_options.mutable_base_options()->mutable_model_file()->set_file_name( + c_options->base_options.model_file.file_path); + + // c_options->base_options.compute_settings.num_threads is expected to be + // set to value > 0 or -1. Otherwise invoking + // ImageClassifierCpp::CreateFromOptions() results in a not ok status. + cpp_options.mutable_base_options() + ->mutable_compute_settings() + ->mutable_tflite_settings() + ->mutable_cpu_settings() + ->set_num_threads( + c_options->base_options.compute_settings.cpu_settings.num_threads); + + for (int i = 0; i < c_options->classification_options.label_denylist.length; + i++) + cpp_options.add_class_name_blacklist( + c_options->classification_options.label_denylist.list[i]); + + for (int i = 0; i < c_options->classification_options.label_allowlist.length; + i++) + cpp_options.add_class_name_whitelist( + c_options->classification_options.label_allowlist.list[i]); + + // Check needed since setting a nullptr for this field results in a segfault + // on invocation of ImageClassifierCpp::CreateFromOptions(). + if (c_options->classification_options.display_names_local) { + cpp_options.set_display_names_locale( + c_options->classification_options.display_names_local); + } + + // c_options->classification_options.max_results is expected to be set to -1 + // or any value > 0. Otherwise invoking + // ImageClassifierCpp::CreateFromOptions() results in a not ok status. + cpp_options.set_max_results(c_options->classification_options.max_results); + + cpp_options.set_score_threshold( + c_options->classification_options.score_threshold); + + return cpp_options; +} +} // namespace + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +struct TfLiteImageClassifier { + std::unique_ptr<ImageClassifierCpp> impl; +}; + +TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate() { + // Use brace-enclosed initializer list will break the Kokoro test. + TfLiteImageClassifierOptions options = {{{0}}}; + options.classification_options.max_results = -1; + options.classification_options.score_threshold = 0.0; + options.base_options.compute_settings.cpu_settings.num_threads = -1; + return options; +} + +TfLiteImageClassifier* TfLiteImageClassifierFromOptions( + const TfLiteImageClassifierOptions* options, + TfLiteSupportError** error) { + StatusOr<ImageClassifierOptionsCpp> cpp_option_status = + CreateImageClassifierCppOptionsFromCOptions(options); + + if (!cpp_option_status.ok()) { + ::tflite::support::CreateTfLiteSupportErrorWithStatus( + cpp_option_status.status(), error); + return nullptr; + } + + StatusOr<std::unique_ptr<ImageClassifierCpp>> classifier_status = + ImageClassifierCpp::CreateFromOptions(cpp_option_status.value()); + + if (classifier_status.ok()) { + return new TfLiteImageClassifier{.impl = + std::move(classifier_status.value())}; + } else { + ::tflite::support::CreateTfLiteSupportErrorWithStatus( + classifier_status.status(), error); + return nullptr; + } +} + +TfLiteClassificationResult* GetClassificationResultCStruct( + const ClassificationResultCpp& classification_result_cpp) { + auto c_classifications = + new TfLiteClassifications[classification_result_cpp + .classifications_size()]; + + for (int head = 0; head < classification_result_cpp.classifications_size(); + ++head) { + const ClassificationsCpp& classifications = + classification_result_cpp.classifications(head); + c_classifications[head].head_index = head; + + auto c_categories = new TfLiteCategory[classifications.classes_size()]; + c_classifications->size = classifications.classes_size(); + + for (int rank = 0; rank < classifications.classes_size(); ++rank) { + const ClassCpp& classification = classifications.classes(rank); + c_categories[rank].index = classification.index(); + c_categories[rank].score = classification.score(); + + if (classification.has_class_name()) + c_categories[rank].label = strdup(classification.class_name().c_str()); + else + c_categories[rank].label = nullptr; + + if (classification.has_display_name()) + c_categories[rank].display_name = + strdup(classification.display_name().c_str()); + else + c_categories[rank].display_name = nullptr; + } + c_classifications[head].categories = c_categories; + } + + auto c_classification_result = new TfLiteClassificationResult; + c_classification_result->classifications = c_classifications; + c_classification_result->size = + classification_result_cpp.classifications_size(); + + return c_classification_result; +} + +TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi( + const TfLiteImageClassifier* classifier, + const TfLiteFrameBuffer* frame_buffer, + const TfLiteBoundingBox* roi, + TfLiteSupportError** error) { + if (classifier == nullptr) { + tflite::support::CreateTfLiteSupportError( + kInvalidArgumentError, "Expected non null image classifier.", error); + return nullptr; + } + + StatusOr<std::unique_ptr<FrameBufferCpp>> cpp_frame_buffer_status = + ::tflite::task::vision::CreateCppFrameBuffer(frame_buffer); + if (!cpp_frame_buffer_status.ok()) { + tflite::support::CreateTfLiteSupportErrorWithStatus( + cpp_frame_buffer_status.status(), error); + return nullptr; + } + + BoundingBoxCpp cc_roi; + if (roi == nullptr) { + cc_roi.set_width(frame_buffer->dimension.width); + cc_roi.set_height(frame_buffer->dimension.height); + } else { + cc_roi.set_origin_x(roi->origin_x); + cc_roi.set_origin_y(roi->origin_y); + cc_roi.set_width(roi->width); + cc_roi.set_height(roi->height); + } + + // fnc_sample(cpp_frame_buffer_status); + StatusOr<ClassificationResultCpp> cpp_classification_result_status = + classifier->impl->Classify(*std::move(cpp_frame_buffer_status.value()), + cc_roi); + + if (!cpp_classification_result_status.ok()) { + tflite::support::CreateTfLiteSupportErrorWithStatus( + cpp_classification_result_status.status(), error); + return nullptr; + } + + return GetClassificationResultCStruct( + cpp_classification_result_status.value()); +} + +TfLiteClassificationResult* TfLiteImageClassifierClassify( + const TfLiteImageClassifier* classifier, + const TfLiteFrameBuffer* frame_buffer, + TfLiteSupportError** error) { + return TfLiteImageClassifierClassifyWithRoi(classifier, frame_buffer, nullptr, + error); +} + +void TfLiteImageClassifierDelete(TfLiteImageClassifier* classifier) { + delete classifier; +} + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h new file mode 100644 index 0000000..8a53e5e2 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h
@@ -0,0 +1,214 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_IMAGE_CLASSIFIER_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_IMAGE_CLASSIFIER_H_ + +#include <stdint.h> + +#include "tensorflow_lite_support/c/common.h" +#include "tensorflow_lite_support/c/task/core/base_options.h" +#include "tensorflow_lite_support/c/task/processor/bounding_box.h" +#include "tensorflow_lite_support/c/task/processor/classification_options.h" +#include "tensorflow_lite_support/c/task/processor/classification_result.h" +#include "tensorflow_lite_support/c/task/vision/core/frame_buffer.h" + +// -------------------------------------------------------------------------- +/// C API for ImageClassifiier. +/// +/// The API leans towards simplicity and uniformity instead of convenience, as +/// most usage will be by language-specific wrappers. It provides largely the +/// same set of functionality as that of the C++ TensorFlow Lite +/// `ImageClassifier` API, but is useful for shared libraries where having +/// a stable ABI boundary is important. +/// +/// Usage: +/// <pre><code> +/// // Create the model +/// Using the options initialized with default values returned by +/// TfLiteImageClassifierOptionsCreate() makes sure that there will be no +/// undefined behaviour due to garbage values in unitialized members. +/// TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); +/// +/// Set the model file path in options +/// options.base_options.model_file.file_path = "/path/to/model.tflite"; +/// +/// If need be, set values for any options to customize behaviour. +/// options.base_options.compute_settings.cpu_settings.num_threads = 3 +/// +/// Create TfLiteImageClassifier using the options: +/// If error information is not nedded in case of failure: +/// TfLiteImageClassifier* image_classifier = +/// TfLiteImageClassifierFromOptions(&options, NULL); +/// +/// If error information is nedded in case of failure: +/// TfLiteSupportError* create_error = NULL; +/// TfLiteImageClassifier* image_classifier = +/// TfLiteImageClassifierFromOptions(&options, &create_error); +/// +/// if (!image_classifier) { +/// Handle failure. +/// Do something with `create_error`, if requested as illustrated above. +/// } +/// +/// Dispose of the create_error object. +/// TfLiteSupportErrorDelete(create_error); +/// +/// Classify an image +/// TfLiteFrameBuffer frame_buffer = { Initialize with image data } +/// +/// If error information is not nedded in case of failure: +/// TfLiteClassificationResult* classification_result = +/// TfLiteImageClassifierClassify(image_classifier, &frame_buffer, NULL); +/// +/// If error information is nedded in case of failure: +/// TfLiteSupportError* classify_error = NULL; +/// TfLiteClassificationResult* classification_result = +/// TfLiteImageClassifierClassify(image_classifier, &frame_buffer, +/// &classify_error); +/// +/// if (!classification_result) { +/// Handle failure. +/// Do something with `classify_error`, if requested as illustrated above. +/// } +/// +/// Dispose of the classify_error object. +/// TfLiteSupportErrorDelete(classify_error); +/// +/// Dispose of the API object. +/// TfLiteImageClassifierDelete(image_classifier); + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct TfLiteImageClassifier TfLiteImageClassifier; + +typedef struct TfLiteImageClassifierOptions { + TfLiteClassificationOptions classification_options; + TfLiteBaseOptions base_options; +} TfLiteImageClassifierOptions; + +// Creates and returns TfLiteImageClassifierOptions initialized with default +// values. Default values are as follows: +// 1. .classification_options.max_results = -1, which returns all classification +// categories by default. +// 2. .base_options.compute_settings.tflite_settings.cpu_settings.num_threads = +// -1, which makes the TFLite runtime choose the value. +// 3. .classification_options.score_threshold = 0 +// 4. All pointers like .base_options.model_file.file_path, +// .base_options.classification_options.display_names_local, +// .classification_options.label_allowlist.list, +// options.classification_options.label_denylist.list are NULL. +// 5. All other integer values are initialized to 0. +TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate(); + +// Creates TfLiteImageClassifier from options. +// .base_options.model_file.file_path in TfLiteImageClassifierOptions should be +// set to the path of the tflite model you wish to create the +// TfLiteImageClassifier with. +// Create TfLiteImageClassifierOptions using +// TfLiteImageClassifierOptionsCreate(). If need be, you can change the default +// values of options for customizing classification, If options are not created +// in the aforementioned way, you have to make sure that all members are +// initialized to respective default values and all pointer members are zero +// initialized to avoid any undefined behaviour. +// +// Returns the created image classifier in case of success. +// Returns nullptr on failure which happens commonly due to one of the following +// issues: +// 1. file doesn't exist or is not a well formatted. +// 2. options is nullptr. +// 3. Both options.classification_options.label_denylist and +// options.classification_options.label_allowlist are non empty. These +// fields are mutually exclusive. +// +// The caller can check if an error was encountered by testing if the returned +// value of the function is null. If the caller doesn't want the reason for +// failure, they can simply pass a NULL for the address of the error pointer as +// shown below: +// +// TfLiteImageClassifier* classifier = TfLiteImageClassifierFromOptions(options, +// NULL); +// +// If the caller wants to be informed of the reason for failure, they must pass +// the adress of a pointer of type TfLiteSupportError to the `error` param as +// shown below: +// +// TfLiteSupport *error = NULL: +// TfLiteImageClassifier* classifier = TfLiteImageClassifierFromOptions(options, +// &error); +// +// In case of unsuccessful execution, Once the function returns, the error +// pointer will point to a struct containing the error information. If error +// info is passed back to the caller, it is the responsibility of the caller to +// free the error struct by calling the following method defined in common.h: +// +// TfLiteSupportErrorDelete(error) +// +TfLiteImageClassifier* TfLiteImageClassifierFromOptions( + const TfLiteImageClassifierOptions* options, + TfLiteSupportError** error); + +// Invokes the encapsulated TFLite model and classifies the frame_buffer. +// Returns a pointer to the created classification result in case of success or +// NULL in case of failure. The caller must test the return value to identify +// success or failure. If the caller doesn't want the reason for failure, they +// can simply pass a NULL for the address of the error pointer as shown below: +// +// TfLiteClassificationResult* classification_result = +// TfLiteImageClassifierClassify(&options, NULL); +// +// If the caller wants to be informed of the reason for failure, they must pass +// the adress of a pointer of type TfLiteSupportError to the `error` param as +// shown below: +// +// TfLiteSupport *error = NULL: +// TfLiteImageClassifier* classifier = TfLiteImageClassifierFromOptions(options, +// &error); +// +// In case of unsuccessful execution, Once the function returns, the error +// pointer will point to a struct containing the error information. If error +// info is passed back to the caller, it is the responsibility of the caller to +// free the error struct by calling the following method defined in common.h: +// +// TfLiteSupportErrorDelete(error) +// +TfLiteClassificationResult* TfLiteImageClassifierClassify( + const TfLiteImageClassifier* classifier, + const TfLiteFrameBuffer* frame_buffer, + TfLiteSupportError** error); + +// Invokes the encapsulated TFLite model and classifies the region of the +// frame_buffer specified by the bounding box. Same as TfLiteImageClassifier* +// TfLiteImageClassifierFromOptions( +// const TfLiteImageClassifierOptions* options, TfLiteSupportError** error), +// except that the +// classification is performed based on the input region of interest. Cropping +// according to this region of interest is prepended to the pre-processing +// operations. +TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi( + const TfLiteImageClassifier* classifier, + const TfLiteFrameBuffer* frame_buffer, + const TfLiteBoundingBox* roi, + TfLiteSupportError** error); + +// Disposes off the image classifier. +void TfLiteImageClassifierDelete(TfLiteImageClassifier* classifier); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_IMAGE_CLASSIFIER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h new file mode 100644 index 0000000..5a2d3e1 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h
@@ -0,0 +1,200 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_IMAGE_CLASSIFIER_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_IMAGE_CLASSIFIER_H_ + +#include <stdint.h> + +#include "tensorflow_lite_support/c/common.h" +#include "tensorflow_lite_support/c/task/core/base_options.h" +#include "tensorflow_lite_support/c/task/processor/classification_options.h" +#include "tensorflow_lite_support/c/task/processor/detection_result.h" +#include "tensorflow_lite_support/c/task/vision/core/frame_buffer.h" + +// -------------------------------------------------------------------------- +/// C API for Object Detector. +/// +/// The API leans towards simplicity and uniformity instead of convenience, as +/// most usage will be by language-specific wrappers. It provides largely the +/// same set of functionality as that of the C++ TensorFlow Lite +/// `ObjectDetector` API, but is useful for shared libraries where having +/// a stable ABI boundary is important. +/// +/// Usage: +/// <pre><code> +/// // Create the model +/// Using the options initialized with default values returned by +/// TfLiteObjectDetectorOptionsCreate() makes sure that there will be no +/// undefined behaviour due to garbage values in unitialized members. +/// TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate(); +/// +/// Set the model file path in options +/// options.base_options.model_file.file_path = "/path/to/model.tflite"; +/// +/// If need be, set values for any options to customize behaviour. +/// options.base_options.compute_settings.cpu_settings.num_threads = 3 +/// +/// Create TfLiteObjectDetector using the options: +/// If error information is not nedded in case of failure: +/// TfLiteObjectDetector* object_detector = +/// TfLiteObjectDetectorFromOptions(&options, NULL); +/// +/// If error information is nedded in case of failure: +/// TfLiteSupportError* create_error = NULL; +/// TfLiteObjectDetector* object_detector = +/// TfLiteObjectDetectorFromOptions(&options, &create_error); +/// +/// if (!object_detector) { +/// Handle failure. +/// Do something with `create_error`, if requested as illustrated above. +/// } +/// +/// Dispose of the create_error object. +/// TfLiteSupportErrorDelete(create_error); +/// +/// Classify an image +/// TfLiteFrameBuffer frame_buffer = { Initialize with image data } +/// +/// If error information is not nedded in case of failure: +/// TfLiteDetectionResult* detection_result = +/// TfLiteObjectDetectorClassify(object_detector, &frame_buffer, NULL); +/// +/// If error information is needed in case of failure: +/// TfLiteSupportError* detect_error = NULL; +/// TfLiteDetectionResult* detection_result = +/// TfLiteObjectDetectorDetect(object_detector, &frame_buffer, +/// &detect_error); +/// +/// if (!detection_result) { +/// Handle failure. +/// Do something with `detection_error`, if requested as illustrated above. +/// } +/// +/// Dispose of the detection_error object. +/// TfLiteSupportErrorDelete(detection_error); +/// +/// Dispose of the API object. +/// TfLiteObjectDetectorOptionsDelete(object_detector); + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct TfLiteObjectDetector TfLiteObjectDetector; + +typedef struct TfLiteObjectDetectorOptions { + TfLiteClassificationOptions classification_options; + TfLiteBaseOptions base_options; +} TfLiteObjectDetectorOptions; + +// Creates and returns TfLiteObjectDetectorOptions initialized with default +// values. Default values are as follows: +// 1. .classification_options.max_results = -1, which returns all classification +// categories by default. +// 2. .base_options.compute_settings.tflite_settings.cpu_settings.num_threads = +// -1, which makes the TFLite runtime choose the value. +// 3. .classification_options.score_threshold = 0 +// 4. All pointers like .base_options.model_file.file_path, +// .base_options.classification_options.display_names_local, +// .classification_options.label_allowlist.list, +// options.classification_options.label_denylist.list are NULL. +// 5. All other integer values are initialized to 0. +TfLiteObjectDetectorOptions TfLiteObjectDetectorOptionsCreate(); + +// Creates TfLiteObjectDetector from options. +// .base_options.model_file.file_path in TfLiteObjectDetectorOptions should be +// set to the path of the tflite model you wish to create the +// TfLiteObjectDetector with. +// Create TfLiteObjectDetectorOptions using +// TfLiteObjectDetectorOptionsCreate(). If need be, you can change the default +// values of options for customizing classification, If options are not created +// in the aforementioned way, you have to make sure that all members are +// initialized to respective default values and all pointer members are zero +// initialized to avoid any undefined behaviour. +// +// Returns the created object detector in case of success. +// Returns nullptr on failure which happens commonly due to one of the following +// issues: +// 1. file doesn't exist or is not a well formatted. +// 2. options is nullptr. +// 3. Both options.classification_options.label_denylist and +// options.classification_options.label_allowlist are non empty. These +// fields are mutually exclusive. +// +// The caller can check if an error was encountered by testing if the returned +// value of the function is null. If the caller doesn't want the reason for +// failure, they can simply pass a NULL for the address of the error pointer as +// shown below: +// +// TfLiteObjectDetector* detector = TfLiteObjectDetectorFromOptions(options, +// NULL); +// +// If the caller wants to be informed of the reason for failure, they must pass +// the adress of a pointer of type TfLiteSupportError to the `error` param as +// shown below: +// +// TfLiteSupport *error = NULL: +// TfLiteObjectDetector* classifier = TfLiteObjectDetectorFromOptions(options, +// &error); +// +// In case of unsuccessful execution, Once the function returns, the error +// pointer will point to a struct containing the error information. If error +// info is passed back to the caller, it is the responsibility of the caller to +// free the error struct by calling the following method defined in common.h: +// +// TfLiteSupportErrorDelete(error) +// +TfLiteObjectDetector* TfLiteObjectDetectorFromOptions( + const TfLiteObjectDetectorOptions* options, + TfLiteSupportError** error); + +// Invokes the encapsulated TFLite model and performs object detection on the +// frame_buffer. Returns a pointer to the created object detection result result +// in case of success or NULL in case of failure. The caller must test the +// return value to identify success or failure. If the caller doesn't want the +// reason for failure, they can simply pass a NULL for the address of the error +// pointer as shown below: +// +// TfLiteDetectionResult* detection_result = +// TfLiteObjectDetectorDetect(&options, NULL); +// +// If the caller wants to be informed of the reason for failure, they must pass +// the adress of a pointer of type TfLiteSupportError to the `error` param as +// shown below: +// +// TfLiteSupport *error = NULL: +// TfLiteObjectDetector* detector = TfLiteObjectDetectorFromOptions(options, +// &error); +// +// In case of unsuccessful execution, Once the function returns, the error +// pointer will point to a struct containing the error information. If error +// info is passed back to the caller, it is the responsibility of the caller to +// free the error struct by calling the following method defined in common.h: +// +// TfLiteSupportErrorDelete(error) +// +TfLiteDetectionResult* TfLiteObjectDetectorDetect( + const TfLiteObjectDetector* detector, + const TfLiteFrameBuffer* frame_buffer, + TfLiteSupportError** error); + +// Disposes off the object detector. +void TfLiteObjectDetectorDelete(TfLiteObjectDetector* detector); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_IMAGE_CLASSIFIER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/BUILD new file mode 100644 index 0000000..9b395b3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/BUILD
@@ -0,0 +1,22 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "frame_buffer_cpp_c_utils", + srcs = [ + "frame_buffer_cpp_c_utils.cc", + ], + hdrs = [ + "frame_buffer_cpp_c_utils.h", + ], + deps = [ + "//tensorflow_lite_support/c/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "@com_google_absl//absl/strings:str_format", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.cc new file mode 100644 index 0000000..a084789 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.cc
@@ -0,0 +1,50 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.h" + +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/common.h" + +namespace tflite { +namespace task { +namespace vision { + +namespace { +using FrameBufferCpp = ::tflite::task::vision::FrameBuffer; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +} // namespace + +StatusOr<std::unique_ptr<FrameBufferCpp>> CreateCppFrameBuffer( + const TfLiteFrameBuffer* frame_buffer) { + if (frame_buffer == nullptr) + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Expected non null frame buffer."), + TfLiteSupportStatus::kInvalidArgumentError); + + FrameBufferCpp::Format frame_buffer_format = + FrameBufferCpp::Format(frame_buffer->format); + + return CreateFromRawBuffer( + frame_buffer->buffer, + {frame_buffer->dimension.width, frame_buffer->dimension.height}, + frame_buffer_format); +} + +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.h new file mode 100644 index 0000000..6d062da --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.h
@@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_FRAME_BUFFER_CPP_C_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_FRAME_BUFFER_CPP_C_UTILS_H_ + +#include "tensorflow_lite_support/c/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" + +// Utils for Conversions between C and C++ FrameBuffer +// ----------------------------------------------------------------- +// Meant to be used with vision C apis. + +// Creates the C++ FrameBuffer from the C FrameBuffer +namespace tflite { +namespace task { +namespace vision { + +tflite::support::StatusOr<std::unique_ptr<tflite::task::vision::FrameBuffer>> +CreateCppFrameBuffer(const TfLiteFrameBuffer* frame_buffer); + +} // namespace vision +} // namespace task +} // namespace tflite +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_FRAME_BUFFER_CPP_C_UTILS_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/BUILD new file mode 100644 index 0000000..788dce4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/BUILD
@@ -0,0 +1,34 @@ +load( + "@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", + "cc_test_with_tflite", +) + +package( + default_visibility = [ + "//visibility:private", + ], + licenses = ["notice"], # Apache 2.0 +) + +# To test it with Bazel, run the following command from the terminal of your desktop: +# bazel test tensorflow_lite_support/c/test/task/vision:image_classifier_test +cc_test_with_tflite( + name = "image_classifier_test", + srcs = ["image_classifier_test.cc"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_images", + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_models", + ], + tflite_deps = [ + "//tensorflow_lite_support/c/task/vision:image_classifier", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], + deps = [ + "//tensorflow_lite_support/c:common", + "//tensorflow_lite_support/c/task/processor:classification_result", + "//tensorflow_lite_support/c/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/test:test_utils", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc new file mode 100644 index 0000000..b398b7ad --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc
@@ -0,0 +1,432 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/c/task/vision/image_classifier.h" + +#include <string.h> + +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow_lite_support/c/common.h" +#include "tensorflow_lite_support/c/task/processor/classification_result.h" +#include "tensorflow_lite_support/c/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +namespace tflite { +namespace task { +namespace vision { +namespace { + +using ::testing::HasSubstr; +using ::tflite::support::StatusOr; +using ::tflite::task::JoinPath; + +constexpr char kTestDataDirectory[] = + "/tensorflow_lite_support/cc/test/testdata/task/" + "vision/"; +// Quantized model. +constexpr char kMobileNetQuantizedWithMetadata[] = + "mobilenet_v1_0.25_224_quant.tflite"; + +StatusOr<ImageData> LoadImage(const char* image_name) { + return DecodeImageFromFile( + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name)); +} + +class ImageClassifierFromOptionsTest : public tflite_shims::testing::Test {}; + +TEST_F(ImageClassifierFromOptionsTest, FailsWithNullOptionsAndError) { + TfLiteSupportError* error = nullptr; + TfLiteImageClassifier* image_classifier = + TfLiteImageClassifierFromOptions(nullptr, &error); + + EXPECT_EQ(image_classifier, nullptr); + if (image_classifier) + TfLiteImageClassifierDelete(image_classifier); + + ASSERT_NE(error, nullptr); + EXPECT_EQ(error->code, kInvalidArgumentError); + EXPECT_NE(error->message, nullptr); + EXPECT_THAT(error->message, HasSubstr("Expected non null options")); + + TfLiteSupportErrorDelete(error); +} + +TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPath) { + TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); + TfLiteImageClassifier* image_classifier = + TfLiteImageClassifierFromOptions(&options, nullptr); + EXPECT_EQ(image_classifier, nullptr); + if (image_classifier) + TfLiteImageClassifierDelete(image_classifier); +} + +TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPathAndError) { + TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); + + TfLiteSupportError* error = nullptr; + TfLiteImageClassifier* image_classifier = + TfLiteImageClassifierFromOptions(&options, &error); + + EXPECT_EQ(image_classifier, nullptr); + if (image_classifier) + TfLiteImageClassifierDelete(image_classifier); + + ASSERT_NE(error, nullptr); + EXPECT_EQ(error->code, kInvalidArgumentError); + EXPECT_NE(error->message, nullptr); + EXPECT_THAT(error->message, HasSubstr("`base_options.model_file`")); + + TfLiteSupportErrorDelete(error); +} + +TEST_F(ImageClassifierFromOptionsTest, SucceedsWithModelPath) { + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileNetQuantizedWithMetadata); + TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); + TfLiteImageClassifier* image_classifier = + TfLiteImageClassifierFromOptions(&options, nullptr); + + EXPECT_NE(image_classifier, nullptr); + TfLiteImageClassifierDelete(image_classifier); +} + +TEST_F(ImageClassifierFromOptionsTest, SucceedsWithNumberOfThreadsAndError) { + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileNetQuantizedWithMetadata); + TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); + options.base_options.compute_settings.cpu_settings.num_threads = 3; + + TfLiteSupportError* error = nullptr; + TfLiteImageClassifier* image_classifier = + TfLiteImageClassifierFromOptions(&options, &error); + + EXPECT_NE(image_classifier, nullptr); + EXPECT_EQ(error, nullptr); + + if (image_classifier) + TfLiteImageClassifierDelete(image_classifier); + if (error) + TfLiteSupportErrorDelete(error); +} + +TEST_F(ImageClassifierFromOptionsTest, + FailsWithClassNameDenyListAndClassNameAllowListAndError) { + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileNetQuantizedWithMetadata); + + TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); + + char* label_denylist[9] = {(char*)"brambling"}; + options.classification_options.label_denylist.list = label_denylist; + options.classification_options.label_denylist.length = 1; + + char* label_allowlist[12] = {(char*)"cheeseburger"}; + options.classification_options.label_allowlist.list = label_allowlist; + options.classification_options.label_allowlist.length = 1; + + TfLiteSupportError* error = nullptr; + TfLiteImageClassifier* image_classifier = + TfLiteImageClassifierFromOptions(&options, &error); + + EXPECT_EQ(image_classifier, nullptr); + if (image_classifier) + TfLiteImageClassifierDelete(image_classifier); + + ASSERT_NE(error, nullptr); + EXPECT_EQ(error->code, kInvalidArgumentError); + EXPECT_NE(error->message, nullptr); + EXPECT_THAT(error->message, HasSubstr("mutually exclusive options")); + + TfLiteSupportErrorDelete(error); +} + +TEST(ImageClassifierNullClassifierClassifyTest, + FailsWithNullImageClassifierAndError) { + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, + LoadImage("burger-224.png")); + + TfLiteSupportError* error = nullptr; + TfLiteClassificationResult* classification_result = + TfLiteImageClassifierClassify(nullptr, nullptr, &error); + + ImageDataFree(&image_data); + + EXPECT_EQ(classification_result, nullptr); + if (classification_result) + TfLiteClassificationResultDelete(classification_result); + + ASSERT_NE(error, nullptr); + EXPECT_EQ(error->code, kInvalidArgumentError); + EXPECT_NE(error->message, nullptr); + EXPECT_THAT(error->message, HasSubstr("Expected non null image classifier")); + + TfLiteSupportErrorDelete(error); +} + +class ImageClassifierClassifyTest : public tflite_shims::testing::Test { + protected: + void SetUp() override { + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileNetQuantizedWithMetadata); + + TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); + image_classifier = TfLiteImageClassifierFromOptions(&options, nullptr); + ASSERT_NE(image_classifier, nullptr); + } + + void TearDown() override { TfLiteImageClassifierDelete(image_classifier); } + TfLiteImageClassifier* image_classifier; +}; + +TEST_F(ImageClassifierClassifyTest, SucceedsWithImageData) { + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, + LoadImage("burger-224.png")); + + TfLiteFrameBuffer frame_buffer = { + .format = kRGB, + .orientation = kTopLeft, + .dimension = {.width = image_data.width, .height = image_data.height}, + .buffer = image_data.pixel_data}; + + TfLiteClassificationResult* classification_result = + TfLiteImageClassifierClassify(image_classifier, &frame_buffer, nullptr); + + ImageDataFree(&image_data); + + ASSERT_NE(classification_result, nullptr); + EXPECT_GE(classification_result->size, 1); + EXPECT_NE(classification_result->classifications, nullptr); + EXPECT_GE(classification_result->classifications->size, 1); + EXPECT_NE(classification_result->classifications->categories, nullptr); + EXPECT_EQ(strcmp(classification_result->classifications->categories[0].label, + "cheeseburger"), + 0); + EXPECT_GE(classification_result->classifications->categories[0].score, 0.90); + + TfLiteClassificationResultDelete(classification_result); +} + +TEST_F(ImageClassifierClassifyTest, FailsWithNullFrameBufferAndError) { + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, + LoadImage("burger-224.png")); + + TfLiteSupportError* error = nullptr; + TfLiteClassificationResult* classification_result = + TfLiteImageClassifierClassify(image_classifier, nullptr, &error); + + ImageDataFree(&image_data); + + EXPECT_EQ(classification_result, nullptr); + if (classification_result) + TfLiteClassificationResultDelete(classification_result); + + ASSERT_NE(error, nullptr); + EXPECT_EQ(error->code, kInvalidArgumentError); + EXPECT_NE(error->message, nullptr); + EXPECT_THAT(error->message, HasSubstr("Expected non null frame buffer")); + + TfLiteSupportErrorDelete(error); +} + +TEST_F(ImageClassifierClassifyTest, FailsWithNullImageDataAndError) { + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, + LoadImage("burger-224.png")); + + TfLiteFrameBuffer frame_buffer = {.format = kRGB, .orientation = kTopLeft}; + + TfLiteSupportError* error = nullptr; + TfLiteClassificationResult* classification_result = + TfLiteImageClassifierClassify(image_classifier, &frame_buffer, &error); + + ImageDataFree(&image_data); + + EXPECT_EQ(classification_result, nullptr); + if (classification_result) + TfLiteClassificationResultDelete(classification_result); + + ASSERT_NE(error, nullptr); + EXPECT_EQ(error->code, kInvalidArgumentError); + EXPECT_NE(error->message, nullptr); + EXPECT_THAT(error->message, HasSubstr("Invalid stride information")); + + TfLiteSupportErrorDelete(error); +} + +TEST_F(ImageClassifierClassifyTest, SucceedsWithRoiWithinImageBounds) { + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, + LoadImage("burger-224.png")); + + TfLiteFrameBuffer frame_buffer = { + .format = kRGB, + .orientation = kTopLeft, + .dimension = {.width = image_data.width, .height = image_data.height}, + .buffer = image_data.pixel_data}; + + TfLiteBoundingBox bounding_box = { + .origin_x = 0, .origin_y = 0, .width = 100, .height = 100}; + TfLiteSupportError* error = nullptr; + TfLiteClassificationResult* classification_result = + TfLiteImageClassifierClassifyWithRoi(image_classifier, &frame_buffer, + &bounding_box, &error); + + ImageDataFree(&image_data); + + ASSERT_NE(classification_result, nullptr); + EXPECT_GE(classification_result->size, 1); + EXPECT_NE(classification_result->classifications, nullptr); + EXPECT_GE(classification_result->classifications->size, 1); + EXPECT_NE(classification_result->classifications->categories, nullptr); + EXPECT_EQ(strcmp(classification_result->classifications->categories[0].label, + "bagel"), + 0); + EXPECT_GE(classification_result->classifications->categories[0].score, 0.30); + + TfLiteClassificationResultDelete(classification_result); +} + +TEST_F(ImageClassifierClassifyTest, FailsWithRoiOutsideImageBoundsAndError) { + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, + LoadImage("burger-224.png")); + + TfLiteFrameBuffer frame_buffer = { + .format = kRGB, + .orientation = kTopLeft, + .dimension = {.width = image_data.width, .height = image_data.height}, + .buffer = image_data.pixel_data}; + + TfLiteBoundingBox bounding_box = { + .origin_x = 0, .origin_y = 0, .width = 250, .height = 250}; + TfLiteSupportError* error = nullptr; + TfLiteClassificationResult* classification_result = + TfLiteImageClassifierClassifyWithRoi(image_classifier, &frame_buffer, + &bounding_box, &error); + + ImageDataFree(&image_data); + + EXPECT_EQ(classification_result, nullptr); + if (classification_result) + TfLiteClassificationResultDelete(classification_result); + + ASSERT_NE(error, nullptr); + EXPECT_EQ(error->code, kInvalidArgumentError); + EXPECT_NE(error->message, nullptr); + EXPECT_THAT(error->message, HasSubstr("Invalid crop coordinates")); + + TfLiteSupportErrorDelete(error); +} + +TEST(ImageClassifierWithUserDefinedOptionsClassifyTest, + SucceedsWithClassNameDenyList) { + char* denylisted_label_name = (char*)"cheeseburger"; + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileNetQuantizedWithMetadata); + + TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); + + char* label_denylist[12] = {denylisted_label_name}; + options.classification_options.label_denylist.list = label_denylist; + options.classification_options.label_denylist.length = 1; + + TfLiteImageClassifier* image_classifier = + TfLiteImageClassifierFromOptions(&options, nullptr); + ASSERT_NE(image_classifier, nullptr); + + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, + LoadImage("burger-224.png")); + + TfLiteFrameBuffer frame_buffer = { + .format = kRGB, + .orientation = kTopLeft, + .dimension = {.width = image_data.width, .height = image_data.height}, + .buffer = image_data.pixel_data}; + + TfLiteClassificationResult* classification_result = + TfLiteImageClassifierClassify(image_classifier, &frame_buffer, nullptr); + + ImageDataFree(&image_data); + if (image_classifier) + TfLiteImageClassifierDelete(image_classifier); + + ASSERT_NE(classification_result, nullptr); + EXPECT_GE(classification_result->size, 1); + EXPECT_NE(classification_result->classifications, nullptr); + EXPECT_GE(classification_result->classifications->size, 1); + EXPECT_NE(classification_result->classifications->categories, nullptr); + EXPECT_NE(strcmp(classification_result->classifications->categories[0].label, + denylisted_label_name), + 0); + + TfLiteClassificationResultDelete(classification_result); +} + +TEST(ImageClassifierWithUserDefinedOptionsClassifyTest, + SucceedsWithClassNameAllowList) { + char* allowlisted_label_name = (char*)"cheeseburger"; + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileNetQuantizedWithMetadata) + .data(); + + TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); + + char* label_allowlist[12] = {allowlisted_label_name}; + options.classification_options.label_allowlist.list = label_allowlist; + options.classification_options.label_allowlist.length = 1; + + TfLiteImageClassifier* image_classifier = + TfLiteImageClassifierFromOptions(&options, nullptr); + ASSERT_NE(image_classifier, nullptr); + + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, + LoadImage("burger-224.png")); + + TfLiteFrameBuffer frame_buffer = { + .format = kRGB, + .orientation = kTopLeft, + .dimension = {.width = image_data.width, .height = image_data.height}, + .buffer = image_data.pixel_data}; + + TfLiteClassificationResult* classification_result = + TfLiteImageClassifierClassify(image_classifier, &frame_buffer, nullptr); + + ImageDataFree(&image_data); + if (image_classifier) + TfLiteImageClassifierDelete(image_classifier); + + ASSERT_NE(classification_result, nullptr); + EXPECT_GE(classification_result->size, 1); + EXPECT_NE(classification_result->classifications, nullptr); + EXPECT_GE(classification_result->classifications->size, 1); + EXPECT_NE(classification_result->classifications->categories, nullptr); + EXPECT_EQ(strcmp(classification_result->classifications->categories[0].label, + allowlisted_label_name), + 0); + + TfLiteClassificationResultDelete(classification_result); +} + +} // namespace +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/BUILD index b19bfde..3fe09242 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/BUILD
@@ -1,5 +1,5 @@ package( - default_visibility = ["//tensorflow_lite_support:users"], + default_visibility = ["//tensorflow_lite_support:internal"], licenses = ["notice"], # Apache 2.0 ) @@ -9,17 +9,12 @@ "common.cc", ], hdrs = ["common.h"], + visibility = [ + "//tensorflow_lite_support:internal", + ], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", ], ) - -config_setting( - name = "tflite_use_c_api", - values = { - "copt": "-DTFLITE_USE_C_API", - }, - visibility = ["//tensorflow_lite_support:__subpackages__"], -)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc index ed373e9..09e9a83 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc
@@ -15,7 +15,8 @@ #include "tensorflow_lite_support/cc/common.h" -#include "absl/strings/str_cat.h" +#include "absl/strings/cord.h" // from @com_google_absl +#include "absl/strings/str_cat.h" // from @com_google_absl namespace tflite { namespace support { @@ -25,6 +26,8 @@ TfLiteSupportStatus tfls_code) { // NOTE: Ignores `message` if the canonical code is ok. absl::Status status = absl::Status(canonical_code, message); + // NOTE: Does nothing if the canonical code is ok. + status.SetPayload(kTfLiteSupportPayload, absl::Cord(absl::StrCat(tfls_code))); return status; }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h index 15a9224..71dd920 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h
@@ -16,8 +16,8 @@ #ifndef TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_ #define TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_ -#include "absl/status/status.h" -#include "absl/strings/string_view.h" +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl namespace tflite { namespace support { @@ -53,6 +53,12 @@ kInvalidArgumentError = 2, // Invalid FlatBuffer file or buffer specified. kInvalidFlatBufferError = 3, + // Model contains a builtin op that isn't supported by the OpResolver or + // delegates. + kUnsupportedBuiltinOp = 4, + // Model contains a custom op that isn't supported by the OpResolver or + // delegates. + kUnsupportedCustomOp = 5, // File I/O error codes.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/BUILD index 4a413617..39f4a45 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/BUILD
@@ -1,5 +1,7 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + package( - default_visibility = ["//tensorflow_lite_support:users"], + default_visibility = ["//tensorflow_lite_support:internal"], licenses = ["notice"], # Apache 2.0 ) @@ -8,8 +10,9 @@ hdrs = [ "statusor.h", ], + visibility = ["//visibility:public"], deps = [ - "//tensorflow_lite_support/cc/port/default:statusor", + "@com_google_absl//absl/status:statusor", ], ) @@ -18,31 +21,54 @@ hdrs = [ "status_macros.h", ], + visibility = ["//visibility:public"], deps = [ "//tensorflow_lite_support/cc/port/default:status_macros", ], ) cc_library( + name = "configuration_proto_inc", + hdrs = ["configuration_proto_inc.h"], + deps = ["@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_cc_proto"], +) + +cc_library_with_tflite( name = "tflite_wrapper", hdrs = ["tflite_wrapper.h"], deps = ["//tensorflow_lite_support/cc/port/default:tflite_wrapper"], ) -# This is identical to the rule above, except that it gets built with -# '-DTFLITE_USE_C_API'. This rule is used for unit tests that verify things -# work correctly when built with TFLITE_USE_C_API defined. cc_library( - name = "tflite_wrapper_with_c_api_for_test", + name = "integral_types", + hdrs = ["integral_types.h"], + visibility = ["//tensorflow_lite_support:users"], +) + +cc_library( + name = "gtest_main", testonly = 1, - hdrs = ["tflite_wrapper.h"], + hdrs = [ + "benchmark.h", + "gmock.h", + "gtest.h", + "status_matchers.h", + ], + visibility = [ + "//tensorflow_lite_support:internal", + ], deps = [ - "//intelligence/mobile_acceleration/proto:allowlist_portable_proto", - "//intelligence/mobile_acceleration/support_library:tflite_wrapper_with_c_api_for_test", + "//tensorflow_lite_support/cc/port/default:status_matchers", + "@com_google_googletest//:gtest_main", ], ) cc_library( - name = "integral_types", - hdrs = ["integral_types.h"], + name = "proto2", + hdrs = [ + "proto2.h", + ], + deps = [ + "@com_google_protobuf//:protobuf", + ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/build_defs.bzl b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/build_defs.bzl index a8053db..73fbb5f 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/build_defs.bzl +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/build_defs.bzl
@@ -1,30 +1,22 @@ """.bzl file for TFLite Support open source build configs.""" -load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library") - def provided_args(**kwargs): """Returns the keyword arguments omitting None arguments.""" return {k: v for k, v in kwargs.items() if v != None} -def support_cc_proto_library(name, srcs, visibility = None, deps = [], cc_deps = [], testonly = 0): - """Generate cc_proto_library for TFLite Support open source version. +def support_cc_proto_library(name, deps = [], visibility = None): + """Generates cc_proto_library for TFLite Support open source version. Args: name: the name of the cc_proto_library. - srcs: the .proto files of the cc_proto_library for Bazel use. + deps: a list of dependency labels for Bazel use; must be proto_library. visibility: visibility of this target. - deps: a list of dependency labels for Bazel use; must be cc_proto_library. - testonly: test only proto or not. """ - _ignore = [deps] - cc_proto_library(**provided_args( + + # Verified in the external path. + # buildifier: disable=native-cc-proto + native.cc_proto_library(**provided_args( name = name, - srcs = srcs, visibility = visibility, - deps = cc_deps, - testonly = testonly, - cc_libs = ["@com_google_protobuf//:protobuf"], - protoc = "@com_google_protobuf//:protoc", - default_runtime = "@com_google_protobuf//:protobuf", - alwayslink = 1, + deps = deps, ))
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/configuration_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/configuration_proto_inc.h new file mode 100644 index 0000000..d014973 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/configuration_proto_inc.h
@@ -0,0 +1,21 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_PORT_CONFIGURATION_PROTO_INC_H_ +#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_PORT_CONFIGURATION_PROTO_INC_H_ + +#include "tensorflow/lite/experimental/acceleration/configuration/configuration.pb.h" + +#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_PORT_CONFIGURATION_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/BUILD index 3f6e9e9..f506a7c3 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/BUILD
@@ -1,30 +1,12 @@ package( default_visibility = [ - "//tensorflow_lite_support/cc/port:__pkg__", - "//tensorflow_lite_support/cc/test:__pkg__", + "//tensorflow_lite_support/cc/port:__subpackages__", + "//tensorflow_lite_support/cc/test:__subpackages__", ], licenses = ["notice"], # Apache 2.0 ) cc_library( - name = "statusor", - srcs = ["statusor.cc"], - hdrs = [ - "statusor.h", - "statusor_internals.h", - ], - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:variant", - "@com_google_absl//absl/utility", - "@com_google_glog//:glog", - ], -) - -cc_library( name = "status_macros", hdrs = [ "status_macros.h", @@ -36,15 +18,47 @@ ) cc_library( + name = "status_matchers", + testonly = 1, + hdrs = ["status_matchers.h"], + deps = [ + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( name = "tflite_wrapper", srcs = ["tflite_wrapper.cc"], hdrs = [ "tflite_wrapper.h", ], deps = [ - "//tensorflow_lite_support/cc/port:status_macros", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@flatbuffers", "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:minimal_logging", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:nnapi_plugin", + "@org_tensorflow//tensorflow/lite/delegates:interpreter_utils", "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_cc_proto", - ], + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:delegate_registry", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:flatbuffer_to_proto", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:proto_to_flatbuffer", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/mini_benchmark", + "//tensorflow_lite_support/cc/port:status_macros", + ] + select({ + # We only intend to use TFLite mini-benchmark on arm-based Andorid and x86_64 Linux. + "@org_tensorflow//tensorflow:android_arm": [ + "@org_tensorflow//tensorflow/lite/experimental/acceleration/mini_benchmark:mini_benchmark_implementation", + ], + "@org_tensorflow//tensorflow:android_arm64": [ + "@org_tensorflow//tensorflow/lite/experimental/acceleration/mini_benchmark:mini_benchmark_implementation", + ], + "@org_tensorflow//tensorflow:linux_x86_64": [ + "@org_tensorflow//tensorflow/lite/experimental/acceleration/mini_benchmark:mini_benchmark_implementation", + ], + "//conditions:default": [], + }), )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h index 47476c9..cb145db 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h
@@ -17,8 +17,8 @@ #ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_ #define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_ -#include "absl/base/optimization.h" -#include "absl/status/status.h" +#include "absl/base/optimization.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl // Evaluates an expression that produces a `absl::Status`. If the status is not // ok, returns it from the current function.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor.cc deleted file mode 100644 index 058c0070..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor.cc +++ /dev/null
@@ -1,67 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// This file is forked from absl. - -#include "tensorflow_lite_support/cc/port/default/statusor.h" - -#include <utility> - -#include "absl/strings/str_cat.h" -#include "base/logging.h" - -namespace tflite { -namespace support { - -BadStatusOrAccess::BadStatusOrAccess(absl::Status status) - : status_(std::move(status)) {} - -BadStatusOrAccess::~BadStatusOrAccess() = default; - -const char* BadStatusOrAccess::what() const noexcept { - return "Bad StatusOr access"; -} - -const absl::Status& BadStatusOrAccess::status() const { - return status_; -} - -namespace internal_statusor { - -void Helper::HandleInvalidStatusCtorArg(absl::Status* status) { - const char* kMessage = - "An OK status is not a valid constructor argument to StatusOr<T>"; - LOG(DFATAL) << kMessage; - // In optimized builds, we will fall back to ::util::error::INTERNAL. - *status = absl::InternalError(kMessage); -} - -void Helper::Crash(const absl::Status& status) { - LOG(FATAL) << "Attempting to fetch value instead of handling error " - << status; - _Exit(1); -} - -void ThrowBadStatusOrAccess(absl::Status status) { -#ifdef ABSL_HAVE_EXCEPTIONS - throw BadStatusOrAccess(std::move(status)); -#else - LOG(FATAL) << "Attempting to fetch value instead of handling error " - << status; -#endif -} - -} // namespace internal_statusor -} // namespace support -} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor.h deleted file mode 100644 index b8d41c7..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor.h +++ /dev/null
@@ -1,584 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// This file is forked from absl. - -#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_ -#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_ - -#include <exception> -#include <initializer_list> -#include <new> -#include <string> -#include <type_traits> -#include <utility> - -#include "absl/base/optimization.h" -#include "absl/meta/type_traits.h" -#include "absl/status/status.h" -#include "absl/types/variant.h" -#include "absl/utility/utility.h" -#include "tensorflow_lite_support/cc/port/default/statusor_internals.h" - -namespace tflite { -namespace support { - -#ifndef SWIG -class BadStatusOrAccess : public std::exception { - public: - explicit BadStatusOrAccess(absl::Status status); - ~BadStatusOrAccess() override; - const char* what() const noexcept override; - const absl::Status& status() const; - - private: - absl::Status status_; -}; -#endif // !SWIG - -// Returned StatusOr objects may not be ignored. -// Note: Disabled for SWIG as it doesn't parse attributes correctly. Codesearch -// doesn't handle ifdefs as part of a class definitions (b/6995610), so we use a -// forward declaration. -#ifndef SWIG -template <typename T> -class ABSL_MUST_USE_RESULT StatusOr; -#endif - -template <typename T> -class StatusOr : private internal_statusor::StatusOrData<T>, - private internal_statusor::CopyCtorBase<T>, - private internal_statusor::MoveCtorBase<T>, - private internal_statusor::CopyAssignBase<T>, - private internal_statusor::MoveAssignBase<T> { - template <typename U> - friend class StatusOr; - - typedef internal_statusor::StatusOrData<T> Base; - - public: - typedef T value_type; - - // Constructs a new StatusOr with Status::UNKNOWN status. This is marked - // 'explicit' to try to catch cases like 'return {};', where people think - // tflite::support::StatusOr<std::vector<int>> will be initialized with an - // empty vector, instead of a Status::UNKNOWN status. - explicit StatusOr(); - - // StatusOr<T> is copy constructible if T is copy constructible. - StatusOr(const StatusOr&) = default; - // StatusOr<T> is copy assignable if T is copy constructible and copy - // assignable. - StatusOr& operator=(const StatusOr&) = default; - -#ifndef SWIG - - // StatusOr<T> is move constructible if T is move constructible. - StatusOr(StatusOr&&) = default; - // StatusOr<T> is moveAssignable if T is move constructible and move - // assignable. - StatusOr& operator=(StatusOr&&) = default; - - // Converting constructors from StatusOr<U>, when T is constructible from U. - // To avoid ambiguity, they are disabled if T is also constructible from - // StatusOr<U>. Explicit iff the corresponding construction of T from U is - // explicit. - template < - typename U, - absl::enable_if_t< - absl::conjunction< - absl::negation<std::is_same<T, U>>, - std::is_constructible<T, const U&>, - std::is_convertible<const U&, T>, - absl::negation< - internal_statusor:: - IsConstructibleOrConvertibleFromStatusOr<T, U>>>::value, - int> = 0> - StatusOr(const StatusOr<U>& other) // NOLINT - : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {} - template < - typename U, - absl::enable_if_t< - absl::conjunction< - absl::negation<std::is_same<T, U>>, - std::is_constructible<T, const U&>, - absl::negation<std::is_convertible<const U&, T>>, - absl::negation< - internal_statusor:: - IsConstructibleOrConvertibleFromStatusOr<T, U>>>::value, - int> = 0> - explicit StatusOr(const StatusOr<U>& other) - : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {} - - template < - typename U, - absl::enable_if_t< - absl::conjunction< - absl::negation<std::is_same<T, U>>, - std::is_constructible<T, U&&>, - std::is_convertible<U&&, T>, - absl::negation< - internal_statusor:: - IsConstructibleOrConvertibleFromStatusOr<T, U>>>::value, - int> = 0> - StatusOr(StatusOr<U>&& other) // NOLINT - : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {} - template < - typename U, - absl::enable_if_t< - absl::conjunction< - absl::negation<std::is_same<T, U>>, - std::is_constructible<T, U&&>, - absl::negation<std::is_convertible<U&&, T>>, - absl::negation< - internal_statusor:: - IsConstructibleOrConvertibleFromStatusOr<T, U>>>::value, - int> = 0> - explicit StatusOr(StatusOr<U>&& other) - : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {} - - // Conversion copy/move assignment operator, T must be constructible and - // assignable from U. Only enable if T cannot be directly assigned from - // StatusOr<U>. - template < - typename U, - absl::enable_if_t< - absl::conjunction< - absl::negation<std::is_same<T, U>>, - std::is_constructible<T, const U&>, - std::is_assignable<T, const U&>, - absl::negation< - internal_statusor:: - IsConstructibleOrConvertibleOrAssignableFromStatusOr< - T, - U>>>::value, - int> = 0> - StatusOr& operator=(const StatusOr<U>& other) { - this->Assign(other); - return *this; - } - template < - typename U, - absl::enable_if_t< - absl::conjunction< - absl::negation<std::is_same<T, U>>, - std::is_constructible<T, U&&>, - std::is_assignable<T, U&&>, - absl::negation< - internal_statusor:: - IsConstructibleOrConvertibleOrAssignableFromStatusOr< - T, - U>>>::value, - int> = 0> - StatusOr& operator=(StatusOr<U>&& other) { - this->Assign(std::move(other)); - return *this; - } - -#endif // SWIG - - // Constructs a new StatusOr with the given value. After calling this - // constructor, this->ok() will be true and the contained value may be - // retrieved with value(), operator*(), or operator->(). - // - // NOTE: Not explicit - we want to use StatusOr<T> as a return type - // so it is convenient and sensible to be able to do 'return T()' - // when the return type is StatusOr<T>. - // - // REQUIRES: T is copy constructible. - // TODO(b/113125838): Replace this constructor with a direct-initialization - // constructor. - StatusOr(const T& value); - - // Constructs a new StatusOr with the given non-ok status. After calling this - // constructor, this->ok() will be false and calls to value() will CHECK-fail. - // - // NOTE: Not explicit - we want to use StatusOr<T> as a return - // value, so it is convenient and sensible to be able to do 'return - // Status()' when the return type is StatusOr<T>. - // - // REQUIRES: !status.ok(). This requirement is DCHECKed. - // In optimized builds, passing util::OkStatus() here will have the effect - // of passing util::error::INTERNAL as a fallback. - StatusOr(const absl::Status& status); - StatusOr& operator=(const absl::Status& status); - -#ifndef SWIG - // Perfect-forwarding value assignment operator. - // If `*this` contains a `T` value before the call, the contained value is - // assigned from `std::forward<U>(v)`; Otherwise, it is directly-initialized - // from `std::forward<U>(v)`. - // This function does not participate in overload unless: - // 1. `std::is_constructible_v<T, U>` is true, - // 2. `std::is_assignable_v<T&, U>` is true. - // 3. `std::is_same_v<StatusOr<T>, std::remove_cvref_t<U>>` is false. - // 4. Assigning `U` to `T` is not ambiguous: - // If `U` is `StatusOr<V>` and `T` is constructible and assignable from - // both `StatusOr<V>` and `V`, the assignment is considered bug-prone and - // ambiguous thus will fail to compile. For example: - // StatusOr<bool> s1 = true; // s1.ok() && *s1 == true - // StatusOr<bool> s2 = false; // s2.ok() && *s2 == false - // s1 = s2; // ambiguous, `s1 = *s2` or `s1 = bool(s2)`? - template < - typename U = T, - typename = typename std::enable_if<absl::conjunction< - std::is_constructible<T, U&&>, - std::is_assignable<T&, U&&>, - internal_statusor::IsForwardingAssignmentValid<T, U&&>>::value>::type> - StatusOr& operator=(U&& v) { - this->Assign(std::forward<U>(v)); - return *this; - } - - // Similar to the `const T&` overload. - // - // REQUIRES: T is move constructible. - StatusOr(T&& value); - - // RValue versions of the operations declared above. - StatusOr(absl::Status&& status); - StatusOr& operator=(absl::Status&& status); - - // Constructs the inner value T in-place using the provided args, using the - // T(args...) constructor. - template <typename... Args> - explicit StatusOr(absl::in_place_t, Args&&... args); - template <typename U, typename... Args> - explicit StatusOr(absl::in_place_t, - std::initializer_list<U> ilist, - Args&&... args); - - // Constructs the inner value T in-place using the provided args, using the - // T(U) (direct-initialization) constructor. Only valid if T can be - // constructed from a U. Can accept move or copy constructors. Explicit if - // U is not convertible to T. To avoid ambiguity, this is disabled if U is - // a StatusOr<J>, where J is convertible to T. - // Style waiver for implicit conversion granted in cl/209187539. - template <typename U = T, - absl::enable_if_t< - absl::conjunction< - internal_statusor::IsDirectInitializationValid<T, U&&>, - std::is_constructible<T, U&&>, - std::is_convertible<U&&, T>>::value, - int> = 0> - StatusOr(U&& u) // NOLINT - : StatusOr(absl::in_place, std::forward<U>(u)) {} - - template <typename U = T, - absl::enable_if_t< - absl::conjunction< - internal_statusor::IsDirectInitializationValid<T, U&&>, - std::is_constructible<T, U&&>, - absl::negation<std::is_convertible<U&&, T>>>::value, - int> = 0> - explicit StatusOr(U&& u) // NOLINT - : StatusOr(absl::in_place, std::forward<U>(u)) {} - -#endif // SWIG - - // Returns this->status().ok() - ABSL_MUST_USE_RESULT bool ok() const { return this->status_.ok(); } - - // Returns a reference to our status. If this contains a T, then - // returns util::OkStatus(). -#ifdef SWIG - const ::util::Status& status() const; -#else // SWIG - const absl::Status& status() const&; - absl::Status status() &&; -#endif // SWIG - - // Returns a reference to the held value if `this->ok()`. Otherwise, throws - // `absl::BadStatusOrAccess` if exception is enabled, or `LOG(FATAL)` if - // exception is disabled. - // If you have already checked the status using `this->ok()` or - // `operator bool()`, you probably want to use `operator*()` or `operator->()` - // to access the value instead of `value`. - // Note: for value types that are cheap to copy, prefer simple code: - // - // T value = statusor.value(); - // - // Otherwise, if the value type is expensive to copy, but can be left - // in the StatusOr, simply assign to a reference: - // - // T& value = statusor.value(); // or `const T&` - // - // Otherwise, if the value type supports an efficient move, it can be - // used as follows: - // - // T value = std::move(statusor).value(); - // - // The `std::move` on statusor instead of on the whole expression enables - // warnings about possible uses of the statusor object after the move. -#ifdef SWIG - const T& value() const; -#else // SWIG - const T& value() const&; - T& value() &; - const T&& value() const&&; - T&& value() &&; -#endif // SWIG - -#ifndef SWIG - // Returns a reference to the current value. - // - // REQUIRES: this->ok() == true, otherwise the behavior is undefined. - // - // Use this->ok() or `operator bool()` to verify that there is a current - // value. Alternatively, see value() for a similar API that guarantees - // CHECK-failing if there is no current value. - const T& operator*() const&; - T& operator*() &; - const T&& operator*() const&&; - T&& operator*() &&; -#endif // SWIG - -#ifndef SWIG - // Returns a pointer to the current value. - // - // REQUIRES: this->ok() == true, otherwise the behavior is undefined. - // - // Use this->ok() or `operator bool()` to verify that there is a current - // value. - const T* operator->() const; - T* operator->(); -#endif // SWIG - -#ifndef SWIG - // Returns a copy of the current value if this->ok() == true. Otherwise - // returns a default value. - template <typename U> - T value_or(U&& default_value) const&; - template <typename U> - T value_or(U&& default_value) &&; -#endif // SWIG - - // Ignores any errors. This method does nothing except potentially suppress - // complaints from any tools that are checking that errors are not dropped on - // the floor. - void IgnoreError() const; - -#ifndef SWIG - // Reconstructs the inner value T in-place using the provided args, using the - // T(args...) constructor. Returns reference to the reconstructed `T`. - template <typename... Args> - T& emplace(Args&&... args) { - if (ok()) { - this->Clear(); - this->MakeValue(std::forward<Args>(args)...); - } else { - this->MakeValue(std::forward<Args>(args)...); - this->status_ = absl::OkStatus(); - } - return this->data_; - } - - template < - typename U, - typename... Args, - absl::enable_if_t< - std::is_constructible<T, std::initializer_list<U>&, Args&&...>::value, - int> = 0> - T& emplace(std::initializer_list<U> ilist, Args&&... args) { - if (ok()) { - this->Clear(); - this->MakeValue(ilist, std::forward<Args>(args)...); - } else { - this->MakeValue(ilist, std::forward<Args>(args)...); - this->status_ = absl::OkStatus(); - } - return this->data_; - } -#endif // SWIG - - private: -#ifndef SWIG - using internal_statusor::StatusOrData<T>::Assign; - template <typename U> - void Assign(const StatusOr<U>& other); - template <typename U> - void Assign(StatusOr<U>&& other); -#endif // SWIG -}; - -#ifndef SWIG -//////////////////////////////////////////////////////////////////////////////// -// Implementation details for StatusOr<T> - -template <typename T> -tflite::support::StatusOr<T>::StatusOr() - : Base(absl::Status(absl::StatusCode::kUnknown, "")) {} - -template <typename T> -tflite::support::StatusOr<T>::StatusOr(const T& value) : Base(value) {} - -template <typename T> -tflite::support::StatusOr<T>::StatusOr(const absl::Status& status) - : Base(status) {} - -template <typename T> -tflite::support::StatusOr<T>& StatusOr<T>::operator=( - const absl::Status& status) { - this->Assign(status); - return *this; -} - -template <typename T> -tflite::support::StatusOr<T>::StatusOr(T&& value) : Base(std::move(value)) {} - -template <typename T> -tflite::support::StatusOr<T>::StatusOr(absl::Status&& status) - : Base(std::move(status)) {} - -template <typename T> -tflite::support::StatusOr<T>& StatusOr<T>::operator=(absl::Status&& status) { - this->Assign(std::move(status)); - return *this; -} - -template <typename T> -template <typename U> -inline void StatusOr<T>::Assign(const StatusOr<U>& other) { - if (other.ok()) { - this->Assign(other.value()); - } else { - this->Assign(other.status()); - } -} - -template <typename T> -template <typename U> -inline void StatusOr<T>::Assign(StatusOr<U>&& other) { - if (other.ok()) { - this->Assign(std::move(other).value()); - } else { - this->Assign(std::move(other).status()); - } -} -template <typename T> -template <typename... Args> -tflite::support::StatusOr<T>::StatusOr(absl::in_place_t, Args&&... args) - : Base(absl::in_place, std::forward<Args>(args)...) {} - -template <typename T> -template <typename U, typename... Args> -tflite::support::StatusOr<T>::StatusOr(absl::in_place_t, - std::initializer_list<U> ilist, - Args&&... args) - : Base(absl::in_place, ilist, std::forward<Args>(args)...) {} - -template <typename T> -const absl::Status& StatusOr<T>::status() const& { - return this->status_; -} -template <typename T> -absl::Status StatusOr<T>::status() && { - return ok() ? absl::OkStatus() : std::move(this->status_); -} - -template <typename T> -const T& StatusOr<T>::value() const& { - if (!this->ok()) - internal_statusor::ThrowBadStatusOrAccess(this->status_); - return this->data_; -} - -template <typename T> -T& StatusOr<T>::value() & { - if (!this->ok()) - internal_statusor::ThrowBadStatusOrAccess(this->status_); - return this->data_; -} - -template <typename T> -const T&& StatusOr<T>::value() const&& { - if (!this->ok()) { - internal_statusor::ThrowBadStatusOrAccess(std::move(this->status_)); - } - return std::move(this->data_); -} - -template <typename T> -T&& StatusOr<T>::value() && { - if (!this->ok()) { - internal_statusor::ThrowBadStatusOrAccess(std::move(this->status_)); - } - return std::move(this->data_); -} - -template <typename T> -const T& StatusOr<T>::operator*() const& { - this->EnsureOk(); - return this->data_; -} - -template <typename T> -T& StatusOr<T>::operator*() & { - this->EnsureOk(); - return this->data_; -} - -template <typename T> -const T&& StatusOr<T>::operator*() const&& { - this->EnsureOk(); - return std::move(this->data_); -} - -template <typename T> -T&& StatusOr<T>::operator*() && { - this->EnsureOk(); - return std::move(this->data_); -} - -template <typename T> -const T* StatusOr<T>::operator->() const { - this->EnsureOk(); - return &this->data_; -} - -template <typename T> -T* StatusOr<T>::operator->() { - this->EnsureOk(); - return &this->data_; -} - -template <typename T> -template <typename U> -T StatusOr<T>::value_or(U&& default_value) const& { - if (ok()) { - return this->data_; - } - return std::forward<U>(default_value); -} - -template <typename T> -template <typename U> -T StatusOr<T>::value_or(U&& default_value) && { - if (ok()) { - return std::move(this->data_); - } - return std::forward<U>(default_value); -} - -template <typename T> -void StatusOr<T>::IgnoreError() const { - // no-op -} - -#endif // SWIG - -} // namespace support -} // namespace tflite - -#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h index 30b7e3f..81ec3c1 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h
@@ -20,9 +20,9 @@ #include <type_traits> #include <utility> -#include "absl/meta/type_traits.h" -#include "absl/status/status.h" -#include "absl/utility/utility.h" +#include "absl/meta/type_traits.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/utility/utility.h" // from @com_google_absl namespace tflite { namespace support {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc index 548e679..0b3e5d6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
@@ -15,45 +15,332 @@ #include "tensorflow_lite_support/cc/port/default/tflite_wrapper.h" -#include "absl/status/status.h" +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/interpreter_utils.h" +#include "tensorflow/lite/experimental/acceleration/configuration/flatbuffer_to_proto.h" +#include "tensorflow/lite/experimental/acceleration/configuration/proto_to_flatbuffer.h" +#include "tensorflow/lite/minimal_logging.h" #include "tensorflow_lite_support/cc/port/status_macros.h" namespace tflite { namespace support { +namespace { +using tflite::delegates::DelegatePluginRegistry; +using tflite::delegates::InterpreterUtils; +using tflite::proto::ComputeSettings; +using tflite::proto::Delegate; +} // namespace + +/* static */ +absl::Status TfLiteInterpreterWrapper::SanityCheckComputeSettings( + const ComputeSettings& compute_settings) { + Delegate delegate = compute_settings.tflite_settings().delegate(); + if (delegate != Delegate::NONE && delegate != Delegate::GPU && + delegate != Delegate::HEXAGON && delegate != Delegate::NNAPI && + delegate != Delegate::XNNPACK && delegate != Delegate::EDGETPU_CORAL) { + return absl::UnimplementedError(absl::StrFormat( + "Using delegate '%s' is not supported.", Delegate_Name(delegate))); + } + return absl::OkStatus(); +} + +TfLiteInterpreterWrapper::TfLiteInterpreterWrapper( + const std::string& default_model_namespace, + const std::string& default_model_id) + : delegate_(nullptr, nullptr), + got_error_do_not_delegate_anymore_(false), + default_model_namespace_(default_model_namespace), + default_model_id_(default_model_id), + mini_benchmark_(nullptr) {} + +std::string TfLiteInterpreterWrapper::ModelNamespace() { + const auto& ns_from_acceleration = + compute_settings_.model_namespace_for_statistics(); + return ns_from_acceleration.empty() ? default_model_namespace_ + : ns_from_acceleration; +} + +std::string TfLiteInterpreterWrapper::ModelID() { + const auto& id_from_acceleration = + compute_settings_.model_identifier_for_statistics(); + return id_from_acceleration.empty() ? default_model_id_ + : id_from_acceleration; +} + +// This is the deprecated overload that doesn't take an +// InterpreterCreationResources parameter. absl::Status TfLiteInterpreterWrapper::InitializeWithFallback( std::function<absl::Status(std::unique_ptr<tflite::Interpreter>*)> interpreter_initializer, - const tflite::proto::ComputeSettings& compute_settings) { - if (compute_settings.has_preference() || - compute_settings.has_tflite_settings()) { - return absl::UnimplementedError( - "Acceleration via ComputeSettings is not supported yet."); + const ComputeSettings& compute_settings) { + return InitializeWithFallback( + [interpreter_initializer]( + const InterpreterCreationResources& resources, + std::unique_ptr<tflite::Interpreter>* interpreter_out) + -> absl::Status { + RETURN_IF_ERROR(interpreter_initializer(interpreter_out)); + if (*interpreter_out != nullptr && + resources.optional_delegate != nullptr) { + TfLiteStatus status = + (*interpreter_out) + ->ModifyGraphWithDelegate(resources.optional_delegate); + if (status != kTfLiteOk) { + *interpreter_out = nullptr; + RETURN_IF_ERROR( + absl::InvalidArgumentError("Applying delegate failed")); + } + } + return absl::OkStatus(); + }, + compute_settings); +} + +absl::Status TfLiteInterpreterWrapper::InitializeWithFallback( + std::function<absl::Status(const InterpreterCreationResources&, + std::unique_ptr<tflite::Interpreter>*)> + interpreter_initializer, + const ComputeSettings& compute_settings) { + // Store interpreter initializer if not already here. + if (interpreter_initializer_) { + return absl::FailedPreconditionError( + "InitializeWithFallback already called."); } - RETURN_IF_ERROR(interpreter_initializer(&interpreter_)); - return interpreter_->AllocateTensors() != kTfLiteOk - ? absl::InternalError( - "TFLite interpreter: AllocateTensors() failed.") - : absl::OkStatus(); + interpreter_initializer_ = std::move(interpreter_initializer); + + // Sanity check and copy ComputeSettings. + RETURN_IF_ERROR(SanityCheckComputeSettings(compute_settings)); + compute_settings_ = compute_settings; + if (compute_settings_.has_settings_to_test_locally()) { + flatbuffers::FlatBufferBuilder mini_benchmark_settings_fbb; + const auto* mini_benchmark_settings = + tflite::ConvertFromProto(compute_settings_.settings_to_test_locally(), + &mini_benchmark_settings_fbb); + mini_benchmark_ = tflite::acceleration::CreateMiniBenchmark( + *mini_benchmark_settings, ModelNamespace(), ModelID()); + const tflite::ComputeSettingsT from_minibenchmark = + mini_benchmark_->GetBestAcceleration(); + if (from_minibenchmark.tflite_settings != nullptr) { + TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO, "Using mini benchmark results\n"); + compute_settings_ = tflite::ConvertFromFlatbuffer( + from_minibenchmark, /*skip_mini_benchmark_settings=*/true); + } + // Trigger mini benchmark if it hasn't already run. Vast majority of cases + // should not actually do anything, since first runs are rare. + mini_benchmark_->TriggerMiniBenchmark(); + mini_benchmark_->MarkAndGetEventsToLog(); + } + + // Initialize fallback behavior. + fallback_on_compilation_error_ = + compute_settings_.tflite_settings() + .fallback_settings() + .allow_automatic_fallback_on_compilation_error() || + // Deprecated, keep supporting for backward compatibility. + compute_settings_.tflite_settings() + .nnapi_settings() + .fallback_settings() + .allow_automatic_fallback_on_compilation_error(); + fallback_on_execution_error_ = + compute_settings_.tflite_settings() + .fallback_settings() + .allow_automatic_fallback_on_execution_error() || + // Deprecated, keep supporting for backward compatibility. + compute_settings_.tflite_settings() + .nnapi_settings() + .fallback_settings() + .allow_automatic_fallback_on_execution_error(); + + return InitializeWithFallbackAndResize(); +} + +absl::Status TfLiteInterpreterWrapper::AllocateTensors() { + if (interpreter_->AllocateTensors() != kTfLiteOk) { + return absl::InternalError("AllocateTensors() failed."); + } + return absl::OkStatus(); +} + +// TODO(b/173406463): the `resize` parameter is going to be used by +// ResizeAndAllocateTensors functions, coming soon. +absl::Status TfLiteInterpreterWrapper::InitializeWithFallbackAndResize( + std::function<absl::Status(Interpreter*)> resize) { + InterpreterCreationResources resources{}; + if (got_error_do_not_delegate_anymore_ || + compute_settings_.tflite_settings().delegate() == Delegate::NONE) { + delegate_.reset(nullptr); + } else { + // Initialize delegate and add it to 'resources'. + RETURN_IF_ERROR(InitializeDelegate()); + resources.optional_delegate = delegate_.get(); + } + + absl::Status status = interpreter_initializer_(resources, &interpreter_); + if (resources.optional_delegate == nullptr) { + RETURN_IF_ERROR(status); + } + if (resources.optional_delegate != nullptr && !status.ok()) { + // Any error when constructing the interpreter is assumed to be a delegate + // compilation error. If a delegate compilation error occurs, stop + // delegation from happening in the future. + got_error_do_not_delegate_anymore_ = true; + delegate_.reset(nullptr); + if (fallback_on_compilation_error_) { + InterpreterCreationResources fallback_resources{}; + fallback_resources.optional_delegate = nullptr; + RETURN_IF_ERROR( + interpreter_initializer_(fallback_resources, &interpreter_)); + } else { + // If instructed not to fallback, return error. + return absl::InternalError(absl::StrFormat( + "ModifyGraphWithDelegate() failed for delegate '%s'.", + Delegate_Name(compute_settings_.tflite_settings().delegate()))); + } + } + + RETURN_IF_ERROR(resize(interpreter_.get())); + if (compute_settings_.tflite_settings().cpu_settings().num_threads() != -1) { + if (interpreter_->SetNumThreads( + compute_settings_.tflite_settings().cpu_settings().num_threads()) != + kTfLiteOk) { + return absl::InternalError("Failed setting number of CPU threads"); + } + } + SetTfLiteCancellation(); + + if (!delegate_) { + // Just allocate tensors and return. + return AllocateTensors(); + } + + // The call to ModifyGraphWithDelegate() leaves the interpreter in a usable + // state in case of failure: calling AllocateTensors() will silently fallback + // on CPU in such a situation. + return AllocateTensors(); +} + +absl::Status TfLiteInterpreterWrapper::InitializeDelegate() { + if (delegate_ == nullptr) { + Delegate which_delegate = compute_settings_.tflite_settings().delegate(); + const tflite::ComputeSettings* compute_settings = + tflite::ConvertFromProto(compute_settings_, &flatbuffers_builder_); + + if (which_delegate == Delegate::NNAPI) { + RETURN_IF_ERROR( + LoadDelegatePlugin("Nnapi", *compute_settings->tflite_settings())); + } else if (which_delegate == Delegate::HEXAGON) { + RETURN_IF_ERROR( + LoadDelegatePlugin("Hexagon", *compute_settings->tflite_settings())); + } else if (which_delegate == Delegate::GPU) { + RETURN_IF_ERROR( + LoadDelegatePlugin("Gpu", *compute_settings->tflite_settings())); + } else if (which_delegate == Delegate::EDGETPU) { + RETURN_IF_ERROR( + LoadDelegatePlugin("EdgeTpu", *compute_settings->tflite_settings())); + } else if (which_delegate == Delegate::EDGETPU_CORAL) { + RETURN_IF_ERROR(LoadDelegatePlugin("EdgeTpuCoral", + *compute_settings->tflite_settings())); + } else if (which_delegate == Delegate::XNNPACK) { + RETURN_IF_ERROR( + LoadDelegatePlugin("XNNPack", *compute_settings->tflite_settings())); + } + } + return absl::OkStatus(); } absl::Status TfLiteInterpreterWrapper::InvokeWithFallback( const std::function<absl::Status(tflite::Interpreter* interpreter)>& set_inputs) { RETURN_IF_ERROR(set_inputs(interpreter_.get())); - return interpreter_->Invoke() != kTfLiteOk - ? absl::InternalError("TFLite interpreter: Invoke() failed.") - : absl::OkStatus(); + // Reset cancel flag before calling `Invoke()`. + cancel_flag_.Set(false); + TfLiteStatus status = kTfLiteError; + if (fallback_on_execution_error_) { + status = InterpreterUtils::InvokeWithCPUFallback(interpreter_.get()); + } else { + status = interpreter_->Invoke(); + } + if (status == kTfLiteOk) { + return absl::OkStatus(); + } + // Assume InvokeWithoutFallback() is guarded under caller's synchronization. + // Assume the inference is cancelled successfully if Invoke() returns + // kTfLiteError and the cancel flag is `true`. + if (status == kTfLiteError && cancel_flag_.Get()) { + return absl::CancelledError("Invoke() cancelled."); + } + if (delegate_) { + // Mark that an error occurred so that later invocations immediately + // fallback to CPU. + got_error_do_not_delegate_anymore_ = true; + // InvokeWithCPUFallback returns `kTfLiteDelegateError` in case of + // *successful* fallback: convert it to an OK status. + if (status == kTfLiteDelegateError) { + return absl::OkStatus(); + } + } + return absl::InternalError("Invoke() failed."); } absl::Status TfLiteInterpreterWrapper::InvokeWithoutFallback() { - return interpreter_->Invoke() != kTfLiteOk - ? absl::InternalError("TFLite interpreter: Invoke() failed.") - : absl::OkStatus(); + // Reset cancel flag before calling `Invoke()`. + cancel_flag_.Set(false); + TfLiteStatus status = interpreter_->Invoke(); + if (status != kTfLiteOk) { + // Assume InvokeWithoutFallback() is guarded under caller's synchronization. + // Assume the inference is cancelled successfully if Invoke() returns + // kTfLiteError and the cancel flag is `true`. + if (status == kTfLiteError && cancel_flag_.Get()) { + return absl::CancelledError("Invoke() cancelled."); + } + return absl::InternalError("Invoke() failed."); + } + return absl::OkStatus(); } void TfLiteInterpreterWrapper::Cancel() { - // NOP + cancel_flag_.Set(true); +} + +void TfLiteInterpreterWrapper::SetTfLiteCancellation() { + // Create a cancellation check function and set to the TFLite interpreter. + auto check_cancel_flag = [](void* data) { + auto* cancel_flag = reinterpret_cast<CancelFlag*>(data); + return cancel_flag->Get(); + }; + interpreter_->SetCancellationFunction(reinterpret_cast<void*>(&cancel_flag_), + check_cancel_flag); +} + +absl::Status TfLiteInterpreterWrapper::LoadDelegatePlugin( + const std::string& name, + const tflite::TFLiteSettings& tflite_settings) { + delegate_plugin_ = DelegatePluginRegistry::CreateByName( + absl::StrFormat("%sPlugin", name), tflite_settings); + + if (delegate_plugin_ == nullptr) { + return absl::InternalError(absl::StrFormat( + "Could not create %s plugin. Have you linked in the %s_plugin target?", + name, name)); + } + + delegate_ = delegate_plugin_->Create(); + if (delegate_ == nullptr) { + return absl::InternalError( + absl::StrFormat("Plugin did not create %s delegate.", name)); + } + + return absl::OkStatus(); +} + +bool TfLiteInterpreterWrapper::HasMiniBenchmarkCompleted() { + if (mini_benchmark_ != nullptr && + mini_benchmark_->NumRemainingAccelerationTests() == 0) { + return true; + } + return false; } } // namespace support
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h index 3fd489f7..9f32fa8 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h
@@ -16,40 +16,123 @@ #define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_TFLITE_WRAPPER_H_ #include <memory> +#include <string> #include <utility> -#include "absl/status/status.h" +#include "absl/status/status.h" // from @com_google_absl +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/experimental/acceleration/configuration/configuration.pb.h" +#include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h" +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.h" #include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/interpreter_builder.h" namespace tflite { namespace support { -// Wrapper for a TfLiteInterpreter that may be accelerated[1]. This is NOT yet -// implemented: this class only provides a first, minimal interface in the -// meanwhile. +// Options that are created by `TFLiteInterpreterWrapper` and will help to +// initialize Interpreter in the callback function. `TFLiteInterpreterWrapper` +// retains ownership of the included options, and will ensure that they remain +// valid for the duration of the created interpreter's lifetime. +struct InterpreterCreationResources { + // The delegate created, based on the parameters in `ComputeSettings`. + // `TfLiteInterpreterWrapper` exclusively owns the `TfLiteDelegate` object, + // and maintains it through out the lifetime of `TfLiteInterpreterWrapper`. + TfLiteDelegate* optional_delegate; + + // Number of threads to use, or -1 to use the default number of threads. + int num_threads = -1; + + // Apply the InterpreterCreationResources to the InterpreterBuilder. + // Note: caller is responsible for ensuring that arguments are valid, + // e.g. that num_threads >= -1. + void ApplyTo(tflite::InterpreterBuilder* interpreter_builder) const { + if (optional_delegate != nullptr) { + interpreter_builder->AddDelegate(optional_delegate); + } + if (num_threads != -1) { + // We ignore the TfLiteStatus return value here; caller is responsible + // for checking that num_threads is valid. + (void)interpreter_builder->SetNumThreads(num_threads); + } + } +}; + +// Wrapper for a TfLiteInterpreter that may be accelerated [1]. Meant to be +// substituted for `unique_ptr<tflite::Interpreter>` class members. // -// [1] See tensorflow/lite/experimental/acceleration for more details. +// This class is in charge of: +// * Picking, instantiating and configuring the right delegate for the provided +// ComputeSettings [2], +// * Providing methods to initialize and invoke the Interpreter with optional +// (controlled through the ComputeSettings) automatic fallback to CPU if any +// acceleration-related error occurs at compilation or runtime. +// * TODO(b/169474250) Cache interpreters for multiple input sizes to enable +// performant acceleration for the case where input size changes frequently. +// +// IMPORTANT: The only supported delegates are (as defined in [1]) NONE, GPU, +// HEXAGON and NNAPI. Trying to use this class with EDGETPU or XNNPACK delegates +// will cause an UnimplementedError to be thrown at initialization time. +// +// Like TfLiteInterpreter, this class is thread-compatible. Use from multiple +// threads must be guarded by synchronization outside this class. +// +// [1]: +// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/configuration.proto class TfLiteInterpreterWrapper { public: - TfLiteInterpreterWrapper() = default; + // Creates an instance to be associated with a TfLite model that could be + // identified by (`default_model_namespace`, `default_model_id`). Note the + // model identifier is generally used for the sake of logging. + TfLiteInterpreterWrapper(const std::string& default_model_namespace, + const std::string& default_model_id); + TfLiteInterpreterWrapper() + : TfLiteInterpreterWrapper("org.tensorflow.lite.support", + "unknown_model_id") {} virtual ~TfLiteInterpreterWrapper() = default; - // Calls `interpreter_initializer` and then `AllocateTensors`. Future - // implementation of this method will attempt to apply the provided - // `compute_settings` with a graceful fallback in case a failure occurs. - // Note: before this gets implemented, do NOT call this method with non-empty - // `compute_settings` otherwise an unimplemented error occurs. + // Calls `interpreter_initializer` to construct the Interpreter, then + // initializes it with the appropriate delegate (if any) specified through + // `compute_settings` and finally calls AllocateTensors() on it. + // + // Whether or not this function automatically falls back to using CPU in case + // initialization with a delegate fails depends on the FallbackSettings + // specified in the TFLiteSettings of the provided ComputeSettings: if the + // `allow_automatic_fallback_on_compilation_error` field is set to true, + // fallback will automatically happen; otherwise an InternalError will be + // thrown. + // This flag allows callers to rely on this function whether or not they + // actually want fallback to happen; if they don't, it will ensure that the + // configuration doesn't accidentally trigger fallback. + // + // IMPORTANT: Supported delegate type includes: NONE, NNAPI, GPU, HEXAGON, + // XNNPACK, EDGETPU (Google internal), and EDGETPU_CORAL. Specifying another + // delegate type may cause an UnimplementedError to be thrown. + absl::Status InitializeWithFallback( + std::function<absl::Status(const InterpreterCreationResources&, + std::unique_ptr<tflite::Interpreter>*)> + interpreter_initializer, + const tflite::proto::ComputeSettings& compute_settings); + + // Deprecated: Use the one above with `InterpreterCreationResources` instead. absl::Status InitializeWithFallback( std::function<absl::Status(std::unique_ptr<tflite::Interpreter>*)> interpreter_initializer, const tflite::proto::ComputeSettings& compute_settings); - // Calls `set_inputs` and then Invoke() on the interpreter. Future - // implementation of this method will perform a graceful fallback in case a - // failure occur due to the `compute_settings` provided at initialization - // time. + // Calls `set_inputs` and then Invoke() on the interpreter. + // + // Whether or not this function automatically falls back to using CPU in case + // invocation with a delegate fails depends on the FallbackSettings + // specified in the TFLiteSettings of the ComputeSettings provided at + // initialization: if the `allow_automatic_fallback_on_execution_error` + // field is set to true, fallback will automatically happen; otherwise an + // InternalError will be thrown. + // This flag allows callers to rely on this function whether or not they + // actually want fallback to happen; if they don't, it will ensure that the + // configuration doesn't accidentally trigger fallback. absl::Status InvokeWithFallback( const std::function<absl::Status(tflite::Interpreter* interpreter)>& set_inputs); @@ -58,8 +141,23 @@ // before-hand. absl::Status InvokeWithoutFallback(); - // Cancels the current running TFLite invocation on CPU. This method is not - // yet implemented though it is safe to use it as it acts as a NOP. + // Cancels the current TFLite **CPU** inference. + // + // IMPORTANT: If inference is entirely running on a delegate, this has no + // effect; if inference is partially delegated, only the CPU part is + // cancelled. + // + // Usually called on a different thread than the one Invoke() is running + // on. Calling Cancel() while InvokeWithFallback() or InvokeWithoutFallback() + // is running may cause these methods to return a `CancelledError` with empty + // results. Calling Cancel() at any other time doesn't have any effect. + // + // InvokeWithFallback() and InvokeWithoutFallback() reset the cancel flag + // right before the underlying Invoke() is called, so these two methods can be + // called again on the same instance after a call to Cancel(). + // + // Note that this is the only method that can be called from another thread + // without locking. void Cancel(); // Accesses the underlying interpreter for other methods. @@ -72,8 +170,109 @@ TfLiteInterpreterWrapper(const TfLiteInterpreterWrapper&) = delete; TfLiteInterpreterWrapper& operator=(const TfLiteInterpreterWrapper&) = delete; + // Whether an error has occurred with the delegate. + bool HasDelegateError() { return got_error_do_not_delegate_anymore_; } + + // Whether the on-device mini-benchmark has completed for those TfLite + // acceleration configurations that are specified in passed-in + // ComputeSettings. If it finishes, the next time this same InterpreterWrapper + // object is created (i.e. w/ the same model and the same + // mini-benchmark-related configurations), the best acceleration configuration + // will be picked up and used. + bool HasMiniBenchmarkCompleted(); + + const tflite::proto::ComputeSettings& compute_settings() const { + return compute_settings_; + } + + protected: + // The delegate used to accelerate inference. + Interpreter::TfLiteDelegatePtr delegate_; + // The corresponding delegate plugin. + std::unique_ptr<tflite::delegates::DelegatePluginInterface> delegate_plugin_; + private: + // Performs sanity checks on the provided ComputeSettings. + static absl::Status SanityCheckComputeSettings( + const tflite::proto::ComputeSettings& compute_settings); + + // Inner function for initializing an interpreter with fallback, optionally + // resizing input tensors by calling `resize` on the newly initialized + // interpreter. + absl::Status InitializeWithFallbackAndResize( + std::function<absl::Status(Interpreter* interpreter)> resize = + [](Interpreter* interpreter) { return absl::OkStatus(); }); + + // Initializes the delegate plugin and creates the delegate. + absl::Status InitializeDelegate(); + + // Wrapper around the interpreter's `AllocateTensors()` method converting the + // returned `TfLiteStatus` to an `absl::Status`. + absl::Status AllocateTensors(); + + absl::Status LoadDelegatePlugin(const std::string&, + const tflite::TFLiteSettings&); + + std::string ModelNamespace(); + std::string ModelID(); + + // The interpreter instance being used. std::unique_ptr<tflite::Interpreter> interpreter_; + // The function used to initialize the interpreter and store it into the + // provided `std::unique_ptr`. + // This is typically a wrapper function around `tflite::InterpreterBuilder`, + // giving the caller the opportunity to hook-up a custom `tflite::OpResolver` + // and / or `tflite::ErrorReporter`. + std::function<absl::Status(const InterpreterCreationResources&, + std::unique_ptr<Interpreter>*)> + interpreter_initializer_; + + // The ComputeSettings provided at initialization time. + // Note when TfLite mini-benchmark is enabled, it could be changed to the + // best TfLite acceleration setting selected. + tflite::proto::ComputeSettings compute_settings_; + + // Set to true if an occurs with the specified delegate (if any), causing + // future calls to fallback on CPU. + bool got_error_do_not_delegate_anymore_; + + // Fallback behavior as specified through the ComputeSettings. + bool fallback_on_compilation_error_; + bool fallback_on_execution_error_; + + std::string default_model_namespace_; + std::string default_model_id_; + + // Used to convert the ComputeSettings proto to FlatBuffer format. + flatbuffers::FlatBufferBuilder flatbuffers_builder_; + + // Cancellation flag definition. + struct CancelFlag { + // Mutex to guard the `cancel_flag`. + mutable absl::Mutex cancel_mutex; + + // A flag indicates if the caller cancels the TFLite interpreter invocation. + bool cancel_flag ABSL_GUARDED_BY(cancel_mutex) = false; + + // Returns `cancel_flag`. + bool Get() const ABSL_LOCKS_EXCLUDED(cancel_mutex) { + absl::MutexLock cancel_lock(&cancel_mutex); + return cancel_flag; + } + + // Sets `cancel_flag` to `value`. + void Set(bool value) ABSL_LOCKS_EXCLUDED(cancel_mutex) { + absl::MutexLock cancel_lock(&cancel_mutex); + cancel_flag = value; + } + }; + CancelFlag cancel_flag_; + + std::unique_ptr<tflite::acceleration::MiniBenchmark> mini_benchmark_; + + // Sets up the TFLite invocation cancellation by + // tflite::Interpreter::SetCancellationFunction(). + void SetTfLiteCancellation(); }; } // namespace support
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h index 76d9d50..dc6183be 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h
@@ -37,6 +37,19 @@ #define GG_LL_FORMAT "ll" // As in "%lld". Note that "q" is poor form also. #define GG_LL_FORMAT_W L"ll" +const uint8 kuint8max{0xFF}; +const uint16 kuint16max{0xFFFF}; +const uint32 kuint32max{0xFFFFFFFF}; +const uint64 kuint64max{GG_ULONGLONG(0xFFFFFFFFFFFFFFFF)}; +const int8 kint8min{~0x7F}; +const int8 kint8max{0x7F}; +const int16 kint16min{~0x7FFF}; +const int16 kint16max{0x7FFF}; +const int32 kint32min{~0x7FFFFFFF}; +const int32 kint32max{0x7FFFFFFF}; +const int64 kint64min{GG_LONGLONG(~0x7FFFFFFFFFFFFFFF)}; +const int64 kint64max{GG_LONGLONG(0x7FFFFFFFFFFFFFFF)}; + typedef uint64 Fprint; static const Fprint kIllegalFprint = 0; static const Fprint kMaxFprint = GG_ULONGLONG(0xFFFFFFFFFFFFFFFF);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/status_matchers.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/status_matchers.h new file mode 100644 index 0000000..3794c0b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/status_matchers.h
@@ -0,0 +1,21 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MATCHERS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MATCHERS_H_ + +#include "tensorflow_lite_support/cc/port/default/status_matchers.h" + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MATCHERS_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/statusor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/statusor.h index f84c756..a80394d41 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/statusor.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/statusor.h
@@ -16,5 +16,17 @@ #ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_ #define TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_ -#include "tensorflow_lite_support/cc/port/default/statusor.h" +// This header file is used to manage the depended StatusOr library. It creates +// an extra layer that makes it easier to switch between the desired version of +// StatusOr. +#include "absl/status/statusor.h" // from @com_google_absl + +namespace tflite { +namespace support { + +template <typename T> +using StatusOr = absl::StatusOr<T>; + +} // namespace support +} // namespace tflite #endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/README.md b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/README.md index bd756a2..b91c264 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/README.md +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/README.md
@@ -15,8 +15,8 @@ Use the C++ API to answer questions as follows: ```cc -using tflite::task::text::qa::BertQuestionAnswerer; -using tflite::task::text::qa::QaAnswer; +using tflite::task::text::BertQuestionAnswerer; +using tflite::task::text::QaAnswer; // Create API handler with Mobile Bert model. auto qa_client = BertQuestionAnswerer::CreateBertQuestionAnswererFromFile("/path/to/mobileBertModel", "/path/to/vocab"); // Or create API handler with Albert model.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/BUILD new file mode 100644 index 0000000..d50d3e98 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/BUILD
@@ -0,0 +1,65 @@ +load( + "@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", + "cc_library_with_tflite", +) + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library_with_tflite( + name = "audio_classifier", + srcs = ["audio_classifier.cc"], + hdrs = ["audio_classifier.h"], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + "//tensorflow_lite_support/cc/task/processor:classification_postprocessor", + "//tensorflow_lite_support/cc/task/processor:audio_preprocessor", + "//tensorflow_lite_support/cc/task/core:base_task_api", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/audio/core:audio_buffer", + "//tensorflow_lite_support/cc/task/audio/proto:audio_classifier_options_cc_proto", + "//tensorflow_lite_support/cc/task/audio/proto:class_proto_inc", + "//tensorflow_lite_support/cc/task/audio/proto:classifications_proto_inc", + "//tensorflow_lite_support/cc/task/core:classification_head", + "//tensorflow_lite_support/cc/task/core:label_map_item", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/processor/proto:classification_options_cc_proto", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite/c:c_api_types", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], + alwayslink = 1, +) + +cc_library_with_tflite( + name = "audio_embedder", + srcs = ["audio_embedder.cc"], + hdrs = ["audio_embedder.h"], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/processor:embedding_postprocessor", + "//tensorflow_lite_support/cc/task/processor:audio_preprocessor", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:base_task_api", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/audio/proto:audio_embedder_options_cc_proto", + "//tensorflow_lite_support/cc/task/processor/proto:embedding_cc_proto", + "//tensorflow_lite_support/cc/task/processor/proto:embedding_options_cc_proto", + "@com_google_absl//absl/status", + "@org_tensorflow//tensorflow/lite/c:common", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc new file mode 100644 index 0000000..4be3e53 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc
@@ -0,0 +1,151 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/audio/audio_classifier.h" + +#include <initializer_list> + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/task/audio/proto/audio_classifier_options.pb.h" +#include "tensorflow_lite_support/cc/task/audio/proto/class_proto_inc.h" +#include "tensorflow_lite_support/cc/task/audio/proto/classifications_proto_inc.h" +#include "tensorflow_lite_support/cc/task/core/classification_head.h" +#include "tensorflow_lite_support/cc/task/core/label_map_item.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/processor/audio_preprocessor.h" +#include "tensorflow_lite_support/cc/task/processor/classification_postprocessor.h" +#include "tensorflow_lite_support/cc/task/processor/proto/classification_options.pb.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace audio { + +namespace { + +using ::absl::StatusCode; +using ::tflite::AudioProperties; +using ::tflite::ContentProperties; +using ::tflite::ContentProperties_AudioProperties; +using ::tflite::metadata::ModelMetadataExtractor; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::audio::Class; +using ::tflite::task::audio::ClassificationResult; +using ::tflite::task::core::AssertAndReturnTypedTensor; +using ::tflite::task::core::LabelMapItem; +using ::tflite::task::core::TaskAPIFactory; +using ::tflite::task::core::TfLiteEngine; + +} // namespace + +StatusOr<std::unique_ptr<processor::ClassificationPostprocessor>> +CreatePostprocessor(TfLiteEngine* engine, + const std::initializer_list<int> output_indices, + AudioClassifierOptions* options) { + auto new_options = std::make_unique<processor::ClassificationOptions>(); + new_options->set_display_names_locale(options->display_names_locale()); + new_options->set_max_results(options->max_results()); + new_options->set_score_threshold(options->score_threshold()); + new_options->mutable_class_name_allowlist()->Swap( + options->mutable_class_name_allowlist()); + new_options->mutable_class_name_denylist()->Swap( + options->mutable_class_name_denylist()); + return processor::ClassificationPostprocessor::Create(engine, output_indices, + std::move(new_options)); +} + +/* static */ +StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::CreateFromOptions( + const AudioClassifierOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + RETURN_IF_ERROR(SanityCheckOptions(options)); + + // Copy options to ensure the ExternalFile outlives the constructed object. + auto options_copy = absl::make_unique<AudioClassifierOptions>(options); + + ASSIGN_OR_RETURN(auto audio_classifier, + TaskAPIFactory::CreateFromBaseOptions<AudioClassifier>( + &options_copy->base_options(), std::move(resolver))); + + RETURN_IF_ERROR(audio_classifier->Init(std::move(options_copy))); + + return audio_classifier; +} + +/* static */ +absl::Status AudioClassifier::SanityCheckOptions( + const AudioClassifierOptions& options) { + if (!options.has_base_options()) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Missing mandatory `base_options` field", + TfLiteSupportStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +absl::Status AudioClassifier::Init( + std::unique_ptr<AudioClassifierOptions> options) { + // Set options. + options_ = std::move(options); + + // Create preprocessor, assuming having only 1 input tensor. + ASSIGN_OR_RETURN(preprocessor_, processor::AudioPreprocessor::Create( + GetTfLiteEngine(), {0})); + + // Assuming all output tensors share the same option. This is an limitation in + // the current API design. + int output_count = + GetTfLiteEngine()->OutputCount(GetTfLiteEngine()->interpreter()); + postprocessors_.reserve(output_count); + for (int i = 0; i < output_count; i++) { + ASSIGN_OR_RETURN(auto processor, CreatePostprocessor(GetTfLiteEngine(), {i}, + options_.get())); + postprocessors_.emplace_back(std::move(processor)); + } + + return absl::OkStatus(); +} + +tflite::support::StatusOr<ClassificationResult> AudioClassifier::Classify( + const AudioBuffer& audio_buffer) { + return InferWithFallback(audio_buffer); +} + +tflite::support::StatusOr<audio::ClassificationResult> +AudioClassifier::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const AudioBuffer& audio_buffer) { + audio::ClassificationResult result; + for (auto& processor : postprocessors_) { + auto* classification = result.add_classifications(); + // ClassificationPostprocessor doesn't set head name for backward + // compatibility, so we set it here manually. + classification->set_head_name(processor->GetHeadName()); + RETURN_IF_ERROR(processor->Postprocess(classification)); + } + + return result; +} + +} // namespace audio +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.h new file mode 100644 index 0000000..cef48d8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.h
@@ -0,0 +1,133 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_AUDIO_CLASSIFIER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_AUDIO_CLASSIFIER_H_ + +#include <memory> + +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/core/shims/cc/kernels/register.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h" +#include "tensorflow_lite_support/cc/task/audio/proto/audio_classifier_options.pb.h" +#include "tensorflow_lite_support/cc/task/audio/proto/classifications_proto_inc.h" +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/classification_head.h" +#include "tensorflow_lite_support/cc/task/processor/audio_preprocessor.h" +#include "tensorflow_lite_support/cc/task/processor/classification_postprocessor.h" + +namespace tflite { +namespace task { +namespace audio { + +// Performs classification on audio clips. +// +// This API expects a TFLite model with metadata. +// +// Input tensor: +// (kTfLiteFloat32) +// - input audio buffer of size `[batch * samples]`. +// - batch inference is not supported (`batch` is required to be 1). +// - for multi-channel models, the channels need be interleaved. +// At least one output tensor with: +// (kTfLiteFloat32) +// - `[1 x N]` array with `N` represents the class number. +// - optional (but recommended) label map(s) as AssociatedFile-s with type +// TENSOR_AXIS_LABELS, containing one label per line. The first such +// AssociatedFile (if any) is used to fill the `class_name` field of the +// results. The `display_name` field is filled from the AssociatedFile (if +// any) whose locale matches the `display_names_locale` field of the +// `ImageClassifierOptions` used at creation time ("en" by default, i.e. +// English). If none of these are available, only the `index` field of the +// results will be filled. +// +// An example of such model can be found at: +// https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1 + +// A CLI demo tool is available for easily trying out this API, and provides +// example usage. See: +// https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/examples/task/audio/desktop +class AudioClassifier + : public tflite::task::core::BaseTaskApi<ClassificationResult, + const AudioBuffer&> { + public: + using BaseTaskApi::BaseTaskApi; + + // Creates an AudioClassifier from the provided options. A non-default + // OpResolver can be specified in order to support custom Ops or specify a + // subset of built-in Ops. + static tflite::support::StatusOr<std::unique_ptr<AudioClassifier>> + CreateFromOptions( + const AudioClassifierOptions& options, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); + + // Performs classification on the provided audio buffer. + // + // The input `audio_buffer` are the raw buffer captured by the required format + // which can retrieved by GetRequiredAudioFormat(). + tflite::support::StatusOr<ClassificationResult> Classify( + const AudioBuffer& audio_buffer); + + // Returns the required input audio format if it is set. Otherwise, returns + // kMetadataNotFoundError. + // TODO(b/182625132): Add unit test after the format is populated from model + // metadata. + tflite::support::StatusOr<AudioBuffer::AudioFormat> GetRequiredAudioFormat() { + return preprocessor_->GetRequiredAudioFormat(); + } + + // Returns the required input buffer size in number of float elements. + int GetRequiredInputBufferSize() { + return preprocessor_->GetRequiredInputBufferSize(); + } + + private: + // Performs sanity checks on the provided AudioClassifierOptions. + static absl::Status SanityCheckOptions(const AudioClassifierOptions& options); + + // Initializes the AudioClassifier from the provided AudioClassifierOptions, + // whose ownership is transferred to this object. + absl::Status Init(std::unique_ptr<AudioClassifierOptions> options); + + // Passes through the input audio buffer into model's input tensor. + absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, + const AudioBuffer& audio_buffer) override { + return preprocessor_->Preprocess(audio_buffer); + } + + // Post-processing to transform the raw model outputs into classification + // results. + tflite::support::StatusOr<ClassificationResult> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const AudioBuffer& audio_buffer) override; + + // The options used to build this AudioClassifier. + std::unique_ptr<AudioClassifierOptions> options_; + + std::unique_ptr<tflite::task::processor::AudioPreprocessor> preprocessor_ = + nullptr; + + std::vector< + std::unique_ptr<tflite::task::processor::ClassificationPostprocessor>> + postprocessors_; +}; + +} // namespace audio +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_AUDIO_CLASSIFIER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.cc new file mode 100644 index 0000000..c67b67dd --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.cc
@@ -0,0 +1,119 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/audio/audio_embedder.h" + +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/audio/proto/audio_embedder_options.pb.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/processor/audio_preprocessor.h" +#include "tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h" +#include "tensorflow_lite_support/cc/task/processor/proto/embedding_options.pb.h" + +namespace tflite { +namespace task { +namespace audio { + +/* static */ +tflite::support::StatusOr<std::unique_ptr<AudioEmbedder>> +AudioEmbedder::CreateFromOptions(const AudioEmbedderOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + RETURN_IF_ERROR(SanityCheckOptions(options)); + auto options_copy = absl::make_unique<AudioEmbedderOptions>(options); + + ASSIGN_OR_RETURN(auto audio_embedder, + core::TaskAPIFactory::CreateFromBaseOptions<AudioEmbedder>( + &options_copy->base_options(), std::move(resolver))); + + RETURN_IF_ERROR(audio_embedder->Init(std::move(options_copy))); + return audio_embedder; +} + +/* static */ +absl::Status AudioEmbedder::SanityCheckOptions( + const AudioEmbedderOptions& options) { + if (!options.has_base_options()) { + return support::CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Missing mandatory `base_options` field", + support::TfLiteSupportStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +absl::Status AudioEmbedder::Init( + std::unique_ptr<AudioEmbedderOptions> options) { + options_ = std::move(options); + + // Create preprocessor, assuming having only 1 input tensor. + ASSIGN_OR_RETURN(preprocessor_, + tflite::task::processor::AudioPreprocessor::Create( + GetTfLiteEngine(), {0})); + + // Create postprocessors, assuming that all output tensors are embedding + // outputs. + int post_processors_count = + GetTfLiteEngine()->OutputCount(GetTfLiteEngine()->interpreter()); + postprocessors_.reserve(post_processors_count); + + for (int i = 0; i < post_processors_count; i++) { + std::unique_ptr<processor::EmbeddingOptions> option = nullptr; + if (options_->embedding_options_size() == 0) { + // Default options. + option = std::make_unique<processor::EmbeddingOptions>(); + } else if (options_->embedding_options_size() == 1) { + // Share the first options. + option = std::make_unique<processor::EmbeddingOptions>( + options_->embedding_options(0)); + } else if (options_->embedding_options_size() == post_processors_count) { + option = std::make_unique<processor::EmbeddingOptions>( + // Use the corresponding options for the tensor. + options_->embedding_options(i)); + } else { + return support::CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid embedding_options. It should have size of either 0, 1 or " + "number of output tensors.", + support::TfLiteSupportStatus::kInvalidArgumentError); + } + ASSIGN_OR_RETURN(auto processor, + processor::EmbeddingPostprocessor::Create( + GetTfLiteEngine(), {i}, std::move(option))); + postprocessors_.emplace_back(std::move(processor)); + } + return absl::OkStatus(); +} + +tflite::support::StatusOr<tflite::task::processor::EmbeddingResult> +AudioEmbedder::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const AudioBuffer& audio_buffer) { + tflite::task::processor::EmbeddingResult result; + for (int i = 0; i < postprocessors_.size(); i++) { + auto processor = postprocessors_.at(i).get(); + RETURN_IF_ERROR(processor->Postprocess(result.add_embeddings())); + } + return result; +} + +tflite::support::StatusOr<tflite::task::processor::EmbeddingResult> +AudioEmbedder::Embed(const AudioBuffer& audio_buffer) { + return InferWithFallback(audio_buffer); +} + +} // namespace audio +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h new file mode 100644 index 0000000..a3d4c57 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h
@@ -0,0 +1,92 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_AUDIO_EMBEDDER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_AUDIO_EMBEDDER_H_ + +#include <memory> + +#include "tensorflow/lite/c/common.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/audio/proto/audio_embedder_options.pb.h" +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/processor/audio_preprocessor.h" +#include "tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h" +#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h" +namespace tflite { +namespace task { +namespace audio { +class AudioEmbedder : public tflite::task::core::BaseTaskApi< + tflite::task::processor::EmbeddingResult, + const AudioBuffer&> { + public: + // Use base class constructor. + using BaseTaskApi::BaseTaskApi; + + // Creates an AudioEmbedder from the provided options. A non-default + // OpResolver can be specified in order to support custom Ops or specify a + // subset of built-in Ops. + static tflite::support::StatusOr<std::unique_ptr<AudioEmbedder>> + CreateFromOptions( + const AudioEmbedderOptions& options, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); + + // Performs actual feature vector extraction on the provided AudioBuffer. + tflite::support::StatusOr<tflite::task::processor::EmbeddingResult> Embed( + const AudioBuffer& audio_buffer); + + // Returns the required input audio format if it is set. Otherwise, returns + // kMetadataNotFoundError. + // TODO(b/182625132): Add unit test after the format is populated from model + // metadata. + tflite::support::StatusOr<AudioBuffer::AudioFormat> GetRequiredAudioFormat() { + return preprocessor_->GetRequiredAudioFormat(); + } + + // Returns the required input buffer size in number of float elements. + int GetRequiredInputBufferSize() { + return preprocessor_->GetRequiredInputBufferSize(); + } + + private: + static absl::Status SanityCheckOptions(const AudioEmbedderOptions& options); + + absl::Status Init(std::unique_ptr<AudioEmbedderOptions> options); + + // Passes through the input audio buffer into model's input tensor. + absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, + const AudioBuffer& audio_buffer) override { + return preprocessor_->Preprocess(audio_buffer); + } + + // Transforms the raw model outputs into embedding results. + tflite::support::StatusOr<tflite::task::processor::EmbeddingResult> + Postprocess(const std::vector<const TfLiteTensor*>& output_tensors, + const AudioBuffer& audio_buffer) override; + + std::unique_ptr<AudioEmbedderOptions> options_ = nullptr; + + // Processors + std::unique_ptr<tflite::task::processor::AudioPreprocessor> preprocessor_ = + nullptr; + std::vector<std::unique_ptr<tflite::task::processor::EmbeddingPostprocessor>> + postprocessors_; +}; + +} // namespace audio +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_AUDIO_EMBEDDER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/BUILD new file mode 100644 index 0000000..1b785da --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/BUILD
@@ -0,0 +1,23 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "audio_buffer", + hdrs = [ + "audio_buffer.h", + ], + visibility = [ + "//visibility:public", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:statusor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h new file mode 100644 index 0000000..d922e48 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h
@@ -0,0 +1,76 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_CORE_AUDIO_BUFFER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_CORE_AUDIO_BUFFER_H_ + +#include <memory> + +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/statusor.h" + +namespace tflite { +namespace task { +namespace audio { + +// Provides a view into the provided backing buffer and the audio format +// metadata. +// TODO(b/182675479): Support quantized input format. +class AudioBuffer { + public: + // Audio format metadata. + struct AudioFormat { + int channels; + int sample_rate; + }; + + // Factory method for creating an AudioBuffer object. The internal buffer does + // not take the ownership of the input backing buffer. + static tflite::support::StatusOr<std::unique_ptr<AudioBuffer>> Create( + const float* audio_buffer, + int buffer_size, + const AudioFormat& audio_format) { + return absl::make_unique<AudioBuffer>(audio_buffer, buffer_size, + audio_format); + } + + // AudioBuffer for internal use only. Uses the factory method to construct + // AudioBuffer instance. The internal buffer does not take the ownership of + // the input backing buffer. + AudioBuffer(const float* audio_buffer, + int buffer_size, + const AudioFormat& audio_format) + : audio_buffer_(audio_buffer), + buffer_size_(buffer_size), + audio_format_(audio_format) {} + + // Accessors + AudioFormat GetAudioFormat() const { return audio_format_; } + int GetBufferSize() const { return buffer_size_; } + const float* GetFloatBuffer() const { return audio_buffer_; } + + private: + const float* audio_buffer_; + int buffer_size_; + AudioFormat audio_format_; +}; + +} // namespace audio +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_CORE_AUDIO_BUFFER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/BUILD new file mode 100644 index 0000000..63d78fb4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/BUILD
@@ -0,0 +1,53 @@ +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +# AudioClassifier protos. + +proto_library( + name = "audio_classifier_options_proto", + srcs = ["audio_classifier_options.proto"], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto", + ], +) + +cc_proto_library( + name = "audio_classifier_options_cc_proto", + deps = [ + ":audio_classifier_options_proto", + ], +) + +cc_library( + name = "classifications_proto_inc", + hdrs = ["classifications_proto_inc.h"], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:classifications_cc_proto", + ], +) + +cc_library( + name = "class_proto_inc", + hdrs = ["class_proto_inc.h"], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:class_cc_proto", + ], +) + +proto_library( + name = "audio_embedder_options_proto", + srcs = ["audio_embedder_options.proto"], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto", + "//tensorflow_lite_support/cc/task/processor/proto:embedding_options_proto", + ], +) + +cc_proto_library( + name = "audio_embedder_options_cc_proto", + deps = [":audio_embedder_options_proto"], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/audio_classifier_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/audio_classifier_options.proto new file mode 100644 index 0000000..a17bb15 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/audio_classifier_options.proto
@@ -0,0 +1,50 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.audio; + +import "tensorflow_lite_support/cc/task/core/proto/base_options.proto"; + +// Options for setting up an AudioClassifier. +// Next Id: 7 +message AudioClassifierOptions { + // Base options for configuring the external model file. + optional tflite.task.core.BaseOptions base_options = 1; + + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + optional string display_names_locale = 2 [default = "en"]; + + // The maximum number of top-scored classification results to return. If < 0, + // all available results will be returned. If 0, an invalid argument error is + // returned. + optional int32 max_results = 3 [default = -1]; + + // Score threshold, overrides the ones provided in the model metadata + // (if any). Results below this value are rejected. + optional float score_threshold = 4; + + // Optional allowlist of class names. If non-empty, classifications whose + // class name is not in this set will be filtered out. Duplicate or unknown + // class names are ignored. Mutually exclusive with class_name_denylist. + repeated string class_name_allowlist = 5; + + // Optional denylist of class names. If non-empty, classifications whose + // class name is in this set will be filtered out. Duplicate or unknown + // class names are ignored. Mutually exclusive with class_name_allowlist. + repeated string class_name_denylist = 6; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/audio_embedder_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/audio_embedder_options.proto new file mode 100644 index 0000000..0271c2f4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/audio_embedder_options.proto
@@ -0,0 +1,36 @@ + +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.audio; + +import "tensorflow_lite_support/cc/task/core/proto/base_options.proto"; +import "tensorflow_lite_support/cc/task/processor/proto/embedding_options.proto"; + +// Options for setting up an AudioEmbedder. +// Next Id: 3 +message AudioEmbedderOptions { + // Base options for configuring the external model file. + optional tflite.task.core.BaseOptions base_options = 1; + + // Options for each embedding head. If the model contains N heads (embedding + // output tensor), then the len(embedding_options) needs to either + // 0: All output tensors are processed using the *default* EmbeddingOptions. + // 1: All output tensors are processed using the *same* EmbeddingOptions. + // N: Output tensors are processed using the *corresponding* EmbeddingOptions. + repeated tflite.task.processor.EmbeddingOptions embedding_options = 2; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/class_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/class_proto_inc.h new file mode 100644 index 0000000..f965e5a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/class_proto_inc.h
@@ -0,0 +1,29 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_PROTO_CLASS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_PROTO_CLASS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/core/proto/class.pb.h" + +namespace tflite { +namespace task { +namespace audio { + +using Class = ::tflite::task::core::Class; + +} // namespace audio +} // namespace task +} // namespace tflite +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_PROTO_CLASSIFICATIONS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/classifications_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/classifications_proto_inc.h new file mode 100644 index 0000000..00ac3fd --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/proto/classifications_proto_inc.h
@@ -0,0 +1,30 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_PROTO_CLASSIFICATIONS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_PROTO_CLASSIFICATIONS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/core/proto/classifications.pb.h" + +namespace tflite { +namespace task { +namespace audio { + +using Classifications = ::tflite::task::core::Classifications; +using ClassificationResult = ::tflite::task::core::ClassificationResult; + +} // namespace audio +} // namespace task +} // namespace tflite +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_PROTO_CLASSIFICATIONS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/BUILD new file mode 100644 index 0000000..9d2b542 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/BUILD
@@ -0,0 +1,27 @@ +# Description: some utilitiy function for audio processing. + +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "wav_io", + srcs = [ + "wav_io.cc", + ], + hdrs = ["wav_io.h"], + visibility = [ + "//tensorflow_lite_support:internal", + ], + deps = [ + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:status_macros", + "@com_google_absl//absl/base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc new file mode 100644 index 0000000..9ae3fbec --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc
@@ -0,0 +1,232 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Functions to write audio in WAV format. +// This file is forked from `tensorflow/core/lib/wav/wav_io.cc`. + +#include "tensorflow_lite_support/cc/task/audio/utils/wav_io.h" + +#include <math.h> +#include <string.h> + +#include <algorithm> +#include <cinttypes> +#include <cstdint> +#include <fstream> +#include <limits> + +#include "absl/base/casts.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_cat.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/status_macros.h" + +namespace tflite { +namespace task { +namespace audio { +namespace { + +constexpr char kRiffChunkId[] = "RIFF"; +constexpr char kRiffType[] = "WAVE"; +constexpr char kFormatChunkId[] = "fmt "; +constexpr char kDataChunkId[] = "data"; + +inline float Int16SampleToFloat(int16_t data) { + constexpr float kMultiplier = 1.0f / (1 << 15); + return data * kMultiplier; +} + +} // namespace + +std::string ReadFile(const std::string filepath) { + std::ifstream fs(filepath); + if (!fs.is_open()) { + return ""; + } + std::string contents((std::istreambuf_iterator<char>(fs)), + (std::istreambuf_iterator<char>())); + return contents; +} + +// Handles moving the data index forward, validating the arguments, and avoiding +// overflow or underflow. +absl::Status IncrementOffset(int old_offset, + size_t increment, + size_t max_size, + int* new_offset) { + if (old_offset < 0) { + return absl::InvalidArgumentError( + absl::StrFormat("Negative offsets are not allowed: %d", old_offset)); + } + if (old_offset > max_size) { + return absl::InvalidArgumentError(absl::StrFormat( + "Initial offset is outside data range: %d", old_offset)); + } + *new_offset = old_offset + increment; + if (*new_offset > max_size) { + return absl::InvalidArgumentError( + "Data too short when trying to read string"); + } + // See above for the check that the input offset is positive. If it's negative + // here then it means that there's been an overflow in the arithmetic. + if (*new_offset < 0) { + return absl::InvalidArgumentError( + absl::StrFormat("Offset too large, overflowed: %d", *new_offset)); + } + return absl::OkStatus(); +} + +absl::Status ExpectText(const std::string& data, + const std::string& expected_text, + int* offset) { + int new_offset; + RETURN_IF_ERROR( + IncrementOffset(*offset, expected_text.size(), data.size(), &new_offset)); + const std::string found_text(data.begin() + *offset, + data.begin() + new_offset); + if (found_text != expected_text) { + return absl::InvalidArgumentError(absl::StrCat( + "Header mismatch: Expected", expected_text, " but found ", found_text)); + } + *offset = new_offset; + return absl::OkStatus(); +} + +absl::Status ReadString(const std::string& data, + int expected_length, + std::string* value, + int* offset) { + int new_offset; + RETURN_IF_ERROR( + IncrementOffset(*offset, expected_length, data.size(), &new_offset)); + *value = std::string(data.begin() + *offset, data.begin() + new_offset); + *offset = new_offset; + return absl::OkStatus(); +} + +absl::Status DecodeLin16WaveAsFloatVector(const std::string& wav_string, + std::vector<float>* float_values, + uint32_t* sample_count, + uint16_t* channel_count, + uint32_t* sample_rate) { + int offset = 0; + RETURN_IF_ERROR(ExpectText(wav_string, kRiffChunkId, &offset)); + uint32_t total_file_size; + RETURN_IF_ERROR(ReadValue<uint32_t>(wav_string, &total_file_size, &offset)); + RETURN_IF_ERROR(ExpectText(wav_string, kRiffType, &offset)); + RETURN_IF_ERROR(ExpectText(wav_string, kFormatChunkId, &offset)); + uint32_t format_chunk_size; + RETURN_IF_ERROR(ReadValue<uint32_t>(wav_string, &format_chunk_size, &offset)); + if ((format_chunk_size != 16) && (format_chunk_size != 18)) { + return absl::InvalidArgumentError(absl::StrFormat( + "Bad format chunk size for WAV: Expected 16 or 18, but got %" PRIu32, + format_chunk_size)); + } + uint16_t audio_format; + RETURN_IF_ERROR(ReadValue<uint16_t>(wav_string, &audio_format, &offset)); + if (audio_format != 1) { + return absl::InvalidArgumentError(absl::StrFormat( + "Bad audio format for WAV: Expected 1 (PCM), but got %" PRIu16, + audio_format)); + } + RETURN_IF_ERROR(ReadValue<uint16_t>(wav_string, channel_count, &offset)); + if (*channel_count < 1) { + return absl::InvalidArgumentError(absl::StrFormat( + "Bad number of channels for WAV: Expected at least 1, but got %" PRIu16, + *channel_count)); + } + RETURN_IF_ERROR(ReadValue<uint32_t>(wav_string, sample_rate, &offset)); + uint32_t bytes_per_second; + RETURN_IF_ERROR(ReadValue<uint32_t>(wav_string, &bytes_per_second, &offset)); + uint16_t bytes_per_sample; + RETURN_IF_ERROR(ReadValue<uint16_t>(wav_string, &bytes_per_sample, &offset)); + // Confusingly, bits per sample is defined as holding the number of bits for + // one channel, unlike the definition of sample used elsewhere in the WAV + // spec. For example, bytes per sample is the memory needed for all channels + // for one point in time. + uint16_t bits_per_sample; + RETURN_IF_ERROR(ReadValue<uint16_t>(wav_string, &bits_per_sample, &offset)); + if (bits_per_sample != 16) { + return absl::InvalidArgumentError( + absl::StrFormat("Can only read 16-bit WAV files, but received %" PRIu16, + bits_per_sample)); + } + const uint32_t expected_bytes_per_sample = + ((bits_per_sample * *channel_count) + 7) / 8; + if (bytes_per_sample != expected_bytes_per_sample) { + return absl::InvalidArgumentError( + absl::StrFormat("Bad bytes per sample in WAV header: Expected %" PRIu32 + " but got %" PRIu16, + expected_bytes_per_sample, bytes_per_sample)); + } + const uint32_t expected_bytes_per_second = bytes_per_sample * *sample_rate; + if (bytes_per_second != expected_bytes_per_second) { + return absl::InvalidArgumentError( + absl::StrFormat("Bad bytes per second in WAV header: Expected %" PRIu32 + " but got %" PRIu32 " (sample_rate=%" PRIu32 + ", bytes_per_sample=%" PRIu16 ")", + expected_bytes_per_second, bytes_per_second, + *sample_rate, bytes_per_sample)); + } + if (format_chunk_size == 18) { + // Skip over this unused section. + offset += 2; + } + + bool was_data_found = false; + while (offset < wav_string.size()) { + std::string chunk_id; + RETURN_IF_ERROR(ReadString(wav_string, 4, &chunk_id, &offset)); + uint32_t chunk_size; + RETURN_IF_ERROR(ReadValue<uint32_t>(wav_string, &chunk_size, &offset)); + if (chunk_size > std::numeric_limits<int32_t>::max()) { + return absl::InvalidArgumentError(absl::StrFormat( + "WAV data chunk '%s' is too large: %" PRIu32 + " bytes, but the limit is %d", + chunk_id.c_str(), chunk_size, std::numeric_limits<int32_t>::max())); + } + if (chunk_id == kDataChunkId) { + if (was_data_found) { + return absl::InvalidArgumentError( + "More than one data chunk found in WAV"); + } + was_data_found = true; + *sample_count = chunk_size / bytes_per_sample; + const uint32_t data_count = *sample_count * *channel_count; + int unused_new_offset = 0; + // Validate that the data exists before allocating space for it + // (prevent easy OOM errors). + RETURN_IF_ERROR(IncrementOffset(offset, sizeof(int16_t) * data_count, + wav_string.size(), &unused_new_offset)); + float_values->resize(data_count); + for (int i = 0; i < data_count; ++i) { + int16_t single_channel_value = 0; + RETURN_IF_ERROR( + ReadValue<int16_t>(wav_string, &single_channel_value, &offset)); + (*float_values)[i] = Int16SampleToFloat(single_channel_value); + } + } else { + offset += chunk_size; + } + } + if (!was_data_found) { + return absl::InvalidArgumentError("No data chunk found in WAV"); + } + return absl::OkStatus(); +} + +} // namespace audio +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h new file mode 100644 index 0000000..9aca5d06 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h
@@ -0,0 +1,102 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Functions to write audio in WAV format. +// This file is forked from `tensorflow/core/lib/wav/wav_io.h`. + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_UTILS_WAV_IO_H_ + +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_UTILS_WAV_IO_H_ + +#include <cstdint> +#include <string> +#include <vector> + +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/status_macros.h" + +namespace tflite { +namespace task { +namespace audio { + +// Byte order definition provided by gcc. MSVC doesn't define those so +// we define them here. +// We assume that all windows platform out there are little endian. +#if defined(_MSC_VER) && !defined(__clang__) +#define __ORDER_LITTLE_ENDIAN__ 0x4d2 +#define __ORDER_BIG_ENDIAN__ 0x10e1 +#define __BYTE_ORDER__ __ORDER_LITTLE_ENDIAN__ +#endif + +namespace port { + +// TODO(jeff,sanjay): Make portable +constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; + +} // namespace port + +// Decodes the little-endian signed 16-bit PCM WAV file data (aka LIN16 +// encoding) into a float Tensor. The channels are encoded as the lowest +// dimension of the tensor, with the number of frames as the second. This means +// that a four frame stereo signal will have the shape [4, 2]. The sample rate +// is read from the file header, and an error is returned if the format is not +// supported. +// The results are output as floats within the range -1 to 1, +absl::Status DecodeLin16WaveAsFloatVector(const std::string& wav_string, + std::vector<float>* float_values, + uint32_t* sample_count, + uint16_t* channel_count, + uint32_t* sample_rate); + +// Everything below here is only exposed publicly for testing purposes. + +// Handles moving the data index forward, validating the arguments, and avoiding +// overflow or underflow. +absl::Status IncrementOffset(int old_offset, + size_t increment, + size_t max_size, + int* new_offset); + +// This function is only exposed in the header for testing purposes, as a +// template that needs to be instantiated. Reads a typed numeric value from a +// stream of data. +template <class T> +absl::Status ReadValue(const std::string& data, T* value, int* offset) { + int new_offset; + RETURN_IF_ERROR( + IncrementOffset(*offset, sizeof(T), data.size(), &new_offset)); + if (port::kLittleEndian) { + memcpy(value, data.data() + *offset, sizeof(T)); + } else { + *value = 0; + const uint8_t* data_buf = + reinterpret_cast<const uint8_t*>(data.data() + *offset); + int shift = 0; + for (int i = 0; i < sizeof(T); ++i, shift += 8) { + *value = *value | (data_buf[i] << shift); + } + } + *offset = new_offset; + return absl::OkStatus(); +} + +// Load the content of the file into std::string. +std::string ReadFile(const std::string filepath); + +} // namespace audio +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_UTILS_WAV_IO_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/BUILD index 13164ed..46a16bd 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/BUILD
@@ -1,19 +1,16 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + package( - default_visibility = ["//tensorflow_lite_support:users"], + default_visibility = ["//tensorflow_lite_support:internal"], licenses = ["notice"], # Apache 2.0 ) -cc_library( +cc_library_with_tflite( name = "tflite_engine", srcs = ["tflite_engine.cc"], hdrs = ["tflite_engine.h"], - deps = [ - ":external_file_handler", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@org_tensorflow//tensorflow/lite/c:common", - "@org_tensorflow//tensorflow/lite/core/api", + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:common", # The dependency on builtin_ops here is only for the default # value of the OpResolver parameter: # std::unique_ptr<tflite::IterableOpResolver> resolver = @@ -21,88 +18,70 @@ # When linking statically, if the client of this library doesn't use # the default argument, this dependency does not cause all the builtin ops # to get included in the executable. - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - "@org_tensorflow//tensorflow/lite/tools:verifier", - "@org_tensorflow//tensorflow/lite:kernel_api", - ] + select({ - "//tensorflow_lite_support/cc:tflite_use_c_api": [ - "@org_tensorflow//tensorflow/lite/c:c_api", - "@org_tensorflow//tensorflow/lite/c:c_api_experimental", - "@org_tensorflow//tensorflow/lite:stderr_reporter", - ], - "//conditions:default": ["@org_tensorflow//tensorflow/lite:framework"], - }) + [ - "//tensorflow_lite_support/cc:common", - "//tensorflow_lite_support/cc/port:status_macros", + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + "@org_tensorflow//tensorflow/lite/core/shims:framework", + "@org_tensorflow//tensorflow/lite/core/shims:verifier", "//tensorflow_lite_support/cc/port:tflite_wrapper", - "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", - "//tensorflow_lite_support/metadata/cc:metadata_extractor", ], -) - -# This is a duplicate of the above 'tflite_engine' target that is used for -# testing with TFLITE_USE_C_API defined. It should be the same as the target -# above, except that it adds -# testonly = 1, -# defines = ["TFLITE_USE_C_API"], -# and that it resolves the conditional deps from the 'select' as if -# "//tensorflow_lite_support/cc:tflite_use_c_api" was enabled. -# This allows testing the TFLITE_USE_C_API case even when -# '--copt=-DTFLITE_USE_C_API' wasn't passed on the build command line. -cc_library( - name = "tflite_engine_with_c_api_for_test", - testonly = 1, - srcs = ["tflite_engine.cc"], - hdrs = ["tflite_engine.h"], - defines = ["TFLITE_USE_C_API"], + visibility = [ + "//tensorflow_lite_support:internal", + ], deps = [ + ":error_reporter", ":external_file_handler", "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:configuration_proto_inc", "//tensorflow_lite_support/cc/port:status_macros", - "//tensorflow_lite_support/cc/port:tflite_wrapper_with_c_api_for_test", "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", "//tensorflow_lite_support/metadata/cc:metadata_extractor", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/lite:kernel_api", - "@org_tensorflow//tensorflow/lite:stderr_reporter", - "@org_tensorflow//tensorflow/lite/c:c_api", - "@org_tensorflow//tensorflow/lite/c:c_api_experimental", - "@org_tensorflow//tensorflow/lite/c:common", - "@org_tensorflow//tensorflow/lite/core/api", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - "@org_tensorflow//tensorflow/lite/tools:verifier", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", ], ) -cc_library( +cc_library_with_tflite( name = "base_task_api", hdrs = ["base_task_api.h"], - deps = [ + tflite_deps = [ ":tflite_engine", + "//tensorflow_lite_support/cc/port:tflite_wrapper", + ], + visibility = [ + "//tensorflow_lite_support:internal", + ], + deps = [ "//tensorflow_lite_support/cc:common", "//tensorflow_lite_support/cc/port:status_macros", "//tensorflow_lite_support/cc/port:statusor", - "//tensorflow_lite_support/cc/port:tflite_wrapper", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/lite/c:common", ], ) -cc_library( +cc_library_with_tflite( name = "task_api_factory", hdrs = ["task_api_factory.h"], - deps = [ + tflite_deps = [ ":base_task_api", ":tflite_engine", + ], + visibility = [ + "//tensorflow_lite_support:internal", + ], + deps = [ + "//tensorflow_lite_support/cc/port:configuration_proto_inc", "//tensorflow_lite_support/cc/port:status_macros", "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@org_tensorflow//tensorflow/lite/c:common", - "@org_tensorflow//tensorflow/lite/core/api", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/kernels:op_macros", ], ) @@ -111,10 +90,18 @@ name = "task_utils", srcs = ["task_utils.cc"], hdrs = ["task_utils.h"], + visibility = [ + "//tensorflow_lite_support:internal", + ], deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", "//tensorflow_lite_support/metadata:metadata_schema_cc", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@flatbuffers", "@org_tensorflow//tensorflow/lite:string_util", "@org_tensorflow//tensorflow/lite:type_to_tflitetype", @@ -126,13 +113,18 @@ cc_library( name = "category", hdrs = ["category.h"], + visibility = [ + "//tensorflow_lite_support:internal", + ], ) cc_library( name = "external_file_handler", srcs = ["external_file_handler.cc"], hdrs = ["external_file_handler.h"], - visibility = ["//visibility:public"], + visibility = [ + "//tensorflow_lite_support:internal", + ], deps = [ "//tensorflow_lite_support/cc:common", "//tensorflow_lite_support/cc/port:integral_types", @@ -146,3 +138,64 @@ "@com_google_absl//absl/strings:str_format", ], ) + +cc_library( + name = "error_reporter", + srcs = ["error_reporter.cc"], + hdrs = ["error_reporter.h"], + deps = [ + "@org_tensorflow//tensorflow/lite:minimal_logging", + "@org_tensorflow//tensorflow/lite:stateful_error_reporter", + ], +) + +cc_library( + name = "label_map_item", + srcs = ["label_map_item.cc"], + hdrs = ["label_map_item.h"], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "classification_head", + srcs = ["classification_head.cc"], + hdrs = ["classification_head.h"], + deps = [ + ":label_map_item", + ":score_calibration", + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "score_calibration", + srcs = ["score_calibration.cc"], + hdrs = ["score_calibration.h"], + deps = [ + ":label_map_item", + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h index a27f785..effd42f 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h
@@ -18,8 +18,8 @@ #include <utility> -#include "absl/status/status.h" -#include "absl/strings/string_view.h" +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl #include "tensorflow/lite/c/common.h" #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/status_macros.h" @@ -38,14 +38,19 @@ virtual ~BaseUntypedTaskApi() = default; - TfLiteEngine* GetTfLiteEngine() { return engine_.get(); } - const TfLiteEngine* GetTfLiteEngine() const { return engine_.get(); } - const metadata::ModelMetadataExtractor* GetMetadataExtractor() const { return engine_->metadata_extractor(); } protected: + // TODO(b/200258103): It's a short term solution. In the future we will forbid + // Tasks exposing the underlying TfLiteEngine. Please try not rely on this + // function. + // + // Returns a raw pointer to the underlying TfLiteEngine. + TfLiteEngine* GetTfLiteEngine() { return engine_.get(); } + + private: std::unique_ptr<TfLiteEngine> engine_; }; @@ -58,6 +63,24 @@ BaseTaskApi(const BaseTaskApi&) = delete; BaseTaskApi& operator=(const BaseTaskApi&) = delete; + int32_t GetInputCount() { + return GetTfLiteEngine()->interpreter()->inputs().size(); + } + + const TfLiteIntArray* GetInputShape(int index) { + auto interpreter = GetTfLiteEngine()->interpreter(); + return interpreter->tensor(interpreter->inputs()[index])->dims; + } + + int32_t GetOutputCount() { + return GetTfLiteEngine()->interpreter()->outputs().size(); + } + + const TfLiteIntArray* GetOutputShape(int index) { + auto interpreter = GetTfLiteEngine()->interpreter(); + return interpreter->tensor(interpreter->outputs()[index])->dims; + } + // Cancels the current running TFLite invocation on CPU. // // Usually called on a different thread than the one inference is running on. @@ -69,7 +92,7 @@ // partially delegated on CPU, logs a warning message and only cancels the // invocation running on CPU. Other invocation which depends on the output of // the CPU invocation will not be executed. - void Cancel() { engine_->Cancel(); } + void Cancel() { GetTfLiteEngine()->Cancel(); } protected: // Subclasses need to populate input_tensors from api_inputs. @@ -84,18 +107,20 @@ InputTypes... api_inputs) = 0; // Returns (the addresses of) the model's inputs. - std::vector<TfLiteTensor*> GetInputTensors() { return engine_->GetInputs(); } + std::vector<TfLiteTensor*> GetInputTensors() { + return GetTfLiteEngine()->GetInputs(); + } // Returns (the addresses of) the model's outputs. std::vector<const TfLiteTensor*> GetOutputTensors() { - return engine_->GetOutputs(); + return GetTfLiteEngine()->GetOutputs(); } // Performs inference using tflite::support::TfLiteInterpreterWrapper // InvokeWithoutFallback(). tflite::support::StatusOr<OutputType> Infer(InputTypes... args) { tflite::task::core::TfLiteEngine::InterpreterWrapper* interpreter_wrapper = - engine_->interpreter_wrapper(); + GetTfLiteEngine()->interpreter_wrapper(); // Note: AllocateTensors() is already performed by the interpreter wrapper // at InitInterpreter time (see TfLiteEngine). RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...)); @@ -115,7 +140,7 @@ // CPU where applicable. tflite::support::StatusOr<OutputType> InferWithFallback(InputTypes... args) { tflite::task::core::TfLiteEngine::InterpreterWrapper* interpreter_wrapper = - engine_->interpreter_wrapper(); + GetTfLiteEngine()->interpreter_wrapper(); // Note: AllocateTensors() is already performed by the interpreter wrapper // at InitInterpreter time (see TfLiteEngine). RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...));
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/category.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/category.h index a99f994..9fadd6a 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/category.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/category.h
@@ -15,6 +15,8 @@ #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_CATEGORY_H_ #define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_CATEGORY_H_ + +#include <cmath> #include <string> namespace tflite { @@ -29,7 +31,9 @@ : class_name(class_name), score(score) {} friend bool operator==(const Category& lhs, const Category& rhs) { - return lhs.score == rhs.score && lhs.class_name == rhs.class_name; + constexpr const double kScoreTolerance = 1e-6; + return lhs.class_name == rhs.class_name && + abs((double)(lhs.score - rhs.score)) <= kScoreTolerance; } friend bool operator!=(const Category& lhs, const Category& rhs) {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.cc new file mode 100644 index 0000000..fe22176 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.cc
@@ -0,0 +1,114 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/core/classification_head.h" + +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace core { + +using ::absl::StatusCode; +using ::tflite::metadata::ModelMetadataExtractor; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; + +StatusOr<ClassificationHead> BuildClassificationHead( + const tflite::metadata::ModelMetadataExtractor& metadata_extractor, + const tflite::TensorMetadata& output_tensor_metadata, + absl::string_view display_names_locale) { + ClassificationHead head; + if (output_tensor_metadata.name() != nullptr) { + head.name = output_tensor_metadata.name()->str(); + } + + // Build label map, if present. + const std::string labels_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + output_tensor_metadata, + tflite::AssociatedFileType_TENSOR_AXIS_LABELS); + if (!labels_filename.empty()) { + ASSIGN_OR_RETURN(absl::string_view labels_file, + metadata_extractor.GetAssociatedFile(labels_filename)); + const std::string display_names_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + output_tensor_metadata, + tflite::AssociatedFileType_TENSOR_AXIS_LABELS, + display_names_locale); + absl::string_view display_names_file; + if (!display_names_filename.empty()) { + ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile( + display_names_filename)); + } + ASSIGN_OR_RETURN(head.label_map_items, + BuildLabelMapFromFiles(labels_file, display_names_file)); + } + + // Set score threshold, if present. + ASSIGN_OR_RETURN(const tflite::ProcessUnit* score_thresholding_process_unit, + ModelMetadataExtractor::FindFirstProcessUnit( + output_tensor_metadata, + tflite::ProcessUnitOptions_ScoreThresholdingOptions)); + if (score_thresholding_process_unit != nullptr) { + head.score_threshold = + score_thresholding_process_unit->options_as_ScoreThresholdingOptions() + ->global_score_threshold(); + } + + // Build score calibration parameters, if present. + ASSIGN_OR_RETURN(const tflite::ProcessUnit* score_calibration_process_unit, + ModelMetadataExtractor::FindFirstProcessUnit( + output_tensor_metadata, + tflite::ProcessUnitOptions_ScoreCalibrationOptions)); + if (score_calibration_process_unit != nullptr) { + if (labels_filename.empty()) { + return CreateStatusWithPayload( + StatusCode::kNotFound, + "Using ScoreCalibrationOptions requires a label map to be provided " + "as TENSOR_AXIS_LABELS associated file.", + TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError); + } + const std::string score_calibration_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + output_tensor_metadata, + tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION); + if (score_calibration_filename.empty()) { + return CreateStatusWithPayload( + StatusCode::kNotFound, + "Found ScoreCalibrationOptions but missing required associated " + "parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.", + TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError); + } + ASSIGN_OR_RETURN( + absl::string_view score_calibration_file, + metadata_extractor.GetAssociatedFile(score_calibration_filename)); + ASSIGN_OR_RETURN(SigmoidCalibrationParameters sigmoid_params, + BuildSigmoidCalibrationParams( + *score_calibration_process_unit + ->options_as_ScoreCalibrationOptions(), + score_calibration_file, head.label_map_items)); + head.calibration_params = sigmoid_params; + } + + return head; +} + +} // namespace core +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h new file mode 100644 index 0000000..c91552f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h
@@ -0,0 +1,110 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_CLASSIFICATION_HEAD_ITEM_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_CLASSIFICATION_HEAD_ITEM_H_ + +#include <string> +#include <vector> + +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/label_map_item.h" +#include "tensorflow_lite_support/cc/task/core/score_calibration.h" +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace core { + +// A single classifier head for a classifier model, associated with a +// corresponding output tensor. +struct ClassificationHead { + ClassificationHead() : score_threshold(0) {} + + explicit ClassificationHead( + const std::vector<tflite::task::core::LabelMapItem>&& label_map_items) + : label_map_items(label_map_items), score_threshold(0) {} + + // An optional name that usually indicates what this set of classes represent, + // e.g. "flowers". + std::string name; + // The label map representing the list of supported classes, aka labels. + // + // This must be in direct correspondence with the associated output tensor, + // i.e.: + // + // - The number of classes must match with the dimension of the corresponding + // output tensor, + // - The i-th item in the label map is assumed to correspond to the i-th + // output value in the output tensor. + // + // This requires to put in place dedicated sanity checks before running + // inference. + std::vector<tflite::task::core::LabelMapItem> label_map_items; + // Recommended score threshold typically in [0,1[. Classification results with + // a score below this value are considered low-confidence and should be + // rejected from returned results. + float score_threshold; + // Optional score calibration parameters (one set of parameters per class in + // the label map). This is primarily meant for multi-label classifiers made of + // independent sigmoids. + // + // Such parameters are usually tuned so that calibrated scores can be compared + // to a default threshold common to all classes to achieve a given amount of + // precision. + // + // Example: 60% precision for threshold = 0.5. + absl::optional<tflite::task::core::SigmoidCalibrationParameters> + calibration_params; +}; + +// Builds a classification head using the provided metadata extractor, for the +// given output tensor metadata. Returns an error in case the head cannot be +// built (e.g. missing associated file for score calibration parameters). +// +// Optionally it is possible to specify which locale should be used (e.g. "en") +// to fill the label map display names, if any, and provided the corresponding +// associated file is present in the metadata. If no locale is specified, or if +// there is no associated file for the provided locale, display names are just +// left empty and no error is returned. +// +// E.g. (metatada displayed in JSON format below): +// +// ... +// "associated_files": [ +// { +// "name": "labels.txt", +// "type": "TENSOR_AXIS_LABELS" +// }, +// { +// "name": "labels-en.txt", +// "type": "TENSOR_AXIS_LABELS", +// "locale": "en" +// }, +// ... +// +// See metadata schema TENSOR_AXIS_LABELS for more details. +tflite::support::StatusOr<ClassificationHead> BuildClassificationHead( + const tflite::metadata::ModelMetadataExtractor& metadata_extractor, + const tflite::TensorMetadata& output_tensor_metadata, + absl::string_view display_names_locale = absl::string_view()); + +} // namespace core +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_CLASSIFICATION_HEAD_ITEM_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc new file mode 100644 index 0000000..a626ce6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc
@@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/core/error_reporter.h" + +#include <cstdarg> +#include <cstdio> +#include <cstring> + +#include "tensorflow/lite/minimal_logging.h" + +namespace tflite { +namespace task { +namespace core { + +int ErrorReporter::Report(const char* format, va_list args) { + std::strcpy(second_last_message_, last_message_); // NOLINT + last_message_[0] = '\0'; + int num_characters = vsnprintf(last_message_, kBufferSize, format, args); + // To mimic tflite::StderrReporter. + tflite::logging_internal::MinimalLogger::Log(TFLITE_LOG_ERROR, "%s", + last_message_); + return num_characters; +} + +std::string ErrorReporter::message() { + return last_message_; +} + +std::string ErrorReporter::previous_message() { + return second_last_message_; +} + +} // namespace core +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.h new file mode 100644 index 0000000..9f05350 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.h
@@ -0,0 +1,56 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_ERROR_REPORTER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_ERROR_REPORTER_H_ + +#include <cstdarg> +#include <string> + +#include "tensorflow/lite/stateful_error_reporter.h" + +namespace tflite { +namespace task { +namespace core { + +// An ErrorReporter that logs to stderr and captures the last two messages. +class ErrorReporter : public tflite::StatefulErrorReporter { + public: + ErrorReporter() { + last_message_[0] = '\0'; + second_last_message_[0] = '\0'; + } + + // We declared two functions with name 'Report', so that the variadic Report + // function in tflite::ErrorReporter is hidden. + // See https://isocpp.org/wiki/faq/strange-inheritance#hiding-rule. + using tflite::ErrorReporter::Report; + + int Report(const char* format, std::va_list args) override; + + std::string message() override; + std::string previous_message(); + + private: + static constexpr int kBufferSize = 1024; + char last_message_[kBufferSize]; + char second_last_message_[kBufferSize]; +}; + +} // namespace core +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_ERROR_REPORTER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc index 55b662f..e91a54f 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc
@@ -15,14 +15,12 @@ #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" -#include <errno.h> -#include <fcntl.h> #include <stddef.h> #include <memory> #include <string> -#include "absl/memory/memory.h" -#include "absl/strings/str_format.h" +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/status_macros.h" #include "tensorflow_lite_support/cc/port/statusor.h" @@ -30,15 +28,12 @@ namespace tflite { namespace task { namespace core { -namespace { using ::absl::StatusCode; using ::tflite::support::CreateStatusWithPayload; using ::tflite::support::StatusOr; using ::tflite::support::TfLiteSupportStatus; -} // namespace - /* static */ StatusOr<std::unique_ptr<ExternalFileHandler>> ExternalFileHandler::CreateFromExternalFile(const ExternalFile* external_file) { @@ -56,21 +51,15 @@ if (!external_file_.file_content().empty()) { return absl::OkStatus(); } - return CreateStatusWithPayload( - StatusCode::kInvalidArgument, - "ExternalFile must have 'file_content' set, loading from" - "'file_name' is not supported.", - TfLiteSupportStatus::kInvalidArgumentError); + + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "ExternalFile must specify 'file_content' " + "to be compatible with Chromium.", + TfLiteSupportStatus::kInvalidArgumentError); } absl::string_view ExternalFileHandler::GetFileContent() { - if (!external_file_.file_content().empty()) { - return external_file_.file_content(); - } else { - return absl::string_view(static_cast<const char*>(buffer_) + - buffer_offset_ - buffer_aligned_offset_, - buffer_size_); - } + return external_file_.file_content(); } ExternalFileHandler::~ExternalFileHandler() = default;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h index ad292dc..48c6281 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h
@@ -18,8 +18,8 @@ #include <memory> -#include "absl/status/status.h" -#include "absl/strings/string_view.h" +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/integral_types.h" #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" @@ -64,20 +64,6 @@ // Reference to the input ExternalFile. const ExternalFile& external_file_; - - // Points to the memory buffer mapped from the file descriptor of the - // ExternalFile, if provided by path or file descriptor. - void* buffer_{}; - - // The mapped memory buffer offset, if any. - int64 buffer_offset_{}; - // The size in bytes of the mapped memory buffer, if any. - int64 buffer_size_{}; - - // As mmap(2) requires the offset to be a multiple of sysconf(_SC_PAGE_SIZE): - - // The aligned mapped memory buffer offset, if any. - int64 buffer_aligned_offset_{}; }; } // namespace core
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc new file mode 100644 index 0000000..72e4b67 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc
@@ -0,0 +1,128 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/core/label_map_item.h" + +#include "absl/strings/str_format.h" // from @com_google_absl +#include "absl/strings/str_split.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/common.h" + +namespace tflite { +namespace task { +namespace core { + +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; + +StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles( + absl::string_view labels_file, + absl::string_view display_names_file) { + if (labels_file.empty()) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Expected non-empty labels file.", + TfLiteSupportStatus::kInvalidArgumentError); + } + std::vector<absl::string_view> labels = absl::StrSplit(labels_file, '\n'); + // In most cases, there is an empty line (i.e. newline character) at the end + // of the file that needs to be ignored. In such a situation, StrSplit() will + // produce a vector with an empty string as final element. Also note that in + // case `labels_file` is entirely empty, StrSplit() will produce a vector with + // one single empty substring, so there's no out-of-range risk here. + if (labels[labels.size() - 1].empty()) { + labels.pop_back(); + } + + std::vector<LabelMapItem> label_map_items; + label_map_items.reserve(labels.size()); + for (int i = 0; i < labels.size(); ++i) { + label_map_items.emplace_back(LabelMapItem{.name = std::string(labels[i])}); + } + + if (!display_names_file.empty()) { + std::vector<std::string> display_names = + absl::StrSplit(display_names_file, '\n'); + // In most cases, there is an empty line (i.e. newline character) at the end + // of the file that needs to be ignored. See above. + if (display_names[display_names.size() - 1].empty()) { + display_names.pop_back(); + } + if (display_names.size() != labels.size()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Mismatch between number of labels (%d) and display names (%d).", + labels.size(), display_names.size()), + TfLiteSupportStatus::kMetadataNumLabelsMismatchError); + } + for (int i = 0; i < display_names.size(); ++i) { + label_map_items[i].display_name = display_names[i]; + } + } + return label_map_items; +} + +absl::Status LabelHierarchy::InitializeFromLabelMap( + std::vector<LabelMapItem> label_map_items) { + parents_map_.clear(); + for (const LabelMapItem& label : label_map_items) { + for (const std::string& child_name : label.child_name) { + parents_map_[child_name].insert(label.name); + } + } + if (parents_map_.empty()) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Input labelmap is not hierarchical: there " + "is no parent-child relationship.", + TfLiteSupportStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +bool LabelHierarchy::HaveAncestorDescendantRelationship( + const std::string& ancestor_name, + const std::string& descendant_name) const { + absl::flat_hash_set<std::string> ancestors; + GetAncestors(descendant_name, &ancestors); + return ancestors.contains(ancestor_name); +} + +absl::flat_hash_set<std::string> LabelHierarchy::GetParents( + const std::string& name) const { + absl::flat_hash_set<std::string> parents; + auto it = parents_map_.find(name); + if (it != parents_map_.end()) { + for (const std::string& parent_name : it->second) { + parents.insert(parent_name); + } + } + return parents; +} + +void LabelHierarchy::GetAncestors( + const std::string& name, + absl::flat_hash_set<std::string>* ancestors) const { + const absl::flat_hash_set<std::string> parents = GetParents(name); + for (const std::string& parent_name : parents) { + auto it = ancestors->insert(parent_name); + if (it.second) { + GetAncestors(parent_name, ancestors); + } + } +} + +} // namespace core +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h new file mode 100644 index 0000000..d8e1f70d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h
@@ -0,0 +1,96 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_LABEL_MAP_ITEM_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_LABEL_MAP_ITEM_H_ + +#include <string> +#include <vector> + +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "absl/container/flat_hash_set.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/statusor.h" + +namespace tflite { +namespace task { +namespace core { + +// Structure mapping a numerical class index output to a Knowledge Graph entity +// ID or any other string label representing this class. Optionally it is +// possible to specify an additional display name (in a given language) which is +// typically used for display purposes. +struct LabelMapItem { + // E.g. name = "/m/02xwb" + std::string name; + // E.g. display_name = "Fruit" + std::string display_name; + // Optional list of children (e.g. subcategories) used to represent a + // hierarchy. + std::vector<std::string> child_name; +}; + +// Builds a label map from labels and (optional) display names file contents, +// both expected to contain one label per line. Those are typically obtained +// from TFLite Model Metadata TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS +// associated files. +// Returns an error e.g. if there's a mismatch between the number of labels and +// display names. +tflite::support::StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles( + absl::string_view labels_file, + absl::string_view display_names_file); + +// A class that represents a hierarchy of labels as specified in a label map. +// +// For example, it is useful to determine if one label is a descendant of +// another label or not. This can be used to implement labels pruning based on +// hierarchy, e.g. if both "fruit" and "banana" have been inferred by a given +// classifier model prune "fruit" from the final results as "banana" is a more +// fine-grained descendant. +class LabelHierarchy { + public: + LabelHierarchy() = default; + + // Initializes the hierarchy of labels from a given label map vector. Returns + // an error status in case of failure, typically if the input label map does + // not contain any hierarchical relations between labels. + absl::Status InitializeFromLabelMap( + std::vector<LabelMapItem> label_map_items); + + // Returns true if `descendant_name` is a descendant of `ancestor_name` in the + // hierarchy of labels. Invalid names, i.e. names which do not exist in the + // label map used at initialization time, are ignored. + bool HaveAncestorDescendantRelationship( + const std::string& ancestor_name, + const std::string& descendant_name) const; + + private: + // Retrieve and return all parent names, if any, for the input label name. + absl::flat_hash_set<std::string> GetParents(const std::string& name) const; + + // Retrieve all ancestor names, if any, for the input label name. + void GetAncestors(const std::string& name, + absl::flat_hash_set<std::string>* ancestors) const; + + // Label name (key) to parent names (value) direct mapping. + absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> + parents_map_; +}; + +} // namespace core +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_LABEL_MAP_ITEM_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/BUILD index 7418e5b2..6c97e2b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/BUILD
@@ -12,9 +12,8 @@ srcs = ["external_file.proto"], ) -support_cc_proto_library( +cc_proto_library( name = "external_file_cc_proto", - srcs = ["external_file.proto"], deps = [ ":external_file_proto", ], @@ -25,3 +24,56 @@ hdrs = ["external_file_proto_inc.h"], deps = [":external_file_cc_proto"], ) + +proto_library( + name = "base_options_proto", + srcs = ["base_options.proto"], + deps = [ + ":external_file_proto", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_proto", + ], +) + +support_cc_proto_library( + name = "base_options_cc_proto", + deps = [ + ":base_options_proto", + ], +) + +cc_library( + name = "base_options_proto_inc", + hdrs = ["base_options_proto_inc.h"], + deps = [ + ":base_options_cc_proto", + ":external_file_proto_inc", + "//tensorflow_lite_support/cc/port:configuration_proto_inc", + ], +) + +proto_library( + name = "classifications_proto", + srcs = ["classifications.proto"], + deps = [ + ":class_proto", + ], +) + +cc_proto_library( + name = "classifications_cc_proto", + deps = [ + ":classifications_proto", + ], +) + +proto_library( + name = "class_proto", + srcs = ["class.proto"], +) + +cc_proto_library( + name = "class_cc_proto", + deps = [ + ":class_proto", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/base_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/base_options.proto new file mode 100644 index 0000000..9d529d6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/base_options.proto
@@ -0,0 +1,51 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.core; + +import "tensorflow/lite/experimental/acceleration/configuration/configuration.proto"; + +import "tensorflow_lite_support/cc/task/core/proto/external_file.proto"; + +// Base options for task libraries. +// Next Id: 4 +message BaseOptions { + // The external model file, as a single standalone TFLite file. It could be + // packed with TFLite Model Metadata[1] and associated files if exist. Fail to + // provide the necessary metadata and associated files might result in errors. + // Check the documentation for each task about the specific requirement. + // [1]: https://www.tensorflow.org/lite/convert/metadata + + optional core.ExternalFile model_file = 1; + + // Advanced settings specifying how to accelerate the model inference using + // dedicated delegates. Supported delegate type includes: + // NONE, NNAPI, GPU, HEXAGON, XNNPACK, EDGETPU (Google internal), + // and EDGETPU_CORAL. + // + // IMPORTANT: in order to use a delegate, the appropriate delegate plugin + // needs to be linked at build time. + // + // For example, `gpu_plugin` for GPU from: + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/BUILD + // To use EDGETPU_CORAL, link to `edgetpu_coral_plugin` from: + // https://github.com/tensorflow/tflite-support/blob/a58a4f9225c411fa9ba29f821523e6e283988d23/tensorflow_lite_support/acceleration/configuration/BUILD#L11 + // + // See settings definition at: + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/configuration.proto + optional tflite.proto.ComputeSettings compute_settings = 2; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h new file mode 100644 index 0000000..4c53a2f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h
@@ -0,0 +1,23 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_BASE_OPTIONS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_BASE_OPTIONS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/port/configuration_proto_inc.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options.pb.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_BASE_OPTIONS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/class.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/class.proto new file mode 100644 index 0000000..83d381a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/class.proto
@@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.core; + +// A single classification result. +message Class { + // The index of the class in the corresponding label map, usually packed in + // the TFLite Model Metadata [1]. + // + // [1]: https://www.tensorflow.org/lite/convert/metadata + optional int32 index = 1; + // The score for this class e.g. (but not necessarily) a probability in [0,1]. + optional float score = 2; + // A human readable name of the class filled from the label map. + optional string display_name = 3; + // An ID for the class, not necessarily human-readable (e.g. a Google + // Knowledge Graph ID [1]), filled from the label map. + // + // [1]: https://developers.google.com/knowledge-graph + optional string class_name = 4; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/classifications.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/classifications.proto new file mode 100644 index 0000000..ee1f099e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/classifications.proto
@@ -0,0 +1,39 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.core; + +import "tensorflow_lite_support/cc/task/core/proto/class.proto"; + +// List of predicted classes (aka labels) for a given classifier head. +message Classifications { + // The array of predicted classes, usually sorted by descending scores (e.g. + // from high to low probability). + repeated Class classes = 1; + // The index of the classifier head these classes refer to. This is useful for + // multi-head models. + optional int32 head_index = 2; + // The name of the classifier head, which is the corresponding tensor metadata + // name. See + // https://github.com/tensorflow/tflite-support/blob/710e323265bfb71fdbdd72b3516e00cff15c0326/tensorflow_lite_support/metadata/metadata_schema.fbs#L545 + optional string head_name = 3; +} + +// Contains one set of results per classifier head. +message ClassificationResult { + repeated Classifications classifications = 1; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/proto_config.pbtxt b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/proto_config.pbtxt deleted file mode 100644 index dafb0fd..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/proto_config.pbtxt +++ /dev/null
@@ -1,2 +0,0 @@ -allow_all: true -optimize_mode: LITE_RUNTIME
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc new file mode 100644 index 0000000..e7faeba --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc
@@ -0,0 +1,225 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/core/score_calibration.h" + +#include <cmath> +#include <memory> +#include <utility> +#include <vector> + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "absl/strings/str_split.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "absl/types/optional.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" + +namespace tflite { +namespace task { +namespace core { +namespace { + +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; + +// Used to prevent log(<=0.0) in ClampedLog() calls. +constexpr float kLogScoreMinimum = 1e-16; + +// Returns the following, depending on x: +// x => threshold: log(x) +// x < threshold: 2 * log(thresh) - log(2 * thresh - x) +// This form (a) is anti-symmetric about the threshold and (b) has continuous +// value and first derivative. This is done to prevent taking the log of values +// close to 0 which can lead to floating point errors and is better than simple +// clamping since it preserves order for scores less than the threshold. +float ClampedLog(float x, float threshold) { + if (x < threshold) { + return 2.0 * std::log(static_cast<double>(threshold)) - + log(2.0 * threshold - x); + } + return std::log(static_cast<double>(x)); +} + +// Applies the specified score transformation to the provided score. +// Currently supports the following, +// IDENTITY : f(x) = x +// LOG : f(x) = log(x) +// INVERSE_LOGISTIC : f(x) = log(x) - log(1-x) +float ApplyScoreTransformation(float score, const ScoreTransformation& type) { + switch (type) { + case ScoreTransformation::kIDENTITY: + return score; + case ScoreTransformation::kINVERSE_LOGISTIC: + return (ClampedLog(score, kLogScoreMinimum) - + ClampedLog(1.0 - score, kLogScoreMinimum)); + case ScoreTransformation::kLOG: + return ClampedLog(score, kLogScoreMinimum); + } +} + +// Builds a single Sigmoid from the label name and associated CSV file line. +StatusOr<Sigmoid> SigmoidFromLabelAndLine(absl::string_view label, + absl::string_view line) { + std::vector<absl::string_view> str_params = absl::StrSplit(line, ','); + if (str_params.size() != 3 && str_params.size() != 4) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Expected 3 or 4 parameters per line in score " + "calibration file, got %d.", + str_params.size()), + TfLiteSupportStatus::kMetadataMalformedScoreCalibrationError); + } + std::vector<float> float_params(4); + for (int i = 0; i < str_params.size(); ++i) { + if (!absl::SimpleAtof(str_params[i], &float_params[i])) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Could not parse score calibration parameter as float: %s.", + str_params[i]), + TfLiteSupportStatus::kMetadataMalformedScoreCalibrationError); + } + } + Sigmoid sigmoid; + sigmoid.label = std::string(label); + sigmoid.scale = float_params[0]; + sigmoid.slope = float_params[1]; + sigmoid.offset = float_params[2]; + if (str_params.size() == 4) { + sigmoid.min_uncalibrated_score = float_params[3]; + } + return sigmoid; +} + +// Converts a tflite::ScoreTransformationType to its +// tflite::task::vision::ScoreTransformation equivalent. +ScoreTransformation ConvertScoreTransformationType( + tflite::ScoreTransformationType type) { + switch (type) { + case tflite::ScoreTransformationType_IDENTITY: + return ScoreTransformation::kIDENTITY; + case tflite::ScoreTransformationType_LOG: + return ScoreTransformation::kLOG; + case tflite::ScoreTransformationType_INVERSE_LOGISTIC: + return ScoreTransformation::kINVERSE_LOGISTIC; + } +} + +} // namespace + +std::ostream& operator<<(std::ostream& os, const Sigmoid& s) { + os << s.label << "," << s.slope << "," << s.offset << "," << s.scale; + if (s.min_uncalibrated_score.has_value()) { + os << "," << s.min_uncalibrated_score.value(); + } + return os; +} + +ScoreCalibration::ScoreCalibration() {} +ScoreCalibration::~ScoreCalibration() {} + +absl::Status ScoreCalibration::InitializeFromParameters( + const SigmoidCalibrationParameters& params) { + sigmoid_parameters_ = std::move(params); + // Fill in the map from label -> sigmoid. + sigmoid_parameters_map_.clear(); + for (const auto& sigmoid : sigmoid_parameters_.sigmoid) { + sigmoid_parameters_map_.insert_or_assign(sigmoid.label, sigmoid); + } + return absl::OkStatus(); +} + +float ScoreCalibration::ComputeCalibratedScore(const std::string& label, + float uncalibrated_score) const { + absl::optional<Sigmoid> sigmoid = FindSigmoidParameters(label); + if (!sigmoid.has_value() || + (sigmoid.value().min_uncalibrated_score.has_value() && + uncalibrated_score < sigmoid.value().min_uncalibrated_score.value())) { + return sigmoid_parameters_.default_score; + } + + float transformed_score = ApplyScoreTransformation( + uncalibrated_score, sigmoid_parameters_.score_transformation); + float scale_shifted_score = + transformed_score * sigmoid.value().slope + sigmoid.value().offset; + + // For numerical stability use 1 / (1+exp(-x)) when scale_shifted_score >= 0 + // and exp(x) / (1+exp(x)) when scale_shifted_score < 0. + if (scale_shifted_score >= 0.0) { + return sigmoid.value().scale / + (1.0 + std::exp(static_cast<double>(-scale_shifted_score))); + } else { + float score_exp = std::exp(static_cast<double>(scale_shifted_score)); + return sigmoid.value().scale * score_exp / (1.0 + score_exp); + } +} + +absl::optional<Sigmoid> ScoreCalibration::FindSigmoidParameters( + const std::string& label) const { + auto it = sigmoid_parameters_map_.find(label); + if (it != sigmoid_parameters_map_.end()) { + return it->second; + } else if (sigmoid_parameters_.default_sigmoid.has_value()) { + return sigmoid_parameters_.default_sigmoid.value(); + } + return absl::nullopt; +} + +StatusOr<SigmoidCalibrationParameters> BuildSigmoidCalibrationParams( + const tflite::ScoreCalibrationOptions& score_calibration_options, + absl::string_view score_calibration_file, + const std::vector<LabelMapItem>& label_map_items) { + // Split file lines and perform sanity checks. + if (score_calibration_file.empty()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Expected non-empty score calibration file."); + } + std::vector<absl::string_view> lines = + absl::StrSplit(score_calibration_file, '\n'); + if (label_map_items.size() != lines.size()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Mismatch between number of labels (%d) and score " + "calibration parameters (%d).", + label_map_items.size(), lines.size()), + TfLiteSupportStatus::kMetadataNumLabelsMismatchError); + } + // Initialize SigmoidCalibrationParameters with its class-agnostic parameters. + SigmoidCalibrationParameters sigmoid_params = {}; + sigmoid_params.score_transformation = ConvertScoreTransformationType( + score_calibration_options.score_transformation()); + sigmoid_params.default_score = score_calibration_options.default_score(); + std::vector<Sigmoid> sigmoid_vector; + // Fill sigmoids for each class with parameters in the file. + for (int i = 0; i < label_map_items.size(); ++i) { + if (lines[i].empty()) { + continue; + } + ASSIGN_OR_RETURN(Sigmoid sigmoid, SigmoidFromLabelAndLine( + label_map_items[i].name, lines[i])); + sigmoid_vector.emplace_back(std::move(sigmoid)); + } + sigmoid_params.sigmoid = std::move(sigmoid_vector); + + return sigmoid_params; +} + +} // namespace core +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h new file mode 100644 index 0000000..6e2b308b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h
@@ -0,0 +1,149 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_SCORE_CALIBRATION_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_SCORE_CALIBRATION_H_ + +#include <iostream> +#include <map> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "absl/types/optional.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/label_map_item.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace core { + +// Sigmoid structure. +struct Sigmoid { + Sigmoid() : scale(1.0) {} + Sigmoid(std::string label, + float slope, + float offset, + float scale = 1.0, + absl::optional<float> min_uncalibrated_score = absl::nullopt) + : label(label), + slope(slope), + offset(offset), + scale(scale), + min_uncalibrated_score(min_uncalibrated_score) {} + + bool operator==(const Sigmoid& other) const { + return label == other.label && slope == other.slope && + offset == other.offset && scale == other.scale && + min_uncalibrated_score == other.min_uncalibrated_score; + } + + // Unique label corresponding to the sigmoid parameters. + std::string label; + float slope; + float offset; + float scale; + absl::optional<float> min_uncalibrated_score; +}; + +std::ostream& operator<<(std::ostream& os, const Sigmoid& s); + +// Transformation function to use for computing transformation scores. +enum class ScoreTransformation { + kIDENTITY, // f(x) = x + kLOG, // f(x) = log(x) + kINVERSE_LOGISTIC // f(x) = log(x) - log(1 - x) +}; + +// Sigmoid calibration parameters. +struct SigmoidCalibrationParameters { + SigmoidCalibrationParameters() + : default_score(0.0), + score_transformation(ScoreTransformation::kIDENTITY) {} + explicit SigmoidCalibrationParameters( + std::vector<Sigmoid> sigmoid, + ScoreTransformation score_transformation = ScoreTransformation::kIDENTITY, + absl::optional<Sigmoid> default_sigmoid = absl::nullopt, + float default_score = 0.0) + : sigmoid(sigmoid), + default_sigmoid(default_sigmoid), + default_score(default_score), + score_transformation(score_transformation) {} + // A vector of Sigmoid associated to the ScoreCalibration instance. + std::vector<Sigmoid> sigmoid; + // If set, this sigmoid will be applied to any non-matching labels. + absl::optional<Sigmoid> default_sigmoid; + // The default score for non-matching labels. Only used if default_sigmoid + // isn't set. + float default_score; + // Function for computing a transformation score prior to sigmoid fitting. + ScoreTransformation score_transformation; +}; + +// This class is used to calibrate predicted scores so that scores are +// comparable across labels. Depending on the particular calibration parameters +// being used, the calibrated scores can also be approximately interpreted as a +// likelihood of being correct. For a given TF Lite model, such parameters are +// typically obtained from TF Lite Metadata (see ScoreCalibrationOptions). +class ScoreCalibration { + public: + ScoreCalibration(); + ~ScoreCalibration(); + + // Transfers input parameters and construct a label to sigmoid map. + absl::Status InitializeFromParameters( + const SigmoidCalibrationParameters& params); + + // Returns a calibrated score given a label string and uncalibrated score. The + // calibrated score will be in the range [0.0, 1.0] and can loosely be + // interpreted as a likelihood of the label being correct. + float ComputeCalibratedScore(const std::string& label, + float uncalibrated_score) const; + + private: + // Finds the sigmoid parameters corresponding to the provided label. + absl::optional<Sigmoid> FindSigmoidParameters(const std::string& label) const; + + // Parameters for internal states. + SigmoidCalibrationParameters sigmoid_parameters_; + + // Maps label strings to the particular sigmoid stored in sigmoid_parameters_. + absl::flat_hash_map<std::string, Sigmoid> sigmoid_parameters_map_; +}; + +// Builds SigmoidCalibrationParameters using data obtained from TF Lite Metadata +// (see ScoreCalibrationOptions in metadata schema). +// +// The provided `score_calibration_file` represents the contents of the score +// calibration associated file (TENSOR_AXIS_SCORE_CALIBRATION), i.e. one set of +// parameters (scale, slope, etc) per line. Each line must be in 1:1 +// correspondence with `label_map_items`, so as to associate each sigmoid to its +// corresponding label name. Returns an error if no valid parameters could be +// built (e.g. malformed parameters). +tflite::support::StatusOr<SigmoidCalibrationParameters> +BuildSigmoidCalibrationParams( + const tflite::ScoreCalibrationOptions& score_calibration_options, + absl::string_view score_calibration_file, + const std::vector<LabelMapItem>& label_map_items); + +} // namespace core +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_SCORE_CALIBRATION_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h index 76a3440..f42d703 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h
@@ -18,18 +18,22 @@ #include <memory> -#include "absl/status/status.h" +#include "absl/base/macros.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow_lite_support/cc/port/configuration_proto_inc.h" #include "tensorflow_lite_support/cc/port/status_macros.h" #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" #include "tensorflow_lite_support/cc/task/core/tflite_engine.h" namespace tflite { namespace task { namespace core { + template <typename T> using EnableIfBaseUntypedTaskApiSubclass = typename std::enable_if< std::is_base_of<BaseUntypedTaskApi, T>::value>::type*; @@ -40,57 +44,132 @@ TaskAPIFactory() = delete; template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr> + ABSL_DEPRECATED( + "Use CreateFromBaseOptions and configure model input from " + "tensorflow_lite_support/cc/task/core/proto/base_options.proto") static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromBuffer( const char* buffer_data, size_t buffer_size, std::unique_ptr<tflite::OpResolver> resolver = - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), - int num_threads = 1) { + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>(), + int num_threads = 1, + const tflite::proto::ComputeSettings& compute_settings = + tflite::proto::ComputeSettings()) { auto engine = absl::make_unique<TfLiteEngine>(std::move(resolver)); - RETURN_IF_ERROR(engine->BuildModelFromFlatBuffer(buffer_data, buffer_size)); - return CreateFromTfLiteEngine<T>(std::move(engine), num_threads); + RETURN_IF_ERROR(engine->BuildModelFromFlatBuffer(buffer_data, buffer_size, + compute_settings)); + return CreateFromTfLiteEngine<T>(std::move(engine), num_threads, + compute_settings); } template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr> + ABSL_DEPRECATED( + "Use CreateFromBaseOptions and configure model input from " + "tensorflow_lite_support/cc/task/core/proto/base_options.proto") static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromFile( - const string& file_name, + const std::string& file_name, std::unique_ptr<tflite::OpResolver> resolver = - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), - int num_threads = 1) { + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>(), + int num_threads = 1, + const tflite::proto::ComputeSettings& compute_settings = + tflite::proto::ComputeSettings()) { auto engine = absl::make_unique<TfLiteEngine>(std::move(resolver)); - RETURN_IF_ERROR(engine->BuildModelFromFile(file_name)); - return CreateFromTfLiteEngine<T>(std::move(engine), num_threads); + RETURN_IF_ERROR(engine->BuildModelFromFile(file_name, compute_settings)); + return CreateFromTfLiteEngine<T>(std::move(engine), num_threads, + compute_settings); } template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr> + ABSL_DEPRECATED( + "Use CreateFromBaseOptions and configure model input from " + "tensorflow_lite_support/cc/task/core/proto/base_options.proto") static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromFileDescriptor( int file_descriptor, std::unique_ptr<tflite::OpResolver> resolver = - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), - int num_threads = 1) { + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>(), + int num_threads = 1, + const tflite::proto::ComputeSettings& compute_settings = + tflite::proto::ComputeSettings()) { auto engine = absl::make_unique<TfLiteEngine>(std::move(resolver)); - RETURN_IF_ERROR(engine->BuildModelFromFileDescriptor(file_descriptor)); - return CreateFromTfLiteEngine<T>(std::move(engine), num_threads); + RETURN_IF_ERROR(engine->BuildModelFromFileDescriptor(file_descriptor, + compute_settings)); + return CreateFromTfLiteEngine<T>(std::move(engine), num_threads, + compute_settings); } template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr> - static tflite::support::StatusOr<std::unique_ptr<T>> - CreateFromExternalFileProto( - const ExternalFile* external_file, - std::unique_ptr<tflite::OpResolver> resolver = - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), - int num_threads = 1) { + ABSL_DEPRECATED( + "Use CreateFromBaseOptions and configure model input from " + "tensorflow_lite_support/cc/task/core/proto/base_options.proto") + static tflite::support:: + StatusOr<std::unique_ptr<T>> CreateFromExternalFileProto( + const ExternalFile* external_file, + std::unique_ptr<tflite::OpResolver> resolver = absl::make_unique< + tflite_shims::ops::builtin::BuiltinOpResolver>(), + int num_threads = 1, + const tflite::proto::ComputeSettings& compute_settings = + tflite::proto::ComputeSettings()) { auto engine = absl::make_unique<TfLiteEngine>(std::move(resolver)); - RETURN_IF_ERROR(engine->BuildModelFromExternalFileProto(external_file)); - return CreateFromTfLiteEngine<T>(std::move(engine), num_threads); + RETURN_IF_ERROR(engine->BuildModelFromExternalFileProto(external_file, + compute_settings)); + return CreateFromTfLiteEngine<T>(std::move(engine), num_threads, + compute_settings); + } + + // Creates a Task API from the provided BaseOptions. A non-default + // OpResolver can be specified in order to support custom Ops or specify a + // subset of built-in Ops. + template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr> + static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromBaseOptions( + const BaseOptions* base_options, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()) { + if (!base_options->has_model_file()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Missing mandatory `model_file` field in `base_options`", + tflite::support::TfLiteSupportStatus::kInvalidArgumentError); + } + + int num_threads = base_options->compute_settings() + .tflite_settings() + .cpu_settings() + .num_threads(); + if (num_threads == 0 || num_threads < -1) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "`num_threads` must be greater than 0 or equal to -1.", + tflite::support::TfLiteSupportStatus::kInvalidArgumentError); + } + + auto engine = absl::make_unique<TfLiteEngine>(std::move(resolver)); + RETURN_IF_ERROR(engine->BuildModelFromExternalFileProto( + &base_options->model_file(), base_options->compute_settings())); + return CreateFromTfLiteEngine<T>(std::move(engine), + base_options->compute_settings()); } private: template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr> static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromTfLiteEngine( std::unique_ptr<TfLiteEngine> engine, - int num_threads) { - RETURN_IF_ERROR(engine->InitInterpreter(num_threads)); + int num_threads, + const tflite::proto::ComputeSettings& compute_settings = + tflite::proto::ComputeSettings()) { + tflite::proto::ComputeSettings settings_copy = + tflite::proto::ComputeSettings(compute_settings); + settings_copy.mutable_tflite_settings() + ->mutable_cpu_settings() + ->set_num_threads(num_threads); + return CreateFromTfLiteEngine<T>(std::move(engine), settings_copy); + } + + template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr> + static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromTfLiteEngine( + std::unique_ptr<TfLiteEngine> engine, + const tflite::proto::ComputeSettings& compute_settings = + tflite::proto::ComputeSettings()) { + RETURN_IF_ERROR(engine->InitInterpreter(compute_settings)); return absl::make_unique<T>(std::move(engine)); } };
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.cc index de733ae..e40c2ac 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.cc
@@ -17,7 +17,7 @@ #include <fstream> -#include "absl/strings/str_cat.h" +#include "absl/strings/str_cat.h" // from @com_google_absl namespace tflite { namespace task { @@ -61,6 +61,22 @@ return buffer; } +int FindIndexByMetadataTensorName( + const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>* + tensor_metadatas, + const std::string& name) { + if (tensor_metadatas == nullptr) { + return -1; + } + for (int i = 0; i < tensor_metadatas->size(); i++) { + if (strcmp(name.data(), tensor_metadatas->Get(i)->name()->c_str()) == 0) { + return i; + } + } + // Returns -1 if not found. + return -1; +} + } // namespace core } // namespace task } // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h index ced3dbc..7cde474 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h
@@ -21,119 +21,164 @@ #include <numeric> #include <vector> -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_cat.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/type_to_tflitetype.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" namespace tflite { namespace task { namespace core { -// Checks if data type of tensor is T and returns the pointer casted to T if -// applicable, returns nullptr if tensor type is not T. -// See type_to_tflitetype.h for a mapping from plain C++ type to TfLiteType. -template <typename T> -T* TypedTensor(const TfLiteTensor* tensor_ptr) { - if (tensor_ptr->type == typeToTfLiteType<T>()) { - return reinterpret_cast<T*>(tensor_ptr->data.raw); - } - return nullptr; -} - // Checks and returns type of a tensor, fails if tensor type is not T. template <typename T> -T* AssertAndReturnTypedTensor(const TfLiteTensor* tensor) { - if (T* v = TypedTensor<T>(tensor)) - return v; - // TODO(b/150903834): throw exceptions instead - TF_LITE_ASSERT(tensor->data.raw); - TF_LITE_FATAL(absl::StrCat("Type mismatch for tensor ", tensor->name, - ". Requested ", - TfLiteTypeGetName(typeToTfLiteType<T>()), ", got ", - TfLiteTypeGetName(tensor->type), ".") - .c_str()); +tflite::support::StatusOr<T*> AssertAndReturnTypedTensor( + const TfLiteTensor* tensor) { + if (!tensor->data.raw) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("Tensor (%s) has no raw data.", tensor->name)); + } + + // Checks if data type of tensor is T and returns the pointer casted to T if + // applicable, returns nullptr if tensor type is not T. + // See type_to_tflitetype.h for a mapping from plain C++ type to TfLiteType. + if (tensor->type == typeToTfLiteType<T>()) { + return reinterpret_cast<T*>(tensor->data.raw); + } + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("Type mismatch for tensor %s. Required %d, got %d.", + tensor->name, typeToTfLiteType<T>(), tensor->bytes)); } // Populates tensor with array of data, fails if data type doesn't match tensor // type or has not the same number of elements. -template <typename T> -inline void PopulateTensor(const T* data, - int num_elements, - TfLiteTensor* tensor) { - T* v = AssertAndReturnTypedTensor<T>(tensor); +// Note: std::negation is not used because it is from C++17, where the code will +// be compiled using C++14 in OSS. +template < + typename T, + typename = std::enable_if_t<std::is_same<T, std::string>::value == false>> +inline absl::Status PopulateTensor(const T* data, + int num_elements, + TfLiteTensor* tensor) { + T* v; + ASSIGN_OR_RETURN(v, AssertAndReturnTypedTensor<T>(tensor)); size_t bytes = num_elements * sizeof(T); - // TODO(b/150903834): throw exceptions instead - TF_LITE_ASSERT(tensor->bytes == bytes); + if (tensor->bytes != bytes) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("tensor->bytes (%d) != bytes (%d)", tensor->bytes, + bytes)); + } memcpy(v, data, bytes); + return absl::OkStatus(); } // Populates tensor with vector of data, fails if data type doesn't match tensor // type or has not the same number of elements. template <typename T> -inline void PopulateTensor(const std::vector<T>& data, TfLiteTensor* tensor) { +inline absl::Status PopulateTensor(const std::vector<T>& data, + TfLiteTensor* tensor) { return PopulateTensor<T>(data.data(), data.size(), tensor); } template <> -inline void PopulateTensor<std::string>(const std::vector<std::string>& data, - TfLiteTensor* tensor) { +inline absl::Status PopulateTensor<std::string>( + const std::vector<std::string>& data, + TfLiteTensor* tensor) { if (tensor->type != kTfLiteString) { - TF_LITE_FATAL(absl::StrCat("Type mismatch for tensor ", tensor->name, - ". Requested STRING, got ", - TfLiteTypeGetName(tensor->type), ".") - .c_str()); + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("Type mismatch for tensor %s. Required STRING, got %d.", + tensor->name, tensor->bytes)); } tflite::DynamicBuffer input_buf; for (const auto& value : data) { input_buf.AddString(value.data(), value.length()); } input_buf.WriteToTensorAsVector(tensor); + return absl::OkStatus(); } // Populates tensor one data item, fails if data type doesn't match tensor // type. template <typename T> -inline void PopulateTensor(const T& data, TfLiteTensor* tensor) { - T* v = AssertAndReturnTypedTensor<T>(tensor); +inline absl::Status PopulateTensor(const T& data, TfLiteTensor* tensor) { + T* v; + ASSIGN_OR_RETURN(v, AssertAndReturnTypedTensor<T>(tensor)); *v = data; + return absl::OkStatus(); } template <> -inline void PopulateTensor<std::string>(const std::string& data, - TfLiteTensor* tensor) { +inline absl::Status PopulateTensor<std::string>(const std::string& data, + TfLiteTensor* tensor) { tflite::DynamicBuffer input_buf; input_buf.AddString(data.data(), data.length()); input_buf.WriteToTensorAsVector(tensor); + return absl::OkStatus(); } // Populates a vector from the tensor, fails if data type doesn't match tensor // type. template <typename T> -inline void PopulateVector(const TfLiteTensor* tensor, std::vector<T>* data) { - AssertAndReturnTypedTensor<T>(tensor); - const T* results = GetTensorData<T>(tensor); +inline absl::Status PopulateVector(const TfLiteTensor* tensor, + std::vector<T>* data) { + const T* v; + ASSIGN_OR_RETURN(v, AssertAndReturnTypedTensor<T>(tensor)); size_t num = tensor->bytes / sizeof(tensor->type); data->reserve(num); for (size_t i = 0; i < num; i++) { - data->emplace_back(results[i]); + data->emplace_back(v[i]); } + return absl::OkStatus(); } template <> -inline void PopulateVector<std::string>(const TfLiteTensor* tensor, - std::vector<std::string>* data) { - AssertAndReturnTypedTensor<std::string>(tensor); +inline absl::Status PopulateVector<std::string>( + const TfLiteTensor* tensor, + std::vector<std::string>* data) { + if (tensor->type != typeToTfLiteType<std::string>()) { + return absl::InvalidArgumentError("not of type string"); + } + int num = GetStringCount(tensor); data->reserve(num); for (int i = 0; i < num; i++) { const auto& strref = tflite::GetString(tensor, i); data->emplace_back(strref.str, strref.len); } + return absl::OkStatus(); +} + +// Populates vector to a repeated field. +// Note: std::negation is not used because it is from C++17, where the code will +// be compiled using C++14 in OSS. +template < + class TRepeatedField, + class T = float, + typename = std::enable_if_t<std::is_same<T, std::string>::value == false>> +inline absl::Status PopulateVectorToRepeated(const TfLiteTensor* tensor, + TRepeatedField* data) { + const T* v; + ASSIGN_OR_RETURN(v, AssertAndReturnTypedTensor<T>(tensor)); + size_t num = tensor->bytes / sizeof(tensor->type); + data->Resize(num, T()); + T* pdata = data->mutable_data(); + for (size_t i = 0; i < num; i++) { + pdata[i] = v[i]; + } + return absl::OkStatus(); } // Returns the reversely sorted indices of a vector. @@ -158,6 +203,14 @@ // Loads binary content of a file into a string. std::string LoadBinaryContent(const char* filename); +// Gets the index from a vector of tensors with name specified inside metadata. +// The range of the return value should be [0, output_tensor_size). If not +// found, returns -1. +int FindIndexByMetadataTensorName( + const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>* + tensor_metadatas, + const std::string& name); + // Gets the tensor from a vector of tensors with name specified inside metadata. template <typename TensorType> static TensorType* FindTensorByName( @@ -169,12 +222,8 @@ tensor_metadatas->size() != tensors.size()) { return nullptr; } - for (flatbuffers::uoffset_t i = 0; i < tensor_metadatas->size(); i++) { - if (strcmp(name.data(), tensor_metadatas->Get(i)->name()->c_str()) == 0) { - return tensors[i]; - } - } - return nullptr; + int i = FindIndexByMetadataTensorName(tensor_metadatas, name); + return i == -1 ? nullptr : tensors[i]; } } // namespace core
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc index 6230e5c..41e06389 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
@@ -15,63 +15,37 @@ #include "tensorflow_lite_support/cc/task/core/tflite_engine.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" +#include <memory> + +#include "absl/strings/match.h" // from @com_google_absl +#include "absl/strings/str_cat.h" // from @com_google_absl #include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/core/shims/cc/kernels/register.h" +#include "tensorflow/lite/core/shims/cc/tools/verifier.h" #include "tensorflow/lite/stderr_reporter.h" -#include "tensorflow/lite/tools/verifier.h" #include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/configuration_proto_inc.h" #include "tensorflow_lite_support/cc/port/status_macros.h" #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" -#if TFLITE_USE_C_API -#include "tensorflow/lite/c/c_api_experimental.h" -#else -#include "tensorflow/lite/kernels/register.h" -#endif - namespace tflite { namespace task { namespace core { -#ifdef __ANDROID__ -// https://github.com/opencv/opencv/issues/14906 -// "ios_base::Init" object is not a part of Android's "iostream" header (in case -// of clang toolchain, NDK 20). -// -// Ref1: -// https://en.cppreference.com/w/cpp/io/ios_base/Init -// The header <iostream> behaves as if it defines (directly or indirectly) -// an instance of std::ios_base::Init with static storage duration -// -// Ref2: -// https://github.com/gcc-mirror/gcc/blob/gcc-8-branch/libstdc%2B%2B-v3/include/std/iostream#L73-L74 -static std::ios_base::Init s_iostream_initializer; -#endif - using ::absl::StatusCode; +using ::tflite::proto::ComputeSettings; using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::InterpreterCreationResources; using ::tflite::support::TfLiteSupportStatus; -int TfLiteEngine::ErrorReporter::Report(const char* format, va_list args) { - return std::vsnprintf(error_message, sizeof(error_message), format, args); -} - bool TfLiteEngine::Verifier::Verify(const char* data, int length, tflite::ErrorReporter* reporter) { - return tflite::Verify(data, length, *op_resolver_, reporter); + return tflite_shims::Verify(data, length, reporter); } -#if TFLITE_USE_C_API TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver) - : model_(nullptr, TfLiteModelDelete), - resolver_(std::move(resolver)), - verifier_(resolver_.get()) {} -#else -TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver) - : model_(), resolver_(std::move(resolver)), verifier_(resolver_.get()) {} -#endif + : model_(), resolver_(std::move(resolver)), verifier_() {} std::vector<TfLiteTensor*> TfLiteEngine::GetInputs() { Interpreter* interpreter = this->interpreter(); @@ -95,46 +69,33 @@ return tensors; } -// The following function is adapted from the code in -// tflite::FlatBufferModel::VerifyAndBuildFromBuffer. -void TfLiteEngine::VerifyAndBuildModelFromBuffer(const char* buffer_data, - size_t buffer_size) { -#if TFLITE_USE_C_API - // First verify with the base flatbuffers verifier. - // This verifies that the model is a valid flatbuffer model. - flatbuffers::Verifier base_verifier( - reinterpret_cast<const uint8_t*>(buffer_data), buffer_size); - if (!VerifyModelBuffer(base_verifier)) { - TF_LITE_REPORT_ERROR(&error_reporter_, - "The model is not a valid Flatbuffer buffer"); - model_ = nullptr; - return; - } - // Next verify with the extra verifier. This verifies that the model only - // uses operators supported by the OpResolver. - if (!verifier_.Verify(buffer_data, buffer_size, &error_reporter_)) { - model_ = nullptr; - return; - } - // Build the model. - model_.reset(TfLiteModelCreate(buffer_data, buffer_size)); -#else - model_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer( - buffer_data, buffer_size, &verifier_, &error_reporter_); -#endif +void TfLiteEngine::VerifyAndBuildModelFromBuffer( + const char* buffer_data, + size_t buffer_size, + TfLiteVerifier* extra_verifier) { + model_ = tflite_shims::FlatBufferModel::VerifyAndBuildFromBuffer( + buffer_data, buffer_size, extra_verifier, &error_reporter_); } -absl::Status TfLiteEngine::InitializeFromModelFileHandler() { +absl::Status TfLiteEngine::InitializeFromModelFileHandler( + const tflite::proto::ComputeSettings& compute_settings) { const char* buffer_data = model_file_handler_->GetFileContent().data(); size_t buffer_size = model_file_handler_->GetFileContent().size(); - VerifyAndBuildModelFromBuffer(buffer_data, buffer_size); + VerifyAndBuildModelFromBuffer(buffer_data, buffer_size, &verifier_); if (model_ == nullptr) { + static constexpr char kInvalidFlatbufferMessage[] = + "The model is not a valid Flatbuffer"; // To be replaced with a proper switch-case when TF Lite model builder // returns a `TfLiteStatus` code capturing this type of error. - if (absl::StrContains(error_reporter_.error_message, - "The model is not a valid Flatbuffer")) { + if (absl::StrContains(error_reporter_.message(), + kInvalidFlatbufferMessage)) { return CreateStatusWithPayload( - StatusCode::kInvalidArgument, error_reporter_.error_message, + StatusCode::kInvalidArgument, error_reporter_.message(), + TfLiteSupportStatus::kInvalidFlatBufferError); + } else if (absl::StrContains(error_reporter_.message(), + "Error loading model from buffer")) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, kInvalidFlatbufferMessage, TfLiteSupportStatus::kInvalidFlatBufferError); } else { // TODO(b/154917059): augment status with another `TfLiteStatus` code when @@ -144,7 +105,7 @@ StatusCode::kUnknown, absl::StrCat( "Could not build model from the provided pre-loaded flatbuffer: ", - error_reporter_.error_message)); + error_reporter_.message())); } } @@ -156,125 +117,125 @@ return absl::OkStatus(); } -absl::Status TfLiteEngine::BuildModelFromFlatBuffer(const char* buffer_data, - size_t buffer_size) { +absl::Status TfLiteEngine::BuildModelFromFlatBuffer( + const char* buffer_data, + size_t buffer_size, + const tflite::proto::ComputeSettings& compute_settings) { if (model_) { return CreateStatusWithPayload(StatusCode::kInternal, "Model already built"); } - external_file_.set_file_content(std::string(buffer_data, buffer_size)); + external_file_ = std::make_unique<ExternalFile>(); + external_file_->set_file_content(std::string(buffer_data, buffer_size)); ASSIGN_OR_RETURN( model_file_handler_, - ExternalFileHandler::CreateFromExternalFile(&external_file_)); - return InitializeFromModelFileHandler(); + ExternalFileHandler::CreateFromExternalFile(external_file_.get())); + return InitializeFromModelFileHandler(compute_settings); } -absl::Status TfLiteEngine::BuildModelFromFile(const std::string& file_name) { +absl::Status TfLiteEngine::BuildModelFromFile( + const std::string& file_name, + const tflite::proto::ComputeSettings& compute_settings) { if (model_) { return CreateStatusWithPayload(StatusCode::kInternal, "Model already built"); } - external_file_.set_file_name(file_name); + if (external_file_ == nullptr) { + external_file_ = std::make_unique<ExternalFile>(); + } + external_file_->set_file_name(file_name); ASSIGN_OR_RETURN( model_file_handler_, - ExternalFileHandler::CreateFromExternalFile(&external_file_)); - return InitializeFromModelFileHandler(); + ExternalFileHandler::CreateFromExternalFile(external_file_.get())); + return InitializeFromModelFileHandler(compute_settings); } -absl::Status TfLiteEngine::BuildModelFromFileDescriptor(int file_descriptor) { +absl::Status TfLiteEngine::BuildModelFromFileDescriptor( + int file_descriptor, + const tflite::proto::ComputeSettings& compute_settings) { if (model_) { return CreateStatusWithPayload(StatusCode::kInternal, "Model already built"); } - external_file_.mutable_file_descriptor_meta()->set_fd(file_descriptor); + if (external_file_ == nullptr) { + external_file_ = std::make_unique<ExternalFile>(); + } + external_file_->mutable_file_descriptor_meta()->set_fd(file_descriptor); ASSIGN_OR_RETURN( model_file_handler_, - ExternalFileHandler::CreateFromExternalFile(&external_file_)); - return InitializeFromModelFileHandler(); + ExternalFileHandler::CreateFromExternalFile(external_file_.get())); + return InitializeFromModelFileHandler(compute_settings); } absl::Status TfLiteEngine::BuildModelFromExternalFileProto( - const ExternalFile* external_file) { + const ExternalFile* external_file, + const tflite::proto::ComputeSettings& compute_settings) { if (model_) { return CreateStatusWithPayload(StatusCode::kInternal, "Model already built"); } ASSIGN_OR_RETURN(model_file_handler_, ExternalFileHandler::CreateFromExternalFile(external_file)); - return InitializeFromModelFileHandler(); + return InitializeFromModelFileHandler(compute_settings); +} + +absl::Status TfLiteEngine::BuildModelFromExternalFileProto( + std::unique_ptr<ExternalFile> external_file) { + if (model_) { + return CreateStatusWithPayload(StatusCode::kInternal, + "Model already built"); + } + external_file_ = std::move(external_file); + ASSIGN_OR_RETURN( + model_file_handler_, + ExternalFileHandler::CreateFromExternalFile(external_file_.get())); + // Dummy proto. InitializeFromModelFileHandler doesn't use this proto. + tflite::proto::ComputeSettings compute_settings; + return InitializeFromModelFileHandler(compute_settings); } absl::Status TfLiteEngine::InitInterpreter(int num_threads) { tflite::proto::ComputeSettings compute_settings; - return InitInterpreter(compute_settings, num_threads); + compute_settings.mutable_tflite_settings() + ->mutable_cpu_settings() + ->set_num_threads(num_threads); + return InitInterpreter(compute_settings); } -#if TFLITE_USE_C_API -const TfLiteRegistration* FindBuiltinOp(void* user_data, - TfLiteBuiltinOperator builtin_op, - int version) { - OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data); - tflite::BuiltinOperator op = static_cast<tflite::BuiltinOperator>(builtin_op); - return op_resolver->FindOp(op, version); -} - -const TfLiteRegistration* FindCustomOp(void* user_data, - const char* custom_op, - int version) { - OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data); - return op_resolver->FindOp(custom_op, version); -} -#endif - +// TODO(b/183798104): deprecate num_threads in VK task protos. +// Deprecated. Use the following method, and configure `num_threads` through +// `compute_settings`, i.e. in `CPUSettings`: +// absl::Status TfLiteEngine::InitInterpreter( +// const tflite::proto::ComputeSettings& compute_settings) absl::Status TfLiteEngine::InitInterpreter( const tflite::proto::ComputeSettings& compute_settings, int num_threads) { + ComputeSettings settings_copy = ComputeSettings(compute_settings); + settings_copy.mutable_tflite_settings() + ->mutable_cpu_settings() + ->set_num_threads(num_threads); + return InitInterpreter(settings_copy); +} + +absl::Status TfLiteEngine::InitInterpreter( + const tflite::proto::ComputeSettings& compute_settings) { if (model_ == nullptr) { return CreateStatusWithPayload( StatusCode::kInternal, "TF Lite FlatBufferModel is null. Please make sure to call one of the " "BuildModelFrom methods before calling InitInterpreter."); } -#if TFLITE_USE_C_API - std::function<absl::Status(TfLiteDelegate*, - std::unique_ptr<Interpreter, InterpreterDeleter>*)> - initializer = - [this, num_threads]( - TfLiteDelegate* optional_delegate, - std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out) - -> absl::Status { - std::unique_ptr<TfLiteInterpreterOptions, - void (*)(TfLiteInterpreterOptions*)> - options{TfLiteInterpreterOptionsCreate(), - TfLiteInterpreterOptionsDelete}; - TfLiteInterpreterOptionsSetOpResolver(options.get(), FindBuiltinOp, - FindCustomOp, resolver_.get()); - TfLiteInterpreterOptionsSetNumThreads(options.get(), num_threads); - if (optional_delegate != nullptr) { - TfLiteInterpreterOptionsAddDelegate(options.get(), optional_delegate); - } - interpreter_out->reset( - TfLiteInterpreterCreateWithSelectedOps(model_.get(), options.get())); - if (*interpreter_out == nullptr) { - return CreateStatusWithPayload( - StatusCode::kAborted, - absl::StrCat("Could not build the TF Lite interpreter: " - "TfLiteInterpreterCreateWithSelectedOps failed: ", - error_reporter_.error_message)); - } - return absl::OkStatus(); - }; -#else auto initializer = - [this, num_threads]( - std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out) + [this](const InterpreterCreationResources& resources, + std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out) -> absl::Status { - if (tflite::InterpreterBuilder(*model_, *resolver_)( - interpreter_out, num_threads) != kTfLiteOk) { + tflite_shims::InterpreterBuilder interpreter_builder(*model_, *resolver_); + resources.ApplyTo(&interpreter_builder); + if (interpreter_builder(interpreter_out) != kTfLiteOk) { return CreateStatusWithPayload( StatusCode::kUnknown, absl::StrCat("Could not build the TF Lite interpreter: ", - error_reporter_.error_message)); + error_reporter_.message())); } if (*interpreter_out == nullptr) { return CreateStatusWithPayload(StatusCode::kInternal, @@ -282,14 +243,24 @@ } return absl::OkStatus(); }; -#endif absl::Status status = interpreter_.InitializeWithFallback(initializer, compute_settings); - - if (!status.ok() && - !status.GetPayload(tflite::support::kTfLiteSupportPayload).has_value()) { - status = CreateStatusWithPayload(status.code(), status.message()); + if (!status.ok()) { + if (absl::StrContains(error_reporter_.previous_message(), + "Encountered unresolved custom op")) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + error_reporter_.previous_message(), + TfLiteSupportStatus::kUnsupportedCustomOp); + } else if (absl::StrContains(error_reporter_.previous_message(), + "Didn't find op for builtin opcode")) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, error_reporter_.previous_message(), + TfLiteSupportStatus::kUnsupportedBuiltinOp); + } else if (!status.GetPayload(tflite::support::kTfLiteSupportPayload) + .has_value()) { + return CreateStatusWithPayload(status.code(), status.message()); + } } return status; }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h index bc55f6b0..0cbaa73 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
@@ -18,31 +18,21 @@ #include <memory> -#include "absl/memory/memory.h" -#include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/c/common.h" +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl #include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/core/shims/c/common.h" +#include "tensorflow/lite/core/shims/cc/interpreter.h" +#include "tensorflow/lite/core/shims/cc/kernels/register.h" +#include "tensorflow/lite/core/shims/cc/model.h" +#include "tensorflow_lite_support/cc/port/configuration_proto_inc.h" #include "tensorflow_lite_support/cc/port/tflite_wrapper.h" +#include "tensorflow_lite_support/cc/task/core/error_reporter.h" #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" -// If compiled with -DTFLITE_USE_C_API, this file will use the TF Lite C API -// rather than the TF Lite C++ API. -// TODO(b/168025296): eliminate the '#if TFLITE_USE_C_API' directives here and -// elsewhere and instead use the C API unconditionally, once we have a suitable -// replacement for the features of tflite::support::TfLiteInterpreterWrapper. -#if TFLITE_USE_C_API -#include "tensorflow/lite/c/c_api.h" -#include "tensorflow/lite/core/api/verifier.h" -#include "tensorflow/lite/tools/verifier.h" -#else -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/model.h" -#endif - namespace tflite { namespace task { namespace core { @@ -52,77 +42,42 @@ class TfLiteEngine { public: // Types. - using InterpreterWrapper = tflite::support::TfLiteInterpreterWrapper; -#if TFLITE_USE_C_API - using Model = struct TfLiteModel; - using Interpreter = struct TfLiteInterpreter; - using ModelDeleter = void (*)(Model*); - using InterpreterDeleter = InterpreterWrapper::InterpreterDeleter; -#else - using Model = tflite::FlatBufferModel; - using Interpreter = tflite::Interpreter; + using InterpreterWrapper = ::tflite::support::TfLiteInterpreterWrapper; + using Model = ::tflite_shims::FlatBufferModel; + using Interpreter = ::tflite_shims::Interpreter; using ModelDeleter = std::default_delete<Model>; using InterpreterDeleter = std::default_delete<Interpreter>; -#endif // Constructors. explicit TfLiteEngine( std::unique_ptr<tflite::OpResolver> resolver = - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); // Model is neither copyable nor movable. TfLiteEngine(const TfLiteEngine&) = delete; TfLiteEngine& operator=(const TfLiteEngine&) = delete; // Accessors. static int32_t InputCount(const Interpreter* interpreter) { -#if TFLITE_USE_C_API - return TfLiteInterpreterGetInputTensorCount(interpreter); -#else return interpreter->inputs().size(); -#endif } static int32_t OutputCount(const Interpreter* interpreter) { -#if TFLITE_USE_C_API - return TfLiteInterpreterGetOutputTensorCount(interpreter); -#else return interpreter->outputs().size(); -#endif } static TfLiteTensor* GetInput(Interpreter* interpreter, int index) { -#if TFLITE_USE_C_API - return TfLiteInterpreterGetInputTensor(interpreter, index); -#else return interpreter->tensor(interpreter->inputs()[index]); -#endif } // Same as above, but const. static const TfLiteTensor* GetInput(const Interpreter* interpreter, int index) { -#if TFLITE_USE_C_API - return TfLiteInterpreterGetInputTensor(interpreter, index); -#else return interpreter->tensor(interpreter->inputs()[index]); -#endif } static TfLiteTensor* GetOutput(Interpreter* interpreter, int index) { -#if TFLITE_USE_C_API - // We need a const_cast here, because the TF Lite C API only has a non-const - // version of GetOutputTensor (in part because C doesn't support overloading - // on const). - return const_cast<TfLiteTensor*>( - TfLiteInterpreterGetOutputTensor(interpreter, index)); -#else return interpreter->tensor(interpreter->outputs()[index]); -#endif } // Same as above, but const. static const TfLiteTensor* GetOutput(const Interpreter* interpreter, int index) { -#if TFLITE_USE_C_API - return TfLiteInterpreterGetOutputTensor(interpreter, index); -#else return interpreter->tensor(interpreter->outputs()[index]); -#endif } std::vector<TfLiteTensor*> GetInputs(); @@ -140,51 +95,61 @@ // whose ownership remains with the caller, and which must outlive the current // object. This performs extra verification on the input data using // tflite::Verify. - absl::Status BuildModelFromFlatBuffer(const char* buffer_data, - size_t buffer_size); + absl::Status BuildModelFromFlatBuffer( + const char* buffer_data, + size_t buffer_size, + const tflite::proto::ComputeSettings& compute_settings = + tflite::proto::ComputeSettings()); // Builds the TF Lite model from a given file. - absl::Status BuildModelFromFile(const std::string& file_name); + absl::Status BuildModelFromFile( + const std::string& file_name, + const tflite::proto::ComputeSettings& compute_settings = + tflite::proto::ComputeSettings()); // Builds the TF Lite model from a given file descriptor using mmap(2). - absl::Status BuildModelFromFileDescriptor(int file_descriptor); + absl::Status BuildModelFromFileDescriptor( + int file_descriptor, + const tflite::proto::ComputeSettings& compute_settings = + tflite::proto::ComputeSettings()); // Builds the TFLite model from the provided ExternalFile proto, which must // outlive the current object. absl::Status BuildModelFromExternalFileProto( - const ExternalFile* external_file); + const ExternalFile* external_file, + const tflite::proto::ComputeSettings& compute_settings = + tflite::proto::ComputeSettings()); + + // Builds the TFLite model from the provided ExternalFile proto, and take + // ownership of ExternalFile proto. + absl::Status BuildModelFromExternalFileProto( + std::unique_ptr<ExternalFile> external_file); // Initializes interpreter with encapsulated model. // Note: setting num_threads to -1 has for effect to let TFLite runtime set // the value. absl::Status InitInterpreter(int num_threads = 1); - // Same as above, but allows specifying `compute_settings` for acceleration. + // Initializes interpreter with acceleration configurations. + absl::Status InitInterpreter( + const tflite::proto::ComputeSettings& compute_settings); + + // Deprecated. Use the following method, and configure `num_threads` through + // `compute_settings`, i.e. in `CPUSettings`: + // absl::Status TfLiteEngine::InitInterpreter( + // const tflite::proto::ComputeSettings& compute_settings) absl::Status InitInterpreter( const tflite::proto::ComputeSettings& compute_settings, - int num_threads = 1); + int num_threads); // Cancels the on-going `Invoke()` call if any and if possible. This method // can be called from a different thread than the one where `Invoke()` is // running. - void Cancel() { -#if TFLITE_USE_C_API - // NOP. -#else - interpreter_.Cancel(); -#endif - } + void Cancel() { interpreter_.Cancel(); } protected: - // TF Lite's DefaultErrorReporter() outputs to stderr. This one captures the - // error into a string so that it can be used to complement tensorflow::Status + // Custom error reporter capturing and printing to stderr low-level TF Lite // error messages. - struct ErrorReporter : public tflite::ErrorReporter { - // Last error message captured by this error reporter. - char error_message[256]; - int Report(const char* format, va_list args) override; - }; - // Custom error reporter capturing low-level TF Lite error messages. ErrorReporter error_reporter_; private: @@ -192,13 +157,9 @@ // the FlatBuffer data provided as input. class Verifier : public tflite::TfLiteVerifier { public: - explicit Verifier(const tflite::OpResolver* op_resolver) - : op_resolver_(op_resolver) {} bool Verify(const char* data, int length, tflite::ErrorReporter* reporter) override; - // The OpResolver to be used to build the TF Lite interpreter. - const tflite::OpResolver* op_resolver_; }; // Verifies that the supplied buffer refers to a valid flatbuffer model, @@ -206,13 +167,23 @@ // that was passed to the TfLiteEngine constructor, and then builds // the model from the buffer and stores it in 'model_'. void VerifyAndBuildModelFromBuffer(const char* buffer_data, - size_t buffer_size); + size_t buffer_size, + TfLiteVerifier* extra_verifier = nullptr); // Gets the buffer from the file handler; verifies and builds the model // from the buffer; if successful, sets 'model_metadata_extractor_' to be // a TF Lite Metadata extractor for the model; and calculates an appropriate // return Status, - absl::Status InitializeFromModelFileHandler(); + // TODO(b/192726981): Remove `compute_settings` as it's not in use. + absl::Status InitializeFromModelFileHandler( + const tflite::proto::ComputeSettings& compute_settings = + tflite::proto::ComputeSettings()); + + // ExternalFile and corresponding ExternalFileHandler for models loaded from + // disk or file descriptor. + // Make sure ExternalFile proto outlives the model and the interpreter. + std::unique_ptr<ExternalFile> external_file_; + std::unique_ptr<ExternalFileHandler> model_file_handler_; // TF Lite model and interpreter for actual inference. std::unique_ptr<Model, ModelDeleter> model_; @@ -230,11 +201,6 @@ // Extra verifier for FlatBuffer input data. Verifier verifier_; - - // ExternalFile and corresponding ExternalFileHandler for models loaded from - // disk or file descriptor. - ExternalFile external_file_; - std::unique_ptr<ExternalFileHandler> model_file_handler_; }; } // namespace core
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/BUILD new file mode 100644 index 0000000..c412dee --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/BUILD
@@ -0,0 +1,169 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + +package( + default_visibility = ["//tensorflow_lite_support:internal"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library_with_tflite( + name = "processor", + srcs = ["processor.cc"], + hdrs = ["processor.h"], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:common", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library_with_tflite( + name = "image_preprocessor", + srcs = ["image_preprocessor.cc"], + hdrs = ["image_preprocessor.h"], + tflite_deps = [ + ":processor", + "//tensorflow_lite_support/cc/task/vision/utils:image_tensor_specs", + ], + deps = [ + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "@com_google_absl//absl/memory", + ], +) + +cc_library_with_tflite( + name = "classification_postprocessor", + srcs = ["classification_postprocessor.cc"], + hdrs = ["classification_postprocessor.h"], + tflite_deps = [ + ":processor", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + ], + deps = [ + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:classification_head", + "//tensorflow_lite_support/cc/task/core:label_map_item", + "//tensorflow_lite_support/cc/task/core:score_calibration", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/processor/proto:class_cc_proto", + "//tensorflow_lite_support/cc/task/processor/proto:classification_options_cc_proto", + "//tensorflow_lite_support/cc/task/processor/proto:classifications_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite/c:c_api_types", + ], +) + +cc_library_with_tflite( + name = "embedding_postprocessor", + srcs = ["embedding_postprocessor.cc"], + hdrs = ["embedding_postprocessor.h"], + tflite_deps = [ + ":processor", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + ], + deps = [ + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/processor/proto:embedding_options_cc_proto", + "@com_google_absl//absl/status", + ], +) + +cc_library_with_tflite( + name = "audio_preprocessor", + srcs = ["audio_preprocessor.cc"], + hdrs = ["audio_preprocessor.h"], + tflite_deps = [ + ":processor", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/audio/core:audio_buffer", + "//tensorflow_lite_support/cc/task/core:task_utils", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library_with_tflite( + name = "regex_preprocessor", + srcs = ["regex_preprocessor.cc"], + hdrs = ["regex_preprocessor.h"], + tflite_deps = [ + ":text_preprocessor_header", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/text/tokenizers:regex_tokenizer", + "@com_google_absl//absl/status", + ], +) + +cc_library_with_tflite( + name = "bert_preprocessor", + srcs = ["bert_preprocessor.cc"], + hdrs = ["bert_preprocessor.h"], + tflite_deps = [ + ":text_preprocessor_header", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/text/tokenizers:tokenizer", + "//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils", + "//tensorflow_lite_support/cc/utils:common_utils", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_library_with_tflite( + name = "text_preprocessor_header", + hdrs = ["text_preprocessor.h"], + tflite_deps = [ + ":processor", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + ], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "@com_google_absl//absl/status", + ], +) + +cc_library_with_tflite( + name = "text_preprocessor", + srcs = ["text_preprocessor.cc"], + hdrs = ["text_preprocessor.h"], + tflite_deps = [ + ":processor", + ":bert_preprocessor", + ":regex_preprocessor", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "@com_google_absl//absl/status", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc new file mode 100644 index 0000000..254d068 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc
@@ -0,0 +1,167 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/processor/audio_preprocessor.h" + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" + +namespace tflite { +namespace task { +namespace processor { + +namespace { +// Looks up AudioProperty from metadata. If no error occurs, the returned value +// is guaranteed to be valid (not null). +tflite::support::StatusOr<const AudioProperties*> GetAudioPropertiesSafe( + const TensorMetadata* tensor_metadata, + int input_index) { + if (tensor_metadata->content() == nullptr || + tensor_metadata->content()->content_properties() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + "Missing audio format metadata in the model metadata.", + tflite::support::TfLiteSupportStatus::kMetadataNotFoundError); + } + + ContentProperties type = + tensor_metadata->content()->content_properties_type(); + + if (type != ContentProperties_AudioProperties) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("Expected AudioProperties for tensor ", + tensor_metadata->name() + ? tensor_metadata->name()->str() + : absl::StrFormat("#%d", input_index), + ", got ", EnumNameContentProperties(type), "."), + tflite::support::TfLiteSupportStatus:: + kMetadataInvalidContentPropertiesError); + } + + auto props = + tensor_metadata->content()->content_properties_as_AudioProperties(); + if (props == nullptr) { + return support::CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrCat("Expected AudioProperties for tensor", + tensor_metadata->name() + ? tensor_metadata->name()->str() + : absl::StrFormat("#%d", input_index), + ", got nullptr"), + tflite::support::TfLiteSupportStatus:: + kMetadataInvalidContentPropertiesError); + } + return props; +} +} // namespace + +/* static */ +tflite::support::StatusOr<std::unique_ptr<AudioPreprocessor>> +AudioPreprocessor::Create(tflite::task::core::TfLiteEngine* engine, + const std::initializer_list<int> input_indices) { + ASSIGN_OR_RETURN(auto processor, + Processor::Create<AudioPreprocessor>( + /* num_expected_tensors = */ 1, engine, input_indices)); + + RETURN_IF_ERROR(processor->Init()); + return processor; +} + +absl::Status AudioPreprocessor::Init() { + RETURN_IF_ERROR(SetAudioFormatFromMetadata()); + RETURN_IF_ERROR(CheckAndSetInputs()); + return absl::OkStatus(); +} + +absl::Status AudioPreprocessor::SetAudioFormatFromMetadata() { + ASSIGN_OR_RETURN( + const AudioProperties* props, + GetAudioPropertiesSafe(GetTensorMetadata(), tensor_indices_.at(0))); + audio_format_.channels = props->channels(); + audio_format_.sample_rate = props->sample_rate(); + if (audio_format_.channels <= 0 || audio_format_.sample_rate <= 0) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Missing audio format metadata in the model.", + tflite::support::TfLiteSupportStatus::kMetadataNotFoundError); + } + + return absl::OkStatus(); +} + +absl::Status AudioPreprocessor::CheckAndSetInputs() { + input_buffer_size_ = 1; + for (int i = 0; i < GetTensor()->dims->size; i++) { + if (GetTensor()->dims->data[i] < 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Invalid size: %d for input tensor dimension: %d.", + GetTensor()->dims->data[i], i), + tflite::support::TfLiteSupportStatus:: + kInvalidInputTensorDimensionsError); + } + input_buffer_size_ *= GetTensor()->dims->data[i]; + } + // Check if the input buffer size is divisible by the required audio channels. + // This needs to be done after loading metadata and input. + if (input_buffer_size_ % audio_format_.channels != 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("Model input tensor size (%d) should be a " + "multiplier of the number of channels (%d).", + input_buffer_size_, audio_format_.channels), + tflite::support::TfLiteSupportStatus::kMetadataInconsistencyError); + } + return absl::OkStatus(); +} + +absl::Status AudioPreprocessor::Preprocess( + const ::tflite::task::audio::AudioBuffer& audio_buffer) { + if (audio_buffer.GetAudioFormat().channels != audio_format_.channels) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Input audio buffer channel number %d does not match " + "the model required audio channel number %d.", + audio_buffer.GetAudioFormat().channels, + audio_format_.channels)); + } + if (audio_buffer.GetAudioFormat().sample_rate != audio_format_.sample_rate) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Input audio sample rate %d does not match " + "the model required audio sample rate %d.", + audio_buffer.GetAudioFormat().sample_rate, + audio_format_.sample_rate)); + } + if (audio_buffer.GetBufferSize() != input_buffer_size_) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Input audio buffer size %d does not match the model required " + "input size %d.", + audio_buffer.GetBufferSize(), input_buffer_size_), + tflite::support::TfLiteSupportStatus::kInvalidArgumentError); + } + return tflite::task::core::PopulateTensor(audio_buffer.GetFloatBuffer(), + input_buffer_size_, GetTensor()); +} + +} // namespace processor +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.h new file mode 100644 index 0000000..7cc337b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.h
@@ -0,0 +1,79 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_AUDIO_PREPROCESSOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_AUDIO_PREPROCESSOR_H_ + +#include <initializer_list> + +#include "absl/memory/memory.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/processor/processor.h" + +namespace tflite { +namespace task { +namespace processor { + +// Processes input audio and populates the associate input tensor. +// Requirement for the input tensor: +// +// Input tensor: +// (kTfLiteFloat32) +// - input audio buffer of size `[batch * samples]`. +// - batch inference is not supported (`batch` is required to be 1). +// - for multi-channel models, the channels need be interleaved. +class AudioPreprocessor : public Preprocessor { + public: + static tflite::support::StatusOr<std::unique_ptr<AudioPreprocessor>> Create( + tflite::task::core::TfLiteEngine* engine, + const std::initializer_list<int> input_indices); + + // Processes the provided AudioBuffer and populates tensor values. + // + // The input `audio_buffer` are the raw buffer captured by the required format + // which can retrieved by GetRequiredAudioFormat(). + ::absl::Status Preprocess( + const tflite::task::audio::AudioBuffer& audio_buffer); + + // Returns the required input audio format if it is set. Otherwise, returns + // kMetadataNotFoundError. + tflite::task::audio::AudioBuffer::AudioFormat GetRequiredAudioFormat() { + return audio_format_; + } + + // Returns the required input buffer size in number of float elements. + int GetRequiredInputBufferSize() { return input_buffer_size_; } + + private: + using Preprocessor::Preprocessor; + + ::absl::Status Init(); + ::absl::Status SetAudioFormatFromMetadata(); + ::absl::Status CheckAndSetInputs(); + + // Expected input audio format by the model. + tflite::task::audio::AudioBuffer::AudioFormat audio_format_; + + // Expected input audio buffer size in number of float elements. + int input_buffer_size_; +}; + +} // namespace processor +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_AUDIO_PREPROCESSOR_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/bert_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/bert_preprocessor.cc new file mode 100644 index 0000000..96c5e0d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/bert_preprocessor.cc
@@ -0,0 +1,157 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/processor/bert_preprocessor.h" + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/ascii.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h" +#include "tensorflow_lite_support/cc/utils/common_utils.h" + +namespace tflite { +namespace task { +namespace processor { + +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit; +using ::tflite::support::text::tokenizer::TokenizerResult; +using ::tflite::task::core::FindIndexByMetadataTensorName; +using ::tflite::task::core::PopulateTensor; + +constexpr int kTokenizerProcessUnitIndex = 0; +constexpr char kIdsTensorName[] = "ids"; +constexpr char kMaskTensorName[] = "mask"; +constexpr char kSegmentIdsTensorName[] = "segment_ids"; +constexpr char kClassificationToken[] = "[CLS]"; +constexpr char kSeparator[] = "[SEP]"; + +/* static */ +StatusOr<std::unique_ptr<BertPreprocessor>> BertPreprocessor::Create( + tflite::task::core::TfLiteEngine* engine, + const std::initializer_list<int> input_tensor_indices) { + ASSIGN_OR_RETURN(auto processor, Processor::Create<BertPreprocessor>( + /* num_expected_tensors = */ 3, engine, + input_tensor_indices, + /* requires_metadata = */ false)); + RETURN_IF_ERROR(processor->Init()); + return processor; +} + +absl::Status BertPreprocessor::Init() { + // Try if RegexTokenzier can be found. + // BertTokenzier is packed in the processing unit of the InputTensors in + // SubgraphMetadata. + const tflite::ProcessUnit* tokenzier_metadata = + GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex); + // Identify the tensor index for three Bert input tensors. + auto tensors_metadata = GetMetadataExtractor()->GetInputTensorMetadata(); + int ids_tensor_index = + FindIndexByMetadataTensorName(tensors_metadata, kIdsTensorName); + ids_tensor_index_ = + ids_tensor_index == -1 ? tensor_indices_[0] : ids_tensor_index; + int mask_tensor_index = + FindIndexByMetadataTensorName(tensors_metadata, kMaskTensorName); + mask_tensor_index_ = + mask_tensor_index == -1 ? tensor_indices_[1] : mask_tensor_index; + int segment_ids_tensor_index = + FindIndexByMetadataTensorName(tensors_metadata, kSegmentIdsTensorName); + segment_ids_tensor_index_ = segment_ids_tensor_index == -1 + ? tensor_indices_[2] + : segment_ids_tensor_index; + + if (GetLastDimSize(ids_tensor_index_) != GetLastDimSize(mask_tensor_index_) || + GetLastDimSize(ids_tensor_index_) != + GetLastDimSize(segment_ids_tensor_index_)) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("The three input tensors in Bert models are " + "expected to have same length, but got ids_tensor " + "(%d), mask_tensor (%d), segment_ids_tensor (%d).", + GetLastDimSize(ids_tensor_index_), + GetLastDimSize(mask_tensor_index_), + GetLastDimSize(segment_ids_tensor_index_)), + TfLiteSupportStatus::kInvalidNumOutputTensorsError); + } + bert_max_seq_len_ = GetLastDimSize(ids_tensor_index_); + + ASSIGN_OR_RETURN(tokenizer_, CreateTokenizerFromProcessUnit( + tokenzier_metadata, GetMetadataExtractor())); + return absl::OkStatus(); +} + +absl::Status BertPreprocessor::Preprocess(const std::string& input_text) { + auto* ids_tensor = + engine_->GetInput(engine_->interpreter(), ids_tensor_index_); + auto* mask_tensor = + engine_->GetInput(engine_->interpreter(), mask_tensor_index_); + auto* segment_ids_tensor = + engine_->GetInput(engine_->interpreter(), segment_ids_tensor_index_); + + std::string processed_input = input_text; + absl::AsciiStrToLower(&processed_input); + + TokenizerResult input_tokenize_results; + input_tokenize_results = tokenizer_->Tokenize(processed_input); + + // 2 accounts for [CLS], [SEP] + absl::Span<const std::string> query_tokens = + absl::MakeSpan(input_tokenize_results.subwords.data(), + input_tokenize_results.subwords.data() + + std::min(static_cast<size_t>(bert_max_seq_len_ - 2), + input_tokenize_results.subwords.size())); + + std::vector<std::string> tokens; + tokens.reserve(2 + query_tokens.size()); + // Start of generating the features. + tokens.push_back(kClassificationToken); + // For query input. + for (const auto& query_token : query_tokens) { + tokens.push_back(query_token); + } + // For Separation. + tokens.push_back(kSeparator); + + std::vector<int> input_ids(bert_max_seq_len_, 0); + std::vector<int> input_mask(bert_max_seq_len_, 0); + // Convert tokens back into ids and set mask + for (int i = 0; i < tokens.size(); ++i) { + tokenizer_->LookupId(tokens[i], &input_ids[i]); + input_mask[i] = 1; + } + // |<--------bert_max_seq_len_--------->| + // input_ids [CLS] s1 s2... sn [SEP] 0 0... 0 + // input_masks 1 1 1... 1 1 0 0... 0 + // segment_ids 0 0 0... 0 0 0 0... 0 + + RETURN_IF_ERROR(PopulateTensor(input_ids, ids_tensor)); + RETURN_IF_ERROR(PopulateTensor(input_mask, mask_tensor)); + RETURN_IF_ERROR(PopulateTensor(std::vector<int>(bert_max_seq_len_, 0), + segment_ids_tensor)); + return absl::OkStatus(); +} + +int BertPreprocessor::GetLastDimSize(int tensor_index) { + auto tensor = engine_->GetInput(engine_->interpreter(), tensor_index); + return tensor->dims->data[tensor->dims->size - 1]; +} + +} // namespace processor +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/bert_preprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/bert_preprocessor.h new file mode 100644 index 0000000..85bba20 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/bert_preprocessor.h
@@ -0,0 +1,59 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_BERT_PREPROCESOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_BERT_PREPROCESOR_H_ + +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/processor/text_preprocessor.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" + +namespace tflite { +namespace task { +namespace processor { + +// Processes input text and populates the associated bert input tensors. +// Requirements for the input tensors: +// - The 3 input tensors should be populated with the metadata tensor names, +// "ids", "mask", and "segment_ids", respectively. +// - The input_process_units metadata should contain WordPiece or +// Sentencepiece Tokenizer metadata. +class BertPreprocessor : public TextPreprocessor { + public: + static tflite::support::StatusOr<std::unique_ptr<BertPreprocessor>> Create( + tflite::task::core::TfLiteEngine* engine, + const std::initializer_list<int> input_tensor_indices); + + absl::Status Preprocess(const std::string& text); + + private: + using TextPreprocessor::TextPreprocessor; + + absl::Status Init(); + + int GetLastDimSize(int tensor_index); + + std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_; + int ids_tensor_index_; + int mask_tensor_index_; + int segment_ids_tensor_index_; + int bert_max_seq_len_; +}; + +} // namespace processor +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_BERT_PREPROCESOR_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc new file mode 100644 index 0000000..6396200 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc
@@ -0,0 +1,222 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/processor/classification_postprocessor.h" + +#include <memory> + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/label_map_item.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" + +namespace tflite { +namespace task { +namespace processor { + +namespace { + +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::core::BuildClassificationHead; +using ::tflite::task::core::LabelMapItem; +using ::tflite::task::core::ScoreCalibration; + +} // namespace + +/* static */ +tflite::support::StatusOr<std::unique_ptr<ClassificationPostprocessor>> +ClassificationPostprocessor::Create( + core::TfLiteEngine* engine, + const std::initializer_list<int> output_indices, + std::unique_ptr<ClassificationOptions> options) { + ASSIGN_OR_RETURN(auto processor, + Processor::Create<ClassificationPostprocessor>( + /* num_expected_tensors = */ 1, engine, output_indices)); + + RETURN_IF_ERROR(processor->Init(std::move(options))); + return processor; +} + +absl::Status ClassificationPostprocessor::Init( + std::unique_ptr<ClassificationOptions> options) { + // Sanity check options + if (options->max_results() == 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Invalid `max_results` option: value must be != 0", + TfLiteSupportStatus::kInvalidArgumentError); + } + if (options->class_name_allowlist_size() > 0 && + options->class_name_denylist_size() > 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "`class_name_allowlist` and `class_name_denylist` are mutually " + "exclusive options.", + TfLiteSupportStatus::kInvalidArgumentError); + } + + ASSIGN_OR_RETURN(classification_head_, + BuildClassificationHead(*engine_->metadata_extractor(), + *GetTensorMetadata(), + options->display_names_locale())); + + // Sanity check output tensors + const TfLiteTensor* output_tensor = GetTensor(); + const int num_dimensions = output_tensor->dims->size; + if (num_dimensions == 4) { + if (output_tensor->dims->data[1] != 1 || + output_tensor->dims->data[2] != 1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Unexpected WxH sizes for output index %d: got " + "%dx%d, expected 1x1.", + tensor_indices_.at(0), output_tensor->dims->data[2], + output_tensor->dims->data[1]), + TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + } else if (num_dimensions != 2) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Unexpected number of dimensions for output index %d: got %dD, " + "expected either 2D (BxN with B=1) or 4D (BxHxWxN with B=1, W=1, " + "H=1).", + tensor_indices_.at(0), num_dimensions), + TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + if (output_tensor->dims->data[0] != 1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("The output array is expected to have a batch size " + "of 1. Got %d for output index %d.", + output_tensor->dims->data[0], tensor_indices_.at(0)), + TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + int num_classes = output_tensor->dims->data[num_dimensions - 1]; + // If label map is not set, build a default one based on model + // introspection. This happens if a model with partial or no metadata was + // provided through the `model_file_with_metadata` options field. + if (classification_head_.label_map_items.empty()) { + classification_head_.label_map_items.reserve(num_classes); + for (int class_index = 0; class_index < num_classes; ++class_index) { + classification_head_.label_map_items.emplace_back(LabelMapItem{}); + } + } + int num_label_map_items = classification_head_.label_map_items.size(); + if (num_classes != num_label_map_items) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Got %d class(es) for output index %d, expected %d " + "according to the label map.", + output_tensor->dims->data[num_dimensions - 1], + tensor_indices_.at(0), num_label_map_items), + TfLiteSupportStatus::kMetadataInconsistencyError); + } + if (output_tensor->type != kTfLiteUInt8 && + output_tensor->type != kTfLiteFloat32) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Type mismatch for output tensor %s. Requested one " + "of these types: " + "kTfLiteUint8/kTfLiteFloat32, got %s.", + output_tensor->name, + TfLiteTypeGetName(output_tensor->type)), + TfLiteSupportStatus::kInvalidOutputTensorTypeError); + } + + // Set class name set + if (options->class_name_denylist_size() != 0 || + options->class_name_allowlist_size() != 0) { + // Before processing class names allowlist or denylist from the input + // options create a set with _all_ known class names from the label map(s). + absl::flat_hash_set<std::string> head_class_names; + for (const auto& item : classification_head_.label_map_items) { + if (!item.name.empty()) { + head_class_names.insert(item.name); + } + } + + if (head_class_names.empty()) { + std::string name = classification_head_.name; + if (name.empty()) { + name = absl::StrFormat("#%d", tensor_indices_.at(0)); + } + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Using `class_name_allowlist` or `class_name_denylist` " + "requires labels to be present but none was found for " + "classification head: %s", + name), + TfLiteSupportStatus::kMetadataMissingLabelsError); + } + + class_name_set_.is_allowlist = options->class_name_allowlist_size() > 0; + const auto& class_names = class_name_set_.is_allowlist + ? options->class_name_allowlist() + : options->class_name_denylist(); + + // Note: duplicate or unknown classes are just ignored. + class_name_set_.values.clear(); + for (const auto& class_name : class_names) { + if (!head_class_names.contains(class_name)) { + continue; + } + class_name_set_.values.insert(class_name); + } + + if (class_name_set_.values.empty()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Invalid class names specified via `class_name_%s`: none match " + "with model labels.", + class_name_set_.is_allowlist ? "allowlsit" : "denylist"), + TfLiteSupportStatus::kInvalidArgumentError); + } + } + + // Set score calibration + if (classification_head_.calibration_params.has_value()) { + score_calibration_ = absl::make_unique<ScoreCalibration>(); + if (score_calibration_ == nullptr) { + return CreateStatusWithPayload( + StatusCode::kInternal, "Could not create score calibration object."); + } + + RETURN_IF_ERROR(score_calibration_->InitializeFromParameters( + classification_head_.calibration_params.value())); + } + + num_results_ = + options->max_results() >= 0 + ? std::min( + static_cast<int>(classification_head_.label_map_items.size()), + options->max_results()) + : classification_head_.label_map_items.size(); + score_threshold_ = options->has_score_threshold() + ? options->score_threshold() + : classification_head_.score_threshold; + + return absl::OkStatus(); +} + +} // namespace processor +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.h new file mode 100644 index 0000000..517974c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.h
@@ -0,0 +1,230 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_CLASSIFICATION_POSTPROCESSOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_CLASSIFICATION_POSTPROCESSOR_H_ + +#include <initializer_list> + +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/classification_head.h" +#include "tensorflow_lite_support/cc/task/core/label_map_item.h" +#include "tensorflow_lite_support/cc/task/core/score_calibration.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/processor/processor.h" +#include "tensorflow_lite_support/cc/task/processor/proto/class.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/classification_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/classifications.pb.h" + +namespace tflite { +namespace task { +namespace processor { + +// This Postprocessor expects one output tensor with: +// (kTfLiteUInt8/kTfLiteFloat32) +// - `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or +// `[1 x 1 x 1 x N]` +// - optional (but recommended) label map(s) as AssociatedFile-s with type +// TENSOR_AXIS_LABELS, containing one label per line. The first such +// AssociatedFile (if any) is used to fill the `class_name` field of the +// results. The `display_name` field is filled from the AssociatedFile (if +// any) whose locale matches the `display_names_locale` field of the +// `ImageClassifierOptions` used at creation time ("en" by default, i.e. +// English). If none of these are available, only the `index` field of the +// results will be filled. +class ClassificationPostprocessor : public Postprocessor { + public: + static tflite::support::StatusOr<std::unique_ptr<ClassificationPostprocessor>> + Create(core::TfLiteEngine* engine, + const std::initializer_list<int> output_indices, + std::unique_ptr<ClassificationOptions> options); + + // Convert the tensor output to classification class. + // Note that this method doesn't add head_name for backward compatibility. + // Head name can be retrieved by `GetHeadName` method. + template <typename T> + absl::Status Postprocess(T* classifications); + + const std::string GetHeadName() const { return classification_head_.name; } + + private: + using Postprocessor::Postprocessor; + + absl::Status Init(std::unique_ptr<ClassificationOptions> options); + // Given a ClassificationResult object containing class indices, fills the + // name and display name from the label map(s). + template <typename T> + absl::Status FillResultsFromLabelMaps(T* classifications); + + // The list of classification heads associated with the corresponding output + // tensors. Built from TFLite Model Metadata. + ::tflite::task::core::ClassificationHead classification_head_{}; + + // Set of allowlisted or denylisted class names. + struct ClassNameSet { + absl::flat_hash_set<std::string> values; + bool is_allowlist; + }; + + // Allowlisted or denylisted class names based on provided options at + // construction time. These are used to filter out results during + // post-processing. + ClassNameSet class_name_set_; + + // Score calibration parameters, if any. Built from TFLite Model + // Metadata. + std::unique_ptr<core::ScoreCalibration> score_calibration_; + + // Number of classes returned by `Postprocess` method. + int num_results_; + + // Recommended score threshold typically in [0,1[. Classification results with + // a score below this value are considered low-confidence and should be + // rejected from returned results. + float score_threshold_; +}; + +template <typename T> +absl::Status ClassificationPostprocessor::Postprocess(T* classifications) { + const auto& head = classification_head_; + classifications->set_head_index(tensor_indices_.at(0)); + + std::vector<std::pair<int, float>> score_pairs; + score_pairs.reserve(head.label_map_items.size()); + + const TfLiteTensor* output_tensor = GetTensor(); + if (output_tensor->type == kTfLiteUInt8) { + ASSIGN_OR_RETURN(const uint8* output_data, + core::AssertAndReturnTypedTensor<uint8>(output_tensor)); + for (int j = 0; j < head.label_map_items.size(); ++j) { + score_pairs.emplace_back( + j, output_tensor->params.scale * (static_cast<int>(output_data[j]) - + output_tensor->params.zero_point)); + } + } else { + ASSIGN_OR_RETURN(const float* output_data, + core::AssertAndReturnTypedTensor<float>(output_tensor)); + for (int j = 0; j < head.label_map_items.size(); ++j) { + score_pairs.emplace_back(j, output_data[j]); + } + } + + // Optional score calibration. + if (score_calibration_ != nullptr) { + for (auto& score_pair : score_pairs) { + const std::string& class_name = + head.label_map_items[score_pair.first].name; + + // In ComputeCalibratedScore, score_pair.second is set to the + // default_score value from metadata [1] if the category (1) has no + // score calibration data or (2) has a very low confident uncalibrated + // score, i.e. lower than the `min_uncalibrated_score` threshold. + // Otherwise, score_pair.second is calculated based on the selected + // score transformation function, and the value is guaranteed to be in + // the range of [0, scale], where scale is a label-dependent sigmoid + // parameter. + // + // [1]: + // https://github.com/tensorflow/tflite-support/blob/af26cb6952ccdeee0e849df2b93dbe7e57f6bc48/tensorflow_lite_support/metadata/metadata_schema.fbs#L453 + score_pair.second = score_calibration_->ComputeCalibratedScore( + class_name, score_pair.second); + } + } + + if (class_name_set_.values.empty()) { + // Partially sort in descending order (higher score is better). + absl::c_partial_sort( + score_pairs, score_pairs.begin() + num_results_, + [](const std::pair<int, float>& a, const std::pair<int, float>& b) { + return a.second > b.second; + }); + + for (int j = 0; j < num_results_; ++j) { + float score = score_pairs[j].second; + if (score < score_threshold_) { + break; + } + auto* cl = classifications->add_classes(); + cl->set_index(score_pairs[j].first); + cl->set_score(score); + } + } else { + // Sort in descending order (higher score is better). + absl::c_sort(score_pairs, [](const std::pair<int, float>& a, + const std::pair<int, float>& b) { + return a.second > b.second; + }); + + for (int j = 0; j < head.label_map_items.size(); ++j) { + float score = score_pairs[j].second; + if (score < score_threshold_ || + classifications->classes_size() >= num_results_) { + break; + } + + const int class_index = score_pairs[j].first; + const std::string& class_name = head.label_map_items[class_index].name; + + bool class_name_found = class_name_set_.values.contains(class_name); + + if ((!class_name_found && class_name_set_.is_allowlist) || + (class_name_found && !class_name_set_.is_allowlist)) { + continue; + } + + auto* cl = classifications->add_classes(); + cl->set_index(class_index); + cl->set_score(score); + } + } + return FillResultsFromLabelMaps(classifications); +} + +template <typename T> +absl::Status ClassificationPostprocessor::FillResultsFromLabelMaps( + T* classifications) { + int head_index = classifications->head_index(); + const auto& label_map_items = classification_head_.label_map_items; + for (int j = 0; j < classifications->classes_size(); ++j) { + auto* current_class = classifications->mutable_classes(j); + int current_class_index = current_class->index(); + if (current_class_index < 0 || + current_class_index >= label_map_items.size()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Invalid class index (%d) with respect to label " + "map size (%d) for head #%d.", + current_class_index, label_map_items.size(), + head_index), + support::TfLiteSupportStatus::kMetadataInconsistencyError); + } + const std::string& name = label_map_items[current_class_index].name; + if (!name.empty()) { + current_class->set_class_name(name); + } + const std::string& display_name = + label_map_items[current_class_index].display_name; + if (!display_name.empty()) { + current_class->set_display_name(display_name); + } + } + return absl::OkStatus(); +} + +} // namespace processor +} // namespace task +} // namespace tflite +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_CLASSIFICATION_POSTPROCESSOR_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.cc new file mode 100644 index 0000000..83b123f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.cc
@@ -0,0 +1,90 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h" + +namespace tflite { +namespace task { +namespace processor { + +/* static */ +tflite::support::StatusOr<std::unique_ptr<EmbeddingPostprocessor>> +EmbeddingPostprocessor::Create(core::TfLiteEngine* engine, + const std::initializer_list<int> output_indices, + std::unique_ptr<EmbeddingOptions> options) { + ASSIGN_OR_RETURN(auto processor, + Processor::Create<EmbeddingPostprocessor>( + /* num_expected_tensors = */ 1, engine, output_indices, + /* requires_metadata = */ false)); + + RETURN_IF_ERROR(processor->Init(std::move(options))); + return processor; +} + +absl::Status EmbeddingPostprocessor::Init( + std::unique_ptr<EmbeddingOptions> options) { + options_ = std::move(options); + + int output_index = tensor_indices_.at(0); + auto* output_tensor = GetTensor(); + int num_dimensions = output_tensor->dims->size; + if (num_dimensions == 4) { + if (output_tensor->dims->data[1] != 1 || + output_tensor->dims->data[2] != 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Unexpected WxH sizes for output index %d: got " + "%dx%d, expected 1x1.", + output_index, output_tensor->dims->data[2], + output_tensor->dims->data[1]), + support::TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + } else if (num_dimensions != 2) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Unexpected number of dimensions for output index %d: got %dD, " + "expected either 2D (BxN with B=1) or 4D (BxHxWxN with B=1, " + "W=1, " + "H=1).", + output_index, num_dimensions), + support::TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + if (output_tensor->dims->data[0] != 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("The output array is expected to have a batch size " + "of 1. Got %d for output index %d.", + output_tensor->dims->data[0], output_index), + support::TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + embedding_dimension_ = output_tensor->dims->data[num_dimensions - 1]; + if (output_tensor->type != kTfLiteUInt8 && + output_tensor->type != kTfLiteFloat32) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Type mismatch for output tensor %s. Requested one " + "of these types: " + "kTfLiteUint8/kTfLiteFloat32, got %s.", + output_tensor->name, + TfLiteTypeGetName(output_tensor->type)), + support::TfLiteSupportStatus::kInvalidOutputTensorTypeError); + } + return absl::OkStatus(); +} + +} // namespace processor +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h new file mode 100644 index 0000000..78cef8ab5 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h
@@ -0,0 +1,213 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_EMBEDDING_POSTPROCESSOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_EMBEDDING_POSTPROCESSOR_H_ +#include <initializer_list> + +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/processor/processor.h" +#include "tensorflow_lite_support/cc/task/processor/proto/embedding_options.pb.h" + +namespace tflite { +namespace task { +namespace processor { + +// This postprocessor works with the following output tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - `N` components corresponding to the `N` dimensions of the returned +// feature vector for this output layer. +// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. +class EmbeddingPostprocessor : public Postprocessor { + public: + static tflite::support::StatusOr<std::unique_ptr<EmbeddingPostprocessor>> + Create(core::TfLiteEngine* engine, + const std::initializer_list<int> output_indices, + std::unique_ptr<EmbeddingOptions> options); + + template <typename T> + absl::Status Postprocess(T* embedding); + + // Utility function to compute cosine similarity [1] between two feature + // vectors. May return an InvalidArgumentError if e.g. the feature vectors are + // of different types (quantized vs. float), have different sizes, or have a + // an L2-norm of 0. + // + // [1]: https://en.wikipedia.org/wiki/Cosine_similarity + template <typename T> + static tflite::support::StatusOr<double> CosineSimilarity(const T& u, + const T& v); + + int GetEmbeddingDimension() const { return embedding_dimension_; } + + private: + using Postprocessor::Postprocessor; + + absl::Status Init(std::unique_ptr<EmbeddingOptions> options); + + std::unique_ptr<EmbeddingOptions> options_; + + int embedding_dimension_ = 0; + + // Performs actual cosine similarity computation. + template <typename T> + static tflite::support::StatusOr<double> + ComputeCosineSimilarity(const T* u, const T* v, int num_elements); + + template <typename T> + void NormalizeFeatureVector(T* feature_vector) const; + + template <typename T> + void QuantizeFeatureVector(T* feature_vector) const; +}; + +template <typename T> +absl::Status EmbeddingPostprocessor::Postprocess(T* embedding) { + embedding->set_output_index(tensor_indices_.at(0)); + auto* feature_vector = embedding->mutable_feature_vector(); + if (GetTensor()->type == kTfLiteUInt8) { + const uint8* output_data = + engine_->interpreter()->typed_output_tensor<uint8>( + tensor_indices_.at(0)); + // Get the zero_point and scale parameters from the tensor metadata. + const int output_tensor_index = + engine_->interpreter()->outputs()[tensor_indices_.at(0)]; + const TfLiteTensor* output_tensor = + engine_->interpreter()->tensor(output_tensor_index); + for (int j = 0; j < embedding_dimension_; ++j) { + feature_vector->add_value_float(output_tensor->params.scale * + (static_cast<int>(output_data[j]) - + output_tensor->params.zero_point)); + } + } else { + // Float + const float* output_data = + engine_->interpreter()->typed_output_tensor<float>( + tensor_indices_.at(0)); + for (int j = 0; j < embedding_dimension_; ++j) { + feature_vector->add_value_float(output_data[j]); + } + } + if (options_->l2_normalize()) { + NormalizeFeatureVector(feature_vector); + } + if (options_->quantize()) { + QuantizeFeatureVector(feature_vector); + } + return absl::OkStatus(); +} + +template <typename T> +void EmbeddingPostprocessor::NormalizeFeatureVector(T* feature_vector) const { + float squared_l2_norm = 0.0f; + for (const float val : feature_vector->value_float()) { + squared_l2_norm += val * val; + } + if (squared_l2_norm == 0.0f) { + return; + } + const float inv_l2_norm = 1.0f / std::sqrt(squared_l2_norm); + for (int i = 0; i < feature_vector->value_float().size(); ++i) { + feature_vector->set_value_float( + i, feature_vector->value_float(i) * inv_l2_norm); + } +} + +template <typename T> +void EmbeddingPostprocessor::QuantizeFeatureVector(T* feature_vector) const { + auto* quantized_values = feature_vector->mutable_value_string(); + quantized_values->resize(feature_vector->value_float().size()); + for (int i = 0; i < feature_vector->value_float().size(); ++i) { + int value = static_cast<int>(roundf(feature_vector->value_float(i) * 128)); + (*quantized_values)[i] = + static_cast<char>(std::max(-128, std::min(value, 127))); + } + feature_vector->clear_value_float(); +} + +/* static */ +template <typename T> +tflite::support::StatusOr<double> +EmbeddingPostprocessor::ComputeCosineSimilarity(const T* u, + const T* v, + int num_elements) { + if (num_elements <= 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Cannot compute cosine similarity on empty feature vectors", + support::TfLiteSupportStatus::kInvalidArgumentError); + } + double dot_product = 0.0; + double norm_u = 0.0; + double norm_v = 0.0; + for (int i = 0; i < num_elements; ++i) { + dot_product += u[i] * v[i]; + norm_u += u[i] * u[i]; + norm_v += v[i] * v[i]; + } + if (norm_u <= 0.0 || norm_v <= 0.0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Cannot compute cosine similarity on feature vector with 0 norm", + support::TfLiteSupportStatus::kInvalidArgumentError); + } + return dot_product / std::sqrt(norm_u * norm_v); +} + +/* static */ +template <typename T> +tflite::support::StatusOr<double> EmbeddingPostprocessor::CosineSimilarity( + const T& u, + const T& v) { + if (u.has_value_string() && v.has_value_string()) { + if (u.value_string().size() != v.value_string().size()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Cannot compute cosine similarity on quantized " + "feature vectors of different sizes (%d vs %d)", + u.value_string().size(), v.value_string().size()), + support::TfLiteSupportStatus::kInvalidArgumentError); + } + return ComputeCosineSimilarity( + reinterpret_cast<const int8_t*>(&u.value_string()[0]), + reinterpret_cast<const int8_t*>(&v.value_string()[0]), + u.value_string().size()); + } + if (!u.has_value_string() && !v.has_value_string()) { + if (u.value_float_size() != v.value_float_size()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Cannot compute cosine similarity on float " + "feature vectors of different sizes (%d vs %d)", + u.value_float_size(), v.value_float_size()), + support::TfLiteSupportStatus::kInvalidArgumentError); + } + return ComputeCosineSimilarity( + u.value_float().data(), v.value_float().data(), u.value_float().size()); + } + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Cannot compute cosine similarity between quantized and float " + "feature vectors", + support::TfLiteSupportStatus::kInvalidArgumentError); +} + +} // namespace processor +} // namespace task +} // namespace tflite +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_EMBEDDING_POSTPROCESSOR_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc new file mode 100644 index 0000000..310a1f5 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc
@@ -0,0 +1,232 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/processor/image_preprocessor.h" + +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h" + +namespace tflite { +namespace task { +namespace processor { + +namespace { +// Number of bytes required for 8-bit per pixel RGB color space. +static constexpr int kRgbPixelBytes = 3; + +using ::tflite::task::vision::BoundingBox; +using ::tflite::task::vision::FrameBuffer; +} // namespace + +/* static */ +tflite::support::StatusOr<std::unique_ptr<ImagePreprocessor>> +ImagePreprocessor::Create( + core::TfLiteEngine* engine, + const std::initializer_list<int> input_indices, + const vision::FrameBufferUtils::ProcessEngine& process_engine) { + ASSIGN_OR_RETURN(auto processor, + Processor::Create<ImagePreprocessor>( + /* num_expected_tensors = */ 1, engine, input_indices, + /* requires_metadata = */ false)); + + RETURN_IF_ERROR(processor->Init(process_engine)); + return processor; +} + +// Returns false if image preprocessing could be skipped, true otherwise. +bool ImagePreprocessor::IsImagePreprocessingNeeded( + const FrameBuffer& frame_buffer, + const BoundingBox& roi) { + // Is crop required? + if (roi.origin_x() != 0 || roi.origin_y() != 0 || + roi.width() != frame_buffer.dimension().width || + roi.height() != frame_buffer.dimension().height) { + return true; + } + + // Are image transformations required? + if (frame_buffer.orientation() != FrameBuffer::Orientation::kTopLeft || + frame_buffer.format() != FrameBuffer::Format::kRGB || + frame_buffer.dimension().width != input_specs_.image_width || + frame_buffer.dimension().height != input_specs_.image_height) { + return true; + } + + return false; +} + +absl::Status ImagePreprocessor::Init( + const vision::FrameBufferUtils::ProcessEngine& process_engine) { + frame_buffer_utils_ = vision::FrameBufferUtils::Create(process_engine); + + ASSIGN_OR_RETURN(input_specs_, vision::BuildInputImageTensorSpecs( + *engine_->interpreter(), + *engine_->metadata_extractor())); + + if (input_specs_.color_space != tflite::ColorSpaceType_RGB) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kUnimplemented, + "ImagePreprocessor only supports RGB color space for now."); + } + + // Determine if the input shape is resizable. + const TfLiteIntArray* dims_signature = GetTensor()->dims_signature; + + // Some fixed-shape models do not have dims_signature. + if (dims_signature != nullptr && dims_signature->size > 2) { + // Only the HxW dimensions support mutability. + is_height_mutable_ = dims_signature->data[1] == -1; + is_width_mutable_ = dims_signature->data[2] == -1; + } + return absl::OkStatus(); +} + +absl::Status ImagePreprocessor::Preprocess(const FrameBuffer& frame_buffer) { + BoundingBox roi; + roi.set_width(frame_buffer.dimension().width); + roi.set_height(frame_buffer.dimension().height); + return Preprocess(frame_buffer, roi); +} + +absl::Status ImagePreprocessor::Preprocess(const FrameBuffer& frame_buffer, + const BoundingBox& roi) { + // Input data to be normalized (if needed) and used for inference. In most + // cases, this is the result of image preprocessing. In case no image + // preprocessing is needed (see below), this points to the input frame + // buffer raw data. + const uint8* input_data; + size_t input_data_byte_size; + + // Optional buffers in case image preprocessing is needed. + std::unique_ptr<FrameBuffer> preprocessed_frame_buffer; + std::vector<uint8> preprocessed_data; + + if (IsImagePreprocessingNeeded(frame_buffer, roi)) { + // Preprocess input image to fit model requirements. + // For now RGB is the only color space supported, which is ensured by + // `InitInternal`. + input_specs_.image_width = + is_width_mutable_ ? roi.width() : input_specs_.image_width; + input_specs_.image_height = + is_height_mutable_ ? roi.height() : input_specs_.image_height; + + FrameBuffer::Dimension to_buffer_dimension = {input_specs_.image_width, + input_specs_.image_height}; + input_data_byte_size = + GetBufferByteSize(to_buffer_dimension, FrameBuffer::Format::kRGB); + preprocessed_data.resize(input_data_byte_size / sizeof(uint8), 0); + input_data = preprocessed_data.data(); + + FrameBuffer::Plane preprocessed_plane = { + /*buffer=*/preprocessed_data.data(), + /*stride=*/{input_specs_.image_width * kRgbPixelBytes, kRgbPixelBytes}}; + preprocessed_frame_buffer = FrameBuffer::Create( + {preprocessed_plane}, to_buffer_dimension, FrameBuffer::Format::kRGB, + FrameBuffer::Orientation::kTopLeft); + + RETURN_IF_ERROR(frame_buffer_utils_->Preprocess( + frame_buffer, roi, preprocessed_frame_buffer.get())); + } else { + // Input frame buffer already targets model requirements: skip image + // preprocessing. For RGB, the data is always stored in a single plane. + input_data = frame_buffer.plane(0).buffer; + input_data_byte_size = frame_buffer.plane(0).stride.row_stride_bytes * + frame_buffer.dimension().height; + } + + // If dynamic, it will re-dim the entire graph as per the input. + if (is_height_mutable_ || is_width_mutable_) { + engine_->interpreter()->ResizeInputTensorStrict( + 0, {GetTensor()->dims->data[0], input_specs_.image_height, + input_specs_.image_width, GetTensor()->dims->data[3]}); + + engine_->interpreter()->AllocateTensors(); + } + // Then normalize pixel data (if needed) and populate the input tensor. + switch (input_specs_.tensor_type) { + case kTfLiteUInt8: + if (GetTensor()->bytes != input_data_byte_size) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, + "Size mismatch or unsupported padding bytes between pixel data " + "and input tensor."); + } + // No normalization required: directly populate data. + RETURN_IF_ERROR(tflite::task::core::PopulateTensor( + input_data, input_data_byte_size / sizeof(uint8), GetTensor())); + break; + case kTfLiteFloat32: { + if (GetTensor()->bytes / sizeof(float) != + input_data_byte_size / sizeof(uint8)) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, + "Size mismatch or unsupported padding bytes between pixel data " + "and input tensor."); + } + // Normalize and populate. + ASSIGN_OR_RETURN( + float* normalized_input_data, + tflite::task::core::AssertAndReturnTypedTensor<float>(GetTensor())); + const tflite::task::vision::NormalizationOptions& normalization_options = + input_specs_.normalization_options.value(); + for (int i = 0; i < normalization_options.num_values; i++) { + if (std::abs(normalization_options.std_values[i]) < + std::numeric_limits<float>::epsilon()) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, + "NormalizationOptions.std_values can't be 0. Please check if the " + "tensor metadata has been populated correctly."); + } + } + if (normalization_options.num_values == 1) { + float mean_value = normalization_options.mean_values[0]; + float inv_std_value = (1.0f / normalization_options.std_values[0]); + for (size_t i = 0; i < input_data_byte_size / sizeof(uint8); + i++, input_data++, normalized_input_data++) { + *normalized_input_data = + inv_std_value * (static_cast<float>(*input_data) - mean_value); + } + } else { + std::array<float, 3> inv_std_values = { + 1.0f / normalization_options.std_values[0], + 1.0f / normalization_options.std_values[1], + 1.0f / normalization_options.std_values[2]}; + for (size_t i = 0; i < input_data_byte_size / sizeof(uint8); + i++, input_data++, normalized_input_data++) { + *normalized_input_data = inv_std_values[i % 3] * + (static_cast<float>(*input_data) - + normalization_options.mean_values[i % 3]); + } + } + break; + } + case kTfLiteInt8: + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kUnimplemented, + "kTfLiteInt8 input type is not implemented yet."); + default: + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, "Unexpected input tensor type."); + } + + return absl::OkStatus(); +} + +} // namespace processor +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.h new file mode 100644 index 0000000..9b7d46f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.h
@@ -0,0 +1,103 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_IMAGE_PREPROCESSOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_IMAGE_PREPROCESSOR_H_ + +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/processor/processor.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h" + +namespace tflite { +namespace task { +namespace processor { + +// Process input image and populate the associate input tensor. +// Requirement for the input tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - image input of size `[batch x height x width x channels]`. +// - batch inference is not supported (`batch` is required to be 1). +// - only RGB inputs are supported (`channels` is required to be 3). +// - if type is kTfLiteFloat32, NormalizationOptions are required to be +// attached to the metadata for input normalization. +class ImagePreprocessor : public Preprocessor { + public: + static tflite::support::StatusOr<std::unique_ptr<ImagePreprocessor>> Create( + core::TfLiteEngine* engine, + const std::initializer_list<int> input_indices, + const vision::FrameBufferUtils::ProcessEngine& process_engine = + vision::FrameBufferUtils::ProcessEngine::kLibyuv); + + // Processes the provided FrameBuffer and populate tensor values. + // + // The FrameBuffer can be of any size and any of the supported formats, i.e. + // RGBA, RGB, NV12, NV21, YV12, YV21. It is automatically pre-processed before + // inference in order to (and in this order): + // - resize it (with bilinear interpolation, aspect-ratio *not* preserved) to + // the dimensions of the model input tensor, + // - convert it to the colorspace of the input tensor (i.e. RGB, which is the + // only supported colorspace for now), + // - rotate it according to its `Orientation` so that inference is performed + // on an "upright" image. + // + // NOTE: In case the model has dynamic input shape, the method would re-dim + // the entire graph based on the dimensions of the image. + absl::Status Preprocess(const vision::FrameBuffer& frame_buffer); + + // Same as above, except based on the input region of interest. + // + // IMPORTANT: as a consequence of cropping occurring first, the provided + // region of interest is expressed in the unrotated frame of reference + // coordinates system, i.e. in `[0, frame_buffer.width) x [0, + // frame_buffer.height)`, which are the dimensions of the underlying + // `frame_buffer` data before any `Orientation` flag gets applied. Also, the + // region of interest is not clamped, so this method will return a non-ok + // status if the region is out of these bounds. + absl::Status Preprocess(const vision::FrameBuffer& frame_buffer, + const vision::BoundingBox& roi); + + // Returns the spec of model. Passing in an image with this spec will speed up + // the inference as it bypasses image cropping and resizing. + const vision::ImageTensorSpecs& GetInputSpecs() const { return input_specs_; } + + private: + using Preprocessor::Preprocessor; + + // Returns false if image preprocessing could be skipped, true otherwise. + bool IsImagePreprocessingNeeded(const vision::FrameBuffer& frame_buffer, + const vision::BoundingBox& roi); + + absl::Status Init( + const vision::FrameBufferUtils::ProcessEngine& process_engine); + + // Parameters related to the input tensor which represents an image. + vision::ImageTensorSpecs input_specs_; + + // Utils for input image preprocessing (resizing, colorspace conversion, etc). + std::unique_ptr<vision::FrameBufferUtils> frame_buffer_utils_; + + // Is true if the model expects dynamic image shape, false otherwise. + bool is_height_mutable_ = false; + bool is_width_mutable_ = false; +}; + +} // namespace processor +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_IMAGE_PREPROCESSOR_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.cc new file mode 100644 index 0000000..1fc47f7 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.cc
@@ -0,0 +1,70 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/processor/processor.h" + +#include <iterator> +#include <sstream> + +namespace tflite { +namespace task { +namespace processor { + +constexpr char Preprocessor::kInputTypeName[]; +constexpr char Postprocessor::kOutputTypeName[]; + +absl::Status Processor::SanityCheck(int num_expected_tensors, + bool requires_metadata) { + const char* tensor_type_name = GetTensorTypeName(); + if (tensor_indices_.size() != num_expected_tensors) { + return support::CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Processor can handle %d tensors, " + "got: %d tensors.", + num_expected_tensors, tensor_indices_.size())); + } + + int tensor_count = GetModelTensorCount(); + for (int i = 0; i < tensor_indices_.size(); i++) { + int index = tensor_indices_.at(i); + if (index < 0 || index >= tensor_count) { + return support::CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Invalid tensor_index: %d. Model has %d %s tensors.", + index, tensor_count, tensor_type_name)); + } + if (requires_metadata) { + if (GetTensorMetadata(i) == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("%s tensor %d is missing TensorMetadata.", + tensor_type_name, index), + support::TfLiteSupportStatus::kMetadataNotFoundError); + } + } + } + + return absl::OkStatus(); +} + +std::string Processor::GetTensorIndexString() { + std::stringstream stream; + std::copy(tensor_indices_.begin(), tensor_indices_.end(), + std::ostream_iterator<int>(stream, " ")); + return stream.str(); +} + +} // namespace processor +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h new file mode 100644 index 0000000..b3c4360 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h
@@ -0,0 +1,201 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_PROCESSOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_PROCESSOR_H_ + +#include <initializer_list> +#include <vector> + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow/lite/core/shims/c/common.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" + +namespace tflite { +namespace task { +namespace processor { + +// Abstract base class for all Processors. +// Shares the common logics to handle tflite_engine and metadata. +class Processor { + public: + Processor() = default; + virtual ~Processor() = default; + + // Processor is neither copyable nor movable. + Processor(const Processor&) = delete; + Processor& operator=(const Processor&) = delete; + + template <typename T> + using EnableIfProcessorSubclass = + typename std::enable_if<std::is_base_of<Processor, T>::value>::type*; + + // Factory method to create a subclass of Processor. + // + // Example usage: + // auto processor = Processor::Create<MyPreprocessor>( + // num_expected_tensors, engine, tensor_indices); + template <typename T, EnableIfProcessorSubclass<T> = nullptr> + static tflite::support::StatusOr<std::unique_ptr<T>> Create( + int num_expected_tensors, + tflite::task::core::TfLiteEngine* engine, + const std::initializer_list<int> tensor_indices, + bool requires_metadata = true) { + auto processor = absl::make_unique<T>(engine, tensor_indices); + RETURN_IF_ERROR( + processor->SanityCheck(num_expected_tensors, requires_metadata)); + return processor; + } + + // `tensor_indices` is the indices of the input tensors or output tensors + // that the processor should process. For example, a model may have 4 input + // tensors, and a preprocessor can process the first and third tensor, + // then `tensor_indices` should be {0, 2}. + explicit Processor(core::TfLiteEngine* engine, + const std::initializer_list<int> tensor_indices) + : engine_(engine), tensor_indices_(tensor_indices) {} + + // Checks if tensor counts and metadata of the model matches what required + // by the processor in general. + absl::Status SanityCheck(int num_expected_tensors, + bool requires_metadata = true); + + protected: + // Gets the associated tensor. + // `i` refers to the element index in `tensor_indices`. For example, + // assume `tensor_indices` is {3, 6, 8}, to access second tensor in + // `tensor_indices`, which is the 6th tensor of the model inputs or ourputs, + // `i` should be 1. + virtual TfLiteTensor* GetTensor(int i) const = 0; + + // Gets the associated tensor metadata. + // `i` refers to the element index in `tensor_indices`. For example, + // assume `tensor_indices` is {3, 6, 8}, to access second tensor in + // `tensor_indices`, which is the 6th tensor of the model inputs or ourputs, + // `i` should be 1. + virtual const tflite::TensorMetadata* GetTensorMetadata(int i = 0) const = 0; + + inline const tflite::metadata::ModelMetadataExtractor* GetMetadataExtractor() + const { + return engine_->metadata_extractor(); + } + + // Gets the tesnor indices in string format. + std::string GetTensorIndexString(); + + core::TfLiteEngine* engine_; + const std::vector<int> tensor_indices_; + + private: + // Gets the number of input or ourput tensors of the TfLiteEngine that this + // processor holds. + virtual int GetModelTensorCount() const = 0; + + // Either "input" or "output". + virtual const char* GetTensorTypeName() const = 0; +}; + +// Abstract base class for all Preprocessors. +// Preprocessor is a helper class that converts input structured data (such as +// image) to raw bytes and populates the associated tensors in the +// interpreter. +// +// As a convention, child class needs to implement a factory `Create` method +// to initialize and bind tensors. +// +// Example usage: +// auto processor = MyPreprocessor::Create( +// /* input_tensors */ {0}, engine, option); +// // Populate the associate tensors. +// processor->Preprocess(...); +class Preprocessor : public Processor { + protected: + using Processor::Processor; + + // Get the associated input tensor. + // Note: Caller is responsible for passing in a valid `i`. + inline TfLiteTensor* GetTensor(int i = 0) const override { + return engine_->GetInput(engine_->interpreter(), tensor_indices_.at(i)); + } + + // Get the associated input metadata. + // Note: Caller is responsible for passing in a valid `i`. + inline const tflite::TensorMetadata* GetTensorMetadata( + int i = 0) const override { + return GetMetadataExtractor()->GetInputTensorMetadata( + tensor_indices_.at(i)); + } + + private: + static constexpr char kInputTypeName[] = "input"; + + inline int GetModelTensorCount() const override { + return engine_->InputCount(engine_->interpreter()); + } + + inline const char* GetTensorTypeName() const override { + return kInputTypeName; + } +}; + +// Abstract base class for all Postprocessors. +// Postprocessor is a helper class to convert tensor value to structured +// data. +// As a convention, child class needs to implement a factory `Create` method +// to initialize and bind tensors. +// +// Example usage: +// auto processor = MyPostprocessor::Create( +// /* output_tensors */ {0}, engine, option); +// // Populate the associate tensors. +// auto value = processor->Postprocess(); +class Postprocessor : public Processor { + protected: + using Processor::Processor; + + // Get the associated output tensor. + // Note: Caller is responsible for passing in a valid `i`. + inline TfLiteTensor* GetTensor(int i = 0) const override { + return engine_->GetOutput(engine_->interpreter(), tensor_indices_.at(i)); + } + + // Get the associated output metadata. + // Note: Caller is responsible for passing in a valid `i`. + inline const tflite::TensorMetadata* GetTensorMetadata( + int i = 0) const override { + return GetMetadataExtractor()->GetOutputTensorMetadata( + tensor_indices_.at(i)); + } + + private: + static constexpr char kOutputTypeName[] = "output"; + + inline int GetModelTensorCount() const override { + return engine_->OutputCount(engine_->interpreter()); + } + + inline const char* GetTensorTypeName() const override { + return kOutputTypeName; + } +}; + +} // namespace processor +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_PROCESSOR_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/BUILD new file mode 100644 index 0000000..7e1e203 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/BUILD
@@ -0,0 +1,67 @@ +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +proto_library( + name = "classifications_proto", + srcs = ["classifications.proto"], + deps = [":class_proto"], +) + +cc_proto_library( + name = "classifications_cc_proto", + deps = [ + ":classifications_proto", + ], +) + +proto_library( + name = "class_proto", + srcs = ["class.proto"], +) + +cc_proto_library( + name = "class_cc_proto", + deps = [ + ":class_proto", + ], +) + +proto_library( + name = "embedding_proto", + srcs = ["embedding.proto"], +) + +cc_proto_library( + name = "embedding_cc_proto", + deps = [ + ":embedding_proto", + ], +) + +proto_library( + name = "embedding_options_proto", + srcs = ["embedding_options.proto"], +) + +cc_proto_library( + name = "embedding_options_cc_proto", + deps = [ + ":embedding_options_proto", + ], +) + +proto_library( + name = "classification_options_proto", + srcs = ["classification_options.proto"], +) + +cc_proto_library( + name = "classification_options_cc_proto", + deps = [ + ":classification_options_proto", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/class.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/class.proto new file mode 100644 index 0000000..f8b13f6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/class.proto
@@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.processor; + +// A single classification result. +message Class { + // The index of the class in the corresponding label map, usually packed in + // the TFLite Model Metadata [1]. + // + // [1]: https://www.tensorflow.org/lite/convert/metadata + optional int32 index = 1; + // The score for this class e.g. (but not necessarily) a probability in [0,1]. + optional float score = 2; + // A human readable name of the class filled from the label map. + optional string display_name = 3; + // An ID for the class, not necessarily human-readable (e.g. a Google + // Knowledge Graph ID [1]), filled from the label map. + // + // [1]: https://developers.google.com/knowledge-graph + optional string class_name = 4; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/classification_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/classification_options.proto new file mode 100644 index 0000000..0546800 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/classification_options.proto
@@ -0,0 +1,45 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.processor; + +// Options for classification processor. +// Next Id: 6 +message ClassificationOptions { + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + optional string display_names_locale = 1 [default = "en"]; + + // The maximum number of top-scored classification results to return. If < 0, + // all available results will be returned. If 0, an invalid argument error is + // returned. + optional int32 max_results = 2 [default = -1]; + + // Score threshold, overrides the ones provided in the model metadata + // (if any). Results below this value are rejected. + optional float score_threshold = 3; + + // Optional allowlist of class names. If non-empty, classifications whose + // class name is not in this set will be filtered out. Duplicate or unknown + // class names are ignored. Mutually exclusive with class_name_denylist. + repeated string class_name_allowlist = 4; + + // Optional denylist of class names. If non-empty, classifications whose + // class name is in this set will be filtered out. Duplicate or unknown + // class names are ignored. Mutually exclusive with class_name_allowlist. + repeated string class_name_denylist = 5; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/classifications.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/classifications.proto new file mode 100644 index 0000000..aaa4e7e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/classifications.proto
@@ -0,0 +1,39 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.processor; + +import "tensorflow_lite_support/cc/task/processor/proto/class.proto"; + +// List of predicted classes (aka labels) for a given classifier head. +message Classifications { + // The array of predicted classes, usually sorted by descending scores (e.g. + // from high to low probability). + repeated Class classes = 1; + // The index of the classifier head these classes refer to. This is useful for + // multi-head models. + optional int32 head_index = 2; + // The name of the classifier head, which is the corresponding tensor metadata + // name. See + // https://github.com/tensorflow/tflite-support/blob/710e323265bfb71fdbdd72b3516e00cff15c0326/tensorflow_lite_support/metadata/metadata_schema.fbs#L545 + optional string head_name = 3; +} + +// Contains one set of results per classifier head. +message ClassificationResult { + repeated Classifications classifications = 1; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/embedding.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/embedding.proto new file mode 100644 index 0000000..5c55c88 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/embedding.proto
@@ -0,0 +1,48 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.processor; + +// Defines a dense feature vector. Only one of the two fields is ever present. +// Feature vectors are assumed to be one-dimensional and L2-normalized. +message FeatureVector { + // Raw output of the embedding layer. Only provided if `quantize` is set to + // false in the EmbeddingOptions, which is the case by default. + repeated float value_float = 1 [packed = true]; + // Scalar-quantized embedding. Only provided if `quantize` is set to true in + // the ImageEmbedderOptions. + optional bytes value_string = 2; +} + +// Result produced by one of the embedder model output layers. +message Embedding { + // The output feature vector. + optional FeatureVector feature_vector = 1; + // The index of the model output layer that produced this feature vector. + optional int32 output_index = 2; +} + +// Embeddings produced by the Embedder. +message EmbeddingResult { + // The embeddings produced by each of the model output layers. + // + // Except in advanced cases, the embedding model has a single output layer, + // and this list is thus made of a single element. + repeated Embedding embeddings = 1; + // Reserved tags. + reserved 2; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/embedding_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/embedding_options.proto new file mode 100644 index 0000000..28bd2e11 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/embedding_options.proto
@@ -0,0 +1,34 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.processor; + +// Options for embedding processor. +// Next Id: 3 +message EmbeddingOptions { + // Whether to normalize the returned feature vector with L2 norm. Use this + // option only if the model does not already contain a native L2_NORMALIZATION + // TF Lite Op. In most cases, this is already the case and L2 norm is thus + // achieved through TF Lite inference. + optional bool l2_normalize = 1; + + // Whether the returned embedding should be quantized to bytes via scalar + // quantization. Embeddings are implicitly assumed to be unit-norm and + // therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + // the l2_normalize option if this is not the case. + optional bool quantize = 2; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc new file mode 100644 index 0000000..58b77b6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc
@@ -0,0 +1,216 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/processor/regex_preprocessor.h" + +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" + +namespace tflite { +namespace task { +namespace processor { + +namespace { + +using ::absl::StatusCode; +using ::tflite::TensorMetadata; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::support::text::tokenizer::RegexTokenizer; +using ::tflite::support::text::tokenizer::TokenizerResult; +using ::tflite::task::core::PopulateTensor; + +StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile( + const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>* + associated_files, + const tflite::metadata::ModelMetadataExtractor* metadata_extractor) { + if (associated_files == nullptr || associated_files->size() < 1 || + associated_files->Get(0)->name() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid vocab_file from input process unit.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + ASSIGN_OR_RETURN(absl::string_view vocab_buffer, + metadata_extractor->GetAssociatedFile( + associated_files->Get(0)->name()->str())); + return vocab_buffer; +} + +} // namespace + +/* static */ +StatusOr<std::unique_ptr<RegexPreprocessor>> RegexPreprocessor::Create( + tflite::task::core::TfLiteEngine* engine, + int input_tensor_index) { + ASSIGN_OR_RETURN(auto processor, Processor::Create<RegexPreprocessor>( + /* num_expected_tensors = */ 1, engine, + {input_tensor_index}, + /* requires_metadata = */ false)); + RETURN_IF_ERROR(processor->Init()); + return processor; +} + +absl::Status RegexPreprocessor::Init() { + // Check if the input is a STRING. If so, no tokenizer is needed. + if (GetTensor()->type == kTfLiteString) { + return absl::OkStatus(); + } + // Try if RegexTokenzier metadata can be found. + ASSIGN_OR_RETURN(const auto tokenzier_metadata, + TryFindRegexTokenizerMetadata()); + + ASSIGN_OR_RETURN(tokenizer_, CreateTokenizerFromMetadata( + tokenzier_metadata, GetMetadataExtractor())); + return absl::OkStatus(); +} + +StatusOr<const tflite::ProcessUnit*> +RegexPreprocessor::TryFindRegexTokenizerMetadata() { + // RegexTokenizer is packed in the processing unit of the input tensor. + const TensorMetadata* tensor_metadata = GetTensorMetadata(); + if (tensor_metadata == nullptr) { + return nullptr; + } + + ASSIGN_OR_RETURN( + auto tokenizer_metadata, + GetMetadataExtractor()->FindFirstProcessUnit( + *tensor_metadata, ProcessUnitOptions_RegexTokenizerOptions)); + + if (tokenizer_metadata != nullptr) { + // RegexTokenizer is found. Check if the tensor type matches. + auto input_tensor = GetTensor(); + if (input_tensor->type != kTfLiteInt32) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("Type mismatch for input tensor ", input_tensor->name, + ". Requested INT32 for RegexTokenizer, got ", + TfLiteTypeGetName(input_tensor->type), "."), + TfLiteSupportStatus::kInvalidInputTensorTypeError); + } + } + return tokenizer_metadata; +} + +StatusOr<std::unique_ptr<RegexTokenizer>> +RegexPreprocessor::CreateTokenizerFromMetadata( + const tflite::ProcessUnit* tokenizer_metadata, + const tflite::metadata::ModelMetadataExtractor* metadata_extractor) { + if (metadata_extractor == nullptr || tokenizer_metadata == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "No metadata or input process unit found.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + if (tokenizer_metadata->options_type() == + ProcessUnitOptions_RegexTokenizerOptions) { + const tflite::RegexTokenizerOptions* options = + tokenizer_metadata->options_as<RegexTokenizerOptions>(); + ASSIGN_OR_RETURN(absl::string_view vocab_buffer, + CheckAndLoadFirstAssociatedFile(options->vocab_file(), + metadata_extractor)); + if (options->delim_regex_pattern() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid delim_regex_pattern from input process unit.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + + std::unique_ptr<RegexTokenizer> regex_tokenizer = + absl::make_unique<RegexTokenizer>(options->delim_regex_pattern()->str(), + vocab_buffer.data(), + vocab_buffer.size()); + + int unknown_token_id = 0; + if (!regex_tokenizer->GetUnknownToken(&unknown_token_id)) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "RegexTokenizer doesn't have <UNKNOWN> token.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + + int pad_token_id = 0; + if (!regex_tokenizer->GetPadToken(&pad_token_id)) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "RegexTokenizer doesn't have <PAD> token.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + + return std::move(regex_tokenizer); + } else { + return CreateStatusWithPayload( + absl::StatusCode::kNotFound, + absl::StrCat("Incorrect options_type:", + tokenizer_metadata->options_type()), + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } +} + +absl::Status RegexPreprocessor::Preprocess(const std::string& input_text) { + if (tokenizer_ == nullptr) { + return PopulateTensor(input_text, GetTensor()); + } else { + return RegexPreprocess(input_text); + } +} + +absl::Status RegexPreprocessor::RegexPreprocess(const std::string& input_text) { + TfLiteTensor* input_tensor = GetTensor(); + + // |<-------sentence_length-------->| + // input_tensor <START>, t1, t2... <PAD>, <PAD>... + // <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's + // not found in tokenizer vocab. + TokenizerResult result = tokenizer_->Tokenize(input_text); + + size_t max_sentence_length = input_tensor->dims->size == 2 + ? input_tensor->dims->data[1] + : input_tensor->dims->data[0]; + + int unknown_token_id = 0; + tokenizer_->GetUnknownToken(&unknown_token_id); + + int pad_token_id = 0; + tokenizer_->GetPadToken(&pad_token_id); + + std::vector<int> input_tokens(max_sentence_length, pad_token_id); + int start_token_id = 0; + size_t input_token_index = 0; + if (tokenizer_->GetStartToken(&start_token_id)) { + input_tokens[0] = start_token_id; + input_token_index = 1; + } + + for (size_t i = 0; (i < result.subwords.size()) && + (input_token_index < max_sentence_length); + ++i, ++input_token_index) { + const std::string& token = result.subwords[i]; + int token_id = 0; + if (tokenizer_->LookupId(token, &token_id)) { + input_tokens[input_token_index] = token_id; + } else { + input_tokens[input_token_index] = unknown_token_id; + } + } + return PopulateTensor(input_tokens, input_tensor); +} + +} // namespace processor +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h new file mode 100644 index 0000000..bdd4e5e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h
@@ -0,0 +1,65 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_REGEX_PREPROCESSOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_REGEX_PREPROCESSOR_H_ + +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/processor/text_preprocessor.h" +#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h" + +namespace tflite { +namespace task { +namespace processor { + +// Processes input text and populates the associated input tensor. +// Requirements for the input tensor: +// A string tensor of type, kTfLiteString +// or +// An int32 tensor of type, kTfLiteInt32: contains the tokenized indices of +// a string input. A RegexTokenizer needs to be set up in the input tensor's +// metadata. +class RegexPreprocessor : public TextPreprocessor { + public: + static tflite::support::StatusOr<std::unique_ptr<RegexPreprocessor>> Create( + tflite::task::core::TfLiteEngine* engine, + int input_tensor_index); + + absl::Status Preprocess(const std::string& text); + + private: + using TextPreprocessor::TextPreprocessor; + + absl::Status Init(); + + tflite::support::StatusOr<const tflite::ProcessUnit*> + TryFindRegexTokenizerMetadata(); + + absl::Status RegexPreprocess(const std::string& input_text); + + tflite::support::StatusOr< + std::unique_ptr<tflite::support::text::tokenizer::RegexTokenizer>> + CreateTokenizerFromMetadata( + const tflite::ProcessUnit* tokenizer_metadata, + const tflite::metadata::ModelMetadataExtractor* metadata_extractor); + + std::unique_ptr<tflite::support::text::tokenizer::RegexTokenizer> tokenizer_; +}; + +} // namespace processor +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_REGEX_PREPROCESSOR_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/text_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/text_preprocessor.cc new file mode 100644 index 0000000..30525229 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/text_preprocessor.cc
@@ -0,0 +1,58 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/processor/text_preprocessor.h" + +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/processor/bert_preprocessor.h" +#include "tensorflow_lite_support/cc/task/processor/regex_preprocessor.h" + +namespace tflite { +namespace task { +namespace processor { + +namespace { + +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; + +} // namespace + +/* static */ StatusOr<std::unique_ptr<TextPreprocessor>> +TextPreprocessor::Create( + tflite::task::core::TfLiteEngine* engine, + const std::initializer_list<int> input_tensor_indices) { + switch (input_tensor_indices.size()) { + case 1: + return RegexPreprocessor::Create(engine, *input_tensor_indices.begin()); + case 3: + return BertPreprocessor::Create(engine, input_tensor_indices); + default: + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "TextPrerocessor accepts either 1 input tesnor (for Regex " + "tokenizer or String tensor) or 3 input tensors (for Bert " + "tokenizer), but got %d tensors.", + input_tensor_indices.size())); + } +} + +} // namespace processor +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/text_preprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/text_preprocessor.h new file mode 100644 index 0000000..5781bf4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/text_preprocessor.h
@@ -0,0 +1,56 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_TEXT_PREPROCESSOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_TEXT_PREPROCESSOR_H_ + +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/processor/processor.h" + +namespace tflite { +namespace task { +namespace processor { + +// Processes input text and populates the associated input tensors. +// Requirements for the input tensors (either one of the following): +// - One input tensor: +// A string tensor of type, kTfLiteString +// or +// An int32 tensor of type, kTfLiteInt32: contains the tokenized indices of +// a string input. A RegexTokenizer needs to be set up in the input tensor's +// metadata. +// - Three input tensors (input tensors of a Bert model): +// The 3 input tensors should be populated with metadata tensor names, +// "ids", "mask", and "segment_ids", respectively. The input_process_units +// metadata should contain WordPiece or Sentencepiece Tokenizer +// metadata. +class TextPreprocessor : public Preprocessor { + public: + static tflite::support::StatusOr<std::unique_ptr<TextPreprocessor>> Create( + tflite::task::core::TfLiteEngine* engine, + const std::initializer_list<int> input_tensor_indices); + + absl::Status Preprocess(const std::string& text); + + protected: + using Preprocessor::Preprocessor; +}; + +} // namespace processor +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_TEXT_PREPROCESSOR_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/BUILD new file mode 100644 index 0000000..05b7c5e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/BUILD
@@ -0,0 +1,116 @@ +load( + "@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", + "cc_library_with_tflite", +) + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library_with_tflite( + name = "bert_nl_classifier", + srcs = [ + "bert_nl_classifier.cc", + ], + hdrs = [ + "bert_nl_classifier.h", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier", + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/task/core:category", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/text/proto:bert_nl_classifier_options_proto_inc", + "//tensorflow_lite_support/cc/text/tokenizers:tokenizer", + "//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite:string", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/core/api", + ], +) + +cc_library_with_tflite( + name = "question_answerer", + hdrs = [ + "question_answerer.h", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/core:base_task_api", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + ], +) + +cc_library_with_tflite( + name = "bert_question_answerer", + srcs = [ + "bert_question_answerer.cc", + ], + hdrs = [ + "bert_question_answerer.h", + ], + tflite_deps = [ + ":question_answerer", + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + "//tensorflow_lite_support/cc/task/core:base_task_api", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + ], + deps = [ + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/text/proto:bert_question_answerer_options_proto_inc", + "//tensorflow_lite_support/cc/text/tokenizers:bert_tokenizer", + "//tensorflow_lite_support/cc/text/tokenizers:sentencepiece_tokenizer", + "//tensorflow_lite_support/cc/text/tokenizers:tokenizer", + "//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_library_with_tflite( + name = "universal_sentence_encoder_qa", + srcs = [ + "universal_sentence_encoder_qa.cc", + ], + hdrs = [ + "universal_sentence_encoder_qa.h", + ], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + "//tensorflow_lite_support/cc/task/core:base_task_api", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + ], + deps = [ + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/processor/proto:embedding_cc_proto", + "//tensorflow_lite_support/cc/task/text/proto:retrieval_cc_proto", + "//tensorflow_lite_support/cc/text/tokenizers:bert_tokenizer", + "//tensorflow_lite_support/cc/text/tokenizers:sentencepiece_tokenizer", + "//tensorflow_lite_support/cc/text/tokenizers:tokenizer", + "//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc new file mode 100644 index 0000000..ac8fa548 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc
@@ -0,0 +1,208 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/bert_nl_classifier.h" + +#include <limits.h> +#include <stddef.h> + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/ascii.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/string_type.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/category.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h" +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" + +namespace tflite { +namespace task { +namespace text { + +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit; +using ::tflite::support::text::tokenizer::TokenizerResult; +using ::tflite::task::core::FindTensorByName; +using ::tflite::task::core::PopulateTensor; + +namespace { +constexpr char kIdsTensorName[] = "ids"; +constexpr char kMaskTensorName[] = "mask"; +constexpr char kSegmentIdsTensorName[] = "segment_ids"; +constexpr char kScoreTensorName[] = "probability"; +constexpr char kClassificationToken[] = "[CLS]"; +constexpr char kSeparator[] = "[SEP]"; +constexpr int kTokenizerProcessUnitIndex = 0; + +absl::Status SanityCheckOptions(const BertNLClassifierOptions& options) { + if (!options.has_base_options()) { + return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, + "Missing mandatory `base_options` field", + TfLiteSupportStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +int GetLastDimSize(const TfLiteTensor* tensor) { + return tensor->dims->data[tensor->dims->size - 1]; +} + +} // namespace + +absl::Status BertNLClassifier::Preprocess( + const std::vector<TfLiteTensor*>& input_tensors, + const std::string& input) { + auto* input_tensor_metadatas = + GetMetadataExtractor()->GetInputTensorMetadata(); + auto* ids_tensor = + FindTensorByName(input_tensors, input_tensor_metadatas, kIdsTensorName); + auto* mask_tensor = + FindTensorByName(input_tensors, input_tensor_metadatas, kMaskTensorName); + auto* segment_ids_tensor = FindTensorByName( + input_tensors, input_tensor_metadatas, kSegmentIdsTensorName); + + if (GetLastDimSize(ids_tensor) != GetLastDimSize(mask_tensor) || + GetLastDimSize(ids_tensor) != GetLastDimSize(segment_ids_tensor)) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("The three input tensors in BertNLClassifier models " + "are expected to have same length, but got ids_tensor " + "(%d), mask_tensor (%d), segment_ids_tensor (%d).", + GetLastDimSize(ids_tensor), GetLastDimSize(mask_tensor), + GetLastDimSize(ids_tensor)), + TfLiteSupportStatus::kInvalidNumOutputTensorsError); + } + + int max_seq_len = GetLastDimSize(ids_tensor); + + std::string processed_input = input; + absl::AsciiStrToLower(&processed_input); + + TokenizerResult input_tokenize_results; + input_tokenize_results = tokenizer_->Tokenize(processed_input); + + // 2 accounts for [CLS], [SEP] + absl::Span<const std::string> query_tokens = + absl::MakeSpan(input_tokenize_results.subwords.data(), + input_tokenize_results.subwords.data() + + std::min(static_cast<size_t>(max_seq_len - 2), + input_tokenize_results.subwords.size())); + + std::vector<std::string> tokens; + tokens.reserve(2 + query_tokens.size()); + // Start of generating the features. + tokens.push_back(kClassificationToken); + // For query input. + for (const auto& query_token : query_tokens) { + tokens.push_back(query_token); + } + // For Separation. + tokens.push_back(kSeparator); + + std::vector<int> input_ids(max_seq_len, 0); + std::vector<int> input_mask(max_seq_len, 0); + // Convert tokens back into ids and set mask + for (int i = 0; i < tokens.size(); ++i) { + tokenizer_->LookupId(tokens[i], &input_ids[i]); + input_mask[i] = 1; + } + // |<-----------max_seq_len--------->| + // input_ids [CLS] s1 s2... sn [SEP] 0 0... 0 + // input_masks 1 1 1... 1 1 0 0... 0 + // segment_ids 0 0 0... 0 0 0 0... 0 + + RETURN_IF_ERROR(PopulateTensor(input_ids, ids_tensor)); + RETURN_IF_ERROR(PopulateTensor(input_mask, mask_tensor)); + RETURN_IF_ERROR( + PopulateTensor(std::vector<int>(max_seq_len, 0), segment_ids_tensor)); + + return absl::OkStatus(); +} + +StatusOr<std::vector<core::Category>> BertNLClassifier::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const std::string& /*input*/) { + if (output_tensors.size() != 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("BertNLClassifier models are expected to have only 1 " + "output, found %d", + output_tensors.size()), + TfLiteSupportStatus::kInvalidNumOutputTensorsError); + } + const TfLiteTensor* scores = FindTensorByName( + output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(), + kScoreTensorName); + + // optional labels extracted from metadata + return BuildResults(scores, /*labels=*/nullptr); +} + +StatusOr<std::unique_ptr<BertNLClassifier>> BertNLClassifier::CreateFromOptions( + const BertNLClassifierOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + RETURN_IF_ERROR(SanityCheckOptions(options)); + + auto options_copy = absl::make_unique<BertNLClassifierOptions>(options); + + ASSIGN_OR_RETURN( + auto bert_nl_classifier, + core::TaskAPIFactory::CreateFromBaseOptions<BertNLClassifier>( + &options_copy->base_options(), std::move(resolver))); + RETURN_IF_ERROR(bert_nl_classifier->Initialize(std::move(options_copy))); + return std::move(bert_nl_classifier); +} + +absl::Status BertNLClassifier::Initialize( + std::unique_ptr<BertNLClassifierOptions> options) { + options_ = std::move(options); + // Set up mandatory tokenizer from metadata. + const ProcessUnit* tokenizer_process_unit = + GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex); + if (tokenizer_process_unit == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "No input process unit found from metadata.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + ASSIGN_OR_RETURN(tokenizer_, + CreateTokenizerFromProcessUnit(tokenizer_process_unit, + GetMetadataExtractor())); + + // Set up optional label vector from metadata. + TrySetLabelFromMetadata( + GetMetadataExtractor()->GetOutputTensorMetadata(kOutputTensorIndex)) + .IgnoreError(); + + return absl::OkStatus(); +} + +} // namespace text +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h new file mode 100644 index 0000000..91bcfe5 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h
@@ -0,0 +1,133 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_ + +#include <stddef.h> + +#include <memory> +#include <string> +#include <vector> + +#include "absl/base/macros.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/core/shims/cc/kernels/register.h" +#include "tensorflow/lite/string_type.h" +#include "tensorflow_lite_support/cc/task/core/category.h" +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" +#include "tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options_proto_inc.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" + +namespace tflite { +namespace task { +namespace text { + +// Classifier API for NLClassification tasks with Bert models, categorizes +// string into different classes. +// +// The API expects a Bert based TFLite model with metadata populated. +// The metadata should contain the following information: +// - input_process_units for Wordpiece/Sentencepiece Tokenizer +// - 3 input tensors with names "ids", "mask" and "segment_ids" +// - 1 output tensor of type float32[1, 2], with a optionally attached label +// file. If a label file is attached, the file should be a plain text file +// with one label per line, the number of labels should match the number of +// categories the model outputs. + +class BertNLClassifier : public tflite::task::text::nlclassifier::NLClassifier { + public: + using tflite::task::text::nlclassifier::NLClassifier::NLClassifier; + + // Factory function to create a BertNLClassifier from BertNLClassifierOptions. + static tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> + CreateFromOptions( + const BertNLClassifierOptions& options, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); + + // Factory function to create a BertNLClassifier from TFLite model with + // metadata. + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") + static tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> + CreateFromFile( + const std::string& path_to_model_with_metadata, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()) { + BertNLClassifierOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + path_to_model_with_metadata); + return CreateFromOptions(options, std::move(resolver)); + } + + // Factory function to create a BertNLClassifier from in memory buffer of a + // TFLite model with metadata. + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") + static tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> + CreateFromBuffer( + const char* model_with_metadata_buffer_data, + size_t model_with_metadata_buffer_size, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()) { + BertNLClassifierOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_content( + model_with_metadata_buffer_data, model_with_metadata_buffer_size); + return CreateFromOptions(options, std::move(resolver)); + } + + // Factory function to create a BertNLClassifier from the file descriptor of a + // TFLite model with metadata. + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") + static tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> + CreateFromFd( + int fd, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()) { + BertNLClassifierOptions options; + options.mutable_base_options() + ->mutable_model_file() + ->mutable_file_descriptor_meta() + ->set_fd(fd); + return CreateFromOptions(options, std::move(resolver)); + } + + protected: + // Run tokenization on input text and construct three input tensors ids, mask + // and segment_ids for the model input. + absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, + const std::string& input) override; + + // Extract model output and create results with label file attached in + // metadata. If no label file is attached, use output score index as labels. + tflite::support::StatusOr<std::vector<core::Category>> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const std::string& input) override; + + private: + // Initialize the API with the tokenizer and label files set in the metadata. + absl::Status Initialize(std::unique_ptr<BertNLClassifierOptions> options); + + std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_; + + std::unique_ptr<BertNLClassifierOptions> options_; +}; + +} // namespace text +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc new file mode 100644 index 0000000..591b70e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc
@@ -0,0 +1,423 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h" + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_join.h" // from @com_google_absl +#include "absl/strings/str_split.h" // from @com_google_absl +#include "tensorflow/lite/core/shims/cc/kernels/register.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace text { + +constexpr char kIdsTensorName[] = "ids"; +constexpr char kMaskTensorName[] = "mask"; +constexpr char kSegmentIdsTensorName[] = "segment_ids"; +constexpr char kEndLogitsTensorName[] = "end_logits"; +constexpr char kStartLogitsTensorName[] = "start_logits"; + +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::support::text::tokenizer::BertTokenizer; +using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit; +using ::tflite::support::text::tokenizer::SentencePieceTokenizer; +using ::tflite::support::text::tokenizer::TokenizerResult; +using ::tflite::task::core::FindTensorByName; +using ::tflite::task::core::PopulateTensor; +using ::tflite::task::core::PopulateVector; +using ::tflite::task::core::ReverseSortIndices; + +namespace { +constexpr int kTokenizerProcessUnitIndex = 0; + +absl::Status SanityCheckOptions(const BertQuestionAnswererOptions& options) { + if (!options.has_base_options()) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Missing mandatory `base_options` field", + TfLiteSupportStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} +} // namespace + +StatusOr<std::unique_ptr<QuestionAnswerer>> +BertQuestionAnswerer::CreateFromOptions( + const BertQuestionAnswererOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + RETURN_IF_ERROR(SanityCheckOptions(options)); + + // Copy options to ensure the ExternalFile outlives the duration of this + // created BertQuestionAnswerer object. + auto options_copy = absl::make_unique<BertQuestionAnswererOptions>(options); + + ASSIGN_OR_RETURN( + auto bert_question_answerer, + core::TaskAPIFactory::CreateFromBaseOptions<BertQuestionAnswerer>( + &options_copy->base_options(), std::move(resolver))); + RETURN_IF_ERROR( + bert_question_answerer->InitializeFromMetadata(std::move(options_copy))); + return std::move(bert_question_answerer); +} + +StatusOr<std::unique_ptr<QuestionAnswerer>> +BertQuestionAnswerer::CreateFromFile( + const std::string& path_to_model_with_metadata) { + BertQuestionAnswererOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + path_to_model_with_metadata); + return CreateFromOptions(options); +} + +StatusOr<std::unique_ptr<QuestionAnswerer>> +BertQuestionAnswerer::CreateFromBuffer( + const char* model_with_metadata_buffer_data, + size_t model_with_metadata_buffer_size) { + BertQuestionAnswererOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_content( + model_with_metadata_buffer_data, model_with_metadata_buffer_size); + return CreateFromOptions(options); +} + +StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateFromFd( + int fd) { + BertQuestionAnswererOptions options; + options.mutable_base_options() + ->mutable_model_file() + ->mutable_file_descriptor_meta() + ->set_fd(fd); + return CreateFromOptions(options); +} + +StatusOr<std::unique_ptr<QuestionAnswerer>> +BertQuestionAnswerer::CreateBertQuestionAnswererFromFile( + const std::string& path_to_model, + const std::string& path_to_vocab) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>( + path_to_model, + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>(), + kNumLiteThreads)); + api_to_init->InitializeBertTokenizer(path_to_vocab); + return std::move(api_to_init); +} + +StatusOr<std::unique_ptr<QuestionAnswerer>> +BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer( + const char* model_buffer_data, + size_t model_buffer_size, + const char* vocab_buffer_data, + size_t vocab_buffer_size) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>( + model_buffer_data, model_buffer_size, + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>(), + kNumLiteThreads)); + api_to_init->InitializeBertTokenizerFromBinary(vocab_buffer_data, + vocab_buffer_size); + return std::move(api_to_init); +} + +StatusOr<std::unique_ptr<QuestionAnswerer>> +BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile( + const std::string& path_to_model, + const std::string& path_to_spmodel) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>( + path_to_model, + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>(), + kNumLiteThreads)); + api_to_init->InitializeSentencepieceTokenizer(path_to_spmodel); + return std::move(api_to_init); +} + +StatusOr<std::unique_ptr<QuestionAnswerer>> +BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer( + const char* model_buffer_data, + size_t model_buffer_size, + const char* spmodel_buffer_data, + size_t spmodel_buffer_size) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>( + model_buffer_data, model_buffer_size, + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>(), + kNumLiteThreads)); + api_to_init->InitializeSentencepieceTokenizerFromBinary(spmodel_buffer_data, + spmodel_buffer_size); + return std::move(api_to_init); +} + +std::vector<QaAnswer> BertQuestionAnswerer::Answer( + const std::string& context, + const std::string& question) { + // The BertQuestionAnswererer implementation for Preprocess() and + // Postprocess() never returns errors: just call value(). + return Infer(context, question).value(); +} + +absl::Status BertQuestionAnswerer::Preprocess( + const std::vector<TfLiteTensor*>& input_tensors, + const std::string& context, + const std::string& query) { + auto* input_tensor_metadatas = + GetMetadataExtractor()->GetInputTensorMetadata(); + TfLiteTensor* ids_tensor = + input_tensor_metadatas + ? FindTensorByName(input_tensors, input_tensor_metadatas, + kIdsTensorName) + : input_tensors[0]; + TfLiteTensor* mask_tensor = + input_tensor_metadatas + ? FindTensorByName(input_tensors, input_tensor_metadatas, + kMaskTensorName) + : input_tensors[1]; + TfLiteTensor* segment_ids_tensor = + input_tensor_metadatas + ? FindTensorByName(input_tensors, input_tensor_metadatas, + kSegmentIdsTensorName) + : input_tensors[2]; + + token_to_orig_map_.clear(); + + // The orig_tokens is used for recovering the answer string from the index, + // while the processed_tokens is lower-cased and used to generate input of + // the model. + orig_tokens_ = absl::StrSplit(context, absl::ByChar(' '), absl::SkipEmpty()); + std::vector<std::string> processed_tokens(orig_tokens_); + + std::string processed_query = query; + if (kUseLowerCase) { + for (auto& token : processed_tokens) { + absl::AsciiStrToLower(&token); + } + absl::AsciiStrToLower(&processed_query); + } + + TokenizerResult query_tokenize_results; + query_tokenize_results = tokenizer_->Tokenize(processed_query); + + std::vector<std::string> query_tokens = query_tokenize_results.subwords; + if (query_tokens.size() > kMaxQueryLen) { + query_tokens.resize(kMaxQueryLen); + } + + // Example: + // context: tokenize me please + // all_doc_tokens: token ##ize me plea ##se + // token_to_orig_index: [0, 0, 1, 2, 2] + + std::vector<std::string> all_doc_tokens; + std::vector<int> token_to_orig_index; + for (size_t i = 0; i < processed_tokens.size(); i++) { + const std::string& token = processed_tokens[i]; + std::vector<std::string> sub_tokens = tokenizer_->Tokenize(token).subwords; + for (const std::string& sub_token : sub_tokens) { + token_to_orig_index.emplace_back(i); + all_doc_tokens.emplace_back(sub_token); + } + } + + // -3 accounts for [CLS], [SEP] and [SEP]. + int max_context_len = kMaxSeqLen - query_tokens.size() - 3; + if (all_doc_tokens.size() > max_context_len) { + all_doc_tokens.resize(max_context_len); + } + + std::vector<std::string> tokens; + tokens.reserve(3 + query_tokens.size() + all_doc_tokens.size()); + std::vector<int> segment_ids; + segment_ids.reserve(kMaxSeqLen); + + // Start of generating the features. + tokens.emplace_back("[CLS]"); + segment_ids.emplace_back(0); + + // For query input. + for (const auto& query_token : query_tokens) { + tokens.emplace_back(query_token); + segment_ids.emplace_back(0); + } + + // For Separation. + tokens.emplace_back("[SEP]"); + segment_ids.emplace_back(0); + + // For Text Input. + for (int i = 0; i < all_doc_tokens.size(); i++) { + auto& doc_token = all_doc_tokens[i]; + tokens.emplace_back(doc_token); + segment_ids.emplace_back(1); + token_to_orig_map_[tokens.size()] = token_to_orig_index[i]; + } + + // For ending mark. + tokens.emplace_back("[SEP]"); + segment_ids.emplace_back(1); + + std::vector<int> input_ids(tokens.size()); + input_ids.reserve(kMaxSeqLen); + // Convert tokens back into ids + for (int i = 0; i < tokens.size(); i++) { + auto& token = tokens[i]; + tokenizer_->LookupId(token, &input_ids[i]); + } + + std::vector<int> input_mask; + input_mask.reserve(kMaxSeqLen); + input_mask.insert(input_mask.end(), tokens.size(), 1); + + int zeros_to_pad = kMaxSeqLen - input_ids.size(); + input_ids.insert(input_ids.end(), zeros_to_pad, 0); + input_mask.insert(input_mask.end(), zeros_to_pad, 0); + segment_ids.insert(segment_ids.end(), zeros_to_pad, 0); + + // input_ids INT32[1, 384] + RETURN_IF_ERROR(PopulateTensor(input_ids, ids_tensor)); + // input_mask INT32[1, 384] + RETURN_IF_ERROR(PopulateTensor(input_mask, mask_tensor)); + // segment_ids INT32[1, 384] + RETURN_IF_ERROR(PopulateTensor(segment_ids, segment_ids_tensor)); + + return absl::OkStatus(); +} + +StatusOr<std::vector<QaAnswer>> BertQuestionAnswerer::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const std::string& /*lowercased_context*/, + const std::string& /*lowercased_query*/) { + auto* output_tensor_metadatas = + GetMetadataExtractor()->GetOutputTensorMetadata(); + + const TfLiteTensor* end_logits_tensor = + output_tensor_metadatas + ? FindTensorByName(output_tensors, output_tensor_metadatas, + kEndLogitsTensorName) + : output_tensors[0]; + const TfLiteTensor* start_logits_tensor = + output_tensor_metadatas + ? FindTensorByName(output_tensors, output_tensor_metadatas, + kStartLogitsTensorName) + : output_tensors[1]; + + std::vector<float> end_logits; + std::vector<float> start_logits; + + // end_logits FLOAT[1, 384] + RETURN_IF_ERROR(PopulateVector(end_logits_tensor, &end_logits)); + // start_logits FLOAT[1, 384] + RETURN_IF_ERROR(PopulateVector(start_logits_tensor, &start_logits)); + + auto start_indices = ReverseSortIndices(start_logits); + auto end_indices = ReverseSortIndices(end_logits); + + std::vector<QaAnswer::Pos> orig_results; + for (int start_index = 0; start_index < kPredictAnsNum; start_index++) { + for (int end_index = 0; end_index < kPredictAnsNum; end_index++) { + int start = start_indices[start_index]; + int end = end_indices[end_index]; + + if (!token_to_orig_map_.contains(start + kOutputOffset) || + !token_to_orig_map_.contains(end + kOutputOffset) || end < start || + (end - start + 1) > kMaxAnsLen) { + continue; + } + orig_results.emplace_back( + QaAnswer::Pos(start, end, start_logits[start] + end_logits[end])); + } + } + + std::sort(orig_results.begin(), orig_results.end()); + + std::vector<QaAnswer> answers; + for (int i = 0; i < orig_results.size() && i < kPredictAnsNum; i++) { + auto orig_pos = orig_results[i]; + answers.emplace_back( + orig_pos.start > 0 ? ConvertIndexToString(orig_pos.start, orig_pos.end) + : "", + orig_pos); + } + + return answers; +} + +std::string BertQuestionAnswerer::ConvertIndexToString(int start, int end) { + int start_index = token_to_orig_map_[start + kOutputOffset]; + int end_index = token_to_orig_map_[end + kOutputOffset]; + + return absl::StrJoin(orig_tokens_.begin() + start_index, + orig_tokens_.begin() + end_index + 1, " "); +} + +absl::Status BertQuestionAnswerer::InitializeFromMetadata( + std::unique_ptr<BertQuestionAnswererOptions> options) { + options_ = std::move(options); + + const ProcessUnit* tokenizer_process_unit = + GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex); + if (tokenizer_process_unit == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "No input process unit found from metadata.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + ASSIGN_OR_RETURN(tokenizer_, + CreateTokenizerFromProcessUnit(tokenizer_process_unit, + GetMetadataExtractor())); + return absl::OkStatus(); +} + +void BertQuestionAnswerer::InitializeBertTokenizer( + const std::string& path_to_vocab) { + tokenizer_ = absl::make_unique<BertTokenizer>(path_to_vocab); +} + +void BertQuestionAnswerer::InitializeBertTokenizerFromBinary( + const char* vocab_buffer_data, + size_t vocab_buffer_size) { + tokenizer_ = + absl::make_unique<BertTokenizer>(vocab_buffer_data, vocab_buffer_size); +} + +void BertQuestionAnswerer::InitializeSentencepieceTokenizer( + const std::string& path_to_spmodel) { + tokenizer_ = absl::make_unique<SentencePieceTokenizer>(path_to_spmodel); +} + +void BertQuestionAnswerer::InitializeSentencepieceTokenizerFromBinary( + const char* spmodel_buffer_data, + size_t spmodel_buffer_size) { + tokenizer_ = absl::make_unique<SentencePieceTokenizer>(spmodel_buffer_data, + spmodel_buffer_size); +} + +} // namespace text +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h new file mode 100644 index 0000000..52ec835 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h
@@ -0,0 +1,160 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_QA_BERT_QUESTION_ANSWERER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_QA_BERT_QUESTION_ANSWERER_H_ + +#include "absl/base/macros.h" // from @com_google_absl +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/text/proto/bert_question_answerer_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/text/question_answerer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h" + +namespace tflite { +namespace task { +namespace text { + +// BertQA task API, performs tokenization for models (BERT, Albert, etc.) in +// preprocess and returns most possible answers. +// +// In particular, the branch of BERT models use WordPiece tokenizer, and the +// branch of Albert models use SentencePiece tokenizer, respectively. +// +// The API expects a Bert based TFLite model with metadata populated. +// The metadata should contain the following information: +// - input_process_units for Wordpiece/Sentencepiece Tokenizer. Wordpiece +// Tokenizer can be used for a MobileBert[0] model, Sentencepiece +// Tokenizer Tokenizer can be used for an Albert[1] model. +// - 3 input tensors with names "ids", "mask" and "segment_ids". +// - 2 output tensors with names "end_logits" and "start_logits". +// [0]: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 +// [1]: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 +// +// See the public documentation for more information: +// https://www.tensorflow.org/lite/inference_with_metadata/task_library/bert_question_answerer + +class BertQuestionAnswerer : public QuestionAnswerer { + public: + // TODO(b/150904655): add support to parameterize. + static constexpr int kMaxQueryLen = 64; + static constexpr int kMaxSeqLen = 384; + static constexpr int kPredictAnsNum = 5; + static constexpr int kMaxAnsLen = 32; + // TODO(b/151954803): clarify the offset usage + static constexpr int kOutputOffset = 1; + static constexpr int kNumLiteThreads = 4; + static constexpr bool kUseLowerCase = true; + + // Factory function to create a `BertQuestionAnswerer` from + // `BertQuestionAnswererOptions`. + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateFromOptions( + const BertQuestionAnswererOptions& options, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); + + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateFromFile(const std::string& path_to_model_with_metadata); + + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateFromBuffer(const char* model_with_metadata_buffer_data, + size_t model_with_metadata_buffer_size); + + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateFromFd(int fd); + + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateBertQuestionAnswererFromFile(const std::string& path_to_model, + const std::string& path_to_vocab); + + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateBertQuestionAnswererFromBuffer(const char* model_buffer_data, + size_t model_buffer_size, + const char* vocab_buffer_data, + size_t vocab_buffer_size); + + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateAlbertQuestionAnswererFromFile(const std::string& path_to_model, + const std::string& path_to_spmodel); + + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateAlbertQuestionAnswererFromBuffer(const char* model_buffer_data, + size_t model_buffer_size, + const char* spmodel_buffer_data, + size_t spmodel_buffer_size); + + explicit BertQuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine) + : QuestionAnswerer(std::move(engine)) {} + + // Answers question based on the context. Could be empty if no answer was + // found from the given context. + std::vector<QaAnswer> Answer(const std::string& context, + const std::string& question) override; + + private: + absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, + const std::string& lowercased_context, + const std::string& lowercased_query) override; + + tflite::support::StatusOr<std::vector<QaAnswer>> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const std::string& lowercased_context, + const std::string& lowercased_query) override; + + // Initialize API with a BertTokenizer from the vocabulary file. + void InitializeBertTokenizer(const std::string& path_to_vocab); + // Initialize API with a BertTokenizer from the vocabulary buffer. + void InitializeBertTokenizerFromBinary(const char* vocab_buffer_data, + size_t vocab_buffer_size); + + // Initialize API with a SentencepieceTokenizer from the model file. + void InitializeSentencepieceTokenizer(const std::string& path_to_spmodel); + // Initialize API with a SentencepieceTokenizer from the model buffer. + void InitializeSentencepieceTokenizerFromBinary( + const char* spmodel_buffer_data, + size_t spmodel_buffer_size); + + // Initialize the API with the tokenizer set in the metadata. + absl::Status InitializeFromMetadata( + std::unique_ptr<BertQuestionAnswererOptions> options); + + std::string ConvertIndexToString(int start, int end); + + std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_; + // Maps index of input token to index of untokenized word from original input. + absl::flat_hash_map<size_t, size_t> token_to_orig_map_; + // Original tokens of context. + std::vector<std::string> orig_tokens_; + std::unique_ptr<BertQuestionAnswererOptions> options_; +}; + +} // namespace text +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_QA_BERT_QUESTION_ANSWERER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/BUILD index 33b6f6a..8a5ae2b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/BUILD
@@ -1,15 +1,14 @@ +load( + "@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", + "cc_library_with_tflite", +) + package( default_visibility = ["//tensorflow_lite_support:users"], licenses = ["notice"], # Apache 2.0 ) -exports_files([ - "bert_nl_classifier_c_api.h", - "nl_classifier_c_api.h", - "nl_classifier_c_api_common.h", -]) - -cc_library( +cc_library_with_tflite( name = "nl_classifier", srcs = [ "nl_classifier.cc", @@ -17,102 +16,30 @@ hdrs = [ "nl_classifier.h", ], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + "//tensorflow_lite_support/cc/task/core:base_task_api", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + ], deps = [ "//tensorflow_lite_support/cc:common", "//tensorflow_lite_support/cc/port:status_macros", "//tensorflow_lite_support/cc/port:statusor", - "//tensorflow_lite_support/cc/task/core:base_task_api", "//tensorflow_lite_support/cc/task/core:category", - "//tensorflow_lite_support/cc/task/core:task_api_factory", "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/text/proto:nl_classifier_options_proto_inc", "//tensorflow_lite_support/cc/text/tokenizers:regex_tokenizer", "//tensorflow_lite_support/cc/text/tokenizers:tokenizer", "//tensorflow_lite_support/cc/utils:common_utils", "//tensorflow_lite_support/metadata/cc:metadata_extractor", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@flatbuffers", "@org_tensorflow//tensorflow/lite:string", "@org_tensorflow//tensorflow/lite/c:common", "@org_tensorflow//tensorflow/lite/core/api", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", ], ) - -cc_library( - name = "nl_classifier_c_api", - srcs = [ - "nl_classifier_c_api.cc", - ], - hdrs = [ - "nl_classifier_c_api.h", - "nl_classifier_c_api_common.h", - ], - visibility = ["//tensorflow_lite_support:__subpackages__"], - deps = [ - ":nl_classifier", - ":nl_classifier_c_api_common", - "//tensorflow_lite_support/cc/task/core:category", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "bert_nl_classifier", - srcs = [ - "bert_nl_classifier.cc", - ], - hdrs = [ - "bert_nl_classifier.h", - ], - deps = [ - ":nl_classifier", - "//tensorflow_lite_support/cc:common", - "//tensorflow_lite_support/cc/port:status_macros", - "//tensorflow_lite_support/cc/port:statusor", - "//tensorflow_lite_support/cc/task/core:category", - "//tensorflow_lite_support/cc/task/core:task_api_factory", - "//tensorflow_lite_support/cc/task/core:task_utils", - "//tensorflow_lite_support/cc/text/tokenizers:tokenizer", - "//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils", - "//tensorflow_lite_support/metadata/cc:metadata_extractor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@org_tensorflow//tensorflow/lite:string", - "@org_tensorflow//tensorflow/lite/c:common", - "@org_tensorflow//tensorflow/lite/core/api", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ], -) - -cc_library( - name = "bert_nl_classifier_c_api", - srcs = [ - "bert_nl_classifier_c_api.cc", - ], - hdrs = [ - "bert_nl_classifier_c_api.h", - "nl_classifier_c_api_common.h", - ], - visibility = ["//tensorflow_lite_support:__subpackages__"], - deps = [ - ":bert_nl_classifier", - ":nl_classifier_c_api_common", - "//tensorflow_lite_support/cc/task/core:category", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "nl_classifier_c_api_common", - srcs = [ - "nl_classifier_c_api_common.cc", - ], - hdrs = [ - "nl_classifier_c_api_common.h", - ], - visibility = ["//tensorflow_lite_support:__subpackages__"], -)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc deleted file mode 100644 index e992b1c..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc +++ /dev/null
@@ -1,198 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h" - -#include <stddef.h> - -#include <memory> -#include <string> -#include <utility> -#include <vector> - -#include "absl/status/status.h" -#include "absl/strings/ascii.h" -#include "absl/strings/str_format.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/string_type.h" -#include "tensorflow_lite_support/cc/common.h" -#include "tensorflow_lite_support/cc/port/status_macros.h" -#include "tensorflow_lite_support/cc/task/core/category.h" -#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" -#include "tensorflow_lite_support/cc/task/core/task_utils.h" -#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" -#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" -#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h" -#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" - -namespace tflite { -namespace task { -namespace text { -namespace nlclassifier { - -using ::tflite::support::CreateStatusWithPayload; -using ::tflite::support::StatusOr; -using ::tflite::support::TfLiteSupportStatus; -using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit; -using ::tflite::support::text::tokenizer::TokenizerResult; -using ::tflite::task::core::FindTensorByName; -using ::tflite::task::core::PopulateTensor; - -namespace { -constexpr char kIdsTensorName[] = "ids"; -constexpr char kMaskTensorName[] = "mask"; -constexpr char kSegmentIdsTensorName[] = "segment_ids"; -constexpr char kScoreTensorName[] = "probability"; -constexpr char kClassificationToken[] = "[CLS]"; -constexpr char kSeparator[] = "[SEP]"; -constexpr int kTokenizerProcessUnitIndex = 0; -} // namespace - -absl::Status BertNLClassifier::Preprocess( - const std::vector<TfLiteTensor*>& input_tensors, - const std::string& input) { - auto* input_tensor_metadatas = - GetMetadataExtractor()->GetInputTensorMetadata(); - auto* ids_tensor = - FindTensorByName(input_tensors, input_tensor_metadatas, kIdsTensorName); - auto* mask_tensor = - FindTensorByName(input_tensors, input_tensor_metadatas, kMaskTensorName); - auto* segment_ids_tensor = FindTensorByName( - input_tensors, input_tensor_metadatas, kSegmentIdsTensorName); - - std::string processed_input = input; - absl::AsciiStrToLower(&processed_input); - - TokenizerResult input_tokenize_results; - input_tokenize_results = tokenizer_->Tokenize(processed_input); - - // 2 accounts for [CLS], [SEP] - absl::Span<const std::string> query_tokens = - absl::MakeSpan(input_tokenize_results.subwords.data(), - input_tokenize_results.subwords.data() + - std::min(static_cast<size_t>(kMaxSeqLen - 2), - input_tokenize_results.subwords.size())); - - std::vector<std::string> tokens; - tokens.reserve(2 + query_tokens.size()); - // Start of generating the features. - tokens.push_back(kClassificationToken); - // For query input. - for (const auto& query_token : query_tokens) { - tokens.push_back(query_token); - } - // For Separation. - tokens.push_back(kSeparator); - - std::vector<int> input_ids(kMaxSeqLen, 0); - std::vector<int> input_mask(kMaxSeqLen, 0); - // Convert tokens back into ids and set mask - for (int i = 0; i < tokens.size(); ++i) { - tokenizer_->LookupId(tokens[i], &input_ids[i]); - input_mask[i] = 1; - } - // |<-----------kMaxSeqLen---------->| - // input_ids [CLS] s1 s2... sn [SEP] 0 0... 0 - // input_masks 1 1 1... 1 1 0 0... 0 - // segment_ids 0 0 0... 0 0 0 0... 0 - - PopulateTensor(input_ids, ids_tensor); - PopulateTensor(input_mask, mask_tensor); - PopulateTensor(std::vector<int>(kMaxSeqLen, 0), segment_ids_tensor); - - return absl::OkStatus(); -} - -StatusOr<std::vector<core::Category>> BertNLClassifier::Postprocess( - const std::vector<const TfLiteTensor*>& output_tensors, - const std::string& /*input*/) { - if (output_tensors.size() != 1) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat("BertNLClassifier models are expected to have only 1 " - "output, found %d", - output_tensors.size()), - TfLiteSupportStatus::kInvalidNumOutputTensorsError); - } - const TfLiteTensor* scores = FindTensorByName( - output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(), - kScoreTensorName); - - // optional labels extracted from metadata - return BuildResults(scores, /*labels=*/nullptr); -} - -StatusOr<std::unique_ptr<BertNLClassifier>> BertNLClassifier::CreateFromFile( - const std::string& path_to_model_with_metadata, - std::unique_ptr<tflite::OpResolver> resolver) { - std::unique_ptr<BertNLClassifier> bert_nl_classifier; - ASSIGN_OR_RETURN(bert_nl_classifier, - core::TaskAPIFactory::CreateFromFile<BertNLClassifier>( - path_to_model_with_metadata, std::move(resolver))); - RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata()); - return std::move(bert_nl_classifier); -} - -StatusOr<std::unique_ptr<BertNLClassifier>> BertNLClassifier::CreateFromBuffer( - const char* model_with_metadata_buffer_data, - size_t model_with_metadata_buffer_size, - std::unique_ptr<tflite::OpResolver> resolver) { - std::unique_ptr<BertNLClassifier> bert_nl_classifier; - ASSIGN_OR_RETURN(bert_nl_classifier, - core::TaskAPIFactory::CreateFromBuffer<BertNLClassifier>( - model_with_metadata_buffer_data, - model_with_metadata_buffer_size, std::move(resolver))); - RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata()); - return std::move(bert_nl_classifier); -} - -StatusOr<std::unique_ptr<BertNLClassifier>> BertNLClassifier::CreateFromFd( - int fd, - std::unique_ptr<tflite::OpResolver> resolver) { - std::unique_ptr<BertNLClassifier> bert_nl_classifier; - ASSIGN_OR_RETURN( - bert_nl_classifier, - core::TaskAPIFactory::CreateFromFileDescriptor<BertNLClassifier>( - fd, std::move(resolver))); - RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata()); - return std::move(bert_nl_classifier); -} - -absl::Status BertNLClassifier::InitializeFromMetadata() { - // Set up mandatory tokenizer. - const ProcessUnit* tokenizer_process_unit = - GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex); - if (tokenizer_process_unit == nullptr) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "No input process unit found from metadata.", - TfLiteSupportStatus::kMetadataInvalidTokenizerError); - } - ASSIGN_OR_RETURN(tokenizer_, - CreateTokenizerFromProcessUnit(tokenizer_process_unit, - GetMetadataExtractor())); - - // Set up optional label vector. - TrySetLabelFromMetadata( - GetMetadataExtractor()->GetOutputTensorMetadata(kOutputTensorIndex)) - .IgnoreError(); - return absl::OkStatus(); -} - -} // namespace nlclassifier -} // namespace text -} // namespace task -} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h deleted file mode 100644 index e78085d..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h +++ /dev/null
@@ -1,106 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_ -#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_ - -#include <stddef.h> - -#include <memory> -#include <string> -#include <vector> - -#include "absl/status/status.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/string_type.h" -#include "tensorflow_lite_support/cc/task/core/category.h" -#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" -#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" - -namespace tflite { -namespace task { -namespace text { -namespace nlclassifier { - -// Classifier API for NLClassification tasks with Bert models, categorizes -// string into different classes. -// -// The API expects a Bert based TFLite model with metadata populated. -// The metadata should contain the following information: -// - input_process_units for Wordpiece/Sentencepiece Tokenizer -// - 3 input tensors with names "ids", "mask" and "segment_ids" -// - 1 output tensor of type float32[1, 2], with a optionally attached label -// file. If a label file is attached, the file should be a plain text file -// with one label per line, the number of labels should match the number of -// categories the model outputs. - -class BertNLClassifier : public NLClassifier { - public: - using NLClassifier::NLClassifier; - // Max number of tokens to pass to the model. - static constexpr int kMaxSeqLen = 512; - - // Factory function to create a BertNLClassifier from TFLite model with - // metadata. - static tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> - CreateFromFile( - const std::string& path_to_model_with_metadata, - std::unique_ptr<tflite::OpResolver> resolver = - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); - - // Factory function to create a BertNLClassifier from in memory buffer of a - // TFLite model with metadata. - static tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> - CreateFromBuffer( - const char* model_with_metadata_buffer_data, - size_t model_with_metadata_buffer_size, - std::unique_ptr<tflite::OpResolver> resolver = - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); - - // Factory function to create a BertNLClassifier from the file descriptor of a - // TFLite model with metadata. - static tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> - CreateFromFd( - int fd, - std::unique_ptr<tflite::OpResolver> resolver = - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); - - protected: - // Run tokenization on input text and construct three input tensors ids, mask - // and segment_ids for the model input. - absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, - const std::string& input) override; - - // Extract model output and create results with label file attached in - // metadata. If no label file is attached, use output score index as labels. - tflite::support::StatusOr<std::vector<core::Category>> Postprocess( - const std::vector<const TfLiteTensor*>& output_tensors, - const std::string& input) override; - - private: - // Initialize the API with the tokenizer and label files set in the metadata. - absl::Status InitializeFromMetadata(); - - std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_; -}; - -} // namespace nlclassifier -} // namespace text -} // namespace task -} // namespace tflite - -#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.cc deleted file mode 100644 index f42e2b5..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.cc +++ /dev/null
@@ -1,72 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h" - -#include <memory> - -#include "absl/strings/string_view.h" -#include "tensorflow_lite_support/cc/task/core/category.h" -#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h" - -using CategoryCPP = ::tflite::task::core::Category; -using BertNLClassifierCPP = - ::tflite::task::text::nlclassifier::BertNLClassifier; - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -struct BertNLClassifier { - std::unique_ptr<BertNLClassifierCPP> impl; -}; - -BertNLClassifier* BertNLClassifierFromFile(const char* model_path) { - auto classifier_status = - BertNLClassifierCPP::CreateFromFile(std::string(model_path)); - if (classifier_status.ok()) { - return new BertNLClassifier{.impl = std::unique_ptr<BertNLClassifierCPP>( - dynamic_cast<BertNLClassifierCPP*>( - classifier_status.value().release()))}; - } else { - return nullptr; - } -} - -Categories* BertNLClassifierClassify(const BertNLClassifier* classifier, - const char* text) { - std::vector<CategoryCPP> results = - classifier->impl->Classify(absl::string_view(text).data()); - size_t size = results.size(); - auto* categories = new Category[size]; - - for (size_t i = 0; i < size; ++i) { - categories[i].text = strdup(results[i].class_name.c_str()); - categories[i].score = results[i].score; - } - - auto* c_categories = new Categories; - c_categories->size = size; - c_categories->categories = categories; - return c_categories; -} - -void BertNLClassifierDelete(BertNLClassifier* classifier) { - delete classifier; -} - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h deleted file mode 100644 index 674521b9..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h +++ /dev/null
@@ -1,61 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_C_API_H_ -#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_C_API_H_ - -#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h" -// -------------------------------------------------------------------------- -/// C API for BertNLClassifier. -/// -/// The API leans towards simplicity and uniformity instead of convenience, as -/// most usage will be by language-specific wrappers. It provides largely the -/// same set of functionality as that of the C++ TensorFlow Lite -/// `BertNLClassifier` API, but is useful for shared libraries where having -/// a stable ABI boundary is important. -/// -/// Usage: -/// <pre><code> -/// // Create the model and interpreter options. -/// BertNLClassifier* classifier = -/// BertNLClassifierFromFile("/path/to/model.tflite"); -/// -/// // classification. -/// Categories* categories = Classify(classifier, context, question); -/// -/// // Dispose of the API object. -/// BertNLClassifierrDelete(classifier); - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef struct BertNLClassifier BertNLClassifier; - -// Creates BertNLClassifier from model path, returns nullptr if the file -// doesn't exist or is not a well formatted TFLite model path. -extern BertNLClassifier* BertNLClassifierFromFile(const char* model_path); - -// Invokes the encapsulated TFLite model and classifies the input text. -extern struct Categories* BertNLClassifierClassify( - const BertNLClassifier* classifier, - const char* text); - -extern void BertNLClassifierDelete(BertNLClassifier* classifier); - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_C_API_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc index 693ea6f..6986bcc 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
@@ -21,10 +21,10 @@ #include <utility> #include <vector> -#include "absl/algorithm/container.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" +#include "absl/algorithm/container.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_cat.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -59,11 +59,24 @@ using ::tflite::task::core::Dequantize; using ::tflite::task::core::GetStringAtIndex; using ::tflite::task::core::PopulateTensor; +using ::tflite::task::core::TaskAPIFactory; +// To differenciate it with the struct option, +// tflite::task::text::nl_classifier::NLClassifierOptions. +using NLClassifierProtoOptions = ::tflite::task::text::NLClassifierOptions; namespace { constexpr int kRegexTokenizerInputTensorIndex = 0; constexpr int kRegexTokenizerProcessUnitIndex = 0; +absl::Status SanityCheckOptions(const NLClassifierProtoOptions& options) { + if (!options.has_base_options()) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Missing mandatory `base_options` field", + TfLiteSupportStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile( const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>* associated_files, @@ -133,13 +146,13 @@ "RegexTokenizer doesn't have <PAD> token.", TfLiteSupportStatus::kMetadataInvalidTokenizerError); } - return regex_tokenizer; + return std::move(regex_tokenizer); } } // namespace const NLClassifierOptions& NLClassifier::GetOptions() const { - return options_; + return struct_options_; } absl::Status NLClassifier::TrySetLabelFromMetadata( @@ -191,7 +204,7 @@ const std::string& input) { TfLiteTensor* input_tensor = FindTensorWithNameOrIndex( input_tensors, GetMetadataExtractor()->GetInputTensorMetadata(), - options_.input_tensor_name, options_.input_tensor_index); + struct_options_.input_tensor_name, struct_options_.input_tensor_index); if (input_tensor == nullptr) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -236,9 +249,9 @@ } } - PopulateTensor(input_tokens, input_tensor); + RETURN_IF_ERROR(PopulateTensor(input_tokens, input_tensor)); } else { - PopulateTensor(input, input_tensor); + RETURN_IF_ERROR(PopulateTensor(input, input_tensor)); } return absl::OkStatus(); } @@ -249,12 +262,12 @@ return BuildResults( FindTensorWithNameOrIndex( output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(), - options_.output_score_tensor_name, - options_.output_score_tensor_index), + struct_options_.output_score_tensor_name, + struct_options_.output_score_tensor_index), FindTensorWithNameOrIndex( - output_tensors, GetMetadataExtractor()->GetInputTensorMetadata(), - options_.output_label_tensor_name, - options_.output_label_tensor_index)); + output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(), + struct_options_.output_label_tensor_name, + struct_options_.output_label_tensor_index)); } std::vector<Category> NLClassifier::BuildResults(const TfLiteTensor* scores, @@ -298,8 +311,23 @@ return predictions; } + +absl::Status NLClassifier::Initialize( + std::unique_ptr<tflite::task::text::NLClassifierOptions> options) { + proto_options_ = std::move(options); + + RETURN_IF_ERROR(Initialize(NLClassifierOptions{ + .input_tensor_index = proto_options_->input_tensor_index(), + .output_score_tensor_index = proto_options_->output_score_tensor_index(), + .output_label_tensor_index = proto_options_->output_label_tensor_index(), + .input_tensor_name = proto_options_->input_tensor_name(), + .output_score_tensor_name = proto_options_->output_score_tensor_name(), + .output_label_tensor_name = proto_options_->output_label_tensor_name()})); + return absl::OkStatus(); +} + absl::Status NLClassifier::Initialize(const NLClassifierOptions& options) { - options_ = options; + struct_options_ = options; // input tensor should be type STRING auto input_tensor = FindTensorWithNameOrIndex( GetInputTensors(), GetMetadataExtractor()->GetInputTensorMetadata(), @@ -399,6 +427,24 @@ return absl::OkStatus(); } +/* static */ +StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromOptions( + const NLClassifierProtoOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + RETURN_IF_ERROR(SanityCheckOptions(options)); + + // Copy options to ensure the ExternalFile outlives the duration of this + // created NLClassifier object. + auto options_copy = absl::make_unique<NLClassifierProtoOptions>(options); + + ASSIGN_OR_RETURN(auto nl_classifier, + TaskAPIFactory::CreateFromBaseOptions<NLClassifier>( + &options_copy->base_options(), std::move(resolver))); + RETURN_IF_ERROR(nl_classifier->Initialize(std::move(options_copy))); + + return nl_classifier; +} + StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromBufferAndOptions( const char* model_buffer_data,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h index 4055a93..331a6e4 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
@@ -23,16 +23,18 @@ #include <string> #include <vector> -#include "absl/status/status.h" +#include "absl/base/macros.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/core/shims/cc/kernels/register.h" #include "tensorflow/lite/string_type.h" #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/core/base_task_api.h" #include "tensorflow_lite_support/cc/task/core/category.h" +#include "tensorflow_lite_support/cc/task/text/proto/nl_classifier_options_proto_inc.h" #include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h" namespace tflite { @@ -87,30 +89,49 @@ public: using BaseTaskApi::BaseTaskApi; - // Creates a NLClassifier from TFLite model buffer. + // Creates an NLClassifier from the provided options. A non-default + // OpResolver can be specified in order to support custom Ops or specify a + // subset of built-in Ops. + // + // This is a forward compatible method that uses + // `tflite::task::text::NLClassifierOptions`. Factory create methods with + // `tflite::task::text::nlclassifier::NLClassifierOptions` will be deprecated. + // + // TODO(b/182537114): unify the classification options (support the common + // classification options) and results across vision/text/audio. + static tflite::support::StatusOr<std::unique_ptr<NLClassifier>> + CreateFromOptions( + const tflite::task::text::NLClassifierOptions& options, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); + + // Creates an NLClassifier from TFLite model buffer. + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") static tflite::support::StatusOr<std::unique_ptr<NLClassifier>> CreateFromBufferAndOptions( const char* model_buffer_data, size_t model_buffer_size, const NLClassifierOptions& options = {}, std::unique_ptr<tflite::OpResolver> resolver = - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); - // Creates a NLClassifier from TFLite model file. + // Creates an NLClassifier from TFLite model file. + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") static tflite::support::StatusOr<std::unique_ptr<NLClassifier>> CreateFromFileAndOptions( const std::string& path_to_model, const NLClassifierOptions& options = {}, std::unique_ptr<tflite::OpResolver> resolver = - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); - // Creates a NLClassifier from TFLite model file descriptor. + // Creates an NLClassifier from TFLite model file descriptor. + ABSL_DEPRECATED("Prefer using `CreateFromOptions`") static tflite::support::StatusOr<std::unique_ptr<NLClassifier>> CreateFromFdAndOptions( int fd, const NLClassifierOptions& options = {}, std::unique_ptr<tflite::OpResolver> resolver = - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); // Performs classification on a string input, returns classified results. std::vector<core::Category> Classify(const std::string& text); @@ -119,7 +140,18 @@ static constexpr int kOutputTensorIndex = 0; static constexpr int kOutputTensorLabelFileIndex = 0; - absl::Status Initialize(const NLClassifierOptions& options); + // Initialize NLClassifier with the proto NLClassifierOptions. + absl::Status Initialize( + std::unique_ptr<tflite::task::text::NLClassifierOptions> options); + + ABSL_DEPRECATED( + "Prefer using `tflite::task::text::NLClassifierOptions` and " + "`Initialize(std::unique_ptr<tflite::task::text::NLClassifierOptions> " + "options)`") + absl::Status Initialize(const NLClassifierOptions& options = {}); + ABSL_DEPRECATED( + "Prefer using `tflite::task::text::NLClassifierOptions` and " + "`CreateFromOptions`") const NLClassifierOptions& GetOptions() const; // Try to extract attached label file from metadata and initialize @@ -170,7 +202,12 @@ bool HasRegexTokenizerMetadata(); absl::Status SetupRegexTokenizer(); - NLClassifierOptions options_; + std::unique_ptr<tflite::task::text::NLClassifierOptions> proto_options_; + + // Deprecated: using the proto_options_ + // (tflite::task::text::NLClassifierOptions). + NLClassifierOptions struct_options_; + // labels vector initialized from output tensor's associated file, if one // exists. std::unique_ptr<std::vector<std::string>> labels_vector_;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.cc deleted file mode 100644 index f498ae9e..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.cc +++ /dev/null
@@ -1,92 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h" - -#include <memory> - -#include "absl/strings/string_view.h" -#include "tensorflow_lite_support/cc/task/core/category.h" -#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" - -using CategoryCPP = ::tflite::task::core::Category; -using NLClassifierCPP = ::tflite::task::text::nlclassifier::NLClassifier; -using NLClassifierOptionsCPP = - ::tflite::task::text::nlclassifier::NLClassifierOptions; - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -struct NLClassifier { - std::unique_ptr<NLClassifierCPP> impl; -}; - -NLClassifier* NLClassifierFromFileAndOptions( - const char* model_path, - const NLClassifierOptions* options) { - auto classifier_status = NLClassifierCPP::CreateFromFileAndOptions( - std::string(model_path), - { - .input_tensor_index = options->input_tensor_index, - .output_score_tensor_index = options->output_score_tensor_index, - .output_label_tensor_index = options->output_label_tensor_index, - .input_tensor_name = !options->input_tensor_name - ? "" - : std::string(options->input_tensor_name), - .output_score_tensor_name = - !options->output_score_tensor_name - ? "" - : std::string(options->output_score_tensor_name), - .output_label_tensor_name = - !options->output_label_tensor_name - ? "" - : std::string(options->output_label_tensor_name), - }); - - if (classifier_status.ok()) { - return new NLClassifier{ - .impl = std::unique_ptr<NLClassifierCPP>(dynamic_cast<NLClassifierCPP*>( - classifier_status.value().release()))}; - } else { - return nullptr; - } -} - -Categories* NLClassifierClassify(const NLClassifier* classifier, - const char* text) { - std::vector<CategoryCPP> results = - classifier->impl->Classify(absl::string_view(text).data()); - size_t size = results.size(); - auto* categories = new Category[size]; - - for (size_t i = 0; i < size; ++i) { - categories[i].text = strdup(results[i].class_name.c_str()); - categories[i].score = results[i].score; - } - - auto* c_categories = new Categories; - c_categories->size = size; - c_categories->categories = categories; - return c_categories; -} - -void NLClassifierDelete(NLClassifier* classifier) { - delete classifier; -} - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h deleted file mode 100644 index 25e2c23..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h +++ /dev/null
@@ -1,71 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_H_ -#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_H_ - -#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h" -// -------------------------------------------------------------------------- -/// C API for NLClassifier. -/// -/// The API leans towards simplicity and uniformity instead of convenience, as -/// most usage will be by language-specific wrappers. It provides largely the -/// same set of functionality as that of the C++ TensorFlow Lite `NLClassifier` -/// API, but is useful for shared libraries where having a stable ABI boundary -/// is important. -/// -/// Usage: -/// <pre><code> -/// // Create the model and interpreter options. -/// NLClassifier* classifier = NLClassifierFromFileAndOptions( -/// "/path/to/model.tflite"); -/// -/// // classification. -/// Categories* categories = Classify(classifier, context, question); -/// -/// // Dispose of the API object. -/// NLClassifierDelete(classifier); - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef struct NLClassifier NLClassifier; - -struct NLClassifierOptions { - int input_tensor_index; - int output_score_tensor_index; - int output_label_tensor_index; - const char* input_tensor_name; - const char* output_score_tensor_name; - const char* output_label_tensor_name; -}; - -// Creates NLClassifier from model path and options, returns nullptr if the file -// doesn't exist or is not a well formatted TFLite model path. -extern NLClassifier* NLClassifierFromFileAndOptions( - const char* model_path, - const struct NLClassifierOptions* options); - -// Invokes the encapsulated TFLite model and classifies the input text. -extern struct Categories* NLClassifierClassify(const NLClassifier* classifier, - const char* text); - -extern void NLClassifierDelete(NLClassifier* classifier); - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.cc deleted file mode 100644 index 817ef3d..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.cc +++ /dev/null
@@ -1,29 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -void NLClassifierCategoriesDelete(Categories* categories) { - delete[] categories->categories; - delete categories; -} - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h deleted file mode 100644 index 663c873..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h +++ /dev/null
@@ -1,43 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_COMMON_H_ -#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_COMMON_H_ - -// Common structs shared between NLClassifier APIs -// -/// // Dispose of the Categories object. -/// NLClassifierCategoriesDelete(categories); - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -struct Category { - char* text; - double score; -}; - -struct Categories { - int size; - struct Category* categories; -}; - -extern void NLClassifierCategoriesDelete(struct Categories* categories); - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_COMMON_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/BUILD new file mode 100644 index 0000000..abc1f24 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/BUILD
@@ -0,0 +1,104 @@ +load("//tensorflow_lite_support/cc/port:build_defs.bzl", "support_cc_proto_library") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +proto_library( + name = "nl_classifier_options_proto", + srcs = ["nl_classifier_options.proto"], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto", + ], +) + +support_cc_proto_library( + name = "nl_classifier_options_cc_proto", + deps = [ + ":nl_classifier_options_proto", + ], +) + +cc_library( + name = "nl_classifier_options_proto_inc", + hdrs = ["nl_classifier_options_proto_inc.h"], + deps = [ + ":nl_classifier_options_cc_proto", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", + ], +) + +proto_library( + name = "bert_nl_classifier_options_proto", + srcs = ["bert_nl_classifier_options.proto"], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto", + ], +) + +cc_proto_library( + name = "bert_nl_classifier_options_cc_proto", + deps = [ + ":bert_nl_classifier_options_proto", + ], +) + +cc_library( + name = "bert_nl_classifier_options_proto_inc", + hdrs = ["bert_nl_classifier_options_proto_inc.h"], + deps = [ + ":bert_nl_classifier_options_cc_proto", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", + ], +) + +proto_library( + name = "bert_question_answerer_options_proto", + srcs = ["bert_question_answerer_options.proto"], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto", + ], +) + +cc_proto_library( + name = "bert_question_answerer_options_cc_proto", + deps = [ + ":bert_question_answerer_options_proto", + ], +) + +cc_library( + name = "bert_question_answerer_options_proto_inc", + hdrs = ["bert_question_answerer_options_proto_inc.h"], + deps = [ + ":bert_question_answerer_options_cc_proto", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", + ], +) + +proto_library( + name = "retrieval_proto", + srcs = ["retrieval.proto"], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto", + "//tensorflow_lite_support/cc/task/processor/proto:embedding_proto", + ], +) + +cc_proto_library( + name = "retrieval_cc_proto", + deps = [ + ":retrieval_proto", + ], +) + +cc_library( + name = "retrieval_proto_inc", + hdrs = ["retrieval_proto_inc.h"], + deps = [ + ":retrieval_cc_proto", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:embeddings_proto_inc", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options.proto new file mode 100644 index 0000000..505ccef --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options.proto
@@ -0,0 +1,34 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.text; + +import "tensorflow_lite_support/cc/task/core/proto/base_options.proto"; + +// Options for setting up a BertNLClassifier. +// Next Id: 3 +message BertNLClassifierOptions { + // Base options for configuring BertNLClassifier, such as specifying the + // TfLite model file with metadata, accelerator options, etc. + optional tflite.task.core.BaseOptions base_options = 1; + + // Max number of tokens to pass to the model. + // + // Deprecated: max_seq_len is now read from the model (i.e. input tensor size) + // automatically. + optional int32 max_seq_len = 2 [default = 128]; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options_proto_inc.h new file mode 100644 index 0000000..59cfddb --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options_proto_inc.h
@@ -0,0 +1,22 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_PROTO_BERT_NL_CLASSIFIER_OPTIONS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_PROTO_BERT_NL_CLASSIFIER_OPTIONS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options.pb.h" + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_PROTO_BERT_NL_CLASSIFIER_OPTIONS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/bert_question_answerer_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/bert_question_answerer_options.proto new file mode 100644 index 0000000..2b5bcc5f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/bert_question_answerer_options.proto
@@ -0,0 +1,28 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.text; + +import "tensorflow_lite_support/cc/task/core/proto/base_options.proto"; + +// Options for setting up a BertQuestionAnswerer. +// Next Id: 2 +message BertQuestionAnswererOptions { + // Base options for configuring BertQuestionAnswerer, such as specifying the + // TfLite model file with metadata, accelerator options, etc. + optional tflite.task.core.BaseOptions base_options = 1; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/bert_question_answerer_options_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/bert_question_answerer_options_proto_inc.h new file mode 100644 index 0000000..993d420 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/bert_question_answerer_options_proto_inc.h
@@ -0,0 +1,22 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_PROTO_BERT_QUESTION_ANSWERER_OPTIONS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_PROTO_BERT_QUESTION_ANSWERER_OPTIONS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/text/proto/bert_question_answerer_options.pb.h" + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_PROTO_BERT_QUESTION_ANSWERER_OPTIONS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/nl_classifier_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/nl_classifier_options.proto new file mode 100644 index 0000000..6ebb459e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/nl_classifier_options.proto
@@ -0,0 +1,111 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.text; + +import "tensorflow_lite_support/cc/task/core/proto/base_options.proto"; + +// Options for setting up an NLClassifier. +// Next Id: 8 +message NLClassifierOptions { + // Base options for configuring NLClassifier, such as specifying the + // TfLite model file with metadata, accelerator options, etc. + optional tflite.task.core.BaseOptions base_options = 1; + + // **************************************************** + // Configure the input/output tensors for NLClassifier: + // + // - No special configuration is needed if the model has only one input tensor + // and one output tensor. + // + // - When the model has multiple input or output tensors, use the following + // configurations to specifiy the desired tensors: + // -- tensor names: `input_tensor_name`, `output_score_tensor_name`, + // `output_label_tensor_name` + // -- tensor indices: `input_tensor_index`, `output_score_tensor_index`, + // `output_label_tensor_index` + // Tensor names has higher priorities than tensor indices in locating the + // tensors. It means the tensors will be first located according to tensor + // names. If not found, then the tensors will be located according to tensor + // indices. + // + // - Failing to match the input text tensor or output score tensor with + // neither tensor names nor tensor indices will trigger a runtime error. + // However, failing to locate the output label tensor will not trigger an + // error because the label tensor is optional. + // **************************************************** + + // Name of the input text tensor, if the model has multiple inputs. Only the + // input tensor specified will be used for inference; other input tensors will + // be ignored. Default to "INPUT". + // + // See the "Configure the input/output tensors for NLClassifier" section above + // for more details. + optional string input_tensor_name = 2 [default = "INPUT"]; + + // Name of the output score tensor, if the model has multiple outputs. Default + // to "OUTPUT_SCORE". + // + // See the "Configure the input/output tensors for NLClassifier" section above + // for more details. + optional string output_score_tensor_name = 3 [default = "OUTPUT_SCORE"]; + + // Name of the output label tensor, if the model has multiple outputs. Default + // to "OUTPUT_LABEL". + // + // See the "Configure the input/output tensors for NLClassifier" section above + // for more details. + // + // By default, label file should be packed with + // the output score tensor through Model Metadata. See the MetadataWriter for + // NLClassifier [1]. NLClassifier reads and parses labels from the label + // file automatically. However, some models may output a specific label tensor + // instead. In this case, NLClassifier reads labels from the output label + // tensor. + // + // [1]: + // https://www.tensorflow.org/lite/convert/metadata_writer_tutorial#natural_language_classifiers + optional string output_label_tensor_name = 4 [default = "OUTPUT_LABEL"]; + + // Index of the input text tensor among all input tensors, if the model has + // multiple inputs. Only the input tensor specified will be used for + // inference; other input tensors will be ignored. Default to 0. + // + // See the "Configure the input/output tensors for NLClassifier" section above + // for more details. + optional int32 input_tensor_index = 5 [default = 0]; + + // Index of the output score tensor among all output tensors, if the model has + // multiple outputs. Default to 0. + // + // See the "Configure the input/output tensors for NLClassifier" section above + // for more details. + optional int32 output_score_tensor_index = 6 [default = 0]; + + // Index of the optional output label tensor among all output tensors, if the + // model has multiple outputs. + // + // See the comment above `output_label_tensor_name` for more information about + // what the output label tensor is. + // + // See the "Configure the input/output tensors for NLClassifier" section above + // for more details. + // + // `output_label_tensor_index` defaults to -1, meaning to disable searching + // the output label tensor as it might be optional. + optional int32 output_label_tensor_index = 7 [default = -1]; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/nl_classifier_options_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/nl_classifier_options_proto_inc.h new file mode 100644 index 0000000..9c2a2ac --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/nl_classifier_options_proto_inc.h
@@ -0,0 +1,22 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_PROTO_NL_CLASSIFIER_OPTIONS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_PROTO_NL_CLASSIFIER_OPTIONS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/text/proto/nl_classifier_options.pb.h" + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_PROTO_NL_CLASSIFIER_OPTIONS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/retrieval.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/retrieval.proto new file mode 100644 index 0000000..21c0fb2 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/retrieval.proto
@@ -0,0 +1,89 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +syntax = "proto2"; + +package tflite.task.text; + +import "tensorflow_lite_support/cc/task/core/proto/base_options.proto"; +import "tensorflow_lite_support/cc/task/processor/proto/embedding.proto"; + +// Input message for retrieval models that encode strings into vectors. +// The input is a tuple of query text, and one or more response text with their +// context. +// +// Retrieve expects both query_text and responses are not empty. +message RetrievalInput { + // Input text as the user query (e.g. "When is Father's Day?"). + optional string query_text = 1; + + // A list of response entries. + repeated ResponseEntry responses = 2; + // Next Id: 3 +} + +// Response entry includes raw text or text_encoding. +message ResponseEntry { + // We allow users to pass: + // (1) raw_text to encode + // Or (2) cached text_encoding (save computation and faster). + oneof response_options { + RawText raw_text = 1; + tflite.task.processor.FeatureVector text_encoding = 2; + // Next Id: 3 + } + + // Raw text contains the text and its context. + message RawText { + // Text for the response (e.g. "In the US, it falls on the third Sunday in + // June."). + optional string text = 1; + + // Context for the response, such as the surrounding text or background + // information (e.g. "Father's Day is a celebration honoring fathers and + // celebrating fatherhood, paternal bonds, and the influence of fathers + // in society."). + optional string context = 2; + // Next Id: 3 + } +} + +// The result for response entry. +message ResponseResult { + // The encoded vector for the response. + optional tflite.task.processor.FeatureVector encoding = 1; + + // The score measured by encodings of query and response. + optional float score = 2; + // Next Id: 3 +} + +// Output message for retrieval. The retrieval model encodes query and response +// into respective vectors, and calculates their similarity score. +message RetrievalOutput { + // The encoded vector for the query. + optional tflite.task.processor.FeatureVector query_encoding = 1; + + // Results corresponding to responses in the same input order. + repeated ResponseResult response_results = 2; + // Next Id: 3 +} + +// Options for setting up retrieval models. +message RetrievalOptions { + // Base options for configuring retrieval models, such as specifying the + // TfLite model file with metadata, accelerator options, etc. + optional tflite.task.core.BaseOptions base_options = 1; + // Next Id: 2 +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/retrieval_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/retrieval_proto_inc.h new file mode 100644 index 0000000..7d7bb8f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/retrieval_proto_inc.h
@@ -0,0 +1,22 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_PROTO_RETRIEVAL_PROTO_INC_H_ +#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_PROTO_RETRIEVAL_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/text/proto/retrieval.pb.h" + +#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_PROTO_RETRIEVAL_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/BUILD deleted file mode 100644 index 49ad5a1..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/BUILD +++ /dev/null
@@ -1,61 +0,0 @@ -package( - default_visibility = ["//tensorflow_lite_support:users"], - licenses = ["notice"], # Apache 2.0 -) - -exports_files([ - "bert_qa_c_api.h", -]) - -cc_library( - name = "question_answerer", - hdrs = [ - "question_answerer.h", - ], - deps = [ - "//tensorflow_lite_support/cc/task/core:base_task_api", - "//tensorflow_lite_support/cc/task/core:tflite_engine", - ], -) - -cc_library( - name = "bert_question_answerer", - srcs = [ - "bert_question_answerer.cc", - ], - hdrs = [ - "bert_question_answerer.h", - ], - deps = [ - ":question_answerer", - "//tensorflow_lite_support/cc/port:status_macros", - "//tensorflow_lite_support/cc/port:statusor", - "//tensorflow_lite_support/cc/task/core:base_task_api", - "//tensorflow_lite_support/cc/task/core:task_api_factory", - "//tensorflow_lite_support/cc/task/core:task_utils", - "//tensorflow_lite_support/cc/task/core:tflite_engine", - "//tensorflow_lite_support/cc/text/tokenizers:bert_tokenizer", - "//tensorflow_lite_support/cc/text/tokenizers:sentencepiece_tokenizer", - "//tensorflow_lite_support/cc/text/tokenizers:tokenizer", - "//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils", - "//tensorflow_lite_support/metadata:metadata_schema_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "bert_qa_c_api", - srcs = [ - "bert_qa_c_api.cc", - ], - hdrs = [ - "bert_qa_c_api.h", - ], - visibility = ["//tensorflow_lite_support:__subpackages__"], - deps = [ - ":bert_question_answerer", - ":question_answerer", - ], -)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.cc deleted file mode 100644 index 3dc7a28..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.cc +++ /dev/null
@@ -1,80 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h" - -#include <memory> - -#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h" -#include "tensorflow_lite_support/cc/task/text/qa/question_answerer.h" - -using BertQuestionAnswererCPP = ::tflite::task::text::qa::BertQuestionAnswerer; -using QaAnswerCPP = ::tflite::task::text::qa::QaAnswer; - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -struct BertQuestionAnswerer { - std::unique_ptr<BertQuestionAnswererCPP> impl; -}; - -BertQuestionAnswerer* BertQuestionAnswererFromFile(const char* model_path) { - auto bert_qa_status = - BertQuestionAnswererCPP::CreateFromFile(std::string(model_path)); - if (bert_qa_status.ok()) { - return new BertQuestionAnswerer{ - .impl = std::unique_ptr<BertQuestionAnswererCPP>( - dynamic_cast<BertQuestionAnswererCPP*>( - bert_qa_status.value().release()))}; - } else { - return nullptr; - } -} - -QaAnswers* BertQuestionAnswererAnswer( - const BertQuestionAnswerer* question_answerer, - const char* context, - const char* question) { - std::vector<QaAnswerCPP> answers = question_answerer->impl->Answer( - absl::string_view(context).data(), absl::string_view(question).data()); - size_t size = answers.size(); - auto* qa_answers = new QaAnswer[size]; - - for (size_t i = 0; i < size; ++i) { - qa_answers[i].start = answers[i].pos.start; - qa_answers[i].end = answers[i].pos.end; - qa_answers[i].logit = answers[i].pos.logit; - qa_answers[i].text = strdup(answers[i].text.c_str()); - } - - auto* c_answers = new QaAnswers; - c_answers->size = size; - c_answers->answers = qa_answers; - return c_answers; -} - -void BertQuestionAnswererDelete(BertQuestionAnswerer* bert_question_answerer) { - delete bert_question_answerer; -} - -void BertQuestionAnswererQaAnswersDelete(QaAnswers* qa_answers) { - delete[] qa_answers->answers; - delete qa_answers; -} - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h deleted file mode 100644 index 6cd27ee..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h +++ /dev/null
@@ -1,79 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QA_C_API_H_ -#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QA_C_API_H_ - -// -------------------------------------------------------------------------- -/// C API for BertQuestionAnswerer. -/// -/// The API leans towards simplicity and uniformity instead of convenience, as -/// most usage will be by language-specific wrappers. It provides largely the -/// same set of functionality as that of the C++ TensorFlow Lite -/// `BertQuestionAnswerer` API, but is useful for shared libraries where having -/// a stable ABI boundary is important. -/// -/// Usage: -/// <pre><code> -/// // Create the model and interpreter options. -/// BertQuestionAnswerer* qa_answerer = -/// BertQuestionAnswererFromFile("/path/to/model.tflite"); -/// -/// // answer a question. -/// QaAnswers* answers = Answer(qa_answerer, context, question); -/// -/// // Dispose of the API and QaAnswers objects. -/// BertQuestionAnswererDelete(qa_answerer); -/// BertQuestionAnswererQaAnswersDelete(answers); - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef struct BertQuestionAnswerer BertQuestionAnswerer; - -struct QaAnswer { - int start; - int end; - float logit; - char* text; -}; - -struct QaAnswers { - int size; - struct QaAnswer* answers; -}; - -// Creates BertQuestionAnswerer from model path, returns nullptr if the file -// doesn't exist or is not a well formatted TFLite model path. -extern BertQuestionAnswerer* BertQuestionAnswererFromFile( - const char* model_path); - -// Invokes the encapsulated TFLite model and answers a question based on -// context. -extern struct QaAnswers* BertQuestionAnswererAnswer( - const BertQuestionAnswerer* question_answerer, - const char* context, - const char* question); - -extern void BertQuestionAnswererDelete( - BertQuestionAnswerer* bert_question_answerer); - -extern void BertQuestionAnswererQaAnswersDelete(struct QaAnswers* qa_answers); - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QA_C_API_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc deleted file mode 100644 index 286f6f6..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc +++ /dev/null
@@ -1,403 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h" - -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "tensorflow_lite_support/cc/port/status_macros.h" -#include "tensorflow_lite_support/cc/task/core/task_utils.h" -#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" -#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h" -#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" - -namespace tflite { -namespace task { -namespace text { -namespace qa { - -constexpr char kIdsTensorName[] = "ids"; -constexpr char kMaskTensorName[] = "mask"; -constexpr char kSegmentIdsTensorName[] = "segment_ids"; -constexpr char kEndLogitsTensorName[] = "end_logits"; -constexpr char kStartLogitsTensorName[] = "start_logits"; - -using ::tflite::support::CreateStatusWithPayload; -using ::tflite::support::StatusOr; -using ::tflite::support::TfLiteSupportStatus; -using ::tflite::support::text::tokenizer::BertTokenizer; -using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit; -using ::tflite::support::text::tokenizer::SentencePieceTokenizer; -using ::tflite::support::text::tokenizer::TokenizerResult; -using ::tflite::task::core::FindTensorByName; -using ::tflite::task::core::PopulateTensor; -using ::tflite::task::core::PopulateVector; -using ::tflite::task::core::ReverseSortIndices; - -namespace { -constexpr int kTokenizerProcessUnitIndex = 0; -} - -StatusOr<std::unique_ptr<QuestionAnswerer>> -BertQuestionAnswerer::CreateFromFile( - const std::string& path_to_model_with_metadata) { - std::unique_ptr<BertQuestionAnswerer> api_to_init; - ASSIGN_OR_RETURN( - api_to_init, - core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>( - path_to_model_with_metadata, - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), - kNumLiteThreads)); - RETURN_IF_ERROR(api_to_init->InitializeFromMetadata()); - return api_to_init; -} - -StatusOr<std::unique_ptr<QuestionAnswerer>> -BertQuestionAnswerer::CreateFromBuffer( - const char* model_with_metadata_buffer_data, - size_t model_with_metadata_buffer_size) { - std::unique_ptr<BertQuestionAnswerer> api_to_init; - ASSIGN_OR_RETURN( - api_to_init, - core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>( - model_with_metadata_buffer_data, model_with_metadata_buffer_size, - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), - kNumLiteThreads)); - RETURN_IF_ERROR(api_to_init->InitializeFromMetadata()); - return api_to_init; -} - -StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateFromFd( - int fd) { - std::unique_ptr<BertQuestionAnswerer> api_to_init; - ASSIGN_OR_RETURN( - api_to_init, - core::TaskAPIFactory::CreateFromFileDescriptor<BertQuestionAnswerer>( - fd, absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), - kNumLiteThreads)); - RETURN_IF_ERROR(api_to_init->InitializeFromMetadata()); - return api_to_init; -} - -StatusOr<std::unique_ptr<QuestionAnswerer>> -BertQuestionAnswerer::CreateBertQuestionAnswererFromFile( - const std::string& path_to_model, - const std::string& path_to_vocab) { - std::unique_ptr<BertQuestionAnswerer> api_to_init; - ASSIGN_OR_RETURN( - api_to_init, - core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>( - path_to_model, - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), - kNumLiteThreads)); - api_to_init->InitializeBertTokenizer(path_to_vocab); - return api_to_init; -} - -StatusOr<std::unique_ptr<QuestionAnswerer>> -BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer( - const char* model_buffer_data, - size_t model_buffer_size, - const char* vocab_buffer_data, - size_t vocab_buffer_size) { - std::unique_ptr<BertQuestionAnswerer> api_to_init; - ASSIGN_OR_RETURN( - api_to_init, - core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>( - model_buffer_data, model_buffer_size, - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), - kNumLiteThreads)); - api_to_init->InitializeBertTokenizerFromBinary(vocab_buffer_data, - vocab_buffer_size); - return api_to_init; -} - -StatusOr<std::unique_ptr<QuestionAnswerer>> -BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile( - const std::string& path_to_model, - const std::string& path_to_spmodel) { - std::unique_ptr<BertQuestionAnswerer> api_to_init; - ASSIGN_OR_RETURN( - api_to_init, - core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>( - path_to_model, - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), - kNumLiteThreads)); - api_to_init->InitializeSentencepieceTokenizer(path_to_spmodel); - return api_to_init; -} - -StatusOr<std::unique_ptr<QuestionAnswerer>> -BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer( - const char* model_buffer_data, - size_t model_buffer_size, - const char* spmodel_buffer_data, - size_t spmodel_buffer_size) { - std::unique_ptr<BertQuestionAnswerer> api_to_init; - ASSIGN_OR_RETURN( - api_to_init, - core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>( - model_buffer_data, model_buffer_size, - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), - kNumLiteThreads)); - api_to_init->InitializeSentencepieceTokenizerFromBinary(spmodel_buffer_data, - spmodel_buffer_size); - return api_to_init; -} - -std::vector<QaAnswer> BertQuestionAnswerer::Answer( - const std::string& context, - const std::string& question) { - // The BertQuestionAnswererer implementation for Preprocess() and - // Postprocess() never returns errors: just call value(). - return Infer(context, question).value(); -} - -absl::Status BertQuestionAnswerer::Preprocess( - const std::vector<TfLiteTensor*>& input_tensors, - const std::string& context, - const std::string& query) { - auto* input_tensor_metadatas = - GetMetadataExtractor()->GetInputTensorMetadata(); - TfLiteTensor* ids_tensor = - input_tensor_metadatas - ? FindTensorByName(input_tensors, input_tensor_metadatas, - kIdsTensorName) - : input_tensors[0]; - TfLiteTensor* mask_tensor = - input_tensor_metadatas - ? FindTensorByName(input_tensors, input_tensor_metadatas, - kMaskTensorName) - : input_tensors[1]; - TfLiteTensor* segment_ids_tensor = - input_tensor_metadatas - ? FindTensorByName(input_tensors, input_tensor_metadatas, - kSegmentIdsTensorName) - : input_tensors[2]; - - token_to_orig_map_.clear(); - - // The orig_tokens is used for recovering the answer string from the index, - // while the processed_tokens is lower-cased and used to generate input of - // the model. - orig_tokens_ = absl::StrSplit(context, absl::ByChar(' '), absl::SkipEmpty()); - std::vector<std::string> processed_tokens(orig_tokens_); - - std::string processed_query = query; - if (kUseLowerCase) { - for (auto& token : processed_tokens) { - absl::AsciiStrToLower(&token); - } - absl::AsciiStrToLower(&processed_query); - } - - TokenizerResult query_tokenize_results; - query_tokenize_results = tokenizer_->Tokenize(processed_query); - - std::vector<std::string> query_tokens = query_tokenize_results.subwords; - if (query_tokens.size() > kMaxQueryLen) { - query_tokens.resize(kMaxQueryLen); - } - - // Example: - // context: tokenize me please - // all_doc_tokens: token ##ize me plea ##se - // token_to_orig_index: [0, 0, 1, 2, 2] - - std::vector<std::string> all_doc_tokens; - std::vector<int> token_to_orig_index; - for (size_t i = 0; i < processed_tokens.size(); i++) { - const std::string& token = processed_tokens[i]; - std::vector<std::string> sub_tokens = tokenizer_->Tokenize(token).subwords; - for (const std::string& sub_token : sub_tokens) { - token_to_orig_index.emplace_back(i); - all_doc_tokens.emplace_back(sub_token); - } - } - - // -3 accounts for [CLS], [SEP] and [SEP]. - int max_context_len = kMaxSeqLen - query_tokens.size() - 3; - if (all_doc_tokens.size() > max_context_len) { - all_doc_tokens.resize(max_context_len); - } - - std::vector<std::string> tokens; - tokens.reserve(3 + query_tokens.size() + all_doc_tokens.size()); - std::vector<int> segment_ids; - segment_ids.reserve(kMaxSeqLen); - - // Start of generating the features. - tokens.emplace_back("[CLS]"); - segment_ids.emplace_back(0); - - // For query input. - for (const auto& query_token : query_tokens) { - tokens.emplace_back(query_token); - segment_ids.emplace_back(0); - } - - // For Separation. - tokens.emplace_back("[SEP]"); - segment_ids.emplace_back(0); - - // For Text Input. - for (int i = 0; i < all_doc_tokens.size(); i++) { - auto& doc_token = all_doc_tokens[i]; - tokens.emplace_back(doc_token); - segment_ids.emplace_back(1); - token_to_orig_map_[tokens.size()] = token_to_orig_index[i]; - } - - // For ending mark. - tokens.emplace_back("[SEP]"); - segment_ids.emplace_back(1); - - std::vector<int> input_ids(tokens.size()); - input_ids.reserve(kMaxSeqLen); - // Convert tokens back into ids - for (int i = 0; i < tokens.size(); i++) { - auto& token = tokens[i]; - tokenizer_->LookupId(token, &input_ids[i]); - } - - std::vector<int> input_mask; - input_mask.reserve(kMaxSeqLen); - input_mask.insert(input_mask.end(), tokens.size(), 1); - - int zeros_to_pad = kMaxSeqLen - input_ids.size(); - input_ids.insert(input_ids.end(), zeros_to_pad, 0); - input_mask.insert(input_mask.end(), zeros_to_pad, 0); - segment_ids.insert(segment_ids.end(), zeros_to_pad, 0); - - // input_ids INT32[1, 384] - PopulateTensor(input_ids, ids_tensor); - // input_mask INT32[1, 384] - PopulateTensor(input_mask, mask_tensor); - // segment_ids INT32[1, 384] - PopulateTensor(segment_ids, segment_ids_tensor); - - return absl::OkStatus(); -} - -StatusOr<std::vector<QaAnswer>> BertQuestionAnswerer::Postprocess( - const std::vector<const TfLiteTensor*>& output_tensors, - const std::string& /*lowercased_context*/, - const std::string& /*lowercased_query*/) { - auto* output_tensor_metadatas = - GetMetadataExtractor()->GetOutputTensorMetadata(); - - const TfLiteTensor* end_logits_tensor = - output_tensor_metadatas - ? FindTensorByName(output_tensors, output_tensor_metadatas, - kEndLogitsTensorName) - : output_tensors[0]; - const TfLiteTensor* start_logits_tensor = - output_tensor_metadatas - ? FindTensorByName(output_tensors, output_tensor_metadatas, - kStartLogitsTensorName) - : output_tensors[1]; - - std::vector<float> end_logits; - std::vector<float> start_logits; - - // end_logits FLOAT[1, 384] - PopulateVector(end_logits_tensor, &end_logits); - // start_logits FLOAT[1, 384] - PopulateVector(start_logits_tensor, &start_logits); - - auto start_indices = ReverseSortIndices(start_logits); - auto end_indices = ReverseSortIndices(end_logits); - - std::vector<QaAnswer::Pos> orig_results; - for (int start_index = 0; start_index < kPredictAnsNum; start_index++) { - for (int end_index = 0; end_index < kPredictAnsNum; end_index++) { - int start = start_indices[start_index]; - int end = end_indices[end_index]; - - if (!token_to_orig_map_.contains(start + kOutputOffset) || - !token_to_orig_map_.contains(end + kOutputOffset) || end < start || - (end - start + 1) > kMaxAnsLen) { - continue; - } - orig_results.emplace_back( - QaAnswer::Pos(start, end, start_logits[start] + end_logits[end])); - } - } - - std::sort(orig_results.begin(), orig_results.end()); - - std::vector<QaAnswer> answers; - for (int i = 0; i < orig_results.size() && i < kPredictAnsNum; i++) { - auto orig_pos = orig_results[i]; - answers.emplace_back( - orig_pos.start > 0 ? ConvertIndexToString(orig_pos.start, orig_pos.end) - : "", - orig_pos); - } - - return answers; -} - -std::string BertQuestionAnswerer::ConvertIndexToString(int start, int end) { - int start_index = token_to_orig_map_[start + kOutputOffset]; - int end_index = token_to_orig_map_[end + kOutputOffset]; - - return absl::StrJoin(orig_tokens_.begin() + start_index, - orig_tokens_.begin() + end_index + 1, " "); -} - -absl::Status BertQuestionAnswerer::InitializeFromMetadata() { - const ProcessUnit* tokenizer_process_unit = - GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex); - if (tokenizer_process_unit == nullptr) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "No input process unit found from metadata.", - TfLiteSupportStatus::kMetadataInvalidTokenizerError); - } - ASSIGN_OR_RETURN(tokenizer_, - CreateTokenizerFromProcessUnit(tokenizer_process_unit, - GetMetadataExtractor())); - return absl::OkStatus(); -} - -void BertQuestionAnswerer::InitializeBertTokenizer( - const std::string& path_to_vocab) { - tokenizer_ = absl::make_unique<BertTokenizer>(path_to_vocab); -} - -void BertQuestionAnswerer::InitializeBertTokenizerFromBinary( - const char* vocab_buffer_data, - size_t vocab_buffer_size) { - tokenizer_ = - absl::make_unique<BertTokenizer>(vocab_buffer_data, vocab_buffer_size); -} - -void BertQuestionAnswerer::InitializeSentencepieceTokenizer( - const std::string& path_to_spmodel) { - tokenizer_ = absl::make_unique<SentencePieceTokenizer>(path_to_spmodel); -} - -void BertQuestionAnswerer::InitializeSentencepieceTokenizerFromBinary( - const char* spmodel_buffer_data, - size_t spmodel_buffer_size) { - tokenizer_ = absl::make_unique<SentencePieceTokenizer>(spmodel_buffer_data, - spmodel_buffer_size); -} - -} // namespace qa -} // namespace text -} // namespace task -} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h deleted file mode 100644 index 54f4b102..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h +++ /dev/null
@@ -1,171 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_ -#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_ - -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "tensorflow_lite_support/cc/port/statusor.h" -#include "tensorflow_lite_support/cc/task/core/base_task_api.h" -#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" -#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" -#include "tensorflow_lite_support/cc/task/text/qa/question_answerer.h" -#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" -#include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h" - -namespace tflite { -namespace task { -namespace text { -namespace qa { - -// BertQA task API, performs tokenization for models (BERT, Albert, etc.) in -// preprocess and returns most possible answers. -// -// In particular, the branch of BERT models use WordPiece tokenizer, and the -// branch of Albert models use SentencePiece tokenizer, respectively. -// -// Factory methods: -// CreateFromFile(path_to_model_with_metadata) -// CreateFromBuffer(model_with_metadata_buffer_data, -// model_with_metadata_buffer_size) -// CreateFromFd(file_descriptor_to_model_with_metadata) -// Generic API to create the QuestionAnswerer for bert models with metadata -// populated. The API expects a Bert based TFLite model with metadata -// containing the following information: -// - input_process_units for Wordpiece/Sentencepiece Tokenizer. Wordpiece -// Tokenizer can be used for a MobileBert[0] model, Sentencepiece -// Tokenizer Tokenizer can be used for an Albert[1] model -// - 3 input tensors with names "ids", "mask" and "segment_ids" -// - 2 output tensors with names "end_logits" and "start_logits" -// [0]: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 -// [1]: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 -// -// CreateBertQuestionAnswererFromFile(path_to_model, path_to_vocab) -// Creates a BertQuestionAnswerer from TFLite model file and vocab file for -// WordPiece tokenizer. Used in C++ environment. -// One suitable model is: -// https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 -// -// CreateBertQuestionAnswererFromBuffer(model_buffer_data, model_buffer_size, -// vocab_buffer_data, vocab_buffer_size) -// Creates a BertQuestionAnswerer from TFLite model buffer and vocab file -// buffer for WordPiece tokenizer. Used in Jave (JNI) environment. -// -// CreateAlbertQuestionAnswererFromFile(path_to_model, path_to_spmodel) -// Creates an AlbertQuestionAnswerer from TFLite model file and -// SentencePiece model file. Used in C++ environment. -// One suitable model is: -// https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 -// -// CreateAlbertQuestionAnswererFromBuffer(model_buffer_data, -// model_buffer_size, -// spmodel_buffer_data, -// spmodel_buffer_size) -// Creates an AlbertQuestionAnswerer from TFLite model file buffer and -// SentencePiece model file buffer. Used in Jave (JNI) environment. -// - -class BertQuestionAnswerer : public QuestionAnswerer { - public: - // TODO(b/150904655): add support to parameterize. - static constexpr int kMaxQueryLen = 64; - static constexpr int kMaxSeqLen = 384; - static constexpr int kPredictAnsNum = 5; - static constexpr int kMaxAnsLen = 32; - // TODO(b/151954803): clarify the offset usage - static constexpr int kOutputOffset = 1; - static constexpr int kNumLiteThreads = 4; - static constexpr bool kUseLowerCase = true; - - static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> - CreateFromFile(const std::string& path_to_model_with_metadata); - - static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> - CreateFromBuffer(const char* model_with_metadata_buffer_data, - size_t model_with_metadata_buffer_size); - - static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> - CreateFromFd(int fd); - - static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> - CreateBertQuestionAnswererFromFile(const std::string& path_to_model, - const std::string& path_to_vocab); - - static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> - CreateBertQuestionAnswererFromBuffer(const char* model_buffer_data, - size_t model_buffer_size, - const char* vocab_buffer_data, - size_t vocab_buffer_size); - - static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> - CreateAlbertQuestionAnswererFromFile(const std::string& path_to_model, - const std::string& path_to_spmodel); - - static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> - CreateAlbertQuestionAnswererFromBuffer(const char* model_buffer_data, - size_t model_buffer_size, - const char* spmodel_buffer_data, - size_t spmodel_buffer_size); - - explicit BertQuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine) - : QuestionAnswerer(std::move(engine)) {} - - // Answers question based on the context. Could be empty if no answer was - // found from the given context. - std::vector<QaAnswer> Answer(const std::string& context, - const std::string& question) override; - - private: - absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, - const std::string& lowercased_context, - const std::string& lowercased_query) override; - - tflite::support::StatusOr<std::vector<QaAnswer>> Postprocess( - const std::vector<const TfLiteTensor*>& output_tensors, - const std::string& lowercased_context, - const std::string& lowercased_query) override; - - // Initialize API with a BertTokenizer from the vocabulary file. - void InitializeBertTokenizer(const std::string& path_to_vocab); - // Initialize API with a BertTokenizer from the vocabulary buffer. - void InitializeBertTokenizerFromBinary(const char* vocab_buffer_data, - size_t vocab_buffer_size); - - // Initialize API with a SentencepieceTokenizer from the model file. - void InitializeSentencepieceTokenizer(const std::string& path_to_spmodel); - // Initialize API with a SentencepieceTokenizer from the model buffer. - void InitializeSentencepieceTokenizerFromBinary( - const char* spmodel_buffer_data, - size_t spmodel_buffer_size); - - // Initialize the API with the tokenizer set in the metadata. - absl::Status InitializeFromMetadata(); - - std::string ConvertIndexToString(int start, int end); - - std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_; - // Maps index of input token to index of untokenized word from original input. - absl::flat_hash_map<size_t, size_t> token_to_orig_map_; - // Original tokens of context. - std::vector<std::string> orig_tokens_; -}; - -} // namespace qa -} // namespace text -} // namespace task -} // namespace tflite - -#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/question_answerer.h deleted file mode 100644 index b4bd641f..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/qa/question_answerer.h +++ /dev/null
@@ -1,65 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_QUESTION_ANSWERER_H_ -#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_QUESTION_ANSWERER_H_ - -#include <string> -#include <utility> -#include <vector> - -#include "tensorflow_lite_support/cc/task/core/base_task_api.h" -#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" - -namespace tflite { -namespace task { -namespace text { -namespace qa { - -// Struct for the Answer to QuestionAnswerer. -struct QaAnswer { - // struct to represent the logit and offset of the answer related to context. - struct Pos { - Pos(int arg_start, int arg_end, float arg_logit) - : start(arg_start), end(arg_end), logit(arg_logit) {} - int start, end; - float logit; - bool operator<(const Pos& rhs) const { return rhs.logit < logit; } - }; - - QaAnswer(std::string arg_text, Pos arg_pos) - : text(std::move(arg_text)), pos(arg_pos) {} - std::string text; - Pos pos; -}; - -// Interface for an Question-Answer API. -class QuestionAnswerer : public core::BaseTaskApi<std::vector<QaAnswer>, - const std::string&, - const std::string&> { - public: - explicit QuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine) - : BaseTaskApi(std::move(engine)) {} - - virtual std::vector<QaAnswer> Answer(const std::string& context, - const std::string& question) = 0; -}; - -} // namespace qa -} // namespace text -} // namespace task -} // namespace tflite - -#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_QUESTION_ANSWERER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h new file mode 100644 index 0000000..df21662 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h
@@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_QA_QUESTION_ANSWERER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_QA_QUESTION_ANSWERER_H_ + +#include <string> +#include <utility> +#include <vector> + +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" + +namespace tflite { +namespace task { +namespace text { + +// Struct for the Answer to QuestionAnswerer. +struct QaAnswer { + // struct to represent the logit and offset of the answer related to context. + struct Pos { + Pos(int arg_start, int arg_end, float arg_logit) + : start(arg_start), end(arg_end), logit(arg_logit) {} + int start, end; + float logit; + bool operator<(const Pos& rhs) const { return rhs.logit < logit; } + }; + + QaAnswer(std::string arg_text, Pos arg_pos) + : text(std::move(arg_text)), pos(arg_pos) {} + std::string text; + Pos pos; +}; + +// Interface for an Question-Answer API. +class QuestionAnswerer : public core::BaseTaskApi<std::vector<QaAnswer>, + const std::string&, + const std::string&> { + public: + explicit QuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine) + : BaseTaskApi(std::move(engine)) {} + + virtual std::vector<QaAnswer> Answer(const std::string& context, + const std::string& question) = 0; +}; + +} // namespace text +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_QA_QUESTION_ANSWERER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc new file mode 100644 index 0000000..2937a175 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc
@@ -0,0 +1,330 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h" + +#include <algorithm> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/text/proto/retrieval.pb.h" + +namespace tflite { +namespace ops { +namespace custom { +TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER(); +TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR(); +} // namespace custom +} // namespace ops +} // namespace tflite + +namespace tflite { +namespace task { +namespace text { +namespace retrieval { + +using ::absl::Status; +using ::absl::StatusCode; +using internal::QAInput; +using internal::QAOutput; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::core::FindTensorByName; +using ::tflite::task::core::PopulateTensor; +using ::tflite::task::core::PopulateVectorToRepeated; +using ::tflite::task::core::TaskAPIFactory; +using FeatureVector = UniversalSentenceEncoderQA::FeatureVector; + +namespace { +constexpr char kQueryTextTensorName[] = "inp_text"; +constexpr char kResponseTextTensorName[] = "res_text"; +constexpr char kResponseContextTensorName[] = "res_context"; +constexpr char kQueryEncodingTensorName[] = "query_encoding"; +constexpr char kResponseEncodingTensorName[] = "response_encoding"; + +// Sanity check for options to ensure required fields. +absl::Status SanityCheckOptions(const RetrievalOptions& options) { + if (!options.has_base_options()) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Missing mandatory `base_options` field", + TfLiteSupportStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +// Copy vector from model output. +inline absl::Status CopyVector(const TfLiteTensor* src, FeatureVector* target) { + return PopulateVectorToRepeated(src, target->mutable_value_float()); +} + +// Dot product of two vectors. Returns error status if size is mismatched. +template <class TCollection, class T = float> +tflite::support::StatusOr<T> Dot(const TCollection& a, const TCollection& b) { + if (a.size() != b.size()) { + return Status( + StatusCode::kInvalidArgument, + absl::StrFormat("mismatched vector size %d != %d", a.size(), b.size())); + } + auto dist = T(); + for (size_t i = 0; i < a.size(); ++i) { + dist += T(a[i]) * T(b[i]); + } + return dist; +} +} // namespace + +namespace internal { +struct QAInput { + std::string query_text; + std::string response_text; + std::string response_context; +}; + +struct QAOutput { + // Directly populate from raw tensor pointers to avoid extra copy. + const TfLiteTensor* query_encoding; // not owned. + const TfLiteTensor* response_encoding; // not owned. +}; +} // namespace internal + +// Creates custom op resolver for USE QA task. +std::unique_ptr<tflite_shims::ops::builtin::BuiltinOpResolver> +CreateQACustomOpResolver() { + auto resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>(); + resolver->AddCustom( + "TFSentencepieceTokenizeOp", + ::tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER()); + resolver->AddCustom( + "RaggedTensorToTensor", + ::tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR()); + return resolver; +} + +constexpr int UniversalSentenceEncoderQA::kFinalEmbeddingSize; + +StatusOr<RetrievalOutput> UniversalSentenceEncoderQA::Retrieve( + const RetrievalInput& input) { + if (input.query_text().empty()) { + return Status(StatusCode::kInvalidArgument, "query text cannot be empty."); + } + if (input.responses().empty()) { + return Status(StatusCode::kInvalidArgument, "responses cannot be empty."); + } + + RetrievalOutput output; + // Run inference. + // (1) Query is only encoded for once. + // (2) If responses are raw text, run model to get encoded vectors; otherwise, + // the encoded vector is kept from the input when given. + for (size_t i = 0; i < input.responses_size(); ++i) { + const auto& resp = input.responses(i); + + if (resp.has_raw_text()) { + // If response is in th raw text, encode both query and response. + const auto out = Run(input.query_text(), resp.raw_text().text(), + resp.raw_text().context()); + + // Only encode query for the first time. + if (i == 0) { + RETURN_IF_ERROR( + CopyVector(out.query_encoding, output.mutable_query_encoding())); + } + + // For each answer, set the response result. + auto r = output.mutable_response_results()->Add(); + RETURN_IF_ERROR(CopyVector(out.response_encoding, r->mutable_encoding())); + } else { + // If response is already encoded, encode query only and keep response + // encoding. + + // Only encode query for the first time. + if (i == 0) { + const auto& q = EncodeQuery(input.query_text()); + *output.mutable_query_encoding() = q.value(); + } + + // For each answer, set the response result from text_encoding + auto r = output.mutable_response_results()->Add(); + *r->mutable_encoding() = resp.text_encoding(); + } + } + + // Calculate scores. + for (size_t i = 0; i < output.response_results_size(); ++i) { + auto* r = output.mutable_response_results(i); + // TODO(tianlin): For a large size of results, it is more efficient to use + // matrix multiplication. + const auto& score = Similarity(output.query_encoding(), r->encoding()); + if (!score.ok()) { + return score.status(); + } + r->set_score(score.value()); + } + return output; +} + +StatusOr<FeatureVector> UniversalSentenceEncoderQA::EncodeQuery( + absl::string_view query_text) { + if (query_text.empty()) { + return Status(StatusCode::kInvalidArgument, "query text cannot be empty."); + } + + const auto& output = Run(query_text, "", ""); + FeatureVector v; + RETURN_IF_ERROR(CopyVector(output.query_encoding, &v)); + return v; +} + +StatusOr<FeatureVector> UniversalSentenceEncoderQA::EncodeResponse( + absl::string_view response_text, + absl::string_view response_context) { + if (response_text.empty() && response_context.empty()) { + return Status( + StatusCode::kInvalidArgument, + "either response text or context should be set to non-empty."); + } + + const auto& output = Run("", response_text, response_context); + FeatureVector v; + RETURN_IF_ERROR(CopyVector(output.response_encoding, &v)); + return v; +} + +StatusOr<float> UniversalSentenceEncoderQA::Similarity(const FeatureVector& a, + const FeatureVector& b) { + const auto& av = a.value_float(); + const auto& bv = b.value_float(); + return Dot(av, bv); +} + +std::vector<size_t> UniversalSentenceEncoderQA::Top( + const RetrievalOutput& output, + size_t k) { + // Ensure k in [0, total_size). + // If k == 0, it means that all outputs are ranked. + if (k == 0) { + k = output.response_results_size(); + } else { + k = std::min(k, size_t(output.response_results_size())); + } + + std::vector<size_t> pos(output.response_results_size()); + for (size_t i = 0; i < output.response_results_size(); ++i) { + pos[i] = i; + } + const auto greater_score = [&output](size_t i, size_t j) { + return output.response_results(i).score() > + output.response_results(j).score(); + }; + std::partial_sort(pos.begin(), pos.begin() + k, pos.end(), greater_score); + + // Return sorted. + return std::vector<size_t>(pos.begin(), pos.begin() + k); +} + +Status UniversalSentenceEncoderQA::Preprocess( + const std::vector<TfLiteTensor*>& input_tensors, + const QAInput& input) { + auto* input_tensor_metadatas = + GetMetadataExtractor()->GetInputTensorMetadata(); + TfLiteTensor* query_text_tensor = + input_tensor_metadatas + ? FindTensorByName(input_tensors, input_tensor_metadatas, + kQueryTextTensorName) + : input_tensors[0]; + TfLiteTensor* response_text_tensor = + input_tensor_metadatas + ? FindTensorByName(input_tensors, input_tensor_metadatas, + kResponseTextTensorName) + : input_tensors[2]; + TfLiteTensor* response_context_tensor = + input_tensor_metadatas + ? FindTensorByName(input_tensors, input_tensor_metadatas, + kResponseContextTensorName) + : input_tensors[1]; + + RETURN_IF_ERROR(PopulateTensor(input.query_text, query_text_tensor)); + RETURN_IF_ERROR(PopulateTensor(input.response_text, response_text_tensor)); + RETURN_IF_ERROR( + PopulateTensor(input.response_context, response_context_tensor)); + + return absl::OkStatus(); +} + +StatusOr<QAOutput> UniversalSentenceEncoderQA::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const QAInput& /*input*/) { + auto* output_tensor_metadatas = + GetMetadataExtractor()->GetOutputTensorMetadata(); + + const TfLiteTensor* output_query_encoding_tensor = + output_tensor_metadatas + ? FindTensorByName(output_tensors, output_tensor_metadatas, + kQueryEncodingTensorName) + : output_tensors[0]; + const TfLiteTensor* output_response_encoding_tensor = + output_tensor_metadatas + ? FindTensorByName(output_tensors, output_tensor_metadatas, + kResponseEncodingTensorName) + : output_tensors[1]; + + QAOutput output; + output.query_encoding = output_query_encoding_tensor; + output.response_encoding = output_response_encoding_tensor; + return output; +} + +internal::QAOutput UniversalSentenceEncoderQA::Run( + absl::string_view query_text, + absl::string_view response_text, + absl::string_view response_context) { + QAInput input; + input.query_text = query_text; + input.response_text = response_text; + input.response_context = response_context; + return Infer(input).value(); +} + +StatusOr<std::unique_ptr<UniversalSentenceEncoderQA>> +UniversalSentenceEncoderQA::CreateFromOption( + const RetrievalOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + RETURN_IF_ERROR(SanityCheckOptions(options)); + + // Copy options to ensure the ExternalFile outlives the duration of this + // created object. + auto options_copy = absl::make_unique<RetrievalOptions>(options); + + ASSIGN_OR_RETURN( + auto encoder, + TaskAPIFactory::CreateFromBaseOptions<UniversalSentenceEncoderQA>( + &options_copy->base_options(), std::move(resolver))); + encoder->proto_options_ = std::move(options_copy); + return std::move(encoder); +} + +} // namespace retrieval +} // namespace text +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h new file mode 100644 index 0000000..0269033 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h
@@ -0,0 +1,107 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_UNIVERSAL_SENTENCE_ENCODER_QA_H_ +#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_UNIVERSAL_SENTENCE_ENCODER_QA_H_ + +#include <string> +#include <utility> +#include <vector> + +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h" +#include "tensorflow_lite_support/cc/task/text/proto/retrieval.pb.h" + +namespace tflite { +namespace task { +namespace text { +namespace retrieval { + +// QAInput and QAOutput for UniversalSentenceEncoderQA internally. +namespace internal { +struct QAInput; +struct QAOutput; +} // namespace internal + +// Creates custom op resolver for USE QA task. +std::unique_ptr<tflite_shims::ops::builtin::BuiltinOpResolver> +CreateQACustomOpResolver(); + +// Universal Sentence Encoder (USE) Question Answerer. The model uses USE as the +// backbone and answers a question. +class UniversalSentenceEncoderQA + : public core::BaseTaskApi<internal::QAOutput, const internal::QAInput&> { + public: + using BaseTaskApi::BaseTaskApi; + using FeatureVector = ::tflite::task::processor::FeatureVector; + + // TODO(b/198995952): add support to parameterize. + static constexpr int kFinalEmbeddingSize = 100; + + static tflite::support::StatusOr<std::unique_ptr<UniversalSentenceEncoderQA>> + CreateFromOption(const tflite::task::text::RetrievalOptions& options, + std::unique_ptr<tflite::OpResolver> resolver = + CreateQACustomOpResolver()); + + // Retrieves output from the input by running TFLite engine. + // Returns an error, if either query_text or responses is empty. + tflite::support::StatusOr<RetrievalOutput> Retrieve( + const RetrievalInput& input); + + // Encodes query from the text. + // Returns an error, if query text is empty. + tflite::support::StatusOr<FeatureVector> EncodeQuery( + absl::string_view query_text); + + // Encodes response from the text and/or context. + // Returns an error, if both text and context are empty. + tflite::support::StatusOr<FeatureVector> EncodeResponse( + absl::string_view response_text, + absl::string_view response_context); + + // Calculates similarity between two encoded vectors (require same size). + static tflite::support::StatusOr<float> Similarity(const FeatureVector& a, + const FeatureVector& b); + + // Gets top k corresponding to output response scores in descending order. + // If k == 0, all responses are ranked. + static std::vector<size_t> Top(const RetrievalOutput& output, size_t k = 0); + + private: + absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, + const internal::QAInput& input) override; + + tflite::support::StatusOr<internal::QAOutput> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const internal::QAInput& input) override; + + internal::QAOutput Run(absl::string_view query_text, + absl::string_view response_text, + absl::string_view response_context); + + std::unique_ptr<tflite::task::text::RetrievalOptions> proto_options_; +}; + +} // namespace retrieval +} // namespace text +} // namespace task +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_UNIVERSAL_SENTENCE_ENCODER_QA_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/BUILD index d426486f..ab6d7b9 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/BUILD
@@ -1,3 +1,5 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + package( default_visibility = [ "//visibility:public", @@ -5,19 +7,28 @@ licenses = ["notice"], # Apache 2.0 ) -cc_library( +# IMPORTANT: in order to use hardware acceleration delegates, configurable through the +# `compute_settings` field of the ObjectDetectorOptions, you must additionally link to +# the appropriate delegate plugin target (e.g. `gpu_plugin` for GPU) from: +# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/BUILD +# To use EDGETPU_CORAL, link to `edgetpu_coral_plugin` from: +# https://github.com/tensorflow/tflite-support/blob/a58a4f9225c411fa9ba29f821523e6e283988d23/tensorflow_lite_support/acceleration/configuration/BUILD#L11 +cc_library_with_tflite( name = "object_detector", srcs = ["object_detector.cc"], hdrs = ["object_detector.h"], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/cc/task/vision/core:base_vision_task_api", + ], deps = [ "//tensorflow_lite_support/cc:common", "//tensorflow_lite_support/cc/port:status_macros", "//tensorflow_lite_support/cc/port:statusor", "//tensorflow_lite_support/cc/task/core:external_file_handler", - "//tensorflow_lite_support/cc/task/core:task_api_factory", "//tensorflow_lite_support/cc/task/core:task_utils", - "//tensorflow_lite_support/cc/task/core:tflite_engine", - "//tensorflow_lite_support/cc/task/vision/core:base_vision_task_api", "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", "//tensorflow_lite_support/cc/task/vision/core:label_map_item", "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", @@ -25,6 +36,7 @@ "//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/cc/task/vision/utils:score_calibration", "//tensorflow_lite_support/metadata:metadata_schema_cc", "//tensorflow_lite_support/metadata/cc:metadata_extractor", "@com_google_absl//absl/container:flat_hash_set", @@ -32,26 +44,35 @@ "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@org_tensorflow//tensorflow/lite:framework", + "@com_google_glog//:glog", "@org_tensorflow//tensorflow/lite/c:common", "@org_tensorflow//tensorflow/lite/core/api", ], ) -cc_library( +# IMPORTANT: in order to use hardware acceleration delegates, configurable through the +# `compute_settings` field of the ImageClassifierOptions, you must additionally link to +# the appropriate delegate plugin target (e.g. `gpu_plugin` for GPU) from: +# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/BUILD +# To use EDGETPU_CORAL, link to `edgetpu_coral_plugin` from: +# https://github.com/tensorflow/tflite-support/blob/a58a4f9225c411fa9ba29f821523e6e283988d23/tensorflow_lite_support/acceleration/configuration/BUILD#L11 +cc_library_with_tflite( name = "image_classifier", srcs = ["image_classifier.cc"], hdrs = ["image_classifier.h"], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/cc/task/vision/core:base_vision_task_api", + ], deps = [ "//tensorflow_lite_support/cc:common", "//tensorflow_lite_support/cc/port:integral_types", "//tensorflow_lite_support/cc/port:status_macros", "//tensorflow_lite_support/cc/port:statusor", "//tensorflow_lite_support/cc/task/core:external_file_handler", - "//tensorflow_lite_support/cc/task/core:task_api_factory", "//tensorflow_lite_support/cc/task/core:task_utils", - "//tensorflow_lite_support/cc/task/core:tflite_engine", - "//tensorflow_lite_support/cc/task/vision/core:base_vision_task_api", "//tensorflow_lite_support/cc/task/vision/core:classification_head", "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", "//tensorflow_lite_support/cc/task/vision/core:label_map_item", @@ -69,26 +90,34 @@ "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@flatbuffers", - "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/c:common", "@org_tensorflow//tensorflow/lite/core/api", ], ) -cc_library( +# IMPORTANT: in order to use hardware acceleration delegates, configurable through the +# `compute_settings` field of the ImageSegmenterOptions, you must additionally link to +# the appropriate delegate plugin target (e.g. `gpu_plugin` for GPU) from: +# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/BUILD +# To use EDGETPU_CORAL, link to `edgetpu_coral_plugin` from: +# https://github.com/tensorflow/tflite-support/blob/a58a4f9225c411fa9ba29f821523e6e283988d23/tensorflow_lite_support/acceleration/configuration/BUILD#L11 +cc_library_with_tflite( name = "image_segmenter", srcs = ["image_segmenter.cc"], hdrs = ["image_segmenter.h"], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/vision/core:base_vision_task_api", + ], deps = [ "//tensorflow_lite_support/cc:common", "//tensorflow_lite_support/cc/port:integral_types", "//tensorflow_lite_support/cc/port:status_macros", "//tensorflow_lite_support/cc/port:statusor", "//tensorflow_lite_support/cc/task/core:external_file_handler", - "//tensorflow_lite_support/cc/task/core:task_api_factory", "//tensorflow_lite_support/cc/task/core:task_utils", - "//tensorflow_lite_support/cc/task/core:tflite_engine", - "//tensorflow_lite_support/cc/task/vision/core:base_vision_task_api", "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", "//tensorflow_lite_support/cc/task/vision/core:label_map_item", "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", @@ -106,3 +135,41 @@ "@org_tensorflow//tensorflow/lite/core/api", ], ) + +# IMPORTANT: in order to use hardware acceleration delegates, configurable through the +# `compute_settings` field of the ImageEmbedderOptions, you must additionally link to +# the appropriate delegate plugin target (e.g. `gpu_plugin` for GPU) from: +# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/BUILD +# To use EDGETPU_CORAL, link to `edgetpu_coral_plugin` from: +# https://github.com/tensorflow/tflite-support/blob/a58a4f9225c411fa9ba29f821523e6e283988d23/tensorflow_lite_support/acceleration/configuration/BUILD#L11 +cc_library_with_tflite( + name = "image_embedder", + srcs = ["image_embedder.cc"], + hdrs = ["image_embedder.h"], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/cc/task/vision/core:base_vision_task_api", + "//tensorflow_lite_support/cc/task/processor:embedding_postprocessor", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:external_file_handler", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:embeddings_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:image_embedder_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/BUILD index 1df86cb..9043269 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/BUILD
@@ -1,28 +1,33 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + package( default_visibility = [ - "//visibility:public", + "//tensorflow_lite_support:internal", ], licenses = ["notice"], # Apache 2.0 ) exports_files(srcs = ["base_vision_task_api.h"]) -cc_library( +cc_library_with_tflite( name = "base_vision_task_api", hdrs = [ "base_vision_task_api.h", ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/core:base_task_api", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/cc/task/vision/utils:image_tensor_specs", + "//tensorflow_lite_support/cc/task/processor:image_preprocessor", + ], deps = [ ":frame_buffer", "//tensorflow_lite_support/cc:common", "//tensorflow_lite_support/cc/port:integral_types", "//tensorflow_lite_support/cc/port:status_macros", - "//tensorflow_lite_support/cc/task/core:base_task_api", "//tensorflow_lite_support/cc/task/core:task_utils", - "//tensorflow_lite_support/cc/task/core:tflite_engine", "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", - "//tensorflow_lite_support/cc/task/vision/utils:image_tensor_specs", "//tensorflow_lite_support/metadata:metadata_schema_cc", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -35,6 +40,7 @@ name = "frame_buffer", srcs = ["frame_buffer.cc"], hdrs = ["frame_buffer.h"], + visibility = ["//visibility:public"], deps = [ "//tensorflow_lite_support/cc/port:integral_types", "//tensorflow_lite_support/cc/port:statusor", @@ -51,6 +57,9 @@ name = "label_map_item", srcs = ["label_map_item.cc"], hdrs = ["label_map_item.h"], + visibility = [ + "//tensorflow_lite_support:internal", + ], deps = [ "//tensorflow_lite_support/cc:common", "//tensorflow_lite_support/cc/port:statusor",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h index c787876b..d3557fc 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
@@ -21,9 +21,9 @@ #include <utility> #include <vector> -#include "absl/memory/memory.h" -#include "absl/status/status.h" -#include "absl/time/clock.h" +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/time/clock.h" // from @com_google_absl #include "tensorflow/lite/c/common.h" #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/integral_types.h" @@ -31,6 +31,7 @@ #include "tensorflow_lite_support/cc/task/core/base_task_api.h" #include "tensorflow_lite_support/cc/task/core/task_utils.h" #include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/processor/image_preprocessor.h" #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" @@ -56,19 +57,15 @@ BaseVisionTaskApi(const BaseVisionTaskApi&) = delete; BaseVisionTaskApi& operator=(const BaseVisionTaskApi&) = delete; - // Number of bytes required for 8-bit per pixel RGB color space. - static constexpr int kRgbPixelBytes = 3; - // Sets the ProcessEngine used for image pre-processing. Must be called before // any inference is performed. Can be called between inferences to override // the current process engine. void SetProcessEngine(const FrameBufferUtils::ProcessEngine& process_engine) { - frame_buffer_utils_ = FrameBufferUtils::Create(process_engine); + process_engine_ = process_engine; } protected: - using tflite::task::core:: - BaseTaskApi<OutputType, const FrameBuffer&, const BoundingBox&>::engine_; + FrameBufferUtils::ProcessEngine process_engine_; // Checks input tensor and metadata (if any) are valid, or return an error // otherwise. This must be called once at initialization time, before running @@ -76,19 +73,10 @@ // Note: the underlying interpreter and metadata extractor are assumed to be // already successfully initialized before calling this method. virtual absl::Status CheckAndSetInputs() { - ASSIGN_OR_RETURN( - ImageTensorSpecs input_specs, - BuildInputImageTensorSpecs(*engine_->interpreter(), - *engine_->metadata_extractor())); - - if (input_specs.color_space != tflite::ColorSpaceType_RGB) { - return tflite::support::CreateStatusWithPayload( - absl::StatusCode::kUnimplemented, - "BaseVisionTaskApi only supports RGB color space for now."); - } - - input_specs_ = absl::make_unique<ImageTensorSpecs>(input_specs); - + // BaseTaskApi always assume having a single input. + ASSIGN_OR_RETURN(preprocessor_, + ::tflite::task::processor::ImagePreprocessor::Create( + this->GetTfLiteEngine(), {0}, process_engine_)); return absl::OkStatus(); } @@ -115,153 +103,28 @@ absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, const FrameBuffer& frame_buffer, const BoundingBox& roi) override { - if (input_specs_ == nullptr) { + if (preprocessor_ == nullptr) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, + "Uninitialized preprocessor: CheckAndSetInputs must be called " + "at initialization time."); + } + if (GetInputSpecs().image_height == 0 && GetInputSpecs().image_width == 0) { return tflite::support::CreateStatusWithPayload( absl::StatusCode::kInternal, "Uninitialized input tensor specs: CheckAndSetInputs must be called " "at initialization time."); } - - if (frame_buffer_utils_ == nullptr) { - return tflite::support::CreateStatusWithPayload( - absl::StatusCode::kInternal, - "Uninitialized frame buffer utils: SetProcessEngine must be called " - "at initialization time."); - } - - if (input_tensors.size() != 1) { - return tflite::support::CreateStatusWithPayload( - absl::StatusCode::kInternal, "A single input tensor is expected."); - } - - // Input data to be normalized (if needed) and used for inference. In most - // cases, this is the result of image preprocessing. In case no image - // preprocessing is needed (see below), this points to the input frame - // buffer raw data. - const uint8* input_data; - size_t input_data_byte_size; - - // Optional buffers in case image preprocessing is needed. - std::unique_ptr<FrameBuffer> preprocessed_frame_buffer; - std::vector<uint8> preprocessed_data; - - if (IsImagePreprocessingNeeded(frame_buffer, roi)) { - // Preprocess input image to fit model requirements. - // For now RGB is the only color space supported, which is ensured by - // `CheckAndSetInputs`. - FrameBuffer::Dimension to_buffer_dimension = {input_specs_->image_width, - input_specs_->image_height}; - input_data_byte_size = - GetBufferByteSize(to_buffer_dimension, FrameBuffer::Format::kRGB); - preprocessed_data.resize(input_data_byte_size / sizeof(uint8), 0); - input_data = preprocessed_data.data(); - - FrameBuffer::Plane preprocessed_plane = { - /*buffer=*/preprocessed_data.data(), - /*stride=*/{input_specs_->image_width * kRgbPixelBytes, - kRgbPixelBytes}}; - preprocessed_frame_buffer = FrameBuffer::Create( - {preprocessed_plane}, to_buffer_dimension, FrameBuffer::Format::kRGB, - FrameBuffer::Orientation::kTopLeft); - - RETURN_IF_ERROR(frame_buffer_utils_->Preprocess( - frame_buffer, roi, preprocessed_frame_buffer.get())); - } else { - // Input frame buffer already targets model requirements: skip image - // preprocessing. For RGB, the data is always stored in a single plane. - input_data = frame_buffer.plane(0).buffer; - input_data_byte_size = frame_buffer.plane(0).stride.row_stride_bytes * - frame_buffer.dimension().height; - } - - // Then normalize pixel data (if needed) and populate the input tensor. - switch (input_specs_->tensor_type) { - case kTfLiteUInt8: - if (input_tensors[0]->bytes != input_data_byte_size) { - return tflite::support::CreateStatusWithPayload( - absl::StatusCode::kInternal, - "Size mismatch or unsupported padding bytes between pixel data " - "and input tensor."); - } - // No normalization required: directly populate data. - tflite::task::core::PopulateTensor( - input_data, input_data_byte_size / sizeof(uint8), input_tensors[0]); - break; - case kTfLiteFloat32: { - if (input_tensors[0]->bytes / sizeof(float) != - input_data_byte_size / sizeof(uint8)) { - return tflite::support::CreateStatusWithPayload( - absl::StatusCode::kInternal, - "Size mismatch or unsupported padding bytes between pixel data " - "and input tensor."); - } - // Normalize and populate. - float* normalized_input_data = - tflite::task::core::AssertAndReturnTypedTensor<float>( - input_tensors[0]); - const tflite::task::vision::NormalizationOptions& - normalization_options = input_specs_->normalization_options.value(); - if (normalization_options.num_values == 1) { - float mean_value = normalization_options.mean_values[0]; - float inv_std_value = (1.0f / normalization_options.std_values[0]); - for (size_t i = 0; i < input_data_byte_size / sizeof(uint8); - i++, input_data++, normalized_input_data++) { - *normalized_input_data = - inv_std_value * (static_cast<float>(*input_data) - mean_value); - } - } else { - std::array<float, 3> inv_std_values = { - 1.0f / normalization_options.std_values[0], - 1.0f / normalization_options.std_values[1], - 1.0f / normalization_options.std_values[2]}; - for (size_t i = 0; i < input_data_byte_size / sizeof(uint8); - i++, input_data++, normalized_input_data++) { - *normalized_input_data = inv_std_values[i % 3] * - (static_cast<float>(*input_data) - - normalization_options.mean_values[i % 3]); - } - } - break; - } - case kTfLiteInt8: - return tflite::support::CreateStatusWithPayload( - absl::StatusCode::kUnimplemented, - "kTfLiteInt8 input type is not implemented yet."); - default: - return tflite::support::CreateStatusWithPayload( - absl::StatusCode::kInternal, "Unexpected input tensor type."); - } - - return absl::OkStatus(); + return preprocessor_->Preprocess(frame_buffer, roi); } - // Utils for input image preprocessing (resizing, colorspace conversion, etc). - std::unique_ptr<FrameBufferUtils> frame_buffer_utils_; - - // Parameters related to the input tensor which represents an image. - std::unique_ptr<ImageTensorSpecs> input_specs_; + // Returns the spec for the input image. + const vision::ImageTensorSpecs& GetInputSpecs() const { + return preprocessor_->GetInputSpecs(); + } private: - // Returns false if image preprocessing could be skipped, true otherwise. - bool IsImagePreprocessingNeeded(const FrameBuffer& frame_buffer, - const BoundingBox& roi) { - // Is crop required? - if (roi.origin_x() != 0 || roi.origin_y() != 0 || - roi.width() != frame_buffer.dimension().width || - roi.height() != frame_buffer.dimension().height) { - return true; - } - - // Are image transformations required? - if (frame_buffer.orientation() != FrameBuffer::Orientation::kTopLeft || - frame_buffer.format() != FrameBuffer::Format::kRGB || - frame_buffer.dimension().width != input_specs_->image_width || - frame_buffer.dimension().height != input_specs_->image_height) { - return true; - } - - return false; - } + std::unique_ptr<processor::ImagePreprocessor> preprocessor_ = nullptr; }; } // namespace vision
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.cc index 962cb34b..b5b57f2 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.cc
@@ -14,7 +14,7 @@ ==============================================================================*/ #include "tensorflow_lite_support/cc/task/vision/core/classification_head.h" -#include "absl/status/status.h" +#include "absl/status/status.h" // from @com_google_absl #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/status_macros.h" #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h index 07cd8b9..2e1aa6d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h
@@ -18,8 +18,8 @@ #include <string> #include <vector> -#include "absl/memory/memory.h" -#include "absl/strings/string_view.h" +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" #include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h index 42ac080..2936f5a 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
@@ -22,12 +22,12 @@ #include <utility> #include <vector> -#include "absl/memory/memory.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "absl/types/optional.h" +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_cat.h" // from @com_google_absl +#include "absl/time/clock.h" // from @com_google_absl +#include "absl/time/time.h" // from @com_google_absl +#include "absl/types/optional.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/integral_types.h" #include "tensorflow_lite_support/cc/port/statusor.h" @@ -74,7 +74,16 @@ class FrameBuffer { public: // Colorspace formats. - enum class Format { kRGBA, kRGB, kNV12, kNV21, kYV12, kYV21, kGRAY }; + enum class Format { + kRGBA, + kRGB, + kNV12, + kNV21, + kYV12, + kYV21, + kGRAY, + kUNKNOWN + }; // Stride information. struct Stride { @@ -85,6 +94,13 @@ // pixels in bytes. It may be larger than the size of a single pixel to // account for interleaved image data or padded formats. int pixel_stride_bytes; + + bool operator==(const Stride& other) const { + return row_stride_bytes == other.row_stride_bytes && + pixel_stride_bytes == other.pixel_stride_bytes; + } + + bool operator!=(const Stride& other) const { return !operator==(other); } }; // YUV data structure.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc index ee2d9ad..67fe0753 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
@@ -15,8 +15,8 @@ ==============================================================================*/ #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_split.h" +#include "absl/strings/str_format.h" // from @com_google_absl +#include "absl/strings/str_split.h" // from @com_google_absl #include "tensorflow_lite_support/cc/common.h" namespace tflite {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h index 8a95362..20c316b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h
@@ -18,10 +18,10 @@ #include <string> #include <vector> -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/status/status.h" -#include "absl/strings/string_view.h" +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "absl/container/flat_hash_set.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/statusor.h" namespace tflite {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc index 313f28b..36ab3c3 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc
@@ -15,11 +15,10 @@ #include "tensorflow_lite_support/cc/task/vision/image_classifier.h" -#include "absl/algorithm/container.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "tensorflow/lite/interpreter.h" +#include "absl/algorithm/container.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/integral_types.h" #include "tensorflow_lite_support/cc/port/status_macros.h" @@ -47,26 +46,6 @@ using ::tflite::task::core::TaskAPIFactory; using ::tflite::task::core::TfLiteEngine; -// Default score value used as a fallback for classes that (1) have no score -// calibration data or (2) have a very low confident uncalibrated score, i.e. -// lower than the `min_uncalibrated_score` threshold. -// -// (1) This happens when the ScoreCalibration does not cover all the classes -// listed in the label map. This can be used to enforce the blacklisting of -// given classes so that they are never returned. -// -// (2) This is an optional threshold provided part of the calibration data. It -// is used to mitigate false alarms on some classes. -// -// In both cases, a class that gets assigned a score of -1 is never returned as -// it gets discarded by the `score_threshold` check (see post-processing logic). -constexpr float kDefaultCalibratedScore = -1.0f; - -// Calibrated scores should be in the [0, 1] range, otherwise an error is -// returned at post-processing time. -constexpr float kMinCalibratedScore = 0.0f; -constexpr float kMaxCalibratedScore = 1.0f; - } // namespace /* static */ @@ -78,10 +57,25 @@ // Copy options to ensure the ExternalFile outlives the constructed object. auto options_copy = absl::make_unique<ImageClassifierOptions>(options); - ASSIGN_OR_RETURN(auto image_classifier, - TaskAPIFactory::CreateFromExternalFileProto<ImageClassifier>( - &options_copy->model_file_with_metadata(), - std::move(resolver), options_copy->num_threads())); + std::unique_ptr<ImageClassifier> image_classifier; + if (options_copy->has_model_file_with_metadata()) { + ASSIGN_OR_RETURN( + image_classifier, + TaskAPIFactory::CreateFromExternalFileProto<ImageClassifier>( + &options_copy->model_file_with_metadata(), std::move(resolver), + options_copy->num_threads(), options_copy->compute_settings())); + } else if (options_copy->base_options().has_model_file()) { + ASSIGN_OR_RETURN(image_classifier, + TaskAPIFactory::CreateFromBaseOptions<ImageClassifier>( + &options_copy->base_options(), std::move(resolver))); + } else { + // Should never happen because of SanityCheckOptions. + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Expected exactly one of `base_options.model_file` or " + "`model_file_with_metadata` to be provided, found 0."), + TfLiteSupportStatus::kInvalidArgumentError); + } RETURN_IF_ERROR(image_classifier->Init(std::move(options_copy))); @@ -91,10 +85,14 @@ /* static */ absl::Status ImageClassifier::SanityCheckOptions( const ImageClassifierOptions& options) { - if (!options.has_model_file_with_metadata()) { + int num_input_models = (options.base_options().has_model_file() ? 1 : 0) + + (options.has_model_file_with_metadata() ? 1 : 0); + if (num_input_models != 1) { return CreateStatusWithPayload( StatusCode::kInvalidArgument, - "Missing mandatory `model_file_with_metadata` field", + absl::StrFormat("Expected exactly one of `base_options.model_file` or " + "`model_file_with_metadata` to be provided, found %d.", + num_input_models), TfLiteSupportStatus::kInvalidArgumentError); } if (options.max_results() == 0) { @@ -103,14 +101,6 @@ "Invalid `max_results` option: value must be != 0", TfLiteSupportStatus::kInvalidArgumentError); } - if (options.score_threshold() < 0 || options.score_threshold() >= 1) { - return CreateStatusWithPayload( - StatusCode::kInvalidArgument, - absl::StrFormat( - "`score_threshold` out of range: %f. Valid range is [0,1[.", - options.score_threshold()), - TfLiteSupportStatus::kInvalidArgumentError); - } if (options.class_name_whitelist_size() > 0 && options.class_name_blacklist_size() > 0) { return CreateStatusWithPayload( @@ -161,11 +151,11 @@ } absl::Status ImageClassifier::CheckAndSetOutputs() { - num_outputs_ = TfLiteEngine::OutputCount(engine_->interpreter()); + num_outputs_ = TfLiteEngine::OutputCount(GetTfLiteEngine()->interpreter()); // Perform sanity checks and extract metadata. const ModelMetadataExtractor* metadata_extractor = - engine_->metadata_extractor(); + GetTfLiteEngine()->metadata_extractor(); const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>* output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata(); @@ -220,7 +210,7 @@ int num_quantized_outputs = 0; for (int i = 0; i < num_outputs_; ++i) { const TfLiteTensor* output_tensor = - TfLiteEngine::GetOutput(engine_->interpreter(), i); + TfLiteEngine::GetOutput(GetTfLiteEngine()->interpreter(), i); const int num_dimensions = output_tensor->dims->size; if (num_dimensions == 4) { if (output_tensor->dims->data[1] != 1 || @@ -370,12 +360,6 @@ continue; } - // Use a specific default score instead of the one specified by default in - // cc/task/vision/utils/score_calibration.h. See `kDefaultCalibratedScore` - // documentation for more details. - classification_heads_[i].calibration_params->default_score = - kDefaultCalibratedScore; - score_calibrations_[i] = absl::make_unique<ScoreCalibration>(); if (score_calibrations_[i] == nullptr) { return CreateStatusWithPayload( @@ -427,16 +411,16 @@ const TfLiteTensor* output_tensor = output_tensors[i]; if (has_uint8_outputs_) { - const uint8* output_data = - AssertAndReturnTypedTensor<uint8>(output_tensor); + ASSIGN_OR_RETURN(const uint8* output_data, + AssertAndReturnTypedTensor<uint8>(output_tensor)); for (int j = 0; j < head.label_map_items.size(); ++j) { score_pairs.emplace_back(j, output_tensor->params.scale * (static_cast<int>(output_data[j]) - output_tensor->params.zero_point)); } } else { - const float* output_data = - AssertAndReturnTypedTensor<float>(output_tensor); + ASSIGN_OR_RETURN(const float* output_data, + AssertAndReturnTypedTensor<float>(output_tensor)); for (int j = 0; j < head.label_map_items.size(); ++j) { score_pairs.emplace_back(j, output_data[j]); } @@ -447,23 +431,20 @@ for (auto& score_pair : score_pairs) { const std::string& class_name = head.label_map_items[score_pair.first].name; + + // In ComputeCalibratedScore, score_pair.second is set to the + // default_score value from metadata [1] if the category (1) has no + // score calibration data or (2) has a very low confident uncalibrated + // score, i.e. lower than the `min_uncalibrated_score` threshold. + // Otherwise, score_pair.second is calculated based on the selected + // score transformation function, and the value is guaranteed to be in + // the range of [0, scale], where scale is a label-dependent sigmoid + // parameter. + // + // [1]: + // https://github.com/tensorflow/tflite-support/blob/af26cb6952ccdeee0e849df2b93dbe7e57f6bc48/tensorflow_lite_support/metadata/metadata_schema.fbs#L453 score_pair.second = score_calibrations_[i]->ComputeCalibratedScore( class_name, score_pair.second); - if (score_pair.second > kMaxCalibratedScore) { - return CreateStatusWithPayload( - StatusCode::kInternal, - absl::StrFormat("calibrated score is too high: got %f, expected " - "%f as maximum.", - score_pair.second, kMaxCalibratedScore)); - } - if (score_pair.second != kDefaultCalibratedScore && - score_pair.second < kMinCalibratedScore) { - return CreateStatusWithPayload( - StatusCode::kInternal, - absl::StrFormat("calibrated score is too low: got %f, expected " - "%f as minimum.", - score_pair.second, kMinCalibratedScore)); - } } }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h index 4c3affe..eb0c13e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h
@@ -19,10 +19,11 @@ #include <memory> #include <vector> -#include "absl/container/flat_hash_set.h" -#include "absl/status/status.h" +#include "absl/container/flat_hash_set.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/core/shims/cc/kernels/register.h" #include "tensorflow_lite_support/cc/port/integral_types.h" #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" @@ -80,7 +81,7 @@ CreateFromOptions( const ImageClassifierOptions& options, std::unique_ptr<tflite::OpResolver> resolver = - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); // Performs actual classification on the provided FrameBuffer. //
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc new file mode 100644 index 0000000..943a39b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc
@@ -0,0 +1,161 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/vision/image_embedder.h" + +#include <algorithm> + +#include "absl/container/node_hash_set.h" // from @com_google_absl +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "tensorflow/lite/c/common.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" + +namespace tflite { +namespace task { +namespace vision { + +namespace { +using ::tflite::task::core::TaskAPIFactory; + +tflite::support::StatusOr<std::unique_ptr<processor::EmbeddingPostprocessor>> +CreatePostprocessor(core::TfLiteEngine* engine, + const std::initializer_list<int> output_indices, + const ImageEmbedderOptions& options) { + auto new_options = std::make_unique<processor::EmbeddingOptions>(); + new_options->set_l2_normalize(options.l2_normalize()); + new_options->set_quantize(options.quantize()); + return processor::EmbeddingPostprocessor::Create(engine, output_indices, + std::move(new_options)); +} +} // namespace + +/* static */ +tflite::support::StatusOr<double> ImageEmbedder::CosineSimilarity( + const FeatureVector& u, + const FeatureVector& v) { + return processor::EmbeddingPostprocessor::CosineSimilarity(u, v); +} + +/* static */ +tflite::support::StatusOr<std::unique_ptr<ImageEmbedder>> +ImageEmbedder::CreateFromOptions(const ImageEmbedderOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + // Copy options to ensure the ExternalFile-s outlive the constructed object. + auto options_copy = absl::make_unique<ImageEmbedderOptions>(options); + + ASSIGN_OR_RETURN( + auto image_embedder, + TaskAPIFactory::CreateFromExternalFileProto<ImageEmbedder>( + &options_copy->model_file_with_metadata(), std::move(resolver), + options_copy->num_threads(), options_copy->compute_settings())); + + RETURN_IF_ERROR(image_embedder->Init(std::move(options_copy))); + + return image_embedder; +} + +absl::Status ImageEmbedder::PreInit() { + SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv); + return absl::OkStatus(); +} + +absl::Status ImageEmbedder::PostInit() { + // Nothing to do. + return absl::OkStatus(); +} + +absl::Status ImageEmbedder::Init( + std::unique_ptr<ImageEmbedderOptions> options) { + // Set options. + options_ = std::move(options); + + // Perform pre-initialization actions. + RETURN_IF_ERROR(PreInit()); + + // Sanity check and set inputs and outputs. + RETURN_IF_ERROR(CheckAndSetInputs()); + + // Perform post-initialization actions. + RETURN_IF_ERROR(PostInit()); + + // ImageEmbedder assumes that all output tensors share the same + // embedding option. + postprocessors_.reserve(GetTfLiteEngine()->interpreter()->outputs().size()); + for (int i = 0; i < GetTfLiteEngine()->interpreter()->outputs().size(); i++) { + ASSIGN_OR_RETURN(auto processor, + CreatePostprocessor(GetTfLiteEngine(), {i}, *options_)); + postprocessors_.emplace_back(std::move(processor)); + } + + return absl::OkStatus(); +} + +tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Embed( + const FrameBuffer& frame_buffer) { + BoundingBox roi; + roi.set_width(frame_buffer.dimension().width); + roi.set_height(frame_buffer.dimension().height); + return Embed(frame_buffer, roi); +} + +tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Embed( + const FrameBuffer& frame_buffer, + const BoundingBox& roi) { + return InferWithFallback(frame_buffer, roi); +} + +tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const FrameBuffer& /*frame_buffer*/, + const BoundingBox& /*roi*/) { + EmbeddingResult result; + for (int i = 0; i < postprocessors_.size(); ++i) { + RETURN_IF_ERROR( + postprocessors_.at(i)->Postprocess(result.add_embeddings())); + } + + return result; +} + +Embedding ImageEmbedder::GetEmbeddingByIndex(const EmbeddingResult& result, + int output_index) { + if (output_index < 0 || output_index >= postprocessors_.size()) { + return Embedding(); + } + return result.embeddings(output_index); +} + +int ImageEmbedder::GetEmbeddingDimension(int output_index) const { + if (output_index < 0 || output_index >= postprocessors_.size()) { + return -1; + } + return postprocessors_.at(output_index)->GetEmbeddingDimension(); +} + +int ImageEmbedder::GetNumberOfOutputLayers() const { + return postprocessors_.size(); +} + +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h new file mode 100644 index 0000000..93e2455 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h
@@ -0,0 +1,152 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_EMBEDDER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_EMBEDDER_H_ + +#include <vector> + +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/core/shims/cc/kernels/register.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h" +#include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/embeddings_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_embedder_options_proto_inc.h" + +namespace tflite { +namespace task { +namespace vision { + +// Performs dense feature vector extraction on images. +// +// The API expects a TFLite model with optional, but strongly recommended, +// TFLite Model Metadata. +// +// Input tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - image input of size `[batch x height x width x channels]`. +// - batch inference is not supported (`batch` is required to be 1). +// - only RGB inputs are supported (`channels` is required to be 3). +// - if type is kTfLiteFloat32, NormalizationOptions are required to be +// attached to the metadata for input normalization. +// At least one output tensor with: +// (kTfLiteUInt8/kTfLiteFloat32) +// - `N` components corresponding to the `N` dimensions of the returned +// feature vector for this output layer. +// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. +// +// TODO(b/180502532): add pointer to example model. +// +// A CLI demo tool is available for easily trying out this API, and provides +// example usage. See: +// examples/task/vision/desktop/image_embedder_demo.cc +class ImageEmbedder + : public tflite::task::vision::BaseVisionTaskApi<EmbeddingResult> { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates an ImageEmbedder from the provided options. A non-default + // OpResolver can be specified in order to support custom Ops or specify a + // subset of built-in Ops. + static tflite::support::StatusOr<std::unique_ptr<ImageEmbedder>> + CreateFromOptions( + const ImageEmbedderOptions& options, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); + + // Performs actual feature vector extraction on the provided FrameBuffer. + // + // The FrameBuffer can be of any size and any of the supported formats, i.e. + // RGBA, RGB, NV12, NV21, YV12, YV21. It is automatically pre-processed before + // inference in order to (and in this order): + // - resize it (with bilinear interpolation, aspect-ratio *not* preserved) to + // the dimensions of the model input tensor, + // - convert it to the colorspace of the input tensor (i.e. RGB, which is the + // only supported colorspace for now), + // - rotate it according to its `Orientation` so that inference is performed + // on an "upright" image. + tflite::support::StatusOr<EmbeddingResult> Embed( + const FrameBuffer& frame_buffer); + + // Same as above, except the inference is performed only on the provided + // region of interest. Note that the region of interest is not clamped, so + // this method will fail if the region is out of bounds of the input image. + tflite::support::StatusOr<EmbeddingResult> Embed( + const FrameBuffer& frame_buffer, + const BoundingBox& roi); + + // Returns the Embedding output by the output_index'th layer. In (the most + // common) case where a single embedding is produced, you can just call + // GetEmbeddingByIndex(result, 0). + // Returns an empty Embedding if `output_index` is out of bounds. + Embedding GetEmbeddingByIndex(const EmbeddingResult& result, + int output_index); + + // Returns the dimensionality of the embedding output by the output_index'th + // output layer. Returns -1 if `output_index` is out of bounds. + int GetEmbeddingDimension(int output_index) const; + + // Returns the number of output layers of the model. + int GetNumberOfOutputLayers() const; + + // Utility function to compute cosine similarity [1] between two feature + // vectors. May return an InvalidArgumentError if e.g. the feature vectors are + // of different types (quantized vs. float), have different sizes, or have a + // an L2-norm of 0. + // + // [1]: https://en.wikipedia.org/wiki/Cosine_similarity + static tflite::support::StatusOr<double> CosineSimilarity( + const FeatureVector& u, + const FeatureVector& v); + + protected: + // The options used to build this ImageEmbedder. + std::unique_ptr<ImageEmbedderOptions> options_; + + // Post-processing to transform the raw model outputs into embedding results. + tflite::support::StatusOr<EmbeddingResult> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const FrameBuffer& frame_buffer, + const BoundingBox& roi) override; + + // Performs pre-initialization actions. + virtual absl::Status PreInit(); + // Performs post-initialization actions. + virtual absl::Status PostInit(); + + // Initializes the ImageEmbedder. + absl::Status Init(std::unique_ptr<ImageEmbedderOptions> options); + + // Performs scalar quantization on a feature vector whose elements are + // assumed to lie in the range [-1.0, 1.0] (values outside this range will be + // clamped to -128 or 127). + void QuantizeFeatureVector(FeatureVector* feature_vector) const; + + private: + std::vector<std::unique_ptr<processor::EmbeddingPostprocessor>> + postprocessors_; +}; + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_EMBEDDER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc index 08ea3b9..20a34a9 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc
@@ -17,9 +17,9 @@ #include <algorithm> -#include "absl/memory/memory.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" #include "tensorflow_lite_support/cc/common.h" @@ -52,7 +52,6 @@ // segmentation masks are stored with 8 bit per pixel (flattened byte array). constexpr uint32 kMaxNumClasses = 256; -// TODO(b/) // The colormap used to fill `ColoredLabel`-s, as a flattened array of 256 {R, // G, B} components. constexpr uint8 kColorMap[768] = { @@ -138,10 +137,14 @@ /* static */ absl::Status ImageSegmenter::SanityCheckOptions( const ImageSegmenterOptions& options) { - if (!options.has_model_file_with_metadata()) { + int num_input_models = (options.base_options().has_model_file() ? 1 : 0) + + (options.has_model_file_with_metadata() ? 1 : 0); + if (num_input_models != 1) { return CreateStatusWithPayload( StatusCode::kInvalidArgument, - "Missing mandatory `model_file_with_metadata` field", + absl::StrFormat("Expected exactly one of `base_options.model_file` or " + "`model_file_with_metadata` to be provided, found %d.", + num_input_models), TfLiteSupportStatus::kInvalidArgumentError); } if (options.output_type() == ImageSegmenterOptions::UNSPECIFIED) { @@ -167,10 +170,25 @@ // Copy options to ensure the ExternalFile outlives the constructed object. auto options_copy = absl::make_unique<ImageSegmenterOptions>(options); - ASSIGN_OR_RETURN(auto image_segmenter, - TaskAPIFactory::CreateFromExternalFileProto<ImageSegmenter>( - &options_copy->model_file_with_metadata(), - std::move(resolver), options_copy->num_threads())); + std::unique_ptr<ImageSegmenter> image_segmenter; + if (options_copy->has_model_file_with_metadata()) { + ASSIGN_OR_RETURN( + image_segmenter, + TaskAPIFactory::CreateFromExternalFileProto<ImageSegmenter>( + &options_copy->model_file_with_metadata(), std::move(resolver), + options_copy->num_threads(), options_copy->compute_settings())); + } else if (options_copy->base_options().has_model_file()) { + ASSIGN_OR_RETURN(image_segmenter, + TaskAPIFactory::CreateFromBaseOptions<ImageSegmenter>( + &options_copy->base_options(), std::move(resolver))); + } else { + // Should never happen because of SanityCheckOptions. + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Expected exactly one of `base_options.model_file` or " + "`model_file_with_metadata` to be provided, found 0."), + TfLiteSupportStatus::kInvalidArgumentError); + } RETURN_IF_ERROR(image_segmenter->Init(std::move(options_copy))); @@ -203,7 +221,8 @@ absl::Status ImageSegmenter::CheckAndSetOutputs() { // First, sanity checks on the model itself. - const TfLiteEngine::Interpreter* interpreter = engine_->interpreter(); + const TfLiteEngine::Interpreter* interpreter = + GetTfLiteEngine()->interpreter(); // Check the number of output tensors. if (TfLiteEngine::OutputCount(interpreter) != 1) { @@ -257,7 +276,7 @@ // Build label map from metadata, if available. const ModelMetadataExtractor* metadata_extractor = - engine_->metadata_extractor(); + GetTfLiteEngine()->metadata_extractor(); const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>* output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata(); if (output_tensor_metadata != nullptr) { @@ -374,8 +393,9 @@ int class_index = 0; float max_confidence = 0.0f; for (int d = 0; d < output_depth_; ++d) { - const float confidence = - GetOutputConfidence(*output_tensor, tensor_x, tensor_y, d); + ASSIGN_OR_RETURN( + const float confidence, + GetOutputConfidence(*output_tensor, tensor_x, tensor_y, d)); if (confidence > max_confidence) { class_index = d; max_confidence = confidence; @@ -401,8 +421,10 @@ /*to_x=*/&tensor_x, /*to_y=*/&tensor_y); for (int d = 0; d < output_depth_; ++d) { - confidence_masks->mutable_confidence_mask(d)->add_value( + ASSIGN_OR_RETURN( + float confidence, GetOutputConfidence(*output_tensor, tensor_x, tensor_y, d)); + confidence_masks->mutable_confidence_mask(d)->add_value(confidence); } } } @@ -411,17 +433,20 @@ return result; } -float ImageSegmenter::GetOutputConfidence(const TfLiteTensor& output_tensor, - int x, - int y, - int depth) { +StatusOr<float> ImageSegmenter::GetOutputConfidence( + const TfLiteTensor& output_tensor, + int x, + int y, + int depth) { int index = output_width_ * output_depth_ * y + output_depth_ * x + depth; if (has_uint8_outputs_) { - const uint8* data = AssertAndReturnTypedTensor<uint8>(&output_tensor); + ASSIGN_OR_RETURN(const uint8* data, + AssertAndReturnTypedTensor<uint8>(&output_tensor)); return output_tensor.params.scale * (static_cast<int>(data[index]) - output_tensor.params.zero_point); } else { - const float* data = AssertAndReturnTypedTensor<float>(&output_tensor); + ASSIGN_OR_RETURN(const float* data, + AssertAndReturnTypedTensor<float>(&output_tensor)); return data[index]; } }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h index e1d484c..e255110 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h
@@ -19,8 +19,9 @@ #include <memory> #include <vector> -#include "absl/status/status.h" +#include "absl/status/status.h" // from @com_google_absl #include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/core/shims/cc/kernels/register.h" #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" #include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h" @@ -78,7 +79,7 @@ CreateFromOptions( const ImageSegmenterOptions& options, std::unique_ptr<tflite::OpResolver> resolver = - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); // Performs actual segmentation on the provided FrameBuffer. // @@ -147,10 +148,11 @@ // Returns the output confidence at coordinates {x, y, depth}, dequantizing // on-the-fly if needed (i.e. if `has_uint8_outputs_` is true). - float GetOutputConfidence(const TfLiteTensor& output_tensor, - int x, - int y, - int depth); + tflite::support::StatusOr<float> GetOutputConfidence( + const TfLiteTensor& output_tensor, + int x, + int y, + int depth); // Prebuilt list of ColoredLabel attached to each Segmentation result. The // i-th item in this list corresponds to the i-th label map item.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc index e0f3149..3eb5126 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc
@@ -19,11 +19,12 @@ #include <limits> #include <vector> -#include "absl/memory/memory.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" +#include <glog/logging.h> +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/interpreter.h" #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/status_macros.h" #include "tensorflow_lite_support/cc/task/core/task_api_factory.h" @@ -33,6 +34,7 @@ #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h" #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" @@ -55,12 +57,27 @@ using ::tflite::support::StatusOr; using ::tflite::support::TfLiteSupportStatus; using ::tflite::task::core::AssertAndReturnTypedTensor; +using ::tflite::task::core::FindIndexByMetadataTensorName; using ::tflite::task::core::TaskAPIFactory; using ::tflite::task::core::TfLiteEngine; // The expected number of dimensions of the 4 output tensors, representing in -// that order: locations, classes, scores, num_results. +// that order: locations, categories, scores, num_results. The order is +// coming from the TFLite custom NMS op for object detection post-processing +// shown in +// https://github.com/tensorflow/tensorflow/blob/1c419b231b622bd9e9685682545e9064f0fbb42a/tensorflow/lite/kernels/detection_postprocess.cc#L47. static constexpr int kOutputTensorsExpectedDims[4] = {3, 2, 2, 1}; +constexpr int kDefaultLocationsIndex = 0; +constexpr int kDefaultClassesIndex = 1; +constexpr int kDefaultScoresIndex = 2; +constexpr int kDefaultNumResultsIndex = 3; + +constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest(); + +constexpr char kLocationTensorName[] = "location"; +constexpr char kCategoryTensorName[] = "category"; +constexpr char kScoreTensorName[] = "score"; +constexpr char kNumberOfDetectionsTensorName[] = "number of detections"; StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties( const TensorMetadata& tensor_metadata) { @@ -138,7 +155,7 @@ ModelMetadataExtractor::FindFirstAssociatedFileName( tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS, locale); - absl::string_view display_names_file = nullptr; + absl::string_view display_names_file; if (!display_names_filename.empty()) { ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile( display_names_filename)); @@ -154,14 +171,44 @@ metadata_extractor.FindFirstProcessUnit( tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions)); if (score_thresholding_process_unit == nullptr) { - return std::numeric_limits<float>::lowest(); + return kDefaultScoreThreshold; } return score_thresholding_process_unit->options_as_ScoreThresholdingOptions() ->global_score_threshold(); } +// Use tensor names in metadata to get the output order. +std::vector<int> GetOutputIndices( + const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>* + tensor_metadatas) { + std::vector<int> output_indices = { + FindIndexByMetadataTensorName(tensor_metadatas, kLocationTensorName), + FindIndexByMetadataTensorName(tensor_metadatas, kCategoryTensorName), + FindIndexByMetadataTensorName(tensor_metadatas, kScoreTensorName), + FindIndexByMetadataTensorName(tensor_metadatas, + kNumberOfDetectionsTensorName)}; + + for (int i = 0; i < 4; i++) { + int output_index = output_indices[i]; + // If tensor name is not found, set the default output indices. + if (output_index == -1) { + LOG(WARNING) << absl::StrFormat( + "You don't seem to be matching tensor names in metadata list. The " + "tensor name \"%s\" at index %d in the model metadata doesn't match " + "the available output names: [\"%s\", \"%s\", \"%s\", \"%s\"].", + tensor_metadatas->Get(i)->name()->c_str(), i, kLocationTensorName, + kCategoryTensorName, kScoreTensorName, kNumberOfDetectionsTensorName); + output_indices = {kDefaultLocationsIndex, kDefaultClassesIndex, + kDefaultScoresIndex, kDefaultNumResultsIndex}; + return output_indices; + } + } + return output_indices; +} + absl::Status SanityCheckOutputTensors( - const std::vector<const TfLiteTensor*>& output_tensors) { + const std::vector<const TfLiteTensor*>& output_tensors, + const std::vector<int>& output_indices) { if (output_tensors.size() != 4) { return CreateStatusWithPayload( StatusCode::kInternal, @@ -170,48 +217,55 @@ } // Get number of results. - if (output_tensors[3]->dims->data[0] != 1) { + const TfLiteTensor* num_results_tensor = output_tensors[output_indices[3]]; + if (num_results_tensor->dims->data[0] != 1) { return CreateStatusWithPayload( StatusCode::kInternal, absl::StrFormat( "Expected tensor with dimensions [1] at index 3, found [%d]", - output_tensors[3]->dims->data[0])); + num_results_tensor->dims->data[0])); } - int num_results = - static_cast<int>(AssertAndReturnTypedTensor<float>(output_tensors[3])[0]); + ASSIGN_OR_RETURN(float* num_results_data, + AssertAndReturnTypedTensor<float>(num_results_tensor)); + int num_results = static_cast<int>(num_results_data[0]); + + const TfLiteTensor* location_tensor = output_tensors[output_indices[0]]; // Check dimensions for the other tensors are correct. - if (output_tensors[0]->dims->data[0] != 1 || - output_tensors[0]->dims->data[1] != num_results || - output_tensors[0]->dims->data[2] != 4) { + if (location_tensor->dims->data[0] != 1 || + location_tensor->dims->data[1] < num_results || + location_tensor->dims->data[2] != 4) { return CreateStatusWithPayload( StatusCode::kInternal, absl::StrFormat( - "Expected locations tensor with dimensions [1,%d,4] at index 0, " - "found [%d,%d,%d].", - num_results, output_tensors[0]->dims->data[0], - output_tensors[0]->dims->data[1], - output_tensors[0]->dims->data[2])); + "Expected locations tensor with dimensions [1, num_detected_boxes, " + "4] at index 0, num_detected_boxes >= %d, found [%d,%d,%d].", + num_results, location_tensor->dims->data[0], + location_tensor->dims->data[1], location_tensor->dims->data[2])); } - if (output_tensors[1]->dims->data[0] != 1 || - output_tensors[1]->dims->data[1] != num_results) { + + const TfLiteTensor* class_tensor = output_tensors[output_indices[1]]; + if (class_tensor->dims->data[0] != 1 || + class_tensor->dims->data[1] < num_results) { return CreateStatusWithPayload( StatusCode::kInternal, absl::StrFormat( - "Expected classes tensor with dimensions [1,%d] at index 1, " - "found [%d,%d].", - num_results, output_tensors[1]->dims->data[0], - output_tensors[1]->dims->data[1])); + "Expected classes tensor with dimensions [1, num_detected_boxes] " + "at index 1, num_detected_boxes >= %d, found [%d,%d].", + num_results, class_tensor->dims->data[0], + class_tensor->dims->data[1])); } - if (output_tensors[2]->dims->data[0] != 1 || - output_tensors[2]->dims->data[1] != num_results) { + + const TfLiteTensor* scores_tensor = output_tensors[output_indices[2]]; + if (scores_tensor->dims->data[0] != 1 || + scores_tensor->dims->data[1] < num_results) { return CreateStatusWithPayload( StatusCode::kInternal, absl::StrFormat( - "Expected scores tensor with dimensions [1,%d] at index 2, " - "found [%d,%d].", - num_results, output_tensors[2]->dims->data[0], - output_tensors[2]->dims->data[1])); + "Expected scores tensor with dimensions [1, num_detected_boxes] " + "at index 2, num_detected_boxes >= %d, found [%d,%d].", + num_results, scores_tensor->dims->data[0], + scores_tensor->dims->data[1])); } return absl::OkStatus(); @@ -222,10 +276,14 @@ /* static */ absl::Status ObjectDetector::SanityCheckOptions( const ObjectDetectorOptions& options) { - if (!options.has_model_file_with_metadata()) { + int num_input_models = (options.base_options().has_model_file() ? 1 : 0) + + (options.has_model_file_with_metadata() ? 1 : 0); + if (num_input_models != 1) { return CreateStatusWithPayload( StatusCode::kInvalidArgument, - "Missing mandatory `model_file_with_metadata` field", + absl::StrFormat("Expected exactly one of `base_options.model_file` or " + "`model_file_with_metadata` to be provided, found %d.", + num_input_models), TfLiteSupportStatus::kInvalidArgumentError); } if (options.max_results() == 0) { @@ -260,10 +318,25 @@ // Copy options to ensure the ExternalFile outlives the constructed object. auto options_copy = absl::make_unique<ObjectDetectorOptions>(options); - ASSIGN_OR_RETURN(auto object_detector, - TaskAPIFactory::CreateFromExternalFileProto<ObjectDetector>( - &options_copy->model_file_with_metadata(), - std::move(resolver), options_copy->num_threads())); + std::unique_ptr<ObjectDetector> object_detector; + if (options_copy->has_model_file_with_metadata()) { + ASSIGN_OR_RETURN( + object_detector, + TaskAPIFactory::CreateFromExternalFileProto<ObjectDetector>( + &options_copy->model_file_with_metadata(), std::move(resolver), + options_copy->num_threads(), options_copy->compute_settings())); + } else if (options_copy->base_options().has_model_file()) { + ASSIGN_OR_RETURN(object_detector, + TaskAPIFactory::CreateFromBaseOptions<ObjectDetector>( + &options_copy->base_options(), std::move(resolver))); + } else { + // Should never happen because of SanityCheckOptions. + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Expected exactly one of `base_options.model_file` or " + "`model_file_with_metadata` to be provided, found 0."), + TfLiteSupportStatus::kInvalidArgumentError); + } RETURN_IF_ERROR(object_detector->Init(std::move(options_copy))); @@ -286,6 +359,10 @@ // Initialize class whitelisting/blacklisting, if any. RETURN_IF_ERROR(CheckAndSetClassIndexSet()); + // Perform final initialization (by default, initialize score calibration + // parameters, if any). + RETURN_IF_ERROR(PostInit()); + return absl::OkStatus(); } @@ -294,9 +371,83 @@ return absl::OkStatus(); } +absl::Status ObjectDetector::PostInit() { + return InitScoreCalibrations(); +} + +StatusOr<SigmoidCalibrationParameters> BuildCalibrationParametersIfAny( + const tflite::metadata::ModelMetadataExtractor& metadata_extractor, + const tflite::TensorMetadata& output_tensor_metadata, + const std::vector<LabelMapItem>& label_map_items, + bool* has_score_calibration) { + SigmoidCalibrationParameters sigmoid_params; + *has_score_calibration = false; + + // Build score calibration parameters, if present. + // + // TODO(tianlin): This code is similar to the block of classification_head.cc + // that does sanity checks and builds sigmoid calibration params in: + // https://github.com/tensorflow/tflite-support/blob/64e044408f3d3654de7fc10bca401ed900649ca3/tensorflow_lite_support/cc/task/vision/core/classification_head.cc#L75-L107 + // Consider to refactor it and reuse the same function. + ASSIGN_OR_RETURN(const tflite::ProcessUnit* score_calibration_process_unit, + ModelMetadataExtractor::FindFirstProcessUnit( + output_tensor_metadata, + tflite::ProcessUnitOptions_ScoreCalibrationOptions)); + if (score_calibration_process_unit != nullptr) { + const std::string score_calibration_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + output_tensor_metadata, + tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION); + ASSIGN_OR_RETURN( + absl::string_view score_calibration_file, + metadata_extractor.GetAssociatedFile(score_calibration_filename)); + + // Set has_score_calibration to true, only if sigmoid_params is built. + ASSIGN_OR_RETURN(sigmoid_params, + BuildSigmoidCalibrationParams( + *score_calibration_process_unit + ->options_as_ScoreCalibrationOptions(), + score_calibration_file, label_map_items)); + *has_score_calibration = true; + } + return sigmoid_params; +} + +absl::Status ObjectDetector::InitScoreCalibrations() { + StatusOr<SigmoidCalibrationParameters> calibration_params_status; + bool has_score_calibration = false; + + // Search the output tensor metadata, can try to get calibration_params. + const ModelMetadataExtractor* metadata_extractor = + GetTfLiteEngine()->metadata_extractor(); + const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>* + output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata(); + const tflite::TensorMetadata* output_tensor = + output_tensor_metadata->Get(kDefaultScoresIndex); + ASSIGN_OR_RETURN( + auto calibration_params, + BuildCalibrationParametersIfAny(*metadata_extractor, *output_tensor, + label_map_, &has_score_calibration)); + + // If no calibration_params is found, just skip score calibration. + if (!has_score_calibration) { + return absl::OkStatus(); + } + + score_calibration_ = absl::make_unique<ScoreCalibration>(); + if (score_calibration_ == nullptr) { + return CreateStatusWithPayload( + StatusCode::kInternal, "Could not create score calibration object."); + } + RETURN_IF_ERROR( + score_calibration_->InitializeFromParameters(calibration_params)); + return absl::OkStatus(); +} + absl::Status ObjectDetector::CheckAndSetOutputs() { // First, sanity checks on the model itself. - const TfLiteEngine::Interpreter* interpreter = engine_->interpreter(); + const TfLiteEngine::Interpreter* interpreter = + GetTfLiteEngine()->interpreter(); // Check the number of output tensors. if (TfLiteEngine::OutputCount(interpreter) != 4) { return CreateStatusWithPayload( @@ -306,29 +457,10 @@ TfLiteEngine::OutputCount(interpreter)), TfLiteSupportStatus::kInvalidNumOutputTensorsError); } - // Check tensor dimensions and batch size. - for (int i = 0; i < 4; ++i) { - const TfLiteTensor* tensor = TfLiteEngine::GetOutput(interpreter, i); - if (tensor->dims->size != kOutputTensorsExpectedDims[i]) { - return CreateStatusWithPayload( - StatusCode::kInvalidArgument, - absl::StrFormat("Output tensor at index %d is expected to " - "have %d dimensions, found %d.", - i, kOutputTensorsExpectedDims[i], tensor->dims->size), - TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); - } - if (tensor->dims->data[0] != 1) { - return CreateStatusWithPayload( - StatusCode::kInvalidArgument, - absl::StrFormat("Expected batch size of 1, found %d.", - tensor->dims->data[0]), - TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); - } - } // Now, perform sanity checks and extract metadata. const ModelMetadataExtractor* metadata_extractor = - engine_->metadata_extractor(); + GetTfLiteEngine()->metadata_extractor(); // Check that metadata is available. if (metadata_extractor->GetModelMetadata() == nullptr || metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { @@ -352,10 +484,13 @@ TfLiteSupportStatus::kMetadataInconsistencyError); } + output_indices_ = GetOutputIndices(output_tensors_metadata); + // Extract mandatory BoundingBoxProperties for easier access at // post-processing time, performing sanity checks on the fly. ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties, - GetBoundingBoxProperties(*output_tensors_metadata->Get(0))); + GetBoundingBoxProperties( + *output_tensors_metadata->Get(output_indices_[0]))); if (bounding_box_properties->index() == nullptr) { bounding_box_corners_order_ = {0, 1, 2, 3}; } else { @@ -371,16 +506,39 @@ // Build label map (if available) from metadata. ASSIGN_OR_RETURN( label_map_, - GetLabelMapIfAny(*metadata_extractor, *output_tensors_metadata->Get(1), + GetLabelMapIfAny(*metadata_extractor, + *output_tensors_metadata->Get(output_indices_[1]), options_->display_names_locale())); // Set score threshold. if (options_->has_score_threshold()) { score_threshold_ = options_->score_threshold(); } else { - ASSIGN_OR_RETURN(score_threshold_, - GetScoreThreshold(*metadata_extractor, - *output_tensors_metadata->Get(2))); + ASSIGN_OR_RETURN( + score_threshold_, + GetScoreThreshold(*metadata_extractor, + *output_tensors_metadata->Get(output_indices_[2]))); + } + + // Check tensor dimensions and batch size. + for (int i = 0; i < 4; ++i) { + std::size_t j = output_indices_[i]; + const TfLiteTensor* tensor = TfLiteEngine::GetOutput(interpreter, j); + if (tensor->dims->size != kOutputTensorsExpectedDims[i]) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Output tensor at index %d is expected to " + "have %d dimensions, found %d.", + j, kOutputTensorsExpectedDims[i], tensor->dims->size), + TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + if (tensor->dims->data[0] != 1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Expected batch size of 1, found %d.", + tensor->dims->data[0]), + TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } } return absl::OkStatus(); @@ -439,9 +597,7 @@ BoundingBox roi; roi.set_width(frame_buffer.dimension().width); roi.set_height(frame_buffer.dimension().height); - // Rely on `Infer` instead of `InferWithFallback` as DetectionPostprocessing - // op doesn't support hardware acceleration at the time. - return Infer(frame_buffer, roi); + return InferWithFallback(frame_buffer, roi); } StatusOr<DetectionResult> ObjectDetector::Postprocess( @@ -451,11 +607,13 @@ // Most of the checks here should never happen, as outputs have been validated // at construction time. Checking nonetheless and returning internal errors if // something bad happens. - RETURN_IF_ERROR(SanityCheckOutputTensors(output_tensors)); + RETURN_IF_ERROR(SanityCheckOutputTensors(output_tensors, output_indices_)); // Get number of available results. - const int num_results = - static_cast<int>(AssertAndReturnTypedTensor<float>(output_tensors[3])[0]); + ASSIGN_OR_RETURN( + float* num_results_data, + AssertAndReturnTypedTensor<float>(output_tensors[output_indices_[3]])); + const int num_results = static_cast<int>(num_results_data[0]); // Compute number of max results to return. const int max_results = options_->max_results() > 0 ? std::min(options_->max_results(), num_results) @@ -469,16 +627,32 @@ upright_input_frame_dimensions.Swap(); } - const float* locations = AssertAndReturnTypedTensor<float>(output_tensors[0]); - const float* classes = AssertAndReturnTypedTensor<float>(output_tensors[1]); - const float* scores = AssertAndReturnTypedTensor<float>(output_tensors[2]); + ASSIGN_OR_RETURN( + const float* locations, + AssertAndReturnTypedTensor<float>(output_tensors[output_indices_[0]])); + ASSIGN_OR_RETURN( + const float* classes, + AssertAndReturnTypedTensor<float>(output_tensors[output_indices_[1]])); + ASSIGN_OR_RETURN( + const float* scores, + AssertAndReturnTypedTensor<float>(output_tensors[output_indices_[2]])); DetectionResult results; for (int i = 0; i < num_results; ++i) { const int class_index = static_cast<int>(classes[i]); - const float score = scores[i]; - if (!IsClassIndexAllowed(class_index) || score < score_threshold_) { + if (!IsClassIndexAllowed(class_index)) { continue; } + + float score = scores[i]; + // Calibrate score only if score_calibration_ is presented. + if (score_calibration_ != nullptr) { + const std::string& class_name = label_map_[class_index].name; + score = score_calibration_->ComputeCalibratedScore(class_name, score); + } + if (score <= score_threshold_) { + continue; + } + Detection* detection = results.add_detections(); // Denormalize the bounding box cooordinates in the upright frame // coordinates system, then rotate back from frame_buffer.orientation() to
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h index 5305407..c37fa877 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h
@@ -18,9 +18,10 @@ #include <memory> -#include "absl/container/flat_hash_set.h" -#include "absl/status/status.h" +#include "absl/container/flat_hash_set.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl #include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/core/shims/cc/kernels/register.h" #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" #include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h" @@ -28,6 +29,7 @@ #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" #include "tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h" namespace tflite { namespace task { @@ -84,7 +86,7 @@ CreateFromOptions( const ObjectDetectorOptions& options, std::unique_ptr<tflite::OpResolver> resolver = - absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); // Performs actual detection on the provided FrameBuffer. // @@ -134,6 +136,9 @@ // Performs pre-initialization actions. virtual absl::Status PreInit(); + // Performs post-initialization actions. + virtual absl::Status PostInit(); + private: // Performs sanity checks on the model outputs and extracts their metadata. absl::Status CheckAndSetOutputs(); @@ -147,6 +152,10 @@ // Always returns true if no whitelist or blacklist were provided. bool IsClassIndexAllowed(int class_index); + // Initializes the score calibration parameters based on corresponding TFLite + // Model Metadata, if any. + absl::Status InitScoreCalibrations(); + // Given a DetectionResult object containing class indices, fills the name and // display name from the label map. absl::Status FillResultsFromLabelMap(DetectionResult* result); @@ -178,6 +187,15 @@ // discarded. If none is provided via metadata or options, -FLT_MAX is set as // default value. float score_threshold_; + + // List of score calibration parameters, if any. Built from TFLite Model + // Metadata. + std::unique_ptr<ScoreCalibration> score_calibration_; + + // Indices of the output tensors to match the output tensors to the correct + // index order of the output tensors: [location, categories, scores, + // num_detections]. + std::vector<int> output_indices_; }; } // namespace vision
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/BUILD index d1430eb..16ea0cd 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/BUILD
@@ -1,8 +1,6 @@ -load("//tensorflow_lite_support/cc/port:build_defs.bzl", "support_cc_proto_library") - package( default_visibility = [ - "//tensorflow_lite_support:users", + "//visibility:public", ], licenses = ["notice"], # Apache 2.0 ) @@ -14,9 +12,8 @@ srcs = ["bounding_box.proto"], ) -support_cc_proto_library( +cc_proto_library( name = "bounding_box_cc_proto", - srcs = ["bounding_box.proto"], deps = [ ":bounding_box_proto", ], @@ -33,9 +30,8 @@ srcs = ["class.proto"], ) -support_cc_proto_library( +cc_proto_library( name = "class_cc_proto", - srcs = ["class.proto"], deps = [ ":class_proto", ], @@ -53,14 +49,14 @@ name = "object_detector_options_proto", srcs = ["object_detector_options.proto"], deps = [ + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto", "//tensorflow_lite_support/cc/task/core/proto:external_file_proto", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_proto", ], ) -support_cc_proto_library( +cc_proto_library( name = "object_detector_options_cc_proto", - srcs = ["object_detector_options.proto"], - cc_deps = ["//tensorflow_lite_support/cc/task/core/proto:external_file_cc_proto"], deps = [ ":object_detector_options_proto", ], @@ -71,6 +67,7 @@ hdrs = ["object_detector_options_proto_inc.h"], deps = [ ":object_detector_options_cc_proto", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", ], ) @@ -84,13 +81,8 @@ ], ) -support_cc_proto_library( +cc_proto_library( name = "detections_cc_proto", - srcs = ["detections.proto"], - cc_deps = [ - ":bounding_box_cc_proto", - ":class_cc_proto", - ], deps = [ ":detections_proto", ], @@ -112,14 +104,14 @@ name = "image_classifier_options_proto", srcs = ["image_classifier_options.proto"], deps = [ + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto", "//tensorflow_lite_support/cc/task/core/proto:external_file_proto", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_proto", ], ) -support_cc_proto_library( +cc_proto_library( name = "image_classifier_options_cc_proto", - srcs = ["image_classifier_options.proto"], - cc_deps = ["//tensorflow_lite_support/cc/task/core/proto:external_file_cc_proto"], deps = [ ":image_classifier_options_proto", ], @@ -130,6 +122,7 @@ hdrs = ["image_classifier_options_proto_inc.h"], deps = [ ":image_classifier_options_cc_proto", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", ], ) @@ -142,10 +135,8 @@ ], ) -support_cc_proto_library( +cc_proto_library( name = "classifications_cc_proto", - srcs = ["classifications.proto"], - cc_deps = [":class_cc_proto"], deps = [ ":classifications_proto", ], @@ -166,14 +157,14 @@ name = "image_segmenter_options_proto", srcs = ["image_segmenter_options.proto"], deps = [ + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto", "//tensorflow_lite_support/cc/task/core/proto:external_file_proto", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_proto", ], ) -support_cc_proto_library( +cc_proto_library( name = "image_segmenter_options_cc_proto", - srcs = ["image_segmenter_options.proto"], - cc_deps = ["//tensorflow_lite_support/cc/task/core/proto:external_file_cc_proto"], deps = [ ":image_segmenter_options_proto", ], @@ -184,6 +175,7 @@ hdrs = ["image_segmenter_options_proto_inc.h"], deps = [ ":image_segmenter_options_cc_proto", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", ], ) @@ -193,9 +185,8 @@ srcs = ["segmentations.proto"], ) -support_cc_proto_library( +cc_proto_library( name = "segmentations_cc_proto", - srcs = ["segmentations.proto"], deps = [ ":segmentations_proto", ], @@ -206,3 +197,48 @@ hdrs = ["segmentations_proto_inc.h"], deps = [":segmentations_cc_proto"], ) + +# ImageEmbedder protos. + +proto_library( + name = "image_embedder_options_proto", + srcs = ["image_embedder_options.proto"], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_proto", + ], +) + +cc_proto_library( + name = "image_embedder_options_cc_proto", + deps = [ + ":image_embedder_options_proto", + ], +) + +cc_library( + name = "image_embedder_options_proto_inc", + hdrs = ["image_embedder_options_proto_inc.h"], + deps = [ + ":image_embedder_options_cc_proto", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + ], +) + +proto_library( + name = "embeddings_proto", + srcs = ["embeddings.proto"], +) + +cc_proto_library( + name = "embeddings_cc_proto", + deps = [ + ":embeddings_proto", + ], +) + +cc_library( + name = "embeddings_proto_inc", + hdrs = ["embeddings_proto_inc.h"], + deps = [":embeddings_cc_proto"], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h index 2f9a409..38547ed 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h
@@ -13,8 +13,8 @@ limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASS_PROTO_INC_H_ -#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASS_PROTO_INC_H_ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASS_PROTO_INC_H_ #include "tensorflow_lite_support/cc/task/vision/proto/class.pb.h" -#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASS_PROTO_INC_H_ +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h index 62a5f117..1b3d538 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h
@@ -13,10 +13,10 @@ limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASSIFICATIONS_PROTO_INC_H_ -#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASSIFICATIONS_PROTO_INC_H_ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASSIFICATIONS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASSIFICATIONS_PROTO_INC_H_ #include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/proto/classifications.pb.h" -#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASSIFICATIONS_PROTO_INC_H_ +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASSIFICATIONS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h index 2b63cad..c702511 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h
@@ -13,11 +13,11 @@ limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_DETECTIONS_PROTO_INC_H_ -#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_DETECTIONS_PROTO_INC_H_ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_DETECTIONS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_DETECTIONS_PROTO_INC_H_ #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/proto/detections.pb.h" -#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_DETECTIONS_PROTO_INC_H_ +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_DETECTIONS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/embeddings.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/embeddings.proto new file mode 100644 index 0000000..bff5faf --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/embeddings.proto
@@ -0,0 +1,48 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.vision; + +// Defines a dense feature vector. Only one of the two fields is ever present. +// Feature vectors are assumed to be one-dimensional and L2-normalized. +message FeatureVector { + // Raw output of the embedding layer. Only provided if `quantize` is set to + // false in the ImageEmbedderOptions, which is the case by default. + repeated float value_float = 1 [packed = true]; + // Scalar-quantized embedding. Only provided if `quantize` is set to true in + // the ImageEmbedderOptions. + optional bytes value_string = 2; +} + +// Result produced by one of the embedder model output layers. +message Embedding { + // The output feature vector. + optional FeatureVector feature_vector = 1; + // The index of the model output layer that produced this feature vector. + optional int32 output_index = 2; +} + +// Embeddings produced by the ImageEmbedder. +message EmbeddingResult { + // The embeddings produced by each of the model output layers. + // + // Except in advanced cases, the embedding model has a single output layer, + // and this list is thus made of a single element. + repeated Embedding embeddings = 1; + // Reserved tags. + reserved 2; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/embeddings_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/embeddings_proto_inc.h new file mode 100644 index 0000000..554aa98 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/embeddings_proto_inc.h
@@ -0,0 +1,20 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_EMBEDDINGS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_EMBEDDINGS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/vision/proto/embeddings.pb.h" +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_EMBEDDINGS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto index 24cd85f..50b7682 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto
@@ -17,20 +17,29 @@ package tflite.task.vision; +import "tensorflow/lite/experimental/acceleration/configuration/configuration.proto"; +import "tensorflow_lite_support/cc/task/core/proto/base_options.proto"; import "tensorflow_lite_support/cc/task/core/proto/external_file.proto"; // Options for setting up an ImageClassifier. -// Next Id: 14 +// Next Id: 15 message ImageClassifierOptions { - // The external model file, as a single standalone TFLite file. If it is - // packed with TFLite Model Metadata [1], those are used to populate e.g. the - // label map, score calibration and recommended score thresholds. Models - // without any such metadata or partial metadata are supported, but may result - // in the image classifier providing degraded functionality; typically, a - // model that doesn't contain any label map won't be able to return any class - // or display names but will be limited to returning class indices. + // Base options for configuring Task library, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional tflite.task.core.BaseOptions base_options = 14; + + // Legacy method for specifying the TFLite model file. If it is packed with + // TFLite Model Metadata [1], those are used to populate e.g. the label map, + // score calibration and recommended score thresholds. Models without any such + // metadata or partial metadata are supported, but may result in the image + // classifier providing degraded functionality; typically, a model that + // doesn't contain any label map won't be able to return any class or display + // names but will be limited to returning class indices. // // [1]: https://www.tensorflow.org/lite/convert/metadata + // + // Deprecated: prefer using `base_options.model_file`, which is mutually + // exclusive with this field. optional core.ExternalFile model_file_with_metadata = 10; // The locale to use for display names specified through the TFLite Model @@ -42,7 +51,7 @@ // returned. optional int32 max_results = 2 [default = -1]; - // Score threshold in [0,1), overrides the ones provided in the model metadata + // Score threshold, overrides the ones provided in the model metadata // (if any). Results below this value are rejected. optional float score_threshold = 3; @@ -56,12 +65,35 @@ // class names are ignored. Mutually exclusive with class_name_whitelist. repeated string class_name_blacklist = 5; - // The number of threads to be used for TFLite ops that support - // multi-threading when running inference with CPU. + // Legacy method for specifying the number of threads to be used for TFLite + // ops that support multi-threading when running inference with CPU. // num_threads should be greater than 0 or equal to -1. Setting num_threads to // -1 has the effect to let TFLite runtime set the value. + // + // Deprecated: only works with `model_file_with_metadata`. Prefer using + // `base_options` to specifying the TFLite model and using + // `base_options.compute_settings.tflite_settings.cpu_settings.num_threads`, + // to configure the number of threads. optional int32 num_threads = 13 [default = -1]; + // Legacy method for specifying how to accelerate the model + // inference using dedicated delegates. Supported delegate type includes: + // NONE, NNAPI, GPU, HEXAGON, XNNPACK, EDGETPU (Google internal), + // and EDGETPU_CORAL. + // + // IMPORTANT: in order to use a delegate, the appropriate delegate plugin + // needs to be linked at build time. See comment above the "image_classifier" + // target at: + // https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/vision/BUILD + // + // See settings definition at: + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/configuration.proto + // + // Deprecated: only works with `model_file_with_metadata`. Prefer using + // `base_options` to specifying the TFLite model and using + // `base_options.compute_settings` to configure acceleration options. + optional tflite.proto.ComputeSettings compute_settings = 9; + // Reserved tags. - reserved 1, 6, 7, 8, 9, 12; + reserved 1, 6, 7, 8, 12; }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h index 03dcd75..97a6deb4 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h
@@ -13,10 +13,11 @@ limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_CLASSIFIER_OPTIONS_PROTO_INC_H_ -#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_CLASSIFIER_OPTIONS_PROTO_INC_H_ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_CLASSIFIER_OPTIONS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_CLASSIFIER_OPTIONS_PROTO_INC_H_ +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.pb.h" -#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_CLASSIFIER_OPTIONS_PROTO_INC_H_ +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_CLASSIFIER_OPTIONS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_embedder_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_embedder_options.proto new file mode 100644 index 0000000..87693ad --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_embedder_options.proto
@@ -0,0 +1,66 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.vision; + +import "tensorflow/lite/experimental/acceleration/configuration/configuration.proto"; +import "tensorflow_lite_support/cc/task/core/proto/external_file.proto"; + +// Options for setting up an ImageEmbedder. +// Next Id: 10. +message ImageEmbedderOptions { + // The external model file, as a single standalone TFLite, optionally packed + // with TFLite Model Metadata [1]. Those are mandatory only if the model input + // is of float type (kTfLiteFloat32), which requires `NormalizationOptions` to + // be set on the input tensor metadata. + // + // [1]: https://www.tensorflow.org/lite/convert/metadata + optional core.ExternalFile model_file_with_metadata = 9; + + // Whether to normalize the returned feature vector with L2 norm. Use this + // option only if the model does not already contain a native L2_NORMALIZATION + // TF Lite Op. In most cases, this is already the case and L2 norm is thus + // achieved through TF Lite inference. + optional bool l2_normalize = 6; + + // Whether the returned embedding should be quantized to bytes via scalar + // quantization. Embeddings are implicitly assumed to be unit-norm and + // therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + // the l2_normalize option if this is not the case. + optional bool quantize = 2; + + // The number of threads allowed for model inference. This value is used in + // building the TF Lite interpreter. + optional int32 num_threads = 7 [default = 1]; + + // Advanced settings specifying how to accelerate the model inference using + // dedicated delegates. Supported delegate type includes: + // NONE, NNAPI, GPU, HEXAGON, XNNPACK, EDGETPU (Google internal), + // and EDGETPU_CORAL. + // + // IMPORTANT: in order to use a delegate, the appropriate delegate plugin + // needs to be linked at build time. See comment above the "image_embedder" + // target at: + // https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/vision/BUILD + // + // See settings definition at: + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/configuration.proto + optional tflite.proto.ComputeSettings compute_settings = 8; + + // Reserved tags. + reserved 1, 3, 4, 5; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_embedder_options_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_embedder_options_proto_inc.h new file mode 100644 index 0000000..606751e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_embedder_options_proto_inc.h
@@ -0,0 +1,22 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_EMBEDDER_OPTIONS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_EMBEDDER_OPTIONS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" + +#include "tensorflow_lite_support/cc/task/vision/proto/image_embedder_options.pb.h" +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_EMBEDDER_OPTIONS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options.proto index 3afed86aa..1aad5eb 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options.proto +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options.proto
@@ -17,12 +17,18 @@ package tflite.task.vision; +import "tensorflow/lite/experimental/acceleration/configuration/configuration.proto"; +import "tensorflow_lite_support/cc/task/core/proto/base_options.proto"; import "tensorflow_lite_support/cc/task/core/proto/external_file.proto"; // Options for setting up an ImageSegmenter. -// Next Id: 8 +// Next Id: 9 message ImageSegmenterOptions { - // The external model file, as a single standalone TFLite file. If it is + // Base options for configuring Task library, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional tflite.task.core.BaseOptions base_options = 8; + + // Legacy method for specifying the TfLite model file. If it is // packed with TFLite Model Metadata [1], those are used to populate label // map. Models without any such metadata or partial metadata are supported, // but may result in the segmenter providing degraded functionality; @@ -30,6 +36,9 @@ // return any class or display names. // // [1]: https://www.tensorflow.org/lite/convert/metadata + // + // Deprecated: prefer using `base_options.model_file`, which is mutually + // exclusive with this field. optional core.ExternalFile model_file_with_metadata = 5; // The locale to use for display names specified through the TFLite Model @@ -50,12 +59,35 @@ // Optional output mask type. optional OutputType output_type = 3 [default = CATEGORY_MASK]; - // The number of threads to be used for TFLite ops that support - // multi-threading when running inference with CPU. + // Legacy method for specifying the number of threads to be used for TFLite + // ops that support multi-threading when running inference with CPU. // num_threads should be greater than 0 or equal to -1. Setting num_threads to // -1 has the effect to let TFLite runtime set the value. + // + // Deprecated: only works with `model_file_with_metadata`. Prefer using + // `base_options` to specifying the TFLite model and using + // `base_options.compute_settings.tflite_settings.cpu_settings.num_threads`, + // to configure the number of threads. optional int32 num_threads = 7 [default = -1]; + // Legacy method for specifying how to accelerate the model + // inference using dedicated delegates. Supported delegate type includes: + // NONE, NNAPI, GPU, HEXAGON, XNNPACK, EDGETPU (Google internal), + // and EDGETPU_CORAL. + // + // IMPORTANT: in order to use a delegate, the appropriate delegate plugin + // needs to be linked at build time. See comment above the "image_segmenter" + // target at: + // https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/vision/BUILD + // + // See settings definition at: + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/configuration.proto + // + // Deprecated: only works with `model_file_with_metadata`. Prefer using + // `base_options` to specifying the TFLite model and use + // `base_options.compute_settings` to configure acceleration options. + optional tflite.proto.ComputeSettings compute_settings = 4; + // Reserved tags. - reserved 1, 2, 4; + reserved 1, 2; }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h index aaaecf36..f2d34c8 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h
@@ -13,10 +13,11 @@ limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_SEGMENTER_OPTIONS_PROTO_INC_H_ -#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_SEGMENTER_OPTIONS_PROTO_INC_H_ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_SEGMENTER_OPTIONS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_SEGMENTER_OPTIONS_PROTO_INC_H_ +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options.pb.h" -#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_SEGMENTER_OPTIONS_PROTO_INC_H_ +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_SEGMENTER_OPTIONS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/object_detector_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/object_detector_options.proto index b55e9740..d10198b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/object_detector_options.proto +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/object_detector_options.proto
@@ -17,16 +17,25 @@ package tflite.task.vision; +import "tensorflow/lite/experimental/acceleration/configuration/configuration.proto"; +import "tensorflow_lite_support/cc/task/core/proto/base_options.proto"; import "tensorflow_lite_support/cc/task/core/proto/external_file.proto"; // Options for setting up an ObjectDetector. -// Next Id: 8. +// Next Id: 10. message ObjectDetectorOptions { - // The external model file, as a single standalone TFLite file packed with - // TFLite Model Metadata [1]. Those are mandatory, and used to populate e.g. - // the label map and recommended score threshold. + // Base options for configuring Task library, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional tflite.task.core.BaseOptions base_options = 9; + + // Legacy method for specifying the TFLite model file, as a single standalone + // TFLite file packed with TFLite Model Metadata [1]. Those are mandatory, and + // used to populate e.g. the label map and recommended score threshold. // // [1]: https://www.tensorflow.org/lite/convert/metadata + // + // Deprecated: prefer using `base_options.model_file`, which is mutually + // exclusive with this field. optional core.ExternalFile model_file_with_metadata = 1; // The locale to use for display names specified through the TFLite Model @@ -54,9 +63,32 @@ // class names are ignored. Mutually exclusive with class_name_whitelist. repeated string class_name_blacklist = 6; - // The number of threads to be used for TFLite ops that support - // multi-threading when running inference with CPU. + // Legacy method for specifying the number of threads to be used for TFLite + // ops that support multi-threading when running inference with CPU. // num_threads should be greater than 0 or equal to -1. Setting num_threads to // -1 has the effect to let TFLite runtime set the value. + // + // Deprecated: only works with `model_file_with_metadata`. Prefer using + // `base_options` to specifying the TFLite model and using + // `base_options.compute_settings.tflite_settings.cpu_settings.num_threads`, + // to configure the number of threads. optional int32 num_threads = 7 [default = -1]; + + // Legacy method for specifying how to accelerate the model + // inference using dedicated delegates. Supported delegate type includes: + // NONE, NNAPI, GPU, HEXAGON, XNNPACK, EDGETPU (Google internal), + // and EDGETPU_CORAL. + // + // IMPORTANT: in order to use a delegate, the appropriate delegate plugin + // needs to be linked at build time. See comment above the "object_detector" + // target at: + // https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/vision/BUILD + // + // See settings definition at: + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/configuration.proto + // + // Deprecated: only works with `model_file_with_metadata`. Prefer using + // `base_options` to specifying the TFLite model and using + // `base_options.compute_settings` to configure acceleration options. + optional tflite.proto.ComputeSettings compute_settings = 8; }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h index 2789847..4d966630 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h
@@ -13,10 +13,11 @@ limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_OBJECT_DETECTOR_OPTIONS_PROTO_INC_H_ -#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_OBJECT_DETECTOR_OPTIONS_PROTO_INC_H_ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_OBJECT_DETECTOR_OPTIONS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_OBJECT_DETECTOR_OPTIONS_PROTO_INC_H_ +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options.pb.h" -#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_OBJECT_DETECTOR_OPTIONS_PROTO_INC_H_ +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_OBJECT_DETECTOR_OPTIONS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h index cfc96e6..86d38b5 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h
@@ -12,8 +12,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_SEGMENTATIONS_PROTO_INC_H_ -#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_SEGMENTATIONS_PROTO_INC_H_ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_SEGMENTATIONS_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_SEGMENTATIONS_PROTO_INC_H_ #include "tensorflow_lite_support/cc/task/vision/proto/segmentations.pb.h" -#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_SEGMENTATIONS_PROTO_INC_H_ +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_SEGMENTATIONS_PROTO_INC_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/BUILD index 8995145..72e3961 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/BUILD
@@ -1,6 +1,8 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + package( default_visibility = [ - "//tensorflow_lite_support:users", + "//tensorflow_lite_support:internal", ], licenses = ["notice"], # Apache 2.0 ) @@ -9,6 +11,9 @@ name = "score_calibration", srcs = ["score_calibration.cc"], hdrs = ["score_calibration.h"], + visibility = [ + "//tensorflow_lite_support:internal", + ], deps = [ "//tensorflow_lite_support/cc:common", "//tensorflow_lite_support/cc/port:status_macros", @@ -32,6 +37,7 @@ "frame_buffer_common_utils.h", "frame_buffer_utils_interface.h", ], + visibility = ["//visibility:public"], deps = [ "//tensorflow_lite_support/cc/port:integral_types", "//tensorflow_lite_support/cc/port:status_macros", @@ -52,6 +58,9 @@ hdrs = [ "frame_buffer_utils.h", ], + visibility = [ + "//tensorflow_lite_support:internal", + ], deps = [ ":frame_buffer_common_utils", ":libyuv_frame_buffer_utils", @@ -66,6 +75,7 @@ "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", + "@com_google_glog//:glog", "@org_tensorflow//tensorflow/lite/kernels:op_macros", "@org_tensorflow//tensorflow/lite/kernels/internal:compatibility", ], @@ -75,6 +85,9 @@ name = "libyuv_frame_buffer_utils", srcs = ["libyuv_frame_buffer_utils.cc"], hdrs = ["libyuv_frame_buffer_utils.h"], + visibility = [ + "//tensorflow_lite_support:internal", + ], deps = [ ":frame_buffer_common_utils", "//tensorflow_lite_support/cc:common", @@ -89,21 +102,22 @@ ], ) -cc_library( +cc_library_with_tflite( name = "image_tensor_specs", srcs = ["image_tensor_specs.cc"], hdrs = ["image_tensor_specs.h"], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/core:tflite_engine", + ], deps = [ "//tensorflow_lite_support/cc:common", "//tensorflow_lite_support/cc/port:integral_types", "//tensorflow_lite_support/cc/port:status_macros", "//tensorflow_lite_support/cc/port:statusor", - "//tensorflow_lite_support/cc/task/core:tflite_engine", "//tensorflow_lite_support/metadata:metadata_schema_cc", "//tensorflow_lite_support/metadata/cc:metadata_extractor", "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", - "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/c:common", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc index 9fbd284..cea7ef3 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
@@ -18,8 +18,8 @@ #include <string> #include <vector> -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" +#include "absl/strings/str_cat.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/status_macros.h" namespace tflite { @@ -33,29 +33,17 @@ constexpr int kRgbChannels = 3; constexpr int kGrayChannel = 1; -// Creates a FrameBuffer from raw NV12 buffer and passing arguments. -std::unique_ptr<FrameBuffer> CreateFromNV12RawBuffer( +// Creates a FrameBuffer from one plane raw NV21/NV12 buffer and passing +// arguments. +StatusOr<std::unique_ptr<FrameBuffer>> CreateFromOnePlaneNVRawBuffer( const uint8* input, FrameBuffer::Dimension dimension, - FrameBuffer::Orientation orientation, - const absl::Time timestamp) { - const std::vector<FrameBuffer::Plane> planes_nv12 = { - {input, /*stride=*/{dimension.width, kGrayChannel}}, - {input + dimension.Size(), /*stride=*/{dimension.width, 2}}}; - return FrameBuffer::Create(planes_nv12, dimension, FrameBuffer::Format::kNV12, - orientation, timestamp); -} - -// Creates a FrameBuffer from raw NV21 buffer and passing arguments. -std::unique_ptr<FrameBuffer> CreateFromNV21RawBuffer( - const uint8* input, - FrameBuffer::Dimension dimension, + FrameBuffer::Format format, FrameBuffer::Orientation orientation, const absl::Time timestamp) { FrameBuffer::Plane input_plane = {/*buffer=*/input, /*stride=*/{dimension.width, kGrayChannel}}; - return FrameBuffer::Create({input_plane}, dimension, - FrameBuffer::Format::kNV21, orientation, + return FrameBuffer::Create({input_plane}, dimension, format, orientation, timestamp); } @@ -101,12 +89,12 @@ case FrameBuffer::Format::kYV12: case FrameBuffer::Format::kYV21: return /*y plane*/ dimension.Size() + - /*uv plane*/ ((static_cast<float>(dimension.width + 1) / 2) * - (static_cast<float>(dimension.height + 1) / 2) * 2); + /*uv plane*/ (dimension.width + 1) / 2 * (dimension.height + 1) / + 2 * 2; case FrameBuffer::Format::kRGB: - return dimension.Size() * 3; + return dimension.Size() * kRgbPixelBytes; case FrameBuffer::Format::kRGBA: - return dimension.Size() * 4; + return dimension.Size() * kRgbaPixelBytes; case FrameBuffer::Format::kGRAY: return dimension.Size(); default: @@ -336,10 +324,14 @@ const uint8* input, FrameBuffer::Dimension dimension, FrameBuffer::Orientation orientation, - const absl::Time timestamp) { - FrameBuffer::Plane input_plane = { - /*buffer=*/input, - /*stride=*/{dimension.width * kRgbaChannels, kRgbaChannels}}; + const absl::Time timestamp, + FrameBuffer::Stride stride) { + if (stride == kDefaultStride) { + stride.row_stride_bytes = dimension.width * kRgbaChannels; + stride.pixel_stride_bytes = kRgbaChannels; + } + FrameBuffer::Plane input_plane = {/*buffer=*/input, + /*stride=*/stride}; return FrameBuffer::Create({input_plane}, dimension, FrameBuffer::Format::kRGBA, orientation, timestamp); @@ -350,10 +342,14 @@ const uint8* input, FrameBuffer::Dimension dimension, FrameBuffer::Orientation orientation, - const absl::Time timestamp) { - FrameBuffer::Plane input_plane = { - /*buffer=*/input, - /*stride=*/{dimension.width * kRgbChannels, kRgbChannels}}; + const absl::Time timestamp, + FrameBuffer::Stride stride) { + if (stride == kDefaultStride) { + stride.row_stride_bytes = dimension.width * kRgbChannels; + stride.pixel_stride_bytes = kRgbChannels; + } + FrameBuffer::Plane input_plane = {/*buffer=*/input, + /*stride=*/stride}; return FrameBuffer::Create({input_plane}, dimension, FrameBuffer::Format::kRGB, orientation, timestamp); } @@ -363,9 +359,14 @@ const uint8* input, FrameBuffer::Dimension dimension, FrameBuffer::Orientation orientation, - const absl::Time timestamp) { + const absl::Time timestamp, + FrameBuffer::Stride stride) { + if (stride == kDefaultStride) { + stride.row_stride_bytes = dimension.width * kGrayChannel; + stride.pixel_stride_bytes = kGrayChannel; + } FrameBuffer::Plane input_plane = {/*buffer=*/input, - /*stride=*/{dimension.width, kGrayChannel}}; + /*stride=*/stride}; return FrameBuffer::Create({input_plane}, dimension, FrameBuffer::Format::kGRAY, orientation, timestamp); @@ -410,9 +411,11 @@ absl::Time timestamp) { switch (target_format) { case FrameBuffer::Format::kNV12: - return CreateFromNV12RawBuffer(buffer, dimension, orientation, timestamp); + return CreateFromOnePlaneNVRawBuffer(buffer, dimension, target_format, + orientation, timestamp); case FrameBuffer::Format::kNV21: - return CreateFromNV21RawBuffer(buffer, dimension, orientation, timestamp); + return CreateFromOnePlaneNVRawBuffer(buffer, dimension, target_format, + orientation, timestamp); case FrameBuffer::Format::kYV12: { ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_dimension, GetUvPlaneDimension(dimension, target_format));
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h index 45670f4..7ebf69fa 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h
@@ -17,9 +17,9 @@ #include <memory> -#include "absl/status/status.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" +#include "absl/status/status.h" // from @com_google_absl +#include "absl/time/clock.h" // from @com_google_absl +#include "absl/time/time.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/integral_types.h" #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" @@ -29,6 +29,11 @@ namespace vision { constexpr int kRgbaPixelBytes = 4, kRgbPixelBytes = 3, kGrayPixelBytes = 1; +// Default stride value for creating frame buffer from raw buffer. When using +// this default value, the default row stride and pixel stride values will be +// applied. e.g. for RGB image, row_stride = width * kRgbPixelBytes, +// pixel_stride = kRgbPixelBytes. +inline constexpr FrameBuffer::Stride kDefaultStride = {0, 0}; // Miscellaneous Methods // ----------------------------------------------------------------- @@ -112,21 +117,24 @@ const uint8* input, FrameBuffer::Dimension dimension, FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft, - absl::Time timestamp = absl::Now()); + absl::Time timestamp = absl::Now(), + FrameBuffer::Stride stride = kDefaultStride); // Creates a FrameBuffer from raw RGB buffer and passing arguments. std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer( const uint8* input, FrameBuffer::Dimension dimension, FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft, - absl::Time timestamp = absl::Now()); + absl::Time timestamp = absl::Now(), + FrameBuffer::Stride stride = kDefaultStride); // Creates a FrameBuffer from raw grayscale buffer and passing arguments. std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer( const uint8* input, FrameBuffer::Dimension dimension, FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft, - absl::Time timestamp = absl::Now()); + absl::Time timestamp = absl::Now(), + FrameBuffer::Stride stride = kDefaultStride); // Creates a FrameBuffer from raw YUV buffer and passing arguments. tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc index f6b3eba..4728c30 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc
@@ -22,9 +22,9 @@ #include <utility> #include <vector> -#include "absl/memory/memory.h" -#include "absl/status/status.h" -#include "absl/strings/str_format.h" +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow_lite_support/cc/port/status_macros.h" @@ -331,6 +331,10 @@ } else if (absl::holds_alternative<CropResizeOperation>(operation)) { const auto& crop_resize = absl::get<CropResizeOperation>(operation); dimension = crop_resize.resize_dimension; + } else if (absl::holds_alternative<UniformCropResizeOperation>(operation)) { + const auto& uniform_crop_resize = + absl::get<UniformCropResizeOperation>(operation); + dimension = uniform_crop_resize.output_dimension; } return dimension; } @@ -361,9 +365,10 @@ planes.push_back( {buffer, /*stride=*/{/*row_stride_bytes=*/dimension.width, /*pixel_stride_bytes=*/1}}); - planes.push_back({buffer + (dimension.width * dimension.height), - /*stride=*/{/*row_stride_bytes=*/dimension.width, - /*pixel_stride_bytes=*/2}}); + planes.push_back( + {buffer + (dimension.width * dimension.height), + /*stride=*/{/*row_stride_bytes=*/(dimension.width + 1) / 2 * 2, + /*pixel_stride_bytes=*/2}}); } break; case FrameBuffer::Format::kYV12: case FrameBuffer::Format::kYV21: { @@ -414,6 +419,13 @@ (params.crop_dimension.width + params.crop_origin_x - 1), (params.crop_dimension.height + params.crop_origin_y - 1), output_buffer)); + } else if (absl::holds_alternative<UniformCropResizeOperation>(operation)) { + const auto& params = absl::get<UniformCropResizeOperation>(operation); + RETURN_IF_ERROR( + Crop(buffer, params.crop_origin_x, params.crop_origin_y, + (params.crop_dimension.width + params.crop_origin_x - 1), + (params.crop_dimension.height + params.crop_origin_y - 1), + output_buffer)); } else if (absl::holds_alternative<ConvertOperation>(operation)) { RETURN_IF_ERROR(Convert(buffer, output_buffer)); } else if (absl::holds_alternative<OrientOperation>(operation)) { @@ -584,7 +596,8 @@ absl::Status FrameBufferUtils::Preprocess( const FrameBuffer& buffer, absl::optional<BoundingBox> bounding_box, - FrameBuffer* output_buffer) { + FrameBuffer* output_buffer, + bool uniform_resizing) { std::vector<FrameBufferOperation> frame_buffer_operations; // Handle cropping and resizing. bool needs_dimension_swap = @@ -596,29 +609,56 @@ pre_orient_dimension.Swap(); } - if (bounding_box.has_value()) { - // Cropping case. + if (uniform_resizing && bounding_box.has_value()) { + // Crop and uniform resize. + frame_buffer_operations.push_back(UniformCropResizeOperation( + bounding_box.value().origin_x(), bounding_box.value().origin_y(), + FrameBuffer::Dimension{bounding_box.value().width(), + bounding_box.value().height()}, + pre_orient_dimension)); + } else if (uniform_resizing) { + // Uniform resize only. + frame_buffer_operations.push_back(UniformCropResizeOperation( + 0, 0, buffer.dimension(), pre_orient_dimension)); + } else if (bounding_box.has_value()) { + // Crop and non-uniform resize. frame_buffer_operations.push_back(CropResizeOperation( bounding_box.value().origin_x(), bounding_box.value().origin_y(), FrameBuffer::Dimension{bounding_box.value().width(), bounding_box.value().height()}, pre_orient_dimension)); } else if (pre_orient_dimension != buffer.dimension()) { - // Resizing case. + // non-uniform resize. frame_buffer_operations.push_back( CropResizeOperation(0, 0, buffer.dimension(), pre_orient_dimension)); } - // Handle color space conversion. - if (output_buffer->format() != buffer.format()) { - frame_buffer_operations.push_back( - ConvertOperation(output_buffer->format())); - } - - // Handle orientation conversion. - if (output_buffer->orientation() != buffer.orientation()) { - frame_buffer_operations.push_back( - OrientOperation(output_buffer->orientation())); + // Handle color space conversion first if the input format is RGB or RGBA, + // because the rotation performance for RGB and RGBA formats are not optimzed + // in libyuv. + if (buffer.format() == FrameBuffer::Format::kRGB || + buffer.format() == FrameBuffer::Format::kRGBA) { + if (output_buffer->format() != buffer.format()) { + frame_buffer_operations.push_back( + ConvertOperation(output_buffer->format())); + } + // Handle orientation conversion + if (output_buffer->orientation() != buffer.orientation()) { + frame_buffer_operations.push_back( + OrientOperation(output_buffer->orientation())); + } + } else { + // Handle orientation conversion first if the input format is not RGB or + // RGBA. + if (output_buffer->orientation() != buffer.orientation()) { + frame_buffer_operations.push_back( + OrientOperation(output_buffer->orientation())); + } + // Handle color space conversion + if (output_buffer->format() != buffer.format()) { + frame_buffer_operations.push_back( + ConvertOperation(output_buffer->format())); + } } // Execute the processing pipeline.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h index 794495e..4854946 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h
@@ -19,9 +19,9 @@ #include <memory> #include <vector> -#include "absl/status/status.h" -#include "absl/types/optional.h" -#include "absl/types/variant.h" +#include "absl/status/status.h" // from @com_google_absl +#include "absl/types/optional.h" // from @com_google_absl +#include "absl/types/variant.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/integral_types.h" #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" @@ -112,6 +112,39 @@ FrameBuffer::Dimension resize_dimension; }; +// The parameters needed to crop / resize / pad. +// +// The coordinate system has its origin at the upper left corner, and +// positive values extend down and to the right from it. +// +// After the operation, the `crop_origin` will become the new origin. +// `crop_width` and `crop_height` defines the desired cropping region. After +// cropping, a resize is performed based on the `resize_width` and +// `resize_height`. +// +// To perform just cropping, the `crop_width` and `crop_height` should be the +// same as `resize_width` `and resize_height`. +// +// The cropped region is resized uniformly (respecting the aspect ratio) to best +// match the size of the given `output_dimension` in both x and y dimensions. +// The resized region is aligned to the upper left pixel of the output buffer. +// The unfilled area of the output buffer remains untouched. +struct UniformCropResizeOperation { + UniformCropResizeOperation(int crop_origin_x, + int crop_origin_y, + FrameBuffer::Dimension crop_dimension, + FrameBuffer::Dimension output_dimension) + : crop_origin_x(crop_origin_x), + crop_origin_y(crop_origin_y), + crop_dimension(crop_dimension), + output_dimension(output_dimension) {} + + int crop_origin_x; + int crop_origin_y; + FrameBuffer::Dimension crop_dimension; + FrameBuffer::Dimension output_dimension; +}; + // The parameters needed to convert to the specified format. struct ConvertOperation { explicit ConvertOperation(FrameBuffer::Format to_format) @@ -128,8 +161,10 @@ // A variant of the supported operations on FrameBuffers. Alias for user // convenience. -using FrameBufferOperation = - absl::variant<CropResizeOperation, ConvertOperation, OrientOperation>; +using FrameBufferOperation = absl::variant<CropResizeOperation, + ConvertOperation, + OrientOperation, + UniformCropResizeOperation>; // Image processing utility. This utility provides both basic image buffer // manipulations (e.g. rotation, format conversion, resizing, etc) as well as @@ -139,7 +174,8 @@ // Examples: // // // Create an instance of FrameBufferUtils with Halide processing engine. -// std::unique_ptr<FrameBufferUtils> utils = FrameBufferUtils::Create(kHalide); +// std::unique_ptr<FrameBufferUtils> utils = +// FrameBufferUtils::Create(kHalide); // // // Perform single basic operation by each individual call. // std::unique_ptr<FrameBuffer> input = FrameBuffer::Create(...); @@ -263,10 +299,17 @@ // If the `buffer` is already in desired format, then an extra copy will be // performed. // + // If `uniform_resizing` is set to true, the source region is resized + // uniformly (respecting the aspect ratio) to best match the dimension of the + // given `output_buffer` in both x and y dimensions. The resized region is + // aligned to the upper left pixel of the output buffer. The unfilled area of + // the output buffer remains untouched. Default `uniform_resizing` to false; + // // The input param `bounding_box` is defined in the `buffer` coordinate space. absl::Status Preprocess(const FrameBuffer& buffer, absl::optional<BoundingBox> bounding_box, - FrameBuffer* output_buffer); + FrameBuffer* output_buffer, + bool uniform_resizing = false); private: // Returns the new FrameBuffer size after the operation is applied.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h index 36d1895e..59da220 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h
@@ -16,7 +16,7 @@ #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_UTILS_INTERFACE_H_ #define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_UTILS_INTERFACE_H_ -#include "absl/status/status.h" +#include "absl/status/status.h" // from @com_google_absl #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" namespace tflite {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc index 51e72fa..afbe07dd 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc
@@ -14,7 +14,7 @@ ==============================================================================*/ #include "tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h" -#include "absl/status/status.h" +#include "absl/status/status.h" // from @com_google_absl #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/integral_types.h" #include "tensorflow_lite_support/cc/port/status_macros.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h index 536eed4d..d15be3f 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h
@@ -17,7 +17,7 @@ #include <array> -#include "absl/types/optional.h" +#include "absl/types/optional.h" // from @com_google_absl #include "tensorflow/lite/c/common.h" #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc index b50b500..a00c8223 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
@@ -20,15 +20,16 @@ #include <memory> #include <string> -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_cat.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "libyuv.h" // from @libyuv #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/integral_types.h" #include "tensorflow_lite_support/cc/port/status_macros.h" #include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" -#include "third_party/libyuv/include/libyuv.h" namespace tflite { namespace task { @@ -139,7 +140,7 @@ // the big endian style with R being the first byte in memory. int ret = libyuv::NV21ToRAW( yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.v_buffer, - yuv_data.uv_pixel_stride, + yuv_data.uv_row_stride, const_cast<uint8*>(output_buffer->plane(0).buffer), output_buffer->plane(0).stride.row_stride_bytes, buffer.dimension().width, buffer.dimension().height); @@ -154,7 +155,7 @@ // The libyuv ABGR format is interleaved RGBA format in memory. int ret = libyuv::NV21ToABGR( yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.v_buffer, - yuv_data.uv_pixel_stride, + yuv_data.uv_row_stride, const_cast<uint8*>(output_buffer->plane(0).buffer), output_buffer->plane(0).stride.row_stride_bytes, buffer.dimension().width, buffer.dimension().height); @@ -353,39 +354,30 @@ // Resizes NV12/NV21 `buffer` to the target `output_buffer`. absl::Status ResizeNv(const FrameBuffer& buffer, FrameBuffer* output_buffer) { - const int buffer_size = - GetFrameBufferByteSize(buffer.dimension(), FrameBuffer::Format::kYV21); - auto yuv_raw_buffer = absl::make_unique<uint8[]>(buffer_size); - ASSIGN_OR_RETURN( - std::unique_ptr<FrameBuffer> yuv_buffer, - CreateFromRawBuffer(yuv_raw_buffer.get(), buffer.dimension(), - FrameBuffer::Format::kYV21, buffer.orientation())); - // TODO(b/151375918): Current implementation is a workaround by converting - // input NV12/NV21 buffer to the YV12 formats, resizing the YV12 buffer, and - // converting the resized YV12 buffer back to the target format. Consider - // optimizes this by adding the support of NV12/NV21 resizing in Libyuv. - if (buffer.format() == FrameBuffer::Format::kNV12) { - RETURN_IF_ERROR(ConvertFromNv12(buffer, yuv_buffer.get())); - } else if (buffer.format() == FrameBuffer::Format::kNV21) { - RETURN_IF_ERROR(ConvertFromNv21(buffer, yuv_buffer.get())); - } else { - return CreateStatusWithPayload( - StatusCode::kInternal, - absl::StrFormat("Format %i is not supported.", buffer.format()), - TfLiteSupportStatus::kImageProcessingError); + ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + const uint8* src_uv = input_data.u_buffer; + const uint8* dst_uv = output_data.u_buffer; + if (buffer.format() == FrameBuffer::Format::kNV21) { + src_uv = input_data.v_buffer; + dst_uv = output_data.v_buffer; } - const int resized_buffer_size = GetFrameBufferByteSize( - output_buffer->dimension(), FrameBuffer::Format::kYV12); - auto resized_yuv_raw_buffer = absl::make_unique<uint8[]>(resized_buffer_size); - ASSIGN_OR_RETURN(std::unique_ptr<FrameBuffer> resized_yuv_buffer, - CreateFromRawBuffer(resized_yuv_raw_buffer.get(), - output_buffer->dimension(), - FrameBuffer::Format::kYV12, - output_buffer->orientation())); - RETURN_IF_ERROR(ResizeYv(*yuv_buffer, resized_yuv_buffer.get())); + int ret = libyuv::NV12Scale( + input_data.y_buffer, input_data.y_row_stride, src_uv, + input_data.uv_row_stride, buffer.dimension().width, + buffer.dimension().height, const_cast<uint8_t*>(output_data.y_buffer), + output_data.y_row_stride, const_cast<uint8_t*>(dst_uv), + output_data.uv_row_stride, output_buffer->dimension().width, + output_buffer->dimension().height, libyuv::FilterMode::kFilterBilinear); - RETURN_IF_ERROR(ConvertFromYv(*resized_yuv_buffer, output_buffer)); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv NV12Scale operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } return absl::OkStatus(); } @@ -1361,6 +1353,7 @@ TfLiteSupportStatus::kImageProcessingError); } } + } // namespace absl::Status LibyuvFrameBufferUtils::Crop(const FrameBuffer& buffer,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h index 24ed424..6f83559 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h
@@ -16,7 +16,7 @@ #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_LIBYUV_FRAME_BUFFER_UTILS_H_ #define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_LIBYUV_FRAME_BUFFER_UTILS_H_ -#include "absl/status/status.h" +#include "absl/status/status.h" // from @com_google_absl #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc index 773ab76..d58969d9 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc
@@ -14,16 +14,17 @@ ==============================================================================*/ #include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h" +#include <algorithm> #include <cmath> #include <memory> #include <utility> #include <vector> -#include "absl/status/status.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "absl/strings/str_split.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "absl/types/optional.h" // from @com_google_absl #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/status_macros.h" @@ -95,6 +96,17 @@ TfLiteSupportStatus::kMetadataMalformedScoreCalibrationError); } } + + // Verify if scale is a non-negative value. + if (float_params[0] < 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected scale to be a non-negative value, but got %f.", + float_params[0]), + TfLiteSupportStatus::kMetadataMalformedScoreCalibrationError); + } + Sigmoid sigmoid; sigmoid.label = std::string(label); sigmoid.scale = float_params[0]; @@ -160,13 +172,20 @@ // For numerical stability use 1 / (1+exp(-x)) when scale_shifted_score >= 0 // and exp(x) / (1+exp(x)) when scale_shifted_score < 0. + float calibrated_score; if (scale_shifted_score >= 0.0) { - return sigmoid.value().scale / - (1.0 + std::exp(static_cast<double>(-scale_shifted_score))); + calibrated_score = + sigmoid.value().scale / + (1.0 + std::exp(static_cast<double>(-scale_shifted_score))); } else { float score_exp = std::exp(static_cast<double>(scale_shifted_score)); - return sigmoid.value().scale * score_exp / (1.0 + score_exp); + calibrated_score = sigmoid.value().scale * score_exp / (1.0 + score_exp); } + // Scale is non-negative (checked in SigmoidFromLabelAndLine), + // thus calibrated_score should be in the range of [0, scale]. However, due to + // numberical stability issue, it may fall out of the boundary. Cap the value + // to [0, scale] instead. + return std::max(std::min(calibrated_score, sigmoid.value().scale), 0.0f); } absl::optional<Sigmoid> ScoreCalibration::FindSigmoidParameters(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h index 19d966b..e2b403d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h
@@ -22,10 +22,10 @@ #include <utility> #include <vector> -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "absl/types/optional.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/BUILD new file mode 100644 index 0000000..a2a8007 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/BUILD
@@ -0,0 +1,31 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "test_utils", + testonly = 1, + srcs = ["test_utils.cc"], + hdrs = [ + "message_matchers.h", + "test_utils.h", + ], + deps = [ + "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/port:proto2", + "@com_google_absl//absl/strings", + "@com_google_glog//:glog", + ], +) + +cc_test( + name = "common_test", + srcs = ["common_test.cc"], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:gtest_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:cord", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc new file mode 100644 index 0000000..bc2f9dfd --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc
@@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/common.h" + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/cord.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" + +namespace tflite { +namespace support { +namespace { + +using testing::Optional; + +TEST(CommonTest, CreateStatusWithPayloadWorks) { + absl::Status status = CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, "Bad schema version: BADF", + TfLiteSupportStatus::kMetadataInvalidSchemaVersionError); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(status.message(), std::string("Bad schema version: BADF")); + EXPECT_THAT(status.GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord("200"))); +} + +} // namespace +} // namespace support +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/message_matchers.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/message_matchers.h new file mode 100644 index 0000000..6820e20 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/message_matchers.h
@@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEST_MESSAGE_MATCHERS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TEST_MESSAGE_MATCHERS_H_ + +#include <memory> + +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/proto2.h" + +namespace tflite { +namespace support { + +class ProtoMatcher { + public: + using is_gtest_matcher = void; + using MessageType = tflite::support::proto::MessageLite; + + explicit ProtoMatcher(const MessageType& message) + : message_(CloneMessage(message)) {} + + bool MatchAndExplain(const MessageType& m, + testing::MatchResultListener*) const { + return EqualsMessage(*message_, m); + } + bool MatchAndExplain(const MessageType* m, + testing::MatchResultListener*) const { + return EqualsMessage(*message_, *m); + } + + void DescribeTo(std::ostream* os) const { + *os << "has the same serialization as " << ExpectedMessageDescription(); + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "does not have the same serialization as " + << ExpectedMessageDescription(); + } + + private: + std::unique_ptr<MessageType> CloneMessage(const MessageType& message) { + std::unique_ptr<MessageType> clone(message.New()); + clone->CheckTypeAndMergeFrom(message); + return clone; + } + + bool EqualsMessage(const MessageType& m_1, const MessageType& m_2) const { + std::string s_1, s_2; + m_1.SerializeToString(&s_1); + m_2.SerializeToString(&s_2); + return s_1 == s_2; + } + + std::string ExpectedMessageDescription() const { + return message_->DebugString(); + } + + const std::shared_ptr<MessageType> message_; +}; + +inline ProtoMatcher EqualsProto( + const tflite::support::proto::MessageLite& message) { + return ProtoMatcher(message); +} + +// for Pointwise +MATCHER(EqualsProto, "") { + const auto& a = ::testing::get<0>(arg); + const auto& b = ::testing::get<1>(arg); + return ::testing::ExplainMatchResult(EqualsProto(b), a, result_listener); +} + +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TEST_MESSAGE_MATCHERS_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/BUILD new file mode 100644 index 0000000..7376ad3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/BUILD
@@ -0,0 +1,32 @@ +load( + "@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", + "cc_test_with_tflite", +) + +package( + default_visibility = [ + "//visibility:private", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_test_with_tflite( + name = "image_preprocessor_test", + srcs = ["image_preprocessor_test.cc"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_images", + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_models", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/processor:image_preprocessor", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], + deps = [ + "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/cc/test:test_utils", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "@com_google_absl//absl/status", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc new file mode 100644 index 0000000..9ae9435 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc
@@ -0,0 +1,116 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/processor/image_preprocessor.h" + +#include <memory> + +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +namespace tflite { +namespace task { +namespace processor { +namespace { + +using ::tflite::support::StatusOr; +using ::tflite::task::JoinPath; +using ::tflite::task::core::TfLiteEngine; +using ::tflite::task::vision::DecodeImageFromFile; +using ::tflite::task::vision::FrameBuffer; +using ::tflite::task::vision::ImageData; + +constexpr char kTestDataDirectory[] = + "/tensorflow_lite_support/cc/test/testdata/task/" + "vision/"; + +constexpr char kDilatedConvolutionModelWithMetaData[] = "dilated_conv.tflite"; + +StatusOr<ImageData> LoadImage(std::string image_name) { + return DecodeImageFromFile( + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name)); +} + +class DynamicInputTest : public tflite_shims::testing::Test { + protected: + void PreprocessImage() { + engine_ = absl::make_unique<TfLiteEngine>(); + SUPPORT_ASSERT_OK(engine_->BuildModelFromFile( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + kDilatedConvolutionModelWithMetaData))); + SUPPORT_ASSERT_OK(engine_->InitInterpreter()); + + SUPPORT_ASSERT_OK_AND_ASSIGN(auto preprocessor, + ImagePreprocessor::Create(engine_.get(), {0})); + + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + image.pixel_data, FrameBuffer::Dimension{image.width, image.height}); + + SUPPORT_ASSERT_OK(preprocessor->Preprocess(*frame_buffer)); + + ImageDataFree(&image); + } + + std::unique_ptr<TfLiteEngine> engine_ = nullptr; +}; + +// See if output tensor has been re-dimmed as per the input +// tensor. Expected shape: (1, input_height, input_width, 16). +TEST_F(DynamicInputTest, OutputDimensionCheck) { + PreprocessImage(); + + EXPECT_TRUE(engine_->interpreter_wrapper()->InvokeWithoutFallback().ok()); + EXPECT_EQ(engine_->GetOutputs()[0]->dims->data[0], 1); + EXPECT_EQ(engine_->GetOutputs()[0]->dims->data[1], + engine_->GetInputs()[0]->dims->data[1]); + EXPECT_EQ(engine_->GetOutputs()[0]->dims->data[2], + engine_->GetInputs()[0]->dims->data[2]); + EXPECT_EQ(engine_->GetOutputs()[0]->dims->data[3], 16); +} + +// Compare pre-processed input with an already pre-processed +// golden image. +TEST_F(DynamicInputTest, GoldenImageComparison) { + PreprocessImage(); + + // Get the processed input image. + SUPPORT_ASSERT_OK_AND_ASSIGN( + float* processed_input_data, + tflite::task::core::AssertAndReturnTypedTensor<float>( + engine_->GetInputs()[0])); + + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + const uint8* image_data = image.pixel_data; + const size_t image_size = image.width * image.height * 3; + + for (size_t i = 0; i < image_size; ++i, ++image_data, ++processed_input_data) + EXPECT_NEAR(static_cast<float>(*image_data), *processed_input_data, + std::numeric_limits<float>::epsilon()); + + ImageDataFree(&image); +} + +} // namespace +} // namespace processor +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/BUILD new file mode 100644 index 0000000..57f26345 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/BUILD
@@ -0,0 +1,48 @@ +load( + "@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", + "cc_test_with_tflite", +) + +package( + default_visibility = [ + "//visibility:private", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_test_with_tflite( + name = "bert_nl_classifier_test", + srcs = ["bert_nl_classifier_test.cc"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/text:bert_nl_classifier_models", + ], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + "//tensorflow_lite_support/cc/task/text:bert_nl_classifier", + ], + deps = [ + "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/test:test_utils", + ], +) + +cc_test( + name = "bert_question_answerer_test", + timeout = "long", + srcs = ["bert_question_answerer_test.cc"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/text:albert_model", + "//tensorflow_lite_support/cc/test/testdata/task/text:mobile_bert_model", + ], + tags = [ + "optonly", # The test takes long, and only run with -c opt. + ], + deps = [ + "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/text:bert_question_answerer", + "//tensorflow_lite_support/cc/test:test_utils", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc new file mode 100644 index 0000000..c4a8cea --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
@@ -0,0 +1,210 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/bert_nl_classifier.h" + +#include <fcntl.h> + +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" + +namespace tflite { +namespace task { +namespace text { + +namespace { + +using ::testing::HasSubstr; +using ::testing::Optional; +using ::tflite::support::kTfLiteSupportPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::core::Category; +using ::tflite::task::core::LoadBinaryContent; + +constexpr char kTestDataDirectory[] = + "/tensorflow_lite_support/cc/test/testdata/task/" + "text/"; + +constexpr char kTestModelPath[] = "bert_nl_classifier.tflite"; + +constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite"; + +constexpr int kMaxSeqLen = 128; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name); +} + +class BertNLClassifierTest : public tflite_shims::testing::Test {}; + +TEST_F(BertNLClassifierTest, CreateFromOptionsSucceedsWithModelWithMetadata) { + BertNLClassifierOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + GetFullPath(kTestModelPath)); + + SUPPORT_ASSERT_OK(BertNLClassifier::CreateFromOptions(options)); +} + +TEST_F(BertNLClassifierTest, CreateFromOptionsFailsWithMissingBaseOptions) { + BertNLClassifierOptions options; + StatusOr<std::unique_ptr<BertNLClassifier>> classifier_or = + BertNLClassifier::CreateFromOptions(options); + + EXPECT_EQ(classifier_or.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(classifier_or.status().message(), + HasSubstr("Missing mandatory `base_options`")); + EXPECT_THAT(classifier_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(BertNLClassifierTest, TestNLClassifierCreationFilePath) { + SUPPORT_ASSERT_OK( + BertNLClassifier::CreateFromFile(GetFullPath(kTestModelPath))); +} + +TEST_F(BertNLClassifierTest, TestNLClassifierCreationBinary) { + std::string model_buffer = + LoadBinaryContent(GetFullPath(kTestModelPath).c_str()); + SUPPORT_ASSERT_OK(BertNLClassifier::CreateFromBuffer(model_buffer.data(), + model_buffer.size())); +} + +TEST_F(BertNLClassifierTest, TestNLClassifierCreationFailure) { + StatusOr<std::unique_ptr<BertNLClassifier>> classifier_or = + BertNLClassifier::CreateFromFile(kInvalidModelPath); + + EXPECT_EQ(classifier_or.status().code(), absl::StatusCode::kNotFound); + EXPECT_THAT(classifier_or.status().message(), + HasSubstr("Unable to open file at i/do/not/exist.tflite")); + EXPECT_THAT(classifier_or.status().GetPayload(kTfLiteSupportPayload), + testing::Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kFileNotFoundError)))); +} + +Category* GetCategoryWithClassName(const std::string& class_name, + std::vector<Category>& categories) { + for (Category& category : categories) { + if (category.class_name == class_name) { + return &category; + } + } + return nullptr; +} + +void verify_classifier(std::unique_ptr<BertNLClassifier> classifier, + bool verify_positive) { + if (verify_positive) { + std::vector<core::Category> results = + classifier->Classify("unflinchingly bleak and desperate"); + EXPECT_GT(GetCategoryWithClassName("negative", results)->score, + GetCategoryWithClassName("positive", results)->score); + } else { + std::vector<Category> results = + classifier->Classify("it's a charming and often affecting journey"); + EXPECT_GT(GetCategoryWithClassName("positive", results)->score, + GetCategoryWithClassName("negative", results)->score); + } +} + +TEST_F(BertNLClassifierTest, ClassifySucceedsWithBaseOptions) { + std::unique_ptr<BertNLClassifier> classifier; + + // Test creating BertNLClassifier when classifier outlives options. + { + std::string contents = + LoadBinaryContent(GetFullPath(kTestModelPath).c_str()); + BertNLClassifierOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_content( + contents); + + SUPPORT_ASSERT_OK_AND_ASSIGN(classifier, + BertNLClassifier::CreateFromOptions(options)); + } + + verify_classifier(std::move(classifier), /*verify_positive=*/false); +} + +TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyNegative) { + std::string model_buffer = + LoadBinaryContent(GetFullPath(kTestModelPath).c_str()); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier, + BertNLClassifier::CreateFromBuffer( + model_buffer.data(), model_buffer.size())); + + verify_classifier(std::move(classifier), false); +} + +TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyPositive) { + std::string model_buffer = + LoadBinaryContent(GetFullPath(kTestModelPath).c_str()); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier, + BertNLClassifier::CreateFromBuffer( + model_buffer.data(), model_buffer.size())); + + verify_classifier(std::move(classifier), true); +} + +TEST_F(BertNLClassifierTest, TestNLClassifierFd_ClassifyPositive) { + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<BertNLClassifier> classifier, + BertNLClassifier::CreateFromFd( + open(GetFullPath(kTestModelPath).c_str(), O_RDONLY))); + + verify_classifier(std::move(classifier), false); +} + +TEST_F(BertNLClassifierTest, TestNLClassifierFd_ClassifyNegative) { + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<BertNLClassifier> classifier, + BertNLClassifier::CreateFromFd( + open(GetFullPath(kTestModelPath).c_str(), O_RDONLY))); + + verify_classifier(std::move(classifier), true); +} + +// BertNLClassifier limits the input sequence to kMaxSeqLen, test when input is +// longer than this the classifier still works correctly. +TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) { + std::string model_buffer = + LoadBinaryContent(GetFullPath(kTestModelPath).c_str()); + std::stringstream ss_for_positive_review; + ss_for_positive_review + << "it's a charming and often affecting journey and this is a long"; + for (int i = 0; i < kMaxSeqLen; ++i) { + ss_for_positive_review << " long"; + } + ss_for_positive_review << " movie review"; + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier, + BertNLClassifier::CreateFromBuffer( + model_buffer.data(), model_buffer.size())); + + std::vector<core::Category> results = + classifier->Classify(ss_for_positive_review.str()); + + EXPECT_GT(GetCategoryWithClassName("positive", results)->score, + GetCategoryWithClassName("negative", results)->score); +} + +} // namespace + +} // namespace text +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc new file mode 100644 index 0000000..a70dab7 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc
@@ -0,0 +1,238 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h" + +#include <fcntl.h> + +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" + +namespace tflite { +namespace task { +namespace text { + +namespace { + +using ::testing::HasSubstr; +using ::testing::Optional; +using ::tflite::support::kTfLiteSupportPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::JoinPath; +using ::tflite::task::core::LoadBinaryContent; + +constexpr char kTestDataDirectory[] = + "/tensorflow_lite_support/cc/test/testdata/task/" + "text/"; + +constexpr char kTestMobileBertModelPath[] = "mobilebert_float.tflite"; +constexpr char kTestVocabPath[] = "mobilebert_vocab.txt"; +constexpr char kTestMobileBertWithMetadataModelPath[] = + "mobilebert_with_metadata.tflite"; +constexpr char kTestAlBertModelPath[] = "albert.tflite"; +constexpr char kTestSPModelPath[] = "30k-clean.model"; +constexpr char kTestAlbertWithMetadataModelPath[] = + "albert_with_metadata.tflite"; + +constexpr char kQuestion[] = "What is a course of study called?"; +constexpr char kAnswer[] = "the curriculum."; +constexpr char kContext[] = + "The role of teacher is often formal and ongoing, carried out at a school " + "or other place of formal education. In many countries, a person who " + "wishes to become a teacher must first obtain specified professional " + "qualifications or credentials from a university or college. These " + "professional qualifications may include the study of pedagogy, the " + "science of teaching. Teachers, like other professionals, may have to " + "continue their education after they qualify, a process known as " + "continuing professional development. Teachers may use a lesson plan to " + "facilitate student learning, providing a course of study which is called " + "the curriculum."; +constexpr int kPredictAnsNum = 5; + +class BertQuestionAnswererTest : public tflite_shims::testing::Test {}; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name); +} + +TEST_F(BertQuestionAnswererTest, + CreateFromOptionsSucceedsWithModelWithMetadata) { + BertQuestionAnswererOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + GetFullPath(kTestMobileBertWithMetadataModelPath)); + + SUPPORT_ASSERT_OK(BertQuestionAnswerer::CreateFromOptions(options)); +} + +TEST_F(BertQuestionAnswererTest, CreateFromOptionsFailsWithMissingBaseOptions) { + BertQuestionAnswererOptions options; + StatusOr<std::unique_ptr<QuestionAnswerer>> question_answerer_or = + BertQuestionAnswerer::CreateFromOptions(options); + + EXPECT_EQ(question_answerer_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(question_answerer_or.status().message(), + HasSubstr("Missing mandatory `base_options`")); + EXPECT_THAT(question_answerer_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(BertQuestionAnswererTest, AnswerSucceedsWithModelWithMetadata) { + std::unique_ptr<QuestionAnswerer> question_answerer; + + // Test creating NLClassifier when question_answerer outlives options. + { + std::string contents = LoadBinaryContent( + GetFullPath(kTestMobileBertWithMetadataModelPath).c_str()); + + BertQuestionAnswererOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_content( + contents); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + question_answerer, BertQuestionAnswerer::CreateFromOptions(options)); + } + + std::vector<QaAnswer> answer = question_answerer->Answer(kContext, kQuestion); + ASSERT_EQ(answer.size(), kPredictAnsNum); + EXPECT_EQ(answer[0].text, kAnswer); +} + +TEST_F(BertQuestionAnswererTest, TestBertCreationFromBinary) { + std::string model_buffer = + LoadBinaryContent(GetFullPath(kTestMobileBertModelPath).c_str()); + std::string vocab_buffer = + LoadBinaryContent(GetFullPath(kTestVocabPath).c_str()); + SUPPORT_ASSERT_OK(BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer( + model_buffer.data(), model_buffer.size(), vocab_buffer.data(), + vocab_buffer.size())); +} + +TEST_F(BertQuestionAnswererTest, TestBertCreationFromFile) { + SUPPORT_ASSERT_OK(BertQuestionAnswerer::CreateBertQuestionAnswererFromFile( + GetFullPath(kTestMobileBertModelPath).c_str(), + GetFullPath(kTestVocabPath).c_str())); +} + +TEST_F(BertQuestionAnswererTest, TestBertAnswer) { + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<QuestionAnswerer> question_answerer_status, + BertQuestionAnswerer::CreateBertQuestionAnswererFromFile( + GetFullPath(kTestMobileBertModelPath).c_str(), + GetFullPath(kTestVocabPath).c_str())); + + std::vector<QaAnswer> answer = + question_answerer_status->Answer(kContext, kQuestion); + ASSERT_EQ(answer.size(), kPredictAnsNum); + EXPECT_EQ(answer[0].text, kAnswer); +} + +TEST_F(BertQuestionAnswererTest, TestAlbertCreationFromBinary) { + std::string model_buffer = + LoadBinaryContent(GetFullPath(kTestAlBertModelPath).c_str()); + std::string vocab_buffer = + LoadBinaryContent(GetFullPath(kTestSPModelPath).c_str()); + SUPPORT_ASSERT_OK(BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer( + model_buffer.data(), model_buffer.size(), vocab_buffer.data(), + vocab_buffer.size())); +} + +TEST_F(BertQuestionAnswererTest, TestAlbertCreationFromFile) { + SUPPORT_ASSERT_OK(BertQuestionAnswerer::CreateBertQuestionAnswererFromFile( + GetFullPath(kTestAlBertModelPath).c_str(), + GetFullPath(kTestSPModelPath).c_str())); +} + +TEST_F(BertQuestionAnswererTest, TestAlbertAnswer) { + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<QuestionAnswerer> question_answerer_status, + BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile( + GetFullPath(kTestAlBertModelPath).c_str(), + GetFullPath(kTestSPModelPath).c_str())); + + std::vector<QaAnswer> answer = + question_answerer_status->Answer(kContext, kQuestion); + ASSERT_EQ(answer.size(), kPredictAnsNum); + EXPECT_EQ(answer[0].text, kAnswer); +} + +TEST_F(BertQuestionAnswererTest, TestCreateWithMetadata) { + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<QuestionAnswerer> question_answerer_status, + BertQuestionAnswerer::CreateFromFile( + GetFullPath(kTestMobileBertWithMetadataModelPath).c_str())); + + std::vector<QaAnswer> answer = + question_answerer_status->Answer(kContext, kQuestion); + ASSERT_EQ(answer.size(), kPredictAnsNum); + EXPECT_EQ(answer[0].text, kAnswer); +} + +TEST_F(BertQuestionAnswererTest, TestCreateWithMetadataFromBinary) { + std::string model_buffer = + LoadBinaryContent(GetFullPath(kTestAlbertWithMetadataModelPath).c_str()); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<QuestionAnswerer> question_answerer_status, + BertQuestionAnswerer::CreateFromBuffer(model_buffer.data(), + model_buffer.size())); + + std::vector<QaAnswer> answer = + question_answerer_status->Answer(kContext, kQuestion); + ASSERT_EQ(answer.size(), kPredictAnsNum); + EXPECT_EQ(answer[0].text, kAnswer); +} + +TEST_F(BertQuestionAnswererTest, TestCreateWithFileDescriptor2) { + std::string model_buffer = + LoadBinaryContent(GetFullPath(kTestAlbertWithMetadataModelPath).c_str()); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<QuestionAnswerer> question_answerer_status, + BertQuestionAnswerer::CreateFromFd(open( + GetFullPath(kTestAlbertWithMetadataModelPath).c_str(), O_RDONLY))); + + std::vector<QaAnswer> answer = + question_answerer_status->Answer(kContext, kQuestion); + ASSERT_EQ(answer.size(), kPredictAnsNum); + EXPECT_EQ(answer[0].text, kAnswer); +} + +TEST_F(BertQuestionAnswererTest, + TestCreateWithMetadataFail_fromModelWithoutMetadata) { + StatusOr<std::unique_ptr<QuestionAnswerer>> question_answerer_or = + BertQuestionAnswerer::CreateFromFile( + GetFullPath(kTestMobileBertModelPath).c_str()); + + EXPECT_EQ(question_answerer_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(question_answerer_or.status().message(), + HasSubstr("No input process unit found from metadata.")); + EXPECT_THAT(question_answerer_or.status().GetPayload( + tflite::support::kTfLiteSupportPayload), + testing::Optional(absl::Cord(absl::StrCat( + TfLiteSupportStatus::kMetadataInvalidTokenizerError)))); +} + +} // namespace +} // namespace text +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/BUILD new file mode 100644 index 0000000..28c5739 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/BUILD
@@ -0,0 +1,47 @@ +load( + "@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", + "cc_test_with_tflite", +) + +package( + default_visibility = [ + "//visibility:private", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "nl_classifier_test_utils", + srcs = ["nl_classifier_test_utils.cc"], + hdrs = ["nl_classifier_test_utils.h"], + visibility = ["//tensorflow_lite_support:internal"], + deps = [ + "//tensorflow_lite_support/cc/task/core:task_utils", + "@org_tensorflow//tensorflow/lite:op_resolver", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + ], +) + +cc_test_with_tflite( + name = "nl_classifier_test", + srcs = ["nl_classifier_test.cc"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/text:nl_classifier_models", + ], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + "//tensorflow_lite_support/cc/task/core:base_task_api", + "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier", + ], + deps = [ + ":nl_classifier_test_utils", + "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/test:test_utils", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/kernels:deprecated_backends", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc new file mode 100644 index 0000000..81198cf --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc
@@ -0,0 +1,580 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" + +#include <utility> + +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test_utils.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" + +static constexpr char kInputStr[] = "hello"; + +namespace tflite { +namespace task { +namespace text { +namespace nlclassifier { +namespace { + +using ::testing::HasSubstr; +using ::testing::Optional; +using ::testing::TestWithParam; +using ::testing::UnorderedElementsAreArray; +using ::testing::ValuesIn; +using ::tflite::support::kTfLiteSupportPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::JoinPath; +using ::tflite::task::core::LoadBinaryContent; +using NLClassifierProtoOptions = ::tflite::task::text::NLClassifierOptions; + +constexpr char kTestDataDirectory[] = + "/tensorflow_lite_support/cc/test/testdata/task/" + "text/"; + +// The model has 1 input tensor and 4 output tensors with the following names +// and indices. +// The model also has three custom OPs to mimic classification model behaviors, +// see CUSTOM_OP_STRING_TO_FLOATS, CUSTOM_OP_STRING_TO_DOUBLES +// and CUSTOM_OP_GENERATE_LABELS in nl_classifier_test_utils for details. +constexpr char kTestModelPath[] = "test_model_nl_classifier.tflite"; + +// The model has 1 input tensor and 1 output tensor +// The model also has a custom OP to mimic classification model behaviors, +// see CUSTOM_OP_STRING_TO_BOOLS in nl_classifier_test_utils for details. +constexpr char kTestModelBoolOutputPath[] = + "test_model_nl_classifier_bool_output.tflite"; + +// The model has same input/output tensors with the above model, except its +// first output tensor is associated with metadata with name +// kMetadataOutputScoreTensorName and an associated label file. +constexpr char kTestModelWithLabelCustomOpsPath[] = + "test_model_nl_classifier_with_associated_label.tflite"; + +constexpr char kTestModelWithLabelBuiltInOpsPath[] = + "test_model_nl_classifier_with_associated_label_builtin_ops.tflite"; + +// This model expects input to be tokenized by a regex tokenizer. +constexpr char kTestModelWithRegexTokenizer[] = + "test_model_nl_classifier_with_regex_tokenizer.tflite"; + +constexpr char kPositiveInput[] = + "This is the best movie I’ve seen in recent years. Strongly recommend " + "it!"; +std::vector<core::Category> GetExpectedResultsOfPositiveInput() { + return { + {"Positive", 0.51342660188674927}, + {"Negative", 0.48657345771789551}, + }; +} + +constexpr char kNegativeInput[] = "What a waste of my time."; +std::vector<core::Category> GetExpectedResultsOfNegativeInput() { + return { + {"Positive", 0.18687039613723755}, + {"Negative", 0.81312954425811768}, + }; +} + +const uint8_t kOutputDequantizedTensorIndex = 0; +const uint8_t kOutputQuantizedTensorIndex = 1; +const uint8_t kOutputLabelTensorIndex = 2; +const uint8_t kOutputDequantizedTensorFloat64Index = 3; +constexpr char kInputTensorName[] = "INPUT"; +constexpr char kOutputDequantizedTensorName[] = "OUTPUT_SCORE_DEQUANTIZED"; +constexpr char kOutputDequantizedTensorFloat64Name[] = + "OUTPUT_SCORE_DEQUANTIZED_FLOAT64"; +constexpr char kOutputQuantizedTensorName[] = "OUTPUT_SCORE_QUANTIZED"; +constexpr char kOutputLabelTensorName[] = "LABELS"; +constexpr char kMetadataOutputScoreTensorName[] = "scores_dequantized"; +constexpr char kDefaultInputTensorName[] = "INPUT"; +constexpr char kDefaultOutputLabelTensorName[] = "OUTPUT_LABEL"; +constexpr int kDefaultInputTensorIndex = 0; +constexpr int kDefaultOutputLabelTensorIndex = -1; + +// Test the API with different combinations in creating proto +// NLClassifierOptions +struct ProtoOptionsTestParam { + // description of current test + std::string description; + NLClassifierProtoOptions options; +}; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name); +} + +class ProtoOptionsTest : public TestWithParam<ProtoOptionsTestParam> { + protected: + void SetUp() override { ASSERT_EQ(TfLiteInitializeShimsForTest(), 0); } +}; + +TEST_F(ProtoOptionsTest, CreateFromOptionsSucceedsWithModelWithMetadata) { + NLClassifierProtoOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + GetFullPath(kTestModelWithRegexTokenizer)); + + SUPPORT_ASSERT_OK(NLClassifier::CreateFromOptions(options)); +} + +TEST_F(ProtoOptionsTest, CreateFromOptionsFailsWithMissingBaseOptions) { + NLClassifierProtoOptions options; + StatusOr<std::unique_ptr<NLClassifier>> nl_classifier_or = + NLClassifier::CreateFromOptions(options); + + EXPECT_EQ(nl_classifier_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(nl_classifier_or.status().message(), + HasSubstr("Missing mandatory `base_options`")); + EXPECT_THAT(nl_classifier_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(ProtoOptionsTest, ClassifySucceedsWithBaseOptions) { + std::unique_ptr<NLClassifier> classifier; + + // Test creating NLClassifier when classifier outlives options. + { + std::string contents = + LoadBinaryContent(GetFullPath(kTestModelWithRegexTokenizer).c_str()); + NLClassifierProtoOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_content( + contents); + + SUPPORT_ASSERT_OK_AND_ASSIGN(classifier, + NLClassifier::CreateFromOptions(options)); + } + + std::vector<core::Category> positive_results = + classifier->Classify(kPositiveInput); + + EXPECT_THAT(positive_results, + UnorderedElementsAreArray(GetExpectedResultsOfPositiveInput())); + + std::vector<core::Category> negative_results = + classifier->Classify(kNegativeInput); + EXPECT_THAT(negative_results, + UnorderedElementsAreArray(GetExpectedResultsOfNegativeInput())); +} + +TEST_F(ProtoOptionsTest, CreationFromIncorrectInputTensor) { + NLClassifierProtoOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath)); + options.set_input_tensor_name("invalid_tensor_name"); + options.set_input_tensor_index(-1); + + StatusOr<std::unique_ptr<NLClassifier>> nl_classifier_or = + NLClassifier::CreateFromOptions(options, CreateCustomResolver()); + + EXPECT_EQ(nl_classifier_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(nl_classifier_or.status().message(), + HasSubstr("No input tensor found with name " + "invalid_tensor_name or at index -1")); + EXPECT_THAT( + nl_classifier_or.status().GetPayload(kTfLiteSupportPayload), + absl::Cord(absl::StrCat(TfLiteSupportStatus::kInputTensorNotFoundError))); +} + +TEST_F(ProtoOptionsTest, CreationFromIncorrectOutputScoreTensor) { + NLClassifierProtoOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath)); + options.set_output_score_tensor_name("invalid_tensor_name"); + options.set_output_score_tensor_index(-1); + + StatusOr<std::unique_ptr<NLClassifier>> nl_classifier_or = + NLClassifier::CreateFromOptions(options, CreateCustomResolver()); + + EXPECT_EQ(nl_classifier_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(nl_classifier_or.status().message(), + HasSubstr("No output score tensor found with name " + "invalid_tensor_name or at index -1")); + EXPECT_THAT(nl_classifier_or.status().GetPayload(kTfLiteSupportPayload), + absl::Cord(absl::StrCat( + TfLiteSupportStatus::kOutputTensorNotFoundError))); +} + +TEST_F(ProtoOptionsTest, TestInferenceWithRegexTokenizer) { + // The model with regex tokenizer doesn't need any custom ops. + NLClassifierProtoOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + GetFullPath(kTestModelWithRegexTokenizer)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier, + NLClassifier::CreateFromOptions(options)); + + std::vector<core::Category> positive_results = + classifier->Classify(kPositiveInput); + + EXPECT_THAT(positive_results, + UnorderedElementsAreArray(GetExpectedResultsOfPositiveInput())); + + std::vector<core::Category> negative_results = + classifier->Classify(kNegativeInput); + EXPECT_THAT(negative_results, + UnorderedElementsAreArray(GetExpectedResultsOfNegativeInput())); +} + +TEST_F(ProtoOptionsTest, TestInferenceWithBoolOutput) { + NLClassifierProtoOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + GetFullPath(kTestModelBoolOutputPath)); + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<NLClassifier> classifier, + NLClassifier::CreateFromOptions(options, CreateCustomResolver())); + std::vector<core::Category> results = classifier->Classify(kInputStr); + std::vector<core::Category> expected_class = { + {"0", 1}, + {"1", 1}, + {"2", 0}, + }; + + EXPECT_THAT(results, UnorderedElementsAreArray(expected_class)); +} + +TEST_F(ProtoOptionsTest, TestInferenceWithAssociatedLabelCustomOps) { + NLClassifierProtoOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + GetFullPath(kTestModelWithLabelCustomOpsPath)); + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<NLClassifier> classifier, + NLClassifier::CreateFromOptions(options, CreateCustomResolver())); + std::vector<core::Category> results = classifier->Classify(kInputStr); + std::vector<core::Category> expected_class = { + {"label0", 255}, + {"label1", 510}, + {"label2", 765}, + }; + + EXPECT_THAT(results, UnorderedElementsAreArray(expected_class)); +} + +TEST_F(ProtoOptionsTest, TestInferenceWithAssociatedLabelBuiltinOps) { + NLClassifierProtoOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + GetFullPath(kTestModelWithLabelBuiltInOpsPath)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier, + NLClassifier::CreateFromOptions(options)); + std::vector<core::Category> results = classifier->Classify(kInputStr); + std::vector<core::Category> expected_class = { + {"Negative", 0.49332118034362793}, + {"Positive", 0.50667881965637207}, + }; + + EXPECT_THAT(results, UnorderedElementsAreArray(expected_class)); +} + +// Parameterized test. +struct ProtoOptionsTestParamToString { + std::string operator()( + const testing::TestParamInfo<ProtoOptionsTestParam>& info) const { + return info.param.description; + } +}; + +NLClassifierProtoOptions CreateProtoOptionsFromTensorName( + const char* input_tensor_name, + const char* output_score_tensor_name, + const char* output_label_tensor_name, + const char* model_path) { + NLClassifierProtoOptions options; + options.set_input_tensor_name(input_tensor_name); + options.set_output_score_tensor_name(output_score_tensor_name); + options.set_output_label_tensor_name(output_label_tensor_name); + + options.mutable_base_options()->mutable_model_file()->set_file_name( + model_path); + + return options; +} + +NLClassifierProtoOptions CreateProtoOptionsFromTensorIndex( + const int input_tensor_index, + const int output_score_tensor_index, + const int output_label_tensor_index, + const char* model_path) { + NLClassifierProtoOptions options; + options.set_input_tensor_index(input_tensor_index); + options.set_output_score_tensor_index(output_score_tensor_index); + options.set_output_label_tensor_index(output_label_tensor_index); + + options.mutable_base_options()->mutable_model_file()->set_file_name( + model_path); + + return options; +} + +std::vector<ProtoOptionsTestParam> ClassifyParams() { + return { + { + .description = "FindTensorByNameQuantizeOutputUseTensorLabel", + .options = CreateProtoOptionsFromTensorName( + kDefaultInputTensorName, kOutputQuantizedTensorName, + kOutputLabelTensorName, GetFullPath(kTestModelPath).c_str()), + }, + { + .description = "FindTensorByNameQuantizeOutputUseIndexLabel", + .options = CreateProtoOptionsFromTensorName( + kDefaultInputTensorName, kOutputQuantizedTensorName, + kDefaultOutputLabelTensorName, + GetFullPath(kTestModelPath).c_str()), + }, + { + .description = "FindTensorByNameDequantizeOutputUseTensorLabel", + .options = CreateProtoOptionsFromTensorName( + kDefaultInputTensorName, kOutputDequantizedTensorName, + kOutputLabelTensorName, GetFullPath(kTestModelPath).c_str()), + }, + { + .description = "FindTensorByNameDequantizeOutputUseIndexLabel", + .options = CreateProtoOptionsFromTensorName( + kDefaultInputTensorName, kOutputDequantizedTensorName, + kDefaultOutputLabelTensorName, + GetFullPath(kTestModelPath).c_str()), + }, + { + .description = + "FindTensorByNameDequantizeFloat64OutputUseTensorLabel", + .options = CreateProtoOptionsFromTensorName( + kDefaultInputTensorName, kOutputDequantizedTensorFloat64Name, + kOutputLabelTensorName, GetFullPath(kTestModelPath).c_str()), + }, + { + .description = "FindTensorByNameDequantizeFloat64OutputUseIndexLabel", + .options = CreateProtoOptionsFromTensorName( + kDefaultInputTensorName, kOutputDequantizedTensorFloat64Name, + kDefaultOutputLabelTensorName, + GetFullPath(kTestModelPath).c_str()), + }, + { + .description = "FindTensorByIndexQuantizeOutputUseTensorLabel", + .options = CreateProtoOptionsFromTensorIndex( + kDefaultInputTensorIndex, kOutputQuantizedTensorIndex, + kOutputLabelTensorIndex, GetFullPath(kTestModelPath).c_str()), + }, + { + .description = "FindTensorByIndexQuantizeOutputUseIndexLabel", + .options = CreateProtoOptionsFromTensorIndex( + kDefaultInputTensorIndex, kOutputQuantizedTensorIndex, + kDefaultOutputLabelTensorIndex, + GetFullPath(kTestModelPath).c_str()), + }, + { + .description = "FindTensorByIndexDequantizeOutputUseTensorLabel", + .options = CreateProtoOptionsFromTensorIndex( + kDefaultInputTensorIndex, kOutputDequantizedTensorIndex, + kOutputLabelTensorIndex, GetFullPath(kTestModelPath).c_str()), + }, + { + .description = "FindTensorByIndexDequantizeOutputUseIndexLabel", + .options = CreateProtoOptionsFromTensorIndex( + kDefaultInputTensorIndex, kOutputDequantizedTensorIndex, + kDefaultOutputLabelTensorIndex, + GetFullPath(kTestModelPath).c_str()), + }, + { + .description = + "FindTensorByIndexDequantizeFloat64OutputUseTensorLabel", + .options = CreateProtoOptionsFromTensorIndex( + kDefaultInputTensorIndex, kOutputDequantizedTensorFloat64Index, + kOutputLabelTensorIndex, GetFullPath(kTestModelPath).c_str()), + }, + { + .description = + "FindTensorByIndexDequantizeFloat64OutputUseIndexLabel", + .options = CreateProtoOptionsFromTensorIndex( + kDefaultInputTensorIndex, kOutputDequantizedTensorFloat64Index, + kDefaultOutputLabelTensorIndex, + GetFullPath(kTestModelPath).c_str()), + }, + }; +} + +TEST_P(ProtoOptionsTest, TestClassify) { + NLClassifierProtoOptions options = GetParam().options; + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<NLClassifier> classifier, + NLClassifier::CreateFromOptions(options, CreateCustomResolver())); + std::vector<core::Category> results = classifier->Classify(kInputStr); + + bool assert_label_name = + options.output_label_tensor_index() == kOutputLabelTensorIndex || + options.output_label_tensor_name() == kOutputLabelTensorName; + + std::vector<core::Category> expected_class; + if (assert_label_name) { + expected_class = { + {"label0", 255}, + {"label1", 510}, + {"label2", 765}, + }; + } else { + expected_class = { + {"0", 255}, + {"1", 510}, + {"2", 765}, + }; + } + + EXPECT_THAT(results, UnorderedElementsAreArray(expected_class)); +} + +INSTANTIATE_TEST_SUITE_P(TestClassify, + ProtoOptionsTest, + ValuesIn(ClassifyParams()), + ProtoOptionsTestParamToString()); + +// Tests for struct sNLClassifierOptions. +class StructOptionsTest : public tflite_shims::testing::Test {}; + +void AssertStatus(absl::Status status, + absl::StatusCode status_code, + TfLiteSupportStatus tfls_code) { + ASSERT_EQ(status.code(), status_code); + EXPECT_THAT(status.GetPayload(kTfLiteSupportPayload), + testing::Optional(absl::Cord(absl::StrCat(tfls_code)))); +} + +TEST_F(StructOptionsTest, TestApiCreationFromBuffer) { + std::string model_buffer = LoadBinaryContent( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath) + .c_str()); + SUPPORT_ASSERT_OK(NLClassifier::CreateFromBufferAndOptions( + model_buffer.data(), model_buffer.size(), {}, CreateCustomResolver())); +} + +TEST_F(StructOptionsTest, TestApiCreationFromFile) { + SUPPORT_ASSERT_OK(NLClassifier::CreateFromFileAndOptions( + GetFullPath(kTestModelPath), {}, CreateCustomResolver())); +} + +TEST_F(StructOptionsTest, TestApiCreationFromIncorrectInputTensor) { + NLClassifierOptions options; + options.input_tensor_index = -1; + options.input_tensor_name = "I do not exist"; + AssertStatus( + NLClassifier::CreateFromFileAndOptions( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath), + options, CreateCustomResolver()) + .status(), + absl::StatusCode::kInvalidArgument, + TfLiteSupportStatus::kInputTensorNotFoundError); +} + +TEST_F(StructOptionsTest, TestApiCreationFromIncorrectOutputScoreTensor) { + NLClassifierOptions options; + options.output_score_tensor_index = 123; + + AssertStatus(NLClassifier::CreateFromFileAndOptions( + GetFullPath(kTestModelPath), options, CreateCustomResolver()) + .status(), + absl::StatusCode::kInvalidArgument, + TfLiteSupportStatus::kOutputTensorNotFoundError); +} + +TEST_F(StructOptionsTest, TestInferenceWithRegexTokenizer) { + NLClassifierOptions options; + options.input_tensor_name = "input_text"; + options.output_score_tensor_name = "probability"; + + // The model with regex tokenizer doesn't need any custom ops. + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<NLClassifier> classifier, + NLClassifier::CreateFromFileAndOptions( + GetFullPath(kTestModelWithRegexTokenizer), options)); + + std::vector<core::Category> positive_results = + classifier->Classify(kPositiveInput); + + EXPECT_THAT(positive_results, + UnorderedElementsAreArray(GetExpectedResultsOfPositiveInput())); + + std::vector<core::Category> negative_results = + classifier->Classify(kNegativeInput); + EXPECT_THAT(negative_results, + UnorderedElementsAreArray(GetExpectedResultsOfNegativeInput())); +} + +TEST_F(StructOptionsTest, TestInferenceWithBoolOutput) { + NLClassifierOptions options; + options.input_tensor_index = 0; + options.output_score_tensor_index = 0; + + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier, + NLClassifier::CreateFromFileAndOptions( + GetFullPath(kTestModelBoolOutputPath), + options, CreateCustomResolver())); + std::vector<core::Category> results = classifier->Classify(kInputStr); + std::vector<core::Category> expected_class = { + {"0", 1}, + {"1", 1}, + {"2", 0}, + }; + + EXPECT_THAT(results, UnorderedElementsAreArray(expected_class)); +} + +TEST_F(StructOptionsTest, TestInferenceWithAssociatedLabelCustomOps) { + NLClassifierOptions options; + options.output_score_tensor_name = kMetadataOutputScoreTensorName; + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<NLClassifier> classifier, + NLClassifier::CreateFromFileAndOptions( + GetFullPath(kTestModelWithLabelCustomOpsPath), options, + CreateCustomResolver())); + std::vector<core::Category> results = classifier->Classify(kInputStr); + std::vector<core::Category> expected_class = { + {"label0", 255}, + {"label1", 510}, + {"label2", 765}, + }; + + EXPECT_THAT(results, UnorderedElementsAreArray(expected_class)); +} + +TEST_F(StructOptionsTest, TestInferenceWithAssociatedLabelBuiltinOps) { + NLClassifierOptions options; + options.input_tensor_index = 0; + options.output_score_tensor_index = 0; + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<NLClassifier> classifier, + NLClassifier::CreateFromFileAndOptions( + GetFullPath(kTestModelWithLabelBuiltInOpsPath), options)); + std::vector<core::Category> results = classifier->Classify(kInputStr); + std::vector<core::Category> expected_class = { + {"Negative", 0.49332118034362793}, + {"Positive", 0.50667881965637207}, + }; + + EXPECT_THAT(results, UnorderedElementsAreArray(expected_class)); +} + +} // namespace +} // namespace nlclassifier +} // namespace text +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test_utils.cc new file mode 100644 index 0000000..92835a9 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test_utils.cc
@@ -0,0 +1,197 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test_utils.h" + +#include "tensorflow/lite/kernels/builtin_op_kernels.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" + +namespace tflite { +namespace ops { +namespace custom { + +constexpr char kInputStr[] = "hello"; + +namespace string_floats { +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, output != nullptr); + TfLiteIntArray* dims = TfLiteIntArrayCreate(1); + dims->data[0] = 3; // Size of 'data' in Invoke() + return context->ResizeTensor(context, output, dims); +} + +TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input_tensor = GetInput(context, node, 0); + TF_LITE_ENSURE(context, input_tensor != nullptr); + StringRef input_str_ref = GetString(input_tensor, 0); + std::string input_str(input_str_ref.str, input_str_ref.len); + if (input_str != kInputStr) { + return kTfLiteError; + } + std::vector<float> data = {255, 510, 765}; + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, output != nullptr); + // Default Quantize OP scale is 255, will be quantized to {1, 2, 3} + TF_LITE_ENSURE(context, + tflite::task::core::PopulateTensor(data, output).ok()); + return kTfLiteOk; +} + +// This custom op takes a string tensor in and outputs a float32 tensor with +// value{0.1, 0.2, 0.3}, it's used to mimic a real text classification model +// which classifies a string into scores of different categories. +TfLiteRegistration* Register() { + // Dummy implementation of custom OP + // This op takes string as input and outputs float[] + static TfLiteRegistration r = { + .init = nullptr, .free = nullptr, .prepare = Prepare, .invoke = Invoke}; + return &r; +} +}; // namespace string_floats + +namespace string_doubles { +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, output != nullptr); + TfLiteIntArray* dims = TfLiteIntArrayCreate(1); + dims->data[0] = 3; // Size of 'data' in Invoke() + return context->ResizeTensor(context, output, dims); +} + +TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input_tensor = GetInput(context, node, 0); + TF_LITE_ENSURE(context, input_tensor != nullptr); + StringRef input_str_ref = GetString(input_tensor, 0); + std::string input_str(input_str_ref.str, input_str_ref.len); + if (input_str != kInputStr) { + return kTfLiteError; + } + std::vector<double> data = {255, 510, 765}; + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, output != nullptr); + // Default Quantize OP scale is 255, will be quantized to {1, 2, 3} + TF_LITE_ENSURE(context, + tflite::task::core::PopulateTensor(data, output).ok()); + return kTfLiteOk; +} + +// This custom op takes a string tensor in and outputs a float64 tensor with +// value{0.1, 0.2, 0.3}, it's used to mimic a real text classification model +// which classifies a string into scores of different categories. +TfLiteRegistration* Register() { + // Dummy implementation of custom OP + // This op takes string as input and outputs double[] + static TfLiteRegistration r = { + .init = nullptr, .free = nullptr, .prepare = Prepare, .invoke = Invoke}; + return &r; +} +}; // namespace string_doubles + +namespace string_bools { +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, output != nullptr); + TfLiteIntArray* dims = TfLiteIntArrayCreate(1); + dims->data[0] = 3; // Size of 'data' in Invoke() + return context->ResizeTensor(context, output, dims); +} + +TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input_tensor = GetInput(context, node, 0); + TF_LITE_ENSURE(context, input_tensor != nullptr); + StringRef input_str_ref = GetString(input_tensor, 0); + std::string input_str(input_str_ref.str, input_str_ref.len); + if (input_str != kInputStr) { + return kTfLiteError; + } + bool data[] = {true, true, false}; + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, output != nullptr); + TF_LITE_ENSURE(context, + tflite::task::core::PopulateTensor(data, 3, output).ok()); + return kTfLiteOk; +} + +// This custom op takes a string tensor in and outputs a bool tensor with +// value{true, true, false}, it's used to mimic a real text classification model +// which classifies a string into scores of different categories. +TfLiteRegistration* Register() { + // Dummy implementation of custom OP + // This op takes string as input and outputs bool[] + static TfLiteRegistration r = { + .init = nullptr, .free = nullptr, .prepare = Prepare, .invoke = Invoke}; + return &r; +} +}; // namespace string_bools + +TfLiteStatus GenerateLabelsInvoke(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input_tensor = GetInput(context, node, 0); + TF_LITE_ENSURE(context, input_tensor != nullptr); + StringRef input_str_ref = GetString(input_tensor, 0); + std::string input_str(input_str_ref.str, input_str_ref.len); + if (input_str != kInputStr) { + return kTfLiteError; + } + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, output != nullptr); + std::vector<std::string> data = {"label0", "label1", "label2"}; + TF_LITE_ENSURE(context, + tflite::task::core::PopulateTensor(data, output).ok()); + return kTfLiteOk; +} + +// This custom op takes a string tensor in and outputs a string tensor with +// value{"label0", "label1", "label2"}, it's used to mimic a real text +// classification model that stores Class names inside a tensor. +TfLiteRegistration* Register_CUSTOM_OP_GENERATE_LABELS() { + static TfLiteRegistration r = {.init = nullptr, + .free = nullptr, + .prepare = nullptr, + .invoke = GenerateLabelsInvoke}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite + +namespace tflite { +namespace task { +namespace text { +namespace nlclassifier { + +std::unique_ptr<MutableOpResolver> CreateCustomResolver() { + MutableOpResolver resolver; + resolver.AddBuiltin(::tflite::BuiltinOperator_QUANTIZE, + ::tflite::ops::builtin::Register_QUANTIZE()); + resolver.AddCustom("CUSTOM_OP_STRING_TO_FLOATS", + ::tflite::ops::custom::string_floats::Register()); + resolver.AddCustom("CUSTOM_OP_STRING_TO_DOUBLES", + ::tflite::ops::custom::string_doubles::Register()); + resolver.AddCustom("CUSTOM_OP_STRING_TO_BOOLS", + ::tflite::ops::custom::string_bools::Register()); + resolver.AddCustom( + "CUSTOM_OP_GENERATE_LABELS", + ::tflite::ops::custom::Register_CUSTOM_OP_GENERATE_LABELS()); + return absl::make_unique<MutableOpResolver>(resolver); +} + +} // namespace nlclassifier +} // namespace text +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test_utils.h new file mode 100644 index 0000000..b3c2501 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test_utils.h
@@ -0,0 +1,37 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEST_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_TEST_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TEST_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_TEST_UTILS_H_ + +#include <memory> + +#include "tensorflow/lite/op_resolver.h" + +namespace tflite { +namespace task { +namespace text { +namespace nlclassifier { + +// Create a custom MutableOpResolver to provide custom OP implementations to +// mimic classification behavior. +std::unique_ptr<MutableOpResolver> CreateCustomResolver(); + +} // namespace nlclassifier +} // namespace text +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TEST_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_TEST_UTILS_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/BUILD new file mode 100644 index 0000000..24a4b9108 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/BUILD
@@ -0,0 +1,175 @@ +load( + "@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", + "cc_test_with_tflite", +) + +package( + default_visibility = [ + "//visibility:private", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_test_with_tflite( + name = "image_classifier_test", + srcs = ["image_classifier_test.cc"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_images", + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_models", + ], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/cc/task/vision:image_classifier", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/cc/test:test_utils", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:cord", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + +# To test it with Bazel, plugin a Coral device, and run the following command: +# bazel test tensorflow_lite_support/cc/test/task/vision:image_classifier_coral_test \ +# --define darwinn_portable=1 +cc_test( + name = "image_classifier_coral_test", + srcs = ["image_classifier_coral_test.cc"], + data = [ + "//tensorflow_lite_support/acceleration/configuration/testdata:test_files", + ], + tags = [ + "manual", + "notap", # Requires edge TPU device. + ], + deps = [ + "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin", + "//tensorflow_lite_support/cc/port:configuration_proto_inc", + "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/task/vision:image_classifier", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "@com_google_absl//absl/status", + ], +) + +cc_test_with_tflite( + name = "object_detector_test", + srcs = ["object_detector_test.cc"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_images", + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_models", + ], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/cc/task/vision:object_detector", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/cc/test:test_utils", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:cord", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + +cc_test_with_tflite( + name = "image_segmenter_test", + srcs = ["image_segmenter_test.cc"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_images", + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_models", + ], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/cc/task/vision:image_segmenter", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/cc/test:test_utils", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:cord", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + +cc_test_with_tflite( + name = "image_embedder_test", + srcs = ["image_embedder_test.cc"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_images", + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_models", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/vision:image_embedder", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:embeddings_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:image_embedder_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/cc/test:test_utils", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_coral_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_coral_test.cc new file mode 100644 index 0000000..7e6a311a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_coral_test.cc
@@ -0,0 +1,53 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> + +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/configuration_proto_inc.h" +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/image_classifier.h" +#include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +namespace tflite { +namespace task { +namespace vision { +namespace { + +using ::tflite::support::StatusOr; + +constexpr char kEdgeTpuModelFilePath[] = + "tensorflow_lite_support/acceleration/configuration/testdata/" + "mobilenet_v1_1.0_224_quant_edgetpu.tflite"; +constexpr char kRegularModelFilePath[] = + "tensorflow_lite_support/acceleration/configuration/testdata/" + "mobilenet_v1_1.0_224_quant.tflite"; +constexpr char kImagePath[] = + "tensorflow_lite_support/acceleration/configuration/testdata/" + "burger.jpg"; + +using ClassifyTest = testing::TestWithParam<std::string>; + +} // namespace +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc new file mode 100644 index 0000000..ae4e48c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc
@@ -0,0 +1,666 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/vision/image_classifier.h" + +#include <memory> + +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/cord.h" // from @com_google_absl +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/kernels/builtin_op_kernels.h" +#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +namespace tflite { +namespace task { +namespace vision { +namespace { + +using ::testing::ElementsAreArray; +using ::testing::HasSubstr; +using ::testing::Optional; +using ::tflite::support::kTfLiteSupportPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::JoinPath; +using ::tflite::task::ParseTextProtoOrDie; +using ::tflite::task::core::PopulateTensor; +using ::tflite::task::core::TaskAPIFactory; +using ::tflite::task::core::TfLiteEngine; + +constexpr char kTestDataDirectory[] = + "/tensorflow_lite_support/cc/test/testdata/task/" + "vision/"; +// Float model. +constexpr char kMobileNetFloatWithMetadata[] = "mobilenet_v2_1.0_224.tflite"; +// Quantized model. +constexpr char kMobileNetQuantizedWithMetadata[] = + "mobilenet_v1_0.25_224_quant.tflite"; +// Hello world flowers classifier supporting 5 classes (quantized model). +constexpr char kAutoMLModelWithMetadata[] = "automl_labeler_model.tflite"; + +StatusOr<ImageData> LoadImage(std::string image_name) { + return DecodeImageFromFile( + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name)); +} + +// If the proto definition changes, please also change this function. +void ExpectApproximatelyEqual(const ClassificationResult& actual, + const ClassificationResult& expected) { + const float kPrecision = 1e-6; + EXPECT_EQ(actual.classifications_size(), expected.classifications_size()); + for (int i = 0; i < actual.classifications_size(); ++i) { + const Classifications& a = actual.classifications(i); + const Classifications& b = expected.classifications(i); + EXPECT_EQ(a.head_index(), b.head_index()); + EXPECT_EQ(a.classes_size(), b.classes_size()); + for (int j = 0; j < a.classes_size(); ++j) { + EXPECT_EQ(a.classes(j).index(), b.classes(j).index()); + EXPECT_EQ(a.classes(j).class_name(), b.classes(j).class_name()); + EXPECT_EQ(a.classes(j).display_name(), b.classes(j).display_name()); + EXPECT_NEAR(a.classes(j).score(), b.classes(j).score(), kPrecision); + } + } +} + +class MobileNetQuantizedOpResolver : public ::tflite::MutableOpResolver { + public: + MobileNetQuantizedOpResolver() { + AddBuiltin(::tflite::BuiltinOperator_AVERAGE_POOL_2D, + ::tflite::ops::builtin::Register_AVERAGE_POOL_2D()); + AddBuiltin(::tflite::BuiltinOperator_CONV_2D, + ::tflite::ops::builtin::Register_CONV_2D()); + AddBuiltin(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, + ::tflite::ops::builtin::Register_DEPTHWISE_CONV_2D()); + AddBuiltin(::tflite::BuiltinOperator_RESHAPE, + ::tflite::ops::builtin::Register_RESHAPE()); + AddBuiltin(::tflite::BuiltinOperator_SOFTMAX, + ::tflite::ops::builtin::Register_SOFTMAX()); + } + + MobileNetQuantizedOpResolver(const MobileNetQuantizedOpResolver& r) = delete; +}; + +class CreateFromOptionsTest : public tflite_shims::testing::Test {}; + +TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) { + ImageClassifierOptions options; + options.set_max_results(3); + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileNetQuantizedWithMetadata)); + + SUPPORT_ASSERT_OK(ImageClassifier::CreateFromOptions( + options, absl::make_unique<MobileNetQuantizedOpResolver>())); +} + +class MobileNetQuantizedOpResolverMissingOps + : public ::tflite::MutableOpResolver { + public: + MobileNetQuantizedOpResolverMissingOps() { + AddBuiltin(::tflite::BuiltinOperator_SOFTMAX, + ::tflite::ops::builtin::Register_SOFTMAX()); + } + + MobileNetQuantizedOpResolverMissingOps( + const MobileNetQuantizedOpResolverMissingOps& r) = delete; +}; + +TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { + ImageClassifierOptions options; + options.set_max_results(3); + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileNetQuantizedWithMetadata)); + + auto image_classifier_or = ImageClassifier::CreateFromOptions( + options, absl::make_unique<MobileNetQuantizedOpResolverMissingOps>()); + EXPECT_EQ(image_classifier_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_classifier_or.status().message(), + HasSubstr("Didn't find op for builtin opcode")); + EXPECT_THAT(image_classifier_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kUnsupportedBuiltinOp)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) { + ImageClassifierOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileNetQuantizedWithMetadata)); + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + + StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or = + ImageClassifier::CreateFromOptions(options); + + EXPECT_EQ(image_classifier_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_classifier_or.status().message(), + HasSubstr("Expected exactly one of `base_options.model_file` or " + "`model_file_with_metadata` to be provided, found 2.")); + EXPECT_THAT(image_classifier_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { + ImageClassifierOptions options; + + StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or = + ImageClassifier::CreateFromOptions(options); + + EXPECT_EQ(image_classifier_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_classifier_or.status().message(), + HasSubstr("Expected exactly one of `base_options.model_file` or " + "`model_file_with_metadata` to be provided, found 0.")); + EXPECT_THAT(image_classifier_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { + ImageClassifierOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileNetQuantizedWithMetadata)); + options.set_max_results(0); + + StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or = + ImageClassifier::CreateFromOptions(options); + + EXPECT_EQ(image_classifier_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_classifier_or.status().message(), + HasSubstr("Invalid `max_results` option")); + EXPECT_THAT(image_classifier_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) { + ImageClassifierOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileNetQuantizedWithMetadata)); + options.add_class_name_whitelist("foo"); + options.add_class_name_blacklist("bar"); + + StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or = + ImageClassifier::CreateFromOptions(options); + + EXPECT_EQ(image_classifier_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_classifier_or.status().message(), + HasSubstr("mutually exclusive options")); + EXPECT_THAT(image_classifier_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) { + ImageClassifierOptions options; + options.set_num_threads(4); + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + + SUPPORT_ASSERT_OK(ImageClassifier::CreateFromOptions(options)); +} + +using NumThreadsTest = testing::TestWithParam<int>; + +INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2)); + +TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) { + ImageClassifierOptions options; + options.set_num_threads(GetParam()); + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + + StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or = + ImageClassifier::CreateFromOptions(options); + + EXPECT_EQ(image_classifier_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_classifier_or.status().message(), + HasSubstr("`num_threads` must be greater than " + "0 or equal to -1")); + EXPECT_THAT(image_classifier_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST(ClassifyTest, SucceedsWithFloatModel) { + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); + + ImageClassifierOptions options; + options.set_max_results(3); + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ImageClassifier> image_classifier, + ImageClassifier::CreateFromOptions(options)); + + StatusOr<ClassificationResult> result_or = + image_classifier->Classify(*frame_buffer); + ImageDataFree(&rgb_image); + SUPPORT_ASSERT_OK(result_or); + + const ClassificationResult& result = result_or.value(); + ExpectApproximatelyEqual( + result, + ParseTextProtoOrDie<ClassificationResult>( + R"pb(classifications { + classes { + index: 934 + score: 0.7399742 + class_name: "cheeseburger" + } + classes { + index: 925 + score: 0.026928535 + class_name: "guacamole" + } + classes { index: 932 score: 0.025737215 class_name: "bagel" } + head_index: 0 + } + )pb")); +} + +TEST(ClassifyTest, SucceedsWithRegionOfInterest) { + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, + LoadImage("multi_objects.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); + + ImageClassifierOptions options; + options.set_max_results(1); + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ImageClassifier> image_classifier, + ImageClassifier::CreateFromOptions(options)); + + // Crop around the soccer ball. + BoundingBox roi; + roi.set_origin_x(406); + roi.set_origin_y(110); + roi.set_width(148); + roi.set_height(153); + + StatusOr<ClassificationResult> result_or = + image_classifier->Classify(*frame_buffer, roi); + ImageDataFree(&rgb_image); + SUPPORT_ASSERT_OK(result_or); + + const ClassificationResult& result = result_or.value(); + ExpectApproximatelyEqual(result, ParseTextProtoOrDie<ClassificationResult>( + R"pb(classifications { + classes { + index: 806 + score: 0.99673367 + class_name: "soccer ball" + } + head_index: 0 + } + )pb")); +} + +TEST(ClassifyTest, SucceedsWithQuantizedModel) { + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); + + ImageClassifierOptions options; + options.set_max_results(3); + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileNetQuantizedWithMetadata)); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ImageClassifier> image_classifier, + ImageClassifier::CreateFromOptions(options)); + + StatusOr<ClassificationResult> result_or = + image_classifier->Classify(*frame_buffer); + ImageDataFree(&rgb_image); + SUPPORT_ASSERT_OK(result_or); + + const ClassificationResult& result = result_or.value(); + ExpectApproximatelyEqual( + result, + ParseTextProtoOrDie<ClassificationResult>( + R"pb(classifications { + classes { + index: 934 + score: 0.96484375 + class_name: "cheeseburger" + } + classes { index: 948 score: 0.0078125 class_name: "mushroom" } + classes { index: 924 score: 0.00390625 class_name: "plate" } + head_index: 0 + } + )pb")); +} + +TEST(ClassifyTest, SucceedsWithBaseOptions) { + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); + + ImageClassifierOptions options; + options.set_max_results(3); + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ImageClassifier> image_classifier, + ImageClassifier::CreateFromOptions(options)); + + StatusOr<ClassificationResult> result_or = + image_classifier->Classify(*frame_buffer); + ImageDataFree(&rgb_image); + SUPPORT_ASSERT_OK(result_or); + + const ClassificationResult& result = result_or.value(); + ExpectApproximatelyEqual( + result, + ParseTextProtoOrDie<ClassificationResult>( + R"pb(classifications { + classes { + index: 934 + score: 0.7399742 + class_name: "cheeseburger" + } + classes { + index: 925 + score: 0.026928535 + class_name: "guacamole" + } + classes { index: 932 score: 0.025737215 class_name: "bagel" } + head_index: 0 + } + )pb")); +} + +TEST(ClassifyTest, GetInputCountSucceeds) { + ImageClassifierOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ImageClassifier> image_classifier, + ImageClassifier::CreateFromOptions(options)); + + int32_t input_count = image_classifier->GetInputCount(); + EXPECT_THAT(input_count, 1); +} + +TEST(ClassifyTest, GetInputShapeSucceeds) { + ImageClassifierOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ImageClassifier> image_classifier, + ImageClassifier::CreateFromOptions(options)); + + // Verify the shape array size. + const TfLiteIntArray* input_shape_0 = image_classifier->GetInputShape(0); + EXPECT_THAT(input_shape_0->size, 4); + + // Verify the shape array data. + auto shape_data = input_shape_0->data; + std::vector<int> shape_vector(shape_data, shape_data + input_shape_0->size); + EXPECT_THAT(shape_vector, ElementsAreArray({1, 224, 224, 3})); +} + +TEST(ClassifyTest, GetOutputCountSucceeds) { + ImageClassifierOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ImageClassifier> image_classifier, + ImageClassifier::CreateFromOptions(options)); + + int32_t output_count = image_classifier->GetOutputCount(); + EXPECT_THAT(output_count, 1); +} + +TEST(ClassifyTest, GetOutputShapeSucceeds) { + ImageClassifierOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata)); + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ImageClassifier> image_classifier, + ImageClassifier::CreateFromOptions(options)); + + // Verify the shape array size. + const TfLiteIntArray* output_shape_0 = image_classifier->GetOutputShape(0); + EXPECT_THAT(output_shape_0->size, 2); + + // Verify the shape array data. + auto shape_data = output_shape_0->data; + std::vector<int> shape_vector(shape_data, shape_data + output_shape_0->size); + EXPECT_THAT(shape_vector, ElementsAreArray({1, 1001})); +} + +class PostprocessTest : public tflite_shims::testing::Test { + public: + class TestImageClassifier : public ImageClassifier { + public: + using ImageClassifier::ImageClassifier; + using ImageClassifier::Postprocess; + + static StatusOr<std::unique_ptr<TestImageClassifier>> CreateFromOptions( + const ImageClassifierOptions& options) { + RETURN_IF_ERROR(SanityCheckOptions(options)); + + auto options_copy = absl::make_unique<ImageClassifierOptions>(options); + + ASSIGN_OR_RETURN( + auto image_classifier, + TaskAPIFactory::CreateFromExternalFileProto<TestImageClassifier>( + &options_copy->model_file_with_metadata())); + + RETURN_IF_ERROR(image_classifier->Init(std::move(options_copy))); + + return image_classifier; + } + + TfLiteTensor* GetOutputTensor() { + if (TfLiteEngine::OutputCount(GetTfLiteEngine()->interpreter()) != 1) { + return nullptr; + } + return TfLiteEngine::GetOutput(GetTfLiteEngine()->interpreter(), 0); + } + }; + + protected: + void SetUp() override { tflite_shims::testing::Test::SetUp(); } + void SetUp(const ImageClassifierOptions& options) { + StatusOr<std::unique_ptr<TestImageClassifier>> test_image_classifier_or = + TestImageClassifier::CreateFromOptions(options); + + init_status_ = test_image_classifier_or.status(); + + if (init_status_.ok()) { + test_image_classifier_ = std::move(test_image_classifier_or).value(); + } + + dummy_frame_buffer_ = CreateFromRgbRawBuffer(/*input=*/nullptr, {}); + } + + std::unique_ptr<TestImageClassifier> test_image_classifier_; + std::unique_ptr<FrameBuffer> dummy_frame_buffer_; + absl::Status init_status_; +}; + +TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) { + ImageClassifierOptions options; + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata)); + options.set_max_results(3); + + SetUp(options); + ASSERT_TRUE(test_image_classifier_ != nullptr) << init_status_; + + TfLiteTensor* output_tensor = test_image_classifier_->GetOutputTensor(); + ASSERT_NE(output_tensor, nullptr); + + std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255, + /*sunflowers*/ 32, /*tulips*/ 128}; + SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor)); + SUPPORT_ASSERT_OK_AND_ASSIGN( + ClassificationResult result, + test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_, + /*roi=*/{})); + ExpectApproximatelyEqual( + result, + ParseTextProtoOrDie<ClassificationResult>( + R"pb(classifications { + classes { index: 2 score: 0.99609375 class_name: "roses" } + classes { index: 4 score: 0.5 class_name: "tulips" } + classes { index: 1 score: 0.25 class_name: "dandelion" } + head_index: 0 + } + )pb")); +} + +TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) { + ImageClassifierOptions options; + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata)); + options.set_score_threshold(0.4); + + SetUp(options); + ASSERT_TRUE(test_image_classifier_ != nullptr) << init_status_; + + TfLiteTensor* output_tensor = test_image_classifier_->GetOutputTensor(); + ASSERT_NE(output_tensor, nullptr); + + std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255, + /*sunflowers*/ 32, /*tulips*/ 128}; + SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor)); + SUPPORT_ASSERT_OK_AND_ASSIGN( + ClassificationResult result, + test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_, + /*roi=*/{})); + + ExpectApproximatelyEqual( + result, + ParseTextProtoOrDie<ClassificationResult>( + R"pb(classifications { + classes { index: 2 score: 0.99609375 class_name: "roses" } + classes { index: 4 score: 0.5 class_name: "tulips" } + head_index: 0 + } + )pb")); +} + +TEST_F(PostprocessTest, SucceedsWithWhitelistOption) { + ImageClassifierOptions options; + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata)); + options.add_class_name_whitelist("dandelion"); + options.add_class_name_whitelist("daisy"); + + SetUp(options); + ASSERT_TRUE(test_image_classifier_ != nullptr) << init_status_; + + TfLiteTensor* output_tensor = test_image_classifier_->GetOutputTensor(); + ASSERT_NE(output_tensor, nullptr); + + std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255, + /*sunflowers*/ 32, /*tulips*/ 128}; + SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor)); + SUPPORT_ASSERT_OK_AND_ASSIGN( + ClassificationResult result, + test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_, + /*roi=*/{})); + ExpectApproximatelyEqual( + result, + ParseTextProtoOrDie<ClassificationResult>( + R"pb(classifications { + classes { index: 1 score: 0.25 class_name: "dandelion" } + classes { index: 0 score: 0 class_name: "daisy" } + head_index: 0 + } + )pb")); +} + +TEST_F(PostprocessTest, SucceedsWithBlacklistOption) { + ImageClassifierOptions options; + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata)); + options.add_class_name_blacklist("dandelion"); + options.add_class_name_blacklist("daisy"); + + SetUp(options); + ASSERT_TRUE(test_image_classifier_ != nullptr) << init_status_; + + TfLiteTensor* output_tensor = test_image_classifier_->GetOutputTensor(); + ASSERT_NE(output_tensor, nullptr); + + std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255, + /*sunflowers*/ 32, /*tulips*/ 128}; + SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor)); + SUPPORT_ASSERT_OK_AND_ASSIGN( + ClassificationResult result, + test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_, + /*roi=*/{})); + + ExpectApproximatelyEqual( + result, + ParseTextProtoOrDie<ClassificationResult>( + R"pb(classifications { + classes { index: 2 score: 0.99609375 class_name: "roses" } + classes { index: 4 score: 0.5 class_name: "tulips" } + classes { index: 3 score: 0.125 class_name: "sunflowers" } + head_index: 0 + } + )pb")); +} + +} // namespace +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc new file mode 100644 index 0000000..8877f28 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc
@@ -0,0 +1,451 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/vision/image_embedder.h" + +#include <memory> + +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/kernels/builtin_op_kernels.h" +#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/embeddings_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_embedder_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +namespace tflite { +namespace task { +namespace vision { +namespace { + +using ::testing::HasSubstr; +using ::testing::Optional; +using ::tflite::support::kTfLiteSupportPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::JoinPath; + +constexpr char kTestDataDirectory[] = + "/tensorflow_lite_support/cc/test/testdata/task/" + "vision/"; +// Test model. Float inputs, produces feature vectors that are not +// L2-normalized as this model doesn't include a L2_NORMALIZATION TFLite Op. +constexpr char kMobileNetV3[] = "mobilenet_v3_small_100_224_embedder.tflite"; +// Tolerancy for cosine similarity evaluation. +constexpr double kSimilarityTolerancy = 1e-6; + +StatusOr<ImageData> LoadImage(std::string image_name) { + return DecodeImageFromFile( + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name)); +} + +class MobileNetV3OpResolver : public ::tflite::MutableOpResolver { + public: + MobileNetV3OpResolver() { + AddBuiltin(::tflite::BuiltinOperator_MUL, + ::tflite::ops::builtin::Register_MUL()); + AddBuiltin(::tflite::BuiltinOperator_SUB, + ::tflite::ops::builtin::Register_SUB()); + AddBuiltin(::tflite::BuiltinOperator_CONV_2D, + ::tflite::ops::builtin::Register_CONV_2D()); + AddBuiltin(::tflite::BuiltinOperator_HARD_SWISH, + ::tflite::ops::builtin::Register_HARD_SWISH()); + AddBuiltin(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, + ::tflite::ops::builtin::Register_DEPTHWISE_CONV_2D()); + AddBuiltin(::tflite::BuiltinOperator_MEAN, + ::tflite::ops::builtin::Register_MEAN()); + AddBuiltin(::tflite::BuiltinOperator_ADD, + ::tflite::ops::builtin::Register_ADD()); + AddBuiltin(::tflite::BuiltinOperator_AVERAGE_POOL_2D, + ::tflite::ops::builtin::Register_AVERAGE_POOL_2D()); + AddBuiltin(::tflite::BuiltinOperator_RESHAPE, + ::tflite::ops::builtin::Register_RESHAPE()); + } + + MobileNetV3OpResolver(const MobileNetV3OpResolver& r) = delete; +}; + +class CreateFromOptionsTest : public tflite_shims::testing::Test {}; + +TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) { + ImageEmbedderOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + + SUPPORT_ASSERT_OK(ImageEmbedder::CreateFromOptions( + options, absl::make_unique<MobileNetV3OpResolver>())); +} + +class MobileNetV3OpResolverMissingOps : public ::tflite::MutableOpResolver { + public: + MobileNetV3OpResolverMissingOps() { + AddBuiltin(::tflite::BuiltinOperator_SOFTMAX, + ::tflite::ops::builtin::Register_SOFTMAX()); + } + + MobileNetV3OpResolverMissingOps(const MobileNetV3OpResolverMissingOps& r) = + delete; +}; + +TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { + ImageEmbedderOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + + auto image_embedder_or = ImageEmbedder::CreateFromOptions( + options, absl::make_unique<MobileNetV3OpResolverMissingOps>()); + EXPECT_EQ(image_embedder_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_embedder_or.status().message(), + HasSubstr("Didn't find op for builtin opcode")); + EXPECT_THAT(image_embedder_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kUnsupportedBuiltinOp)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { + ImageEmbedderOptions options; + + StatusOr<std::unique_ptr<ImageEmbedder>> image_embedder_or = + ImageEmbedder::CreateFromOptions(options); + + EXPECT_EQ(image_embedder_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_embedder_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +// Checks that CosineSimilarity fails if provided with a quantized and a float +// feature vector. +TEST(CosineSimilarityTest, FailsWithDifferentFeatureVectorTypes) { + FeatureVector u; + *u.mutable_value_string() = "\x01\x02"; + FeatureVector v; + v.add_value_float(0.1); + v.add_value_float(0.2); + + StatusOr<double> uv_similarity_or = ImageEmbedder::CosineSimilarity(u, v); + + EXPECT_EQ(uv_similarity_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(uv_similarity_or.status().message(), + HasSubstr("quantized and float")); +} + +// Checks that CosineSimilarity fails if provided with feature vectors of +// different sizes. +TEST(CosineSimilarityTest, FailsWithDifferentFeatureVectorSizes) { + FeatureVector u_float; + u_float.add_value_float(0.1); + FeatureVector v_float; + v_float.add_value_float(0.1); + v_float.add_value_float(0.2); + FeatureVector u_quantized; + *u_quantized.mutable_value_string() = "\x01"; + FeatureVector v_quantized; + *v_quantized.mutable_value_string() = "\x01\x02"; + + StatusOr<double> float_similarity_or = + ImageEmbedder::CosineSimilarity(u_float, v_float); + StatusOr<double> quantized_similarity_or = + ImageEmbedder::CosineSimilarity(u_quantized, v_quantized); + + EXPECT_EQ(float_similarity_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(float_similarity_or.status().message(), + HasSubstr("different sizes")); + EXPECT_EQ(quantized_similarity_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(quantized_similarity_or.status().message(), + HasSubstr("different sizes")); +} + +// Checks that CosineSimilarity fails if one of the feature vectors has 0 norm. +TEST(CosineSimilarityTest, FailsWithZeroNorm) { + FeatureVector u_float; + u_float.add_value_float(0.0); + u_float.add_value_float(0.0); + FeatureVector v_float; + v_float.add_value_float(0.1); + v_float.add_value_float(0.2); + FeatureVector u_quantized; + // Prevent literal from being interpreted as null-terminated C-style string. + *u_quantized.mutable_value_string() = std::string("\x00\x00", 2); + FeatureVector v_quantized; + *v_quantized.mutable_value_string() = "\x01\x02"; + + StatusOr<double> float_similarity_or = + ImageEmbedder::CosineSimilarity(u_float, v_float); + StatusOr<double> quantized_similarity_or = + ImageEmbedder::CosineSimilarity(u_quantized, v_quantized); + + EXPECT_EQ(float_similarity_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(float_similarity_or.status().message(), HasSubstr("0 norm")); + EXPECT_EQ(quantized_similarity_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(quantized_similarity_or.status().message(), HasSubstr("0 norm")); +} + +// Checks that CosineSimilarity produces expected results. +TEST(CosineSimilarityTest, Succeeds) { + FeatureVector u_float; + u_float.add_value_float(1.0); + u_float.add_value_float(0.0); + u_float.add_value_float(0.0); + u_float.add_value_float(0.0); + FeatureVector v_float; + v_float.add_value_float(0.5); + v_float.add_value_float(0.5); + v_float.add_value_float(0.5); + v_float.add_value_float(0.5); + FeatureVector u_quantized; + // Prevent literal from being interpreted as null-terminated C-style string. + *u_quantized.mutable_value_string() = std::string("\x7f\x00\x00\x00", 4); + FeatureVector v_quantized; + // Prevent literal from being interpreted as null-terminated C-style string. + *v_quantized.mutable_value_string() = std::string("\x80\x00\x00\x00", 4); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + double float_similarity, + ImageEmbedder::CosineSimilarity(u_float, v_float)); + SUPPORT_ASSERT_OK_AND_ASSIGN( + double quantized_similarity, + ImageEmbedder::CosineSimilarity(u_quantized, v_quantized)); + + EXPECT_EQ(float_similarity, 0.5); + EXPECT_EQ(quantized_similarity, -1.0); +} + +// Extracts feature vectors without L2 normalization on two image (one being +// slightly cropped from the other) and checks that cosine similarity is high. +TEST(EmbedTest, SucceedsWithoutL2Normalization) { + // Create embedder. + ImageEmbedderOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder, + ImageEmbedder::CreateFromOptions(options)); + // Load images: one is a crop of the other. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer( + image.pixel_data, FrameBuffer::Dimension{image.width, image.height}); + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData crop, LoadImage("burger_crop.jpg")); + std::unique_ptr<FrameBuffer> crop_frame_buffer = CreateFromRgbRawBuffer( + crop.pixel_data, FrameBuffer::Dimension{crop.width, crop.height}); + + // Extract both embeddings. + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + embedder->Embed(*image_frame_buffer)); + ImageDataFree(&image); + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + embedder->Embed(*crop_frame_buffer)); + ImageDataFree(&crop); + + // Check results sizes + EXPECT_EQ(image_result.embeddings_size(), 1); + const FeatureVector& image_feature_vector = + image_result.embeddings(0).feature_vector(); + EXPECT_EQ(image_feature_vector.value_float_size(), 1024); + EXPECT_EQ(crop_result.embeddings_size(), 1); + const FeatureVector& crop_feature_vector = + crop_result.embeddings(0).feature_vector(); + EXPECT_EQ(crop_feature_vector.value_float_size(), 1024); + // Check cosine similarity. + SUPPORT_ASSERT_OK_AND_ASSIGN( + double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector, + crop_feature_vector)); + double expected_similarity = 0.932738; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); +} + +// Same as above, but with `l2_normalize` option set to true. +TEST(EmbedTest, SucceedsWithL2Normalization) { + // Create embedder. + ImageEmbedderOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + options.set_l2_normalize(true); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder, + ImageEmbedder::CreateFromOptions(options)); + // Load images: one is a crop of the other. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer( + image.pixel_data, FrameBuffer::Dimension{image.width, image.height}); + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData crop, LoadImage("burger_crop.jpg")); + std::unique_ptr<FrameBuffer> crop_frame_buffer = CreateFromRgbRawBuffer( + crop.pixel_data, FrameBuffer::Dimension{crop.width, crop.height}); + + // Extract both embeddings. + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + embedder->Embed(*image_frame_buffer)); + ImageDataFree(&image); + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + embedder->Embed(*crop_frame_buffer)); + ImageDataFree(&crop); + + // Check results sizes + EXPECT_EQ(image_result.embeddings_size(), 1); + const FeatureVector& image_feature_vector = + image_result.embeddings(0).feature_vector(); + EXPECT_EQ(image_feature_vector.value_float_size(), 1024); + EXPECT_EQ(crop_result.embeddings_size(), 1); + const FeatureVector& crop_feature_vector = + crop_result.embeddings(0).feature_vector(); + EXPECT_EQ(crop_feature_vector.value_float_size(), 1024); + // Check cosine similarity. + SUPPORT_ASSERT_OK_AND_ASSIGN( + double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector, + crop_feature_vector)); + double expected_similarity = 0.932738; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); +} + +// Same as above, but with `quantize` option set to true. Requires also setting +// `l2_normalize` to true, as per the documentation. +// Same as above, but with `l2_normalize` option set to true. +TEST(EmbedTest, SucceedsWithQuantization) { + // Create embedder. + ImageEmbedderOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + options.set_l2_normalize(true); + options.set_quantize(true); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder, + ImageEmbedder::CreateFromOptions(options)); + // Load images: one is a crop of the other. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer( + image.pixel_data, FrameBuffer::Dimension{image.width, image.height}); + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData crop, LoadImage("burger_crop.jpg")); + std::unique_ptr<FrameBuffer> crop_frame_buffer = CreateFromRgbRawBuffer( + crop.pixel_data, FrameBuffer::Dimension{crop.width, crop.height}); + + // Extract both embeddings. + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + embedder->Embed(*image_frame_buffer)); + ImageDataFree(&image); + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + embedder->Embed(*crop_frame_buffer)); + ImageDataFree(&crop); + + // Check results sizes + EXPECT_EQ(image_result.embeddings_size(), 1); + const FeatureVector& image_feature_vector = + image_result.embeddings(0).feature_vector(); + EXPECT_EQ(image_feature_vector.value_string().size(), 1024); + EXPECT_EQ(crop_result.embeddings_size(), 1); + const FeatureVector& crop_feature_vector = + crop_result.embeddings(0).feature_vector(); + EXPECT_EQ(crop_feature_vector.value_string().size(), 1024); + // Check cosine similarity. + SUPPORT_ASSERT_OK_AND_ASSIGN( + double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector, + crop_feature_vector)); + // Close to but expectedly different from the above tests due to slight loss + // of precision during quantization: + double expected_similarity = 0.929717; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); +} + +// Extracts feature vectors on both the cropped image and the original image +// with a region of interest set to correspond to the cropped image, and checks +// that cosine similarity is close to 1. +TEST(EmbedTest, SucceedsWithRegionOfInterest) { + // Create embedder. + ImageEmbedderOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder, + ImageEmbedder::CreateFromOptions(options)); + // Load images: one is a crop of the other. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer( + image.pixel_data, FrameBuffer::Dimension{image.width, image.height}); + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData crop, LoadImage("burger_crop.jpg")); + std::unique_ptr<FrameBuffer> crop_frame_buffer = CreateFromRgbRawBuffer( + crop.pixel_data, FrameBuffer::Dimension{crop.width, crop.height}); + // Bounding box in "burger.jpg" corresponding to "burger_crop.jpg". + BoundingBox roi; + roi.set_origin_x(0); + roi.set_origin_y(0); + roi.set_width(400); + roi.set_height(325); + + // Extract both embeddings. + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + embedder->Embed(*image_frame_buffer, roi)); + ImageDataFree(&image); + SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + embedder->Embed(*crop_frame_buffer)); + ImageDataFree(&crop); + + // Check results sizes + EXPECT_EQ(image_result.embeddings_size(), 1); + const FeatureVector& image_feature_vector = + image_result.embeddings(0).feature_vector(); + EXPECT_EQ(image_feature_vector.value_float_size(), 1024); + EXPECT_EQ(crop_result.embeddings_size(), 1); + const FeatureVector& crop_feature_vector = + crop_result.embeddings(0).feature_vector(); + EXPECT_EQ(crop_feature_vector.value_float_size(), 1024); + // Check cosine similarity. + SUPPORT_ASSERT_OK_AND_ASSIGN( + double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector, + crop_feature_vector)); + double expected_similarity = 0.999914; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); +} + +TEST(GetEmbeddingDimension, Succeeds) { + // Create embedder. + ImageEmbedderOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder, + ImageEmbedder::CreateFromOptions(options)); + + EXPECT_EQ(embedder->GetEmbeddingDimension(0), 1024); + EXPECT_EQ(embedder->GetEmbeddingDimension(1), -1); +} + +TEST(GetNumberOfOutputLayers, Succeeds) { + // Create embedder. + ImageEmbedderOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder, + ImageEmbedder::CreateFromOptions(options)); + + EXPECT_EQ(embedder->GetNumberOfOutputLayers(), 1); +} + +} // namespace +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc new file mode 100644 index 0000000..dc768a4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc
@@ -0,0 +1,585 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/vision/image_segmenter.h" + +#include <memory> + +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/cord.h" // from @com_google_absl +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/kernels/builtin_op_kernels.h" +#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/cc/test/message_matchers.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +namespace tflite { +namespace task { +namespace vision { +namespace { + +using ::testing::HasSubstr; +using ::testing::Optional; +using ::tflite::support::EqualsProto; +using ::tflite::support::kTfLiteSupportPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::JoinPath; +using ::tflite::task::core::PopulateTensor; +using ::tflite::task::core::TaskAPIFactory; +using ::tflite::task::core::TfLiteEngine; + +constexpr char kTestDataDirectory[] = + "/tensorflow_lite_support/cc/test/testdata/task/" + "vision/"; +constexpr char kDeepLabV3[] = "deeplabv3.tflite"; + +// All results returned by DeepLabV3 are expected to contain these in addition +// to the segmentation masks. +constexpr char kDeepLabV3PartialResult[] = + R"(width: 257 + height: 257 + colored_labels { r: 0 g: 0 b: 0 class_name: "background" } + colored_labels { r: 128 g: 0 b: 0 class_name: "aeroplane" } + colored_labels { r: 0 g: 128 b: 0 class_name: "bicycle" } + colored_labels { r: 128 g: 128 b: 0 class_name: "bird" } + colored_labels { r: 0 g: 0 b: 128 class_name: "boat" } + colored_labels { r: 128 g: 0 b: 128 class_name: "bottle" } + colored_labels { r: 0 g: 128 b: 128 class_name: "bus" } + colored_labels { r: 128 g: 128 b: 128 class_name: "car" } + colored_labels { r: 64 g: 0 b: 0 class_name: "cat" } + colored_labels { r: 192 g: 0 b: 0 class_name: "chair" } + colored_labels { r: 64 g: 128 b: 0 class_name: "cow" } + colored_labels { r: 192 g: 128 b: 0 class_name: "dining table" } + colored_labels { r: 64 g: 0 b: 128 class_name: "dog" } + colored_labels { r: 192 g: 0 b: 128 class_name: "horse" } + colored_labels { r: 64 g: 128 b: 128 class_name: "motorbike" } + colored_labels { r: 192 g: 128 b: 128 class_name: "person" } + colored_labels { r: 0 g: 64 b: 0 class_name: "potted plant" } + colored_labels { r: 128 g: 64 b: 0 class_name: "sheep" } + colored_labels { r: 0 g: 192 b: 0 class_name: "sofa" } + colored_labels { r: 128 g: 192 b: 0 class_name: "train" } + colored_labels { r: 0 g: 64 b: 128 class_name: "tv" })"; + +// The maximum fraction of pixels in the candidate mask that can have a +// different class than the golden mask for the test to pass. +constexpr float kGoldenMaskTolerance = 1e-2; +// Magnification factor used when creating the golden category masks to make +// them more human-friendly. Each pixel in the golden masks has its value +// multiplied by this factor, i.e. a value of 10 means class index 1, a value of +// 20 means class index 2, etc. +constexpr int kGoldenMaskMagnificationFactor = 10; + +StatusOr<ImageData> LoadImage(std::string image_name) { + return DecodeImageFromFile( + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name)); +} + +// Checks that the two provided `Segmentation` protos are equal. +// If the proto definition changes, please also change this function. +void ExpectApproximatelyEqual(const Segmentation& actual, + const Segmentation& expected) { + EXPECT_EQ(actual.height(), expected.height()); + EXPECT_EQ(actual.width(), expected.width()); + for (int i = 0; i < actual.colored_labels_size(); i++) { + EXPECT_THAT(actual.colored_labels(i), + EqualsProto(expected.colored_labels(i))); + } +} + +class DeepLabOpResolver : public ::tflite::MutableOpResolver { + public: + DeepLabOpResolver() { + AddBuiltin(::tflite::BuiltinOperator_ADD, + ::tflite::ops::builtin::Register_ADD()); + AddBuiltin(::tflite::BuiltinOperator_AVERAGE_POOL_2D, + ::tflite::ops::builtin::Register_AVERAGE_POOL_2D()); + AddBuiltin(::tflite::BuiltinOperator_CONCATENATION, + ::tflite::ops::builtin::Register_CONCATENATION()); + AddBuiltin(::tflite::BuiltinOperator_CONV_2D, + ::tflite::ops::builtin::Register_CONV_2D()); + // DeepLab uses different versions of DEPTHWISE_CONV_2D. + AddBuiltin(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, + ::tflite::ops::builtin::Register_DEPTHWISE_CONV_2D(), + /*min_version=*/1, /*max_version=*/2); + AddBuiltin(::tflite::BuiltinOperator_RESIZE_BILINEAR, + ::tflite::ops::builtin::Register_RESIZE_BILINEAR()); + } + + DeepLabOpResolver(const DeepLabOpResolver& r) = delete; +}; + +class CreateFromOptionsTest : public tflite_shims::testing::Test {}; + +TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) { + ImageSegmenterOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + + SUPPORT_ASSERT_OK(ImageSegmenter::CreateFromOptions( + options, absl::make_unique<DeepLabOpResolver>())); +} + +class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver { + public: + DeepLabOpResolverMissingOps() { + AddBuiltin(::tflite::BuiltinOperator_ADD, + ::tflite::ops::builtin::Register_ADD()); + } + + DeepLabOpResolverMissingOps(const DeepLabOpResolverMissingOps& r) = delete; +}; + +TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { + ImageSegmenterOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + + auto image_segmenter_or = ImageSegmenter::CreateFromOptions( + options, absl::make_unique<DeepLabOpResolverMissingOps>()); + + EXPECT_EQ(image_segmenter_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_segmenter_or.status().message(), + HasSubstr("Didn't find op for builtin opcode")); + EXPECT_THAT(image_segmenter_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kUnsupportedBuiltinOp)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) { + ImageSegmenterOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + options.mutable_base_options()->mutable_model_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + + StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or = + ImageSegmenter::CreateFromOptions(options); + + EXPECT_EQ(image_segmenter_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_segmenter_or.status().message(), + HasSubstr("Expected exactly one of `base_options.model_file` or " + "`model_file_with_metadata` to be provided, found 2.")); + EXPECT_THAT(image_segmenter_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { + ImageSegmenterOptions options; + + auto image_segmenter_or = ImageSegmenter::CreateFromOptions(options); + + EXPECT_EQ(image_segmenter_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_segmenter_or.status().message(), + HasSubstr("Expected exactly one of `base_options.model_file` or " + "`model_file_with_metadata` to be provided, found 0.")); + EXPECT_THAT(image_segmenter_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithUnspecifiedOutputType) { + ImageSegmenterOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + options.set_output_type(ImageSegmenterOptions::UNSPECIFIED); + + auto image_segmenter_or = ImageSegmenter::CreateFromOptions(options); + + EXPECT_EQ(image_segmenter_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_segmenter_or.status().message(), + HasSubstr("`output_type` must not be UNSPECIFIED")); + EXPECT_THAT(image_segmenter_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) { + ImageSegmenterOptions options; + options.set_num_threads(4); + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + + SUPPORT_ASSERT_OK(ImageSegmenter::CreateFromOptions(options)); +} + +using NumThreadsTest = testing::TestWithParam<int>; + +INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2)); + +TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) { + ImageSegmenterOptions options; + options.set_num_threads(GetParam()); + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + + StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or = + ImageSegmenter::CreateFromOptions(options); + + EXPECT_EQ(image_segmenter_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_segmenter_or.status().message(), + HasSubstr("`num_threads` must be greater than " + "0 or equal to -1")); + EXPECT_THAT(image_segmenter_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +// Confidence masks tested in PostProcess unit tests below. +TEST(SegmentTest, SucceedsWithCategoryMask) { + // Load input and build frame buffer. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, + LoadImage("segmentation_input_rotation0.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); + // Load golden mask output. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask, + LoadImage("segmentation_golden_rotation0.png")); + + ImageSegmenterOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter, + ImageSegmenter::CreateFromOptions(options)); + SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result, + image_segmenter->Segment(*frame_buffer)); + + EXPECT_EQ(result.segmentation_size(), 1); + const Segmentation& segmentation = result.segmentation(0); + ExpectApproximatelyEqual( + segmentation, ParseTextProtoOrDie<Segmentation>(kDeepLabV3PartialResult)); + EXPECT_TRUE(segmentation.has_category_mask()); + const uint8* mask = + reinterpret_cast<const uint8*>(segmentation.category_mask().data()); + + int inconsistent_pixels = 0; + int num_pixels = golden_mask.height * golden_mask.width; + for (int i = 0; i < num_pixels; ++i) { + inconsistent_pixels += + (mask[i] * kGoldenMaskMagnificationFactor != golden_mask.pixel_data[i]); + } + EXPECT_LT(static_cast<float>(inconsistent_pixels) / num_pixels, + kGoldenMaskTolerance); + ImageDataFree(&rgb_image); + ImageDataFree(&golden_mask); +} + +TEST(SegmentTest, SucceedsWithOrientation) { + // Load input and build frame buffer with kRightBottom orientation. + SUPPORT_ASSERT_OK_AND_ASSIGN( + ImageData rgb_image, LoadImage("segmentation_input_rotation90_flop.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}, + FrameBuffer::Orientation::kRightBottom); + // Load golden mask output. + SUPPORT_ASSERT_OK_AND_ASSIGN( + ImageData golden_mask, + LoadImage("segmentation_golden_rotation90_flop.png")); + + ImageSegmenterOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter, + ImageSegmenter::CreateFromOptions(options)); + SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result, + image_segmenter->Segment(*frame_buffer)); + + EXPECT_EQ(result.segmentation_size(), 1); + const Segmentation& segmentation = result.segmentation(0); + ExpectApproximatelyEqual( + segmentation, ParseTextProtoOrDie<Segmentation>(kDeepLabV3PartialResult)); + EXPECT_TRUE(segmentation.has_category_mask()); + const uint8* mask = + reinterpret_cast<const uint8*>(segmentation.category_mask().data()); + int inconsistent_pixels = 0; + int num_pixels = golden_mask.height * golden_mask.width; + for (int i = 0; i < num_pixels; ++i) { + inconsistent_pixels += + (mask[i] * kGoldenMaskMagnificationFactor != golden_mask.pixel_data[i]); + } + EXPECT_LT(static_cast<float>(inconsistent_pixels) / num_pixels, + kGoldenMaskTolerance); + ImageDataFree(&rgb_image); + ImageDataFree(&golden_mask); +} + +TEST(SegmentTest, SucceedsWithBaseOptions) { + // Load input and build frame buffer. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, + LoadImage("segmentation_input_rotation0.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); + // Load golden mask output. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask, + LoadImage("segmentation_golden_rotation0.png")); + + ImageSegmenterOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter, + ImageSegmenter::CreateFromOptions(options)); + SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result, + image_segmenter->Segment(*frame_buffer)); + + EXPECT_EQ(result.segmentation_size(), 1); + const Segmentation& segmentation = result.segmentation(0); + ExpectApproximatelyEqual( + segmentation, ParseTextProtoOrDie<Segmentation>(kDeepLabV3PartialResult)); + EXPECT_TRUE(segmentation.has_category_mask()); + const uint8* mask = + reinterpret_cast<const uint8*>(segmentation.category_mask().data()); + + int inconsistent_pixels = 0; + int num_pixels = golden_mask.height * golden_mask.width; + for (int i = 0; i < num_pixels; ++i) { + inconsistent_pixels += + (mask[i] * kGoldenMaskMagnificationFactor != golden_mask.pixel_data[i]); + } + EXPECT_LT(static_cast<float>(inconsistent_pixels) / num_pixels, + kGoldenMaskTolerance); + ImageDataFree(&rgb_image); + ImageDataFree(&golden_mask); +} + +class PostprocessTest : public tflite_shims::testing::Test { + public: + class TestImageSegmenter : public ImageSegmenter { + public: + using ImageSegmenter::ImageSegmenter; + using ImageSegmenter::Postprocess; + + static StatusOr<std::unique_ptr<TestImageSegmenter>> CreateFromOptions( + const ImageSegmenterOptions& options) { + RETURN_IF_ERROR(SanityCheckOptions(options)); + + auto options_copy = absl::make_unique<ImageSegmenterOptions>(options); + + ASSIGN_OR_RETURN( + auto image_segmenter, + TaskAPIFactory::CreateFromExternalFileProto<TestImageSegmenter>( + &options_copy->model_file_with_metadata())); + + RETURN_IF_ERROR(image_segmenter->Init(std::move(options_copy))); + + return image_segmenter; + } + + TfLiteTensor* GetOutputTensor() { + if (TfLiteEngine::OutputCount(GetTfLiteEngine()->interpreter()) != 1) { + return nullptr; + } + return TfLiteEngine::GetOutput(GetTfLiteEngine()->interpreter(), 0); + } + }; + + protected: + void SetUp() override { tflite_shims::testing::Test::SetUp(); } + void SetUp(const ImageSegmenterOptions& options) { + StatusOr<std::unique_ptr<TestImageSegmenter>> test_image_segmenter_or = + TestImageSegmenter::CreateFromOptions(options); + + init_status_ = test_image_segmenter_or.status(); + + if (init_status_.ok()) { + test_image_segmenter_ = std::move(test_image_segmenter_or).value(); + } + } + + StatusOr<const TfLiteTensor*> FillAndGetOutputTensor() { + TfLiteTensor* output_tensor = test_image_segmenter_->GetOutputTensor(); + + // Fill top-left corner and pad all other pixels with zeros. + std::vector<float> confidence_scores = confidence_scores_; + confidence_scores.resize(/*width*/ 257 * + /*height*/ 257 * + /*classes*/ 21); + RETURN_IF_ERROR(PopulateTensor(confidence_scores, output_tensor)); + + return output_tensor; + } + + std::unique_ptr<TestImageSegmenter> test_image_segmenter_; + absl::Status init_status_; + std::vector<float> confidence_scores_ = {/*background=*/0.01, + /*aeroplane=*/0.01, + /*bicycle=*/0.01, + /*bird=*/0.01, + /*boat=*/0.01, + /*bottle=*/0.01, + /*bus=*/0.21, + /*car=*/0.60, // highest (index=7) + /*cat=*/0.01, + /*chair=*/0.01, + /*cow=*/0.01, + /*dining table=*/0.01, + /*dog=*/0.01, + /*horse=*/0.01, + /*motorbike=*/0.01, + /*person=*/0.01, + /*potted plant=*/0.01, + /*sheep=*/0.01, + /*sofa=*/0.01, + /*train=*/0.01, + /*tv=*/0.01}; +}; + +TEST_F(PostprocessTest, SucceedsWithCategoryMask) { + ImageSegmenterOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + std::unique_ptr<FrameBuffer> frame_buffer = + CreateFromRgbaRawBuffer(/*input=*/nullptr, {}); + + SetUp(options); + ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_; + SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor, + FillAndGetOutputTensor()); + SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result, + test_image_segmenter_->Postprocess( + {output_tensor}, *frame_buffer, /*roi=*/{})); + + EXPECT_EQ(result.segmentation_size(), 1); + const Segmentation& segmentation = result.segmentation(0); + ExpectApproximatelyEqual( + segmentation, ParseTextProtoOrDie<Segmentation>(kDeepLabV3PartialResult)); + EXPECT_TRUE(segmentation.has_category_mask()); + // Check top-left corner has expected class. + const uint8* category_mask = + reinterpret_cast<const uint8*>(segmentation.category_mask().data()); + EXPECT_EQ(category_mask[0], /*car*/ 7); +} + +TEST_F(PostprocessTest, SucceedsWithCategoryMaskAndOrientation) { + ImageSegmenterOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + // Frame buffer with kRightBottom orientation. + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbaRawBuffer( + /*input=*/nullptr, {}, FrameBuffer::Orientation::kRightBottom); + + SetUp(options); + ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_; + SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor, + FillAndGetOutputTensor()); + SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result, + test_image_segmenter_->Postprocess( + {output_tensor}, *frame_buffer, /*roi=*/{})); + + EXPECT_EQ(result.segmentation_size(), 1); + const Segmentation& segmentation = result.segmentation(0); + ExpectApproximatelyEqual( + segmentation, ParseTextProtoOrDie<Segmentation>(kDeepLabV3PartialResult)); + EXPECT_TRUE(segmentation.has_category_mask()); + // Check bottom-right corner has expected class. + const uint8* category_mask = + reinterpret_cast<const uint8*>(segmentation.category_mask().data()); + EXPECT_EQ(category_mask[/*width*/ 257 * /*height*/ 257 - 1], /*car*/ 7); +} + +TEST_F(PostprocessTest, SucceedsWithConfidenceMask) { + ImageSegmenterOptions options; + options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK); + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + std::unique_ptr<FrameBuffer> frame_buffer = + CreateFromRgbaRawBuffer(/*input=*/nullptr, {}); + + SetUp(options); + ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_; + SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor, + FillAndGetOutputTensor()); + SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result, + test_image_segmenter_->Postprocess( + {output_tensor}, *frame_buffer, /*roi=*/{})); + + EXPECT_EQ(result.segmentation_size(), 1); + const Segmentation& segmentation = result.segmentation(0); + ExpectApproximatelyEqual( + segmentation, ParseTextProtoOrDie<Segmentation>(kDeepLabV3PartialResult)); + EXPECT_TRUE(segmentation.has_confidence_masks()); + const Segmentation::ConfidenceMasks confidence_masks = + segmentation.confidence_masks(); + EXPECT_EQ(confidence_masks.confidence_mask_size(), confidence_scores_.size()); + // Check top-left corner has expected confidences. + for (int index = 0; index < confidence_scores_.size(); ++index) { + const float* confidence_mask = reinterpret_cast<const float*>( + confidence_masks.confidence_mask(index).value().data()); + EXPECT_EQ(confidence_mask[0], confidence_scores_[index]); + } +} + +TEST_F(PostprocessTest, SucceedsWithConfidenceMaskAndOrientation) { + ImageSegmenterOptions options; + options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK); + options.mutable_model_file_with_metadata()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3)); + // Frame buffer with kRightBottom orientation. + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbaRawBuffer( + /*input=*/nullptr, {}, FrameBuffer::Orientation::kRightBottom); + + SetUp(options); + ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_; + SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor, + FillAndGetOutputTensor()); + SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result, + test_image_segmenter_->Postprocess( + {output_tensor}, *frame_buffer, /*roi=*/{})); + + EXPECT_EQ(result.segmentation_size(), 1); + const Segmentation& segmentation = result.segmentation(0); + ExpectApproximatelyEqual( + segmentation, ParseTextProtoOrDie<Segmentation>(kDeepLabV3PartialResult)); + EXPECT_TRUE(segmentation.has_confidence_masks()); + const Segmentation::ConfidenceMasks confidence_masks = + segmentation.confidence_masks(); + EXPECT_EQ(confidence_masks.confidence_mask_size(), confidence_scores_.size()); + // Check top-left corner has expected confidences. + for (int index = 0; index < confidence_scores_.size(); ++index) { + const float* confidence_mask = reinterpret_cast<const float*>( + confidence_masks.confidence_mask(index).value().data()); + EXPECT_EQ(confidence_mask[/*width*/ 257 * /*height*/ 257 - 1], + confidence_scores_[index]); + } +} + +} // namespace +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc new file mode 100644 index 0000000..4a33e4b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc
@@ -0,0 +1,642 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/vision/object_detector.h" + +#include <memory> + +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/cord.h" // from @com_google_absl +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/kernels/builtin_op_kernels.h" +#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/cc/test/message_matchers.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +namespace tflite { + +namespace ops { +namespace custom { + +// Forward declaration for the custom Detection_PostProcess op. +// +// See: +// https://medium.com/@bsramasubramanian/running-a-tensorflow-lite-model-in-python-with-custom-ops-9b2b46efd355 +TfLiteRegistration* Register_DETECTION_POSTPROCESS(); + +} // namespace custom +} // namespace ops + +namespace task { +namespace vision { +namespace { + +using ::testing::HasSubstr; +using ::testing::Optional; +using ::tflite::support::EqualsProto; +using ::tflite::support::kTfLiteSupportPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::JoinPath; +using ::tflite::task::ParseTextProtoOrDie; +using ::tflite::task::core::PopulateTensor; +using ::tflite::task::core::TaskAPIFactory; +using ::tflite::task::core::TfLiteEngine; + +constexpr char kTestDataDirectory[] = + "/tensorflow_lite_support/cc/test/testdata/task/" + "vision/"; +constexpr char kMobileSsdWithMetadata[] = + "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; +constexpr char kExpectResults[] = + R"pb(detections { + bounding_box { origin_x: 54 origin_y: 396 width: 393 height: 196 } + classes { index: 16 score: 0.64453125 class_name: "cat" } + } + detections { + bounding_box { origin_x: 602 origin_y: 157 width: 394 height: 447 } + classes { index: 16 score: 0.59765625 class_name: "cat" } + } + detections { + bounding_box { origin_x: 261 origin_y: 394 width: 179 height: 209 } + # Actually a dog, but the model gets confused. + classes { index: 16 score: 0.5625 class_name: "cat" } + } + detections { + bounding_box { origin_x: 389 origin_y: 197 width: 276 height: 409 } + classes { index: 17 score: 0.51171875 class_name: "dog" } + } + )pb"; +constexpr char kMobileSsdWithMetadataDummyScoreCalibration[] = + "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflite"; +// The model has different output tensor order. +constexpr char kEfficientDetWithMetadata[] = + "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite"; + +StatusOr<ImageData> LoadImage(std::string image_name) { + return DecodeImageFromFile( + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name)); +} + +// Checks that the two provided `DetectionResult` protos are equal, with a +// tolerancy on floating-point scores to account for numerical instabilities. +// If the proto definition changes, please also change this function. +void ExpectApproximatelyEqual(const DetectionResult& actual, + const DetectionResult& expected) { + const float kPrecision = 1e-6; + EXPECT_EQ(actual.detections_size(), expected.detections_size()); + for (int i = 0; i < actual.detections_size(); ++i) { + const Detection& a = actual.detections(i); + const Detection& b = expected.detections(i); + EXPECT_THAT(a.bounding_box(), EqualsProto(b.bounding_box())); + EXPECT_EQ(a.classes_size(), 1); + EXPECT_EQ(b.classes_size(), 1); + EXPECT_EQ(a.classes(0).index(), b.classes(0).index()); + EXPECT_EQ(a.classes(0).class_name(), b.classes(0).class_name()); + EXPECT_NEAR(a.classes(0).score(), b.classes(0).score(), kPrecision); + } +} + +// OpResolver including the custom Detection_PostProcess op. +class MobileSsdQuantizedOpResolver : public ::tflite::MutableOpResolver { + public: + MobileSsdQuantizedOpResolver() { + AddBuiltin(::tflite::BuiltinOperator_CONCATENATION, + ::tflite::ops::builtin::Register_CONCATENATION()); + AddBuiltin(::tflite::BuiltinOperator_CONV_2D, + ::tflite::ops::builtin::Register_CONV_2D()); + AddBuiltin(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, + ::tflite::ops::builtin::Register_DEPTHWISE_CONV_2D()); + AddBuiltin(::tflite::BuiltinOperator_RESHAPE, + ::tflite::ops::builtin::Register_RESHAPE()); + AddBuiltin(::tflite::BuiltinOperator_LOGISTIC, + ::tflite::ops::builtin::Register_LOGISTIC()); + AddBuiltin(::tflite::BuiltinOperator_ADD, + ::tflite::ops::builtin::Register_ADD()); + AddCustom("TFLite_Detection_PostProcess", + tflite::ops::custom::Register_DETECTION_POSTPROCESS()); + } + + MobileSsdQuantizedOpResolver(const MobileSsdQuantizedOpResolver& r) = delete; +}; + +class CreateFromOptionsTest : public tflite_shims::testing::Test {}; + +TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) { + ObjectDetectorOptions options; + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + + SUPPORT_ASSERT_OK(ObjectDetector::CreateFromOptions( + options, absl::make_unique<MobileSsdQuantizedOpResolver>())); +} + +// OpResolver missing the Detection_PostProcess op. +class MobileSsdQuantizedOpResolverMissingOps + : public ::tflite::MutableOpResolver { + public: + MobileSsdQuantizedOpResolverMissingOps() { + AddBuiltin(::tflite::BuiltinOperator_CONCATENATION, + ::tflite::ops::builtin::Register_CONCATENATION()); + AddBuiltin(::tflite::BuiltinOperator_CONV_2D, + ::tflite::ops::builtin::Register_CONV_2D()); + AddBuiltin(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, + ::tflite::ops::builtin::Register_DEPTHWISE_CONV_2D()); + AddBuiltin(::tflite::BuiltinOperator_RESHAPE, + ::tflite::ops::builtin::Register_RESHAPE()); + AddBuiltin(::tflite::BuiltinOperator_LOGISTIC, + ::tflite::ops::builtin::Register_LOGISTIC()); + AddBuiltin(::tflite::BuiltinOperator_ADD, + ::tflite::ops::builtin::Register_ADD()); + } + + MobileSsdQuantizedOpResolverMissingOps( + const MobileSsdQuantizedOpResolverMissingOps& r) = delete; +}; + +TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { + ObjectDetectorOptions options; + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + + auto object_detector_or = ObjectDetector::CreateFromOptions( + options, absl::make_unique<MobileSsdQuantizedOpResolverMissingOps>()); + EXPECT_EQ(object_detector_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(object_detector_or.status().message(), + HasSubstr("Encountered unresolved custom op")); + EXPECT_THAT(object_detector_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kUnsupportedCustomOp)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) { + ObjectDetectorOptions options; + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + + StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or = + ObjectDetector::CreateFromOptions(options); + + EXPECT_EQ(object_detector_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(object_detector_or.status().message(), + HasSubstr("Expected exactly one of `base_options.model_file` or " + "`model_file_with_metadata` to be provided, found 2.")); + EXPECT_THAT(object_detector_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { + ObjectDetectorOptions options; + + StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or = + ObjectDetector::CreateFromOptions(options); + + EXPECT_EQ(object_detector_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(object_detector_or.status().message(), + HasSubstr("Expected exactly one of `base_options.model_file` or " + "`model_file_with_metadata` to be provided, found 0.")); + EXPECT_THAT(object_detector_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { + ObjectDetectorOptions options; + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.set_max_results(0); + + StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or = + ObjectDetector::CreateFromOptions(options); + + EXPECT_EQ(object_detector_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(object_detector_or.status().message(), + HasSubstr("Invalid `max_results` option")); + EXPECT_THAT(object_detector_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) { + ObjectDetectorOptions options; + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.add_class_name_whitelist("foo"); + options.add_class_name_blacklist("bar"); + + StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or = + ObjectDetector::CreateFromOptions(options); + + EXPECT_EQ(object_detector_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(object_detector_or.status().message(), + HasSubstr("mutually exclusive options")); + EXPECT_THAT(object_detector_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) { + ObjectDetectorOptions options; + options.set_num_threads(4); + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + + SUPPORT_ASSERT_OK(ObjectDetector::CreateFromOptions(options)); +} + +using NumThreadsTest = testing::TestWithParam<int>; + +INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2)); + +TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) { + ObjectDetectorOptions options; + options.set_num_threads(GetParam()); + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + + StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or = + ObjectDetector::CreateFromOptions(options); + + EXPECT_EQ(object_detector_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(object_detector_or.status().message(), + HasSubstr("`num_threads` must be greater than " + "0 or equal to -1")); + EXPECT_THAT(object_detector_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +class DetectTest : public tflite_shims::testing::Test {}; + +TEST_F(DetectTest, Succeeds) { + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, + LoadImage("cats_and_dogs.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); + + ObjectDetectorOptions options; + options.set_max_results(4); + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector, + ObjectDetector::CreateFromOptions(options)); + + SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result, + object_detector->Detect(*frame_buffer)); + ImageDataFree(&rgb_image); + ExpectApproximatelyEqual( + result, ParseTextProtoOrDie<DetectionResult>(kExpectResults)); +} + +TEST_F(DetectTest, SucceedswithBaseOptions) { + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, + LoadImage("cats_and_dogs.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); + + ObjectDetectorOptions options; + options.set_max_results(4); + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector, + ObjectDetector::CreateFromOptions(options)); + + SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result, + object_detector->Detect(*frame_buffer)); + ImageDataFree(&rgb_image); + ExpectApproximatelyEqual( + result, ParseTextProtoOrDie<DetectionResult>(kExpectResults)); +} + +TEST_F(DetectTest, SucceedswithScoreCalibrations) { + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, + LoadImage("cats_and_dogs.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + rgb_image.pixel_data, + FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); + + ObjectDetectorOptions options; + options.set_max_results(4); + options.mutable_base_options()->mutable_model_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + kMobileSsdWithMetadataDummyScoreCalibration)); + + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector, + ObjectDetector::CreateFromOptions(options)); + + SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result, + object_detector->Detect(*frame_buffer)); + ImageDataFree(&rgb_image); + ExpectApproximatelyEqual( + result, ParseTextProtoOrDie<DetectionResult>(kExpectResults)); +} + +class PostprocessTest : public tflite_shims::testing::Test { + public: + class TestObjectDetector : public ObjectDetector { + public: + using ObjectDetector::ObjectDetector; + using ObjectDetector::Postprocess; + + static StatusOr<std::unique_ptr<TestObjectDetector>> CreateFromOptions( + const ObjectDetectorOptions& options) { + RETURN_IF_ERROR(SanityCheckOptions(options)); + + auto options_copy = absl::make_unique<ObjectDetectorOptions>(options); + + ASSIGN_OR_RETURN( + auto object_detector, + TaskAPIFactory::CreateFromExternalFileProto<TestObjectDetector>( + &options_copy->model_file_with_metadata())); + + RETURN_IF_ERROR(object_detector->Init(std::move(options_copy))); + + return object_detector; + } + + std::vector<TfLiteTensor*> GetOutputTensors() { + std::vector<TfLiteTensor*> outputs; + int num_outputs = + TfLiteEngine::OutputCount(GetTfLiteEngine()->interpreter()); + outputs.reserve(num_outputs); + for (int i = 0; i < num_outputs; i++) { + outputs.push_back( + TfLiteEngine::GetOutput(GetTfLiteEngine()->interpreter(), i)); + } + return outputs; + } + }; + + protected: + void SetUp() override { tflite_shims::testing::Test::SetUp(); } + void SetUp(const ObjectDetectorOptions& options) { + StatusOr<std::unique_ptr<TestObjectDetector>> test_object_detector_or = + TestObjectDetector::CreateFromOptions(options); + + init_status_ = test_object_detector_or.status(); + + if (init_status_.ok()) { + test_object_detector_ = std::move(test_object_detector_or).value(); + } + + dummy_frame_buffer_ = CreateFromRgbRawBuffer(/*input=*/nullptr, + {/*width=*/20, /*height=*/10}); + } + + StatusOr<std::vector<const TfLiteTensor*>> FillAndGetOutputTensors() { + std::vector<TfLiteTensor*> output_tensors = + test_object_detector_->GetOutputTensors(); + if (output_tensors.size() != 4) { + return absl::InternalError(absl::StrFormat( + "Expected 4 output tensors, found %d", output_tensors.size())); + } + + std::vector<const TfLiteTensor*> result; + + TfLiteTensor* locations = output_tensors[0]; + std::vector<float> locations_data = { + /*left=*/0.2, /*top=*/0.2, /*right=*/0.4, /*bottom=*/0.6, + /*left=*/0.4, /*top=*/0.2, /*right=*/0.6, /*bottom=*/0.6, + /*left=*/0.2, /*top=*/0.4, /*right=*/0.4, /*bottom=*/0.8}; + // Pad with zeros to fill the 10 locations. + locations_data.resize(4 * 10); + RETURN_IF_ERROR(PopulateTensor(locations_data, locations)); + result.push_back(locations); + + TfLiteTensor* classes = output_tensors[1]; + std::vector<float> classes_data = {/*bicycle*/ 1, /*car*/ 2, + /*motorcycle*/ 3}; + // Pad with zeros to fill the 10 classes. + classes_data.resize(10); + RETURN_IF_ERROR(PopulateTensor(classes_data, classes)); + result.push_back(classes); + + TfLiteTensor* scores = output_tensors[2]; + std::vector<float> scores_data = {0.8, 0.6, 0.4}; + // Pad with zeros to fill the 10 scores. + scores_data.resize(10); + RETURN_IF_ERROR(PopulateTensor(scores_data, scores)); + result.push_back(scores); + + TfLiteTensor* num_results = output_tensors[3]; + std::vector<float> num_results_data = {10}; + RETURN_IF_ERROR(PopulateTensor(num_results_data, num_results)); + result.push_back(num_results); + + return result; + } + + std::unique_ptr<TestObjectDetector> test_object_detector_; + std::unique_ptr<FrameBuffer> dummy_frame_buffer_; + absl::Status init_status_; +}; + +TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) { + ObjectDetectorOptions options; + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.set_score_threshold(0.5); + + SetUp(options); + ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_; + + SUPPORT_ASSERT_OK_AND_ASSIGN( + const std::vector<const TfLiteTensor*> output_tensors, + FillAndGetOutputTensors()); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + DetectionResult result, + test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_, + /*roi=*/{})); + + ExpectApproximatelyEqual( + result, + ParseTextProtoOrDie<DetectionResult>( + R"pb(detections { + bounding_box { origin_x: 4 origin_y: 2 width: 8 height: 2 } + classes { index: 1 score: 0.8 class_name: "bicycle" } + } + detections { + bounding_box { origin_x: 4 origin_y: 4 width: 8 height: 2 } + classes { index: 2 score: 0.6 class_name: "car" } + } + )pb")); +} + +TEST_F(PostprocessTest, SucceedsWithFrameBufferOrientation) { + std::unique_ptr<FrameBuffer> frame_buffer_with_orientation = + CreateFromRgbRawBuffer(/*input=*/nullptr, {/*width=*/20, /*height=*/10}, + FrameBuffer::Orientation::kBottomRight); + + ObjectDetectorOptions options; + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.set_score_threshold(0.5); + + SetUp(options); + ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_; + + SUPPORT_ASSERT_OK_AND_ASSIGN( + const std::vector<const TfLiteTensor*> output_tensors, + FillAndGetOutputTensors()); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + DetectionResult result, + test_object_detector_->Postprocess( + output_tensors, *frame_buffer_with_orientation, /*roi=*/{})); + + ExpectApproximatelyEqual( + result, + ParseTextProtoOrDie<DetectionResult>( + R"pb(detections { + bounding_box { origin_x: 8 origin_y: 6 width: 8 height: 2 } + classes { index: 1 score: 0.8 class_name: "bicycle" } + } + detections { + bounding_box { origin_x: 8 origin_y: 4 width: 8 height: 2 } + classes { index: 2 score: 0.6 class_name: "car" } + } + )pb")); +} + +TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) { + ObjectDetectorOptions options; + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.set_max_results(1); + + SetUp(options); + ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_; + + SUPPORT_ASSERT_OK_AND_ASSIGN( + const std::vector<const TfLiteTensor*> output_tensors, + FillAndGetOutputTensors()); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + DetectionResult result, + test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_, + /*roi=*/{})); + + ExpectApproximatelyEqual( + result, + ParseTextProtoOrDie<DetectionResult>( + R"pb(detections { + bounding_box { origin_x: 4 origin_y: 2 width: 8 height: 2 } + classes { index: 1 score: 0.8 class_name: "bicycle" } + } + )pb")); +} + +TEST_F(PostprocessTest, SucceedsWithWhitelistOption) { + ObjectDetectorOptions options; + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.add_class_name_whitelist("car"); + options.add_class_name_whitelist("motorcycle"); + + SetUp(options); + ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_; + + SUPPORT_ASSERT_OK_AND_ASSIGN( + const std::vector<const TfLiteTensor*> output_tensors, + FillAndGetOutputTensors()); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + DetectionResult result, + test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_, + /*roi=*/{})); + + ExpectApproximatelyEqual( + result, + ParseTextProtoOrDie<DetectionResult>( + R"pb(detections { + bounding_box { origin_x: 4 origin_y: 4 width: 8 height: 2 } + classes { index: 2 score: 0.6 class_name: "car" } + } + detections { + bounding_box { origin_x: 8 origin_y: 2 width: 8 height: 2 } + classes { index: 3 score: 0.4 class_name: "motorcycle" } + } + )pb")); +} + +TEST_F(PostprocessTest, SucceedsWithBlacklistOption) { + ObjectDetectorOptions options; + options.mutable_model_file_with_metadata()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata)); + options.add_class_name_blacklist("car"); + // Setting score threshold to discard the 7 padded-with-zeros results. + options.set_score_threshold(0.1); + + SetUp(options); + ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_; + + SUPPORT_ASSERT_OK_AND_ASSIGN( + const std::vector<const TfLiteTensor*> output_tensors, + FillAndGetOutputTensors()); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + DetectionResult result, + test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_, + /*roi=*/{})); + + ExpectApproximatelyEqual( + result, + ParseTextProtoOrDie<DetectionResult>( + R"pb(detections { + bounding_box { origin_x: 4 origin_y: 2 width: 8 height: 2 } + classes { index: 1 score: 0.8 class_name: "bicycle" } + } + detections { + bounding_box { origin_x: 8 origin_y: 2 width: 8 height: 2 } + classes { index: 3 score: 0.4 class_name: "motorcycle" } + } + )pb")); +} + +} // namespace +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc new file mode 100644 index 0000000..c16815c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc
@@ -0,0 +1,82 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/test/test_utils.h" + +#include "absl/strings/str_cat.h" // from @com_google_absl + +namespace tflite { +namespace task { + +std::string JoinPath(absl::string_view path1, absl::string_view path2) { + if (path1.empty()) + return std::string(path2); + if (path2.empty()) + return std::string(path1); + if (path1.back() == '/') { + if (path2.front() == '/') + return absl::StrCat(path1, absl::ClippedSubstr(path2, 1)); + } else { + if (path2.front() != '/') + return absl::StrCat(path1, "/", path2); + } + return absl::StrCat(path1, path2); +} + +namespace internal { + +// Given a collection of file paths, append them all together, +// ensuring that the proper path separators are inserted between them. +std::string JoinPathImpl(bool honor_abs, + std::initializer_list<absl::string_view> paths) { + std::string result; + + if (paths.size() != 0) { + // This size calculation is worst-case: it assumes one extra "/" for every + // path other than the first. + size_t total_size = paths.size() - 1; + for (const absl::string_view path : paths) + total_size += path.size(); + result.resize(total_size); + + auto begin = result.begin(); + auto out = begin; + bool trailing_slash = false; + for (absl::string_view path : paths) { + if (path.empty()) + continue; + if (path.front() == '/') { + if (honor_abs) { + out = begin; // wipe out whatever we've built up so far. + } else if (trailing_slash) { + path.remove_prefix(1); + } + } else { + if (!trailing_slash && out != begin) + *out++ = '/'; + } + const size_t this_size = path.size(); + memcpy(&*out, path.data(), this_size); + out += this_size; + trailing_slash = out[-1] == '/'; + } + result.erase(out - begin); + } + return result; +} + +} // namespace internal +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h new file mode 100644 index 0000000..1d730d5 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h
@@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEST_UTILS_TEST_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TEST_UTILS_TEST_UTILS_H_ + +#include <glog/logging.h> +#include "absl/strings/string_view.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/proto2.h" + +namespace tflite { +namespace task { +namespace internal { + +// Not part of the public API. +std::string JoinPathImpl(bool honor_abs, + std::initializer_list<absl::string_view> paths); + +} // namespace internal + +std::string JoinPath(absl::string_view path1, absl::string_view path2); + +template <typename... T> +inline std::string JoinPath(absl::string_view path1, + absl::string_view path2, + absl::string_view path3, + const T&... args) { + return internal::JoinPathImpl(false, {path1, path2, path3, args...}); +} + +template <typename T> +T ParseTextProtoOrDie(const std::string& input) { + T result; + CHECK(tflite::support::proto::TextFormat::ParseFromString(input, &result)); + return result; +} + +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TEST_UTILS_TEST_UTILS_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/BUILD new file mode 100644 index 0000000..72a88c2 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/BUILD
@@ -0,0 +1,79 @@ +load("//tensorflow_lite_support/tools/build_rules:http_files.bzl", "tflite_file", "tflite_model") + +package( + default_visibility = ["//tensorflow_lite_support:internal"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "nl_classifier_models", + srcs = glob([ + "test_model_nl_classifier*.tflite", + ]), + visibility = [ + "//tensorflow_lite_support:internal", + ], +) + +filegroup( + name = "bert_nl_classifier_models", + srcs = [ + ":bert_nl_classifier", + ], +) + +filegroup( + name = "mobile_bert_model", + srcs = [ + ":mobilebert_float", + ":mobilebert_vocab", + ":mobilebert_with_metadata", + ], + visibility = [ + "//tensorflow_lite_support:internal", + ], +) + +filegroup( + name = "albert_model", + srcs = [ + ":30k-clean", + ":albert", + ":albert_with_metadata", + ], +) + +filegroup( + name = "regex_tokenizer_files", + srcs = [ + "empty_vocab_for_regex_tokenizer.txt", + "vocab_for_regex_tokenizer.txt", + ], +) + +filegroup( + name = "universal_sentence_encoder_qa", + data = [":universal_sentence_encoder_qa_with_metadata"], +) + +tflite_model(name = "bert_nl_classifier") + +tflite_model(name = "albert") + +tflite_model(name = "albert_with_metadata") + +tflite_model(name = "universal_sentence_encoder_qa_with_metadata") + +tflite_model(name = "mobilebert_float") + +tflite_file( + name = "mobilebert_vocab", + extension = "txt", +) + +tflite_model(name = "mobilebert_with_metadata") + +tflite_file( + name = "30k-clean", + extension = "model", +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/albert_with_metadata.json b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/albert_with_metadata.json new file mode 100644 index 0000000..6d2e748 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/albert_with_metadata.json
@@ -0,0 +1,81 @@ +{ + "name": "Albert Question and Answerer", + "description": "Answers questions based on the content of a given passage. To integrate the model into your app, try the `BertQuestionAnswerer` API in the TensorFlow Lite Task library. `BertQuestionAnswerer` takes a passage string and a query string, and returns the answer strings. It encapsulates the processing logic of inputs and outputs and runs the inference with the best practice.", + "version": "v1", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "ids", + "description": "Tokenized ids of input text as concatenated query and passage.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + } + }, + { + "name": "mask", + "description": "Mask with 1 for real tokens and 0 for padding tokens.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + } + }, + { + "name": "segment_ids", + "description": "0 for query and 1 for passage tokens.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + } + } + ], + "output_tensor_metadata": [ + { + "name": "end_logits", + "description": "logits over the sequence which indicates the end position of the answer span with closed interval.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + } + }, + { + "name": "start_logits", + "description": "logits over the sequence which indicates the start position of the answer span with closed interval.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + } + } + ], + "input_process_units": [ + { + "options_type": "SentencePieceTokenizerOptions", + "options": { + "sentencePiece_model": [ + { + "name": "30k-clean.model", + "description": "The sentence piece model file." + } + ], + "vocab_file": [ + { + "name": "30k-clean.vocab", + "description": "Vocabulary file for the SentencePiece tokenizer. This file is optional during tokenization, while the sentence piece model is mandatory.", + "type": "VOCABULARY" + } + ] + } + } + ] + } + ], + "author": "TensorFlow", + "license": "Apache License. Version 2.0 http://www.apache.org/licenses/LICENSE-2.0.", + "min_parser_version": "1.1.0" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/bert_nl_classifier.json b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/bert_nl_classifier.json new file mode 100644 index 0000000..44aedf97 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/bert_nl_classifier.json
@@ -0,0 +1,73 @@ +{ + "name": "MobileBert text classifier", + "description": "Classifies the input string based on the known catergories. To integrate the model into your app, try the `BertNLClassifier` API in the TensorFlow Lite Task library. `BertNLClassifier` takes an input string, and returns the classified label with probability. It encapsulates the processing logic of inputs and outputs and runs the inference with the best practice.", + "version": "v1", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "ids", + "description": "Tokenized ids of input text as concatenated query and passage.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + } + }, + { + "name": "segment_ids", + "description": "0 for query and 1 for passage tokens.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + } + }, + { + "name": "mask", + "description": "Mask with 1 for real tokens and 0 for padding tokens.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Probabilities of labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for classification categories.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ], + "input_process_units": [ + { + "options_type": "BertTokenizerOptions", + "options": { + "vocab_file": [ + { + "name": "vocab.txt", + "description": "Vocabulary file for the BertTokenizer.", + "type": "VOCABULARY" + } + ] + } + } + ] + } + ], + "author": "TensorFlow", + "license": "Apache License. Version 2.0 http://www.apache.org/licenses/LICENSE-2.0.", + "min_parser_version": "1.1.0" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/empty_vocab_for_regex_tokenizer.txt b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/empty_vocab_for_regex_tokenizer.txt new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/empty_vocab_for_regex_tokenizer.txt
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/mobilebert_with_metadata.json b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/mobilebert_with_metadata.json new file mode 100644 index 0000000..d2ba03f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/mobilebert_with_metadata.json
@@ -0,0 +1,75 @@ +{ + "name": "MobileBert Question and Answerer", + "description": "Answers questions based on the content of a given passage. To integrate the model into your app, try the `BertQuestionAnswerer` API in the TensorFlow Lite Task library. `BertQuestionAnswerer` takes a passage string and a query string, and returns the answer strings. It encapsulates the processing logic of inputs and outputs and runs the inference with the best practice.", + "version": "v1", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "ids", + "description": "Tokenized ids of input text as concatenated query and passage.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + } + }, + { + "name": "mask", + "description": "Mask with 1 for real tokens and 0 for padding tokens.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + } + }, + { + "name": "segment_ids", + "description": "0 for query and 1 for passage tokens.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + } + } + ], + "output_tensor_metadata": [ + { + "name": "end_logits", + "description": "logits over the sequence which indicates the end position of the answer span with closed interval.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + } + }, + { + "name": "start_logits", + "description": "logits over the sequence which indicates the start position of the answer span with closed interval.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + } + } + ], + "input_process_units": [ + { + "options_type": "BertTokenizerOptions", + "options": { + "vocab_file": [ + { + "name": "vocab.txt", + "description": "Vocabulary file for the BertTokenizer.", + "type": "VOCABULARY" + } + ] + } + } + ] + } + ], + "author": "TensorFlow", + "license": "Apache License. Version 2.0 http://www.apache.org/licenses/LICENSE-2.0.", + "min_parser_version": "1.1.0" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier.tflite new file mode 100644 index 0000000..52e614c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_bool_output.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_bool_output.tflite new file mode 100644 index 0000000..20c6db1 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_bool_output.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_associated_label.json b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_associated_label.json new file mode 100644 index 0000000..d0a4301 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_associated_label.json
@@ -0,0 +1,50 @@ +{ + "name": "NL Classifier", + "description": "Identify the text string from a set of 3 categories.", + "version": "v1", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "text string", + "description": "Input string to be classified.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + } + } + ], + "output_tensor_metadata": [ + { + "name": "scores_dequantized", + "description": "Scores of the 3 labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "associated_files": [ + { + "name": "test_labels.txt", + "description": "Labels for objects that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + }, + { + "name": "scores_quantized" + }, + { + "name": "labels" + }, + { + "name": "scores_dequantized_float64" + } + ] + } + ], + "author": "TensorFlow", + "license": "Apache License. Version 2.0 http://www.apache.org/licenses/LICENSE-2.0.", + "min_parser_version": "1.0.0" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_associated_label.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_associated_label.tflite new file mode 100644 index 0000000..4d04015 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_associated_label.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_associated_label_builtin_ops.json b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_associated_label_builtin_ops.json new file mode 100644 index 0000000..f8d36ad --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_associated_label_builtin_ops.json
@@ -0,0 +1,64 @@ +{ + "name": "AverageWordVec text classifier", + "description": "Classify text into predefined categories.", + "version": "v1", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input_text", + "description": "Embedding vectors representing the input text to be classified. The input need to be converted from raw text to embedding vectors using the attached dictionary file.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "RegexTokenizerOptions", + "options": { + "delim_regex_pattern": "[^\\w\\']+", + "vocab_file": [ + { + "name": "vocab.txt", + "description": "Vocabulary file to convert natural language words to embedding vectors.", + "type": "VOCABULARY" + } + ] + } + } + ] + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Probabilities of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for the categories that the model can classify.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ] + } + ], + "author": "TensorFlow", + "license": "Apache License. Version 2.0 http://www.apache.org/licenses/LICENSE-2.0.", + "min_parser_version": "1.2.1" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_associated_label_builtin_ops.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_associated_label_builtin_ops.tflite new file mode 100644 index 0000000..7bcfa882 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_associated_label_builtin_ops.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_regex_tokenizer.json b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_regex_tokenizer.json new file mode 100644 index 0000000..79b63a40 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_regex_tokenizer.json
@@ -0,0 +1,64 @@ +{ + "name": "Sentiment Analyzer (AverageWordVecModelSpec)", + "description": "Detect if the input text's sentiment is positive or negative. The model was trained on the IMDB Movie Reviews dataset so it is more accurate when input text is a movie review.", + "version": "v1", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input_text", + "description": "Embedding vectors representing the input text to be classified. The input need to be converted from raw text to embedding vectors using the attached dictionary file.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "RegexTokenizerOptions", + "options": { + "delim_regex_pattern": "[^\\w\\']+", + "vocab_file": [ + { + "name": "vocab.txt", + "description": "Vocabulary file to convert natural language words to embedding vectors.", + "type": "VOCABULARY" + } + ] + } + } + ] + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Probabilities of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for the categories that the model can classify.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ] + } + ], + "author": "TensorFlow", + "license": "Apache License. Version 2.0 http://www.apache.org/licenses/LICENSE-2.0.", + "min_parser_version": "1.2.1" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_regex_tokenizer.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_regex_tokenizer.tflite new file mode 100644 index 0000000..f73e77c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/test_model_nl_classifier_with_regex_tokenizer.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/vocab_for_regex_tokenizer.txt b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/vocab_for_regex_tokenizer.txt new file mode 100644 index 0000000..0a27d7c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/vocab_for_regex_tokenizer.txt
@@ -0,0 +1,10000 @@ +<PAD> 0 +<START> 1 +<UNKNOWN> 2 +<UNUSED> 3 +the 4 +and 5 +a 6 +of 7 +to 8 +is 9 +br 10 +in 11 +it 12 +i 13 +this 14 +that 15 +was 16 +as 17 +for 18 +with 19 +movie 20 +but 21 +film 22 +on 23 +not 24 +you 25 +are 26 +his 27 +have 28 +he 29 +be 30 +one 31 +all 32 +at 33 +by 34 +an 35 +they 36 +who 37 +so 38 +from 39 +like 40 +her 41 +or 42 +just 43 +about 44 +it's 45 +out 46 +has 47 +if 48 +some 49 +there 50 +what 51 +good 52 +more 53 +when 54 +very 55 +up 56 +no 57 +time 58 +she 59 +even 60 +my 61 +would 62 +which 63 +only 64 +story 65 +really 66 +see 67 +their 68 +had 69 +can 70 +were 71 +me 72 +well 73 +than 74 +we 75 +much 76 +been 77 +bad 78 +get 79 +will 80 +do 81 +also 82 +into 83 +people 84 +other 85 +first 86 +great 87 +because 88 +how 89 +him 90 +most 91 +don't 92 +made 93 +its 94 +then 95 +way 96 +make 97 +them 98 +too 99 +could 100 +any 101 +movies 102 +after 103 +think 104 +characters 105 +watch 106 +two 107 +films 108 +character 109 +seen 110 +many 111 +being 112 +life 113 +plot 114 +never 115 +acting 116 +little 117 +best 118 +love 119 +over 120 +where 121 +did 122 +show 123 +know 124 +off 125 +ever 126 +does 127 +better 128 +your 129 +end 130 +still 131 +man 132 +here 133 +these 134 +say 135 +scene 136 +while 137 +why 138 +scenes 139 +go 140 +such 141 +something 142 +through 143 +should 144 +back 145 +i'm 146 +real 147 +those 148 +watching 149 +now 150 +though 151 +doesn't 152 +years 153 +old 154 +thing 155 +actors 156 +work 157 +10 158 +before 159 +another 160 +didn't 161 +new 162 +funny 163 +nothing 164 +actually 165 +makes 166 +director 167 +look 168 +find 169 +going 170 +few 171 +same 172 +part 173 +again 174 +every 175 +lot 176 +cast 177 +us 178 +quite 179 +down 180 +want 181 +world 182 +things 183 +pretty 184 +young 185 +seems 186 +around 187 +got 188 +horror 189 +however 190 +can't 191 +fact 192 +take 193 +big 194 +enough 195 +long 196 +thought 197 +that's 198 +both 199 +between 200 +series 201 +give 202 +may 203 +original 204 +own 205 +action 206 +i've 207 +right 208 +without 209 +always 210 +times 211 +comedy 212 +point 213 +gets 214 +must 215 +come 216 +role 217 +isn't 218 +saw 219 +almost 220 +interesting 221 +least 222 +family 223 +done 224 +there's 225 +whole 226 +bit 227 +music 228 +script 229 +far 230 +making 231 +guy 232 +anything 233 +minutes 234 +feel 235 +last 236 +since 237 +might 238 +performance 239 +he's 240 +2 241 +probably 242 +kind 243 +am 244 +away 245 +yet 246 +rather 247 +tv 248 +worst 249 +girl 250 +day 251 +sure 252 +fun 253 +hard 254 +woman 255 +played 256 +each 257 +found 258 +anyone 259 +having 260 +although 261 +especially 262 +our 263 +believe 264 +course 265 +comes 266 +looking 267 +screen 268 +trying 269 +set 270 +goes 271 +looks 272 +place 273 +book 274 +different 275 +put 276 +ending 277 +money 278 +maybe 279 +once 280 +sense 281 +reason 282 +true 283 +actor 284 +everything 285 +wasn't 286 +shows 287 +dvd 288 +three 289 +worth 290 +year 291 +job 292 +main 293 +someone 294 +together 295 +watched 296 +play 297 +american 298 +plays 299 +1 300 +said 301 +effects 302 +later 303 +takes 304 +instead 305 +seem 306 +beautiful 307 +john 308 +himself 309 +version 310 +audience 311 +high 312 +house 313 +night 314 +during 315 +everyone 316 +left 317 +special 318 +seeing 319 +half 320 +excellent 321 +wife 322 +star 323 +shot 324 +war 325 +idea 326 +nice 327 +black 328 +less 329 +mind 330 +simply 331 +read 332 +second 333 +else 334 +you're 335 +father 336 +fan 337 +poor 338 +help 339 +completely 340 +death 341 +3 342 +used 343 +home 344 +either 345 +short 346 +line 347 +given 348 +men 349 +top 350 +dead 351 +budget 352 +try 353 +performances 354 +wrong 355 +classic 356 +boring 357 +enjoy 358 +need 359 +rest 360 +use 361 +kids 362 +hollywood 363 +low 364 +production 365 +until 366 +along 367 +full 368 +friends 369 +camera 370 +truly 371 +women 372 +awful 373 +video 374 +next 375 +tell 376 +remember 377 +couple 378 +stupid 379 +start 380 +stars 381 +perhaps 382 +sex 383 +mean 384 +came 385 +recommend 386 +let 387 +moments 388 +wonderful 389 +episode 390 +understand 391 +small 392 +face 393 +terrible 394 +playing 395 +school 396 +getting 397 +written 398 +doing 399 +often 400 +keep 401 +early 402 +name 403 +perfect 404 +style 405 +human 406 +definitely 407 +gives 408 +others 409 +itself 410 +lines 411 +live 412 +become 413 +dialogue 414 +person 415 +lost 416 +finally 417 +piece 418 +head 419 +case 420 +felt 421 +yes 422 +liked 423 +supposed 424 +title 425 +couldn't 426 +absolutely 427 +white 428 +against 429 +boy 430 +picture 431 +sort 432 +worse 433 +certainly 434 +went 435 +entire 436 +waste 437 +cinema 438 +problem 439 +hope 440 +entertaining 441 +she's 442 +mr 443 +overall 444 +evil 445 +called 446 +loved 447 +based 448 +oh 449 +several 450 +fans 451 +mother 452 +drama 453 +beginning 454 +killer 455 +lives 456 +5 457 +direction 458 +care 459 +already 460 +becomes 461 +laugh 462 +example 463 +friend 464 +dark 465 +despite 466 +under 467 +seemed 468 +throughout 469 +4 470 +turn 471 +unfortunately 472 +wanted 473 +i'd 474 +– 475 +children 476 +final 477 +fine 478 +history 479 +amazing 480 +sound 481 +guess 482 +heart 483 +totally 484 +lead 485 +humor 486 +writing 487 +michael 488 +quality 489 +you'll 490 +close 491 +son 492 +guys 493 +wants 494 +works 495 +behind 496 +tries 497 +art 498 +side 499 +game 500 +past 501 +able 502 +b 503 +days 504 +turns 505 +child 506 +they're 507 +hand 508 +flick 509 +enjoyed 510 +act 511 +genre 512 +town 513 +favorite 514 +soon 515 +kill 516 +starts 517 +sometimes 518 +car 519 +gave 520 +run 521 +late 522 +eyes 523 +actress 524 +etc 525 +directed 526 +horrible 527 +won't 528 +viewer 529 +brilliant 530 +parts 531 +self 532 +themselves 533 +hour 534 +expect 535 +thinking 536 +stories 537 +stuff 538 +girls 539 +obviously 540 +blood 541 +decent 542 +city 543 +voice 544 +highly 545 +myself 546 +feeling 547 +fight 548 +except 549 +slow 550 +matter 551 +type 552 +anyway 553 +kid 554 +roles 555 +killed 556 +heard 557 +god 558 +age 559 +says 560 +moment 561 +took 562 +leave 563 +writer 564 +strong 565 +cannot 566 +violence 567 +police 568 +hit 569 +stop 570 +happens 571 +particularly 572 +known 573 +involved 574 +happened 575 +extremely 576 +daughter 577 +obvious 578 +told 579 +chance 580 +living 581 +coming 582 +lack 583 +alone 584 +experience 585 +wouldn't 586 +including 587 +murder 588 +attempt 589 +s 590 +please 591 +james 592 +happen 593 +wonder 594 +crap 595 +ago 596 +brother 597 +film's 598 +gore 599 +none 600 +complete 601 +interest 602 +score 603 +group 604 +cut 605 +simple 606 +save 607 +ok 608 +hell 609 +looked 610 +career 611 +number 612 +song 613 +possible 614 +seriously 615 +annoying 616 +shown 617 +exactly 618 +sad 619 +running 620 +musical 621 +serious 622 +taken 623 +yourself 624 +whose 625 +released 626 +cinematography 627 +david 628 +scary 629 +ends 630 +english 631 +hero 632 +usually 633 +hours 634 +reality 635 +opening 636 +i'll 637 +across 638 +today 639 +jokes 640 +light 641 +hilarious 642 +somewhat 643 +usual 644 +started 645 +cool 646 +ridiculous 647 +body 648 +relationship 649 +view 650 +level 651 +opinion 652 +change 653 +happy 654 +middle 655 +taking 656 +wish 657 +husband 658 +finds 659 +saying 660 +order 661 +talking 662 +ones 663 +documentary 664 +shots 665 +huge 666 +novel 667 +female 668 +mostly 669 +robert 670 +power 671 +episodes 672 +room 673 +important 674 +rating 675 +talent 676 +five 677 +major 678 +turned 679 +strange 680 +word 681 +modern 682 +call 683 +apparently 684 +disappointed 685 +single 686 +events 687 +due 688 +four 689 +songs 690 +basically 691 +attention 692 +7 693 +knows 694 +clearly 695 +supporting 696 +knew 697 +british 698 +television 699 +comic 700 +non 701 +fast 702 +earth 703 +country 704 +future 705 +cheap 706 +class 707 +thriller 708 +8 709 +silly 710 +king 711 +problems 712 +aren't 713 +easily 714 +words 715 +tells 716 +miss 717 +jack 718 +local 719 +sequence 720 +bring 721 +entertainment 722 +paul 723 +beyond 724 +upon 725 +whether 726 +predictable 727 +moving 728 +similar 729 +straight 730 +romantic 731 +sets 732 +review 733 +falls 734 +oscar 735 +mystery 736 +enjoyable 737 +needs 738 +appears 739 +talk 740 +rock 741 +george 742 +giving 743 +eye 744 +richard 745 +within 746 +ten 747 +animation 748 +message 749 +theater 750 +near 751 +above 752 +dull 753 +nearly 754 +sequel 755 +theme 756 +points 757 +' 758 +stand 759 +mention 760 +lady 761 +bunch 762 +add 763 +feels 764 +herself 765 +release 766 +red 767 +team 768 +storyline 769 +surprised 770 +ways 771 +using 772 +named 773 +haven't 774 +lots 775 +easy 776 +fantastic 777 +begins 778 +actual 779 +working 780 +effort 781 +york 782 +die 783 +hate 784 +french 785 +minute 786 +tale 787 +clear 788 +stay 789 +9 790 +elements 791 +feature 792 +among 793 +follow 794 +comments 795 +re 796 +viewers 797 +avoid 798 +sister 799 +showing 800 +typical 801 +editing 802 +what's 803 +famous 804 +tried 805 +sorry 806 +dialog 807 +check 808 +fall 809 +period 810 +season 811 +form 812 +certain 813 +filmed 814 +weak 815 +soundtrack 816 +means 817 +buy 818 +material 819 +somehow 820 +realistic 821 +figure 822 +crime 823 +doubt 824 +gone 825 +peter 826 +tom 827 +kept 828 +viewing 829 +t 830 +general 831 +leads 832 +greatest 833 +space 834 +lame 835 +suspense 836 +dance 837 +imagine 838 +brought 839 +third 840 +atmosphere 841 +hear 842 +particular 843 +sequences 844 +whatever 845 +parents 846 +move 847 +lee 848 +indeed 849 +learn 850 +rent 851 +de 852 +eventually 853 +note 854 +deal 855 +average 856 +reviews 857 +wait 858 +forget 859 +japanese 860 +sexual 861 +poorly 862 +premise 863 +okay 864 +zombie 865 +surprise 866 +believable 867 +stage 868 +possibly 869 +sit 870 +who's 871 +decided 872 +expected 873 +you've 874 +subject 875 +nature 876 +became 877 +difficult 878 +free 879 +killing 880 +screenplay 881 +truth 882 +romance 883 +dr 884 +nor 885 +reading 886 +needed 887 +question 888 +leaves 889 +street 890 +20 891 +meets 892 +hot 893 +unless 894 +begin 895 +baby 896 +superb 897 +credits 898 +imdb 899 +otherwise 900 +write 901 +shame 902 +let's 903 +situation 904 +dramatic 905 +memorable 906 +directors 907 +earlier 908 +meet 909 +disney 910 +open 911 +dog 912 +badly 913 +joe 914 +male 915 +weird 916 +acted 917 +forced 918 +laughs 919 +sci 920 +emotional 921 +older 922 +realize 923 +fi 924 +dream 925 +society 926 +writers 927 +interested 928 +footage 929 +forward 930 +comment 931 +crazy 932 +deep 933 +sounds 934 +plus 935 +beauty 936 +whom 937 +america 938 +fantasy 939 +directing 940 +keeps 941 +ask 942 +development 943 +features 944 +air 945 +quickly 946 +mess 947 +creepy 948 +towards 949 +perfectly 950 +mark 951 +worked 952 +box 953 +cheesy 954 +unique 955 +setting 956 +hands 957 +plenty 958 +result 959 +previous 960 +brings 961 +effect 962 +e 963 +total 964 +personal 965 +incredibly 966 +rate 967 +fire 968 +monster 969 +business 970 +leading 971 +apart 972 +casting 973 +admit 974 +joke 975 +powerful 976 +appear 977 +background 978 +telling 979 +girlfriend 980 +meant 981 +christmas 982 +hardly 983 +present 984 +battle 985 +potential 986 +create 987 +bill 988 +break 989 +pay 990 +masterpiece 991 +gay 992 +political 993 +return 994 +dumb 995 +fails 996 +fighting 997 +various 998 +era 999 +portrayed 1000 +co 1001 +cop 1002 +secret 1003 +inside 1004 +outside 1005 +nudity 1006 +reasons 1007 +ideas 1008 +twist 1009 +western 1010 +front 1011 +missing 1012 +boys 1013 +match 1014 +deserves 1015 +jane 1016 +expecting 1017 +fairly 1018 +villain 1019 +talented 1020 +married 1021 +ben 1022 +success 1023 +william 1024 +unlike 1025 +rich 1026 +attempts 1027 +spoilers 1028 +list 1029 +manages 1030 +social 1031 +odd 1032 +recently 1033 +remake 1034 +flat 1035 +cute 1036 +further 1037 +sadly 1038 +copy 1039 +wrote 1040 +agree 1041 +doctor 1042 +cold 1043 +plain 1044 +following 1045 +mentioned 1046 +sweet 1047 +incredible 1048 +missed 1049 +pure 1050 +crew 1051 +office 1052 +wasted 1053 +ended 1054 +produced 1055 +gun 1056 +filmmakers 1057 +large 1058 +caught 1059 +revenge 1060 +filled 1061 +pace 1062 +popular 1063 +waiting 1064 +'the 1065 +members 1066 +science 1067 +decides 1068 +considering 1069 +hold 1070 +public 1071 +cartoon 1072 +party 1073 +tension 1074 +created 1075 +slightly 1076 +uses 1077 +convincing 1078 +compared 1079 +la 1080 +familiar 1081 +neither 1082 +mary 1083 +spent 1084 +sees 1085 +6 1086 +suddenly 1087 +30 1088 +intelligent 1089 +escape 1090 +scott 1091 +fear 1092 +water 1093 +brothers 1094 +d 1095 +clever 1096 +entirely 1097 +kills 1098 +choice 1099 +bored 1100 +language 1101 +moves 1102 +spirit 1103 +laughing 1104 +dancing 1105 +we're 1106 +value 1107 +cover 1108 +credit 1109 +state 1110 +island 1111 +successful 1112 +trouble 1113 +visual 1114 +violent 1115 +ultimately 1116 +century 1117 +singing 1118 +15 1119 +concept 1120 +basic 1121 +italian 1122 +positive 1123 +german 1124 +animated 1125 +biggest 1126 +exciting 1127 +speak 1128 +runs 1129 +store 1130 +died 1131 +cat 1132 +consider 1133 +effective 1134 +walk 1135 +recent 1136 +depth 1137 +former 1138 +amusing 1139 +control 1140 +common 1141 +spend 1142 +band 1143 +appreciate 1144 +zombies 1145 +portrayal 1146 +force 1147 +c 1148 +pointless 1149 +rated 1150 +books 1151 +focus 1152 +hair 1153 +adventure 1154 +younger 1155 +solid 1156 +trash 1157 +adult 1158 +impressive 1159 +follows 1160 +respect 1161 +bizarre 1162 +tone 1163 +law 1164 +super 1165 +amount 1166 +impossible 1167 +mad 1168 +company 1169 +college 1170 +van 1171 +prison 1172 +weren't 1173 +conclusion 1174 +chemistry 1175 +win 1176 +showed 1177 +recommended 1178 +slasher 1179 +producers 1180 +culture 1181 +studio 1182 +fit 1183 +starring 1184 +heavy 1185 +situations 1186 +project 1187 +makers 1188 +trip 1189 +awesome 1190 +accent 1191 +considered 1192 +disturbing 1193 +changed 1194 +sick 1195 +failed 1196 +decide 1197 +somewhere 1198 +won 1199 +leaving 1200 +barely 1201 +honest 1202 +cause 1203 +questions 1204 +shooting 1205 +u 1206 +longer 1207 +post 1208 +f 1209 +anti 1210 +tough 1211 +aside 1212 +ghost 1213 +fake 1214 +cult 1215 +thanks 1216 +meaning 1217 +images 1218 +fiction 1219 +charming 1220 +audiences 1221 +computer 1222 +tony 1223 +brain 1224 +planet 1225 +south 1226 +literally 1227 +generally 1228 +touch 1229 +steve 1230 +stick 1231 +likes 1232 +ex 1233 +values 1234 +pathetic 1235 +magic 1236 +involving 1237 +surprisingly 1238 +alive 1239 +jim 1240 +immediately 1241 +grade 1242 +yeah 1243 +garbage 1244 +100 1245 +dad 1246 +bought 1247 +military 1248 +natural 1249 +camp 1250 +aspect 1251 +honestly 1252 +adaptation 1253 +utterly 1254 +detective 1255 +ability 1256 +fair 1257 +shoot 1258 +smith 1259 +explain 1260 +pick 1261 +genius 1262 +west 1263 +glad 1264 +frank 1265 +sitting 1266 +appearance 1267 +pictures 1268 +week 1269 +motion 1270 +appeal 1271 +army 1272 +standard 1273 +attack 1274 +knowing 1275 +personally 1276 +catch 1277 +drive 1278 +sexy 1279 +normal 1280 +rare 1281 +nowhere 1282 +added 1283 +sam 1284 +humour 1285 +walking 1286 +remains 1287 +purpose 1288 +edge 1289 +comedies 1290 +thinks 1291 +loud 1292 +beautifully 1293 +thank 1294 +silent 1295 +taste 1296 +unbelievable 1297 +naked 1298 +twists 1299 +master 1300 +touching 1301 +subtle 1302 +terms 1303 +date 1304 +equally 1305 +dreams 1306 +terrific 1307 +channel 1308 +drawn 1309 +mood 1310 +journey 1311 +door 1312 +chase 1313 +fully 1314 +complex 1315 +london 1316 +key 1317 +wow 1318 +managed 1319 +road 1320 +narrative 1321 +laughable 1322 +mistake 1323 +bottom 1324 +producer 1325 +themes 1326 +movie's 1327 +pieces 1328 +likely 1329 +climax 1330 +g 1331 +disappointing 1332 +club 1333 +lovely 1334 +harry 1335 +blue 1336 +nobody 1337 +excuse 1338 +outstanding 1339 +soldiers 1340 +issues 1341 +stewart 1342 +constantly 1343 +award 1344 +pass 1345 +thus 1346 +plan 1347 +surely 1348 +marriage 1349 +painful 1350 +justice 1351 +costumes 1352 +presented 1353 +batman 1354 +80's 1355 +innocent 1356 +soul 1357 +wild 1358 +noir 1359 +cinematic 1360 +spoiler 1361 +vampire 1362 +finish 1363 +slowly 1364 +ride 1365 +gang 1366 +contains 1367 +christopher 1368 +presence 1369 +places 1370 +besides 1371 +government 1372 +details 1373 +train 1374 +central 1375 +thrown 1376 +manner 1377 +chris 1378 +historical 1379 +stunning 1380 +photography 1381 +charm 1382 +hoping 1383 +impression 1384 +scenery 1385 +speaking 1386 +disappointment 1387 +loves 1388 +animals 1389 +you'd 1390 +developed 1391 +drug 1392 +smart 1393 +charles 1394 +indian 1395 +numbers 1396 +mysterious 1397 +expectations 1398 +color 1399 +hey 1400 +exception 1401 +throw 1402 +minor 1403 +ahead 1404 +double 1405 +track 1406 +stands 1407 +suppose 1408 +aspects 1409 +boss 1410 +woods 1411 +sent 1412 +festival 1413 +bother 1414 +cry 1415 +church 1416 +feelings 1417 +critics 1418 +green 1419 +brief 1420 +acts 1421 +opera 1422 +filming 1423 +mainly 1424 +support 1425 +emotion 1426 +element 1427 +held 1428 +fascinating 1429 +building 1430 +million 1431 +boyfriend 1432 +names 1433 +opportunity 1434 +serial 1435 +intended 1436 +forever 1437 +emotions 1438 +available 1439 +victim 1440 +charlie 1441 +dies 1442 +changes 1443 +compelling 1444 +bed 1445 +six 1446 +born 1447 +happening 1448 +bar 1449 +paris 1450 +likable 1451 +lived 1452 +twice 1453 +falling 1454 +hotel 1455 +zero 1456 +puts 1457 +tired 1458 +image 1459 +pain 1460 +lover 1461 +everybody 1462 +giant 1463 +offer 1464 +shock 1465 +spot 1466 +suggest 1467 +j 1468 +henry 1469 +include 1470 +confused 1471 +trailer 1472 +adults 1473 +difference 1474 +student 1475 +fresh 1476 +followed 1477 +bruce 1478 +r 1479 +kelly 1480 +hasn't 1481 +appeared 1482 +approach 1483 +victims 1484 +christian 1485 +fellow 1486 +hurt 1487 +impact 1488 +putting 1489 +gorgeous 1490 +step 1491 +sub 1492 +mix 1493 +event 1494 +notice 1495 +murders 1496 +share 1497 +laughed 1498 +confusing 1499 +content 1500 +mediocre 1501 +11 1502 +lacks 1503 +direct 1504 +supposedly 1505 +summer 1506 +actresses 1507 +flaws 1508 +porn 1509 +system 1510 +page 1511 +holes 1512 +wall 1513 +billy 1514 +moral 1515 +jerry 1516 +worthy 1517 +creative 1518 +relationships 1519 +rape 1520 +tragedy 1521 +race 1522 +thin 1523 +lighting 1524 +helps 1525 +random 1526 +answer 1527 +gem 1528 +funniest 1529 +ii 1530 +americans 1531 +jones 1532 +merely 1533 +proves 1534 +wondering 1535 +alien 1536 +students 1537 +ray 1538 +paid 1539 +al 1540 +land 1541 +seven 1542 +damn 1543 +agent 1544 +delivers 1545 +imagination 1546 +park 1547 +childhood 1548 +flying 1549 +hospital 1550 +forgotten 1551 +90 1552 +standards 1553 +flicks 1554 +impressed 1555 +finding 1556 +absolute 1557 +ugly 1558 +beat 1559 +jean 1560 +don 1561 +thoroughly 1562 +ms 1563 +attractive 1564 +ground 1565 +negative 1566 +wise 1567 +provides 1568 +latter 1569 +50 1570 +stuck 1571 +extreme 1572 +seemingly 1573 +seconds 1574 +becoming 1575 +winning 1576 +addition 1577 +reminded 1578 +tragic 1579 +offers 1580 +inspired 1581 +count 1582 +fell 1583 +thats 1584 +lose 1585 +affair 1586 +turning 1587 +folks 1588 +detail 1589 +faces 1590 +cliché 1591 +design 1592 +martin 1593 +collection 1594 +afraid 1595 +intense 1596 +fashion 1597 +pull 1598 +hidden 1599 +industry 1600 +man's 1601 +allen 1602 +apartment 1603 +o 1604 +quick 1605 +nasty 1606 +arthur 1607 +adds 1608 +area 1609 +rented 1610 +alan 1611 +angry 1612 +personality 1613 +artistic 1614 +length 1615 +shouldn't 1616 +therefore 1617 +information 1618 +chinese 1619 +brian 1620 +shocking 1621 +location 1622 +ready 1623 +professional 1624 +lets 1625 +animal 1626 +anymore 1627 +games 1628 +teen 1629 +states 1630 +soldier 1631 +listen 1632 +mom 1633 +describe 1634 +lord 1635 +news 1636 +picked 1637 +led 1638 +wooden 1639 +favourite 1640 +dirty 1641 +mouth 1642 +asks 1643 +food 1644 +deliver 1645 +onto 1646 +martial 1647 +bond 1648 +clothes 1649 +wars 1650 +struggle 1651 +queen 1652 +redeeming 1653 +stone 1654 +jason 1655 +scientist 1656 +p 1657 +wearing 1658 +ed 1659 +stephen 1660 +compare 1661 +castle 1662 +intelligence 1663 +creature 1664 +cross 1665 +sleep 1666 +teenage 1667 +allowed 1668 +wonderfully 1669 +necessary 1670 +carry 1671 +drugs 1672 +40 1673 +tears 1674 +fox 1675 +criminal 1676 +rip 1677 +helped 1678 +member 1679 +desperate 1680 +moved 1681 +sight 1682 +cgi 1683 +trust 1684 +deeply 1685 +roll 1686 +includes 1687 +willing 1688 +whatsoever 1689 +disaster 1690 +12 1691 +machine 1692 +ship 1693 +treat 1694 +began 1695 +mid 1696 +uncle 1697 +grace 1698 +phone 1699 +70's 1700 +williams 1701 +commentary 1702 +build 1703 +accident 1704 +captain 1705 +realized 1706 +plane 1707 +energy 1708 +station 1709 +warning 1710 +epic 1711 +davis 1712 +rarely 1713 +humans 1714 +loving 1715 +theatre 1716 +comedic 1717 +witch 1718 +pop 1719 +suicide 1720 +dying 1721 +powers 1722 +filmmaker 1723 +independent 1724 +introduced 1725 +nightmare 1726 +extra 1727 +engaging 1728 +actions 1729 +character's 1730 +superior 1731 +unusual 1732 +arts 1733 +apparent 1734 +suit 1735 +religious 1736 +heroes 1737 +danny 1738 +remarkable 1739 +artist 1740 +allow 1741 +pleasure 1742 +continue 1743 +unnecessary 1744 +x 1745 +ring 1746 +returns 1747 +physical 1748 +sky 1749 +teacher 1750 +pre 1751 +mental 1752 +watchable 1753 +provide 1754 +absurd 1755 +tim 1756 +memory 1757 +grand 1758 +technical 1759 +normally 1760 +wedding 1761 +desire 1762 +limited 1763 +anywhere 1764 +scared 1765 +russian 1766 +surprising 1767 +douglas 1768 +finished 1769 +brutal 1770 +skip 1771 +vision 1772 +process 1773 +intriguing 1774 +bloody 1775 +media 1776 +holds 1777 +exist 1778 +accept 1779 +nicely 1780 +suspect 1781 +000 1782 +jump 1783 +twenty 1784 +paced 1785 +wanting 1786 +search 1787 +cops 1788 +torture 1789 +growing 1790 +reminds 1791 +jr 1792 +according 1793 +pacing 1794 +legend 1795 +soft 1796 +passion 1797 +andy 1798 +player 1799 +hated 1800 +bits 1801 +fred 1802 +asked 1803 +faith 1804 +joy 1805 +johnny 1806 +clichés 1807 +jeff 1808 +academy 1809 +dressed 1810 +pilot 1811 +eddie 1812 +constant 1813 +anybody 1814 +ill 1815 +deserved 1816 +horse 1817 +gold 1818 +drunk 1819 +joan 1820 +blame 1821 +originally 1822 +explanation 1823 +dangerous 1824 +instance 1825 +smile 1826 +heaven 1827 +heads 1828 +sat 1829 +community 1830 +england 1831 +superman 1832 +deserve 1833 +issue 1834 +nonsense 1835 +met 1836 +dick 1837 +lies 1838 +capture 1839 +gotten 1840 +toward 1841 +kevin 1842 +somebody 1843 +soap 1844 +field 1845 +lovers 1846 +plots 1847 +taylor 1848 +mixed 1849 +players 1850 +nick 1851 +explained 1852 +record 1853 +fail 1854 +creating 1855 +vhs 1856 +knowledge 1857 +quiet 1858 +unknown 1859 +fights 1860 +starting 1861 +friendship 1862 +accurate 1863 +whilst 1864 +guns 1865 +price 1866 +adam 1867 +kate 1868 +hadn't 1869 +sucks 1870 +ball 1871 +river 1872 +floor 1873 +european 1874 +spanish 1875 +wide 1876 +cable 1877 +radio 1878 +fu 1879 +cars 1880 +jackson 1881 +realism 1882 +memories 1883 +moon 1884 +finest 1885 +heroine 1886 +aware 1887 +loose 1888 +eating 1889 +featuring 1890 +prince 1891 +lacking 1892 +responsible 1893 +saved 1894 +keeping 1895 +empty 1896 +understanding 1897 +japan 1898 +treated 1899 +eat 1900 +results 1901 +cuts 1902 +ice 1903 +bland 1904 +terribly 1905 +pulled 1906 +saving 1907 +below 1908 +officer 1909 +villains 1910 +candy 1911 +broken 1912 +sign 1913 +ladies 1914 +hopes 1915 +rubbish 1916 +delightful 1917 +vs 1918 +judge 1919 +witty 1920 +manage 1921 +fat 1922 +mine 1923 +gene 1924 +noticed 1925 +included 1926 +bright 1927 +months 1928 +forces 1929 +screaming 1930 +higher 1931 +kinda 1932 +wind 1933 +tarzan 1934 +cage 1935 +hits 1936 +loss 1937 +today's 1938 +monsters 1939 +youth 1940 +sing 1941 +numerous 1942 +partner 1943 +conflict 1944 +whenever 1945 +humanity 1946 +concerned 1947 +pretentious 1948 +fate 1949 +singer 1950 +dealing 1951 +mike 1952 +driving 1953 +jesus 1954 +private 1955 +talents 1956 +discovered 1957 +naturally 1958 +skills 1959 +unfunny 1960 +opposite 1961 +finale 1962 +bigger 1963 +v 1964 +ann 1965 +international 1966 +dated 1967 +kick 1968 +ups 1969 +prove 1970 +perspective 1971 +morning 1972 +mission 1973 +discover 1974 +portray 1975 +blonde 1976 +here's 1977 +loses 1978 +locations 1979 +visit 1980 +ordinary 1981 +bank 1982 +m 1983 +humorous 1984 +werewolf 1985 +streets 1986 +psychological 1987 +regular 1988 +reviewers 1989 +received 1990 +kong 1991 +w 1992 +edited 1993 +gags 1994 +ass 1995 +luck 1996 +curious 1997 +gary 1998 +continues 1999 +magnificent 2000 +13 2001 +we've 2002 +behavior 2003 +captured 2004 +jimmy 2005 +satire 2006 +survive 2007 +context 2008 +visually 2009 +breaks 2010 +existence 2011 +shallow 2012 +opens 2013 +l 2014 +mrs 2015 +debut 2016 +advice 2017 +calls 2018 +sea 2019 +foot 2020 +morgan 2021 +shop 2022 +h 2023 +murdered 2024 +connection 2025 +core 2026 +essentially 2027 +current 2028 +revealed 2029 +director's 2030 +corny 2031 +remembered 2032 +deals 2033 +blind 2034 +frankly 2035 +occasionally 2036 +lesson 2037 +genuine 2038 +scream 2039 +traditional 2040 +they've 2041 +lucky 2042 +identity 2043 +dimensional 2044 +african 2045 +bob 2046 +anthony 2047 +efforts 2048 +sean 2049 +golden 2050 +learned 2051 +segment 2052 +stock 2053 +window 2054 +cameo 2055 +owner 2056 +visuals 2057 +versions 2058 +village 2059 +albert 2060 +develop 2061 +santa 2062 +formula 2063 +miles 2064 +keaton 2065 +one's 2066 +sucked 2067 +decade 2068 +buddy 2069 +genuinely 2070 +grown 2071 +references 2072 +suffering 2073 +boat 2074 +lewis 2075 +unexpected 2076 +favor 2077 +study 2078 +washington 2079 +allows 2080 +program 2081 +national 2082 +grew 2083 +80s 2084 +proved 2085 +meanwhile 2086 +overly 2087 +ages 2088 +board 2089 +standing 2090 +logic 2091 +desert 2092 +spectacular 2093 +awkward 2094 +ultimate 2095 +comparison 2096 +reaction 2097 +rob 2098 +sheer 2099 +jennifer 2100 +reach 2101 +thomas 2102 +unable 2103 +failure 2104 +brilliantly 2105 +travel 2106 +grant 2107 +ford 2108 +vampires 2109 +types 2110 +parody 2111 +gangster 2112 +devil 2113 +steal 2114 +brown 2115 +passed 2116 +sudden 2117 +stereotypes 2118 +sake 2119 +flesh 2120 +leader 2121 +frame 2122 +bear 2123 +strength 2124 +speed 2125 +creates 2126 +eric 2127 +awards 2128 +laughter 2129 +dan 2130 +technology 2131 +delivered 2132 +author 2133 +bet 2134 +kung 2135 +crappy 2136 +wood 2137 +site 2138 +broadway 2139 +insane 2140 +trek 2141 +executed 2142 +relief 2143 +lake 2144 +hitler 2145 +gonna 2146 +discovers 2147 +emotionally 2148 +painfully 2149 +dreadful 2150 +marie 2151 +utter 2152 +commercial 2153 +decision 2154 +code 2155 +steven 2156 +fault 2157 +anime 2158 +majority 2159 +anne 2160 +round 2161 +pair 2162 +robin 2163 +caused 2164 +bomb 2165 +families 2166 +psycho 2167 +driven 2168 +attitude 2169 +clean 2170 +built 2171 +gratuitous 2172 +harris 2173 +native 2174 +luke 2175 +entertained 2176 +graphic 2177 +ran 2178 +killers 2179 +meeting 2180 +test 2181 +simon 2182 +flashbacks 2183 +underrated 2184 +nevertheless 2185 +model 2186 +seasons 2187 +asian 2188 +foreign 2189 +hill 2190 +levels 2191 +obsessed 2192 +evening 2193 +feet 2194 +halloween 2195 +vehicle 2196 +barbara 2197 +relate 2198 +treatment 2199 +rise 2200 +practically 2201 +range 2202 +endless 2203 +freedom 2204 +costs 2205 +religion 2206 +gory 2207 +cash 2208 +described 2209 +wit 2210 +pleasant 2211 +aged 2212 +ancient 2213 +tape 2214 +reviewer 2215 +center 2216 +president 2217 +chosen 2218 +lynch 2219 +product 2220 +combination 2221 +send 2222 +fly 2223 +seat 2224 +sell 2225 +70s 2226 +irritating 2227 +exploitation 2228 +excited 2229 +stopped 2230 +hearing 2231 +rescue 2232 +fill 2233 +howard 2234 +portrays 2235 +gordon 2236 +assume 2237 +parker 2238 +classics 2239 +pity 2240 +0 2241 +produce 2242 +hunter 2243 +breaking 2244 +dry 2245 +fame 2246 +anna 2247 +generation 2248 +sheriff 2249 +capable 2250 +believes 2251 +handsome 2252 +theatrical 2253 +asking 2254 +sports 2255 +largely 2256 +choose 2257 +theaters 2258 +sympathetic 2259 +extras 2260 +proper 2261 +ruined 2262 +cares 2263 +contrived 2264 +portraying 2265 +drew 2266 +individual 2267 +embarrassing 2268 +rules 2269 +unrealistic 2270 +learns 2271 +warm 2272 +victor 2273 +daniel 2274 +marry 2275 +appealing 2276 +safe 2277 +dubbed 2278 +depressing 2279 +canadian 2280 +freddy 2281 +shakespeare 2282 +recall 2283 +chick 2284 +uk 2285 +winner 2286 +hearted 2287 +contrast 2288 +sequels 2289 +involves 2290 +par 2291 +woody 2292 +crowd 2293 +matters 2294 +k 2295 +correct 2296 +chief 2297 +costume 2298 +haunting 2299 +paper 2300 +research 2301 +vote 2302 +strongly 2303 +heck 2304 +nominated 2305 +grow 2306 +clue 2307 +claim 2308 +facts 2309 +eight 2310 +protagonist 2311 +matt 2312 +rose 2313 +evidence 2314 +joseph 2315 +appropriate 2316 +disgusting 2317 +excitement 2318 +football 2319 +lousy 2320 +germany 2321 +cost 2322 +france 2323 +saturday 2324 +priest 2325 +talks 2326 +substance 2327 +losing 2328 +patrick 2329 +destroy 2330 +circumstances 2331 +tedious 2332 +training 2333 +thoughts 2334 +hunt 2335 +market 2336 +scare 2337 +voices 2338 +promise 2339 +naive 2340 +bringing 2341 +amateurish 2342 +teenager 2343 +angel 2344 +walter 2345 +captures 2346 +convinced 2347 +hanging 2348 +satisfying 2349 +bodies 2350 +united 2351 +fits 2352 +tend 2353 +jackie 2354 +trilogy 2355 +roy 2356 +horribly 2357 +lower 2358 +asleep 2359 +virtually 2360 +baseball 2361 +robot 2362 +hopefully 2363 +rental 2364 +alex 2365 +com 2366 +factor 2367 +haunted 2368 +teenagers 2369 +hall 2370 +walks 2371 +spoil 2372 +creatures 2373 +amateur 2374 +relatively 2375 +steals 2376 +mask 2377 +welcome 2378 +cinderella 2379 +covered 2380 +ryan 2381 +danger 2382 +europe 2383 +insult 2384 +category 2385 +continuity 2386 +mini 2387 +unlikely 2388 +drag 2389 +sinatra 2390 +skin 2391 +contemporary 2392 +louis 2393 +semi 2394 +viewed 2395 +fare 2396 +north 2397 +influence 2398 +depicted 2399 +handled 2400 +target 2401 +oliver 2402 +offensive 2403 +hat 2404 +initial 2405 +nancy 2406 +scale 2407 +lawyer 2408 +tiny 2409 +cutting 2410 +unfortunate 2411 +holding 2412 +witness 2413 +shocked 2414 +africa 2415 +remain 2416 +believed 2417 +fool 2418 +inner 2419 +politics 2420 +hide 2421 +reporter 2422 +presents 2423 +section 2424 +movement 2425 +provided 2426 +surreal 2427 +promising 2428 +designed 2429 +makeup 2430 +max 2431 +qualities 2432 +liners 2433 +refreshing 2434 +australian 2435 +source 2436 +14 2437 +structure 2438 +closer 2439 +drop 2440 +forgettable 2441 +touches 2442 +welles 2443 +display 2444 +angles 2445 +pile 2446 +fairy 2447 +repeated 2448 +till 2449 +texas 2450 +wayne 2451 +claims 2452 +previously 2453 +faced 2454 +sharp 2455 +deaths 2456 +ruin 2457 +accents 2458 +surprises 2459 +universal 2460 +degree 2461 +focused 2462 +propaganda 2463 +plans 2464 +serves 2465 +speaks 2466 +supernatural 2467 +highlight 2468 +service 2469 +peace 2470 +chose 2471 +related 2472 +cartoons 2473 +adventures 2474 +erotic 2475 +25 2476 +roger 2477 +suffers 2478 +blow 2479 +weekend 2480 +sisters 2481 +granted 2482 +mainstream 2483 +latest 2484 +weeks 2485 +prime 2486 +crash 2487 +cant 2488 +professor 2489 +experiences 2490 +speech 2491 +print 2492 +lesbian 2493 +harsh 2494 +deadly 2495 +veteran 2496 +mistakes 2497 +edward 2498 +routine 2499 +whoever 2500 +notch 2501 +uninteresting 2502 +realizes 2503 +invisible 2504 +combined 2505 +sympathy 2506 +accidentally 2507 +kim 2508 +twisted 2509 +brave 2510 +colors 2511 +dollars 2512 +security 2513 +draw 2514 +dogs 2515 +nude 2516 +rain 2517 +universe 2518 +struggling 2519 +dozen 2520 +teens 2521 +convince 2522 +guilty 2523 +path 2524 +appreciated 2525 +atrocious 2526 +mountain 2527 +treasure 2528 +walked 2529 +columbo 2530 +irish 2531 +frightening 2532 +would've 2533 +committed 2534 +aliens 2535 +technically 2536 +recognize 2537 +cowboy 2538 +blah 2539 +birth 2540 +enter 2541 +gritty 2542 +enemy 2543 +aka 2544 +spy 2545 +changing 2546 +magical 2547 +anderson 2548 +princess 2549 +department 2550 +gas 2551 +occasional 2552 +friday 2553 +sword 2554 +directly 2555 +false 2556 +massive 2557 +surface 2558 +narration 2559 +legendary 2560 +featured 2561 +victoria 2562 +anger 2563 +offered 2564 +paint 2565 +performed 2566 +moore 2567 +explains 2568 +abuse 2569 +suspenseful 2570 +vietnam 2571 +kinds 2572 +terror 2573 +experienced 2574 +friendly 2575 +subtitles 2576 +reputation 2577 +crying 2578 +hong 2579 +sorts 2580 +passing 2581 +junk 2582 +beach 2583 +multiple 2584 +forest 2585 +stolen 2586 +everywhere 2587 +figures 2588 +forth 2589 +statement 2590 +exact 2591 +powell 2592 +variety 2593 +required 2594 +clark 2595 +reveal 2596 +donald 2597 +regret 2598 +conversation 2599 +prior 2600 +darkness 2601 +remotely 2602 +execution 2603 +theory 2604 +trapped 2605 +proud 2606 +belief 2607 +urban 2608 +russell 2609 +lonely 2610 +placed 2611 +downright 2612 +wilson 2613 +san 2614 +fictional 2615 +melodrama 2616 +spends 2617 +insight 2618 +court 2619 +effectively 2620 +listening 2621 +grave 2622 +express 2623 +demons 2624 +crude 2625 +figured 2626 +bothered 2627 +abandoned 2628 +scares 2629 +network 2630 +unconvincing 2631 +jobs 2632 +hired 2633 +revolution 2634 +favorites 2635 +jon 2636 +wear 2637 +minds 2638 +metal 2639 +worthwhile 2640 +emma 2641 +california 2642 +dean 2643 +buying 2644 +blockbuster 2645 +lifetime 2646 +bus 2647 +paying 2648 +pulls 2649 +account 2650 +angle 2651 +happiness 2652 +von 2653 +blown 2654 +afternoon 2655 +imagery 2656 +rights 2657 +driver 2658 +alright 2659 +rolling 2660 +matrix 2661 +mexican 2662 +productions 2663 +amazed 2664 +idiot 2665 +rings 2666 +cultural 2667 +status 2668 +delivery 2669 +thankfully 2670 +grim 2671 +reveals 2672 +rule 2673 +stayed 2674 +handed 2675 +alice 2676 +stays 2677 +scenario 2678 +focuses 2679 +ha 2680 +significant 2681 +quest 2682 +rough 2683 +starred 2684 +examples 2685 +julia 2686 +jungle 2687 +sir 2688 +indie 2689 +lights 2690 +mere 2691 +views 2692 +murphy 2693 +shadow 2694 +sarah 2695 +bore 2696 +con 2697 +teeth 2698 +heavily 2699 +mature 2700 +device 2701 +table 2702 +skill 2703 +interview 2704 +caine 2705 +tight 2706 +necessarily 2707 +he'd 2708 +ron 2709 +sunday 2710 +clichéd 2711 +suffer 2712 +mexico 2713 +china 2714 +achieve 2715 +spite 2716 +understood 2717 +format 2718 +artists 2719 +position 2720 +initially 2721 +closing 2722 +campy 2723 +desperately 2724 +bound 2725 +fabulous 2726 +dress 2727 +sensitive 2728 +mgm 2729 +destroyed 2730 +hip 2731 +complicated 2732 +burns 2733 +demon 2734 +summary 2735 +seek 2736 +faithful 2737 +forgot 2738 +sun 2739 +decades 2740 +breath 2741 +gross 2742 +pitt 2743 +bourne 2744 +ghosts 2745 +titanic 2746 +cruel 2747 +murderer 2748 +stereotypical 2749 +deeper 2750 +lisa 2751 +facial 2752 +renting 2753 +ignore 2754 +pregnant 2755 +league 2756 +answers 2757 +racist 2758 +un 2759 +helping 2760 +ludicrous 2761 +beloved 2762 +flashback 2763 +slapstick 2764 +sleeping 2765 +17 2766 +dude 2767 +cell 2768 +musicals 2769 +fourth 2770 +wing 2771 +intellectual 2772 +beast 2773 +sounded 2774 +settings 2775 +environment 2776 +suck 2777 +critical 2778 +drinking 2779 +nazi 2780 +reminiscent 2781 +brad 2782 +calling 2783 +lugosi 2784 +dragon 2785 +description 2786 +susan 2787 +prefer 2788 +amazingly 2789 +task 2790 +mildly 2791 +pacino 2792 +disbelief 2793 +encounter 2794 +regarding 2795 +larry 2796 +inept 2797 +greater 2798 +learning 2799 +arms 2800 +dennis 2801 +extraordinary 2802 +turkey 2803 +storytelling 2804 +funnier 2805 +julie 2806 +halfway 2807 +ain't 2808 +expert 2809 +base 2810 +criticism 2811 +quirky 2812 +father's 2813 +leslie 2814 +warned 2815 +cabin 2816 +flight 2817 +titles 2818 +criminals 2819 +johnson 2820 +raw 2821 +praise 2822 +depiction 2823 +screening 2824 +throwing 2825 +extent 2826 +expression 2827 +kiss 2828 +jail 2829 +studios 2830 +freeman 2831 +truck 2832 +convey 2833 +originality 2834 +chan 2835 +entertain 2836 +choices 2837 +spoof 2838 +notorious 2839 +tree 2840 +raised 2841 +touched 2842 +children's 2843 +rachel 2844 +punch 2845 +experiment 2846 +daughters 2847 +prepared 2848 +comical 2849 +spoken 2850 +people's 2851 +timing 2852 +india 2853 +headed 2854 +purely 2855 +could've 2856 +basis 2857 +hoffman 2858 +bollywood 2859 +chilling 2860 +michelle 2861 +underground 2862 +dollar 2863 +via 2864 +picks 2865 +lie 2866 +inspiration 2867 +novels 2868 +wave 2869 +elizabeth 2870 +introduction 2871 +weapons 2872 +trick 2873 +lazy 2874 +jessica 2875 +graphics 2876 +breathtaking 2877 +notable 2878 +stomach 2879 +succeeds 2880 +term 2881 +crafted 2882 +join 2883 +throws 2884 +handle 2885 +strangely 2886 +properly 2887 +toy 2888 +nowadays 2889 +christ 2890 +sidney 2891 +reference 2892 +adding 2893 +claire 2894 +serve 2895 +ratings 2896 +locked 2897 +honor 2898 +wears 2899 +sitcom 2900 +ted 2901 +authentic 2902 +foster 2903 +regard 2904 +everyday 2905 +causes 2906 +maria 2907 +provoking 2908 +charge 2909 +protect 2910 +lesser 2911 +hitchcock 2912 +caring 2913 +mouse 2914 +mirror 2915 +bat 2916 +fallen 2917 +carrying 2918 +bitter 2919 +jewish 2920 +established 2921 +pet 2922 +amongst 2923 +east 2924 +shut 2925 +guard 2926 +midnight 2927 +sleazy 2928 +southern 2929 +determined 2930 +ned 2931 +challenge 2932 +daily 2933 +obnoxious 2934 +nonetheless 2935 +cases 2936 +carried 2937 +carries 2938 +wins 2939 +alas 2940 +remote 2941 +embarrassed 2942 +gruesome 2943 +hole 2944 +2006 2945 +lane 2946 +attempting 2947 +westerns 2948 +escapes 2949 +sinister 2950 +confusion 2951 +nation 2952 +tales 2953 +ironic 2954 +tradition 2955 +interpretation 2956 +arrives 2957 +busy 2958 +replaced 2959 +risk 2960 +enjoying 2961 +sold 2962 +essential 2963 +needless 2964 +aunt 2965 +hardy 2966 +burt 2967 +goofy 2968 +mass 2969 +obsession 2970 +minded 2971 +balance 2972 +flow 2973 +clips 2974 +existent 2975 +successfully 2976 +legs 2977 +presentation 2978 +screenwriter 2979 +jumps 2980 +exists 2981 +attacked 2982 +blair 2983 +laid 2984 +mentally 2985 +bbc 2986 +seeking 2987 +raise 2988 +topic 2989 +oddly 2990 +warner 2991 +inspector 2992 +horrific 2993 +fortunately 2994 +shape 2995 +marvelous 2996 +usa 2997 +intentions 2998 +buck 2999 +retarded 3000 +madness 3001 +stupidity 3002 +stops 3003 +text 3004 +stylish 3005 +stanley 3006 +che 3007 +rival 3008 +served 3009 +workers 3010 +maker 3011 +sides 3012 +ashamed 3013 +shower 3014 +packed 3015 +comedian 3016 +thrilling 3017 +wwii 3018 +interviews 3019 +nine 3020 +laura 3021 +frequently 3022 +upper 3023 +mob 3024 +mansion 3025 +bridge 3026 +remind 3027 +tongue 3028 +navy 3029 +wanna 3030 +contain 3031 +albeit 3032 +intensity 3033 +attacks 3034 +vacation 3035 +thief 3036 +delight 3037 +manager 3038 +chair 3039 +sum 3040 +hence 3041 +80 3042 +cheese 3043 +drives 3044 +2001 3045 +expressions 3046 +struggles 3047 +flawed 3048 +poignant 3049 +angels 3050 +personalities 3051 +rogers 3052 +riding 3053 +revolves 3054 +refuses 3055 +adapted 3056 +opened 3057 +greatly 3058 +credibility 3059 +philip 3060 +cooper 3061 +glass 3062 +pitch 3063 +tracy 3064 +1950s 3065 +jay 3066 +torn 3067 +dinner 3068 +bette 3069 +18 3070 +cynical 3071 +upset 3072 +pool 3073 +sin 3074 +tour 3075 +2000 3076 +internet 3077 +suspects 3078 +advantage 3079 +lessons 3080 +warn 3081 +lion 3082 +overcome 3083 +credible 3084 +wishes 3085 +thousands 3086 +spin 3087 +miller 3088 +racism 3089 +90's 3090 +mindless 3091 +wealthy 3092 +innocence 3093 +tense 3094 +broke 3095 +bugs 3096 +happily 3097 +catholic 3098 +guessing 3099 +trial 3100 +lucy 3101 +hood 3102 +hundreds 3103 +trite 3104 +physically 3105 +thrillers 3106 +cook 3107 +fish 3108 +alike 3109 +dubbing 3110 +fbi 3111 +crisis 3112 +per 3113 +pride 3114 +succeed 3115 +controversial 3116 +suffered 3117 +reed 3118 +bag 3119 +technique 3120 +wasting 3121 +dislike 3122 +medical 3123 +sexuality 3124 +countries 3125 +perform 3126 +patient 3127 +stranger 3128 +enjoyment 3129 +corner 3130 +arm 3131 +glimpse 3132 +gripping 3133 +reunion 3134 +franchise 3135 +holmes 3136 +ensemble 3137 +separate 3138 +hundred 3139 +lincoln 3140 +60's 3141 +sings 3142 +noble 3143 +shines 3144 +whereas 3145 +tied 3146 +ourselves 3147 +uncomfortable 3148 +infamous 3149 +neat 3150 +atmospheric 3151 +millions 3152 +shorts 3153 +contact 3154 +card 3155 +hint 3156 +pack 3157 +courage 3158 +irony 3159 +exceptional 3160 +plastic 3161 +storm 3162 +drink 3163 +ralph 3164 +searching 3165 +oscars 3166 +scripts 3167 +connected 3168 +italy 3169 +proof 3170 +sandler 3171 +snow 3172 +lying 3173 +flash 3174 +nose 3175 +curse 3176 +helen 3177 +sentimental 3178 +mst3k 3179 +grey 3180 +aired 3181 +holiday 3182 +steps 3183 +hills 3184 +performers 3185 +letting 3186 +chasing 3187 +suggests 3188 +dancer 3189 +tune 3190 +meaningful 3191 +idiotic 3192 +knife 3193 +quote 3194 +weapon 3195 +plague 3196 +sons 3197 +entry 3198 +kurt 3199 +fortune 3200 +cameos 3201 +consists 3202 +perfection 3203 +lovable 3204 +hoped 3205 +troubled 3206 +thousand 3207 +hiding 3208 +develops 3209 +unforgettable 3210 +accepted 3211 +noted 3212 +portrait 3213 +dear 3214 +equal 3215 +bettie 3216 +assistant 3217 +stretch 3218 +woman's 3219 +saves 3220 +colorful 3221 +annoyed 3222 +larger 3223 +attraction 3224 +condition 3225 +miscast 3226 +chases 3227 +brooks 3228 +virgin 3229 +spots 3230 +basement 3231 +host 3232 +dialogs 3233 +shoots 3234 +gain 3235 +horses 3236 +guilt 3237 +protagonists 3238 +oil 3239 +terrifying 3240 +month 3241 +cousin 3242 +neighborhood 3243 +vincent 3244 +pg 3245 +belongs 3246 +stealing 3247 +16 3248 +nelson 3249 +worry 3250 +burning 3251 +concert 3252 +ad 3253 +zone 3254 +strip 3255 +appearing 3256 +worlds 3257 +object 3258 +split 3259 +repeat 3260 +hang 3261 +boredom 3262 +destruction 3263 +thirty 3264 +redemption 3265 +hunting 3266 +encounters 3267 +imaginative 3268 +expensive 3269 +eerie 3270 +cube 3271 +seagal 3272 +jake 3273 +pie 3274 +competent 3275 +homeless 3276 +concerns 3277 +andrew 3278 +flaw 3279 +closely 3280 +bo 3281 +ultra 3282 +factory 3283 +1st 3284 +multi 3285 +civil 3286 +dramas 3287 +gag 3288 +stunts 3289 +wake 3290 +guts 3291 +sends 3292 +60 3293 +sutherland 3294 +glory 3295 +knock 3296 +matthau 3297 +massacre 3298 +letter 3299 +elsewhere 3300 +achieved 3301 +dig 3302 +checking 3303 +widmark 3304 +hooked 3305 +complaint 3306 +neck 3307 +endearing 3308 +segments 3309 +shark 3310 +sullivan 3311 +rushed 3312 +virus 3313 +ripped 3314 +charisma 3315 +incoherent 3316 +dragged 3317 +beating 3318 +dentist 3319 +essence 3320 +bears 3321 +profound 3322 +library 3323 +weight 3324 +tear 3325 +crimes 3326 +arnold 3327 +dare 3328 +appearances 3329 +solve 3330 +trade 3331 +pat 3332 +24 3333 +stanwyck 3334 +colour 3335 +teach 3336 +dorothy 3337 +roberts 3338 +rocks 3339 +fest 3340 +spell 3341 +catherine 3342 +dealt 3343 +stan 3344 +fitting 3345 +hitting 3346 +striking 3347 +pro 3348 +2005 3349 +tribute 3350 +tricks 3351 +60s 3352 +battles 3353 +believing 3354 +briefly 3355 +countless 3356 +fashioned 3357 +loser 3358 +goal 3359 +gothic 3360 +noise 3361 +techniques 3362 +n 3363 +videos 3364 +health 3365 +thumbs 3366 +attempted 3367 +scientists 3368 +st 3369 +painting 3370 +baker 3371 +strikes 3372 +inspiring 3373 +huh 3374 +sexually 3375 +birthday 3376 +secretary 3377 +curtis 3378 +jeremy 3379 +covers 3380 +pointed 3381 +slight 3382 +specific 3383 +tea 3384 +hearts 3385 +unintentionally 3386 +denzel 3387 +horrendous 3388 +charismatic 3389 +silver 3390 +surrounded 3391 +surrounding 3392 +reactions 3393 +branagh 3394 +importance 3395 +rochester 3396 +admittedly 3397 +carefully 3398 +jerk 3399 +tons 3400 +hype 3401 +relevant 3402 +they'd 3403 +walls 3404 +stood 3405 +eyed 3406 +bible 3407 +corrupt 3408 +rush 3409 +stunt 3410 +revelation 3411 +smoking 3412 +magazine 3413 +lloyd 3414 +kicks 3415 +karloff 3416 +stronger 3417 +grows 3418 +mild 3419 +hamlet 3420 +represents 3421 +dawn 3422 +andrews 3423 +intention 3424 +easier 3425 +enters 3426 +spending 3427 +scooby 3428 +fired 3429 +killings 3430 +stated 3431 +chances 3432 +shall 3433 +brand 3434 +exercise 3435 +university 3436 +increasingly 3437 +row 3438 +disagree 3439 +cardboard 3440 +winter 3441 +comics 3442 +requires 3443 +dropped 3444 +associated 3445 +world's 3446 +chuck 3447 +iii 3448 +medium 3449 +bush 3450 +projects 3451 +bride 3452 +occurs 3453 +korean 3454 +inevitable 3455 +messages 3456 +brando 3457 +le 3458 +strike 3459 +poverty 3460 +forgive 3461 +performing 3462 +stiff 3463 +attached 3464 +drags 3465 +luckily 3466 +ian 3467 +identify 3468 +1970s 3469 +gift 3470 +bobby 3471 +acceptable 3472 +resolution 3473 +eva 3474 +typically 3475 +canada 3476 +guest 3477 +nuclear 3478 +elvis 3479 +toilet 3480 +strictly 3481 +vague 3482 +spike 3483 +contract 3484 +hire 3485 +1980s 3486 +thrills 3487 +selling 3488 +hudson 3489 +homage 3490 +lab 3491 +boll 3492 +mafia 3493 +depression 3494 +sophisticated 3495 +fifteen 3496 +disease 3497 +allowing 3498 +brilliance 3499 +investigation 3500 +continued 3501 +struck 3502 +insulting 3503 +worker 3504 +instantly 3505 +useless 3506 +breasts 3507 +barry 3508 +jesse 3509 +sally 3510 +afterwards 3511 +chaplin 3512 +britain 3513 +carter 3514 +executive 3515 +handful 3516 +importantly 3517 +godfather 3518 +estate 3519 +hanks 3520 +pleased 3521 +overlooked 3522 +evident 3523 +burn 3524 +gotta 3525 +wreck 3526 +nights 3527 +2002 3528 +beings 3529 +ego 3530 +kidnapped 3531 +presumably 3532 +competition 3533 +press 3534 +partly 3535 +digital 3536 +shining 3537 +commit 3538 +tremendous 3539 +raped 3540 +menacing 3541 +silence 3542 +talked 3543 +derek 3544 +worthless 3545 +jamie 3546 +realise 3547 +ambitious 3548 +meat 3549 +wondered 3550 +photographed 3551 +sacrifice 3552 +arrested 3553 +buried 3554 +burton 3555 +threatening 3556 +smooth 3557 +aforementioned 3558 +superbly 3559 +boxing 3560 +kane 3561 +flawless 3562 +regardless 3563 +fears 3564 +creation 3565 +shy 3566 +heat 3567 +highlights 3568 +savage 3569 +persona 3570 +frustrated 3571 +drivel 3572 +conspiracy 3573 +individuals 3574 +wonders 3575 +listed 3576 +appalling 3577 +doc 3578 +'s 3579 +spiritual 3580 +pushed 3581 +returning 3582 +jumping 3583 +elvira 3584 +cox 3585 +corpse 3586 +size 3587 +characterization 3588 +bullets 3589 +walken 3590 +generous 3591 +string 3592 +rex 3593 +doors 3594 +pleasantly 3595 +bucks 3596 +relative 3597 +45 3598 +outrageous 3599 +kudos 3600 +planning 3601 +ticket 3602 +achievement 3603 +accomplished 3604 +miserably 3605 +monkey 3606 +beaten 3607 +neighbor 3608 +distant 3609 +fatal 3610 +repetitive 3611 +accused 3612 +picking 3613 +ironically 3614 +consequences 3615 +curiosity 3616 +union 3617 +admire 3618 +guide 3619 +splendid 3620 +prevent 3621 +reynolds 3622 +border 3623 +attracted 3624 +butt 3625 +clues 3626 +trap 3627 +notes 3628 +chain 3629 +opposed 3630 +watches 3631 +samurai 3632 +shortly 3633 +heston 3634 +twin 3635 +cole 3636 +glover 3637 +slightest 3638 +response 3639 +beer 3640 +territory 3641 +spooky 3642 +diamond 3643 +rap 3644 +horrors 3645 +20th 3646 +cup 3647 +dire 3648 +spirited 3649 +melodramatic 3650 +lucas 3651 +flynn 3652 +los 3653 +piano 3654 +push 3655 +revealing 3656 +spoiled 3657 +uninspired 3658 +ritter 3659 +convoluted 3660 +pulling 3661 +ken 3662 +root 3663 +they'll 3664 +streisand 3665 +motivation 3666 +directorial 3667 +installment 3668 +precious 3669 +titled 3670 +logical 3671 +documentaries 3672 +spring 3673 +lacked 3674 +suits 3675 +tall 3676 +subplot 3677 +mate 3678 +timeless 3679 +hatred 3680 +throat 3681 +blows 3682 +jealous 3683 +creators 3684 +blank 3685 +farce 3686 +spielberg 3687 +slap 3688 +ward 3689 +carol 3690 +subsequent 3691 +cared 3692 +mile 3693 +exaggerated 3694 +duke 3695 +morality 3696 +liberal 3697 +francisco 3698 +indians 3699 +psychotic 3700 +overdone 3701 +psychiatrist 3702 +astaire 3703 +intrigued 3704 +jet 3705 +blob 3706 +50's 3707 +conceived 3708 +fx 3709 +neil 3710 +aimed 3711 +remaining 3712 +doo 3713 +ignored 3714 +elderly 3715 +reasonably 3716 +mitchell 3717 +failing 3718 +sole 3719 +obscure 3720 +drunken 3721 +minimal 3722 +temple 3723 +progress 3724 +fancy 3725 +captivating 3726 +repeatedly 3727 +wes 3728 +tunes 3729 +shoes 3730 +grandmother 3731 +cia 3732 +nurse 3733 +marks 3734 +notably 3735 +emily 3736 +soviet 3737 +shirt 3738 +explore 3739 +smoke 3740 +souls 3741 +pushing 3742 +argument 3743 +distance 3744 +warrior 3745 +outcome 3746 +reduced 3747 +loosely 3748 +scientific 3749 +goldberg 3750 +gradually 3751 +bleak 3752 +timothy 3753 +manhattan 3754 +idiots 3755 +restaurant 3756 +scripted 3757 +misses 3758 +explicit 3759 +providing 3760 +elaborate 3761 +poster 3762 +lou 3763 +dignity 3764 +carpenter 3765 +norman 3766 +rid 3767 +turner 3768 +show's 3769 +davies 3770 +draws 3771 +discussion 3772 +exposed 3773 +mel 3774 +sticks 3775 +kenneth 3776 +definite 3777 +darker 3778 +laurel 3779 +intent 3780 +1950's 3781 +returned 3782 +superhero 3783 +sloppy 3784 +cried 3785 +worried 3786 +childish 3787 +shadows 3788 +craig 3789 +cruise 3790 +hysterical 3791 +imagined 3792 +reasonable 3793 +editor 3794 +ah 3795 +birds 3796 +horrid 3797 +areas 3798 +wicked 3799 +gentle 3800 +wannabe 3801 +alexander 3802 +thick 3803 +contrary 3804 +joey 3805 +empire 3806 +connect 3807 +discovery 3808 +unbearable 3809 +tortured 3810 +screams 3811 +fever 3812 +unbelievably 3813 +1930s 3814 +disc 3815 +99 3816 +load 3817 +heroic 3818 +absence 3819 +reached 3820 +ho 3821 +choreography 3822 +triumph 3823 +complain 3824 +annie 3825 +broad 3826 +improved 3827 +concerning 3828 +brazil 3829 +movements 3830 +2003 3831 +2004 3832 +dave 3833 +folk 3834 +eve 3835 +purple 3836 +commercials 3837 +futuristic 3838 +vicious 3839 +gray 3840 +freak 3841 +threat 3842 +cusack 3843 +extended 3844 +citizen 3845 +stole 3846 +anyways 3847 +glenn 3848 +existed 3849 +cheek 3850 +broadcast 3851 +photographer 3852 +translation 3853 +arrive 3854 +differences 3855 +displays 3856 +critic 3857 +slave 3858 +landscape 3859 +occurred 3860 +builds 3861 +drawing 3862 +incident 3863 +warren 3864 +burned 3865 +involvement 3866 +styles 3867 +bathroom 3868 +machines 3869 +narrator 3870 +antics 3871 +he'll 3872 +fisher 3873 +swear 3874 +australia 3875 +matthew 3876 +resembles 3877 +lily 3878 +overrated 3879 +currently 3880 +symbolism 3881 +ought 3882 +bare 3883 +audio 3884 +web 3885 +farm 3886 +contained 3887 +greek 3888 +affected 3889 +blend 3890 +q 3891 +recognized 3892 +duo 3893 +genres 3894 +population 3895 +carrie 3896 +ranks 3897 +demands 3898 +we'll 3899 +abc 3900 +prom 3901 +altogether 3902 +superficial 3903 +kitchen 3904 +pseudo 3905 +sunshine 3906 +sadness 3907 +secrets 3908 +bone 3909 +website 3910 +receive 3911 +popcorn 3912 +threw 3913 +craft 3914 +enjoys 3915 +occur 3916 +twelve 3917 +block 3918 +girl's 3919 +proceedings 3920 +dynamic 3921 +daring 3922 +swedish 3923 +argue 3924 +bite 3925 +wolf 3926 +adequate 3927 +investigate 3928 +harder 3929 +ruth 3930 +ridiculously 3931 +tap 3932 +dinosaurs 3933 +hugh 3934 +synopsis 3935 +beats 3936 +carrey 3937 +explosion 3938 +foul 3939 +merit 3940 +suited 3941 +holy 3942 +staged 3943 +journalist 3944 +pretend 3945 +composed 3946 +cagney 3947 +robots 3948 +giallo 3949 +aging 3950 +fay 3951 +sadistic 3952 +engaged 3953 +escaped 3954 +juvenile 3955 +rambo 3956 +ireland 3957 +conversations 3958 +thugs 3959 +modesty 3960 +selfish 3961 +margaret 3962 +dialogues 3963 +ease 3964 +cameras 3965 +tame 3966 +leg 3967 +rural 3968 +comfortable 3969 +nazis 3970 +clothing 3971 +innovative 3972 +terry 3973 +thrill 3974 +2nd 3975 +dancers 3976 +brosnan 3977 +explosions 3978 +bin 3979 +rage 3980 +overwhelming 3981 +jazz 3982 +vivid 3983 +coherent 3984 +bullet 3985 +odds 3986 +mountains 3987 +kidding 3988 +versus 3989 +lit 3990 +offering 3991 +mother's 3992 +trio 3993 +newspaper 3994 +pulp 3995 +ellen 3996 +dawson 3997 +bird 3998 +buddies 3999 +combat 4000 +dracula 4001 +lol 4002 +grab 4003 +orders 4004 +staff 4005 +nearby 4006 +cats 4007 +wealth 4008 +unpleasant 4009 +staying 4010 +devoted 4011 +centered 4012 +errors 4013 +disturbed 4014 +bell 4015 +atlantis 4016 +snake 4017 +felix 4018 +damage 4019 +clint 4020 +lust 4021 +groups 4022 +banned 4023 +blowing 4024 +fighter 4025 +removed 4026 +react 4027 +conventional 4028 +kapoor 4029 +intrigue 4030 +possessed 4031 +cringe 4032 +eyre 4033 +liking 4034 +implausible 4035 +philosophy 4036 +producing 4037 +abilities 4038 +seventies 4039 +bang 4040 +murderous 4041 +deliberately 4042 +gandhi 4043 +tommy 4044 +meaningless 4045 +subjects 4046 +lips 4047 +ingredients 4048 +mildred 4049 +perry 4050 +warming 4051 +causing 4052 +possibility 4053 +detailed 4054 +walker 4055 +garden 4056 +prostitute 4057 +nightmares 4058 +cameron 4059 +flop 4060 +influenced 4061 +spare 4062 +unwatchable 4063 +undoubtedly 4064 +celluloid 4065 +relies 4066 +resemblance 4067 +neo 4068 +parent 4069 +falk 4070 +uneven 4071 +unintentional 4072 +eccentric 4073 +mistaken 4074 +distracting 4075 +careers 4076 +yesterday 4077 +forbidden 4078 +panic 4079 +crack 4080 +brains 4081 +highest 4082 +occasion 4083 +signs 4084 +focusing 4085 +hollow 4086 +explored 4087 +aid 4088 +cary 4089 +scheme 4090 +shine 4091 +it'll 4092 +kirk 4093 +bedroom 4094 +satisfied 4095 +rat 4096 +passes 4097 +survival 4098 +coffee 4099 +furthermore 4100 +primary 4101 +succeeded 4102 +politically 4103 +pays 4104 +apes 4105 +stiller 4106 +dating 4107 +defeat 4108 +sport 4109 +catches 4110 +mickey 4111 +clown 4112 +roman 4113 +discuss 4114 +karen 4115 +clumsy 4116 +chaos 4117 +financial 4118 +official 4119 +trees 4120 +explaining 4121 +models 4122 +spirits 4123 +carl 4124 +jeffrey 4125 +duty 4126 +whale 4127 +funeral 4128 +secondly 4129 +sentence 4130 +2007 4131 +classes 4132 +sidekick 4133 +tracks 4134 +props 4135 +travels 4136 +flies 4137 +remarkably 4138 +smaller 4139 +wallace 4140 +awake 4141 +1996 4142 +brady 4143 +blatant 4144 +decisions 4145 +afford 4146 +notion 4147 +recorded 4148 +glorious 4149 +enterprise 4150 +maggie 4151 +consistently 4152 +toys 4153 +offended 4154 +officers 4155 +danes 4156 +backdrop 4157 +beneath 4158 +masters 4159 +measure 4160 +endings 4161 +doomed 4162 +mysteries 4163 +lifestyle 4164 +houses 4165 +portion 4166 +primarily 4167 +satan 4168 +hates 4169 +devoid 4170 +impress 4171 +outer 4172 +generic 4173 +dutch 4174 +punk 4175 +lyrics 4176 +yellow 4177 +eastwood 4178 +exotic 4179 +represent 4180 +instant 4181 +desperation 4182 +mixture 4183 +settle 4184 +frustration 4185 +unfolds 4186 +goodness 4187 +wives 4188 +directs 4189 +fetched 4190 +ape 4191 +cheating 4192 +dozens 4193 +rebel 4194 +cuba 4195 +paulie 4196 +enormous 4197 +revolutionary 4198 +hints 4199 +shelf 4200 +brooklyn 4201 +florida 4202 +dances 4203 +motives 4204 +destiny 4205 +1999 4206 +donna 4207 +hardcore 4208 +mill 4209 +wrestling 4210 +subtlety 4211 +forty 4212 +describes 4213 +drops 4214 +blake 4215 +stinker 4216 +doll 4217 +painted 4218 +fond 4219 +linda 4220 +principal 4221 +rank 4222 +ideal 4223 +kennedy 4224 +hammer 4225 +montage 4226 +hollywood's 4227 +tie 4228 +disjointed 4229 +3rd 4230 +reaches 4231 +amy 4232 +immensely 4233 +ginger 4234 +judging 4235 +companion 4236 +communist 4237 +urge 4238 +winds 4239 +developing 4240 +trailers 4241 +cliff 4242 +lawrence 4243 +stellar 4244 +topless 4245 +circle 4246 +surviving 4247 +avoided 4248 +relations 4249 +bold 4250 +hideous 4251 +voight 4252 +closet 4253 +et 4254 +surfing 4255 +melting 4256 +soccer 4257 +edie 4258 +matches 4259 +backgrounds 4260 +planned 4261 +enemies 4262 +advance 4263 +bull 4264 +authority 4265 +crush 4266 +outfit 4267 +emphasis 4268 +method 4269 +terrorist 4270 +senseless 4271 +pig 4272 +uwe 4273 +simplistic 4274 +benefit 4275 +adorable 4276 +eighties 4277 +ruthless 4278 +godzilla 4279 +blew 4280 +countryside 4281 +specifically 4282 +wont 4283 +performer 4284 +hbo 4285 +traveling 4286 +todd 4287 +practice 4288 +diane 4289 +fix 4290 +faster 4291 +1980 4292 +commented 4293 +sh 4294 +loyal 4295 +saga 4296 +ties 4297 +disappear 4298 +awe 4299 +earned 4300 +buff 4301 +rick 4302 +loads 4303 +link 4304 +angeles 4305 +corruption 4306 +forms 4307 +menace 4308 +miserable 4309 +claimed 4310 +vast 4311 +coach 4312 +divorce 4313 +hal 4314 +gadget 4315 +chorus 4316 +limits 4317 +cure 4318 +introduces 4319 +cards 4320 +solo 4321 +blues 4322 +splatter 4323 +april 4324 +endure 4325 +riveting 4326 +dedicated 4327 +tender 4328 +winters 4329 +illogical 4330 +choreographed 4331 +disappeared 4332 +unsettling 4333 +waters 4334 +guessed 4335 +lemmon 4336 +involve 4337 +transformation 4338 +depressed 4339 +rooms 4340 +lasted 4341 +displayed 4342 +weakest 4343 +leonard 4344 +philosophical 4345 +racial 4346 +interaction 4347 +arrogant 4348 +tag 4349 +rocket 4350 +similarities 4351 +hurts 4352 +thoughtful 4353 +realizing 4354 +harvey 4355 +justify 4356 +hook 4357 +survivors 4358 +represented 4359 +pot 4360 +possibilities 4361 +wore 4362 +disappoint 4363 +voiced 4364 +kicked 4365 +abysmal 4366 +hamilton 4367 +buffs 4368 +safety 4369 +widow 4370 +ears 4371 +nomination 4372 +trashy 4373 +honesty 4374 +stereotype 4375 +severe 4376 +formulaic 4377 +moody 4378 +similarly 4379 +stress 4380 +pan 4381 +chased 4382 +isolated 4383 +blond 4384 +stinks 4385 +mario 4386 +passionate 4387 +finger 4388 +shirley 4389 +march 4390 +hank 4391 +improve 4392 +mann 4393 +understandable 4394 +characters' 4395 +considerable 4396 +scope 4397 +holly 4398 +diana 4399 +grasp 4400 +command 4401 +solely 4402 +'em 4403 +concern 4404 +treats 4405 +akshay 4406 +promised 4407 +colonel 4408 +jonathan 4409 +faults 4410 +helicopter 4411 +inventive 4412 +sounding 4413 +quotes 4414 +trained 4415 +switch 4416 +celebrity 4417 +tad 4418 +swimming 4419 +orson 4420 +education 4421 +aids 4422 +nail 4423 +judy 4424 +cg 4425 +user 4426 +nervous 4427 +nostalgic 4428 +daddy 4429 +alert 4430 +amanda 4431 +facing 4432 +comparing 4433 +unhappy 4434 +preview 4435 +report 4436 +bonus 4437 +purchase 4438 +chess 4439 +wet 4440 +lately 4441 +horrifying 4442 +agrees 4443 +thru 4444 +dolls 4445 +cinematographer 4446 +ignorant 4447 +species 4448 +seed 4449 +consistent 4450 +downhill 4451 +corporate 4452 +photos 4453 +confidence 4454 +letters 4455 +berlin 4456 +dinosaur 4457 +rotten 4458 +taught 4459 +fooled 4460 +laws 4461 +nicholson 4462 +namely 4463 +shake 4464 +waited 4465 +wished 4466 +embarrassment 4467 +everyone's 4468 +boot 4469 +pretending 4470 +reaching 4471 +someone's 4472 +transfer 4473 +sits 4474 +armed 4475 +del 4476 +dub 4477 +defend 4478 +hart 4479 +35 4480 +constructed 4481 +mall 4482 +poetic 4483 +motivations 4484 +inane 4485 +behave 4486 +tonight 4487 +staring 4488 +humble 4489 +snl 4490 +elephant 4491 +agents 4492 +oz 4493 +grandfather 4494 +writes 4495 +relation 4496 +hop 4497 +delivering 4498 +fonda 4499 +edgar 4500 +cave 4501 +artificial 4502 +grinch 4503 +sappy 4504 +prize 4505 +1972 4506 +useful 4507 +buildings 4508 +li 4509 +cake 4510 +eager 4511 +closest 4512 +suitable 4513 +raising 4514 +destroying 4515 +combine 4516 +beatty 4517 +pants 4518 +cleverly 4519 +ballet 4520 +convincingly 4521 +porno 4522 +1990 4523 +miike 4524 +affect 4525 +engage 4526 +cd 4527 +conservative 4528 +wound 4529 +arrived 4530 +stevens 4531 +alcoholic 4532 +valuable 4533 +ya 4534 +reads 4535 +scottish 4536 +elegant 4537 +vegas 4538 +chest 4539 +charlotte 4540 +climactic 4541 +tiresome 4542 +z 4543 +conflicts 4544 +babe 4545 +vengeance 4546 +square 4547 +bath 4548 +secretly 4549 +airport 4550 +campbell 4551 +kingdom 4552 +september 4553 +inferior 4554 +1968 4555 +latin 4556 +plant 4557 +button 4558 +museum 4559 +maintain 4560 +wrapped 4561 +kicking 4562 +cheated 4563 +global 4564 +robbery 4565 +virginia 4566 +wells 4567 +waves 4568 +stilted 4569 +blunt 4570 +lena 4571 +boom 4572 +access 4573 +raymond 4574 +1960s 4575 +catching 4576 +nicholas 4577 +yelling 4578 +scarecrow 4579 +beliefs 4580 +paranoia 4581 +christians 4582 +vice 4583 +jumped 4584 +lay 4585 +iron 4586 +steel 4587 +lowest 4588 +reflect 4589 +closed 4590 +mummy 4591 +transition 4592 +advertising 4593 +vulnerable 4594 +abusive 4595 +1970's 4596 +spoke 4597 +plight 4598 +mars 4599 +spread 4600 +adams 4601 +wizard 4602 +poetry 4603 +im 4604 +sandra 4605 +germans 4606 +pokemon 4607 +progresses 4608 +70 4609 +00 4610 +hung 4611 +questionable 4612 +remarks 4613 +airplane 4614 +centers 4615 +potentially 4616 +bottle 4617 +chicago 4618 +guarantee 4619 +couples 4620 +messed 4621 +catchy 4622 +slick 4623 +gangsters 4624 +misery 4625 +blade 4626 +designs 4627 +construction 4628 +ethan 4629 +desired 4630 +miracle 4631 +carradine 4632 +firstly 4633 +scores 4634 +wandering 4635 +greedy 4636 +recognition 4637 +understated 4638 +restored 4639 +complexity 4640 +madonna 4641 +attitudes 4642 +rendition 4643 +hunters 4644 +intentionally 4645 +experiments 4646 +ruby 4647 +alongside 4648 +vaguely 4649 +inappropriate 4650 +copies 4651 +operation 4652 +brutally 4653 +taxi 4654 +amounts 4655 +stooges 4656 +joined 4657 +pearl 4658 +demand 4659 +crocodile 4660 +depicts 4661 +purchased 4662 +acid 4663 +myers 4664 +exploration 4665 +advise 4666 +illegal 4667 +balls 4668 +king's 4669 +gundam 4670 +disney's 4671 +gender 4672 +lengthy 4673 +survived 4674 +hopper 4675 +niro 4676 +advanced 4677 +simplicity 4678 +bela 4679 +parallel 4680 +ocean 4681 +slaughter 4682 +rising 4683 +witnesses 4684 +chicks 4685 +streep 4686 +visible 4687 +nostalgia 4688 +arguably 4689 +careful 4690 +intimate 4691 +online 4692 +floating 4693 +rubber 4694 +june 4695 +illness 4696 +resources 4697 +khan 4698 +jaw 4699 +newly 4700 +witches 4701 +showcase 4702 +signed 4703 +opinions 4704 +dust 4705 +eaten 4706 +civilization 4707 +shelley 4708 +incomprehensible 4709 +invasion 4710 +lee's 4711 +monkeys 4712 +resort 4713 +literature 4714 +junior 4715 +likewise 4716 +homosexual 4717 +family's 4718 +viewings 4719 +sue 4720 +wisdom 4721 +matched 4722 +amitabh 4723 +edition 4724 +witnessed 4725 +visits 4726 +mistress 4727 +1983 4728 +demented 4729 +basketball 4730 +neighbors 4731 +macy 4732 +fascinated 4733 +dreary 4734 +suspicious 4735 +accompanied 4736 +worn 4737 +mail 4738 +challenging 4739 +doom 4740 +ensues 4741 +manipulative 4742 +robinson 4743 +classical 4744 +olivier 4745 +agreed 4746 +appreciation 4747 +franco 4748 +montana 4749 +troops 4750 +capturing 4751 +alternate 4752 +bands 4753 +twilight 4754 +ridden 4755 +responsibility 4756 +proceeds 4757 +chapter 4758 +jenny 4759 +prisoners 4760 +pops 4761 +analysis 4762 +subplots 4763 +lively 4764 +nuts 4765 +prisoner 4766 +incompetent 4767 +damon 4768 +sellers 4769 +mayor 4770 +rats 4771 +simpson 4772 +90s 4773 +persons 4774 +feed 4775 +descent 4776 +reel 4777 +bay 4778 +assault 4779 +losers 4780 +widely 4781 +rabbit 4782 +smiling 4783 +relatives 4784 +excessive 4785 +defined 4786 +satisfy 4787 +solution 4788 +legal 4789 +molly 4790 +arrival 4791 +overacting 4792 +equivalent 4793 +iran 4794 +pit 4795 +masterful 4796 +capital 4797 +richardson 4798 +compelled 4799 +plausible 4800 +stale 4801 +scrooge 4802 +cities 4803 +francis 4804 +enthusiasm 4805 +lone 4806 +parties 4807 +tomatoes 4808 +channels 4809 +hilariously 4810 +rocky 4811 +crucial 4812 +dropping 4813 +unit 4814 +waitress 4815 +domestic 4816 +attorney 4817 +bakshi 4818 +serving 4819 +wrap 4820 +jaws 4821 +historically 4822 +3d 4823 +defense 4824 +hello 4825 +greed 4826 +1973 4827 +priceless 4828 +sincere 4829 +warmth 4830 +paltrow 4831 +gerard 4832 +tends 4833 +god's 4834 +patients 4835 +creep 4836 +counter 4837 +dalton 4838 +kay 4839 +whats 4840 +louise 4841 +peoples 4842 +exceptionally 4843 +nyc 4844 +pal 4845 +seeks 4846 +terrorists 4847 +lumet 4848 +morris 4849 +ninja 4850 +randomly 4851 +frequent 4852 +despair 4853 +irrelevant 4854 +dressing 4855 +pursuit 4856 +prequel 4857 +creativity 4858 +imitation 4859 +bumbling 4860 +hyde 4861 +property 4862 +muslim 4863 +wishing 4864 +richards 4865 +bargain 4866 +50s 4867 +creator 4868 +calm 4869 +bacall 4870 +gabriel 4871 +mentioning 4872 +rangers 4873 +methods 4874 +earl 4875 +royal 4876 +butler 4877 +justin 4878 +psychic 4879 +chooses 4880 +belong 4881 +der 4882 +photo 4883 +polanski 4884 +mundane 4885 +specially 4886 +mighty 4887 +homer 4888 +ear 4889 +masterpieces 4890 +generated 4891 +leo 4892 +improvement 4893 +poem 4894 +ham 4895 +cliche 4896 +marty 4897 +caliber 4898 +mentions 4899 +minimum 4900 +showdown 4901 +borrowed 4902 +elm 4903 +icon 4904 +brenda 4905 +polished 4906 +1984 4907 +mechanical 4908 +overlook 4909 +loaded 4910 +map 4911 +recording 4912 +craven 4913 +tiger 4914 +roth 4915 +awfully 4916 +suffice 4917 +troubles 4918 +introduce 4919 +equipment 4920 +ashley 4921 +wendy 4922 +pamela 4923 +empathy 4924 +phantom 4925 +betty 4926 +resident 4927 +unreal 4928 +ruins 4929 +performs 4930 +promises 4931 +monk 4932 +iraq 4933 +hippie 4934 +purposes 4935 +marketing 4936 +angela 4937 +keith 4938 +sink 4939 +gifted 4940 +opportunities 4941 +garbo 4942 +assigned 4943 +feminist 4944 +household 4945 +wacky 4946 +alfred 4947 +absent 4948 +sneak 4949 +popularity 4950 +trail 4951 +inducing 4952 +moronic 4953 +wounded 4954 +receives 4955 +willis 4956 +unseen 4957 +stretched 4958 +fulci 4959 +unaware 4960 +dimension 4961 +dolph 4962 +definition 4963 +testament 4964 +educational 4965 +survivor 4966 +attend 4967 +clip 4968 +contest 4969 +petty 4970 +13th 4971 +christy 4972 +respected 4973 +resist 4974 +year's 4975 +album 4976 +expressed 4977 +randy 4978 +quit 4979 +phony 4980 +unoriginal 4981 +punishment 4982 +activities 4983 +suspend 4984 +rolled 4985 +eastern 4986 +1933 4987 +instinct 4988 +distinct 4989 +championship 4990 +tech 4991 +doubts 4992 +interests 4993 +exposure 4994 +travesty 4995 +israel 4996 +sixties 4997 +pink 4998 +orange 4999 +resulting 5000 +spain 5001 +bergman 5002 +1987 5003 +verhoeven 5004 +distribution 5005 +laughably 5006 +depicting 5007 +kissing 5008 +tooth 5009 +shed 5010 +kubrick 5011 +pin 5012 +nonsensical 5013 +roots 5014 +assumed 5015 +swim 5016 +whoopi 5017 +domino 5018 +heights 5019 +spock 5020 +inevitably 5021 +abraham 5022 +stunned 5023 +businessman 5024 +correctly 5025 +deceased 5026 +buffalo 5027 +wholly 5028 +underlying 5029 +dud 5030 +othello 5031 +unpredictable 5032 +package 5033 +hopeless 5034 +teaching 5035 +valley 5036 +uplifting 5037 +peters 5038 +integrity 5039 +1993 5040 +biography 5041 +yard 5042 +brutality 5043 +america's 5044 +trademark 5045 +retired 5046 +shaw 5047 +reflection 5048 +maniac 5049 +– 5050 +meryl 5051 +accuracy 5052 +sid 5053 +compassion 5054 +dreck 5055 +2008 5056 +edgy 5057 +greatness 5058 +assassin 5059 +greg 5060 +palace 5061 +suggested 5062 +patience 5063 +landscapes 5064 +1971 5065 +mankind 5066 +supported 5067 +merits 5068 +directions 5069 +fed 5070 +romero 5071 +spider 5072 +mtv 5073 +metaphor 5074 +masses 5075 +puppet 5076 +seldom 5077 +wife's 5078 +loyalty 5079 +deaf 5080 +grayson 5081 +strangers 5082 +3000 5083 +passable 5084 +checked 5085 +connery 5086 +confess 5087 +shaky 5088 +drake 5089 +eugene 5090 +significance 5091 +pierce 5092 +unfair 5093 +maid 5094 +indulgent 5095 +comfort 5096 +orleans 5097 +willie 5098 +glasses 5099 +pressure 5100 +alec 5101 +composer 5102 +marion 5103 +nicole 5104 +tribe 5105 +fought 5106 +technicolor 5107 +watson 5108 +dee 5109 +emperor 5110 +adaptations 5111 +romp 5112 +peak 5113 +conditions 5114 +grabs 5115 +exchange 5116 +fury 5117 +immediate 5118 +women's 5119 +timon 5120 +omen 5121 +generations 5122 +barrymore 5123 +resemble 5124 +1995 5125 +1997 5126 +confrontation 5127 +landing 5128 +frustrating 5129 +demise 5130 +spacey 5131 +lackluster 5132 +disliked 5133 +kyle 5134 +y 5135 +victory 5136 +wretched 5137 +Â… 5138 +farrell 5139 +we'd 5140 +respectively 5141 +crazed 5142 +din 5143 +expedition 5144 +chicken 5145 +cannibal 5146 +conscious 5147 +experimental 5148 +astonishing 5149 +inability 5150 +examination 5151 +wilderness 5152 +tube 5153 +blast 5154 +nerd 5155 +legacy 5156 +companies 5157 +subjected 5158 +ships 5159 +rises 5160 +invented 5161 +stuart 5162 +ambiguous 5163 +grief 5164 +rave 5165 +cracking 5166 +unexpectedly 5167 +scotland 5168 +stargate 5169 +milk 5170 +singers 5171 +darren 5172 +billed 5173 +tripe 5174 +ordered 5175 +furious 5176 +flair 5177 +griffith 5178 +refused 5179 +fascination 5180 +tastes 5181 +owen 5182 +frightened 5183 +amused 5184 +masks 5185 +females 5186 +graham 5187 +rates 5188 +simultaneously 5189 +senses 5190 +walsh 5191 +marc 5192 +simmons 5193 +shanghai 5194 +premiere 5195 +remained 5196 +warriors 5197 +1936 5198 +josh 5199 +antwone 5200 +difficulties 5201 +shoulders 5202 +femme 5203 +alternative 5204 +sentiment 5205 +relax 5206 +ollie 5207 +leon 5208 +rooney 5209 +objective 5210 +deranged 5211 +alcohol 5212 +austin 5213 +sissy 5214 +tank 5215 +dysfunctional 5216 +vulgar 5217 +stumbled 5218 +desires 5219 +replace 5220 +dixon 5221 +claus 5222 +joel 5223 +hears 5224 +coast 5225 +poison 5226 +addicted 5227 +slice 5228 +lundgren 5229 +parade 5230 +gather 5231 +appropriately 5232 +abused 5233 +cream 5234 +challenged 5235 +awhile 5236 +tacky 5237 +interactions 5238 +function 5239 +pun 5240 +bud 5241 +filling 5242 +primitive 5243 +fishing 5244 +raises 5245 +infected 5246 +musicians 5247 +precisely 5248 +caricatures 5249 +karl 5250 +underneath 5251 +ross 5252 +alicia 5253 +prey 5254 +fingers 5255 +nephew 5256 +crystal 5257 +skull 5258 +remakes 5259 +favour 5260 +wildly 5261 +phil 5262 +phrase 5263 +julian 5264 +sopranos 5265 +complaints 5266 +presenting 5267 +noises 5268 +19th 5269 +twins 5270 +les 5271 +ramones 5272 +lands 5273 +joins 5274 +wakes 5275 +require 5276 +fifty 5277 +items 5278 +frankenstein 5279 +nathan 5280 +christianity 5281 +reid 5282 +accomplish 5283 +22 5284 +dana 5285 +wang 5286 +breed 5287 +millionaire 5288 +sums 5289 +knocked 5290 +teaches 5291 +literary 5292 +loneliness 5293 +fiancé 5294 +complaining 5295 +silliness 5296 +sharon 5297 +celebration 5298 +gentleman 5299 +ustinov 5300 +husband's 5301 +exposition 5302 +choppy 5303 +altman 5304 +minus 5305 +amusement 5306 +sugar 5307 +husbands 5308 +framed 5309 +other's 5310 +andre 5311 +unlikable 5312 +sunny 5313 +roommate 5314 +stark 5315 +absurdity 5316 +rifle 5317 +electric 5318 +posters 5319 +aspiring 5320 +conscience 5321 +fields 5322 +hackneyed 5323 +downey 5324 +buster 5325 +edit 5326 +straightforward 5327 +misleading 5328 +carell 5329 +murdering 5330 +credited 5331 +sung 5332 +releases 5333 +muddled 5334 +raines 5335 +coincidence 5336 +unfold 5337 +rude 5338 +charged 5339 +weakness 5340 +quietly 5341 +pitiful 5342 +marshall 5343 +objects 5344 +shared 5345 +inexplicably 5346 +automatically 5347 +heartfelt 5348 +agenda 5349 +dresses 5350 +trend 5351 +acclaimed 5352 +blacks 5353 +murray 5354 +beverly 5355 +asylum 5356 +belushi 5357 +en 5358 +moreover 5359 +shoddy 5360 +bernard 5361 +teachers 5362 +devices 5363 +cattle 5364 +preston 5365 +dont 5366 +grotesque 5367 +visited 5368 +discovering 5369 +roof 5370 +spark 5371 +realised 5372 +handling 5373 +adopted 5374 +bread 5375 +haired 5376 +ethnic 5377 +encourage 5378 +lock 5379 +conviction 5380 +imaginable 5381 +fog 5382 +crawford 5383 +firm 5384 +servant 5385 +invites 5386 +dirt 5387 +cancer 5388 +fantasies 5389 +rely 5390 +biased 5391 +occasions 5392 +dose 5393 +industrial 5394 +harm 5395 +hungry 5396 +vance 5397 +kansas 5398 +active 5399 +preposterous 5400 +profanity 5401 +positively 5402 +prepare 5403 +ladder 5404 +sketch 5405 +alison 5406 +controlled 5407 +squad 5408 +outfits 5409 +deniro 5410 +canyon 5411 +babies 5412 +frankie 5413 +referred 5414 +kumar 5415 +regarded 5416 +designer 5417 +1988 5418 +paradise 5419 +comedians 5420 +russia 5421 +fido 5422 +provocative 5423 +behaviour 5424 +region 5425 +1930's 5426 +baldwin 5427 +laurence 5428 +translated 5429 +tracking 5430 +clock 5431 +1939 5432 +chills 5433 +hawke 5434 +cue 5435 +heist 5436 +citizens 5437 +da 5438 +1978 5439 +mode 5440 +hk 5441 +counts 5442 +riot 5443 +uncut 5444 +musician 5445 +accepts 5446 +shoulder 5447 +heartbreaking 5448 +secondary 5449 +option 5450 +75 5451 +roller 5452 +1980's 5453 +fathers 5454 +mclaglen 5455 +hopelessly 5456 +tasteless 5457 +bye 5458 +challenges 5459 +bitch 5460 +additional 5461 +backs 5462 +should've 5463 +swing 5464 +betrayal 5465 +labor 5466 +lush 5467 +morbid 5468 +abrupt 5469 +gambling 5470 +historic 5471 +iv 5472 +insurance 5473 +1986 5474 +fade 5475 +screens 5476 +bike 5477 +damme 5478 +pages 5479 +nut 5480 +admirable 5481 +rejected 5482 +skits 5483 +lip 5484 +ignorance 5485 +chainsaw 5486 +cassidy 5487 +suspension 5488 +respective 5489 +nod 5490 +chuckle 5491 +recommendation 5492 +guitar 5493 +youngest 5494 +reign 5495 +1970 5496 +biko 5497 +severely 5498 +affection 5499 +coaster 5500 +visiting 5501 +kid's 5502 +darn 5503 +refer 5504 +boxer 5505 +naughty 5506 +macarthur 5507 +deserted 5508 +amazon 5509 +paramount 5510 +files 5511 +corpses 5512 +realm 5513 +nemesis 5514 +1979 5515 +sabrina 5516 +address 5517 +beware 5518 +shares 5519 +tomorrow 5520 +prejudice 5521 +el 5522 +guaranteed 5523 +wwe 5524 +sooner 5525 +reluctant 5526 +1989 5527 +invited 5528 +aim 5529 +dickens 5530 +evidently 5531 +lindsay 5532 +hyped 5533 +penny 5534 +praised 5535 +jews 5536 +sympathize 5537 +barrel 5538 +disappears 5539 +guests 5540 +anticipation 5541 +conventions 5542 +outs 5543 +tail 5544 +deleted 5545 +freaks 5546 +rome 5547 +indication 5548 +bunny 5549 +actor's 5550 +19 5551 +fist 5552 +mayhem 5553 +1969 5554 +policeman 5555 +cannon 5556 +thread 5557 +basinger 5558 +bridget 5559 +selection 5560 +palma 5561 +inconsistent 5562 +saint 5563 +stopping 5564 +gut 5565 +burst 5566 +visions 5567 +angst 5568 +daughter's 5569 +beside 5570 +reader 5571 +sentinel 5572 +nails 5573 +promote 5574 +weaknesses 5575 +heading 5576 +www 5577 +venture 5578 +malone 5579 +misguided 5580 +1960's 5581 +muppet 5582 +uh 5583 +drove 5584 +overlong 5585 +gal 5586 +cope 5587 +mccoy 5588 +threatens 5589 +iconic 5590 +rita 5591 +stages 5592 +underworld 5593 +adolescent 5594 +tip 5595 +previews 5596 +depending 5597 +hammy 5598 +behold 5599 +steady 5600 +circus 5601 +filler 5602 +conveys 5603 +glowing 5604 +vader 5605 +shades 5606 +acceptance 5607 +psychology 5608 +bent 5609 +banal 5610 +receiving 5611 +palance 5612 +reflects 5613 +cruelty 5614 +guy's 5615 +tyler 5616 +insipid 5617 +posted 5618 +hack 5619 +curly 5620 +sassy 5621 +nicolas 5622 +harmless 5623 +morally 5624 +affairs 5625 +macho 5626 +understands 5627 +fluff 5628 +demonstrates 5629 +exceptions 5630 +bow 5631 +investigating 5632 +widescreen 5633 +30's 5634 +remade 5635 +studies 5636 +records 5637 +bros 5638 +unexplained 5639 +sirk 5640 +oldest 5641 +firing 5642 +vein 5643 +explores 5644 +completed 5645 +eternal 5646 +marvel 5647 +preachy 5648 +triple 5649 +schlock 5650 +min 5651 +employed 5652 +campaign 5653 +difficulty 5654 +strongest 5655 +gregory 5656 +grainy 5657 +popping 5658 +disguise 5659 +filth 5660 +dates 5661 +obligatory 5662 +robbins 5663 +terrified 5664 +portrayals 5665 +commander 5666 +hokey 5667 +emerges 5668 +confident 5669 +connections 5670 +lifted 5671 +artsy 5672 +height 5673 +entitled 5674 +outing 5675 +rukh 5676 +hopkins 5677 +pounds 5678 +sending 5679 +hapless 5680 +physics 5681 +phenomenon 5682 +assuming 5683 +unrelated 5684 +kitty 5685 +repeating 5686 +stores 5687 +attract 5688 +fifties 5689 +assured 5690 +clan 5691 +insists 5692 +interestingly 5693 +patricia 5694 +mentality 5695 +knight 5696 +1981 5697 +bug 5698 +paxton 5699 +pole 5700 +hughes 5701 +communicate 5702 +sox 5703 +rhythm 5704 +nolan 5705 +bitten 5706 +despicable 5707 +slimy 5708 +predict 5709 +recognizable 5710 +rounded 5711 +shakespeare's 5712 +gate 5713 +1945 5714 +recycled 5715 +conclude 5716 +casual 5717 +disgusted 5718 +comparisons 5719 +zombi 5720 +couch 5721 +offs 5722 +vital 5723 +representation 5724 +rod 5725 +duck 5726 +martha 5727 +danish 5728 +yawn 5729 +studying 5730 +1976 5731 +clarke 5732 +woo 5733 +route 5734 +prominent 5735 +tarantino 5736 +legends 5737 +paintings 5738 +suitably 5739 +someday 5740 +snakes 5741 +absorbed 5742 +stairs 5743 +redeem 5744 +gear 5745 +shortcomings 5746 +agency 5747 +tempted 5748 +rapist 5749 +inexplicable 5750 +locals 5751 +http 5752 +clueless 5753 +pleasing 5754 +vibrant 5755 +independence 5756 +marries 5757 +clad 5758 +charms 5759 +rendered 5760 +heartwarming 5761 +melody 5762 +shouting 5763 +wig 5764 +defeated 5765 +friend's 5766 +stack 5767 +lois 5768 +novak 5769 +coup 5770 +globe 5771 +soup 5772 +claustrophobic 5773 +eats 5774 +flashy 5775 +trivia 5776 +spinal 5777 +thompson 5778 +considerably 5779 +forcing 5780 +befriends 5781 +grudge 5782 +chavez 5783 +net 5784 +shopping 5785 +gems 5786 +claiming 5787 +foxx 5788 +muppets 5789 +discussing 5790 +boston 5791 +ingenious 5792 +flowers 5793 +harold 5794 +feeding 5795 +eternity 5796 +norm 5797 +sharing 5798 +meg 5799 +quinn 5800 +election 5801 +camcorder 5802 +limit 5803 +genie 5804 +daniels 5805 +quaid 5806 +bacon 5807 +runner 5808 +tierney 5809 +champion 5810 +stallone 5811 +minister 5812 +publicity 5813 +static 5814 +springer 5815 +info 5816 +screw 5817 +inhabitants 5818 +'70s 5819 +renaissance 5820 +carla 5821 +screwed 5822 +delicate 5823 +marlon 5824 +weather 5825 +deserving 5826 +incidentally 5827 +depends 5828 +winchester 5829 +boyle 5830 +gina 5831 +immature 5832 +lift 5833 +wings 5834 +partners 5835 +rope 5836 +ace 5837 +phillips 5838 +kathryn 5839 +elite 5840 +pete 5841 +brother's 5842 +glamorous 5843 +transformed 5844 +blatantly 5845 +symbolic 5846 +traffic 5847 +belt 5848 +strings 5849 +excess 5850 +stalker 5851 +smiles 5852 +ton 5853 +politician 5854 +keen 5855 +esther 5856 +ambition 5857 +surgery 5858 +ants 5859 +audrey 5860 +housewife 5861 +ish 5862 +lasting 5863 +allen's 5864 +dvds 5865 +schools 5866 +concepts 5867 +hilarity 5868 +newman 5869 +shaking 5870 +28 5871 +programs 5872 +frames 5873 +coupled 5874 +cheer 5875 +disorder 5876 +salt 5877 +beatles 5878 +fuller 5879 +shorter 5880 +voted 5881 +toronto 5882 +raj 5883 +1940 5884 +exploring 5885 +debate 5886 +yeti 5887 +layers 5888 +fontaine 5889 +backwards 5890 +continually 5891 +feat 5892 +georges 5893 +organized 5894 +destined 5895 +bombs 5896 +differently 5897 +nope 5898 +bend 5899 +towers 5900 +mothers 5901 +partially 5902 +outdated 5903 +punches 5904 +stumbles 5905 +bully 5906 +threatened 5907 +thrilled 5908 +leigh 5909 +charlton 5910 +wax 5911 +bondage 5912 +kolchak 5913 +spree 5914 +assassination 5915 +doctors 5916 +remove 5917 +claude 5918 +europa 5919 +wire 5920 +leather 5921 +messy 5922 +item 5923 +institution 5924 +departure 5925 +centre 5926 +else's 5927 +detectives 5928 +triangle 5929 +lifeless 5930 +handles 5931 +hides 5932 +wanders 5933 +dudley 5934 +accurately 5935 +duration 5936 +hum 5937 +harrison 5938 +damaged 5939 +satirical 5940 +1950 5941 +minority 5942 +suggestion 5943 +insightful 5944 +hangs 5945 +btw 5946 +preferred 5947 +sorely 5948 +windows 5949 +formed 5950 +profession 5951 +boy's 5952 +commenting 5953 +newer 5954 +landed 5955 +colin 5956 +tenant 5957 +goers 5958 +gunga 5959 +uniformly 5960 +neurotic 5961 +trials 5962 +authorities 5963 +oriented 5964 +swept 5965 +northern 5966 +computers 5967 +dylan 5968 +racing 5969 +kline 5970 +95 5971 +vocal 5972 +steele 5973 +1990s 5974 +viewer's 5975 +bridges 5976 +proving 5977 +entered 5978 +demonic 5979 +natives 5980 +seeming 5981 +brendan 5982 +reeves 5983 +obtain 5984 +rear 5985 +evolution 5986 +ie 5987 +christine 5988 +token 5989 +elevator 5990 +braveheart 5991 +garner 5992 +ripping 5993 +refuse 5994 +firmly 5995 +outright 5996 +mermaid 5997 +exquisite 5998 +mutual 5999 +posey 6000 +biblical 6001 +disastrous 6002 +sleaze 6003 +bars 6004 +helpful 6005 +wendigo 6006 +eleven 6007 +choosing 6008 +neatly 6009 +engrossing 6010 +kidman 6011 +freddy's 6012 +earn 6013 +tops 6014 +uma 6015 +anton 6016 +justified 6017 +wtf 6018 +demanding 6019 +mannerisms 6020 +inspire 6021 +speeches 6022 +containing 6023 +pacific 6024 +myth 6025 +sleeps 6026 +reliable 6027 +fifth 6028 +gillian 6029 +setup 6030 +vile 6031 +cookie 6032 +4th 6033 +hitler's 6034 +bowl 6035 +she'll 6036 +sincerely 6037 +tapes 6038 +vanessa 6039 +insanity 6040 +casts 6041 +ratso 6042 +brooding 6043 +disgrace 6044 +luis 6045 +helpless 6046 +1991 6047 +mirrors 6048 +label 6049 +emerge 6050 +kent 6051 +altered 6052 +forgiven 6053 +predecessor 6054 +heels 6055 +skit 6056 +contempt 6057 +activity 6058 +crossing 6059 +describing 6060 +1985 6061 +duvall 6062 +rampage 6063 +healthy 6064 +knightley 6065 +mercy 6066 +undead 6067 +cemetery 6068 +spies 6069 +mesmerizing 6070 +homicide 6071 +cons 6072 +frontal 6073 +ariel 6074 +restrained 6075 +valentine 6076 +approaches 6077 +startling 6078 +cerebral 6079 +vain 6080 +rooting 6081 +destroys 6082 +preparing 6083 +subtly 6084 +1977 6085 +1974 6086 +jordan 6087 +hats 6088 +grateful 6089 +pc 6090 +boasts 6091 +gere 6092 +regards 6093 +creek 6094 +survives 6095 +mixing 6096 +realities 6097 +conan 6098 +topics 6099 +educated 6100 +shaped 6101 +insights 6102 +melissa 6103 +carey 6104 +tunnel 6105 +artwork 6106 +hulk 6107 +hartley 6108 +radical 6109 +deny 6110 +modest 6111 +unlikeable 6112 +compete 6113 +1994 6114 +sometime 6115 +statue 6116 +grounds 6117 +weaker 6118 +seedy 6119 +mitch 6120 +breakfast 6121 +inspirational 6122 +jess 6123 +hugely 6124 +leaders 6125 +coat 6126 +miami 6127 +scariest 6128 +owners 6129 +casino 6130 +miniseries 6131 +freeze 6132 +akin 6133 +timberlake 6134 +deer 6135 +jared 6136 +bulk 6137 +conrad 6138 +wardrobe 6139 +poker 6140 +crashes 6141 +hers 6142 +rapidly 6143 +applaud 6144 +tara 6145 +nominations 6146 +wrenching 6147 +votes 6148 +contribution 6149 +candidate 6150 +loretta 6151 +affects 6152 +homes 6153 +cinemas 6154 +dubious 6155 +child's 6156 +stare 6157 +banter 6158 +exploits 6159 +advertised 6160 +21st 6161 +guards 6162 +vastly 6163 +relentless 6164 +disguised 6165 +masterfully 6166 +critique 6167 +dim 6168 +located 6169 +refers 6170 +narrow 6171 +des 6172 +washed 6173 +origin 6174 +puppets 6175 +addict 6176 +internal 6177 +error 6178 +disgust 6179 +injured 6180 +cartoonish 6181 +bronson 6182 +gods 6183 +alvin 6184 +30s 6185 +shell 6186 +owes 6187 +repulsive 6188 +gimmick 6189 +boris 6190 +linear 6191 +randolph 6192 +photographs 6193 +rides 6194 +ingrid 6195 +scifi 6196 +abruptly 6197 +limitations 6198 +joker 6199 +youthful 6200 +dandy 6201 +unsure 6202 +dazzling 6203 +gained 6204 +arab 6205 +detract 6206 +underwear 6207 +christina 6208 +caricature 6209 +bloom 6210 +continuing 6211 +lasts 6212 +inaccurate 6213 +where's 6214 +swallow 6215 +standout 6216 +motive 6217 +nations 6218 +convicted 6219 +bravo 6220 +youtube 6221 +nolte 6222 +lauren 6223 +holocaust 6224 +vehicles 6225 +bones 6226 +thirties 6227 +audition 6228 +factors 6229 +headache 6230 +growth 6231 +natured 6232 +mason 6233 +expertly 6234 +spine 6235 +hires 6236 +zizek 6237 +undeniably 6238 +bates 6239 +excellently 6240 +highway 6241 +nina 6242 +screenwriters 6243 +buzz 6244 +chronicles 6245 +insults 6246 +corn 6247 +stunningly 6248 +dread 6249 +homosexuality 6250 +perception 6251 +antonio 6252 +lukas 6253 +reward 6254 +decline 6255 +son's 6256 +las 6257 +mol 6258 +unsuspecting 6259 +strengths 6260 +convinces 6261 +spit 6262 +entering 6263 +natalie 6264 +tossed 6265 +toni 6266 +colours 6267 +ronald 6268 +mathieu 6269 +implied 6270 +teams 6271 +resolved 6272 +tower 6273 +entirety 6274 +confront 6275 +wander 6276 +derivative 6277 +missile 6278 +definitive 6279 +gates 6280 +supply 6281 +bachelor 6282 +anyone's 6283 +divorced 6284 +attenborough 6285 +males 6286 +promptly 6287 +painter 6288 +sinking 6289 +polly 6290 +origins 6291 +endlessly 6292 +nerves 6293 +1959 6294 +wagner 6295 +carmen 6296 +judd 6297 +poe 6298 +walt 6299 +unimaginative 6300 +anil 6301 +mice 6302 +1940s 6303 +confronted 6304 +200 6305 +lend 6306 +authenticity 6307 +siblings 6308 +longest 6309 +repressed 6310 +alexandre 6311 +span 6312 +sergeant 6313 +stardom 6314 +cassavetes 6315 +vividly 6316 +salvation 6317 +yep 6318 +jacket 6319 +users 6320 +jarring 6321 +enhanced 6322 +puerto 6323 +colleagues 6324 +referring 6325 +jedi 6326 +tokyo 6327 +niece 6328 +published 6329 +jackson's 6330 +mates 6331 +cbs 6332 +damned 6333 +sgt 6334 +delicious 6335 +uniform 6336 +dominated 6337 +judgment 6338 +juliet 6339 +accessible 6340 +bsg 6341 +exterior 6342 +misfortune 6343 +zane 6344 +phillip 6345 +ally 6346 +giants 6347 +netflix 6348 +energetic 6349 +austen 6350 +unattractive 6351 +devil's 6352 +mobile 6353 +underwater 6354 +stalking 6355 +disabled 6356 +depict 6357 +offbeat 6358 +earnest 6359 +servants 6360 +jill 6361 +bruno 6362 +cliches 6363 +crisp 6364 +nerve 6365 +peck 6366 +wounds 6367 +hepburn 6368 +terminator 6369 +sized 6370 +suburban 6371 +depths 6372 +buys 6373 +hindi 6374 +sticking 6375 +literal 6376 +playboy 6377 +gable 6378 +meandering 6379 +belly 6380 +sensible 6381 +lighter 6382 +21 6383 +stranded 6384 +yokai 6385 +pray 6386 +mutant 6387 +sale 6388 +exit 6389 +estranged 6390 +anyhow 6391 +identical 6392 +foolish 6393 +eventual 6394 +errol 6395 +separated 6396 +bashing 6397 +cushing 6398 +soylent 6399 +antonioni 6400 +galaxy 6401 +glued 6402 +imo 6403 +tormented 6404 +syndrome 6405 +biting 6406 +dragons 6407 +macabre 6408 +dealer 6409 +filthy 6410 +residents 6411 +victorian 6412 +witchcraft 6413 +cents 6414 +improbable 6415 +inherent 6416 +alley 6417 +lester 6418 +readers 6419 +scratch 6420 +pirate 6421 +cher 6422 +pickford 6423 +astounding 6424 +devastating 6425 +breathing 6426 +clash 6427 +approaching 6428 +severed 6429 +owned 6430 +interact 6431 +cleaning 6432 +characteristics 6433 +expects 6434 +guinness 6435 +dismal 6436 +sniper 6437 +lance 6438 +sand 6439 +respectable 6440 +budgets 6441 +sought 6442 +scoop 6443 +slide 6444 +butch 6445 +nightclub 6446 +yours 6447 +blooded 6448 +she'd 6449 +appeals 6450 +ebert 6451 +harriet 6452 +farmer 6453 +stylized 6454 +owns 6455 +noticeable 6456 +kurosawa 6457 +dustin 6458 +id 6459 +balanced 6460 +fragile 6461 +sublime 6462 +salman 6463 +answered 6464 +penn 6465 +amrita 6466 +adore 6467 +logan 6468 +demonstrate 6469 +concentrate 6470 +exploit 6471 +races 6472 +laden 6473 +psychopath 6474 +affleck 6475 +1982 6476 +garland 6477 +worms 6478 +23 6479 +filmmaking 6480 +pattern 6481 +habit 6482 +incapable 6483 +isolation 6484 +fatale 6485 +decidedly 6486 +steam 6487 +jules 6488 +ford's 6489 +asia 6490 +possess 6491 +senior 6492 +reminder 6493 +cheaply 6494 +principals 6495 +immortal 6496 +christie 6497 +monty 6498 +sf 6499 +evelyn 6500 +denis 6501 +corporation 6502 +turd 6503 +soderbergh 6504 +deliverance 6505 +subway 6506 +potter 6507 +breakdown 6508 +flimsy 6509 +packs 6510 +judged 6511 +wisely 6512 +moe 6513 +bogus 6514 +enthusiastic 6515 +cries 6516 +conveyed 6517 +escaping 6518 +plotting 6519 +wilder 6520 +pale 6521 +deliberate 6522 +dvd's 6523 +informed 6524 +promoted 6525 +axe 6526 +flashes 6527 +cypher 6528 +tremendously 6529 +esquire 6530 +1944 6531 +feast 6532 +glaring 6533 +irene 6534 +spectacle 6535 +chopped 6536 +cyborg 6537 +assembled 6538 +drinks 6539 +dump 6540 +celebrated 6541 +quarter 6542 +boyer 6543 +clara 6544 +arguing 6545 +selected 6546 +numbing 6547 +romeo 6548 +volume 6549 +truman 6550 +combines 6551 +embrace 6552 +troma 6553 +expose 6554 +laurie 6555 +kidnapping 6556 +debt 6557 +contribute 6558 +ominous 6559 +jodie 6560 +magician 6561 +o'hara 6562 +conveniently 6563 +outline 6564 +excruciatingly 6565 +accounts 6566 +pound 6567 +pixar 6568 +pierre 6569 +hackman 6570 +lightning 6571 +absorbing 6572 +copied 6573 +clone 6574 +lola 6575 +ugh 6576 +burke 6577 +cecil 6578 +jan 6579 +mitchum 6580 +jealousy 6581 +advised 6582 +40s 6583 +ensure 6584 +collect 6585 +rewarding 6586 +updated 6587 +freaky 6588 +attacking 6589 +rescued 6590 +lex 6591 +1975 6592 +dilemma 6593 +colored 6594 +beowulf 6595 +hi 6596 +melvyn 6597 +ps 6598 +pocket 6599 +passengers 6600 +accepting 6601 +sydney 6602 +classy 6603 +whiny 6604 +loy 6605 +experiencing 6606 +exorcist 6607 +destructive 6608 +300 6609 +goods 6610 +spencer 6611 +corbett 6612 +shepherd 6613 +reports 6614 +expectation 6615 +sophie 6616 +sentimentality 6617 +pause 6618 +sidewalk 6619 +karate 6620 +quantum 6621 +intricate 6622 +tax 6623 +scarface 6624 +crippled 6625 +longing 6626 +nbc 6627 +reeve 6628 +vintage 6629 +crown 6630 +1998 6631 +quentin 6632 +obsessive 6633 +immense 6634 +knocks 6635 +bounty 6636 +indiana 6637 +adaption 6638 +delighted 6639 +er 6640 +naschy 6641 +liam 6642 +establish 6643 +addiction 6644 +europeans 6645 +tool 6646 +stroke 6647 +overblown 6648 +goldblum 6649 +jaded 6650 +pursue 6651 +sucker 6652 +slip 6653 +theories 6654 +rookie 6655 +havoc 6656 +1953 6657 +anticipated 6658 +dukes 6659 +principle 6660 +voyage 6661 +gamera 6662 +swearing 6663 +unsatisfying 6664 +wonderland 6665 +frontier 6666 +parallels 6667 +crashing 6668 +downs 6669 +incorrect 6670 +erika 6671 +aggressive 6672 +divine 6673 +paula 6674 +dashing 6675 +turmoil 6676 +suspected 6677 +aided 6678 +grass 6679 +story's 6680 +distract 6681 +cape 6682 +snuff 6683 +bach 6684 +comprehend 6685 +werewolves 6686 +masterson 6687 +resulted 6688 +miranda 6689 +tendency 6690 +fright 6691 +spaghetti 6692 +goals 6693 +rainy 6694 +reviewing 6695 +juliette 6696 +establishment 6697 +redundant 6698 +switched 6699 +taped 6700 +sarcastic 6701 +arguments 6702 +rider 6703 +peaceful 6704 +barbra 6705 +butcher 6706 +shootout 6707 +bubble 6708 +routines 6709 +demonstrated 6710 +spice 6711 +backed 6712 +polish 6713 +cultures 6714 +parsons 6715 +distress 6716 +hero's 6717 +chill 6718 +morons 6719 +slugs 6720 +subtext 6721 +ultimatum 6722 +intentional 6723 +virtual 6724 +morals 6725 +cutter 6726 +hayworth 6727 +mouthed 6728 +fleshed 6729 +fascist 6730 +dramatically 6731 +passage 6732 +realization 6733 +slaves 6734 +gentlemen 6735 +liu 6736 +hyper 6737 +peculiar 6738 +avoiding 6739 +lavish 6740 +adrian 6741 +vanilla 6742 +boiled 6743 +admired 6744 +thieves 6745 +moron 6746 +sixth 6747 +'cause 6748 +arranged 6749 +climb 6750 +horny 6751 +approached 6752 +alleged 6753 +pumbaa 6754 +predictably 6755 +wielding 6756 +armstrong 6757 +commitment 6758 +seymour 6759 +serum 6760 +odyssey 6761 +hybrid 6762 +messing 6763 +begging 6764 +alter 6765 +establishing 6766 +toby 6767 +whining 6768 +canceled 6769 +collective 6770 +define 6771 +dame 6772 +bikini 6773 +afterward 6774 +mystical 6775 +tourist 6776 +furniture 6777 +fairbanks 6778 +casper 6779 +revolt 6780 +remembering 6781 +exploding 6782 +consideration 6783 +arrest 6784 +inmates 6785 +1934 6786 +shift 6787 +aiming 6788 +samantha 6789 +puzzle 6790 +ghetto 6791 +arc 6792 +traits 6793 +apply 6794 +olds 6795 +sang 6796 +distraction 6797 +hateful 6798 +fools 6799 +anytime 6800 +reviewed 6801 +enhance 6802 +lunch 6803 +coke 6804 +upside 6805 +papers 6806 +insist 6807 +medieval 6808 +wine 6809 +vega 6810 +insomnia 6811 +arriving 6812 +keaton's 6813 +phenomenal 6814 +fills 6815 +graveyard 6816 +stella 6817 +exploited 6818 +writer's 6819 +acquired 6820 +strict 6821 +slapped 6822 +jewel 6823 +thelma 6824 +mcqueen 6825 +pedestrian 6826 +cal 6827 +anthology 6828 +vince 6829 +mythology 6830 +consciousness 6831 +kinnear 6832 +life's 6833 +carnage 6834 +courtroom 6835 +tolerable 6836 +populated 6837 +huston 6838 +contributed 6839 +poses 6840 +actors' 6841 +optimistic 6842 +verdict 6843 +rebellious 6844 +trace 6845 +whites 6846 +commits 6847 +kelly's 6848 +mouths 6849 +stream 6850 +respects 6851 +leap 6852 +sickening 6853 +puppy 6854 +overboard 6855 +diverse 6856 +monologue 6857 +tuned 6858 +corman 6859 +gypo 6860 +skilled 6861 +seasoned 6862 +settled 6863 +horrified 6864 +remembers 6865 +relentlessly 6866 +dj 6867 +— 6868 +jersey 6869 +psychologist 6870 +borders 6871 +lethal 6872 +tony's 6873 +shoe 6874 +smash 6875 +taboo 6876 +wiped 6877 +excuses 6878 +crosses 6879 +salesman 6880 +ritual 6881 +mormon 6882 +achieves 6883 +thunderbirds 6884 +scored 6885 +vanity 6886 +pad 6887 +aussie 6888 +explodes 6889 +ira 6890 +dynamics 6891 +preminger 6892 +franklin 6893 +verbal 6894 +feminine 6895 +policy 6896 +flavor 6897 +expense 6898 +suggesting 6899 +trains 6900 +instincts 6901 +nuances 6902 +dumber 6903 +flock 6904 +feeble 6905 +deanna 6906 +hoot 6907 +cuban 6908 +kathy 6909 +possession 6910 +document 6911 +cohen 6912 +foundation 6913 +diary 6914 +guinea 6915 +covering 6916 +vomit 6917 +readily 6918 +fluid 6919 +cigarette 6920 +tactics 6921 +deliciously 6922 +seductive 6923 +circles 6924 +phase 6925 +themed 6926 +busey 6927 +marilyn 6928 +amidst 6929 +posing 6930 +lean 6931 +cooking 6932 +deputy 6933 +duel 6934 +brainless 6935 +mute 6936 +meantime 6937 +unsympathetic 6938 +wheel 6939 +update 6940 +immigrant 6941 +weary 6942 +basket 6943 +attending 6944 +mortal 6945 +clive 6946 +regularly 6947 +delightfully 6948 +possesses 6949 +newcomer 6950 +porter 6951 +invention 6952 +sources 6953 +wash 6954 +contestants 6955 +shockingly 6956 +wheelchair 6957 +stephanie 6958 +ritchie 6959 +wong 6960 +pushes 6961 +ricky 6962 +audience's 6963 +einstein 6964 +controlling 6965 +mama 6966 +encountered 6967 +pathos 6968 +zorro 6969 +mysteriously 6970 +korea 6971 +bachchan 6972 +jury 6973 +keys 6974 +skinny 6975 +sells 6976 +satisfaction 6977 +romances 6978 +meal 6979 +explosive 6980 +defies 6981 +drab 6982 +clerk 6983 +pfeiffer 6984 +sunrise 6985 +symbol 6986 +pirates 6987 +otto 6988 +novelty 6989 +jacques 6990 +void 6991 +herbert 6992 +narrated 6993 +lionel 6994 +targets 6995 +august 6996 +razor 6997 +rivers 6998 +admitted 6999 +mum 7000 +sundance 7001 +lends 7002 +cliched 7003 +screwball 7004 +serials 7005 +neglected 7006 +olivia 7007 +truths 7008 +sided 7009 +steer 7010 +flower 7011 +indifferent 7012 +dumped 7013 +lucille 7014 +mole 7015 +products 7016 +beg 7017 +releasing 7018 +niven 7019 +stewart's 7020 +ordeal 7021 +darth 7022 +um 7023 +crosby 7024 +statements 7025 +followers 7026 +psyche 7027 +excruciating 7028 +noteworthy 7029 +swinging 7030 +deed 7031 +aftermath 7032 +ranch 7033 +consist 7034 +embarrassingly 7035 +unusually 7036 +convention 7037 +shifts 7038 +produces 7039 +motorcycle 7040 +tickets 7041 +wider 7042 +longoria 7043 +gwyneth 7044 +employee 7045 +instances 7046 +parking 7047 +intact 7048 +starters 7049 +rapid 7050 +arrow 7051 +thurman 7052 +debbie 7053 +dumbest 7054 +wastes 7055 +sarandon 7056 +economic 7057 +israeli 7058 +additionally 7059 +fanatic 7060 +planes 7061 +pursued 7062 +legitimate 7063 +discussed 7064 +forties 7065 +introducing 7066 +anxious 7067 +cannes 7068 +biker 7069 +deciding 7070 +sanders 7071 +fuzzy 7072 +agony 7073 +alot 7074 +assignment 7075 +stones 7076 +scorsese 7077 +caron 7078 +degrees 7079 +medicine 7080 +hannah 7081 +reverse 7082 +inaccuracies 7083 +july 7084 +attended 7085 +gilbert 7086 +forgetting 7087 +jane's 7088 +gielgud 7089 +angie 7090 +milo 7091 +laputa 7092 +branagh's 7093 +motions 7094 +auto 7095 +controversy 7096 +grandma 7097 +cunningham 7098 +professionals 7099 +criticize 7100 +kidnap 7101 +artistry 7102 +sarcasm 7103 +fishburne 7104 +brow 7105 +bogart 7106 +columbia 7107 +incidents 7108 +vera 7109 +meteor 7110 +georgia 7111 +arty 7112 +freaking 7113 +hadley 7114 +suspicion 7115 +scott's 7116 +coffin 7117 +juan 7118 +crossed 7119 +idol 7120 +grip 7121 +obstacles 7122 +mentor 7123 +consequently 7124 +begs 7125 +stating 7126 +ambitions 7127 +muslims 7128 +executives 7129 +daisy 7130 +manners 7131 +warns 7132 +1948 7133 +jolie 7134 +arquette 7135 +distracted 7136 +centuries 7137 +abound 7138 +jose 7139 +factual 7140 +goodbye 7141 +trigger 7142 +breast 7143 +invite 7144 +tcm 7145 +unanswered 7146 +indicate 7147 +shepard 7148 +session 7149 +daylight 7150 +minnelli 7151 +cindy 7152 +funding 7153 +pains 7154 +predator 7155 +flames 7156 +fried 7157 +scripting 7158 +rational 7159 +stabbed 7160 +collette 7161 +'i 7162 +compliment 7163 +hooker 7164 +cliffhanger 7165 +inclusion 7166 +debra 7167 +roughly 7168 +moss 7169 +1967 7170 +awakening 7171 +viewpoint 7172 +kazan 7173 +rejects 7174 +toned 7175 +sentences 7176 +denise 7177 +originals 7178 +cycle 7179 +informative 7180 +pros 7181 +harlow 7182 +stern 7183 +corey 7184 +stalked 7185 +foil 7186 +plodding 7187 +varied 7188 +sweden 7189 +detroit 7190 +misunderstood 7191 +clay 7192 +relevance 7193 +depictions 7194 +blamed 7195 +paints 7196 +pointing 7197 +click 7198 +stance 7199 +protest 7200 +chamber 7201 +robbers 7202 +gooding 7203 +soprano 7204 +likeable 7205 +exclusively 7206 +slim 7207 +campus 7208 +haines 7209 +cheadle 7210 +cap 7211 +cab 7212 +rambling 7213 +paranoid 7214 +seats 7215 +frances 7216 +rowlands 7217 +101 7218 +consequence 7219 +murky 7220 +abandon 7221 +gap 7222 +berkeley 7223 +ruining 7224 +stink 7225 +denouement 7226 +penelope 7227 +intro 7228 +abortion 7229 +tomei 7230 +replies 7231 +antagonist 7232 +gloria 7233 +stardust 7234 +tomb 7235 +gallery 7236 +bug's 7237 +determination 7238 +40's 7239 +c'mon 7240 +translate 7241 +bait 7242 +killer's 7243 +eagerly 7244 +relating 7245 +iranian 7246 +rips 7247 +momentum 7248 +uncanny 7249 +frozen 7250 +begun 7251 +generate 7252 +uniforms 7253 +intensely 7254 +dreamy 7255 +martian 7256 +festivals 7257 +grabbed 7258 +mock 7259 +jenna 7260 +che's 7261 +schedule 7262 +surroundings 7263 +coma 7264 +imaginary 7265 +schneider 7266 +gus 7267 +foremost 7268 +composition 7269 +robertson 7270 +politicians 7271 +services 7272 +hysterically 7273 +snowman 7274 +maureen 7275 +omar 7276 +republic 7277 +lurking 7278 +pans 7279 +alliance 7280 +hostel 7281 +diner 7282 +sheen 7283 +injury 7284 +rupert 7285 +hippies 7286 +rosario 7287 +chamberlain 7288 +ww2 7289 +scenarios 7290 +participants 7291 +realistically 7292 +communication 7293 +kris 7294 +sg 7295 +kathleen 7296 +brat 7297 +redneck 7298 +launch 7299 +therapy 7300 +quasi 7301 +miyazaki 7302 +hmmm 7303 +85 7304 +faux 7305 +geisha 7306 +bauer 7307 +mick 7308 +enigmatic 7309 +1951 7310 +phones 7311 +shaggy 7312 +hostage 7313 +destination 7314 +lens 7315 +glimpses 7316 +1943 7317 +lastly 7318 +rehash 7319 +gestures 7320 +shotgun 7321 +casablanca 7322 +dismiss 7323 +sights 7324 +periods 7325 +burnt 7326 +bats 7327 +resembling 7328 +charlie's 7329 +apt 7330 +linked 7331 +widowed 7332 +dominic 7333 +glance 7334 +cow 7335 +tho 7336 +traps 7337 +curiously 7338 +heath 7339 +envy 7340 +playwright 7341 +gigantic 7342 +paths 7343 +bleed 7344 +ambiguity 7345 +gaps 7346 +bosses 7347 +hayes 7348 +sterling 7349 +necessity 7350 +comeback 7351 +sketches 7352 +sondra 7353 +ignoring 7354 +revolving 7355 +apocalyptic 7356 +reiser 7357 +sailor 7358 +saloon 7359 +frantic 7360 +resistance 7361 +pegg 7362 +overs 7363 +precise 7364 +herman 7365 +rounds 7366 +arkin 7367 +gloomy 7368 +pressed 7369 +haunt 7370 +1992 7371 +enchanted 7372 +iturbi 7373 +fuel 7374 +blaise 7375 +mabel 7376 +laboratory 7377 +county 7378 +veterans 7379 +studied 7380 +cheers 7381 +bearing 7382 +eh 7383 +sunset 7384 +reflected 7385 +rolls 7386 +investigator 7387 +adele 7388 +pen 7389 +maintains 7390 +capacity 7391 +kubrick's 7392 +unstable 7393 +avid 7394 +midst 7395 +man' 7396 +qualify 7397 +bonnie 7398 +person's 7399 +mins 7400 +geek 7401 +nun 7402 +jude 7403 +angelina 7404 +galactica 7405 +sufficient 7406 +substantial 7407 +incest 7408 +handicapped 7409 +trier 7410 +ample 7411 +doctor's 7412 +warden 7413 +supreme 7414 +hinted 7415 +slashers 7416 +rewarded 7417 +rice 7418 +complications 7419 +trauma 7420 +biopic 7421 +sebastian 7422 +'80s 7423 +characterizations 7424 +awareness 7425 +popped 7426 +sparks 7427 +vignettes 7428 +psychedelic 7429 +unclear 7430 +kells 7431 +tightly 7432 +existing 7433 +du 7434 +entrance 7435 +offend 7436 +goldie 7437 +guardian 7438 +collins 7439 +targeted 7440 +talky 7441 +extensive 7442 +ny 7443 +benefits 7444 +epics 7445 +pilots 7446 +payoff 7447 +stadium 7448 +october 7449 +stake 7450 +characterisation 7451 +applied 7452 +applies 7453 +pivotal 7454 +lowe 7455 +gathering 7456 +marisa 7457 +brent 7458 +upcoming 7459 +1963 7460 +overbearing 7461 +eli 7462 +occult 7463 +joking 7464 +ol' 7465 +graduate 7466 +beckinsale 7467 +nuanced 7468 +homicidal 7469 +addressed 7470 +evans 7471 +lunatic 7472 +parrot 7473 +edith 7474 +revival 7475 +convict 7476 +ignores 7477 +safely 7478 +plate 7479 +sour 7480 +turkish 7481 +favourites 7482 +ajay 7483 +boundaries 7484 +northam 7485 +profile 7486 +russ 7487 +skeptical 7488 +frog 7489 +invested 7490 +repeats 7491 +bias 7492 +'60s 7493 +drowned 7494 +iq 7495 +diversity 7496 +outlandish 7497 +nightmarish 7498 +dynamite 7499 +unfolding 7500 +convent 7501 +clooney 7502 +observations 7503 +johansson 7504 +1955 7505 +enchanting 7506 +tire 7507 +stabbing 7508 +disco 7509 +excellence 7510 +27 7511 +clunky 7512 +valid 7513 +array 7514 +engine 7515 +sammo 7516 +doug 7517 +sly 7518 +interior 7519 +resolve 7520 +hating 7521 +olsen 7522 +interviewed 7523 +chong 7524 +protection 7525 +maximum 7526 +nauseating 7527 +versa 7528 +apocalypse 7529 +exploitative 7530 +observation 7531 +murderers 7532 +questioning 7533 +gosh 7534 +stereotyped 7535 +flag 7536 +shore 7537 +pose 7538 +acknowledge 7539 +fruit 7540 +caretaker 7541 +rosemary's 7542 +interpretations 7543 +shin 7544 +stations 7545 +flavia 7546 +nutshell 7547 +announced 7548 +assure 7549 +silverman 7550 +duh 7551 +sonny 7552 +1958 7553 +blockbusters 7554 +pornography 7555 +vivian 7556 +sensibility 7557 +courtesy 7558 +battlestar 7559 +macdonald 7560 +boots 7561 +brides 7562 +reunite 7563 +brooke 7564 +controls 7565 +masked 7566 +phantasm 7567 +prophecy 7568 +slower 7569 +relying 7570 +sweat 7571 +divided 7572 +mannered 7573 +marked 7574 +witnessing 7575 +girlfriends 7576 +snipes 7577 +fortunate 7578 +watcher 7579 +brett 7580 +ernie 7581 +villainous 7582 +strung 7583 +rebels 7584 +candle 7585 +counting 7586 +mccarthy 7587 +rodriguez 7588 +bonham 7589 +portuguese 7590 +daytime 7591 +rea 7592 +insert 7593 +misty 7594 +displaying 7595 +substitute 7596 +satanic 7597 +wayans 7598 +magically 7599 +sincerity 7600 +owl 7601 +cocaine 7602 +spotlight 7603 +inter 7604 +chewing 7605 +lopez 7606 +chiba 7607 +progressed 7608 +entries 7609 +demille 7610 +chuckles 7611 +climbing 7612 +26 7613 +chaotic 7614 +criticized 7615 +confined 7616 +sanity 7617 +goat 7618 +unhinged 7619 +bittersweet 7620 +collar 7621 +realises 7622 +peril 7623 +bust 7624 +smell 7625 +turtle 7626 +wartime 7627 +admits 7628 +commanding 7629 +evokes 7630 +beard 7631 +seduce 7632 +harrowing 7633 +janet 7634 +phoenix 7635 +stiles 7636 +interrupted 7637 +whore 7638 +shocks 7639 +inadvertently 7640 +jar 7641 +wright 7642 +fart 7643 +resume 7644 +lynch's 7645 +needing 7646 +delirious 7647 +upstairs 7648 +obscurity 7649 +famed 7650 +palm 7651 +weekly 7652 +replacement 7653 +monotonous 7654 +smug 7655 +preaching 7656 +projected 7657 +randall 7658 +enduring 7659 +hmm 7660 +organization 7661 +landmark 7662 +thereby 7663 +fundamental 7664 +ripoff 7665 +rightly 7666 +ins 7667 +chew 7668 +slavery 7669 +unnatural 7670 +arrogance 7671 +waking 7672 +manipulation 7673 +jagger 7674 +reserved 7675 +blazing 7676 +finishes 7677 +somethings 7678 +observe 7679 +raging 7680 +thrust 7681 +trivial 7682 +madsen 7683 +carlos 7684 +samuel 7685 +tones 7686 +commendable 7687 +crushed 7688 +similarity 7689 +deemed 7690 +choir 7691 +imagining 7692 +unappealing 7693 +understatement 7694 +apple 7695 +discipline 7696 +thailand 7697 +colleague 7698 +convenient 7699 +rendering 7700 +hines 7701 +cena 7702 +mandy 7703 +testing 7704 +motel 7705 +subsequently 7706 +fassbinder 7707 +reluctantly 7708 +platform 7709 +men's 7710 +egyptian 7711 +aesthetic 7712 +hooper 7713 +accompanying 7714 +protective 7715 +penned 7716 +fetish 7717 +kirsten 7718 +herd 7719 +layered 7720 +scarecrows 7721 +incestuous 7722 +thunder 7723 +boogie 7724 +participate 7725 +forgiveness 7726 +baddies 7727 +hardened 7728 +forgets 7729 +comparable 7730 +combs 7731 +understandably 7732 +shahid 7733 +laying 7734 +marine 7735 +recover 7736 +scheming 7737 +cancelled 7738 +vargas 7739 +stumble 7740 +celebrities 7741 +merry 7742 +russo 7743 +frost 7744 +unfamiliar 7745 +madeleine 7746 +isabelle 7747 +crooks 7748 +python 7749 +filmography 7750 +explode 7751 +sylvia 7752 +article 7753 +climatic 7754 +achievements 7755 +conductor 7756 +pizza 7757 +reminding 7758 +remark 7759 +lo 7760 +gackt 7761 +traumatic 7762 +benjamin 7763 +stuffed 7764 +accidental 7765 +travis 7766 +govinda 7767 +must've 7768 +quintessential 7769 +deathtrap 7770 +cheerful 7771 +hostile 7772 +orchestra 7773 +ninety 7774 +gorilla 7775 +marcel 7776 +cameraman 7777 +shred 7778 +sholay 7779 +wrestler 7780 +customers 7781 +hallmark 7782 +beers 7783 +glossy 7784 +despise 7785 +anita 7786 +goings 7787 +spontaneous 7788 +1932 7789 +fleet 7790 +shameless 7791 +charges 7792 +camping 7793 +finishing 7794 +district 7795 +sins 7796 +dallas 7797 +file 7798 +yell 7799 +serbian 7800 +myrna 7801 +wholesome 7802 +titular 7803 +boo 7804 +o'brien 7805 +implies 7806 +sack 7807 +flip 7808 +salvage 7809 +annoy 7810 +restraint 7811 +imho 7812 +creations 7813 +affecting 7814 +pornographic 7815 +spoiling 7816 +bonanza 7817 +ala 7818 +raid 7819 +raunchy 7820 +sales 7821 +cheering 7822 +captivated 7823 +je 7824 +espionage 7825 +license 7826 +defining 7827 +beforehand 7828 +se 7829 +conclusions 7830 +bakshi's 7831 +hawn 7832 +sherlock 7833 +caprica 7834 +ruled 7835 +unconventional 7836 +diego 7837 +awry 7838 +verge 7839 +krueger 7840 +grin 7841 +whimsical 7842 +ideals 7843 +meyer 7844 +surround 7845 +characteristic 7846 +digging 7847 +shameful 7848 +coolest 7849 +philo 7850 +cells 7851 +reagan 7852 +seattle 7853 +infinitely 7854 +sickness 7855 +excels 7856 +2009 7857 +novelist 7858 +1946 7859 +burial 7860 +fades 7861 +faded 7862 +shannon 7863 +traditions 7864 +fraud 7865 +perverted 7866 +sheets 7867 +voodoo 7868 +desk 7869 +abundance 7870 +flashing 7871 +hunted 7872 +betrayed 7873 +admission 7874 +gershwin 7875 +rampant 7876 +relaxed 7877 +fires 7878 +polar 7879 +kindly 7880 +tits 7881 +melancholy 7882 +drowning 7883 +semblance 7884 +temper 7885 +cracks 7886 +tide 7887 +oblivious 7888 +miraculously 7889 +clarity 7890 +elliott 7891 +inserted 7892 +considers 7893 +constraints 7894 +drift 7895 +sunk 7896 +distributed 7897 +unnecessarily 7898 +welles' 7899 +flows 7900 +sexist 7901 +beckham 7902 +summed 7903 +henchmen 7904 +tools 7905 +transparent 7906 +devotion 7907 +hitchcock's 7908 +earliest 7909 +scarlett 7910 +dangerously 7911 +taut 7912 +dafoe 7913 +dreaming 7914 +seth 7915 +prop 7916 +cain 7917 +wesley 7918 +adapt 7919 +openly 7920 +sane 7921 +hugo 7922 +creasy 7923 +chops 7924 +pitched 7925 +juice 7926 +riff 7927 +blandings 7928 +shah 7929 +screened 7930 +tashan 7931 +meredith 7932 +doyle 7933 +mud 7934 +zodiac 7935 +regime 7936 +irritated 7937 +eagle 7938 +paycheck 7939 +egypt 7940 +spiral 7941 +letdown 7942 +wherever 7943 +madison 7944 +deeds 7945 +robotic 7946 +faint 7947 +outrageously 7948 +sheep 7949 +elsa 7950 +baron 7951 +overtones 7952 +searched 7953 +unleashed 7954 +sporting 7955 +lennon 7956 +gangs 7957 +dahmer 7958 +peggy 7959 +vapid 7960 +heap 7961 +circa 7962 +simpsons 7963 +slater 7964 +permanent 7965 +voyager 7966 +presidential 7967 +compensate 7968 +deepest 7969 +reject 7970 +uneasy 7971 +ghastly 7972 +gretchen 7973 +sophia 7974 +warehouse 7975 +switching 7976 +cedric 7977 +lara 7978 +evoke 7979 +flame 7980 +automatic 7981 +submarine 7982 +plug 7983 +programme 7984 +sucking 7985 +pursuing 7986 +avoids 7987 +assistance 7988 +assumes 7989 +orphan 7990 +mart 7991 +practical 7992 +joining 7993 +failures 7994 +liner 7995 +garfield 7996 +dwight 7997 +slut 7998 +oprah 7999 +committing 8000 +intend 8001 +ealing 8002 +shirts 8003 +locke 8004 +admirer 8005 +awaiting 8006 +ram 8007 +fritz 8008 +melbourne 8009 +contestant 8010 +timmy 8011 +rivals 8012 +buffy 8013 +clouds 8014 +ambiance 8015 +babes 8016 +ensue 8017 +coburn 8018 +occupied 8019 +sergio 8020 +sitcoms 8021 +variation 8022 +censorship 8023 +ferrell 8024 +radiation 8025 +snap 8026 +underdeveloped 8027 +takashi 8028 +hobgoblins 8029 +finney 8030 +listened 8031 +fiancée 8032 +complained 8033 +pauline 8034 +kinski 8035 +alarm 8036 +engineer 8037 +chloe 8038 +proceed 8039 +demeanor 8040 +suzanne 8041 +battlefield 8042 +rebellion 8043 +criticisms 8044 +remainder 8045 +ghostly 8046 +spaceship 8047 +howling 8048 +motivated 8049 +joint 8050 +carpenter's 8051 +fodder 8052 +bert 8053 +dominate 8054 +monks 8055 +dragging 8056 +inclined 8057 +upbeat 8058 +encouraged 8059 +networks 8060 +han 8061 +loren 8062 +brazilian 8063 +atlantic 8064 +flowing 8065 +progression 8066 +tess 8067 +meek 8068 +darkly 8069 +disappearance 8070 +colman 8071 +crashed 8072 +caper 8073 +solved 8074 +fairness 8075 +distinction 8076 +sensual 8077 +feinstone 8078 +sho 8079 +warrant 8080 +grease 8081 +visitor 8082 +marijuana 8083 +sections 8084 +avenge 8085 +tv's 8086 +croc 8087 +sober 8088 +badness 8089 +who've 8090 +ninjas 8091 +myrtle 8092 +runaway 8093 +helmet 8094 +scratching 8095 +quaint 8096 +busby 8097 +defending 8098 +buttons 8099 +artemisia 8100 +cloak 8101 +noting 8102 +confuse 8103 +experts 8104 +whip 8105 +borrow 8106 +barney 8107 +garage 8108 +happenings 8109 +mega 8110 +1990's 8111 +disregard 8112 +bean 8113 +aaron 8114 +edges 8115 +diving 8116 +investment 8117 +wee 8118 +electronic 8119 +gena 8120 +gypsy 8121 +suave 8122 +mustache 8123 +toxic 8124 +mira 8125 +bartender 8126 +prologue 8127 +transport 8128 +atrocity 8129 +everett 8130 +bernsen 8131 +notices 8132 +jo 8133 +boogeyman 8134 +knees 8135 +1966 8136 +1000 8137 +robbed 8138 +epitome 8139 +bennett 8140 +vcr 8141 +who'd 8142 +'a 8143 +detached 8144 +brit 8145 +hometown 8146 +jack's 8147 +prone 8148 +enormously 8149 +gilliam 8150 +jackman 8151 +dom 8152 +impending 8153 +bloodbath 8154 +mister 8155 +macmurray 8156 +vigilante 8157 +offense 8158 +prostitutes 8159 +fashions 8160 +idealistic 8161 +pigs 8162 +abomination 8163 +carpet 8164 +battling 8165 +principles 8166 +paz 8167 +pretends 8168 +awarded 8169 +admiration 8170 +incidental 8171 +tin 8172 +pairing 8173 +woefully 8174 +chip 8175 +classmates 8176 +timed 8177 +budding 8178 +gandolfini 8179 +revolver 8180 +liberty 8181 +associate 8182 +padding 8183 +colony 8184 +zelah 8185 +drum 8186 +vincenzo 8187 +secure 8188 +palestinian 8189 +girls' 8190 +blames 8191 +torment 8192 +kids' 8193 +framing 8194 +tackle 8195 +tended 8196 +peers 8197 +policemen 8198 +facility 8199 +ostensibly 8200 +harron 8201 +prank 8202 +lindy 8203 +bimbo 8204 +1957 8205 +saints 8206 +capote 8207 +shrek 8208 +breathe 8209 +nineties 8210 +worrying 8211 +believability 8212 +paragraph 8213 +mediocrity 8214 +influences 8215 +reported 8216 +conveying 8217 +programming 8218 +stoned 8219 +val 8220 +barnes 8221 +sharks 8222 +unravel 8223 +courageous 8224 +deck 8225 +giovanna 8226 +grating 8227 +britney 8228 +distinctive 8229 +blondell 8230 +spoofs 8231 +brush 8232 +effortlessly 8233 +riders 8234 +midget 8235 +annoyance 8236 +counterparts 8237 +economy 8238 +rivalry 8239 +stab 8240 +knights 8241 +socially 8242 +symbols 8243 +bodyguard 8244 +qualifies 8245 +connie 8246 +acclaim 8247 +managing 8248 +vibe 8249 +monroe 8250 +frat 8251 +baked 8252 +combining 8253 +martians 8254 +boobs 8255 +prostitution 8256 +closure 8257 +senator 8258 +outset 8259 +magazines 8260 +respond 8261 +interiors 8262 +division 8263 +slam 8264 +celebrate 8265 +elected 8266 +zu 8267 +monica 8268 +dillinger 8269 +brashear 8270 +cohesive 8271 +clinic 8272 +gig 8273 +tacked 8274 +coward 8275 +parodies 8276 +greene 8277 +billing 8278 +weirdness 8279 +dunst 8280 +rourke 8281 +manipulated 8282 +concentration 8283 +sinks 8284 +dreyfuss 8285 +asset 8286 +duchovny 8287 +superstar 8288 +clyde 8289 +december 8290 +pompous 8291 +fabric 8292 +placement 8293 +gibson 8294 +bless 8295 +boards 8296 +troopers 8297 +reese 8298 +goodman 8299 +transplant 8300 +shocker 8301 +examine 8302 +chock 8303 +scarlet 8304 +informs 8305 +responds 8306 +collapse 8307 +data 8308 +swiss 8309 +reasoning 8310 +confines 8311 +categories 8312 +injustice 8313 +laser 8314 +dish 8315 +employees 8316 +smith's 8317 +em 8318 +gasp 8319 +sacrifices 8320 +maurice 8321 +worship 8322 +screenplays 8323 +tolerate 8324 +pee 8325 +overshadowed 8326 +dern 8327 +reunited 8328 +brick 8329 +loner 8330 +holt 8331 +sites 8332 +uncertain 8333 +theatres 8334 +morse 8335 +yells 8336 +sibling 8337 +cheech 8338 +butchered 8339 +mae 8340 +ernest 8341 +sensibilities 8342 +500 8343 +ali 8344 +irving 8345 +castro 8346 +influential 8347 +terrorism 8348 +strained 8349 +derived 8350 +chandler 8351 +slept 8352 +perspectives 8353 +bleeding 8354 +madman 8355 +1942 8356 +inconsistencies 8357 +sensitivity 8358 +jam 8359 +hans 8360 +sustain 8361 +systems 8362 +armor 8363 +burgess 8364 +fiery 8365 +queens 8366 +katie 8367 +gruff 8368 +ewoks 8369 +faye 8370 +tramp 8371 +brandon 8372 +lighthearted 8373 +inform 8374 +cursed 8375 +retro 8376 +250 8377 +malden 8378 +cody 8379 +spelled 8380 +manic 8381 +labeled 8382 +perverse 8383 +collector 8384 +drain 8385 +shelter 8386 +spade 8387 +fallon 8388 +ang 8389 +gino 8390 +kareena 8391 +depardieu 8392 +apollo 8393 +officially 8394 +playful 8395 +informer 8396 +banks 8397 +retirement 8398 +booth 8399 +replacing 8400 +transforms 8401 +surrender 8402 +shield 8403 +jigsaw 8404 +fiend 8405 +predecessors 8406 +judgement 8407 +bing 8408 +englund 8409 +ads 8410 +damsel 8411 +stirring 8412 +structured 8413 +patty 8414 +poet 8415 +signature 8416 +tolerance 8417 +bites 8418 +dash 8419 +seriousness 8420 +casted 8421 +mercifully 8422 +edison 8423 +advances 8424 +padded 8425 +czech 8426 +lingering 8427 +sensational 8428 +crowded 8429 +bigfoot 8430 +captive 8431 +plotted 8432 +premiered 8433 +dictator 8434 +locale 8435 +bastard 8436 +manga 8437 +fighters 8438 +sophistication 8439 +lifts 8440 +yarn 8441 +spelling 8442 +uptight 8443 +farrah 8444 +drummer 8445 +amid 8446 +kidnaps 8447 +peaks 8448 +drastically 8449 +cringing 8450 +coop 8451 +dealers 8452 +geoffrey 8453 +rousing 8454 +supermarket 8455 +standpoint 8456 +thereafter 8457 +portions 8458 +latino 8459 +henchman 8460 +berenger 8461 +slash 8462 +sandy 8463 +lurid 8464 +coal 8465 +interplay 8466 +stares 8467 +willingly 8468 +mines 8469 +ss 8470 +ceremony 8471 +inexperienced 8472 +awfulness 8473 +condemned 8474 +benny 8475 +alba 8476 +mythical 8477 +spotted 8478 +sara 8479 +fierce 8480 +thereof 8481 +bloodshed 8482 +enthralling 8483 +geniuses 8484 +lars 8485 +rant 8486 +theodore 8487 +heather 8488 +echoes 8489 +maintaining 8490 +bombed 8491 +bitchy 8492 +fiasco 8493 +powered 8494 +tina 8495 +ossessione 8496 +worm 8497 +godard 8498 +observed 8499 +staging 8500 +attendant 8501 +anxiety 8502 +villa 8503 +varying 8504 +stepmother 8505 +aircraft 8506 +david's 8507 +justification 8508 +identified 8509 +downfall 8510 +anguish 8511 +shoved 8512 +allan 8513 +bliss 8514 +caution 8515 +transported 8516 +impressions 8517 +miike's 8518 +alexandra 8519 +shout 8520 +functions 8521 +imitate 8522 +norris 8523 +dwarf 8524 +nearest 8525 +funky 8526 +drugged 8527 +stabs 8528 +marrying 8529 +hallucinations 8530 +allies 8531 +communism 8532 +fixed 8533 +sorrow 8534 +orlando 8535 +register 8536 +surf 8537 +scarier 8538 +freed 8539 +tasty 8540 +baddie 8541 +vet 8542 +attic 8543 +representing 8544 +widower 8545 +cunning 8546 +plagued 8547 +hunky 8548 +apartheid 8549 +cockney 8550 +luc 8551 +islands 8552 +fur 8553 +emphasize 8554 +confession 8555 +ceiling 8556 +hairy 8557 +warhols 8558 +stricken 8559 +presume 8560 +rosenstrasse 8561 +meadows 8562 +distorted 8563 +virtue 8564 +natali 8565 +forrest 8566 +starship 8567 +lampoon 8568 +depend 8569 +marvin 8570 +mixes 8571 +jewelry 8572 +correctness 8573 +nest 8574 +myra 8575 +rockets 8576 +russians 8577 +glenda 8578 +byron 8579 +sammy 8580 +grandpa 8581 +monday 8582 +entertains 8583 +adultery 8584 +egg 8585 +massey 8586 +drawings 8587 +travolta 8588 +tricked 8589 +abu 8590 +bio 8591 +lin 8592 +fagin 8593 +cowardly 8594 +overwrought 8595 +determine 8596 +throne 8597 +ratio 8598 +tsui 8599 +paired 8600 +cannibals 8601 +fuss 8602 +client 8603 +animator 8604 +hurry 8605 +romania 8606 +foreboding 8607 +pub 8608 +earns 8609 +bon 8610 +gen 8611 +della 8612 +photograph 8613 +pecker 8614 +censors 8615 +groundbreaking 8616 +predicted 8617 +crooked 8618 +engagement 8619 +arnie 8620 +torturing 8621 +towns 8622 +intellectually 8623 +bald 8624 +finely 8625 +confirmed 8626 +natasha 8627 +hale 8628 +chemical 8629 +spells 8630 +loony 8631 +richly 8632 +edmund 8633 +groove 8634 +vaudeville 8635 +bills 8636 +ma 8637 +millennium 8638 +gladiator 8639 +icy 8640 +irrational 8641 +ballroom 8642 +daria 8643 +conflicted 8644 +clarence 8645 +subdued 8646 +sigh 8647 +artistically 8648 +keanu 8649 +laced 8650 +potent 8651 +representative 8652 +gently 8653 +reckless 8654 +dopey 8655 +jerky 8656 +deborah 8657 +decency 8658 +grossly 8659 +predictability 8660 +consumed 8661 +belle 8662 +blessed 8663 +parks 8664 +curtain 8665 +dukakis 8666 +federal 8667 +analyze 8668 +echo 8669 +contributes 8670 +accomplishment 8671 +cheesiness 8672 +romanian 8673 +almighty 8674 +continuously 8675 +gathered 8676 +dive 8677 +undercover 8678 +diaz 8679 +profoundly 8680 +identities 8681 +crypt 8682 +downbeat 8683 +1949 8684 +gusto 8685 +missions 8686 +sasquatch 8687 +locate 8688 +borrows 8689 +maturity 8690 +harbor 8691 +denial 8692 +emmy 8693 +arch 8694 +animations 8695 +airing 8696 +superfluous 8697 +lists 8698 +officials 8699 +steaming 8700 +operate 8701 +threads 8702 +significantly 8703 +aniston 8704 +goldsworthy 8705 +anchors 8706 +disappoints 8707 +collaboration 8708 +trusted 8709 +lays 8710 +sync 8711 +1920s 8712 +wrongly 8713 +lindsey 8714 +optimism 8715 +vertigo 8716 +abroad 8717 +judges 8718 +continent 8719 +lizard 8720 +muni 8721 +helena 8722 +hartley's 8723 +zeta 8724 +denying 8725 +proportions 8726 +winners 8727 +ll 8728 +monologues 8729 +gravity 8730 +forbes 8731 +launched 8732 +robbing 8733 +mash 8734 +mocking 8735 +confronts 8736 +mutants 8737 +beetle 8738 +nifty 8739 +fence 8740 +horn 8741 +luxury 8742 +athletic 8743 +imprisoned 8744 +scriptwriter 8745 +mack 8746 +handy 8747 +pia 8748 +uninspiring 8749 +rhyme 8750 +1964 8751 +promoting 8752 +73 8753 +flew 8754 +98 8755 +corbin 8756 +chevy 8757 +mobster 8758 +altman's 8759 +extraordinarily 8760 +applause 8761 +abstract 8762 +switches 8763 +garde 8764 +icons 8765 +showcases 8766 +intelligently 8767 +capitalism 8768 +developments 8769 +lions 8770 +hanzo 8771 +hypnotic 8772 +temptation 8773 +dedication 8774 +opposition 8775 +sensation 8776 +kristofferson 8777 +barton 8778 +lds 8779 +bothers 8780 +satisfactory 8781 +nora 8782 +genetic 8783 +moonstruck 8784 +illustrate 8785 +notwithstanding 8786 +elephants 8787 +stripper 8788 +grendel 8789 +fulfilling 8790 +languages 8791 +hilton 8792 +autobiography 8793 +pleasures 8794 +lightweight 8795 +increasing 8796 +preferably 8797 +shifting 8798 +bearable 8799 +prefers 8800 +idiocy 8801 +heroin 8802 +manipulate 8803 +uncredited 8804 +sheridan 8805 +conniving 8806 +surgeon 8807 +nonexistent 8808 +deservedly 8809 +clutter 8810 +bullies 8811 +penalty 8812 +scattered 8813 +owe 8814 +lawn 8815 +upbringing 8816 +increase 8817 +oblivion 8818 +fanning 8819 +shiny 8820 +cynicism 8821 +kings 8822 +hazzard 8823 +preacher 8824 +ongoing 8825 +luthor 8826 +sister's 8827 +quirks 8828 +michaels 8829 +transitions 8830 +ravishing 8831 +reno 8832 +corridors 8833 +shady 8834 +cloth 8835 +liotta 8836 +spinning 8837 +sleeper 8838 +auteur 8839 +plummer 8840 +appalled 8841 +reportedly 8842 +dodgy 8843 +todays 8844 +harilal 8845 +kilmer 8846 +blackmail 8847 +toss 8848 +distinctly 8849 +violently 8850 +ebay 8851 +limp 8852 +marines 8853 +lesbians 8854 +vaughn 8855 +bart 8856 +knocking 8857 +palma's 8858 +boost 8859 +aboard 8860 +defy 8861 +civilians 8862 +brunette 8863 +fewer 8864 +cinematographic 8865 +liberties 8866 +shrill 8867 +youngsters 8868 +strain 8869 +hammerhead 8870 +inhabit 8871 +thug 8872 +dyke 8873 +euro 8874 +cassie 8875 +fellini 8876 +puzzled 8877 +chop 8878 +sweeping 8879 +throats 8880 +thirds 8881 +billion 8882 +witted 8883 +operating 8884 +atomic 8885 +lt 8886 +supportive 8887 +henderson 8888 +profit 8889 +prolific 8890 +sore 8891 +virginity 8892 +sleepy 8893 +golf 8894 +outlaw 8895 +unnerving 8896 +expresses 8897 +mills 8898 +forsythe 8899 +authors 8900 +behaving 8901 +visconti 8902 +efficient 8903 +visceral 8904 +glow 8905 +jones' 8906 +melinda 8907 +muscle 8908 +pepper 8909 +heavenly 8910 +unwilling 8911 +1965 8912 +roach 8913 +marcus 8914 +tables 8915 +shelves 8916 +dunne 8917 +tedium 8918 +illustrated 8919 +explanations 8920 +snowy 8921 +patriotic 8922 +alcoholism 8923 +whipped 8924 +ledger 8925 +slaughtered 8926 +redford 8927 +percent 8928 +rapes 8929 +disasters 8930 +dickinson 8931 +examined 8932 +cradle 8933 +fleeing 8934 +healing 8935 +lightly 8936 +nerdy 8937 +torch 8938 +rodney 8939 +believer 8940 +teddy 8941 +meyers 8942 +lorre 8943 +denver 8944 +dangers 8945 +architect 8946 +vulnerability 8947 +knives 8948 +dillon 8949 +goo 8950 +numbingly 8951 +inch 8952 +compositions 8953 +flipping 8954 +amoral 8955 +wrath 8956 +rack 8957 +imply 8958 +bonds 8959 +pistol 8960 +perceived 8961 +aura 8962 +tobe 8963 +seventh 8964 +verhoeven's 8965 +insignificant 8966 +simpler 8967 +shatner 8968 +mac 8969 +kornbluth 8970 +barbarian 8971 +zoom 8972 +proudly 8973 +hawaii 8974 +hustler 8975 +penguin 8976 +supports 8977 +thumb 8978 +segal 8979 +fulfill 8980 +bothering 8981 +jurassic 8982 +compromise 8983 +annoyingly 8984 +kenny 8985 +scandal 8986 +overtly 8987 +fleeting 8988 +metropolis 8989 +guru 8990 +rotting 8991 +sixteen 8992 +deadpan 8993 +retrieve 8994 +moderately 8995 +chat 8996 +lang 8997 +simon's 8998 +illusion 8999 +heartless 9000 +backwoods 9001 +climate 9002 +righteous 9003 +beth 9004 +grisly 9005 +prejudices 9006 +immigrants 9007 +alienation 9008 +muscular 9009 +astonishingly 9010 +doses 9011 +traveled 9012 +happier 9013 +electricity 9014 +succession 9015 +cousins 9016 +mandatory 9017 +dental 9018 +breakthrough 9019 +freaked 9020 +clockwork 9021 +ursula 9022 +recurring 9023 +notions 9024 +mechanic 9025 +recovering 9026 +zhang 9027 +comprised 9028 +coverage 9029 +elder 9030 +afghanistan 9031 +trendy 9032 +keeper 9033 +hungarian 9034 +attributes 9035 +brennan 9036 +protecting 9037 +priests 9038 +aztec 9039 +ranger 9040 +recipe 9041 +vienna 9042 +ogre 9043 +farnsworth 9044 +tasks 9045 +romero's 9046 +purse 9047 +subtitled 9048 +lansbury 9049 +pickup 9050 +pals 9051 +unconscious 9052 +animators 9053 +legion 9054 +meanings 9055 +needlessly 9056 +sleuth 9057 +association 9058 +slips 9059 +doris 9060 +pond 9061 +improvised 9062 +relates 9063 +mcdowell 9064 +volumes 9065 +ranging 9066 +zany 9067 +irresistible 9068 +elisha 9069 +herrings 9070 +coppola 9071 +prolonged 9072 +relaxing 9073 +1931 9074 +1938 9075 +rudd 9076 +heir 9077 +innuendo 9078 +urgency 9079 +bloke 9080 +flamboyant 9081 +muriel 9082 +prophet 9083 +reruns 9084 +christensen 9085 +lure 9086 +cracker 9087 +levy 9088 +shakespearean 9089 +encourages 9090 +mockery 9091 +swords 9092 +penis 9093 +pam 9094 +welcomed 9095 +rugged 9096 +academic 9097 +honeymoon 9098 +climbs 9099 +snatch 9100 +overwhelmed 9101 +gays 9102 +roommates 9103 +jolly 9104 +heavens 9105 +placing 9106 +watered 9107 +fable 9108 +zealand 9109 +carnival 9110 +gee 9111 +archer 9112 +locales 9113 +thorn 9114 +smarmy 9115 +kiddie 9116 +farewell 9117 +cheat 9118 +hopeful 9119 +backdrops 9120 +treating 9121 +kamal 9122 +irresponsible 9123 +behalf 9124 +benoit 9125 +unemployed 9126 +backyard 9127 +norton 9128 +stumbling 9129 +theirs 9130 +anonymous 9131 +temporary 9132 +distinguished 9133 +moore's 9134 +inhabited 9135 +wwi 9136 +eastwood's 9137 +pranks 9138 +custody 9139 +yearning 9140 +interspersed 9141 +agatha 9142 +chocolate 9143 +hug 9144 +guided 9145 +martino 9146 +steamy 9147 +feared 9148 +opponents 9149 +crawl 9150 +mans 9151 +jew 9152 +bombing 9153 +assortment 9154 +poke 9155 +imitating 9156 +management 9157 +keitel 9158 +frenzy 9159 +mcadams 9160 +architecture 9161 +spitting 9162 +48 9163 +hector 9164 +fitzgerald 9165 +rko 9166 +redgrave 9167 +induced 9168 +plants 9169 +rusty 9170 +janitor 9171 +weaver 9172 +recreate 9173 +islam 9174 +rogue 9175 +roads 9176 +rewrite 9177 +dodge 9178 +balloon 9179 +honey 9180 +neeson 9181 +conquest 9182 +slug 9183 +wolves 9184 +neglect 9185 +shawn 9186 +concentrated 9187 +tested 9188 +existential 9189 +expanded 9190 +worldwide 9191 +truthful 9192 +unlucky 9193 +liz 9194 +compassionate 9195 +limbs 9196 +impeccable 9197 +dogma 9198 +shattering 9199 +sailors 9200 +peterson 9201 +jock 9202 +rizzo 9203 +kalifornia 9204 +mcdermott 9205 +versatile 9206 +400 9207 +michael's 9208 +naval 9209 +burden 9210 +cheung 9211 +largest 9212 +culkin 9213 +retelling 9214 +muted 9215 +leaps 9216 +theo 9217 +passive 9218 +bucket 9219 +pertwee 9220 +eddy 9221 +rapture 9222 +continuous 9223 +gage 9224 +stretches 9225 +giggle 9226 +marx 9227 +concludes 9228 +stalks 9229 +amok 9230 +adequately 9231 +melt 9232 +stature 9233 +counted 9234 +borderline 9235 +mastermind 9236 +boxes 9237 +posh 9238 +taker 9239 +counterpart 9240 +izzard 9241 +straw 9242 +toe 9243 +shamelessly 9244 +crenna 9245 +tango 9246 +pour 9247 +behaves 9248 +sematary 9249 +expand 9250 +azumi 9251 +country's 9252 +stimulating 9253 +grady 9254 +expressing 9255 +payne 9256 +crass 9257 +intellect 9258 +booker 9259 +dani 9260 +parents' 9261 +lotr 9262 +miyazaki's 9263 +wits 9264 +waving 9265 +traumatized 9266 +illiterate 9267 +chan's 9268 +puzzling 9269 +splitting 9270 +subtleties 9271 +seduction 9272 +condescending 9273 +rebecca 9274 +inherited 9275 +seal 9276 +consisted 9277 +stubborn 9278 +didnt 9279 +lieutenant 9280 +slows 9281 +john's 9282 +glee 9283 +honorable 9284 +'73 9285 +valerie 9286 +smoothly 9287 +poo 9288 +evolved 9289 +darling 9290 +planted 9291 +mold 9292 +supremacy 9293 +opener 9294 +seuss 9295 +craven's 9296 +celine 9297 +hesitate 9298 +conception 9299 +supporters 9300 +revolting 9301 +practices 9302 +orgy 9303 +cheaper 9304 +town's 9305 +forgivable 9306 +nutty 9307 +speechless 9308 +nailed 9309 +associates 9310 +platoon 9311 +disdain 9312 +waits 9313 +knox 9314 +it´s 9315 +collecting 9316 +alligator 9317 +hispanic 9318 +mutated 9319 +woven 9320 +hardest 9321 +lubitsch 9322 +january 9323 +apprentice 9324 +uber 9325 +sarne 9326 +pets 9327 +fawcett 9328 +marred 9329 +elevate 9330 +drivers 9331 +creepiness 9332 +revive 9333 +harlem 9334 +vivah 9335 +kindness 9336 +marathon 9337 +bishop 9338 +gannon 9339 +carole 9340 +brits 9341 +submit 9342 +embarrass 9343 +boyfriends 9344 +dreadfully 9345 +oppressive 9346 +discernible 9347 +intruder 9348 +tourists 9349 +conduct 9350 +rehearsal 9351 +bolivia 9352 +astronaut 9353 +joanna 9354 +grounded 9355 +sessions 9356 +cocktail 9357 +stir 9358 +gimmicks 9359 +archive 9360 +stereotyping 9361 +aweigh 9362 +18th 9363 +undeveloped 9364 +rico 9365 +concentrates 9366 +bruckheimer 9367 +psychiatric 9368 +incompetence 9369 +villagers 9370 +customs 9371 +alienate 9372 +slew 9373 +footsteps 9374 +approximately 9375 +discussions 9376 +blink 9377 +vault 9378 +transformers 9379 +sloane 9380 +choke 9381 +infidelity 9382 +relied 9383 +undertaker 9384 +lovingly 9385 +casually 9386 +luzhin 9387 +disappearing 9388 +historians 9389 +shaolin 9390 +mastroianni 9391 +midler 9392 +atrocities 9393 +bash 9394 +inc 9395 +hedy 9396 +drums 9397 +bonding 9398 +entertainer 9399 +revelations 9400 +holland 9401 +floriane 9402 +downtown 9403 +denied 9404 +connor 9405 +stupidest 9406 +tel 9407 +sinatra's 9408 +lyrical 9409 +woke 9410 +knack 9411 +dripping 9412 +saddest 9413 +loathing 9414 +insects 9415 +hoover 9416 +apologize 9417 +premises 9418 +elmer 9419 +screamed 9420 +lecture 9421 +skipping 9422 +bursts 9423 +noam 9424 +passions 9425 +cocky 9426 +prevalent 9427 +regrets 9428 +suspended 9429 +shack 9430 +democracy 9431 +overacts 9432 +enhances 9433 +deathstalker 9434 +1960 9435 +choreographer 9436 +keeler 9437 +cillian 9438 +contemplate 9439 +smarter 9440 +marlene 9441 +philadelphia 9442 +sammi 9443 +kingsley 9444 +micheal 9445 +mpaa 9446 +duryea 9447 +creeps 9448 +capsule 9449 +converted 9450 +zabriskie 9451 +perceive 9452 +confronting 9453 +administration 9454 +arizona 9455 +viggo 9456 +ecstasy 9457 +candidates 9458 +branch 9459 +passenger 9460 +benson 9461 +sans 9462 +victoria's 9463 +callahan 9464 +intestines 9465 +swamp 9466 +sparse 9467 +request 9468 +overseas 9469 +bass 9470 +surpasses 9471 +organs 9472 +rohmer 9473 +montages 9474 +joshua 9475 +ella 9476 +maguire 9477 +rhys 9478 +cloud 9479 +stripped 9480 +rushes 9481 +kentucky 9482 +tensions 9483 +mom's 9484 +operas 9485 +chapters 9486 +monstrous 9487 +usage 9488 +fugitive 9489 +shaun 9490 +slipped 9491 +documents 9492 +email 9493 +classified 9494 +norwegian 9495 +reception 9496 +ash 9497 +sacrificed 9498 +switzerland 9499 +rightfully 9500 +cruella 9501 +psychologically 9502 +bury 9503 +liar 9504 +clumsily 9505 +crow 9506 +mindset 9507 +untrue 9508 +barker 9509 +lange 9510 +toro 9511 +ahmad 9512 +wipe 9513 +sixty 9514 +brink 9515 +insanely 9516 +mourning 9517 +vets 9518 +wu 9519 +1956 9520 +restless 9521 +loop 9522 +fanatics 9523 +rests 9524 +guevara 9525 +connecting 9526 +city's 9527 +friendships 9528 +satellite 9529 +empathize 9530 +surfers 9531 +immersed 9532 +mostel 9533 +squeeze 9534 +backing 9535 +admirably 9536 +confirm 9537 +equals 9538 +vengeful 9539 +pauses 9540 +snippets 9541 +mamet 9542 +that'll 9543 +anchorman 9544 +dense 9545 +strikingly 9546 +daphne 9547 +misplaced 9548 +1941 9549 +streak 9550 +shrink 9551 +garnered 9552 +breathless 9553 +hiv 9554 +delve 9555 +grain 9556 +spectrum 9557 +dusty 9558 +durbin 9559 +locks 9560 +november 9561 +o'neill 9562 +crook 9563 +render 9564 +participation 9565 +deception 9566 +replay 9567 +apartments 9568 +sr 9569 +lawyers 9570 +requisite 9571 +telly 9572 +basil 9573 +kinky 9574 +assist 9575 +spectacularly 9576 +scantily 9577 +prevented 9578 +obscene 9579 +reincarnation 9580 +morgana 9581 +bout 9582 +looney 9583 +adventurous 9584 +sykes 9585 +maverick 9586 +lucio 9587 +travelling 9588 +diabolical 9589 +capt 9590 +promotion 9591 +partial 9592 +eater 9593 +dime 9594 +bathing 9595 +criminally 9596 +underdog 9597 +interpret 9598 +suggestive 9599 +springs 9600 +graves 9601 +spielberg's 9602 +technological 9603 +wan 9604 +cortez 9605 +proverbial 9606 +granger 9607 +phrases 9608 +societies 9609 +thankful 9610 +palette 9611 +outrage 9612 +betrays 9613 +lung 9614 +marquis 9615 +ing 9616 +regal 9617 +oriental 9618 +duties 9619 +whacked 9620 +kerr 9621 +documented 9622 +700 9623 +stoic 9624 +fairytale 9625 +listing 9626 +acknowledged 9627 +allison 9628 +matching 9629 +longtime 9630 +garcia 9631 +elliot 9632 +33 9633 +adopt 9634 +flea 9635 +carlito's 9636 +1940's 9637 +coleman 9638 +draft 9639 +witless 9640 +kramer 9641 +haha 9642 +lap 9643 +alternately 9644 +1930 9645 +sentenced 9646 +harry's 9647 +daisies 9648 +overt 9649 +mining 9650 +stepped 9651 +eliminate 9652 +chains 9653 +regain 9654 +nuance 9655 +italians 9656 +hurting 9657 +honour 9658 +sealed 9659 +societal 9660 +indifference 9661 +lombard 9662 +teamed 9663 +cathy 9664 +its' 9665 +unfinished 9666 +floors 9667 +downside 9668 +tucker 9669 +paperhouse 9670 +compound 9671 +eggs 9672 +underused 9673 +incarnation 9674 +hunk 9675 +goer 9676 +presumed 9677 +caruso 9678 +interpreted 9679 +colourful 9680 +stills 9681 +caroline 9682 +keyboard 9683 +claw 9684 +snappy 9685 +camps 9686 +crop 9687 +sheet 9688 +overnight 9689 +dung 9690 +booze 9691 +risks 9692 +rub 9693 +oddball 9694 +exhibit 9695 +anchor 9696 +fireworks 9697 +batwoman 9698 +gesture 9699 +skinned 9700 +undertones 9701 +achieving 9702 +lanza 9703 +goofs 9704 +flee 9705 +recalls 9706 +stable 9707 +fantastically 9708 +exposing 9709 +shakes 9710 +addressing 9711 +prototype 9712 +carface 9713 +hes 9714 +competently 9715 +retain 9716 +schemes 9717 +hogan 9718 +voting 9719 +episodic 9720 +occurring 9721 +topped 9722 +1954 9723 +norma 9724 +chore 9725 +chang 9726 +shouts 9727 +rainer 9728 +colonial 9729 +recreation 9730 +forum 9731 +companions 9732 +apologies 9733 +insulted 9734 +holidays 9735 +throwaway 9736 +tepid 9737 +darkest 9738 +pulse 9739 +pita 9740 +superiors 9741 +grumpy 9742 +illustrates 9743 +sweetheart 9744 +showtime 9745 +aiello 9746 +btk 9747 +cbc 9748 +baseketball 9749 +horizon 9750 +eliminated 9751 +weirdo 9752 +welch 9753 +stepping 9754 +leno 9755 +beau 9756 +affections 9757 +leopold 9758 +inheritance 9759 +masturbation 9760 +itchy 9761 +locker 9762 +universally 9763 +shadowy 9764 +employ 9765 +skywalker 9766 +grips 9767 +gardens 9768 +sorvino 9769 +expertise 9770 +irwin 9771 +t'aime 9772 +babysitter 9773 +bryan 9774 +positions 9775 +coarse 9776 +tremors 9777 +iceberg 9778 +monumental 9779 +thinner 9780 +allegedly 9781 +dominick 9782 +allied 9783 +bogdanovich 9784 +raving 9785 +supplies 9786 +kaufman 9787 +sacred 9788 +shootings 9789 +primal 9790 +hiring 9791 +hockey 9792 +flamenco 9793 +thirteen 9794 +carlito 9795 +polite 9796 +exudes 9797 +gaining 9798 +darius 9799 +quarters 9800 +willem 9801 +crummy 9802 +duff 9803 +sorta 9804 +rigid 9805 +eponymous 9806 +smitten 9807 +attributed 9808 +variations 9809 +mischievous 9810 +unborn 9811 +wayne's 9812 +circuit 9813 +integrated 9814 +unimpressive 9815 +carson 9816 +150 9817 +siege 9818 +endured 9819 +surrogate 9820 +gifts 9821 +practicing 9822 +disgruntled 9823 +drifter 9824 +renowned 9825 +chef 9826 +operatic 9827 +maiden 9828 +frenetic 9829 +wal 9830 +roaring 9831 +author's 9832 +wondrous 9833 +greta 9834 +gamut 9835 +marital 9836 +gym 9837 +offerings 9838 +zatoichi 9839 +emerged 9840 +exaggeration 9841 +planets 9842 +raft 9843 +connolly 9844 +mcintire 9845 +strangest 9846 +marvellous 9847 +runtime 9848 +misfire 9849 +extremes 9850 +swift 9851 +seinfeld 9852 +jackass 9853 +harmony 9854 +plantation 9855 +bravery 9856 +pavarotti 9857 +catastrophe 9858 +malcolm 9859 +portman 9860 +solving 9861 +albums 9862 +winston 9863 +corky 9864 +allegory 9865 +spears 9866 +saif 9867 +goof 9868 +outta 9869 +virtues 9870 +monstrosity 9871 +ideology 9872 +edits 9873 +celebrating 9874 +adapting 9875 +ferry 9876 +desolate 9877 +jessie 9878 +inflicted 9879 +rocker 9880 +projection 9881 +irs 9882 +cambodia 9883 +enthralled 9884 +ensuing 9885 +leia 9886 +o'toole 9887 +transferred 9888 +exposes 9889 +competing 9890 +yourselves 9891 +sentiments 9892 +kisses 9893 +stray 9894 +turgid 9895 +declares 9896 +nuns 9897 +mercilessly 9898 +it'd 9899 +exceedingly 9900 +ted's 9901 +insecure 9902 +ben's 9903 +tanks 9904 +kusturica 9905 +spaces 9906 +spliced 9907 +sheila 9908 +crowds 9909 +balcony 9910 +menu 9911 +lamas 9912 +diver 9913 +secluded 9914 +integral 9915 +redeemed 9916 +halt 9917 +decapitated 9918 +stealth 9919 +budgeted 9920 +voters 9921 +overweight 9922 +praying 9923 +stevenson 9924 +cleveland 9925 +stakes 9926 +mattei 9927 +charity 9928 +stalk 9929 +olympia 9930 +olympic 9931 +aspirations 9932 +decoration 9933 +slack 9934 +bullying 9935 +bum 9936 +mo 9937 +capitalize 9938 +jameson 9939 +skimpy 9940 +wicker 9941 +starving 9942 +frenchman 9943 +frye 9944 +ate 9945 +monastery 9946 +wb 9947 +hayden 9948 +banana 9949 +grandparents 9950 +vacuous 9951 +willy 9952 +darkman 9953 +neutral 9954 +rumors 9955 +somber 9956 +aunts 9957 +amateurs 9958 +radar 9959 +ounce 9960 +bagdad 9961 +stud 9962 +closeups 9963 +insisted 9964 +jed 9965 +geeky 9966 +64 9967 +aims 9968 +complains 9969 +ewan 9970 +exhausted 9971 +day's 9972 +weaves 9973 +gladly 9974 +misogynistic 9975 +soles 9976 +michel 9977 +uniquely 9978 +interminable 9979 +aristocrat 9980 +paul's 9981 +everybody's 9982 +avant 9983 +answering 9984 +smallest 9985 +contacts 9986 +enlightenment 9987 +murphy's 9988 +employs 9989 +unforgivable 9990 +punchline 9991 +culminating 9992 +talentless 9993 +grabbing 9994 +soulless 9995 +unfairly 9996 +grail 9997 +retrospect 9998 +edged 9999
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/BUILD new file mode 100644 index 0000000..6de3411 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/BUILD
@@ -0,0 +1,21 @@ +package( + default_visibility = ["//tensorflow_lite_support:internal"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(glob(["*.*"])) + +filegroup( + name = "test_models", + srcs = glob([ + "*.tflite", + ]), +) + +filegroup( + name = "test_images", + srcs = glob([ + "*.jpg", + "*.png", + ]), +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/automl_labeler_model.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/automl_labeler_model.tflite new file mode 100644 index 0000000..2c3a3b3d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/automl_labeler_model.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/burger-224.png b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/burger-224.png new file mode 100644 index 0000000..7e1e243 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/burger-224.png Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/burger.jpg b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/burger.jpg new file mode 100644 index 0000000..a94f84e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/burger.jpg Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/burger_crop.jpg b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/burger_crop.jpg new file mode 100644 index 0000000..7aabbb9 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/burger_crop.jpg Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/burger_rotation180.jpg b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/burger_rotation180.jpg new file mode 100644 index 0000000..42cfff3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/burger_rotation180.jpg Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/cats_and_dogs.jpg b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/cats_and_dogs.jpg new file mode 100644 index 0000000..f6be4ab --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/cats_and_dogs.jpg Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/cats_and_dogs_rotation180.jpg b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/cats_and_dogs_rotation180.jpg new file mode 100644 index 0000000..23e80cc --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/cats_and_dogs_rotation180.jpg Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite new file mode 100644 index 0000000..9de3534 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite new file mode 100644 index 0000000..02ecbf8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflite new file mode 100644 index 0000000..6416c4c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/deeplabv3.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/deeplabv3.tflite new file mode 100644 index 0000000..e3fdb7d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/deeplabv3.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/dilated_conv.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/dilated_conv.tflite new file mode 100644 index 0000000..e871a26 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/dilated_conv.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v1_0.25_224_1_default_1.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v1_0.25_224_1_default_1.tflite new file mode 100644 index 0000000..78a6fad --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v1_0.25_224_1_default_1.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v1_0.25_224_1_metadata_1.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v1_0.25_224_1_metadata_1.tflite new file mode 100644 index 0000000..443609b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v1_0.25_224_1_metadata_1.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v1_0.25_224_quant.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v1_0.25_224_quant.tflite new file mode 100644 index 0000000..18a20f9 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v1_0.25_224_quant.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite new file mode 100644 index 0000000..c29c233 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v2_1.0_224.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v2_1.0_224.tflite new file mode 100644 index 0000000..e123012 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v2_1.0_224.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v2_1.0_224_without_labels.json b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v2_1.0_224_without_labels.json new file mode 100644 index 0000000..7fd6567 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v2_1.0_224_without_labels.json
@@ -0,0 +1,28 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ] + } + ], + "output_tensor_metadata": [ + { + } + ] + } + ], + "min_parser_version": "1.0.0" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v2_1.0_224_without_labels.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v2_1.0_224_without_labels.tflite new file mode 100644 index 0000000..acf3e26 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v2_1.0_224_without_labels.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v3_small_100_224_embedder.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v3_small_100_224_embedder.tflite new file mode 100644 index 0000000..7a48343 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v3_small_100_224_embedder.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/multi_objects.jpg b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/multi_objects.jpg new file mode 100644 index 0000000..992684d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/multi_objects.jpg Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_golden_rotation0.png b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_golden_rotation0.png new file mode 100644 index 0000000..3a44939 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_golden_rotation0.png Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_golden_rotation0_yuv.png b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_golden_rotation0_yuv.png new file mode 100644 index 0000000..79f255d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_golden_rotation0_yuv.png Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_golden_rotation90_flop.png b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_golden_rotation90_flop.png new file mode 100644 index 0000000..7eefb94 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_golden_rotation90_flop.png Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_input_rotation0.jpg b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_input_rotation0.jpg new file mode 100644 index 0000000..1b79c037 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_input_rotation0.jpg Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_input_rotation90_flop.jpg b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_input_rotation90_flop.jpg new file mode 100644 index 0000000..51f3bb0 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/segmentation_input_rotation90_flop.jpg Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/sparrow.png b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/sparrow.png new file mode 100644 index 0000000..17eec7b1 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/sparrow.png Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc index 1e3506b..53c8831 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc
@@ -61,7 +61,7 @@ } WordpieceTokenizerResult BertTokenizer::TokenizeWordpiece( - const std::string& input) { + const std::string& input) const { WordpieceTokenizerResult result; std::vector<std::string>& subwords = result.subwords; std::vector<int>& wp_absolute_begin_offset = result.wp_begin_offset;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h index 533a8ee..1de54fa 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h
@@ -20,7 +20,7 @@ #include <string> #include <vector> -#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_map.h" // from @com_google_absl #include "re2/re2.h" #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" #include "tensorflow_lite_support/cc/utils/common_utils.h" @@ -115,7 +115,7 @@ // Perform tokenization, return wordpiece-specific tokenized result including // subwords and offsets - WordpieceTokenizerResult TokenizeWordpiece(const std::string& input); + WordpieceTokenizerResult TokenizeWordpiece(const std::string& input) const; // Check if a certain key is included in the vocab. tensorflow::text::LookupStatus Contains(const absl::string_view key,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc index bcf13d6..249bc2d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc
@@ -17,7 +17,7 @@ #include <string> -#include "absl/memory/memory.h" +#include "absl/memory/memory.h" // from @com_google_absl #include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h" #include "tensorflow_lite_support/cc/utils/jni_utils.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc index 44c43b2d..ded6fbd 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
@@ -17,8 +17,8 @@ #include <iostream> -#include "absl/strings/str_cat.h" -#include "absl/strings/substitute.h" +#include "absl/strings/str_cat.h" // from @com_google_absl +#include "absl/strings/substitute.h" // from @com_google_absl #include "tensorflow_lite_support/cc/utils/common_utils.h" namespace tflite { namespace support {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h index c53ae49..a82f500 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h
@@ -16,7 +16,7 @@ #ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_REGEX_TOKENIZER_H_ #define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_REGEX_TOKENIZER_H_ -#include "absl/container/node_hash_map.h" +#include "absl/container/node_hash_map.h" // from @com_google_absl #include "re2/re2.h" #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc index 6711fe4f..8ca14c5 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc
@@ -20,8 +20,8 @@ #include <utility> #include <vector> -#include "absl/memory/memory.h" -#include "absl/strings/str_split.h" +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/strings/str_split.h" // from @com_google_absl #include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h" #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h" #include "tensorflow_lite_support/cc/utils/jni_utils.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h index ed5d3da7..d6e8631 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h
@@ -20,7 +20,7 @@ #include <string> #include <vector> -#include "src/sentencepiece_processor.h" +#include "src/sentencepiece_processor.h" // from @com_google_sentencepiece #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" namespace tflite {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer.h index c754506..9c86b2b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer.h
@@ -20,7 +20,7 @@ #include <string> #include <vector> -#include "absl/strings/string_view.h" +#include "absl/strings/string_view.h" // from @com_google_absl namespace tflite { namespace support {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h index f3758c8..fd76f3a 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
@@ -17,8 +17,6 @@ #include <jni.h> -#include <string> - #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" #include "tensorflow_lite_support/cc/utils/jni_utils.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc index 46786fd7..32957d1 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
@@ -15,7 +15,7 @@ #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h" -#include "absl/status/status.h" +#include "absl/status/status.h" // from @com_google_absl #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/status_macros.h" #include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" @@ -28,7 +28,6 @@ namespace tokenizer { using ::tflite::ProcessUnit; -using ::tflite::SentencePieceTokenizerOptions; using ::tflite::support::CreateStatusWithPayload; using ::tflite::support::StatusOr; using ::tflite::support::TfLiteSupportStatus; @@ -72,6 +71,12 @@ return absl::make_unique<BertTokenizer>(vocab_buffer.data(), vocab_buffer.size()); } + case ProcessUnitOptions_SentencePieceTokenizerOptions: { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Chromium does not support sentencepiece tokenization", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } case ProcessUnitOptions_RegexTokenizerOptions: { const tflite::RegexTokenizerOptions* options = tokenizer_process_unit->options_as<RegexTokenizerOptions>(); @@ -106,7 +111,7 @@ TfLiteSupportStatus::kMetadataInvalidTokenizerError); } - return regex_tokenizer; + return std::move(regex_tokenizer); } default: return CreateStatusWithPayload(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/BUILD index 07c832f..fc6de34 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/BUILD
@@ -1,9 +1,11 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + package( - default_visibility = ["//tensorflow_lite_support:users"], + default_visibility = ["//tensorflow_lite_support:internal"], licenses = ["notice"], # Apache 2.0 ) -cc_library( +cc_library_with_tflite( name = "jni_utils", srcs = [ "jni_utils.cc", @@ -11,8 +13,24 @@ hdrs = [ "jni_utils.h", ], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:delegate_plugin", + "@org_tensorflow//tensorflow/lite/core/shims:delegate_registry", + "@org_tensorflow//tensorflow/lite/core/shims:jni_initialization", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:delegate_plugin_converter", + ], + visibility = [ + "//tensorflow_lite_support:internal", + ], deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:configuration_proto_inc", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/lite/java/jni", ], ) @@ -25,6 +43,9 @@ hdrs = [ "common_utils.h", ], + visibility = [ + "//tensorflow_lite_support:internal", + ], deps = [ "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc index 17c4823..3ea6b14 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc
@@ -17,7 +17,7 @@ #include <fstream> -#include "absl/strings/str_split.h" +#include "absl/strings/str_split.h" // from @com_google_absl namespace tflite { namespace support {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h index ef1bf33..275c4932 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h
@@ -19,7 +19,7 @@ #include <string> #include <vector> -#include "absl/container/node_hash_map.h" +#include "absl/container/node_hash_map.h" // from @com_google_absl namespace tflite { namespace support {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc index fd0887f..35ce8229 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc
@@ -15,11 +15,100 @@ #include "tensorflow_lite_support/cc/utils/jni_utils.h" +#include <dlfcn.h> #include <string.h> +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow/lite/core/shims/c/experimental/acceleration/configuration/delegate_plugin.h" +#include "tensorflow/lite/core/shims/cc/experimental/acceleration/configuration/delegate_registry.h" +#include "tensorflow/lite/experimental/acceleration/configuration/delegate_plugin_converter.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" + namespace tflite { namespace support { namespace utils { +namespace { + +using ::absl::StatusCode; +using ::tflite::proto::Delegate; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite_shims::delegates::DelegatePluginRegistry; + +// delegate_name should be one of the following: +// gpu / hexagon +absl::Status loadDelegatePluginLibrary(const std::string& delegate_name) { + // Load "lib<delegate_name>_plugin.so". + std::string lib_name = + absl::StrFormat("lib%s_delegate_plugin.so", delegate_name); + + // Choosing RTLD_NOW over RTLD_LAZY: RTLD_NOW loads symbols now and + // makes sure there's no unresolved symbols. Using RTLD_LAZY will not + // discover unresolved symbols issues right away, and may lead to crash later + // during inference, which should be avoided. + // Choosing RTLD_LOCAL over RTLD_GLOBAL: the symbols should not be available + // for subsequently loaded libraries. + // Not choosing RTLD_DEEPBIND due to portability concerns; also we're using a + // linker script to hide internal symbols, so we don't really need it. + // Not choosing RTLD_NODELETE to avoid a (bounded) memory leak: + // if we used RTLD_NODELETE, dlclose() would not free the memory for the + // library. + void* handle = dlopen(lib_name.c_str(), RTLD_NOW | RTLD_LOCAL); + if (!handle) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Error loading %s. %s", lib_name, dlerror())); + } + + // Load the method "TfLite<camel_name>DelegatePluginCApi". + std::string camel_name(delegate_name); + camel_name[0] = toupper(camel_name[0]); + std::string new_method_name = + absl::StrFormat("TfLite%sDelegatePluginCApi", camel_name); + TfLiteOpaqueDelegatePlugin* (*new_delegate_c)(); + new_delegate_c = reinterpret_cast<decltype(new_delegate_c)>( + dlsym(handle, new_method_name.c_str())); + if (!new_delegate_c) { + // Ignore the return value of dlclose as we deliberately hide it from users. + dlclose(handle); + handle = nullptr; + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Error loading method, %s from %s", new_method_name, + lib_name)); + } + + // Register the delegate. + new DelegatePluginRegistry::Register( + absl::StrFormat("%sPlugin", camel_name), + tflite::delegates::DelegatePluginConverter(*new_delegate_c())); + + return absl::OkStatus(); +} + +} // namespace + +tflite::support::StatusOr<Delegate> ConvertToProtoDelegate(jint delegate) { + // The supported delegate types should match + // org.tensorflow.lite.task.core.ComputeSettings.Delegate. + switch (delegate) { + case 0: + return Delegate::NONE; + case 1: + return Delegate::NNAPI; + case 2: + RETURN_IF_ERROR(loadDelegatePluginLibrary("gpu")); + return Delegate::GPU; + default: + break; + } + // Should never happen. + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("The delegate type is unsupported: %d", delegate)); +} std::string JStringToString(JNIEnv* env, jstring jstr) { if (jstr == nullptr) { @@ -96,6 +185,26 @@ env->ThrowNew(e_class, message); } +const char* GetExceptionClassNameForStatusCode(StatusCode status_code) { + switch (status_code) { + case StatusCode::kOk: + return nullptr; + case StatusCode::kInvalidArgument: + return kIllegalArgumentException; + // TODO(b/197650198): Uncomment this before the next major version bump + // and update the signature, as IOException is a checked exception. + // case StatusCode::kNotFound: + // return kIOException; + case StatusCode::kInternal: + return kIllegalStateException; + // kUnknown and all other status codes are mapped to a generic + // RuntimeException. + case StatusCode::kUnknown: + default: + return kRuntimeException; + } +} + } // namespace utils } // namespace support } // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h index 1a215e7..7caf49e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h
@@ -22,7 +22,10 @@ #include <string> #include <vector> -#include "absl/strings/string_view.h" +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/configuration_proto_inc.h" +#include "tensorflow_lite_support/cc/port/statusor.h" namespace tflite { namespace support { @@ -32,6 +35,8 @@ const char kIllegalStateException[] = "java/lang/IllegalStateException"; const char kNullPointerException[] = "java/lang/NullPointerException"; const char kIndexOutOfBoundsException[] = "java/lang/IndexOutOfBoundsException"; +const char kIOException[] = "java/io/IOException"; +const char kRuntimeException[] = "java/lang/RuntimeException"; const char kUnsupportedOperationException[] = "java/lang/UnsupportedOperationException"; const char kAssertionError[] = "java/lang/AssertionError"; @@ -71,6 +76,10 @@ return array_list_object; } +// Converts delegate Java int type to delegate proto type. +tflite::support::StatusOr<tflite::proto::Delegate> ConvertToProtoDelegate( + jint delegate); + std::string JStringToString(JNIEnv* env, jstring jstr); std::vector<std::string> StringListToVector(JNIEnv* env, jobject list_object); @@ -87,6 +96,8 @@ const char* clazz, const char* message); +const char* GetExceptionClassNameForStatusCode(absl::StatusCode status_code); + } // namespace utils } // namespace support } // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/codegen/BUILD index c56bc6c..b224f98 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/BUILD
@@ -15,6 +15,8 @@ hdrs = [ "utils.h", ], + deps = [ + ], ) cc_library(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc index 0754bfe..bb8f1f4 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc
@@ -656,7 +656,7 @@ public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath, Model.Options runningOptions) throws IOException { Model model = Model.createModel(context, modelPath, runningOptions); Metadata metadata = new Metadata(model.getData(), model); - MyImageClassifier instance = new MyImageClassifier(model, metadata);)"); + {{MODEL_CLASS_NAME}} instance = new {{MODEL_CLASS_NAME}}(model, metadata);)"); for (const auto& tensor : model.inputs) { SetCodeWriterWithTensorInfo(code_writer, tensor); code_writer->Append( @@ -779,7 +779,11 @@ code_writer->Append(R"(buildscript { repositories { google() - jcenter() + mavenCentral() // For versioned releases + maven { // For snapshot releases + name 'ossrh-snapshot' + url 'http://oss.sonatype.org/content/repositories/snapshots' + } } dependencies { classpath 'com.android.tools.build:gradle:3.2.1' @@ -835,8 +839,8 @@ dependencies { compileOnly 'org.checkerframework:checker-qual:2.5.8' - api 'org.tensorflow:tensorflow-lite:0.0.0-nightly' - api 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly' + api 'org.tensorflow:tensorflow-lite:0.0.0-nightly-SNAPSHOT' + api 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly-SNAPSHOT' api files("$buildDir/libs/tensorflow-lite-support-metadata.jar") implementation 'org.apache.commons:commons-compress:1.19' })");
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/BUILD index f47288b..11db944 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/BUILD
@@ -11,24 +11,27 @@ name = "tf_text_flex_delegate", additional_deps = ["@org_tensorflow_text//tensorflow_text:ops_lib"], models = [ - # TODO(b/160817619) Replace with a more complex model. "testdata/sentencepiece_tokenizer_flex_op.tflite", + "testdata/wiki40b-lm-en.tflite", ], ) # bazel test --config=monolithic tensorflow_lite_support/custom_ops:tflite_inference_test cc_test( name = "tflite_inference_test", - srcs = ["tflite_inference_main.cc"], - args = ["--model=tensorflow_lite_support/custom_ops/testdata/sentencepiece_tokenizer_flex_op.tflite"], - data = ["//tensorflow_lite_support/custom_ops:testdata/sentencepiece_tokenizer_flex_op.tflite"], + srcs = ["tflite_inference_test.cc"], + data = [ + "//tensorflow_lite_support/custom_ops:testdata/sentencepiece_tokenizer_flex_op.tflite", + "//tensorflow_lite_support/custom_ops:testdata/wiki40b-lm-en.tflite", + ], deps = [ ":tf_text_flex_delegate", + "//testing/base/public:gunit", "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:string_util", "@org_tensorflow//tensorflow/lite/c:common", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - "@org_tensorflow//tensorflow/lite/tools:command_line_flags", + "@com_google_googletest//:gtest_main", ] + select({ "@org_tensorflow//tensorflow:android": [ "@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite", @@ -41,19 +44,3 @@ ], }), ) - -# We expect several libraries to already be installed on the system, e.g. via -# `pip install numpy` -py_library( - name = "expect_numpy_installed", -) - -# `pip install tensorflow` -py_library( - name = "expect_tfpy_installed", -) - -# `pip install tensorflow_text` -py_library( - name = "expect_tftext_installed", -)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/BUILD index 5d63917..300dbd1 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/BUILD
@@ -1,4 +1,4 @@ -# Placeholder for internal Python strict compatibility macro. +# Placeholder for internal Python strict test compatibility macro. load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") package( @@ -70,9 +70,9 @@ srcs_version = "PY3", deps = [ ":_pywrap_whitespace_tokenizer_op_resolver", - "//tensorflow_lite_support/custom_ops:expect_numpy_installed", - "//tensorflow_lite_support/custom_ops:expect_tfpy_installed", - "//tensorflow_lite_support/custom_ops:expect_tftext_installed", + # build rule placeholder: numpy dep, + # build rule placeholder: tensorflow dep, + # build rule placeholder: tensorflow_text dep, "@absl_py//absl/logging", "@absl_py//absl/testing:parameterized", ], @@ -137,8 +137,8 @@ srcs_version = "PY3", deps = [ ":_pywrap_ngrams_op_resolver", - "//tensorflow_lite_support/custom_ops:expect_tfpy_installed", - "//tensorflow_lite_support/custom_ops:expect_tftext_installed", + # build rule placeholder: tensorflow dep, + # build rule placeholder: tensorflow_text dep, "//tensorflow_lite_support/custom_ops/python:tflite_text_api", "@absl_py//absl/logging", "@absl_py//absl/testing:parameterized",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.py b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.py index e52ca28..33a2231 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.py
@@ -71,14 +71,12 @@ @tf.function(input_signature=input_signature) def __call__(self, values, *args): row_splits = list(args) - row_splits.reverse() input_tensor = tf.RaggedTensor.from_nested_row_splits( flat_values=values, nested_row_splits=tuple(row_splits)) output_tensor = ngrams( input_tensor, width, reduction_type=tf_text.Reduction.STRING_JOIN) output = [output_tensor.flat_values] output.extend(list(output_tensor.nested_row_splits)) - output.reverse() return tuple(output) tf.saved_model.save(Model(), temp_dir) @@ -151,29 +149,21 @@ input_tensor = tf.ragged.constant(test_case) tf_output = tf_text.ngrams( input_tensor, 2, reduction_type=tf_text.Reduction.STRING_JOIN) - rank = input_tensor.shape.rank model = self._make_model(rank, 2, ragged_tensor=True, flex=False) interpreter = interpreter_wrapper.InterpreterWithCustomOps( model_content=model, custom_op_registerers=['AddNgramsCustomOp']) - interpreter.resize_tensor_input(0, input_tensor.flat_values.shape) + signature_fn = interpreter.get_signature_runner() + signature_kwargs = {} + signature_kwargs['values'] = input_tensor.flat_values.numpy() for r in range(rank - 1): - interpreter.resize_tensor_input(r + 1, - input_tensor.nested_row_splits[r].shape) - interpreter.allocate_tensors() - interpreter.set_tensor(interpreter.get_input_details()[0]['index'], - input_tensor.flat_values.numpy()) - for r in range(rank - 1): - interpreter.set_tensor(interpreter.get_input_details()[r + 1]['index'], - input_tensor.nested_row_splits[r].numpy()) - interpreter.invoke() - tflite_output_values = interpreter.get_tensor( - interpreter.get_output_details()[0]['index']) + signature_kwargs[f'args_{r}'] = input_tensor.nested_row_splits[r].numpy() + output = signature_fn(**signature_kwargs) + tflite_output_values = output['output_0'] self.assertEqual(tf_output.flat_values.numpy().tolist(), tflite_output_values.tolist()) for i in range(rank - 1): - tflite_output_cur_row_splits = interpreter.get_tensor( - interpreter.get_output_details()[i + 1]['index']) + tflite_output_cur_row_splits = output[f'output_{i + 1}'] self.assertEqual(tf_output.nested_row_splits[i].numpy().tolist(), tflite_output_cur_row_splits.tolist()) @@ -187,24 +177,18 @@ model = self._make_model(rank, 3, ragged_tensor=True, flex=False) interpreter = interpreter_wrapper.InterpreterWithCustomOps( model_content=model, custom_op_registerers=['AddNgramsCustomOp']) - interpreter.resize_tensor_input(0, input_tensor.flat_values.shape) + signature_fn = interpreter.get_signature_runner() + signature_kwargs = {} + signature_kwargs['values'] = ( + input_tensor.flat_values.numpy().astype('bytes')) for r in range(rank - 1): - interpreter.resize_tensor_input(r + 1, - input_tensor.nested_row_splits[r].shape) - interpreter.allocate_tensors() - interpreter.set_tensor(interpreter.get_input_details()[0]['index'], - input_tensor.flat_values.numpy()) - for r in range(rank - 1): - interpreter.set_tensor(interpreter.get_input_details()[r + 1]['index'], - input_tensor.nested_row_splits[r].numpy()) - interpreter.invoke() - tflite_output_values = interpreter.get_tensor( - interpreter.get_output_details()[0]['index']) + signature_kwargs[f'args_{r}'] = input_tensor.nested_row_splits[r].numpy() + output = signature_fn(**signature_kwargs) + tflite_output_values = output['output_0'] self.assertEqual(tf_output.flat_values.numpy().tolist(), tflite_output_values.tolist()) for i in range(rank - 1): - tflite_output_cur_row_splits = interpreter.get_tensor( - interpreter.get_output_details()[i + 1]['index']) + tflite_output_cur_row_splits = output[f'output_{i+1}'] self.assertEqual(tf_output.nested_row_splits[i].numpy().tolist(), tflite_output_cur_row_splits.tolist()) @@ -217,19 +201,14 @@ model = self._make_model(rank, 3, ragged_tensor=True, flex=False) interpreter = interpreter_wrapper.InterpreterWithCustomOps( model_content=model, custom_op_registerers=['AddNgramsCustomOp']) - interpreter.resize_tensor_input(0, input_tensor.flat_values.shape) - for r in range(rank - 1): - interpreter.resize_tensor_input(r + 1, - input_tensor.nested_row_splits[r].shape) - interpreter.allocate_tensors() - interpreter.set_tensor(interpreter.get_input_details()[0]['index'], - input_tensor.flat_values.numpy()) - for r in range(rank - 1): - interpreter.set_tensor(interpreter.get_input_details()[r + 1]['index'], - input_tensor.nested_row_splits[r].numpy()) - start_time = timeit.default_timer() - for _ in range(INVOKES_FOR_SINGLE_OP_BENCHMARK): - interpreter.invoke() + signature_fn = interpreter.get_signature_runner() + signature_kwargs = {} + signature_kwargs['values'] = input_tensor.flat_values.numpy() + for r in range(rank - 1): + signature_kwargs[f'args_{r}'] = input_tensor.nested_row_splits[r].numpy() + start_time = timeit.default_timer() + for _ in range(INVOKES_FOR_SINGLE_OP_BENCHMARK): + _ = signature_fn(**signature_kwargs) latency_op = latency_op + timeit.default_timer() - start_time latency_op = latency_op / ( INVOKES_FOR_SINGLE_OP_BENCHMARK * len(TEST_CASES)) @@ -241,20 +220,17 @@ rank = input_tensor.shape.rank model = self._make_model(rank, 3, ragged_tensor=True, flex=True) interpreter = interpreter_wrapper.Interpreter(model_content=model) - interpreter.resize_tensor_input(0, input_tensor.flat_values.shape) + signature_fn = interpreter.get_signature_runner() + signature_kwargs = {} + signature_kwargs['values'] = input_tensor.flat_values.numpy() + for r in range(rank - 1): - interpreter.resize_tensor_input(r + 1, - input_tensor.nested_row_splits[r].shape) - interpreter.allocate_tensors() - interpreter.set_tensor(interpreter.get_input_details()[0]['index'], - input_tensor.flat_values.numpy()) - for r in range(rank - 1): - interpreter.set_tensor(interpreter.get_input_details()[r + 1]['index'], - input_tensor.nested_row_splits[r].numpy()) + signature_kwargs[f'args_{r}'] = input_tensor.nested_row_splits[r].numpy( + ) start_time = timeit.default_timer() for _ in range(INVOKES_FOR_FLEX_DELEGATE_BENCHMARK): - interpreter.invoke() - latency_flex = latency_flex + timeit.default_timer() - start_time + _ = signature_fn(**signature_kwargs) + latency_flex = latency_flex + timeit.default_timer() - start_time latency_flex = latency_flex / ( INVOKES_FOR_FLEX_DELEGATE_BENCHMARK * len(TEST_CASES))
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/BUILD index 38a55c6..a512cdc8 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/BUILD
@@ -1,7 +1,5 @@ # RaggedTensors suppport in TFLite -load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") - package( default_visibility = ["//tensorflow_lite_support:users"], licenses = ["notice"], # Apache 2.0 @@ -51,24 +49,6 @@ alwayslink = 1, ) -pybind_extension( - name = "pywrap_tflite_registerer", - srcs = [ - "pywrap_tflite_registerer.cc", - ], - hdrs = ["py_tflite_registerer.h"], - additional_exported_symbols = ["TFLite_RaggedTensorToTensorRegisterer"], - module_name = "pywrap_tflite_registerer", - srcs_version = "PY3ONLY", - deps = [ - ":py_tflite_registerer", - "@local_config_python//:python_headers", - "@org_tensorflow//tensorflow/lite:framework", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - "@pybind11", - ], -) - cc_library( name = "ragged_range_tflite", srcs = ["ragged_range_tflite.cc"],
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py/BUILD new file mode 100644 index 0000000..650ab90 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py/BUILD
@@ -0,0 +1,27 @@ +# Python wrapper used for test. + +load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") + +package( + default_visibility = [ + "//tensorflow_lite_support:users", + ], + licenses = ["notice"], # Apache 2.0 +) + +pybind_extension( + name = "pywrap_tflite_registerer", + srcs = [ + "pywrap_tflite_registerer.cc", + ], + additional_exported_symbols = ["TFLite_RaggedTensorToTensorRegisterer"], + module_name = "pywrap_tflite_registerer", + srcs_version = "PY3ONLY", + deps = [ + "//tensorflow_lite_support/custom_ops/kernel/ragged:py_tflite_registerer", + "@local_config_python//:python_headers", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@pybind11", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/pywrap_tflite_registerer.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py/pywrap_tflite_registerer.cc similarity index 100% rename from third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/pywrap_tflite_registerer.cc rename to third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py/pywrap_tflite_registerer.cc
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/BUILD index 2bdb5754..79982c2 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/BUILD
@@ -4,7 +4,6 @@ load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load(":native.bzl", "micore_tf_copts", "micore_tf_deps") -load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") package( default_visibility = [ @@ -274,6 +273,8 @@ ":encoder_config", ":model_converter", ":optimized_encoder", + "//tensorflow_lite_support/cc/test:test_utils", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest_main", @@ -296,6 +297,8 @@ ":double_array_trie_builder", ":model_converter", ":optimized_decoder", + "//tensorflow_lite_support/cc/test:test_utils", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest_main", @@ -318,47 +321,6 @@ alwayslink = 1, ) -pybind_extension( - name = "pywrap_tflite_registerer", - srcs = [ - "pywrap_tflite_registerer.cc", - ], - hdrs = ["py_tflite_registerer.h"], - additional_exported_symbols = ["TFLite_SentencepieceTokenizerRegisterer"], - module_name = "pywrap_tflite_registerer", - srcs_version = "PY3ONLY", - deps = [ - ":py_tflite_registerer", - "@local_config_python//:python_headers", - "@org_tensorflow//tensorflow/lite:framework", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - "@pybind11", - ], -) - -pybind_extension( - name = "pywrap_model_converter", - srcs = ["pywrap_model_converter.cc"], - hdrs = ["model_converter.h"], - additional_exported_symbols = [ - "ConvertSentencepieceModel", - "ConvertSentencepieceModelForDecoder", - "GetVocabularySize", - ], - copts = ["-fexceptions"], - features = ["-use_header_modules"], - module_name = "pywrap_model_converter", - srcs_version = "PY3ONLY", - deps = [ - ":model_converter", - "//tensorflow_lite_support/cc/port:statusor", - "@com_google_absl//absl/status", - "@local_config_python//:python_headers", - "@org_tensorflow//tensorflow/lite:framework", - "@pybind11", - ], -) - config_setting( name = "armeabi_v7a_and_fastbuild", values = {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.cc index 72b7262..75388ce 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.cc
@@ -18,7 +18,7 @@ #include <algorithm> #include <memory> -#include "include/darts.h" +#include "include/darts.h" // from @darts_clone namespace tflite { namespace ops {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc index 53f4f66..47ba9fd 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc
@@ -15,9 +15,9 @@ #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h" -#include "absl/status/status.h" -#include "absl/strings/str_replace.h" -#include "src/sentencepiece_model.pb.h" +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_replace.h" // from @com_google_absl +#include "src/sentencepiece_model.pb.h" // from @com_google_sentencepiece #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config_generated.h" #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h" #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h" @@ -94,8 +94,10 @@ const auto pieces_trie_fbs = pieces_trie_builder.Finish(); // Converting normalization. - const auto [normalization_trie, normalization_strings] = + const auto normalization = DecodePrecompiledCharsmap(model_config.normalizer_spec()); + const auto normalization_trie = std::get<0>(normalization); + const auto normalization_strings = std::get<1>(normalization); const auto normalization_trie_vector = builder.CreateVector(normalization_trie); TrieBuilder normalization_trie_builder(builder);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc index 04d1c85..94161c2 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc
@@ -19,9 +19,12 @@ #include <gmock/gmock.h> #include <gtest/gtest.h> -#include "src/sentencepiece.pb.h" -#include "src/sentencepiece_processor.h" +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece +#include "src/sentencepiece_processor.h" // from @com_google_sentencepiece #include "tensorflow/core/platform/env.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h" namespace tflite { @@ -33,8 +36,8 @@ tensorflow::Status TFReadFileToString(const std::string& filepath, std::string* data) { - return tensorflow::ReadFileToString(tensorflow::Env::Default(), - /*test_path*/ filepath, data); + return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath, + data); } absl::Status StdReadFileToString(const std::string& filepath, @@ -54,22 +57,25 @@ } // namespace internal namespace { +using ::tflite::task::JoinPath; + static char kConfigFilePath[] = - "tensorflow_lite_support/custom_ops/kernel/" + "/tensorflow_lite_support/custom_ops/kernel/" "sentencepiece/testdata/sentencepiece.model"; TEST(OptimizedEncoder, ConfigConverter) { std::string config; - auto status = internal::StdReadFileToString(kConfigFilePath, &config); + auto status = internal::StdReadFileToString( + JoinPath("./" /*test src dir*/, kConfigFilePath), &config); ASSERT_TRUE(status.ok()); ::sentencepiece::SentencePieceProcessor processor; - ASSERT_OK(processor.LoadFromSerializedProto(config)); + ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok()); const auto converted_model = ConvertSentencepieceModelForDecoder(config); const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95"); ::sentencepiece::SentencePieceText reference_encoded; - CHECK_OK(processor.Encode(test_string, &reference_encoded)); + ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok()); std::vector<int> encoded_vector; encoded_vector.reserve(reference_encoded.pieces_size()); @@ -77,7 +83,7 @@ encoded_vector.push_back(piece.id()); } std::string ref_decoded; - ASSERT_OK(processor.Decode(encoded_vector, &ref_decoded)); + ASSERT_TRUE(processor.Decode(encoded_vector, &ref_decoded).ok()); const auto decoded = DecodeString(encoded_vector, converted_model.data()); ASSERT_EQ(decoded.type, DecoderResultType::SUCCESS); ASSERT_EQ(ref_decoded, decoded.decoded);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc index b79c513..4148f8e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc
@@ -39,7 +39,9 @@ std::vector<int> result_offsets; result_offsets.reserve(offsets.size()); for (int i = 0, j = 0; i < input.size();) { - auto [consumed, new_string] = pc(input.data() + i, input.size() - i); + auto result = pc(input.data() + i, input.size() - i); + auto consumed = std::get<0>(result); + auto new_string = std::get<1>(result); if (consumed == 0) { // Skip the current byte and move forward. result_string.push_back(input[i]);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc index 47a5c04..dd956a2 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc
@@ -19,11 +19,13 @@ #include <gmock/gmock.h> #include <gtest/gtest.h> -#include "absl/status/status.h" -#include "absl/strings/str_format.h" -#include "src/sentencepiece.pb.h" -#include "src/sentencepiece_processor.h" +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece +#include "src/sentencepiece_processor.h" // from @com_google_sentencepiece #include "tensorflow/core/platform/env.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h" #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h" #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h" @@ -37,8 +39,8 @@ tensorflow::Status TFReadFileToString(const std::string& filepath, std::string* data) { - return tensorflow::ReadFileToString(tensorflow::Env::Default(), - /*test_path*/ filepath, data); + return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath, + data); } absl::Status StdReadFileToString(const std::string& filepath, @@ -58,8 +60,10 @@ namespace { +using ::tflite::task::JoinPath; + static char kConfigFilePath[] = - "tensorflow_lite_support/custom_ops/kernel/" + "/tensorflow_lite_support/custom_ops/kernel/" "sentencepiece/testdata/sentencepiece.model"; TEST(OptimizedEncoder, NormalizeStringWhitestpaces) { @@ -71,12 +75,16 @@ FinishEncoderConfigBuffer(builder, ecb.Finish()); const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); { - const auto [res_string, offsets] = NormalizeString("x y", *config); + const auto result = NormalizeString("x y", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y"); EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3)); } { - const auto [res_string, offsets] = NormalizeString("\tx y\n", *config); + const auto result = NormalizeString("\tx y\n", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y"); EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4)); } @@ -101,8 +109,9 @@ FinishEncoderConfigBuffer(builder, ecb.Finish()); const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); { - const auto [res_string, offsets] = - NormalizeString("ABAABAAABAAAA", *config); + const auto result = NormalizeString("ABAABAAABAAAA", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); EXPECT_EQ(res_string, "A1BA2BA3BA4"); EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9)); @@ -129,8 +138,9 @@ FinishEncoderConfigBuffer(builder, ecb.Finish()); const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); { - const auto [res_string, offsets] = - NormalizeString("XXABAABAAABAAAA", *config); + const auto result = NormalizeString("XXABAABAAABAAAA", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); EXPECT_EQ(res_string, " A1BA2BA3BA4"); EXPECT_THAT(offsets, ::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11)); @@ -139,7 +149,8 @@ TEST(OptimizedEncoder, ConfigConverter) { std::string config; - auto status = internal::StdReadFileToString(kConfigFilePath, &config); + auto status = internal::StdReadFileToString( + JoinPath("./" /*test src dir*/, kConfigFilePath), &config); ASSERT_TRUE(status.ok()); ::sentencepiece::SentencePieceProcessor processor;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/pywrap_model_converter.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/pywrap_model_converter.cc deleted file mode 100644 index f74c156..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/pywrap_model_converter.cc +++ /dev/null
@@ -1,45 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" -#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace sentencepiece { - -namespace py = pybind11; - -PYBIND11_MODULE(pywrap_model_converter, m) { - m.def("convert_sentencepiece_model", [](py::bytes model_string) { - return py::bytes(ConvertSentencepieceModel(std::string(model_string))); - }); - - m.def("convert_sentencepiece_model_for_decoder", [](py::bytes model_string) { - return py::bytes( - ConvertSentencepieceModelForDecoder(std::string(model_string))); - }); - - m.def("get_vocabulary_size", [](py::bytes model_string) { - return GetVocabularySize(std::string(model_string)); - }); -} - -} // namespace sentencepiece -} // namespace custom -} // namespace ops -} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/pywrap_tflite_registerer.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/pywrap_tflite_registerer.cc deleted file mode 100644 index dc380c6..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/pywrap_tflite_registerer.cc +++ /dev/null
@@ -1,35 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "pybind11/pybind11.h" -#include "pybind11/pytypes.h" -#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h" - -PYBIND11_MODULE(pywrap_tflite_registerer, m) { - m.doc() = R"pbdoc( - pywrap_tflite_registerer - A module with a wrapper that adds to a Python wrapper for TFLite - sentencepiece tokenizer. - )pbdoc"; - m.def( - "TFLite_SentencepieceTokenizerRegisterer", - [](uintptr_t resolver) { - TFLite_SentencepieceTokenizerRegisterer( - reinterpret_cast<tflite::MutableOpResolver*>(resolver)); - }, - R"pbdoc( - The function that adds Sentencepiece Tokenizer to the TFLite interpreter. - )pbdoc"); -}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/BUILD index dad8f28..33d3f08 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/BUILD
@@ -9,8 +9,8 @@ name = "tflite_text_api", srcs = ["tflite_text_api.py"], deps = [ - "//tensorflow_lite_support/custom_ops:expect_tfpy_installed", - "//tensorflow_lite_support/custom_ops:expect_tftext_installed", + # build rule placeholder: tensorflow dep, + # build rule placeholder: tensorflow_text dep, ], ) @@ -23,8 +23,8 @@ ], srcs_version = "PY3", deps = [ - "//tensorflow_lite_support/custom_ops:expect_tfpy_installed", - "//tensorflow_lite_support/custom_ops/kernel/sentencepiece:pywrap_model_converter", + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/custom_ops/kernel/sentencepiece/py:pywrap_model_converter", ], ) @@ -37,9 +37,9 @@ python_version = "PY3", deps = [ ":sentencepiece_tokenizer", - "//tensorflow_lite_support/custom_ops:expect_tfpy_installed", - "//tensorflow_lite_support/custom_ops:expect_tftext_installed", - "//tensorflow_lite_support/custom_ops/kernel/sentencepiece:pywrap_tflite_registerer", + # build rule placeholder: tensorflow dep, + # build rule placeholder: tensorflow_text dep, + "//tensorflow_lite_support/custom_ops/kernel/sentencepiece/py:pywrap_tflite_registerer", "@absl_py//absl:app", "@absl_py//absl/flags", "@absl_py//absl/logging", @@ -52,8 +52,8 @@ srcs = ["ragged_tensor_to_tensor_test.py"], python_version = "PY3", deps = [ - "//tensorflow_lite_support/custom_ops:expect_tfpy_installed", - "//tensorflow_lite_support/custom_ops/kernel/ragged:pywrap_tflite_registerer", + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/custom_ops/kernel/ragged/py:pywrap_tflite_registerer", "@absl_py//absl:app", "@absl_py//absl/flags", "@absl_py//absl/logging",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/ragged_tensor_to_tensor_test.py b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/ragged_tensor_to_tensor_test.py index 319131e..b1fa3d1 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/ragged_tensor_to_tensor_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/ragged_tensor_to_tensor_test.py
@@ -35,7 +35,7 @@ concrete_function = ragged_tensor_function.get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions( - [concrete_function]) + [concrete_function], ragged_tensor_function) converter.allow_custom_ops = True tflite_model = converter.convert() interpreter = interpreter_wrapper.InterpreterWithCustomOps(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py index 55aa0c9..21efed56 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py
@@ -27,7 +27,7 @@ from tensorflow.python.framework import load_library from tensorflow.python.platform import resource_loader gen_sentencepiece_tokenizer_op = load_library.load_op_library(resource_loader.get_path_to_datafile('../kernel/sentencepiece/sentencepiece_tokenizer_op.so')) -from tensorflow_lite_support.custom_ops.kernel.sentencepiece import pywrap_model_converter as model_converter +from tensorflow_lite_support.custom_ops.kernel.sentencepiece.py import pywrap_model_converter as model_converter class SentencepieceTokenizer:
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py index e0a9c81..3609b469 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py
@@ -33,7 +33,7 @@ from tensorflow.lite.python import interpreter as interpreter_wrapper # pylint: disable=g-direct-tensorflow-import from tensorflow.python.platform import resource_loader from tensorflow_lite_support.custom_ops.python import sentencepiece_tokenizer -from tensorflow_lite_support.custom_ops.kernel.sentencepiece import pywrap_tflite_registerer +from tensorflow_lite_support.custom_ops.kernel.sentencepiece.py import pywrap_tflite_registerer FLAGS = flags.FLAGS
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/testdata/wiki40b-lm-en.tflite b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/testdata/wiki40b-lm-en.tflite new file mode 100644 index 0000000..15f5b49 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/testdata/wiki40b-lm-en.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/tflite_inference_main.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/tflite_inference_main.cc deleted file mode 100644 index 2819dee..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/tflite_inference_main.cc +++ /dev/null
@@ -1,105 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This program runs the tflite model specified in --model with random inputs. -// For string type, the input is filled with a fixed string. - -#include <string> - -#include <glog/logging.h> -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model.h" -#include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/string_util.h" -#include "tensorflow/lite/tools/command_line_flags.h" - -void FillRandomString(tflite::DynamicBuffer* buffer, - const TfLiteIntArray* dim_array, - const std::function<std::string()>& random_func) { - int num_elements = 1; - for (size_t i = 0; i < dim_array->size; i++) { - num_elements *= dim_array->data[i]; - } - for (int i = 0; i < num_elements; ++i) { - auto str = random_func(); - buffer->AddString(str.data(), str.length()); - } -} - -void RunWithRandomInputs(const std::string& filename) { - std::unique_ptr<tflite::FlatBufferModel> model = - tflite::FlatBufferModel::BuildFromFile(filename.c_str()); - - // Build the interpreter - tflite::ops::builtin::BuiltinOpResolver resolver; - std::unique_ptr<tflite::Interpreter> interpreter; - if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { - LOG(FATAL) << "Could not initialize interpreter for TFLite model."; - } - - // Resize input tensors, if desired. - if (interpreter->AllocateTensors() != kTfLiteOk) { - LOG(FATAL) << "Could not allocate tensor."; - } - - // Fill the random data. - std::vector<std::vector<uint8_t>> sample; - for (int tensor_idx : interpreter->inputs()) { - auto tensor = interpreter->tensor(tensor_idx); - if (tensor->type == kTfLiteString) { - tflite::DynamicBuffer buffer; - FillRandomString(&buffer, tensor->dims, []() { - return "we're have some friends over saturday to hang out in the " - "yard"; - }); - buffer.WriteToTensor(tensor, /*new_shape=*/nullptr); - } else { - std::vector<uint8_t> data(tensor->bytes); - for (auto it = data.begin(); it != data.end(); ++it) { - *it = random(); - } - sample.push_back(data); - tensor->data.raw = reinterpret_cast<char*>(sample.rbegin()->data()); - } - } - - // Running inference. - if (interpreter->Invoke() != kTfLiteOk) { - LOG(FATAL) << "Failed to run the model."; - } - - // Get the output. - for (int tensor_idx : interpreter->outputs()) { - auto tensor = interpreter->tensor(tensor_idx); - LOG(INFO) << "Output type: " << TfLiteTypeGetName(tensor->type); - } -} - -int main(int argc, char** argv) { - // Parse flags to get the filename. - std::string filename; - std::vector<tflite::Flag> flag_list{tflite::Flag::CreateFlag( - "model", &filename, "The tflite model to run sample inference.", - tflite::Flag::kRequired)}; - tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list); - tensorflow::port::InitMain(argv[0], &argc, &argv); - - // Run the model with random inputs. - RunWithRandomInputs(filename); - return 0; -}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/tflite_inference_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/tflite_inference_test.cc new file mode 100644 index 0000000..79e0652 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/tflite_inference_test.cc
@@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This program runs the tflite model specified in --model with random inputs. +// For string type, the input is filled with a fixed string. + +#include <string> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/model_builder.h" +#include "tensorflow/lite/string_util.h" + +void FillRandomString(tflite::DynamicBuffer* buffer, + const TfLiteIntArray* dim_array, + const std::function<std::string()>& random_func) { + int num_elements = 1; + for (size_t i = 0; i < dim_array->size; i++) { + num_elements *= dim_array->data[i]; + } + for (int i = 0; i < num_elements; ++i) { + auto str = random_func(); + buffer->AddString(str.data(), str.length()); + } +} + +bool RunWithRandomInputs(const std::string& filename) { + std::unique_ptr<tflite::FlatBufferModel> model = + tflite::FlatBufferModel::BuildFromFile(filename.c_str()); + + // Build the interpreter + tflite::ops::builtin::BuiltinOpResolver resolver; + std::unique_ptr<tflite::Interpreter> interpreter; + if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { + LOG(ERROR) << "Could not initialize interpreter for TFLite model."; + return false; + } + + // Resize input tensors, if desired. + if (interpreter->AllocateTensors() != kTfLiteOk) { + LOG(ERROR) << "Could not allocate tensor."; + return false; + } + + // Fill the random data. + std::vector<std::vector<uint8_t>> sample; + for (int tensor_idx : interpreter->inputs()) { + auto tensor = interpreter->tensor(tensor_idx); + if (tensor->type == kTfLiteString) { + tflite::DynamicBuffer buffer; + FillRandomString(&buffer, tensor->dims, []() { + return "we're have some friends over saturday to hang out in the " + "yard"; + }); + buffer.WriteToTensor(tensor, /*new_shape=*/nullptr); + } else { + std::vector<uint8_t> data(tensor->bytes); + for (auto it = data.begin(); it != data.end(); ++it) { + *it = random(); + } + sample.push_back(data); + tensor->data.raw = reinterpret_cast<char*>(sample.rbegin()->data()); + } + } + + // Running inference. + if (interpreter->Invoke() != kTfLiteOk) { + LOG(ERROR) << "Failed to run the model."; + return false; + } + + // Get the output. + for (int tensor_idx : interpreter->outputs()) { + auto tensor = interpreter->tensor(tensor_idx); + LOG(INFO) << "Output type: " << TfLiteTypeGetName(tensor->type); + } + return true; +} + +TEST(SelectiveBuiltTest, SentencePieceTokenizerModel) { + std::string model = + "tensorflow_lite_support/custom_ops/testdata/" + "sentencepiece_tokenizer_flex_op.tflite"; + EXPECT_THAT(RunWithRandomInputs(model), true); +} + +TEST(SelectiveBuiltTest, Wiki40bLmEnModel) { + std::string model = + "tensorflow_lite_support/custom_ops/testdata/" + "wiki40b-lm-en.tflite"; + EXPECT_THAT(RunWithRandomInputs(model), true); +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/BUILD new file mode 100644 index 0000000..d70b641 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/BUILD
@@ -0,0 +1,13 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +config_setting( + name = "darwinn_portable", + values = { + "define": "darwinn_portable=1", + }, +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/BUILD new file mode 100644 index 0000000..fa6a12e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/BUILD
@@ -0,0 +1,46 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "audio_classifier_lib", + srcs = ["audio_classifier_lib.cc"], + hdrs = ["audio_classifier_lib.h"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/audio/utils:wav_io", + "//tensorflow_lite_support/cc/task/audio:audio_classifier", + "//tensorflow_lite_support/cc/task/audio/core:audio_buffer", + "//tensorflow_lite_support/cc/task/audio/proto:classifications_proto_inc", + "//tensorflow_lite_support/cc/task/core:category", + ] + select({ + "//tensorflow_lite_support/examples/task:darwinn_portable": [ + "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin", + ], + "//conditions:default": [ + ], + }), +) + +# Example usage: +# bazel run -c opt \ +# tensorflow_lite_support/examples/task/audio/desktop:audio_classifier_demo \ +# -- \ +# --model_path=/path/to/model.tflite \ +# --audio_wav_path=/path/to/wav.tflite +cc_binary( + name = "audio_classifier_demo", + srcs = ["audio_classifier_demo.cc"], + deps = [ + ":audio_classifier_lib", + "//tensorflow_lite_support/cc/task/audio:audio_classifier", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/README.md b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/README.md new file mode 100644 index 0000000..3c51aa4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/README.md
@@ -0,0 +1,98 @@ +# CLI Demos for C++ Audio Task APIs + +This folder contains simple command-line tools for easily trying out the C++ +Audio Task APIs. + +## Coral Integration + +Task Library now supports fast TFLite inference delegated onto +[Coral Edge TPU devices](https://coral.ai/docs/edgetpu/inference/) on Linux and +macOS. See the +[documentation](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview#run_task_library_with_delegates) +for more details. + +To run the demo on a Coral device, add `--define darwinn_portable=1` to the +bazel command. + +Note the `libusb` package is required. It can be installed as follows: + +```bash +# On Linux +sudo apt-get install libusb-1.0-0-dev + +# On macOS using MacPorts +port install libusb +# or Homebrew +brew install libusb +``` + +See the example commands in each task demo below. + +You can also explore more [pretrained Coral model](https://coral.ai/models) and +try them in the demo. All the models have populated with +[TFLite Model Metadata](https://www.tensorflow.org/lite/convert/metadata). + +## Audio Classification + +### Prerequisites +You will need: + +- a TFLite audio classification model with metadata (e.g. + https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1, an + environmental sound classification model available on TensorFlow Hub), +- a mono-channel 16-bit PCM WAV file. The sample rate of the WAV file should + be the same as what model requires (described in the Metadata). + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1?lite-format=tflite' \ + -o /tmp/yamnet.tflite + +# Download the audio file: +curl \ + -L https://storage.googleapis.com/audioset/miaow_16k.wav \ + -o /tmp/miao.wav + +# Run the classification tool: +bazel run -c opt \ + tensorflow_lite_support/examples/task/audio/desktop:audio_classifier_demo -- \ + --model_path=/tmp/yamnet.tflite \ + --score_threshold=0.5 \ + --audio_wav_path=/tmp/miao.wav +``` + +To run the demo on a [Coral Edge TPU device](https://coral.ai/products/), check +[Coral Integration](#coral-integration) section and then run: + +```bash +# Download the Coral model: +curl \ + -L 'https://tfhub.dev/google/coral-model/yamnet/classification/coral/1?coral-format=tflite' \ + -o /tmp/yamnet_edgetpu.tflite + +# Run the classification tool: +bazel run -c opt --define darwinn_portable=1 \ + tensorflow_lite_support/examples/task/audio/desktop:audio_classifier_demo -- \ + --model_path=/tmp/yamnet_edgetpu.tflite \ + --audio_wav_path=/path/to/the/audio_file.wav \ + --score_threshold=0.5 \ + --use_coral=true +``` + +### Results +In the console, you should get: + +```bash +Time cost to classify the input audio clip on CPU: 51.4087 ms +Note: Only showing classes with score higher than 0.5 + +Head[0]: scores + category[Cat]: 0.73828 + category[Animal]: 0.66797 + category[Domestic animals, pets]: 0.66797 +```
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc new file mode 100644 index 0000000..6339ed7 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc
@@ -0,0 +1,77 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Example usage: +// bazel run -c opt \ +// tensorflow_lite_support/examples/task/audio/desktop:audio_classifier_demo \ +// -- \ +// --model_path=/path/to/model.tflite \ +// --audio_wav_path=/path/to/audio.wav + +#include <cstddef> +#include <iostream> +#include <limits> + +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/flags/parse.h" // from @com_google_absl +#include "tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h" + +ABSL_FLAG(std::string, + model_path, + "", + "Absolute path to the '.tflite' audio classification model."); +ABSL_FLAG(std::string, + audio_wav_path, + "", + "Absolute path to the 16-bit PCM WAV file to classify. The WAV " + "file must be monochannel and has a sampling rate matches the model " + "expected sampling rate (as in the Metadata). If the WAV file is " + "longer than what the model requires, only the beginning section is " + "used for inference."); +ABSL_FLAG(float, + score_threshold, + 0.001f, + "Apply a filter on the results. Only display classes with score " + "higher than the threshold."); +ABSL_FLAG(bool, + use_coral, + false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); + +int main(int argc, char** argv) { + // Parse command line arguments and perform sanity checks. + absl::ParseCommandLine(argc, argv); + if (absl::GetFlag(FLAGS_model_path).empty()) { + std::cerr << "Missing mandatory 'model_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_audio_wav_path).empty()) { + std::cerr << "Missing mandatory 'audio_wav_path' argument.\n"; + return 1; + } + + // Run classification. + auto result = tflite::task::audio::Classify( + absl::GetFlag(FLAGS_model_path), absl::GetFlag(FLAGS_audio_wav_path), + absl::GetFlag(FLAGS_use_coral)); + if (result.ok()) { + tflite::task::audio::Display(result.value(), + absl::GetFlag(FLAGS_score_threshold)); + } else { + std::cerr << "Classification failed: " << result.status().message() << "\n"; + return 1; + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc new file mode 100644 index 0000000..a843501 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc
@@ -0,0 +1,116 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h" + +#include <iostream> +#include <string> +#include <vector> + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/audio/audio_classifier.h" +#include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h" +#include "tensorflow_lite_support/cc/task/audio/proto/classifications_proto_inc.h" +#include "tensorflow_lite_support/cc/task/audio/utils/wav_io.h" +#include "tensorflow_lite_support/cc/task/core/category.h" + +namespace tflite { +namespace task { +namespace audio { + +tflite::support::StatusOr<AudioBuffer> LoadAudioBufferFromFile( + const std::string& wav_file, + int buffer_size, + std::vector<float>* wav_data) { + std::string contents = ReadFile(wav_file); + + uint32 decoded_sample_count; + uint16 decoded_channel_count; + uint32 decoded_sample_rate; + RETURN_IF_ERROR(DecodeLin16WaveAsFloatVector( + contents, wav_data, &decoded_sample_count, &decoded_channel_count, + &decoded_sample_rate)); + + if (decoded_sample_count > buffer_size) { + decoded_sample_count = buffer_size; + } + + return AudioBuffer( + wav_data->data(), decoded_sample_count, + {decoded_channel_count, static_cast<int>(decoded_sample_rate)}); +} + +tflite::support::StatusOr<ClassificationResult> Classify( + const std::string& model_path, + const std::string& wav_file, + bool use_coral) { + AudioClassifierOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + model_path); + if (use_coral) { + options.mutable_base_options() + ->mutable_compute_settings() + ->mutable_tflite_settings() + ->set_delegate(::tflite::proto::Delegate::EDGETPU_CORAL); + } + ASSIGN_OR_RETURN(std::unique_ptr<AudioClassifier> classifier, + AudioClassifier::CreateFromOptions(options)); + + // `wav_data` holds data loaded from the file and needs to outlive `buffer`. + std::vector<float> wav_data; + ASSIGN_OR_RETURN( + AudioBuffer buffer, + LoadAudioBufferFromFile( + wav_file, classifier->GetRequiredInputBufferSize(), &wav_data)); + + auto start_classify = std::chrono::steady_clock::now(); + ASSIGN_OR_RETURN(ClassificationResult result, classifier->Classify(buffer)); + auto end_classify = std::chrono::steady_clock::now(); + std::string delegate = use_coral ? "Coral Edge TPU" : "CPU"; + const auto duration_ms = + std::chrono::duration<float, std::milli>(end_classify - start_classify); + std::cout << "Time cost to classify the input audio clip on " << delegate + << ": " << duration_ms.count() << " ms" << std::endl; + + return result; +} + +void Display(const ClassificationResult& result, float score_threshold) { + std::cout << "Note: Only showing classes with score higher than " + << score_threshold << std::endl; + + for (int i = 0; i < result.classifications_size(); i++) { + const auto& head = result.classifications(i); + std::cout << absl::StrFormat("\nHead[%d]: %s\n", i, head.head_name()); + for (int j = 0; j < head.classes_size(); j++) { + const auto& category = head.classes(j); + if (category.score() < score_threshold) + continue; + std::cout << absl::StrFormat("\tcategory[%s]: %.5f\t", + category.class_name(), category.score()); + if (!category.display_name().empty()) { + std::cout << category.display_name(); + } + std::cout << std::endl; + } + } +} + +} // namespace audio +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h new file mode 100644 index 0000000..13b2d77 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h
@@ -0,0 +1,43 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_AUDIO_DESKTOP_AUDIO_CLASSIFIER_LIB_H_ +#define TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_AUDIO_DESKTOP_AUDIO_CLASSIFIER_LIB_H_ + +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h" +#include "tensorflow_lite_support/cc/task/audio/proto/classifications_proto_inc.h" + +namespace tflite { +namespace task { +namespace audio { + +// Loads `wav_file` from filesystem and runs classification using TFLite model +// in `model_path` with default options. If the content of `wav_file` is longer +// than what the model requires, only the beginning section is used for +// inference. +tflite::support::StatusOr<ClassificationResult> Classify( + const std::string& model_path, + const std::string& wav_file, + bool use_coral = false); + +// Prints the output classification result in the standard output. It only +// displays classes whose score is higher than the `score_threshold`. +void Display(const ClassificationResult& result, float score_threshold); + +} // namespace audio +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_AUDIO_DESKTOP_AUDIO_CLASSIFIER_LIB_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/BUILD new file mode 100644 index 0000000..20b4db24 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/BUILD
@@ -0,0 +1,21 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +# bazel run \ +# tensorflow_lite_support/examples/task/audio/desktop/python:audio_classifier_demo \ +# -- \ +# --model_path=/path/to/model.tflite \ +# --audio="What a waste of my time." +py_binary( + name = "audio_classifier_demo", + srcs = ["audio_classifier_demo.py"], + data = ["//tensorflow_lite_support/examples/task/audio/desktop:audio_classifier_demo"], + deps = [ + "@absl_py//absl:app", + "@absl_py//absl/flags", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/README.md b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/README.md new file mode 100644 index 0000000..f3e30073 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/README.md
@@ -0,0 +1,109 @@ +# CLI Demos for Python Audio Task APIs + +A Python wrapper for the C++ Audio Task APIs. + +## Background + +This Python API is based on the C++ Audio Task APIs. It uses Python's +[subprocess](https://docs.python.org/3/library/subprocess.html) to call C++ Audio +Task APIs. + +## Coral Integration + +Task Library now supports fast TFLite inference delegated onto +[Coral Edge TPU devices](https://coral.ai/docs/edgetpu/inference/) on Linux and +macOS. See the +[documentation](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview#run_task_library_with_delegates) +for more details. + +To run the demo on a Coral device, add `--define darwinn_portable=1` to the +bazel command. + +Note the `libusb` package is required. It can be installed as follows: + +```bash +# On Linux +sudo apt-get install libusb-1.0-0-dev + +# On macOS using MacPorts +port install libusb +# or Homebrew +brew install libusb +``` + +See the example commands in each task demo below. + +You can also explore more [pretrained Coral model](https://coral.ai/models) and +try them in the demo. All the models have populated with +[TFLite Model Metadata](https://www.tensorflow.org/lite/convert/metadata). + +## Audio Classification + +### Prerequisites +You will need: + +- a TFLite audio classification model with metadata (e.g. + https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1, an + environmental sound classification model available on TensorFlow Hub), +- a mono-channel 16-bit PCM WAV file. The sample rate of the WAV file should + be the same as what model requires (described in the Metadata). + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1?lite-format=tflite' \ + -o /tmp/yamnet.tflite + +# Download the audio file: +curl \ + -L https://storage.googleapis.com/audioset/miaow_16k.wav \ + -o /tmp/miao.wav +``` + +##### Build the demo from source + +Run the classification tool: + +```bash +bazel run \ + tensorflow_lite_support/examples/task/audio/desktop/python:audio_classifier_demo -- \ + --model_path=/tmp/yamnet.tflite \ + --score_threshold=0.5 \ + --audio_wav_path=/tmp/miao.wav +``` + +To run the demo on a [Coral Edge TPU device](https://coral.ai/products/), check +[Coral Integration](#coral-integration) section and then run: + +```bash +# Download the Coral model: +curl \ + -L 'https://tfhub.dev/google/coral-model/yamnet/classification/coral/1?coral-format=tflite' \ + -o /tmp/yamnet_edgetpu.tflite + +# Run the classification tool: +bazel run --define darwinn_portable=1 \ + tensorflow_lite_support/examples/task/audio/desktop/python:audio_classifier_demo -- \ + --model_path=/tmp/yamnet_edgetpu.tflite \ + --audio_wav_path=/path/to/the/audio_file.wav \ + --score_threshold=0.5 \ + --use_coral=true +``` + +#### Results + +In the console, you should get: + +```bash +Time cost to classify the input audio clip on CPU: 51.4087 ms +Note: Only showing classes with score higher than 0.5 + +Head[0]: scores + category[Cat]: 0.73828 + category[Animal]: 0.66797 + category[Domestic animals, pets]: 0.66797 +``` \ No newline at end of file
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/audio_classifier_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/audio_classifier_demo.py new file mode 100644 index 0000000..07b6bf9 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/audio_classifier_demo.py
@@ -0,0 +1,87 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python demo tool for Audio Classification.""" + +import inspect +import os.path as _os_path +import subprocess +import sys + +from absl import app +from absl import flags + +FLAGS = flags.FLAGS +flags.DEFINE_string( + 'model_path', None, + 'Absolute path to the ".tflite" audio classification model.') +flags.DEFINE_string( + 'audio_wav_path', None, + 'Absolute path to the 16-bit PCM WAV file to classify. The WAV ' + 'file must be monochannel and has a sampling rate matches the model ' + 'expected sampling rate (as in the Metadata). If the WAV file is ' + 'longer than what the model requires, only the beginning section is ' + 'used for inference.') +flags.DEFINE_float( + 'score_threshold', None, + 'Apply a filter on the results. Only display classes with score ' + 'higher than the threshold.') +flags.DEFINE_bool( + 'use_coral', False, + 'If true, inference will be delegated to a connected Coral Edge TPU ' + 'device.') + +# Required flag. +flags.mark_flag_as_required('model_path') +flags.mark_flag_as_required('score_threshold') +flags.mark_flag_as_required('audio_wav_path') + +_AUDIO_CLASSIFICATION_NATIVE_PATH = _os_path.join( + _os_path.dirname(inspect.getfile(inspect.currentframe())), + '../audio_classifier_demo') + + +def classify(model_path, score_threshold, audio_wav_path, use_coral): + """Classifies input audio clip into different categories. + + Args: + model_path: Path to model + score_threshold: Absolute path to the 16-bit PCM WAV file to classify + audio_wav_path: Apply a filter on the results + use_coral: Optional; If true, inference will be delegated to a connected + Coral Edge TPU device. + """ + # Run the classification tool: + subprocess.run([ + _AUDIO_CLASSIFICATION_NATIVE_PATH + ' --model_path=' + model_path + + ' --score_threshold=' + str(score_threshold) + ' --audio_wav_path="' + + audio_wav_path + '" --use_coral=' + str(use_coral) + ], + shell=True, + check=True) + + +def run_main(argv): + del argv # Unused. + classify(FLAGS.model_path, FLAGS.score_threshold, FLAGS.audio_wav_path, + FLAGS.use_coral) + + +# Simple wrapper to make the code pip-friendly +def main(): + app.run(main=run_main, argv=sys.argv) + + +if __name__ == '__main__': + main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/BUILD index 067d59c..2d69942 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/BUILD
@@ -1,6 +1,6 @@ package( default_visibility = [ - "//tensorflow_lite_support:users", + "//tensorflow_lite_support:internal", ], licenses = ["notice"], # Apache 2.0 ) @@ -16,13 +16,19 @@ name = "bert_question_answerer_demo", srcs = ["bert_question_answerer_demo.cc"], deps = [ - "//tensorflow_lite_support/cc/port:statusor", - "//tensorflow_lite_support/cc/task/text/qa:bert_question_answerer", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", - ], + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/text:bert_question_answerer", + ] + select({ + "//tensorflow_lite_support/examples/task:darwinn_portable": [ + "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin", + ], + "//conditions:default": [ + ], + }), ) # Example usage: @@ -35,14 +41,20 @@ name = "bert_nl_classifier_demo", srcs = ["bert_nl_classifier_demo.cc"], deps = [ - "//tensorflow_lite_support/cc/port:statusor", - "//tensorflow_lite_support/cc/task/core:category", - "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", - ], + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:category", + "//tensorflow_lite_support/cc/task/text:bert_nl_classifier", + ] + select({ + "//tensorflow_lite_support/examples/task:darwinn_portable": [ + "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin", + ], + "//conditions:default": [ + ], + }), ) # Example usage: @@ -57,12 +69,40 @@ name = "nl_classifier_demo", srcs = ["nl_classifier_demo.cc"], deps = [ - "//tensorflow_lite_support/cc/port:statusor", - "//tensorflow_lite_support/cc/task/core:category", - "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:category", + "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier", + ] + select({ + "//tensorflow_lite_support/examples/task:darwinn_portable": [ + "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin", + ], + "//conditions:default": [ + ], + }), +) + +# Example usage: +# bazel run -c opt \ +# tensorflow_lite_support/examples/task/text/desktop:universal_sentence_encoder_qa_main \ +# -- \ +# --model_path=/path/to/model_with_metadata.tflite +cc_binary( + name = "universal_sentence_encoder_qa_demo", + srcs = [ + "universal_sentence_encoder_qa_demo.cc", + ], + deps = [ + "//tensorflow_lite_support/cc/task/text:universal_sentence_encoder_qa", + "//tensorflow_lite_support/custom_ops/kernel/ragged:py_tflite_registerer", + "//tensorflow_lite_support/custom_ops/kernel/sentencepiece:py_tflite_registerer", + "//tensorflow_lite_support/custom_ops/kernel/sentencepiece:sentencepiece_tokenizer_op", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/README.md b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/README.md index 859504f..ecd925b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/README.md +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/README.md
@@ -9,8 +9,8 @@ You will need: -* a TFLite bert based question answerer model from model maker. -(e.g. [mobilebert][1] or [albert][2] available on TensorFlow Hub). +* a TFLite bert based question answerer model from model maker. (e.g. + [mobilebert][1] or [albert][2] available on TensorFlow Hub). #### Usage @@ -19,7 +19,7 @@ ```bash # Download the model: curl \ - -L 'https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1?lite-format=tflite' \ + -L 'https://tfhub.dev/tensorflow/lite-model/mobilebert/1/metadata/1?lite-format=tflite' \ -o /tmp/mobilebert.tflite # Run the classification tool: @@ -40,6 +40,7 @@ In the console, you should get: ``` +Time cost to answer the input question on CPU: 783 ms answer[0]: 'South America.' logit: 1.84847, start_index: 39, end_index: 40 answer[1]: 'most of the Amazon basin of South America.' @@ -58,10 +59,8 @@ You will need: -* a TFLite text classification model with certain format. -(e.g. [movie_review_model][3], a model to classify movie reviews), you'll need -to configure the input tensor and out tensor for the API, see the [doc][4] for -details. +* a TFLite text classification model with certain format. (e.g. + [movie_review_model][3], a model to classify movie reviews). #### Usage @@ -77,9 +76,7 @@ bazel run -c opt \ tensorflow_lite_support/examples/task/text/desktop:nl_classifier_demo -- \ --model_path=/tmp/movie_review.tflite \ - --text="What a waste of my time." \ - --input_tensor_name="input_text" \ - --output_score_tensor_name="probability" + --text="What a waste of my time." ``` #### Results @@ -87,6 +84,7 @@ In the console, you should get: ``` +Time cost to classify the input text on CPU: 0.088 ms category[0]: 'Negative' : '0.81313' category[1]: 'Positive' : '0.18687' ``` @@ -99,7 +97,8 @@ You will need: -* a Bert based TFLite text classification model from model maker. (e.g. [movie_review_model][5] available on TensorFlow Hub). +* a Bert based TFLite text classification model from model maker. (e.g. + [movie_review_model][5] available on TensorFlow Hub). #### Usage @@ -123,12 +122,49 @@ In the console, you should get: ``` +Time cost to classify the input text on CPU: 491 ms category[0]: 'negative' : '0.00006' category[1]: 'positive' : '0.99994' ``` +## UniversalSentenceEncoderQA + +#### Prerequisites + +You will need: + +* a universal sentence encoder QA model from [TensorFlow Hub][6]. + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://tfhub.dev/google/lite-model/universal-sentence-encoder-qa-ondevice/1?lite-format=tflite' \ + -o /tmp/universal_sentence_encoder_qa_with_metadata.tflite + +# Run UniversalSentenceEncoderQA model: +bazel run -c opt \ + tensorflow_lite_support/examples/task/text/desktop:universal_sentence_encoder_qa_demo -- \ + --model_path=/tmp/universal_sentence_encoder_qa_with_metadata.tflite +``` + +#### Results + +In the console, you should get: + +``` +How are you feeling today? +I'm not feeling very well., , 14.9595 +He looks good., , 8.80944 +Paris is the capital of France., , 5.63753 +``` + [1]: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 [2]: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 [3]: https://www.tensorflow.org/lite/models/text_classification/overview [4]: https://github.com/tensorflow/tflite-support/blob/fe8b69002f5416900285dc69e2baa078c91bd994/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h#L55 [5]: http://bert/nl/classifier/model +[6]: https://tfhub.dev/google/lite-model/universal-sentence-encoder-qa-ondevice/1
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc index e88d565..52032008 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc
@@ -15,32 +15,59 @@ #include <iostream> #include <limits> -#include "absl/flags/flag.h" -#include "absl/flags/parse.h" -#include "absl/status/status.h" -#include "absl/strings/str_format.h" +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/flags/parse.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/core/category.h" -#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h" +#include "tensorflow_lite_support/cc/task/text/bert_nl_classifier.h" ABSL_FLAG(std::string, model_path, "", "Absolute path to the '.tflite' bert classification model."); ABSL_FLAG(std::string, text, "", "Text to classify."); +ABSL_FLAG(bool, + use_coral, + false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); namespace tflite { namespace task { namespace text { -namespace nlclassifier { + +namespace { +using std::chrono::microseconds; +using std::chrono::steady_clock; +} // namespace absl::Status Classify() { - ASSIGN_OR_RETURN( - std::unique_ptr<BertNLClassifier> classifier, - BertNLClassifier::CreateFromFile(absl::GetFlag(FLAGS_model_path))); + BertNLClassifierOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + absl::GetFlag(FLAGS_model_path)); + if (absl::GetFlag(FLAGS_use_coral)) { + options.mutable_base_options() + ->mutable_compute_settings() + ->mutable_tflite_settings() + ->set_delegate(::tflite::proto::Delegate::EDGETPU_CORAL); + } + ASSIGN_OR_RETURN(std::unique_ptr<BertNLClassifier> classifier, + BertNLClassifier::CreateFromOptions(options)); + + auto start_classify = steady_clock::now(); std::vector<core::Category> categories = classifier->Classify(absl::GetFlag(FLAGS_text)); + auto end_classify = steady_clock::now(); + std::string delegate = + absl::GetFlag(FLAGS_use_coral) ? "Coral Edge TPU" : "CPU"; + std::cout << "Time cost to classify the input text on " << delegate << ": " + << std::chrono::duration<float, std::milli>(end_classify - + start_classify) + .count() + << " ms" << std::endl; for (int i = 0; i < categories.size(); ++i) { const core::Category& category = categories[i]; @@ -51,7 +78,6 @@ return absl::OkStatus(); } -} // namespace nlclassifier } // namespace text } // namespace task } // namespace tflite @@ -69,7 +95,7 @@ } // Run classification. - absl::Status status = tflite::task::text::nlclassifier::Classify(); + absl::Status status = tflite::task::text::Classify(); if (status.ok()) { return 0; } else {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc index 534b86a..f2577cf 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc
@@ -15,12 +15,12 @@ #include <iostream> #include <limits> -#include "absl/flags/flag.h" -#include "absl/flags/parse.h" -#include "absl/status/status.h" -#include "absl/strings/str_format.h" +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/flags/parse.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/statusor.h" -#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h" +#include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h" ABSL_FLAG(std::string, model_path, @@ -31,19 +31,47 @@ context, "", "Context the asked question is based upon."); +ABSL_FLAG(bool, + use_coral, + false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); namespace tflite { namespace task { namespace text { -namespace qa { + +namespace { +using std::chrono::microseconds; +using std::chrono::steady_clock; +} // namespace absl::Status Answer() { - ASSIGN_OR_RETURN( - std::unique_ptr<QuestionAnswerer> answerer, - BertQuestionAnswerer::CreateFromFile(absl::GetFlag(FLAGS_model_path))); + BertQuestionAnswererOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + absl::GetFlag(FLAGS_model_path)); + if (absl::GetFlag(FLAGS_use_coral)) { + options.mutable_base_options() + ->mutable_compute_settings() + ->mutable_tflite_settings() + ->set_delegate(::tflite::proto::Delegate::EDGETPU_CORAL); + } + ASSIGN_OR_RETURN(std::unique_ptr<QuestionAnswerer> answerer, + BertQuestionAnswerer::CreateFromOptions(options)); + + auto start_answer = steady_clock::now(); std::vector<QaAnswer> answers = answerer->Answer( absl::GetFlag(FLAGS_context), absl::GetFlag(FLAGS_question)); + auto end_answer = steady_clock::now(); + std::string delegate = + absl::GetFlag(FLAGS_use_coral) ? "Coral Edge TPU" : "CPU"; + std::cout << "Time cost to answer the input question on " << delegate << ": " + << std::chrono::duration<float, std::milli>(end_answer - + start_answer) + .count() + << " ms" << std::endl; + for (int i = 0; i < answers.size(); ++i) { const QaAnswer& answer = answers[i]; std::cout << absl::StrFormat( @@ -54,7 +82,6 @@ return absl::OkStatus(); } -} // namespace qa } // namespace text } // namespace task } // namespace tflite @@ -75,7 +102,7 @@ return 1; } // Run the answerer. - absl::Status status = tflite::task::text::qa::Answer(); + absl::Status status = tflite::task::text::Answer(); if (status.ok()) { return 0; } else {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc index 9e62714..613744ff 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc
@@ -15,10 +15,10 @@ #include <iostream> #include <limits> -#include "absl/flags/flag.h" -#include "absl/flags/parse.h" -#include "absl/status/status.h" -#include "absl/strings/str_format.h" +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/flags/parse.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/core/category.h" #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" @@ -28,64 +28,47 @@ "", "Absolute path to the '.tflite' classification model."); ABSL_FLAG(std::string, text, "", "Text to classify."); -ABSL_FLAG(int, input_tensor_index, -1, "Input tensor index of the model."); -ABSL_FLAG(int, - output_score_tensor_index, - -1, - "Output score tensor index of the model."); -ABSL_FLAG(int, - output_label_tensor_index, - -1, - "Output label tensor index of the model."); -ABSL_FLAG(std::string, - input_tensor_name, - "", - "Input tensor name of the model."); -ABSL_FLAG(std::string, - output_score_tensor_name, - "", - "Output score tensor name of the model."); -ABSL_FLAG(std::string, - output_label_tensor_name, - "", - "Output label tensor name of the model."); +ABSL_FLAG(bool, + use_coral, + false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); namespace tflite { namespace task { namespace text { namespace nlclassifier { +namespace { +using std::chrono::microseconds; +using std::chrono::steady_clock; +} // namespace + absl::Status Classify() { - NLClassifierOptions options{}; - if (absl::GetFlag(FLAGS_input_tensor_index) >= 0) { - options.input_tensor_index = absl::GetFlag(FLAGS_input_tensor_index); - } - if (absl::GetFlag(FLAGS_output_score_tensor_index) >= 0) { - options.output_score_tensor_index = - absl::GetFlag(FLAGS_output_score_tensor_index); - } - if (absl::GetFlag(FLAGS_output_label_tensor_index) >= 0) { - options.output_label_tensor_index = - absl::GetFlag(FLAGS_output_label_tensor_index); - } - if (!absl::GetFlag(FLAGS_input_tensor_name).empty()) { - options.input_tensor_name = absl::GetFlag(FLAGS_input_tensor_name); - } - if (!absl::GetFlag(FLAGS_output_score_tensor_name).empty()) { - options.output_score_tensor_name = - absl::GetFlag(FLAGS_output_score_tensor_name); - } - if (!absl::GetFlag(FLAGS_output_label_tensor_name).empty()) { - options.output_label_tensor_name = - absl::GetFlag(FLAGS_output_label_tensor_name); + tflite::task::text::NLClassifierOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + absl::GetFlag(FLAGS_model_path)); + if (absl::GetFlag(FLAGS_use_coral)) { + options.mutable_base_options() + ->mutable_compute_settings() + ->mutable_tflite_settings() + ->set_delegate(::tflite::proto::Delegate::EDGETPU_CORAL); } ASSIGN_OR_RETURN(std::unique_ptr<NLClassifier> classifier, - NLClassifier::CreateFromFileAndOptions( - absl::GetFlag(FLAGS_model_path), options)); + NLClassifier::CreateFromOptions(options)); + auto start_classify = steady_clock::now(); std::vector<core::Category> categories = classifier->Classify(absl::GetFlag(FLAGS_text)); + auto end_classify = steady_clock::now(); + std::string delegate = + absl::GetFlag(FLAGS_use_coral) ? "Coral Edge TPU" : "CPU"; + std::cout << "Time cost to classify the input text on " << delegate << ": " + << std::chrono::duration<float, std::milli>(end_classify - + start_classify) + .count() + << " ms" << std::endl; for (int i = 0; i < categories.size(); ++i) { const core::Category& category = categories[i];
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/BUILD new file mode 100644 index 0000000..06ae5a7 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/BUILD
@@ -0,0 +1,56 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +# bazel run \ +# tensorflow_lite_support/examples/task/text/desktop/python:nl_classifier_demo \ +# -- \ +# --model_path=/path/to/model.tflite \ +# --text="What a waste of my time." +py_binary( + name = "nl_classifier_demo", + srcs = ["nl_classifier_demo.py"], + data = ["//tensorflow_lite_support/examples/task/text/desktop:nl_classifier_demo"], + deps = [ + "@absl_py//absl:app", + "@absl_py//absl/flags", + ], +) + +# bazel run \ +# tensorflow_lite_support/examples/task/text/desktop/python:bert_nl_classifier_demo \ +# -- \ +# --model_path=/path/to/model.tflite \ +# --text="it's a charming and often affecting journey" +py_binary( + name = "bert_nl_classifier_demo", + srcs = ["bert_nl_classifier_demo.py"], + data = ["//tensorflow_lite_support/examples/task/text/desktop:bert_nl_classifier_demo"], + deps = [ + "@absl_py//absl:app", + "@absl_py//absl/flags", + ], +) + +# bazel run \ +# tensorflow_lite_support/examples/task/text/desktop/python:bert_question_answerer_demo -- \ +# --model_path=/tmp/mobilebert.tflite \ +# --question="Where is Amazon rainforest?" \ +# --context="The Amazon rainforest, alternatively, the Amazon Jungle, also known in \ +# English as Amazonia, is a moist broadleaf tropical rainforest in the Amazon \ +# biome that covers most of the Amazon basin of South America. This basin \ +# encompasses 7,000,000 km2 (2,700,000 sq mi), of which \ +# 5,500,000 km2 (2,100,000 sq mi) are covered by the rainforest. This region \ +# includes territory belonging to nine nations." +py_binary( + name = "bert_question_answerer_demo", + srcs = ["bert_question_answerer_demo.py"], + data = ["//tensorflow_lite_support/examples/task/text/desktop:bert_question_answerer_demo"], + deps = [ + "@absl_py//absl:app", + "@absl_py//absl/flags", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/README.md b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/README.md new file mode 100644 index 0000000..b24ee79 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/README.md
@@ -0,0 +1,147 @@ +# CLI Demos for Python Text Task APIs + +A Python wrapper for the C++ Text Task APIs. + +## Background + +This Python API is based on the C++ Text Task APIs. It uses Python's +[subprocess](https://docs.python.org/3/library/subprocess.html) to call C++ Text +Task APIs. + +## Bert Question Answerer + +#### Prerequisites + +You will need: + +* a TFLite bert based question answerer model from model maker. (e.g. + [mobilebert][1] or [albert][2] available on TensorFlow Hub). + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://tfhub.dev/tensorflow/lite-model/mobilebert/1/metadata/1?lite-format=tflite' \ + -o /tmp/mobilebert.tflite + +# Run the classification tool: +bazel run \ + tensorflow_lite_support/examples/task/text/desktop/python:bert_question_answerer_demo -- \ + --model_path=/tmp/mobilebert.tflite \ + --question="Where is Amazon rainforest?" \ + --context="The Amazon rainforest, alternatively, the Amazon Jungle, also known in \ +English as Amazonia, is a moist broadleaf tropical rainforest in the Amazon \ +biome that covers most of the Amazon basin of South America. This basin \ +encompasses 7,000,000 km2 (2,700,000 sq mi), of which \ +5,500,000 km2 (2,100,000 sq mi) are covered by the rainforest. This region \ +includes territory belonging to nine nations." +``` + +#### Results + +In the console, you should get: + +``` +Time cost to answer the input question on CPU: 783 ms +answer[0]: 'South America.' +logit: 1.84847, start_index: 39, end_index: 40 +answer[1]: 'most of the Amazon basin of South America.' +logit: 1.2921, start_index: 34, end_index: 40 +answer[2]: 'the Amazon basin of South America.' +logit: -0.0959535, start_index: 36, end_index: 40 +answer[3]: 'the Amazon biome that covers most of the Amazon basin of South America.' +logit: -0.498558, start_index: 28, end_index: 40 +answer[4]: 'Amazon basin of South America.' +logit: -0.774266, start_index: 37, end_index: 40 +``` + +## NLClassifier + +#### Prerequisites + +You will need: + +* a TFLite text classification model with certain format. (e.g. + [movie_review_model][3], a model to classify movie reviews). + +#### Usage + +First, download the pretrained model by: + +```bash +curl \ + -L 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite' \ + -o /tmp/movie_review.tflite +``` + +##### Build the demo from source + +Run the demo tool: + +```bash +bazel run \ + tensorflow_lite_support/examples/task/text/desktop/python:nl_classifier_demo \ + -- \ + --model_path=/tmp/movie_review.tflite \ + --text="What a waste of my time." +``` + +#### Results + +In the console, you should get: + +``` +Time cost to classify the input text on CPU: 0.088 ms +category[0]: 'Negative' : '0.81313' +category[1]: 'Positive' : '0.18687' +``` + +## BertNLClassifier + +#### Prerequisites + +You will need: + +* a Bert based TFLite text classification model from model maker. (e.g. + [movie_review_model][5] available on TensorFlow Hub). + +#### Usage + +First, download the pretrained model by: + +```bash +curl \ + -L 'https://url/to/bert/nl/classifier' \ + -o /tmp/bert_movie_review.tflite +``` + +##### Build the demo from source + +Run the demo tool: + +```bash +bazel run \ + tensorflow_lite_support/examples/task/text/desktop/python:bert_nl_classifier_demo \ + -- \ + --model_path=/tmp/bert_movie_review.tflite \ + --text="it's a charming and often affecting journey" +``` + +#### Results + +In the console, you should get: + +``` +Time cost to classify the input text on CPU: 491 ms +category[0]: 'negative' : '0.00006' +category[1]: 'positive' : '0.99994' +``` + +[1]: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 +[2]: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 +[3]: https://www.tensorflow.org/lite/models/text_classification/overview +[4]: https://github.com/tensorflow/tflite-support/blob/fe8b69002f5416900285dc69e2baa078c91bd994/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h#L55 +[5]: http://bert/nl/classifier/model
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/bert_nl_classifier_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/bert_nl_classifier_demo.py new file mode 100644 index 0000000..4fed738 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/bert_nl_classifier_demo.py
@@ -0,0 +1,65 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python demo tool for BertNLClassifier.""" + +import inspect +import os.path as _os_path +import subprocess +import sys + +from absl import app +from absl import flags + +FLAGS = flags.FLAGS +flags.DEFINE_string('model_path', None, 'Model Path') +flags.DEFINE_string('text', None, 'Text to Predict') + +# Required flag. +flags.mark_flag_as_required('model_path') +flags.mark_flag_as_required('text') + +_BERT_NL_CLASSIFIER_NATIVE_PATH = _os_path.join( + _os_path.dirname(inspect.getfile(inspect.currentframe())), + '../bert_nl_classifier_demo') + + +def classify(model_path, text): + """Classifies input text into different categories. + + Args: + model_path: path to model + text: input text + """ + # Run the detection tool: + subprocess.run([ + _BERT_NL_CLASSIFIER_NATIVE_PATH + ' --model_path=' + model_path + + ' --text="' + text + '"' + ], + shell=True, + check=True) + + +def run_main(argv): + del argv # Unused. + classify(FLAGS.model_path, FLAGS.text) + + +# Simple wrapper to make the code pip-friendly +def main(): + app.run(main=run_main, argv=sys.argv) + + +if __name__ == '__main__': + main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/bert_question_answerer_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/bert_question_answerer_demo.py new file mode 100644 index 0000000..7276c0b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/bert_question_answerer_demo.py
@@ -0,0 +1,78 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python demo tool for BertQuestionAnswerer.""" + +import inspect +import os.path as _os_path +import subprocess +import sys + +from absl import app +from absl import flags + +FLAGS = flags.FLAGS +flags.DEFINE_string( + 'model_path', None, + 'Absolute path to the ".tflite" bert question answerer model.') +flags.DEFINE_string('question', None, 'Question to ask.') +flags.DEFINE_string('context', None, + 'Context the asked question is based upon.') +flags.DEFINE_bool( + 'use_coral', False, + 'If true, inference will be delegated to a connected Coral Edge TPU device.' +) + +# Required flag. +flags.mark_flag_as_required('model_path') +flags.mark_flag_as_required('question') +flags.mark_flag_as_required('context') + +_BERT_QUESTION_ANSWERER_NATIVE_PATH = _os_path.join( + _os_path.dirname(inspect.getfile(inspect.currentframe())), + '../bert_question_answerer_demo') + + +def classify(model_path, question, context, use_coral): + """Predicts the answers for the given question based on the given context + + Args: + model_path: Path to model + question: Question to ask + context: Context the asked question is based upon + use_coral: Optional; If true, inference will be delegated to a connected + Coral Edge TPU device. + """ + # Run the detection tool: + subprocess.run([ + _BERT_QUESTION_ANSWERER_NATIVE_PATH + ' --model_path=' + model_path + + ' --question="' + question + '" --context="' + context + + '" --use_coral=' + str(use_coral) + ], + shell=True, + check=True) + + +def run_main(argv): + del argv # Unused. + classify(FLAGS.model_path, FLAGS.question, FLAGS.context, FLAGS.use_coral) + + +# Simple wrapper to make the code pip-friendly +def main(): + app.run(main=run_main, argv=sys.argv) + + +if __name__ == '__main__': + main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/nl_classifier_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/nl_classifier_demo.py new file mode 100644 index 0000000..3b3bda3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/python/nl_classifier_demo.py
@@ -0,0 +1,65 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python demo tool for NLClassifier.""" + +import inspect +import os.path as _os_path +import subprocess +import sys + +from absl import app +from absl import flags + +FLAGS = flags.FLAGS +flags.DEFINE_string('model_path', None, 'Model Path') +flags.DEFINE_string('text', None, 'Text to Predict') + +# Required flag. +flags.mark_flag_as_required('model_path') +flags.mark_flag_as_required('text') + +_NL_CLASSIFIER_NATIVE_PATH = _os_path.join( + _os_path.dirname(inspect.getfile(inspect.currentframe())), + '../nl_classifier_demo') + + +def classify(model_path, text): + """Classifies input text into different categories. + + Args: + model_path: path to model + text: input text + """ + # Run the detection tool: + subprocess.run([ + _NL_CLASSIFIER_NATIVE_PATH + ' --model_path=' + model_path + ' --text="' + + text + '"' + ], + shell=True, + check=True) + + +def run_main(argv): + del argv # Unused. + classify(FLAGS.model_path, FLAGS.text) + + +# Simple wrapper to make the code pip-friendly +def main(): + app.run(main=run_main, argv=sys.argv) + + +if __name__ == '__main__': + main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc new file mode 100644 index 0000000..8ba00cb --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc
@@ -0,0 +1,91 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Demostration the usage of UniversalSentenceEncoderQA. +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/flags/parse.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_split.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h" +using tflite::task::text::RetrievalInput; +using tflite::task::text::RetrievalOptions; +using tflite::task::text::RetrievalOutput; +using tflite::task::text::retrieval::UniversalSentenceEncoderQA; + +ABSL_FLAG(std::string, + model_path, + "", + "Absolute path to the '.tflite' UniversalSentenceEncoderQA model."); +ABSL_FLAG(std::string, + question, + "How are you feeling today?", + "Question to ask."); +ABSL_FLAG( + std::string, + answers, + "I'm not feeling very well.:Paris is the capital of France.:He looks good.", + "Candidate answers seperated by `:`."); + +int main(int argc, char** argv) { + // Parse command line arguments and perform sanity checks. + absl::ParseCommandLine(argc, argv); + if (absl::GetFlag(FLAGS_model_path).empty()) { + std::cerr << "Missing mandatory 'model_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_question).empty()) { + std::cerr << "Missing mandatory 'question' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_answers).empty()) { + std::cerr << "Missing mandatory 'answers' argument.\n"; + return 1; + } + + RetrievalOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + absl::GetFlag(FLAGS_model_path)); + auto status = UniversalSentenceEncoderQA::CreateFromOption(options); + CHECK_OK(status); + std::unique_ptr<UniversalSentenceEncoderQA> client = + std::move(status.value()); + + // Create RetrievalInput with a query and responses. + RetrievalInput input; + // Set a sentence of text as the query. + input.set_query_text(absl::GetFlag(FLAGS_question)); + // Add candidate responses, and each one contains a sentence of text. (May + // set context too). + for (const auto& ans : absl::StrSplit(absl::GetFlag(FLAGS_answers), ':')) { + input.add_responses()->mutable_raw_text()->set_text(ans); + } + + // Run inference with the Retrieve function. + const absl::StatusOr<RetrievalOutput>& output_status = + client->Retrieve(input); + CHECK_OK(output_status); // Check ok + const RetrievalOutput& output = output_status.value(); + + // Get top results (may set optional parameter k=? to limit top-K results). + const std::vector<size_t>& top = client->Top(output); + + // Consume the results according to the ranking. Here we just print them out. + std::cout << input.query_text() << std::endl; + for (size_t k : top) { + std::cout << input.responses(k).raw_text().text() << ", " + << input.responses(k).raw_text().context() << ", " + << output.response_results(k).score() << std::endl; + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/BUILD index f61984e..601533c 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/BUILD
@@ -1,6 +1,6 @@ package( default_visibility = [ - "//tensorflow_lite_support:users", + "//tensorflow_lite_support:internal", ], licenses = ["notice"], # Apache 2.0 ) @@ -9,6 +9,10 @@ name = "image_classifier_demo", srcs = ["image_classifier_demo.cc"], deps = [ + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "//tensorflow_lite_support/cc/port:statusor", "//tensorflow_lite_support/cc/task/core:external_file_handler", "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", @@ -18,17 +22,24 @@ "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/flags:parse", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", - ], + ] + select({ + "//tensorflow_lite_support/examples/task:darwinn_portable": [ + "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin", + ], + "//conditions:default": [ + ], + }), ) cc_binary( name = "object_detector_demo", srcs = ["object_detector_demo.cc"], deps = [ + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "//tensorflow_lite_support/cc/port:statusor", "//tensorflow_lite_support/cc/task/core:external_file_handler", "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", @@ -39,18 +50,24 @@ "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/flags:parse", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - ], + ] + select({ + "//tensorflow_lite_support/examples/task:darwinn_portable": [ + "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin", + ], + "//conditions:default": [ + ], + }), ) cc_binary( name = "image_segmenter_demo", srcs = ["image_segmenter_demo.cc"], deps = [ + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "//tensorflow_lite_support/cc/port:statusor", "//tensorflow_lite_support/cc/task/core:external_file_handler", "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", @@ -59,10 +76,36 @@ "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + ] + select({ + "//tensorflow_lite_support/examples/task:darwinn_portable": [ + "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin", + ], + "//conditions:default": [ + ], + }), +) + +cc_binary( + name = "image_embedder_demo", + srcs = ["image_embedder_demo.cc"], + deps = [ "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - ], + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:external_file_handler", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "//tensorflow_lite_support/cc/task/vision:image_embedder", + "//tensorflow_lite_support/cc/task/vision/proto:embeddings_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:image_embedder_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + ] + select({ + "//tensorflow_lite_support/examples/task:darwinn_portable": [ + "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin", + ], + "//conditions:default": [ + ], + }), )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/README.md b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/README.md index 73c6b63..ef7969a 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/README.md +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/README.md
@@ -3,15 +3,44 @@ This folder contains simple command-line tools for easily trying out the C++ Vision Task APIs. +## Coral Integration + +Task Library now supports fast TFLite inference delegated onto +[Coral Edge TPU devices][4] on Linux and macOS. See the +[documentation](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview#run_task_library_with_delegates) +for more details. + +To run the demo on a Coral device, add `--define darwinn_portable=1` to the +bazel command. + +Note the `libusb` package is required. It can be installed as follows: + +```bash +# On Linux +sudo apt-get install libusb-1.0-0-dev + +# On macOS using MacPorts +port install libusb +# or Homebrew +brew install libusb +``` + +See the example commands in each task demo below. + +You can also explore more [pretrained Coral model](https://coral.ai/models) and +try them in the demo. All the models have populated with +[TFLite Model Metadata](https://www.tensorflow.org/lite/convert/metadata). + ## Image Classifier #### Prerequisites You will need: -* a TFLite image classification model (e.g. [aiy/vision/classifier/birds_V1][1], -a bird classification model available on TensorFlow Hub), -* a PNG, JPEG or GIF image to run classification on, e.g.: +* a TFLite image classification model (e.g. + [aiy/vision/classifier/birds_V1][1], a bird classification model available + on TensorFlow Hub), +* a PNG, JPEG or GIF image to run classification on, e.g.: ![sparrow](g3doc/sparrow.jpg) @@ -34,11 +63,31 @@ --max_results=3 ``` +To run the demo on a [Coral Edge TPU device][4], check +[Coral Integration](#coral-integration) section and then run: + +```bash +# Download the Coral model: +curl \ + -L 'https://github.com/google-coral/test_data/raw/master/mobilenet_v2_1.0_224_inat_bird_quant_edgetpu.tflite' \ + -o /tmp/mobilenet_v2_1.0_224_inat_bird_quant_edgetpu.tflite + +# Run the classification tool: +bazel run -c opt --define darwinn_portable=1 \ + tensorflow_lite_support/examples/task/vision/desktop:image_classifier_demo -- \ + --model_path=/tmp/mobilenet_v2_1.0_224_inat_bird_quant_edgetpu.tflite \ + --image_path=\ +$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg \ + --max_results=3 \ + --use_coral=true +``` + #### Results In the console, you should get: ``` +Time cost to classify the input image on CPU: 109ms Results: Rank #0: index : 671 @@ -63,9 +112,9 @@ You will need: -* a TFLite object detection model (e.g. [ssd_mobilenet_v1][2], a generic object -detection model available on TensorFlow Hub), -* a PNG, JPEG or GIF image to run detection on, e.g.: +* a TFLite object detection model (e.g. [ssd_mobilenet_v1][2], a generic + object detection model available on TensorFlow Hub), +* a PNG, JPEG or GIF image to run detection on, e.g.: ![dogs](g3doc/dogs.jpg) @@ -89,11 +138,32 @@ --max_results=2 ``` +To run the demo on a [Coral Edge TPU device][4], check +[Coral Integration](#coral-integration) section and then run: + +```bash +# Download the model: +curl \ + -L 'https://github.com/google-coral/test_data/raw/master/ssd_mobilenet_v1_coco_quant_postprocess_edgetpu.tflite' \ + -o /tmp/ssd_mobilenet_v1_coco_quant_postprocess_edgetpu.tflite + +# Run the detection tool: +bazel run -c opt --define darwinn_portable=1 \ + tensorflow_lite_support/examples/task/vision/desktop:object_detector_demo -- \ + --model_path=/tmp/ssd_mobilenet_v1_coco_quant_postprocess_edgetpu.tflite \ + --image_path=\ +$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg \ + --output_png=/tmp/detection-output.png \ + --max_results=2 \ + --use_coral=true +``` + #### Results In the console, you should get: ``` +Time cost to detect the input image on CPU: 123 ms Results saved to: /tmp/detection-output.png Results: Detection #0 (red): @@ -120,11 +190,11 @@ You will need: -* a TFLite image segmentation model (e.g. [deeplab_v3][3], a generic -segmentation model available on TensorFlow Hub), -* a PNG, JPEG or GIF image to run segmentation on, e.g.: +* a TFLite image segmentation model (e.g. [deeplab_v3][3], a generic + segmentation model available on TensorFlow Hub), +* a PNG, JPEG or GIF image to run segmentation on, e.g.: -![plane](g3doc/plane.jpg) +![cat](g3doc/cat.jpg) #### Usage @@ -133,40 +203,54 @@ ```bash # Download the model: curl \ - -L 'https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1?lite-format=tflite' \ - -o /tmp/deeplabv3_1_metadata_1.tflite + -L 'https://github.com/google-coral/test_data/raw/master/keras_post_training_unet_mv2_128_quant.tflite' \ + -o /tmp/keras_post_training_unet_mv2_128_quant.tflite # Run the segmentation tool: bazel run -c opt \ tensorflow_lite_support/examples/task/vision/desktop:image_segmenter_demo -- \ - --model_path=/tmp/deeplabv3_1_metadata_1.tflite \ + --model_path=/tmp/keras_post_training_unet_mv2_128_quant.tflite \ --image_path=\ -$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg \ +$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/cat.jpg \ --output_mask_png=/tmp/segmentation-output.png ``` +To run the demo on a [Coral Edge TPU device][4], check +[Coral Integration](#coral-integration) section and then run: + +```bash +# Download the model: +curl \ + -L 'https://github.com/google-coral/test_data/raw/master/keras_post_training_unet_mv2_128_quant_edgetpu.tflite' \ + -o /tmp/keras_post_training_unet_mv2_128_quant_edgetpu.tflite + +# Run the segmentation tool: +bazel run -c opt --define darwinn_portable=1 \ + tensorflow_lite_support/examples/task/vision/desktop:image_segmenter_demo -- \ + --model_path=/tmp/keras_post_training_unet_mv2_128_quant_edgetpu.tflite \ + --image_path=\ +$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/cat.jpg \ + --output_mask_png=/tmp/segmentation-output.png \ + --use_coral=true +``` + #### Results In the console, you should get: ``` +Time cost to segment the input image on CPU: 89.9316 ms Category mask saved to: /tmp/segmentation-output.png Color Legend: (r: 000, g: 000, b: 000): index : 0 - class name : background + class name : pet (r: 128, g: 000, b: 000): index : 1 - class name : aeroplane - -# (omitting multiple lines for conciseness) ... - - (r: 128, g: 192, b: 000): - index : 19 - class name : train - (r: 000, g: 064, b: 128): - index : 20 - class name : tv + class name : background + (r: 000, g: 128, b: 000): + index : 2 + class name : border Tip: use a color picker on the output PNG file to inspect the output mask with this legend. ``` @@ -178,3 +262,4 @@ [1]: https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/3 [2]: https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/2 [3]: https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 +[4]: https://coral.ai/docs/edgetpu/inference/
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/g3doc/cat.jpg b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/g3doc/cat.jpg new file mode 100644 index 0000000..b4acd1a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/g3doc/cat.jpg Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg deleted file mode 100644 index 0edefa4..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg +++ /dev/null Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/g3doc/segmentation-output.png b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/g3doc/segmentation-output.png index e871df3..9a27d30 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/g3doc/segmentation-output.png +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/g3doc/segmentation-output.png Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc index bcca6904..bd2aaaf1 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
@@ -22,10 +22,10 @@ #include <iostream> -#include "absl/flags/flag.h" -#include "absl/flags/parse.h" -#include "absl/status/status.h" -#include "absl/strings/str_format.h" +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/flags/parse.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" @@ -70,14 +70,24 @@ "Comma-separated list of class names that acts as a blacklist. If " "non-empty, classification results whose 'class_name' is in this list " "are filtered out. Mutually exclusive with 'class_name_whitelist'."); +ABSL_FLAG(bool, + use_coral, + false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); namespace tflite { namespace task { namespace vision { +namespace { +using std::chrono::microseconds; +using std::chrono::steady_clock; +} // namespace + ImageClassifierOptions BuildOptions() { ImageClassifierOptions options; - options.mutable_model_file_with_metadata()->set_file_name( + options.mutable_base_options()->mutable_model_file()->set_file_name( absl::GetFlag(FLAGS_model_path)); options.set_max_results(absl::GetFlag(FLAGS_max_results)); if (absl::GetFlag(FLAGS_score_threshold) >= 0) { @@ -91,6 +101,12 @@ absl::GetFlag(FLAGS_class_name_blacklist)) { options.add_class_name_blacklist(class_name); } + if (absl::GetFlag(FLAGS_use_coral)) { + options.mutable_base_options() + ->mutable_compute_settings() + ->mutable_tflite_settings() + ->set_delegate(::tflite::proto::Delegate::EDGETPU_CORAL); + } return options; } @@ -143,8 +159,18 @@ } // Run classification and display results. + auto start_classify = steady_clock::now(); ASSIGN_OR_RETURN(ClassificationResult result, image_classifier->Classify(*frame_buffer)); + auto end_classify = steady_clock::now(); + std::string delegate = + absl::GetFlag(FLAGS_use_coral) ? "Coral Edge TPU" : "CPU"; + std::cout << "Time cost to classify the input image on " << delegate << ": " + << std::chrono::duration<float, std::milli>(end_classify - + start_classify) + .count() + << " ms" << std::endl; + DisplayResult(result); // Cleanup and return.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc new file mode 100644 index 0000000..040878a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc
@@ -0,0 +1,195 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Computes and displays cosine similarity between the feature vectors extracted +// on two images. +// +// Example usage: +// bazel run -c opt \ +// tensorflow_lite_support/examples/task/vision/desktop:image_embedder_demo \ +// -- \ +// --model_path=/path/to/model.tflite \ +// --first_image_path=/path/to/first/image.jpg \ +// --second_image_path=/path/to/second/image.jpg + +#include <iostream> + +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/flags/parse.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/image_embedder.h" +#include "tensorflow_lite_support/cc/task/vision/proto/embeddings_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_embedder_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +ABSL_FLAG(std::string, + model_path, + "", + "Absolute path to the '.tflite' image embedder model."); +ABSL_FLAG(std::string, + first_image_path, + "", + "Absolute path to the first image, whose feature vector will be " + "extracted and compared to the second image using cosine similarity. " + "The image must be RGB or RGBA (grayscale is not supported). The " + "image EXIF orientation flag, if any, is NOT taken into account."); +ABSL_FLAG(std::string, + second_image_path, + "", + "Absolute path to the second image, whose feature vector will be " + "extracted and compared to the first image using cosine similarity. " + "The image must be RGB or RGBA (grayscale is not supported). The " + "image EXIF orientation flag, if any, is NOT taken into account."); +ABSL_FLAG(bool, + l2_normalize, + false, + "If true, the raw feature vectors returned by the image embedder " + "will be normalized with L2-norm. Generally only needed if the model " + "doesn't already contain a L2_NORMALIZATION TFLite Op."); +ABSL_FLAG( + bool, + quantize, + false, + "If true, the raw feature vectors returned by the image embedder will " + "be quantized to 8 bit integers (uniform quantization) via post-processing " + "before cosine similarity is computed."); +ABSL_FLAG(bool, + use_coral, + false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); + +namespace tflite { +namespace task { +namespace vision { + +namespace { +using std::chrono::microseconds; +using std::chrono::steady_clock; +using ::tflite::support::StatusOr; +} // namespace + +ImageEmbedderOptions BuildOptions() { + ImageEmbedderOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + absl::GetFlag(FLAGS_model_path)); + options.set_l2_normalize(absl::GetFlag(FLAGS_l2_normalize)); + options.set_quantize(absl::GetFlag(FLAGS_quantize)); + + if (absl::GetFlag(FLAGS_use_coral)) { + options.mutable_compute_settings()->mutable_tflite_settings()->set_delegate( + ::tflite::proto::Delegate::EDGETPU_CORAL); + } + return options; +} + +StatusOr<std::unique_ptr<FrameBuffer>> BuildFrameBufferFromImageData( + const ImageData& image) { + std::unique_ptr<FrameBuffer> frame_buffer; + if (image.channels == 3) { + return CreateFromRgbRawBuffer(image.pixel_data, + {image.width, image.height}); + } else if (image.channels == 4) { + return CreateFromRgbaRawBuffer(image.pixel_data, + {image.width, image.height}); + } + return absl::InvalidArgumentError(absl::StrFormat( + "Expected image with 3 (RGB) or 4 (RGBA) channels, found %d", + image.channels)); +} + +absl::Status ComputeCosineSimilarity() { + // Build ImageEmbedder. + const ImageEmbedderOptions& options = BuildOptions(); + ASSIGN_OR_RETURN(std::unique_ptr<ImageEmbedder> image_embedder, + ImageEmbedder::CreateFromOptions(options)); + + // Load images into FrameBuffer objects. + ASSIGN_OR_RETURN(ImageData first_image, + DecodeImageFromFile(absl::GetFlag(FLAGS_first_image_path))); + ASSIGN_OR_RETURN(std::unique_ptr<FrameBuffer> first_frame_buffer, + BuildFrameBufferFromImageData(first_image)); + ASSIGN_OR_RETURN(ImageData second_image, + DecodeImageFromFile(absl::GetFlag(FLAGS_second_image_path))); + ASSIGN_OR_RETURN(std::unique_ptr<FrameBuffer> second_frame_buffer, + BuildFrameBufferFromImageData(second_image)); + + // Extract feature vectors. + auto start_embed = steady_clock::now(); + ASSIGN_OR_RETURN(const EmbeddingResult& first_embedding_result, + image_embedder->Embed(*first_frame_buffer)); + auto end_embed = steady_clock::now(); + std::string delegate = + absl::GetFlag(FLAGS_use_coral) ? "Coral Edge TPU" : "CPU"; + std::cout << "Time cost to embed the input image on " << delegate << ": " + << std::chrono::duration<float, std::milli>(end_embed - start_embed) + .count() + << " ms" << std::endl; + + ASSIGN_OR_RETURN(const EmbeddingResult& second_embedding_result, + image_embedder->Embed(*second_frame_buffer)); + // Compute cosine similarity. + ASSIGN_OR_RETURN( + double cosine_similarity, + ImageEmbedder::CosineSimilarity( + image_embedder->GetEmbeddingByIndex(first_embedding_result, 0) + .feature_vector(), + image_embedder->GetEmbeddingByIndex(second_embedding_result, 0) + .feature_vector())); + + // Display result. + std::cout << absl::StrFormat("Cosine similarity: %f\n", cosine_similarity); + + // Cleanup and return. + ImageDataFree(&first_image); + ImageDataFree(&second_image); + return absl::OkStatus(); +} + +} // namespace vision +} // namespace task +} // namespace tflite + +int main(int argc, char** argv) { + // Parse command line arguments and perform sanity checks. + absl::ParseCommandLine(argc, argv); + if (absl::GetFlag(FLAGS_model_path).empty()) { + std::cerr << "Missing mandatory 'model_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_first_image_path).empty()) { + std::cerr << "Missing mandatory 'first_image_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_second_image_path).empty()) { + std::cerr << "Missing mandatory 'second_image_path' argument.\n"; + return 1; + } + + // Compute cosine similarity. + absl::Status status = tflite::task::vision::ComputeCosineSimilarity(); + if (status.ok()) { + return 0; + } else { + std::cerr << "Cosine similarity computation failed: " << status.message() + << "\n"; + return 1; + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc index a7539df..6487fe9 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
@@ -23,11 +23,11 @@ #include <iostream> -#include "absl/flags/flag.h" -#include "absl/flags/parse.h" -#include "absl/status/status.h" -#include "absl/strings/match.h" -#include "absl/strings/str_format.h" +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/flags/parse.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/match.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" @@ -52,17 +52,34 @@ "", "Absolute path to the output category mask (confidence masks outputs " "are not supported by this tool). Must have a '.png' extension."); +ABSL_FLAG(bool, + use_coral, + false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); namespace tflite { namespace task { namespace vision { +namespace { +using std::chrono::microseconds; +using std::chrono::steady_clock; +} // namespace + ImageSegmenterOptions BuildOptions() { ImageSegmenterOptions options; - options.mutable_model_file_with_metadata()->set_file_name( + options.mutable_base_options()->mutable_model_file()->set_file_name( absl::GetFlag(FLAGS_model_path)); // Confidence masks are not supported by this tool: output_type is set to // CATEGORY_MASK by default. + + if (absl::GetFlag(FLAGS_use_coral)) { + options.mutable_base_options() + ->mutable_compute_settings() + ->mutable_tflite_settings() + ->set_delegate(::tflite::proto::Delegate::EDGETPU_CORAL); + } return options; } @@ -160,8 +177,18 @@ } // Run segmentation and save category mask. + auto start_segment = steady_clock::now(); ASSIGN_OR_RETURN(SegmentationResult result, image_segmenter->Segment(*frame_buffer)); + auto end_segment = steady_clock::now(); + std::string delegate = + absl::GetFlag(FLAGS_use_coral) ? "Coral Edge TPU" : "CPU"; + std::cout << "Time cost to segment the input image on " << delegate << ": " + << std::chrono::duration<float, std::milli>(end_segment - + start_segment) + .count() + << " ms" << std::endl; + RETURN_IF_ERROR(EncodeMaskToPngFile(result)); // Display the legend.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc index f6f626f..9208439 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
@@ -24,11 +24,11 @@ #include <iostream> #include <limits> -#include "absl/flags/flag.h" -#include "absl/flags/parse.h" -#include "absl/status/status.h" -#include "absl/strings/match.h" -#include "absl/strings/str_format.h" +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/flags/parse.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/match.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" @@ -79,12 +79,22 @@ "Comma-separated list of class names that acts as a blacklist. If " "non-empty, detections results whose 'class_name' is in this list " "are filtered out. Mutually exclusive with 'class_name_whitelist'."); +ABSL_FLAG(bool, + use_coral, + false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); namespace tflite { namespace task { namespace vision { namespace { +using std::chrono::microseconds; +using std::chrono::steady_clock; +} // namespace + +namespace { // The line thickness (in pixels) for drawing the detection results. constexpr int kLineThickness = 3; @@ -105,7 +115,7 @@ ObjectDetectorOptions BuildOptions() { ObjectDetectorOptions options; - options.mutable_model_file_with_metadata()->set_file_name( + options.mutable_base_options()->mutable_model_file()->set_file_name( absl::GetFlag(FLAGS_model_path)); options.set_max_results(absl::GetFlag(FLAGS_max_results)); if (absl::GetFlag(FLAGS_score_threshold) > @@ -120,6 +130,12 @@ absl::GetFlag(FLAGS_class_name_blacklist)) { options.add_class_name_blacklist(class_name); } + if (absl::GetFlag(FLAGS_use_coral)) { + options.mutable_base_options() + ->mutable_compute_settings() + ->mutable_tflite_settings() + ->set_delegate(::tflite::proto::Delegate::EDGETPU_CORAL); + } return options; } @@ -212,8 +228,18 @@ } // Run object detection and draw results on input image. + auto start_detect = steady_clock::now(); ASSIGN_OR_RETURN(DetectionResult result, object_detector->Detect(*frame_buffer)); + auto end_detect = steady_clock::now(); + std::string delegate = + absl::GetFlag(FLAGS_use_coral) ? "Coral Edge TPU" : "CPU"; + std::cout << "Time cost to detect the input image on " << delegate << ": " + << std::chrono::duration<float, std::milli>(end_detect - + start_detect) + .count() + << " ms" << std::endl; + RETURN_IF_ERROR(EncodeResultToPngFile(result, &image)); // Display results as text.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/BUILD new file mode 100644 index 0000000..baf5e784 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/BUILD
@@ -0,0 +1,55 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +# bazel run \ +# tensorflow_lite_support/examples/task/vision/desktop/python:image_classifier_demo -- \ +# --model_path=/tmp/aiy_vision_classifier_birds_V1_3.tflite \ +# --image_path=\ +# $(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg \ +# --max_results=3 +py_binary( + name = "image_classifier_demo", + srcs = ["image_classifier_demo.py"], + data = ["//tensorflow_lite_support/examples/task/vision/desktop:image_classifier_demo"], + deps = [ + "@absl_py//absl:app", + "@absl_py//absl/flags", + ], +) + +# bazel run \ +# tensorflow_lite_support/examples/task/vision/desktop/python:object_detector_demo -- \ +# --model_path=/tmp/ssd_mobilenet_v1_1_metadata_1.tflite \ +# --image_path=\ +# $(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg \ +# --output_png=/tmp/detection-output.png \ +# --max_results=2 +py_binary( + name = "object_detector_demo", + srcs = ["object_detector_demo.py"], + data = ["//tensorflow_lite_support/examples/task/vision/desktop:object_detector_demo"], + deps = [ + "@absl_py//absl:app", + "@absl_py//absl/flags", + ], +) + +# bazel run \ +# tensorflow_lite_support/examples/task/vision/desktop/python:image_segmenter_demo -- \ +# --model_path=/tmp/keras_post_training_unet_mv2_128_quant.tflite \ +# --image_path=\ +# $(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/cat.jpg \ +# --output_mask_png=/tmp/segmentation-output.png +py_binary( + name = "image_segmenter_demo", + srcs = ["image_segmenter_demo.py"], + data = ["//tensorflow_lite_support/examples/task/vision/desktop:image_segmenter_demo"], + deps = [ + "@absl_py//absl:app", + "@absl_py//absl/flags", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/README.md b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/README.md new file mode 100644 index 0000000..d808ce9b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/README.md
@@ -0,0 +1,270 @@ +# CLI Demos for Python Vision Task APIs + +A Python wrapper for the C++ Vision Task APIs. + +## Background + +This Python API is based on the C++ Vision Task APIs. It uses Python's +[subprocess](https://docs.python.org/3/library/subprocess.html) to call C++ Vision +Task APIs. + +## Coral Integration + +Task Library now supports fast TFLite inference delegated onto +[Coral Edge TPU devices][4] on Linux and macOS. See the +[documentation](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview#run_task_library_with_delegates) +for more details. + +To run the demo on a Coral device, add `--define darwinn_portable=1` to the +bazel command. + +Note the `libusb` package is required. It can be installed as follows: + +```bash +# On Linux +sudo apt-get install libusb-1.0-0-dev + +# On macOS using MacPorts +port install libusb +# or Homebrew +brew install libusb +``` + +See the example commands in each task demo below. + +You can also explore more [pretrained Coral model](https://coral.ai/models) and +try them in the demo. All the models have populated with +[TFLite Model Metadata](https://www.tensorflow.org/lite/convert/metadata). + +## Image Classifier + +#### Prerequisites + +You will need: + +* a TFLite image classification model (e.g. + [aiy/vision/classifier/birds_V1][1], a bird classification model available + on TensorFlow Hub), +* a PNG, JPEG or GIF image to run classification on, e.g.: + +![sparrow](../g3doc/sparrow.jpg) + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/3?lite-format=tflite' \ + -o /tmp/aiy_vision_classifier_birds_V1_3.tflite + +# Run the classification tool: +bazel run \ + tensorflow_lite_support/examples/task/vision/desktop/python:image_classifier_demo -- \ + --model_path=/tmp/aiy_vision_classifier_birds_V1_3.tflite \ + --image_path=\ +$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg \ + --max_results=3 +``` + +To run the demo on a [Coral Edge TPU device][4], check +[Coral Integration](#coral-integration) section and then run: + +```bash +# Download the Coral model: +curl \ + -L 'https://github.com/google-coral/test_data/raw/master/mobilenet_v2_1.0_224_inat_bird_quant_edgetpu.tflite' \ + -o /tmp/mobilenet_v2_1.0_224_inat_bird_quant_edgetpu.tflite + +# Run the classification tool: +bazel run --define darwinn_portable=1 \ + tensorflow_lite_support/examples/task/vision/desktop/python:image_classifier_demo -- \ + --model_path=/tmp/mobilenet_v2_1.0_224_inat_bird_quant_edgetpu.tflite \ + --image_path=\ +$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg \ + --max_results=3 \ + --use_coral=true +``` + +#### Results + +In the console, you should get: + +``` +Time cost to classify the input image on CPU: 109ms +Results: + Rank #0: + index : 671 + score : 0.91406 + class name : /m/01bwb9 + display name: Passer domesticus + Rank #1: + index : 670 + score : 0.00391 + class name : /m/01bwbt + display name: Passer montanus + Rank #2: + index : 495 + score : 0.00391 + class name : /m/0bwm6m + display name: Passer italiae +``` + +## Object Detector + +#### Prerequisites + +You will need: + +* a TFLite object detection model (e.g. [ssd_mobilenet_v1][2], a generic + object detection model available on TensorFlow Hub), +* a PNG, JPEG or GIF image to run detection on, e.g.: + +![dogs](../g3doc/dogs.jpg) + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite' \ + -o /tmp/ssd_mobilenet_v1_1_metadata_1.tflite + +# Run the detection tool: +bazel run \ + tensorflow_lite_support/examples/task/vision/desktop/python:object_detector_demo -- \ + --model_path=/tmp/ssd_mobilenet_v1_1_metadata_1.tflite \ + --image_path=\ +$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg \ + --output_png=/tmp/detection-output.png \ + --max_results=2 +``` + +To run the demo on a [Coral Edge TPU device][4], check +[Coral Integration](#coral-integration) section and then run: + +```bash +# Download the model: +curl \ + -L 'https://github.com/google-coral/test_data/raw/master/ssd_mobilenet_v1_coco_quant_postprocess_edgetpu.tflite' \ + -o /tmp/ssd_mobilenet_v1_coco_quant_postprocess_edgetpu.tflite + +# Run the detection tool: +bazel run --define darwinn_portable=1 \ + tensorflow_lite_support/examples/task/vision/desktop/python:object_detector_demo -- \ + --model_path=/tmp/ssd_mobilenet_v1_coco_quant_postprocess_edgetpu.tflite \ + --image_path=\ +$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg \ + --output_png=/tmp/detection-output.png \ + --max_results=2 \ + --use_coral=true +``` + +#### Results + +In the console, you should get: + +``` +Time cost to detect the input image on CPU: 123 ms +Results saved to: /tmp/detection-output.png +Results: + Detection #0 (red): + Box: (x: 355, y: 133, w: 190, h: 206) + Top-1 class: + index : 17 + score : 0.73828 + class name : dog + Detection #1 (green): + Box: (x: 103, y: 15, w: 138, h: 369) + Top-1 class: + index : 17 + score : 0.73047 + class name : dog +``` + +And `/tmp/detection-output.jpg` should contain: + +![detection-output](../g3doc/detection-output.png) + +## Image Segmenter + +#### Prerequisites + +You will need: + +* a TFLite image segmentation model (e.g. [deeplab_v3][3], a generic + segmentation model available on TensorFlow Hub), +* a PNG, JPEG or GIF image to run segmentation on, e.g.: + +![cat](../g3doc/cat.jpg) + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://github.com/google-coral/test_data/raw/master/keras_post_training_unet_mv2_128_quant.tflite' \ + -o /tmp/keras_post_training_unet_mv2_128_quant.tflite + +# Run the segmentation tool: +bazel run \ + tensorflow_lite_support/examples/task/vision/desktop/python:image_segmenter_demo -- \ + --model_path=/tmp/keras_post_training_unet_mv2_128_quant.tflite \ + --image_path=\ +$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/cat.jpg \ + --output_mask_png=/tmp/segmentation-output.png +``` + +To run the demo on a [Coral Edge TPU device][4], check +[Coral Integration](#coral-integration) section and then run: + +```bash +# Download the model: +curl \ + -L 'https://github.com/google-coral/test_data/raw/master/keras_post_training_unet_mv2_128_quant_edgetpu.tflite' \ + -o /tmp/keras_post_training_unet_mv2_128_quant_edgetpu.tflite + +# Run the segmentation tool: +bazel run --define darwinn_portable=1 \ + tensorflow_lite_support/examples/task/vision/desktop/python:image_segmenter_demo -- \ + --model_path=/tmp/keras_post_training_unet_mv2_128_quant_edgetpu.tflite \ + --image_path=\ +$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/cat.jpg \ + --output_mask_png=/tmp/segmentation-output.png \ + --use_coral=true +``` + +#### Results + +In the console, you should get: + +``` +Time cost to segment the input image on CPU: 89.9316 ms +Category mask saved to: /tmp/segmentation-output.png +Color Legend: + (r: 000, g: 000, b: 000): + index : 0 + class name : pet + (r: 128, g: 000, b: 000): + index : 1 + class name : background + (r: 000, g: 128, b: 000): + index : 2 + class name : border +Tip: use a color picker on the output PNG file to inspect the output mask with +this legend. +``` + +And `/tmp/segmentation-output.jpg` should contain the segmentation mask: + +![segmentation-output](../g3doc/segmentation-output.png) + +[1]: https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/3 +[2]: https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/2 +[3]: https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 +[4]: https://coral.ai/docs/edgetpu/inference/
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_classifier_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_classifier_demo.py new file mode 100644 index 0000000..c1ee07b3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_classifier_demo.py
@@ -0,0 +1,107 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python demo tool for Image Classification.""" + +import inspect +import os.path as _os_path +import subprocess +import sys + +from absl import app +from absl import flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string('model_path', None, + 'Absolute path to the ".tflite" image classifier model.') +flags.DEFINE_string( + 'image_path', None, + 'Absolute path to the image to classify. The image must be RGB or ' + 'RGBA (grayscale is not supported). The image EXIF orientation ' + 'flag, if any, is NOT taken into account.') +flags.DEFINE_integer('max_results', 5, + 'Maximum number of classification results to display.') +flags.DEFINE_float( + 'score_threshold', 0, + 'Classification results with a confidence score below this value are ' + 'rejected. If >= 0, overrides the score threshold(s) provided in the ' + 'TFLite Model Metadata. Ignored otherwise.') +flags.DEFINE_string( + 'class_name_whitelist', '', + 'Comma-separated list of class names that acts as a whitelist. If ' + 'non-empty, classification results whose "class_name" is not in this list ' + 'are filtered out. Mutually exclusive with "class_name_blacklist".') +flags.DEFINE_string( + 'class_name_blacklist', '', + 'Comma-separated list of class names that acts as a blacklist. If ' + 'non-empty, classification results whose "class_name" is in this list ' + 'are filtered out. Mutually exclusive with "class_name_whitelist".') +flags.DEFINE_bool( + 'use_coral', False, + 'If true, inference will be delegated to a connected Coral Edge TPU ' + 'device.') +# Required flag. +flags.mark_flag_as_required('model_path') +flags.mark_flag_as_required('image_path') + +_IMAGE_CLASSIFICATION_NATIVE_PATH = _os_path.join( + _os_path.dirname(inspect.getfile(inspect.currentframe())), + '../image_classifier_demo') + + +def classify(model_path, image_path, max_results, score_threshold, + class_name_whitelist, class_name_blacklist, use_coral): + """Classifies input image into different categories. + + Args: + model_path: Path to model + image_path: Absolute path to the image to classify + max_results: Maximum number of classification results to display + score_threshold: Optional; Classification results with a confidence + score below this value are rejected + class_name_whitelist: Optional; Comma-separated list of class names + that acts as a whitelist + class_name_blacklist: Optional; Comma-separated list of class names + that acts as a blacklist + use_coral: Optional; If true, inference will be delegated to a + connected Coral Edge TPU device + """ + # Run the classification tool: + subprocess.run([ + _IMAGE_CLASSIFICATION_NATIVE_PATH + ' --model_path=' + model_path + + ' --image_path=' + image_path + ' --max_results=' + str(max_results) + + ' --score_threshold=' + str(score_threshold) + + ' --class_name_whitelist="' + str(class_name_whitelist) + + '" --class_name_blacklist="' + str(class_name_blacklist) + + '" --use_coral=' + str(use_coral) + ], + shell=True, + check=True) + + +def run_main(argv): + del argv # Unused. + classify(FLAGS.model_path, FLAGS.image_path, FLAGS.max_results, + FLAGS.score_threshold, FLAGS.class_name_whitelist, + FLAGS.class_name_blacklist, FLAGS.use_coral) + + +# Simple wrapper to make the code pip-friendly +def main(): + app.run(main=run_main, argv=sys.argv) + + +if __name__ == '__main__': + main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_segmenter_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_segmenter_demo.py new file mode 100644 index 0000000..f46bf1f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_segmenter_demo.py
@@ -0,0 +1,85 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python demo tool for Image Segmentation.""" + +import inspect +import os.path as _os_path +import subprocess +import sys + +from absl import app +from absl import flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string('model_path', None, + 'Absolute path to the ".tflite" image segmenter model.') +flags.DEFINE_string( + 'image_path', None, + 'Absolute path to the image to segment. The image must be RGB or ' + 'RGBA (grayscale is not supported). The image EXIF orientation ' + 'flag, if any, is NOT taken into account.') +flags.DEFINE_string( + 'output_mask_png', None, + 'Absolute path to the output category mask (confidence masks outputs ' + 'are not supported by this tool). Must have a ".png" extension.') +flags.DEFINE_bool( + 'use_coral', False, + 'If true, inference will be delegated to a connected Coral Edge TPU ' + 'device.') +# Required flag. +flags.mark_flag_as_required('model_path') +flags.mark_flag_as_required('image_path') +flags.mark_flag_as_required('output_mask_png') + +_IMAGE_SEGMENTATION_NATIVE_PATH = _os_path.join( + _os_path.dirname(inspect.getfile(inspect.currentframe())), + '../image_segmenter_demo') + + +def classify(model_path, image_path, output_mask_png, use_coral): + """Segments the input image. + + Args: + model_path: Path to model + image_path: Absolute path to the image to segment + output_mask_png: Absolute path to the output category mask (confidence + masks outputs are not supported by this tool + use_coral: Optional; If true, inference will be delegated to a connected + Coral Edge TPU device + """ + # Run the segmentation tool: + subprocess.run([ + _IMAGE_SEGMENTATION_NATIVE_PATH + ' --model_path=' + model_path + + ' --image_path=' + image_path + ' --output_mask_png=' + output_mask_png + + ' --use_coral=' + str(use_coral) + ], + shell=True, + check=True) + + +def run_main(argv): + del argv # Unused. + classify(FLAGS.model_path, FLAGS.image_path, FLAGS.output_mask_png, + FLAGS.use_coral) + + +# Simple wrapper to make the code pip-friendly +def main(): + app.run(main=run_main, argv=sys.argv) + + +if __name__ == '__main__': + main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/object_detector_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/object_detector_demo.py new file mode 100644 index 0000000..7842b6c2 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/object_detector_demo.py
@@ -0,0 +1,116 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python demo tool for Object Detection.""" + +import inspect +import os.path as _os_path +import subprocess +import sys + +from absl import app +from absl import flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string('model_path', None, + 'Absolute path to the ".tflite" object detector model.') +flags.DEFINE_string( + 'image_path', None, + 'Absolute path to the image to run detection on. The image must be ' + 'RGB or RGBA (grayscale is not supported). The image EXIF ' + 'orientation flag, if any, is NOT taken into account.') +flags.DEFINE_string( + 'output_png', None, + 'Absolute path to a file where to draw the detection results on top ' + 'of the input image. Must have a ".png" extension.') +flags.DEFINE_integer('max_results', 5, + 'Maximum number of detection results to display.') +flags.DEFINE_float( + 'score_threshold', 0, + 'Detection results with a confidence score below this value are ' + 'rejected. If specified, overrides the score threshold(s) provided in the ' + 'TFLite Model Metadata. Ignored otherwise.') +flags.DEFINE_string( + 'class_name_whitelist', '', + 'Comma-separated list of class names that acts as a whitelist. If ' + 'non-empty, detections results whose "class_name" is not in this list ' + 'are filtered out. Mutually exclusive with "class_name_blacklist".') +flags.DEFINE_string( + 'class_name_blacklist', '', + 'Comma-separated list of class names that acts as a blacklist. If ' + 'non-empty, detections results whose "class_name" is in this list ' + 'are filtered out. Mutually exclusive with "class_name_whitelist".') +flags.DEFINE_bool( + 'use_coral', False, + 'If true, inference will be delegated to a connected Coral Edge TPU ' + 'device.') + +# Required flag. +flags.mark_flag_as_required('model_path') +flags.mark_flag_as_required('image_path') +flags.mark_flag_as_required('output_png') +flags.mark_flag_as_required('max_results') + +_OBJECT_DETECTION_NATIVE_PATH = _os_path.join( + _os_path.dirname(inspect.getfile(inspect.currentframe())), + '../object_detector_demo') + + +def classify(model_path, image_path, output_png, max_results, score_threshold, + class_name_whitelist, class_name_blacklist, use_coral): + """Detects the input image. + + Args: + model_path: Path to model + image_path: Absolute path to the image to run detection on + output_png: Absolute path to a file where to draw the detection results + on top of the input image + max_results: Maximum number of detection results to display + score_threshold: Optional; Detection results with a confidence score + below this value are rejected + class_name_whitelist: Optional; Comma-separated list of class names + that acts as a whitelist. + class_name_blacklist: Optional; Comma-separated list of class names + that acts as a blacklist. + use_coral: Optional; If true, inference will be delegated to a + connected Coral Edge TPU device + """ + # Run the detection tool: + subprocess.run([ + _OBJECT_DETECTION_NATIVE_PATH + ' --model_path=' + model_path + + ' --image_path=' + image_path + ' --output_png=' + output_png + + ' --max_results=' + str(max_results) + ' --score_threshold=' + + str(score_threshold) + ' --class_name_whitelist="' + + str(class_name_whitelist) + '" --class_name_blacklist="' + + str(class_name_blacklist) + '" --use_coral=' + str(use_coral) + ], + shell=True, + check=True) + + +def run_main(argv): + del argv # Unused. + classify(FLAGS.model_path, FLAGS.image_path, FLAGS.output_png, + FLAGS.max_results, FLAGS.score_threshold, FLAGS.class_name_whitelist, + FLAGS.class_name_blacklist, FLAGS.use_coral) + + +# Simple wrapper to make the code pip-friendly +def main(): + app.run(main=run_main, argv=sys.argv) + + +if __name__ == '__main__': + main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD index 9a837e2d..481fd8ec 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD
@@ -1,6 +1,6 @@ package( default_visibility = [ - "//tensorflow_lite_support:users", + "//tensorflow_lite_support:internal", ], licenses = ["notice"], # Apache 2.0 ) @@ -9,6 +9,9 @@ name = "image_utils", srcs = ["image_utils.cc"], hdrs = ["image_utils.h"], + visibility = [ + "//tensorflow_lite_support:internal", + ], deps = [ "//tensorflow_lite_support/cc/port:integral_types", "//tensorflow_lite_support/cc/port:status_macros",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc index 004d89a..efdcda9 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc
@@ -23,11 +23,11 @@ #define STB_IMAGE_IMPLEMENTATION #define STB_IMAGE_WRITE_IMPLEMENTATION -#include "absl/status/status.h" -#include "absl/strings/match.h" -#include "absl/strings/str_format.h" -#include "stb_image.h" -#include "stb_image_write.h" +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/match.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "stb_image.h" // from @stblib +#include "stb_image_write.h" // from @stblib #include "tensorflow_lite_support/cc/port/status_macros.h" #include "tensorflow_lite_support/cc/port/statusor.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h index 38f62b6..9e7e3ba 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h
@@ -15,8 +15,8 @@ #ifndef TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_ #define TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_ -#include "absl/status/status.h" -#include "absl/strings/string_view.h" +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/integral_types.h" #include "tensorflow_lite_support/cc/port/statusor.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/ios/BUILD index 07b8951..b1e6c4c 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/BUILD
@@ -1,9 +1,9 @@ -# TensorFlow Lite Task Library - Text +# TensorFlow Lite Task Library load( - "@org_tensorflow//tensorflow/lite/experimental/ios:ios.bzl", + "@org_tensorflow//tensorflow/lite/ios:ios.bzl", "TFL_MINIMUM_OS_VERSION", - "tflite_ios_static_framework", + "tflite_ios_framework", ) load( "//tensorflow_lite_support/ios:ios.bzl", @@ -18,10 +18,17 @@ strip_c_api_include_path_prefix( name = "strip_c_api_include_path", hdr_labels = [ - "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier_c_api.h", - "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier_c_api.h", - "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier_c_api_common.h", - "//tensorflow_lite_support/cc/task/text/qa:bert_qa_c_api.h", + "//tensorflow_lite_support/c/task/text:bert_nl_classifier.h", + "//tensorflow_lite_support/c/task/text:nl_classifier.h", + "//tensorflow_lite_support/c/task/text:nl_classifier_common.h", + "//tensorflow_lite_support/c/task/text:bert_question_answerer.h", + "//tensorflow_lite_support/c/task/vision:image_classifier.h", + "//tensorflow_lite_support/c/task/processor:bounding_box.h", + "//tensorflow_lite_support/c/task/vision/core:frame_buffer.h", + "//tensorflow_lite_support/c/task/processor:classification_result.h", + "//tensorflow_lite_support/c/task/processor:classification_options.h", + "//tensorflow_lite_support/c/task/core:base_options.h", + "//tensorflow_lite_support/c:common.h", ], ) @@ -29,20 +36,75 @@ # which includes the TFLite runtime in it. # # bazel build -c opt --config=ios_fat //tensorflow_lite_support/ios:TensorFlowLiteTaskTextC_framework -tflite_ios_static_framework( +tflite_ios_framework( name = "TensorFlowLiteTaskTextC_framework", hdrs = [ - ":bert_nl_classifier_c_api.h", - ":bert_qa_c_api.h", - ":nl_classifier_c_api.h", - ":nl_classifier_c_api_common.h", + ":bert_nl_classifier.h", + ":bert_question_answerer.h", + ":nl_classifier.h", + ":nl_classifier_common.h", ], allowlist_symbols_file = ":allowlist_TensorFlowLiteTaskText.txt", bundle_name = "TensorFlowLiteTaskTextC", minimum_os_version = TFL_MINIMUM_OS_VERSION, deps = [ - "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier_c_api", - "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier_c_api", - "//tensorflow_lite_support/cc/task/text/qa:bert_qa_c_api", + "//tensorflow_lite_support/c/task/text:bert_nl_classifier", + "//tensorflow_lite_support/c/task/text:bert_question_answerer", + "//tensorflow_lite_support/c/task/text:nl_classifier", + ], +) + +# Xcode 12 does not support ios fat libraries. Frameworks built for multiple +# architectures should be compiled into a .xcframework inside. Bazel currently +# does not support building .xcframework. You have to build the framework +# for the architecture you decide to test on. +# Use the below command to build for arm64 which lets you test the library on +# iOS devices. +# bazel build -c opt --config=ios_arm64 //tensorflow_lite_support/ios:TensorFlowLiteTaskVisionC_framework +tflite_ios_framework( + name = "TensorFlowLiteTaskVisionC_framework", + hdrs = [ + ":base_options.h", + ":bounding_box.h", + ":classification_options.h", + ":classification_result.h", + ":common.h", + ":frame_buffer.h", + ":image_classifier.h", + ], + allowlist_symbols_file = ":allowlist_TensorFlowLiteTaskVision.txt", + bundle_name = "TensorFlowLiteTaskVisionC", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + deps = [ + "//tensorflow_lite_support/c/task/vision:image_classifier", + ], +) + +objc_library( + name = "TFLCommon", + hdrs = [ + "sources/TFLCommon.h", + ], + module_name = "TFLCommon", + visibility = [ + "//tensorflow_lite_support:__subpackages__", + ], +) + +objc_library( + name = "TFLCommonUtils", + srcs = [ + "sources/TFLCommonUtils.m", + ], + hdrs = [ + "sources/TFLCommonUtils.h", + ], + module_name = "TFLCommonUtils", + visibility = [ + "//tensorflow_lite_support:__subpackages__", + ], + deps = [ + "//tensorflow_lite_support/c:common", + "//tensorflow_lite_support/ios:TFLCommon", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template b/third_party/tflite_support/src/tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template index 62c3f33..63bd9f7b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template
@@ -13,7 +13,7 @@ s.module_name = 'TensorFlowLiteTaskText' s.static_framework = true - s.dependency 'GoogleToolboxForMac', '2.2.1' + s.dependency 'GoogleToolboxForMac', '~> 2.2' objc_dir = 'tensorflow_lite_support/ios/task/text/' s.public_header_files = [
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/allowlist_TensorFlowLiteTaskText.txt b/third_party/tflite_support/src/tensorflow_lite_support/ios/allowlist_TensorFlowLiteTaskText.txt index 3af5b0b1..e8ae288 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/allowlist_TensorFlowLiteTaskText.txt +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/allowlist_TensorFlowLiteTaskText.txt
@@ -1,3 +1 @@ -_NLClassifier* -_BertNLClassifier* -_BertQuestionAnswerer* +_TfLite*
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/allowlist_TensorFlowLiteTaskVision.txt b/third_party/tflite_support/src/tensorflow_lite_support/ios/allowlist_TensorFlowLiteTaskVision.txt new file mode 100644 index 0000000..e8ae288 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/allowlist_TensorFlowLiteTaskVision.txt
@@ -0,0 +1 @@ +_TfLite*
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h new file mode 100644 index 0000000..2ca42fb --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h
@@ -0,0 +1,213 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import <Foundation/Foundation.h> + +NS_ASSUME_NONNULL_BEGIN + +/** + * @enum TFLSupportErrorCode + * This enum specifies error codes for TensorFlow Lite Task Library. + * It maintains a 1:1 mapping to TfLiteSupportErrorCode of C libray. + */ +typedef NS_ENUM(NSUInteger, TFLSupportErrorCode) { + + /** Unspecified error. */ + TFLSupportErrorCodeUnspecifiedError = 1, + + /** Invalid argument specified. */ + TFLSupportErrorCodeInvalidArgumentError = 2, + + /** Invalid FlatBuffer file or buffer specified. */ + TFLSupportErrorCodeInvalidFlatBufferError = 3, + + /** Model contains a builtin op that isn't supported by the OpResolver or + delegates. */ + TFLSupportErrorCodeUnsupportedBuiltinOpError = 4, + + /** Model contains a custom op that isn't supported by the OpResolver or + * delegates. */ + TFLSupportErrorCodeUnsupportedCustomOpError = 5, + + /** File I/O error codes. */ + + /** No such file. */ + TFLSupportErrorCodeFileNotFoundError = 100, + + /** Permission issue. */ + TFLSupportErrorCodeFilePermissionDeniedError, + + /** I/O error when reading file. */ + TFLSupportErrorCodeFileReadError, + + /** I/O error when mmap-ing file. */ + TFLSupportErrorCodeFileMmapError, + + /** TensorFlow Lite metadata error codes. */ + + /** Unexpected schema version (aka file_identifier) in the Metadata + FlatBuffer. */ + TFLSupportErrorCodeMetadataInvalidSchemaVersionError = 200, + + /** No such associated file within metadata, or file has not been packed. */ + TFLSupportErrorCodeMetadataAssociatedFileNotFoundError, + + /** ZIP I/O error when unpacking an associated file. */ + TFLSupportErrorCodeMetadataAssociatedFileZipError, + + /** + * Inconsistency error between the metadata and actual TF Lite model. + * E.g.: number of labels and output tensor values differ. + */ + TFLSupportErrorCodeMetadataInconsistencyError, + + /** + * Invalid process units specified. + * E.g.: multiple ProcessUnits with the same type for a given tensor. + */ + TFLSupportErrorCodeMetadataInvalidProcessUnitsError, + + /** + * Inconsistency error with the number of labels. + * E.g.: label files for different locales have a different number of labels. + */ + TFLSupportErrorCodeMetadataNumLabelsMismatchError, + + /** + * Score calibration parameters parsing error. + * E.g.: too many parameters provided in the corresponding associated file. + */ + TFLSupportErrorCodeMetadataMalformedScoreCalibrationError, + + /** Unexpected number of subgraphs for the current task. + * E.g.: image classification expects a single subgraph. + */ + TFLSupportErrorCodeMetadataInvalidNumSubgraphsError, + /** + * A given tensor requires NormalizationOptions but none were found. + * E.g.: float input tensor requires normalization to preprocess input images. + */ + TFLSupportErrorCodeMetadataMissingNormalizationOptionsError, + + /** + * Invalid ContentProperties specified. + * E.g. expected ImageProperties, got BoundingBoxProperties. + */ + TFLSupportErrorCodeMetadataInvalidContentPropertiesError, + + /** + * Metadata is mandatory but was not found. + * E.g. current task requires TFLite Model Metadata but none was found. + */ + TFLSupportErrorCodeMetadataNotFoundError, + + /** + * Associated TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS file is mandatory but + * none was found or it was empty. + * E.g. current task requires labels but none were found. + */ + TFLSupportErrorCodeMetadataMissingLabelsError, + + /** + * The ProcessingUnit for tokenizer is not correctly configured. + * E.g BertTokenizer doesn't have a valid vocab file associated. + */ + TFLSupportErrorCodeMetadataInvalidTokenizerError, + + /** Input tensor(s) error codes. */ + + /** + * Unexpected number of input tensors for the current task. + * E.g. current task expects a single input tensor. + */ + TFLSupportErrorCodeInvalidNumInputTensorsError = 300, + + /** + * Unexpected input tensor dimensions for the current task. + * E.g.: only 4D input tensors supported. + */ + TFLSupportErrorCodeInvalidInputTensorDimensionsError, + + /** + * Unexpected input tensor type for the current task. + * E.g.: current task expects a uint8 pixel image as input. + */ + TFLSupportErrorCodeInvalidInputTensorTypeError, + + /** + * Unexpected input tensor bytes size. + * E.g.: size in bytes does not correspond to the expected number of pixels. + */ + TFLSupportErrorCodeInvalidInputTensorSizeError, + + /** + * No correct input tensor found for the model. + * E.g.: input tensor name is not part of the text model's input tensors. + */ + TFLSupportErrorCodeInputTensorNotFoundError, + + /** Output tensor(s) error codes. */ + + /** + * Unexpected output tensor dimensions for the current task. + * E.g.: only a batch size of 1 is supported. + */ + TFLSupportErrorCodeInvalidOutputTensorDimensionsError = 400, + + /** + * Unexpected input tensor type for the current task. + * E.g.: multi-head model with different output tensor types. + */ + TFLSupportErrorCodeInvalidOutputTensorTypeError, + + /** + * No correct output tensor found for the model. + * E.g.: output tensor name is not part of the text model's output tensors. + */ + TFLSupportErrorCodeOutputTensorNotFoundError, + + /** + * Unexpected number of output tensors for the current task. + * E.g.: current task expects a single output tensor. + */ + TFLSupportErrorCodeInvalidNumOutputTensorsError, + + /** Image processing error codes. **/ + + /** Unspecified image processing failures. */ + TFLSupportErrorCodeImageProcessingError = 500, + + /** + * Unexpected input or output buffer metadata. + * E.g.: rotate RGBA buffer to Grayscale buffer by 90 degrees. + */ + TFLSupportErrorCodeImageProcessingInvalidArgumentError, + /** + * Image processing operation failures. + * E.g. libyuv rotation failed for an unknown reason. + */ + TFLSupportErrorCodeImageProcessingBackendError, + + /** kNotFound indicates some requested entity (such as a file or directory) + was not found. */ + TFLSupportErrorCodeNotFoundError = 900, + + /** kInternal indicates an internal error has occurred and some invariants + * expected by the underlying system have not been satisfied. This error code + * is reserved for serious errors. + */ + TFLSupportErrorCodeInternalError, +} NS_SWIFT_NAME(SupportErrorCode); + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h new file mode 100644 index 0000000..a194b28 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h
@@ -0,0 +1,63 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import <Foundation/Foundation.h> +#include "tensorflow_lite_support/c/common.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Helper utility for the all tasks which encapsulates common functionality. */ +@interface TFLCommonUtils : NSObject + +/** + * Creates and saves an error originating from the task library with the given + * error code and description. + * + * @param code Error code. + * @param description Error description. + * @param error Pointer to the memory location where the created error should be + * saved. If `nil`, no error will be saved. + */ ++ (void)customErrorWithCode:(NSInteger)code + description:(NSString*)description + error:(NSError**)error; + +/** + * Creates and saves an error originating from the task library from a C library + * error, TfLiteSupportError . + * + * @param supportError C library error. + * @param error Pointer to the memory location where the created error should be + * saved. If `nil`, no error will be saved. + */ ++ (void)errorFromTfLiteSupportError:(TfLiteSupportError*)supportError + error:(NSError**)error; + +/** + * Allocates a block of memory with the specified size and returns a pointer to + * it. If memory cannot be allocated because of an invalid memSize, it saves an + * error. In other cases, it terminates program execution. + * + * @param memSize size of memory to be allocated + * @param error Pointer to the memory location where errors if any should be + * saved. If `nil`, no error will be saved. + * + * @return Pointer to the allocated block of memory on successfull allocation. + * nil in case as error is encountered because of invalid memSize. If failure is + * due to any other reason, method terminates program execution. + */ ++ (void*)mallocWithSize:(size_t)memSize error:(NSError**)error; +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m new file mode 100644 index 0000000..2f2d85a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m
@@ -0,0 +1,63 @@ +// Copyright 2021 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow_lite_support/ios/sources/TFLCommonUtils.h" +#import "tensorflow_lite_support/ios/sources/TFLCommon.h" + +/** Error domain of TensorFlow Lite Support related errors. */ +static NSString* const TFLSupportTaskErrorDomain = @"org.tensorflow.lite.tasks"; + +@implementation TFLCommonUtils + ++ (void)customErrorWithCode:(NSInteger)code + description:(NSString*)description + error:(NSError**)error { + if (error) + *error = + [NSError errorWithDomain:TFLSupportTaskErrorDomain + code:code + userInfo:@{NSLocalizedDescriptionKey : description}]; +} + ++ (void)errorFromTfLiteSupportError:(TfLiteSupportError*)supportError + error:(NSError**)error { + if (supportError && error) + *error = [NSError errorWithDomain:TFLSupportTaskErrorDomain + code:supportError->code + userInfo:@{ + NSLocalizedDescriptionKey : [NSString + stringWithCString:supportError->message + encoding:NSUTF8StringEncoding] + }]; +} + ++ (void*)mallocWithSize:(size_t)memSize error:(NSError**)error { + if (!memSize) { + [TFLCommonUtils + customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError + description: + @"Invalid memory size passed for allocation of object." + error:error]; + return NULL; + } + + void* allocedMemory = malloc(memSize); + if (!allocedMemory && memSize) { + exit(-1); + } + + return allocedMemory; +} + +@end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/BUILD new file mode 100644 index 0000000..6102841 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/BUILD
@@ -0,0 +1,30 @@ +package( + default_visibility = ["//tensorflow_lite_support:internal"], + licenses = ["notice"], # Apache 2.0 +) + +objc_library( + name = "TFLBaseOptions", + srcs = [ + "sources/TFLBaseOptions.m", + ], + hdrs = [ + "sources/TFLBaseOptions.h", + ], + module_name = "TFLBaseOptions", +) + +objc_library( + name = "TFLBaseOptionsHelpers", + srcs = [ + "sources/TFLBaseOptions+Helpers.m", + ], + hdrs = [ + "sources/TFLBaseOptions+Helpers.h", + ], + module_name = "TFLBaseOptionsHelpers", + deps = [ + "//tensorflow_lite_support/c/task/core:base_options", + "//tensorflow_lite_support/ios/task/core:TFLBaseOptions", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h new file mode 100644 index 0000000..90864c7 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h
@@ -0,0 +1,24 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#include "tensorflow_lite_support/c/task/core/base_options.h" +#import "third_party/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface TFLBaseOptions (Helpers) +- (void)copyBaseOptionsToCBaseOptions:(TfLiteBaseOptions*)cBaseOptions; +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.m new file mode 100644 index 0000000..ddab0f7 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.m
@@ -0,0 +1,25 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h" + +@implementation TFLBaseOptions (Helpers) + +- (void)copyBaseOptionsToCBaseOptions:(TfLiteBaseOptions*)cBaseOptions { + if (self.modelFile.filePath) { + cBaseOptions->model_file.file_path = self.modelFile.filePath.UTF8String; + } +} + +@end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h new file mode 100644 index 0000000..0f92dd1 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h
@@ -0,0 +1,78 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import <Foundation/Foundation.h> + +NS_ASSUME_NONNULL_BEGIN + +/** + * Holds cpu settings. + */ +NS_SWIFT_NAME(CpuSettings) +@interface TFLCpuSettings : NSObject <NSCopying> + +/** Specifies the number of threads to be used for TFLite ops that support + * multi-threadingwhen running inference with CPU. + * @discussion This property hould be greater than 0 or equal to -1. Setting it + * to -1 has the effect to let TFLite runtime set the value. + */ +@property(nonatomic, assign) int numThreads; + +@end + +/** + * Holds settings for one possible acceleration configuration. + */ +NS_SWIFT_NAME(ComputeSettings) +@interface TFLComputeSettings : NSObject <NSCopying> + +/** Holds cpu settings. */ +@property(nonatomic, copy) TFLCpuSettings* cpuSettings; + +@end + +/** + * Holds settings for one possible acceleration configuration. + */ +NS_SWIFT_NAME(ExternalFile) +@interface TFLExternalFile : NSObject <NSCopying> + +/** Path to the file in bundle. */ +@property(nonatomic, copy) NSString* filePath; +/// Add provision for other sources in future. + +@end + +/** + * Holds the base options that is used for creation of any type of task. It has + * fields with important information acceleration configuration, tflite model + * source etc. + */ +NS_SWIFT_NAME(BaseOptions) +@interface TFLBaseOptions : NSObject <NSCopying> + +/** + * The external model file, as a single standalone TFLite file. It could be + * packed with TFLite Model Metadata[1] and associated files if exist. Fail to + * provide the necessary metadata and associated files might result in errors. + */ +@property(nonatomic, copy) TFLExternalFile* modelFile; + +/** + * Holds settings for one possible acceleration configuration including.cpu/gpu + * settings. Please see documentation of TfLiteComputeSettings and its members + * for more details. + */ +@property(nonatomic, copy) TFLComputeSettings* computeSettings; + +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.m new file mode 100644 index 0000000..1e536cd --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.m
@@ -0,0 +1,94 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h" + +@implementation TFLCpuSettings +@synthesize numThreads; + +- (instancetype)init { + self = [super init]; + if (self) { + self.numThreads = -1; + } + return self; +} + +- (id)copyWithZone:(NSZone*)zone { + TFLCpuSettings* cpuSettings = [[TFLCpuSettings alloc] init]; + + [cpuSettings setNumThreads:self.numThreads]; + + return cpuSettings; +} + +@end + +@implementation TFLComputeSettings +@synthesize cpuSettings; + +- (instancetype)init { + self = [super init]; + if (self) { + self.cpuSettings = [[TFLCpuSettings alloc] init]; + } + return self; +} + +- (id)copyWithZone:(NSZone*)zone { + TFLComputeSettings* computeSettings = [[TFLComputeSettings alloc] init]; + + [computeSettings setCpuSettings:self.cpuSettings]; + + return computeSettings; +} + +@end + +@implementation TFLExternalFile +@synthesize filePath; + +- (id)copyWithZone:(NSZone*)zone { + TFLExternalFile* externalFile = [[TFLExternalFile alloc] init]; + + [externalFile setFilePath:self.filePath]; + + return externalFile; +} + +@end + +@implementation TFLBaseOptions +@synthesize modelFile; +@synthesize computeSettings; + +- (instancetype)init { + self = [super init]; + if (self) { + self.computeSettings = [[TFLComputeSettings alloc] init]; + self.modelFile = [[TFLExternalFile alloc] init]; + } + return self; +} + +- (id)copyWithZone:(NSZone*)zone { + TFLBaseOptions* baseOptions = [[TFLBaseOptions alloc] init]; + + [baseOptions setModelFile:self.modelFile]; + [baseOptions setComputeSettings:self.computeSettings]; + + return baseOptions; +} + +@end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/BUILD new file mode 100644 index 0000000..f6600e0 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/BUILD
@@ -0,0 +1,43 @@ +package( + default_visibility = ["//tensorflow_lite_support:internal"], + licenses = ["notice"], # Apache 2.0 +) + +objc_library( + name = "TFLClassificationOptions", + srcs = [ + "sources/TFLClassificationOptions.m", + ], + hdrs = [ + "sources/TFLClassificationOptions.h", + ], + module_name = "TFLClassificationOptions", +) + +objc_library( + name = "TFLClassificationResult", + srcs = [ + "sources/TFLClassificationResult.m", + ], + hdrs = [ + "sources/TFLClassificationResult.h", + ], + module_name = "TFLClassificationResult", +) + +objc_library( + name = "TFLClassificationOptionsHelpers", + srcs = [ + "sources/TFLClassificationOptions+Helpers.m", + ], + hdrs = [ + "sources/TFLClassificationOptions+Helpers.h", + ], + module_name = "TFLClassificationOptionsHelpers", + deps = [ + "//tensorflow_lite_support/c/task/processor:classification_options", + "//tensorflow_lite_support/ios:TFLCommon", + "//tensorflow_lite_support/ios:TFLCommonUtils", + "//tensorflow_lite_support/ios/task/processor:TFLClassificationOptions", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h new file mode 100644 index 0000000..78a1f965 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h
@@ -0,0 +1,29 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#include "tensorflow_lite_support/c/task/processor/classification_options.h" +#import "third_party/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface TFLClassificationOptions (Helpers) +- (BOOL)copyClassificationOptionsToCClassificationOptions: + (TfLiteClassificationOptions*)cClassificationOptions + error:(NSError**)error; + +- (void)deleteCStringArraysOfClassificationOptions: + (TfLiteClassificationOptions*)cClassificationOptions; +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m new file mode 100644 index 0000000..07254ab --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m
@@ -0,0 +1,118 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "tensorflow_lite_support/ios/sources/TFLCommon.h" +#import "tensorflow_lite_support/ios/sources/TFLCommonUtils.h" +#import "tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h" + +@implementation TFLClassificationOptions (Helpers) + ++ (char**)cStringArrayFromNSArray:(NSArray<NSString*>*)strings + error:(NSError**)error { + if (strings.count <= 0) { + [TFLCommonUtils + customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError + description: + @"Invalid length of strings found for list type options." + error:error]; + return NULL; + } + + char** cStrings = (char**)calloc(strings.count, sizeof(char*)); + + if (!cStrings) { + [TFLCommonUtils + customErrorWithCode:TFLSupportErrorCodeInternalError + description:@"Could not initialize list type options." + error:error]; + return nil; + } + + for (NSInteger i = 0; i < strings.count; i++) { + char* cString = [TFLCommonUtils + mallocWithSize:[strings[i] + lengthOfBytesUsingEncoding:NSUTF8StringEncoding] + + 1 + error:error]; + if (!cString) + return nil; + + strcpy(cString, strings[i].UTF8String); + } + + return cStrings; +} + ++ (void)deleteCStringsArray:(char**)cStrings count:(int)count { + for (NSInteger i = 0; i < count; i++) { + free(cStrings[i]); + } + + free(cStrings); +} + +- (BOOL)copyClassificationOptionsToCClassificationOptions: + (TfLiteClassificationOptions*)cClassificationOptions + error:(NSError**)error { + cClassificationOptions->score_threshold = self.scoreThreshold; + cClassificationOptions->max_results = (int)self.maxResults; + + if (self.labelDenyList) { + char** cClassNameBlackList = + [TFLClassificationOptions cStringArrayFromNSArray:self.labelDenyList + error:error]; + if (!cClassNameBlackList) { + return NO; + } + cClassificationOptions->label_denylist.list = cClassNameBlackList; + cClassificationOptions->label_denylist.length = + (int)self.labelDenyList.count; + } + + if (self.labelAllowList) { + char** cClassNameWhiteList = + [TFLClassificationOptions cStringArrayFromNSArray:self.labelAllowList + error:error]; + if (!cClassNameWhiteList) { + return NO; + } + + cClassificationOptions->label_allowlist.list = cClassNameWhiteList; + cClassificationOptions->label_allowlist.length = + (int)self.labelAllowList.count; + } + + if (self.displayNamesLocal) { + cClassificationOptions->display_names_local = + (char*)self.displayNamesLocal.UTF8String; + } + + return YES; +} + +- (void)deleteCStringArraysOfClassificationOptions: + (TfLiteClassificationOptions*)cClassificationOptions { + if (self.labelAllowList) { + [TFLClassificationOptions + deleteCStringsArray:cClassificationOptions->label_allowlist.list + count:cClassificationOptions->label_allowlist.length]; + } + + if (self.labelDenyList) { + [TFLClassificationOptions + deleteCStringsArray:cClassificationOptions->label_denylist.list + count:cClassificationOptions->label_denylist.length]; + } +} +@end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h new file mode 100644 index 0000000..cc0c8a8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h
@@ -0,0 +1,42 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import <Foundation/Foundation.h> + +NS_ASSUME_NONNULL_BEGIN + +/** + * Holds settings for any single classification task. + */ +@interface TFLClassificationOptions : NSObject <NSCopying> + +/** If set, all classes in this list will be filtered out from the results . */ +@property(nonatomic, copy) NSArray* labelDenyList; + +/** If set, all classes not in this list will be filtered out from the results . + */ +@property(nonatomic, copy) NSArray* labelAllowList; + +/** Display names local for display names*/ +@property(nonatomic, copy) NSString* displayNamesLocal; + +/** Results with score threshold greater than this value are returned . */ +@property(nonatomic, assign) float scoreThreshold; + +/** Limit to the number of classes that can be returned in results. */ +@property(nonatomic, assign) NSInteger maxResults; + +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.m new file mode 100644 index 0000000..dca232d6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.m
@@ -0,0 +1,46 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h" + +@implementation TFLClassificationOptions +@synthesize scoreThreshold; +@synthesize maxResults; +@synthesize labelAllowList; +@synthesize labelDenyList; +@synthesize displayNamesLocal; + +- (instancetype)init { + self = [super init]; + if (self) { + self.maxResults = -1; + self.scoreThreshold = 0; + } + return self; +} + +- (id)copyWithZone:(NSZone*)zone { + TFLClassificationOptions* classificationOptions = + [[TFLClassificationOptions alloc] init]; + + [classificationOptions setScoreThreshold:self.scoreThreshold]; + [classificationOptions setMaxResults:self.maxResults]; + [classificationOptions setLabelDenyList:self.labelDenyList]; + [classificationOptions setLabelAllowList:self.labelAllowList]; + [classificationOptions setDisplayNamesLocal:self.displayNamesLocal]; + + return classificationOptions; +} + +@end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h new file mode 100644 index 0000000..c0d6fb3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h
@@ -0,0 +1,60 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import <Foundation/Foundation.h> + +NS_ASSUME_NONNULL_BEGIN + +/** Encapsulates information about a class in the classification results. */ +@interface TFLCategory : NSObject + +/** Display name of the class. */ +@property(nonatomic, copy) NSString* displayName; + +/** Class name of the class . */ +@property(nonatomic, copy) NSString* label; + +/** Confidence score for this class . */ +@property(nonatomic, assign) float score; + +/** The index of the class in the corresponding label map, usually packed in the + * TFLite Model Metadata. */ +@property(nonatomic, assign) NSInteger classIndex; + +@end + +/** Encapsulates list of predicted classes (aka labels) for a given image + * classifier head. */ +@interface TFLClassifications : NSObject + +/** + * The index of the image classifier head these classes refer to. This is useful + * for multi-head models. + */ +@property(nonatomic, assign) int headIndex; + +/** The array of predicted classes, usually sorted by descending scores + * (e.g.from high to low probability). */ +@property(nonatomic, copy) NSArray<TFLCategory*>* categories; + +@end + +/** Encapsulates results of any classification task. */ +@interface TFLClassificationResult : NSObject + +@property(nonatomic, copy) NSArray<TFLClassifications*>* classifications; + +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m new file mode 100644 index 0000000..febf230b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m
@@ -0,0 +1,33 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h" + +@implementation TFLCategory +@synthesize displayName; +@synthesize label; +@synthesize score; +@synthesize classIndex; +@end + +@implementation TFLClassifications +@synthesize headIndex; +@synthesize categories; + +@end + +@implementation TFLClassificationResult +@synthesize classifications; + +@end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/BUILD new file mode 100644 index 0000000..13b9809 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/BUILD
@@ -0,0 +1,19 @@ +package( + default_visibility = ["//tensorflow_lite_support:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) + +objc_library( + name = "TFLClassificationUtils", + srcs = [ + "sources/TFLClassificationUtils.m", + ], + hdrs = [ + "sources/TFLClassificationUtils.h", + ], + module_name = "TFLClassificationUtils", + deps = [ + "//tensorflow_lite_support/c/task/processor:classification_result", + "//tensorflow_lite_support/ios/task/processor:TFLClassificationResult", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.h new file mode 100644 index 0000000..c52876e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.h
@@ -0,0 +1,46 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import <Foundation/Foundation.h> +#import "third_party/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h" + +#include "tensorflow_lite_support/c/task/processor/classification_result.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Helper utility for conversion between TFLite Task C Library Classification + * Results and iOS Classification Results . */ +@interface TFLClassificationUtils : NSObject + +/** + * Creates and retrurns a TFLClassificationResult from a + * TfLiteClassificationResult returned by TFLite Task C Library Classification + * tasks. + * + * @param cClassificationResult Classification results returned by TFLite Task C + * Library Classification tasks + * + * @return Classification Result of type TFLClassificationResult to be returned + * by inference methods of the iOS TF Lite Task Classification tasks. + */ ++ (TFLClassificationResult*)classificationResultFromCClassificationResults: + (TfLiteClassificationResult*)cClassificationResult; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.m new file mode 100644 index 0000000..b5d884d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.m
@@ -0,0 +1,61 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.h" + +@implementation TFLClassificationUtils + ++ (TFLClassificationResult*)classificationResultFromCClassificationResults: + (TfLiteClassificationResult*)cClassificationResult { + if (cClassificationResult == nil) + return nil; + + NSMutableArray* classificationHeads = [[NSMutableArray alloc] init]; + for (int i = 0; i < cClassificationResult->size; i++) { + TfLiteClassifications cClassifications = + cClassificationResult->classifications[i]; + NSMutableArray* classes = [[NSMutableArray alloc] init]; + for (int j = 0; j < cClassifications.size; j++) { + TfLiteCategory cCategory = cClassifications.categories[j]; + TFLCategory* resultCategory = [[TFLCategory alloc] init]; + + if (cCategory.display_name != nil) { + resultCategory.displayName = + [NSString stringWithCString:cCategory.display_name + encoding:NSUTF8StringEncoding]; + } + + if (cCategory.label != nil) { + resultCategory.label = + [NSString stringWithCString:cCategory.label + encoding:NSUTF8StringEncoding]; + } + + resultCategory.score = cCategory.score; + resultCategory.classIndex = (NSInteger)cCategory.index; + [classes addObject:resultCategory]; + } + TFLClassifications* classificationHead = [[TFLClassifications alloc] init]; + classificationHead.categories = classes; + classificationHead.headIndex = i; + [classificationHeads addObject:classificationHead]; + } + + TFLClassificationResult* classificationResult = + [[TFLClassificationResult alloc] init]; + classificationResult.classifications = classificationHeads; + return classificationResult; +} + +@end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/BUILD index fb369e9..1d11afd 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/BUILD
@@ -1,10 +1,10 @@ load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") -load("@org_tensorflow//tensorflow/lite/experimental/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") +load("@org_tensorflow//tensorflow/lite/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner") package( - default_visibility = ["//tensorflow_lite_support:users"], + default_visibility = ["//tensorflow_lite_support:internal"], licenses = ["notice"], # Apache 2.0 ) @@ -13,8 +13,9 @@ srcs = ["Sources/TFLBertNLClassifier.m"], hdrs = ["Sources/TFLBertNLClassifier.h"], module_name = "TFLBertNLClassifier", + visibility = ["//tensorflow_lite_support:internal"], deps = [ - "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier_c_api", + "//tensorflow_lite_support/c/task/text:bert_nl_classifier", "@google_toolbox_for_mac//:GTM_Defines", ], ) @@ -24,12 +25,11 @@ testonly = 1, srcs = ["Tests/TFLBertNLClassifierTest.swift"], data = [ - "//tensorflow_lite_support/cc/test/testdata/task/text:nl_classifier_models", + "//tensorflow_lite_support/cc/test/testdata/task/text:bert_nl_classifier_models", ], tags = TFL_DEFAULT_TAGS, deps = [ ":TFLBertNLClassifier", - "//third_party/swift/xctest", ], ) @@ -48,7 +48,7 @@ testonly = 1, srcs = ["Tests/TFLBertNLClassifierTest.m"], data = [ - "//tensorflow_lite_support/cc/test/testdata/task/text:nl_classifier_models", + "//tensorflow_lite_support/cc/test/testdata/task/text:bert_nl_classifier_models", ], tags = TFL_DEFAULT_TAGS, deps = [ @@ -71,8 +71,9 @@ srcs = ["Sources/TFLNLClassifier.m"], hdrs = ["Sources/TFLNLClassifier.h"], module_name = "TFLNLClassifier", + visibility = ["//tensorflow_lite_support:internal"], deps = [ - "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier_c_api", + "//tensorflow_lite_support/c/task/text:nl_classifier", "@google_toolbox_for_mac//:GTM_Defines", ], ) @@ -87,7 +88,6 @@ tags = TFL_DEFAULT_TAGS, deps = [ ":TFLNLClassifier", - "//third_party/swift/xctest", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h index 3ec5716..ac81a15 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h
@@ -17,6 +17,16 @@ NS_ASSUME_NONNULL_BEGIN /** + * Options to configure TFLBertNLClassifier. + */ +@interface TFLBertNLClassifierOptions : NSObject + +// @deprecated maxSeqLen is now read from the model (i.e. input tensor size) +// automatically. +@property(nonatomic) int maxSeqLen; +@end + +/** * Classifier API for NLClassification tasks with Bert models, categorizes * string into different classes. The API expects a Bert based TFLite model with * metadata populated. @@ -41,6 +51,17 @@ NS_SWIFT_NAME(bertNLClassifier(modelPath:)); /** + * Creates TFLBertNLClassifier from a model file. + * + * @param modelPath Path to the classification model. + * @return A TFLBertNLClassifier instance. + */ ++ (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath + options: + (TFLBertNLClassifierOptions*)options + NS_SWIFT_NAME(bertNLClassifier(modelPath:options:)); + +/** * Performs classification on a NSString input, returns <NSString *, NSNumber *> * for categories and socres. *
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m index e5dfaaf..8c45ee6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m
@@ -14,30 +14,48 @@ ==============================================================================*/ #import "tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h" #import "GTMDefines.h" -#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h" -#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h" +#include "tensorflow_lite_support/c/task/text/bert_nl_classifier.h" +#include "tensorflow_lite_support/c/task/text/nl_classifier_common.h" NS_ASSUME_NONNULL_BEGIN +@implementation TFLBertNLClassifierOptions +@synthesize maxSeqLen; +@end + @interface TFLBertNLClassifier () /** BertNLClassifier backed by C API */ -@property(nonatomic) BertNLClassifier* bertNLClassifier; +@property(nonatomic) TfLiteBertNLClassifier* bertNLClassifier; @end @implementation TFLBertNLClassifier - (void)dealloc { - BertNLClassifierDelete(_bertNLClassifier); + TfLiteBertNLClassifierDelete(_bertNLClassifier); } + (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath { - BertNLClassifier* classifier = BertNLClassifierFromFile(modelPath.UTF8String); + TfLiteBertNLClassifier* classifier = + TfLiteBertNLClassifierCreate(modelPath.UTF8String); _GTMDevAssert(classifier, @"Failed to create BertNLClassifier"); return [[TFLBertNLClassifier alloc] initWithBertNLClassifier:classifier]; } -- (instancetype)initWithBertNLClassifier:(BertNLClassifier*)bertNLClassifier { ++ (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath + options: + (TFLBertNLClassifierOptions*)options { + // Note that maxSeqLen has been deprecated. Passing it to the C API is a + // no-op. + TfLiteBertNLClassifierOptions cOptions = {.max_seq_len = options.maxSeqLen}; + TfLiteBertNLClassifier* classifier = + TfLiteBertNLClassifierCreateFromOptions(modelPath.UTF8String, &cOptions); + _GTMDevAssert(classifier, @"Failed to create BertNLClassifier"); + return [[TFLBertNLClassifier alloc] initWithBertNLClassifier:classifier]; +} + +- (instancetype)initWithBertNLClassifier: + (TfLiteBertNLClassifier*)bertNLClassifier { self = [super init]; if (self) { _bertNLClassifier = bertNLClassifier; @@ -46,12 +64,12 @@ } - (NSDictionary<NSString*, NSNumber*>*)classifyWithText:(NSString*)text { - struct Categories* cCategories = - BertNLClassifierClassify(_bertNLClassifier, text.UTF8String); + Categories* cCategories = + TfLiteBertNLClassifierClassify(_bertNLClassifier, text.UTF8String); NSMutableDictionary<NSString*, NSNumber*>* ret = [NSMutableDictionary dictionary]; for (int i = 0; i < cCategories->size; i++) { - struct Category cCategory = cCategories->categories[i]; + Category cCategory = cCategories->categories[i]; [ret setValue:[NSNumber numberWithDouble:cCategory.score] forKey:[NSString stringWithUTF8String:cCategory.text]]; }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m index d15dfa3..39eb15c 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m
@@ -14,8 +14,8 @@ ==============================================================================*/ #import "tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h" #import "GTMDefines.h" -#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h" -#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h" +#include "tensorflow_lite_support/c/task/text/nl_classifier.h" +#include "tensorflow_lite_support/c/task/text/nl_classifier_common.h" NS_ASSUME_NONNULL_BEGIN @@ -30,31 +30,31 @@ @interface TFLNLClassifier () /** NLClassifier backed by C API */ -@property(nonatomic) NLClassifier* nlClassifier; +@property(nonatomic) TfLiteNLClassifier* nlClassifier; @end @implementation TFLNLClassifier - (void)dealloc { - NLClassifierDelete(_nlClassifier); + TfLiteNLClassifierDelete(_nlClassifier); } + (instancetype)nlClassifierWithModelPath:(NSString*)modelPath options:(TFLNLClassifierOptions*)options { - struct NLClassifierOptions cOptions = { + TfLiteNLClassifierOptions cOptions = { .input_tensor_index = options.inputTensorIndex, .output_score_tensor_index = options.outputScoreTensorIndex, .output_label_tensor_index = options.outputLabelTensorIndex, .input_tensor_name = options.inputTensorName.UTF8String, .output_score_tensor_name = options.outputScoreTensorName.UTF8String, .output_label_tensor_name = options.outputLabelTensorName.UTF8String}; - NLClassifier* classifier = - NLClassifierFromFileAndOptions(modelPath.UTF8String, &cOptions); + TfLiteNLClassifier* classifier = + TfLiteNLClassifierCreateFromOptions(modelPath.UTF8String, &cOptions); _GTMDevAssert(classifier, @"Failed to create NLClassifier"); return [[TFLNLClassifier alloc] initWithNLClassifier:classifier]; } -- (instancetype)initWithNLClassifier:(NLClassifier*)nlClassifier { +- (instancetype)initWithNLClassifier:(TfLiteNLClassifier*)nlClassifier { self = [super init]; if (self) { _nlClassifier = nlClassifier; @@ -63,12 +63,12 @@ } - (NSDictionary<NSString*, NSNumber*>*)classifyWithText:(NSString*)text { - struct Categories* cCategories = - NLClassifierClassify(_nlClassifier, text.UTF8String); + Categories* cCategories = + TfLiteNLClassifierClassify(_nlClassifier, text.UTF8String); NSMutableDictionary<NSString*, NSNumber*>* ret = [NSMutableDictionary dictionary]; for (int i = 0; i < cCategories->size; i++) { - struct Category cCategory = cCategories->categories[i]; + Category cCategory = cCategories->categories[i]; [ret setValue:[NSNumber numberWithDouble:cCategory.score] forKey:[NSString stringWithUTF8String:cCategory.text]]; }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m index 4f121b8..407be10 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m
@@ -20,6 +20,7 @@ @interface TFLBertNLClassifierTest : XCTestCase @property(nonatomic, nullable) NSString* bertModelPath; +@property(nonatomic, nullable) TFLBertNLClassifierOptions* modelOptions; @end @implementation TFLBertNLClassifierTest @@ -28,7 +29,7 @@ - (void)setUp { [super setUp]; NSBundle* bundle = [NSBundle bundleForClass:[self class]]; - self.bertModelPath = [bundle pathForResource:@"test_model_nl_classifier_bert" + self.bertModelPath = [bundle pathForResource:@"bert_nl_classifier" ofType:@"tflite"]; } @@ -57,5 +58,39 @@ XCTAssertGreaterThan([categories[@"negative"] doubleValue], [categories[@"positive"] doubleValue]); } + +- (void)testCreateFromOptionsClassifyPositiveResult { + self.modelOptions = [[TFLBertNLClassifierOptions alloc] init]; + [self.modelOptions setMaxSeqLen:128]; + + TFLBertNLClassifier* bertNLClassifier = + [TFLBertNLClassifier bertNLClassifierWithModelPath:self.bertModelPath + options:self.modelOptions]; + + XCTAssertNotNil(bertNLClassifier); + + NSDictionary<NSString*, NSNumber*>* categories = [bertNLClassifier + classifyWithText:@"it's a charming and often affecting journey"]; + + XCTAssertGreaterThan([categories[@"positive"] doubleValue], + [categories[@"negative"] doubleValue]); +} + +- (void)testCreateFromOptionsClassifyNegativeResult { + self.modelOptions = [[TFLBertNLClassifierOptions alloc] init]; + [self.modelOptions setMaxSeqLen:128]; + + TFLBertNLClassifier* bertNLClassifier = + [TFLBertNLClassifier bertNLClassifierWithModelPath:self.bertModelPath + options:self.modelOptions]; + + XCTAssertNotNil(bertNLClassifier); + + NSDictionary<NSString*, NSNumber*>* categories = + [bertNLClassifier classifyWithText:@"unflinchingly bleak and desperate"]; + + XCTAssertGreaterThan([categories[@"negative"] doubleValue], + [categories[@"positive"] doubleValue]); +} @end NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.swift b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.swift index d331b04..133b8f1 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.swift +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.swift
@@ -19,7 +19,7 @@ class TFLBertNLClassifierTest: XCTestCase { static let bundle = Bundle(for: TFLBertNLClassifierTest.self) - static let bertModelPath = bundle.path(forResource: "test_model_nl_classifier_bert", ofType: "tflite")! + static let bertModelPath = bundle.path(forResource: "bert_nl_classifier", ofType: "tflite")! func testClassifyPositiveResult() { let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier( @@ -42,4 +42,32 @@ XCTAssertGreaterThan(categories["negative"]!.doubleValue, categories["positive"]!.doubleValue) } + + func testCreateFromOptionsClassifyPositiveResult() { + let modelOptions = TFLBertNLClassifierOptions() + modelOptions.maxSeqLen = 128 + let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier( + modelPath: TFLBertNLClassifierTest.bertModelPath, + options: modelOptions) + + XCTAssertNotNil(bertNLClassifier) + + let categories = bertNLClassifier.classify(text: "it's a charming and often affecting journey") + + XCTAssertGreaterThan(categories["positive"]!.doubleValue, categories["negative"]!.doubleValue) + } + + func testCreateFromOptionsClassifyNegativeResult() { + let modelOptions = TFLBertNLClassifierOptions() + modelOptions.maxSeqLen = 128 + let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier( + modelPath: TFLBertNLClassifierTest.bertModelPath, + options: modelOptions) + + XCTAssertNotNil(bertNLClassifier) + + let categories = bertNLClassifier.classify(text: "unflinchingly bleak and desperate") + + XCTAssertGreaterThan(categories["negative"]!.doubleValue, categories["positive"]!.doubleValue) + } }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/BUILD index 7998a8e..4b15aed 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/BUILD
@@ -1,10 +1,10 @@ load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") -load("@org_tensorflow//tensorflow/lite/experimental/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") +load("@org_tensorflow//tensorflow/lite/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner") package( - default_visibility = ["//tensorflow_lite_support:users"], + default_visibility = ["//tensorflow_lite_support:internal"], licenses = ["notice"], # Apache 2.0 ) @@ -13,8 +13,9 @@ srcs = ["Sources/TFLBertQuestionAnswerer.m"], hdrs = ["Sources/TFLBertQuestionAnswerer.h"], module_name = "TFLBertQuestionAnswerer", + visibility = ["//tensorflow_lite_support:internal"], deps = [ - "//tensorflow_lite_support/cc/task/text/qa:bert_qa_c_api", + "//tensorflow_lite_support/c/task/text:bert_question_answerer", "@google_toolbox_for_mac//:GTM_Defines", ], ) @@ -30,7 +31,6 @@ tags = TFL_DEFAULT_TAGS, deps = [ ":TFLBertQuestionAnswerer", - "//third_party/swift/xctest", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m index 2f410f4..b470c46 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m
@@ -14,7 +14,7 @@ ==============================================================================*/ #import "tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h" #import "GTMDefines.h" -#include "tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h" +#include "tensorflow_lite_support/c/task/text/bert_question_answerer.h" NS_ASSUME_NONNULL_BEGIN @@ -25,25 +25,25 @@ @interface TFLBertQuestionAnswerer () /** BertQuestionAnswerer backed by C API */ -@property(nonatomic) BertQuestionAnswerer* bertQuestionAnswerer; +@property(nonatomic) TfLiteBertQuestionAnswerer* bertQuestionAnswerer; @end @implementation TFLBertQuestionAnswerer - (void)dealloc { - BertQuestionAnswererDelete(_bertQuestionAnswerer); + TfLiteBertQuestionAnswererDelete(_bertQuestionAnswerer); } + (instancetype)questionAnswererWithModelPath:(NSString*)modelPath { - BertQuestionAnswerer* bert_qa = - BertQuestionAnswererFromFile(modelPath.UTF8String); + TfLiteBertQuestionAnswerer* bert_qa = + TfLiteBertQuestionAnswererCreate(modelPath.UTF8String); _GTMDevAssert(bert_qa, @"Failed to create BertQuestionAnswerer"); return [[TFLBertQuestionAnswerer alloc] initWithBertQuestionAnswerer:bert_qa]; } - (instancetype)initWithBertQuestionAnswerer: - (BertQuestionAnswerer*)bertQuestionAnswerer { + (TfLiteBertQuestionAnswerer*)bertQuestionAnswerer { self = [super init]; if (self) { _bertQuestionAnswerer = bertQuestionAnswerer; @@ -53,12 +53,12 @@ - (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context question:(NSString*)question { - struct QaAnswers* cAnswers = BertQuestionAnswererAnswer( + TfLiteQaAnswers* cAnswers = TfLiteBertQuestionAnswererAnswer( _bertQuestionAnswerer, context.UTF8String, question.UTF8String); NSMutableArray<TFLQAAnswer*>* ret = [NSMutableArray arrayWithCapacity:cAnswers->size]; for (int i = 0; i < cAnswers->size; i++) { - struct QaAnswer cAnswer = cAnswers->answers[i]; + TfLiteQaAnswer cAnswer = cAnswers->answers[i]; TFLQAAnswer* answer = [[TFLQAAnswer alloc] init]; struct TFLPos pos = { .start = cAnswer.start, .end = cAnswer.end, .logit = cAnswer.logit}; @@ -66,7 +66,7 @@ [answer setText:[NSString stringWithUTF8String:cAnswer.text]]; [ret addObject:answer]; } - BertQuestionAnswererQaAnswersDelete(cAnswers); + TfLiteQaAnswersDelete(cAnswers); return ret; } @end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/BUILD new file mode 100644 index 0000000..c17dbc8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/BUILD
@@ -0,0 +1,26 @@ +package( + default_visibility = ["//tensorflow_lite_support:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) + +objc_library( + name = "TFLImageClassifier", + srcs = [ + "sources/TFLImageClassifier.m", + ], + hdrs = [ + "sources/TFLImageClassifier.h", + ], + module_name = "TFLImageClassifier", + deps = [ + "//tensorflow_lite_support/c/task/vision:image_classifier", + "//tensorflow_lite_support/ios:TFLCommonUtils", + "//tensorflow_lite_support/ios/task/core:TFLBaseOptions", + "//tensorflow_lite_support/ios/task/core:TFLBaseOptionsHelpers", + "//tensorflow_lite_support/ios/task/processor:TFLClassificationOptions", + "//tensorflow_lite_support/ios/task/processor:TFLClassificationOptionsHelpers", + "//tensorflow_lite_support/ios/task/processor:TFLClassificationResult", + "//tensorflow_lite_support/ios/task/processor/utils:TFLClassificationUtils", + "//tensorflow_lite_support/ios/task/vision/utils:GMLImageUtils", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h new file mode 100644 index 0000000..1b988f2b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h
@@ -0,0 +1,113 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import <Foundation/Foundation.h> + +#import "third_party/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h" +#import "third_party/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h" +#import "third_party/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h" +#import "third_party/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Options to configure TFLImageClassifier. + */ +@interface TFLImageClassifierOptions : NSObject + +/** + * Base options that is used for creation of any type of task. + * @seealso TFLBaseOptions + */ +@property(nonatomic, copy) TFLBaseOptions* baseOptions; + +/** + * Options that configure the display and filtering of results. + * @seealso TFLClassificationOptions + */ +@property(nonatomic, copy) TFLClassificationOptions* classificationOptions; + +/** + * Initializes TFLImageClassifierOptions with the model path set to the + * specified path to a model file. + * @description The external model file, must be a single standalone TFLite + * file. It could be packed with TFLite Model Metadata[1] and associated files + * if exist. Fail to provide the necessary metadata and associated files might + * result in errors. Check the [documentation] + * (https://www.tensorflow.org/lite/convert/metadata) for each task about the + * specific requirement. + * + * @param modelPath Path to a TFLite model file. + * @return An instance of TFLImageClassifierOptions set to the specified + * modelPath. + */ +- (nullable instancetype)initWithModelPath:(nonnull NSString*)modelPath; + +@end + +/** + * A TensorFlow Lite Task Image Classifiier. + */ +@interface TFLImageClassifier : NSObject + +/** + * Creates TFLImageClassifier from a model file and specified options . + * + * @param options TFLImageClassifierOptions instance with the necessary + * properties set. + * + * @return A TFLImageClassifier instance. + */ ++ (nullable instancetype)imageClassifierWithOptions: + (nonnull TFLImageClassifierOptions*)options + error:(NSError**)error + NS_SWIFT_NAME(imageClassifier(options:)); + +/** + * Performs classification on a GMLImage input, returns an array of + * categorization results where each member in the array is an array of + * TFLClass objects for each classification head. + * + * @param image input to the model. + * @return An NSArray<NSArray<TFLClass *>*> * of classification results. + */ +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image + error:(NSError* _Nullable*) + error + NS_SWIFT_NAME(classify(gmlImage:)); + +/** + * Performs classification on a GMLImage input on the pixels in the + * specified bounding box, returns an array of categorization results + * where each member in the array is an array of TFLClass objects for + * each classification head. + * + * @param image input to the model. + * @param roi CGRect specifying region of interest in image. + * + * @return An NSArray<NSArray<TFLClass *>*> * of classification results. + */ +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image + regionOfInterest:(CGRect)roi + error:(NSError* _Nullable*) + error + NS_SWIFT_NAME(classify(gmlImage:regionOfInterest:)); + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m new file mode 100644 index 0000000..06d6793 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m
@@ -0,0 +1,143 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h" +#import "tensorflow_lite_support/ios/sources/TFLCommon.h" +#import "tensorflow_lite_support/ios/sources/TFLCommonUtils.h" +#import "tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h" +#import "tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h" +#import "tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.h" +#import "tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.h" + +#include "tensorflow_lite_support/c/task/vision/image_classifier.h" + +@interface TFLImageClassifier () +/** ImageClassifier backed by C API */ +@property(nonatomic) TfLiteImageClassifier* imageClassifier; +@end + +@implementation TFLImageClassifierOptions +@synthesize baseOptions; +@synthesize classificationOptions; + +- (instancetype)init { + self = [super init]; + if (self) { + self.baseOptions = [[TFLBaseOptions alloc] init]; + self.classificationOptions = [[TFLClassificationOptions alloc] init]; + } + return self; +} + +- (nullable instancetype)initWithModelPath:(nonnull NSString*)modelPath { + self = [self init]; + if (self) { + self.baseOptions.modelFile.filePath = modelPath; + } + return self; +} + +@end + +@implementation TFLImageClassifier +- (void)dealloc { + TfLiteImageClassifierDelete(_imageClassifier); +} + +- (instancetype)initWithImageClassifier: + (TfLiteImageClassifier*)imageClassifier { + self = [super init]; + if (self) { + _imageClassifier = imageClassifier; + } + return self; +} + ++ (nullable instancetype)imageClassifierWithOptions: + (nonnull TFLImageClassifierOptions*)options + error:(NSError**)error { + TfLiteImageClassifierOptions cOptions = TfLiteImageClassifierOptionsCreate(); + if (![options.classificationOptions + copyClassificationOptionsToCClassificationOptions: + &(cOptions.classification_options) + error:error]) + return nil; + + [options.baseOptions copyBaseOptionsToCBaseOptions:&(cOptions.base_options)]; + + TfLiteSupportError* createClassifierError = nil; + TfLiteImageClassifier* imageClassifier = + TfLiteImageClassifierFromOptions(&cOptions, &createClassifierError); + + [options.classificationOptions deleteCStringArraysOfClassificationOptions: + &(cOptions.classification_options)]; + + if (!imageClassifier) { + [TFLCommonUtils errorFromTfLiteSupportError:createClassifierError + error:error]; + TfLiteSupportErrorDelete(createClassifierError); + return nil; + } + + return [[TFLImageClassifier alloc] initWithImageClassifier:imageClassifier]; +} + +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image + error:(NSError* _Nullable*) + error { + return [self classifyWithGMLImage:image + regionOfInterest:CGRectMake(0, 0, image.width, image.height) + error:error]; +} + +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image + regionOfInterest:(CGRect)roi + error:(NSError* _Nullable*) + error { + TfLiteFrameBuffer* cFrameBuffer = + [GMLImageUtils cFrameBufferFromGMLImage:image error:error]; + + if (!cFrameBuffer) { + return nil; + } + + TfLiteBoundingBox boundingBox = {.origin_x = roi.origin.x, + .origin_y = roi.origin.y, + .width = roi.size.width, + .height = roi.size.height}; + + TfLiteSupportError* classifyError = nil; + TfLiteClassificationResult* cClassificationResult = + TfLiteImageClassifierClassifyWithRoi(_imageClassifier, cFrameBuffer, + &boundingBox, &classifyError); + + free(cFrameBuffer->buffer); + cFrameBuffer->buffer = nil; + + free(cFrameBuffer); + cFrameBuffer = nil; + + if (!cClassificationResult) { + [TFLCommonUtils errorFromTfLiteSupportError:classifyError error:error]; + TfLiteSupportErrorDelete(classifyError); + return nil; + } + + TFLClassificationResult* classificationHeadsResults = [TFLClassificationUtils + classificationResultFromCClassificationResults:cClassificationResult]; + TfLiteClassificationResultDelete(cClassificationResult); + + return classificationHeadsResults; +} +@end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/BUILD new file mode 100644 index 0000000..4b2ba81 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/BUILD
@@ -0,0 +1,22 @@ +package( + default_visibility = ["//tensorflow_lite_support:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) + +objc_library( + name = "GMLImageUtils", + srcs = [ + "sources/GMLImageUtils.m", + ], + hdrs = [ + "sources/GMLImageUtils.h", + ], + module_name = "GMLImageUtils", + sdk_frameworks = ["Accelerate"], + deps = [ + "//tensorflow_lite_support/c/task/vision/core:frame_buffer", + "//tensorflow_lite_support/ios:TFLCommon", + "//tensorflow_lite_support/ios:TFLCommonUtils", + "//tensorflow_lite_support/odml/ios/image:MLImage", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.h new file mode 100644 index 0000000..298485b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.h
@@ -0,0 +1,50 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import <Foundation/Foundation.h> + +#include "tensorflow_lite_support/c/task/vision/core/frame_buffer.h" +#import "third_party/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Helper utility for performing operations on GMLImage specific to the + * TF Lite Task Vision library + */ +@interface GMLImageUtils : NSObject +/** + * Creates and returns a TfLiteFrameBuffer from a GMLImage. TfLiteFrameBuffer + * is used by the TFLite Task Vision C library to hold the backing buffer of + * any image. Image inputs to the TFLite Task Vision C library is of type + * TfLiteFrameBuffer. + * + * @param gmlImage Image of type GMLImage which is to be converted into a + * TfLiteFrameBuffer. + * @param error Pointer to the memory location where errors if any should be + * saved. If `nil`, no error will be saved. + * + * @return The TfLiteFrameBuffer created from the gmlImage which can be used + * with the TF Lite Task Vision C library. + */ ++ (nullable TfLiteFrameBuffer*)cFrameBufferFromGMLImage:(GMLImage*)gmlImage + error:(NSError* _Nullable*) + error; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.m new file mode 100644 index 0000000..72425b3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.m
@@ -0,0 +1,333 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.h" +#import "tensorflow_lite_support/ios/sources/TFLCommon.h" +#import "tensorflow_lite_support/ios/sources/TFLCommonUtils.h" + +#include "tensorflow_lite_support/c/task/vision/core/frame_buffer.h" + +#import <Accelerate/Accelerate.h> +#import <CoreGraphics/CoreGraphics.h> +#import <CoreImage/CoreImage.h> +#import <CoreVideo/CoreVideo.h> + +@interface TFLCVPixelBufferUtils : NSObject ++ (uint8_t* _Nullable) + convertBGRAtoRGBforPixelBufferBaseAddress:(CVPixelBufferRef)pixelBuffer + error:(NSError**)error; +@end + +@interface UIImage (RawPixelDataUtils) +- (TfLiteFrameBuffer*)frameBufferWithError:(NSError**)error; +@end + +@implementation TFLCVPixelBufferUtils + ++ (uint8_t*)convertBGRAtoRGBforPixelBufferBaseAddress: + (CVPixelBufferRef)pixelBuffer + error:(NSError**)error { + size_t width = CVPixelBufferGetWidth(pixelBuffer); + size_t height = CVPixelBufferGetHeight(pixelBuffer); + size_t stride = CVPixelBufferGetBytesPerRow(pixelBuffer); + + int destinationChannelCount = 3; + size_t destinationBytesPerRow = destinationChannelCount * width; + + uint8_t* pixelBufferBaseAddress = + (uint8_t*)CVPixelBufferGetBaseAddress(pixelBuffer); + + uint8_t* destPixelBufferAddress = + [TFLCommonUtils mallocWithSize:height * destinationBytesPerRow + error:error]; + + if (!destPixelBufferAddress) { + return NULL; + } + + vImage_Buffer srcBuffer = {.data = pixelBufferBaseAddress, + .height = height, + .width = width, + .rowBytes = stride}; + + vImage_Buffer destBuffer = {.data = destPixelBufferAddress, + .height = height, + .width = width, + .rowBytes = destinationBytesPerRow}; + + vImage_Error convertError = kvImageNoError; + convertError = + vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer, kvImageNoFlags); + + if (convertError != kvImageNoError) { + [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeImageProcessingError + description:@"Image format conversion failed." + error:error]; + + return NULL; + } + + return destPixelBufferAddress; +} + +@end + +@implementation UIImage (RawPixelDataUtils) + +- (TfLiteFrameBuffer*)frameBufferWithError:(NSError**)error { + TfLiteFrameBuffer* frameBuffer = NULL; + + if (self.CGImage) { + frameBuffer = [self frameBufferFromCGImage:self.CGImage error:error]; + } else if (self.CIImage) { + frameBuffer = [self frameBufferFromCIImage:self.CIImage error:error]; + } else { + [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError + description:@"UIImage should be initialized from" + " CIImage or CGImage." + error:error]; + } + + return frameBuffer; +} + ++ (UInt8* _Nullable)pixelDataFromCGImage:(CGImageRef)cgImage + error:(NSError**)error { + long width = CGImageGetWidth(cgImage); + long height = CGImageGetHeight(cgImage); + + int bitsPerComponent = 8; + UInt8* buffer_to_return = NULL; + + CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB(); + CGContextRef context = + CGBitmapContextCreate(nil, width, height, bitsPerComponent, 0, colorSpace, + kCGImageAlphaNoneSkipLast); + + if (context) { + CGContextDrawImage(context, CGRectMake(0, 0, width, height), cgImage); + buffer_to_return = [UIImage + populateRGBBufferFromSourceRGBABuffer:CGBitmapContextGetData(context) + width:width + height:height]; + CGContextRelease(context); + } + + if (buffer_to_return == NULL) { + [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeImageProcessingError + description:@"Image format conversion failed." + error:error]; + } + + CGColorSpaceRelease(colorSpace); + + return buffer_to_return; +} + ++ (nullable UInt8*)populateRGBBufferFromSourceRGBABuffer:(UInt8*)buffer + width:(size_t)width + height:(size_t)height { + if (!buffer) + return nil; + + int sourceChannelCount = 4; + int destChannelCount = 3; + + UInt8* buffer_to_return = malloc(height * destChannelCount * width); + if (!buffer_to_return) { + return nil; + } + for (int row = 0; row < width; row++) { + for (int col = 0; col < height; col++) { + long offset = sourceChannelCount * (col * width + row); + long rgbOffset = destChannelCount * (col * width + row); + buffer_to_return[rgbOffset] = buffer[offset]; + buffer_to_return[rgbOffset + 1] = buffer[offset + 1]; + buffer_to_return[rgbOffset + 2] = buffer[offset + 2]; + } + } + return buffer_to_return; +} + +- (TfLiteFrameBuffer*)frameBufferFromCGImage:(CGImageRef)cgImage + error:(NSError**)error { + UInt8* buffer = [UIImage pixelDataFromCGImage:cgImage error:error]; + + if (buffer == NULL) { + return NULL; + } + + TfLiteFrameBuffer* cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer)); + + cFrameBuffer->dimension.width = (int)CGImageGetWidth(cgImage); + cFrameBuffer->dimension.height = (int)CGImageGetHeight(cgImage); + cFrameBuffer->buffer = buffer; + + enum TfLiteFrameBufferFormat cPixelFormat = kRGB; + cFrameBuffer->format = cPixelFormat; + + return cFrameBuffer; +} + +- (TfLiteFrameBuffer*)frameBufferFromCIImage:(CIImage*)ciImage + error:(NSError**)error { + uint8_t* buffer = nil; + + int width = 0; + int height = 0; + if (ciImage.pixelBuffer) { + buffer = [TFLCVPixelBufferUtils + convertBGRAtoRGBforPixelBufferBaseAddress:ciImage.pixelBuffer + error:error]; + width = (int)CVPixelBufferGetWidth(ciImage.pixelBuffer); + height = (int)CVPixelBufferGetHeight(ciImage.pixelBuffer); + + } else if (ciImage.CGImage) { + buffer = [UIImage pixelDataFromCGImage:ciImage.CGImage error:error]; + width = (int)CGImageGetWidth(ciImage.CGImage); + height = (int)CGImageGetWidth(ciImage.CGImage); + } else { + [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError + description:@"CIImage should have CGImage or " + "CVPixelBuffer info." + error:error]; + } + + if (buffer == NULL) { + return NULL; + } + + TfLiteFrameBuffer* cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer)); + cFrameBuffer->buffer = buffer; + cFrameBuffer->dimension.width = width; + cFrameBuffer->dimension.height = height; + + enum TfLiteFrameBufferFormat cPixelFormat = kRGBA; + cFrameBuffer->format = cPixelFormat; + + return cFrameBuffer; +} + +@end + +@implementation GMLImageUtils + ++ (nullable TfLiteFrameBuffer*)cFrameBufferFromGMLImage:(GMLImage*)gmlImage + error:(NSError* _Nullable*) + error { + TfLiteFrameBuffer* cFrameBuffer = NULL; + + switch (gmlImage.imageSourceType) { + case GMLImageSourceTypeSampleBuffer: { + CVPixelBufferRef sampleImagePixelBuffer = + CMSampleBufferGetImageBuffer(gmlImage.sampleBuffer); + cFrameBuffer = + [GMLImageUtils bufferFromCVPixelBuffer:sampleImagePixelBuffer + error:error]; + break; + } + case GMLImageSourceTypePixelBuffer: { + cFrameBuffer = [GMLImageUtils bufferFromCVPixelBuffer:gmlImage.pixelBuffer + error:error]; + break; + } + case GMLImageSourceTypeImage: { + cFrameBuffer = [GMLImageUtils frameBufferFromUIImage:gmlImage.image + error:error]; + } + + default: + [TFLCommonUtils + customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError + description:@"Invalid source type for GMLImage." + error:error]; + break; + } + + return cFrameBuffer; +} + ++ (TfLiteFrameBuffer*)frameBufferFromUIImage:(UIImage*)image + error:(NSError**)error { + return [image frameBufferWithError:error]; +} + ++ (TfLiteFrameBuffer*)bufferFromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer + error:(NSError**)error { + uint8_t* buffer = nil; + enum TfLiteFrameBufferFormat cPixelFormat = kRGB; + + CVPixelBufferLockBaseAddress(pixelBuffer, 0); + OSType pixelBufferFormat = CVPixelBufferGetPixelFormatType(pixelBuffer); + + switch (pixelBufferFormat) { + case kCVPixelFormatType_24RGB: { + cPixelFormat = kRGB; + buffer = [GMLImageUtils copyPixelufferDataForInference:pixelBuffer + error:error]; + break; + } + case kCVPixelFormatType_32RGBA: { + cPixelFormat = kRGBA; + buffer = [GMLImageUtils copyPixelufferDataForInference:pixelBuffer + error:error]; + break; + } + case kCVPixelFormatType_32BGRA: { + cPixelFormat = kRGB; + buffer = [TFLCVPixelBufferUtils + convertBGRAtoRGBforPixelBufferBaseAddress:pixelBuffer + error:error]; + break; + } + + default: { + [TFLCommonUtils + customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError + description: + @"Unsupported pixel format for TfLiteFrameBufferFormat." + error:error]; + break; + } + } + + CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); + + if (!buffer) { + return nil; + } + + TfLiteFrameBuffer* cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer)); + + cFrameBuffer->dimension.width = (int)CVPixelBufferGetWidth(pixelBuffer); + cFrameBuffer->dimension.height = (int)CVPixelBufferGetHeight(pixelBuffer); + cFrameBuffer->buffer = buffer; + cFrameBuffer->format = cPixelFormat; + + return cFrameBuffer; +} + ++ (UInt8*)copyPixelufferDataForInference:(CVPixelBufferRef)pixelBuffer + error:(NSError**)error { + size_t height = CVPixelBufferGetHeight(pixelBuffer); + size_t stride = CVPixelBufferGetBytesPerRow(pixelBuffer); + UInt8* buffer = [TFLCommonUtils mallocWithSize:height * stride error:error]; + + if (buffer) + memcpy(buffer, CVPixelBufferGetBaseAddress(pixelBuffer), height * stride); + + return buffer; +} + +@end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/BUILD new file mode 100644 index 0000000..6ae128e1 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/BUILD
@@ -0,0 +1,57 @@ +load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") +load("@org_tensorflow//tensorflow/lite/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") +load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") +load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner") + +package( + default_visibility = ["//tensorflow_lite_support:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) + +swift_library( + name = "TFLImageClassifierSwiftTestLibrary", + testonly = 1, + srcs = ["TFLImageClassifierTests.swift"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_images", + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_models", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + "//tensorflow_lite_support/ios/task/vision:TFLImageClassifier", + ], +) + +ios_unit_test( + name = "TFLImageClassifierSwiftTest", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":TFLImageClassifierSwiftTestLibrary", + ], +) + +objc_library( + name = "TFLImageClassifierObjcTestLibrary", + testonly = 1, + srcs = ["TFLImageClassifierTests.m"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_images", + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_models", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + "//tensorflow_lite_support/ios/task/vision:TFLImageClassifier", + ], +) + +ios_unit_test( + name = "TFLImageClassifierObjcTest", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":TFLImageClassifierObjcTestLibrary", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m new file mode 100644 index 0000000..f269594 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m
@@ -0,0 +1,160 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import <XCTest/XCTest.h> +#import "tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface TFLImageClassifierTests : XCTestCase +@property(nonatomic, nullable) NSString* modelPath; +@end + +@implementation TFLImageClassifierTests + +- (GMLImage*)imageFromBundleWithName:(NSString*)name ofType:(NSString*)type { + NSString* imagePath = + [[NSBundle bundleForClass:[self class]] pathForResource:name ofType:type]; + XCTAssertNotNil(imagePath); + UIImage* image = [[UIImage alloc] initWithContentsOfFile:imagePath]; + XCTAssertNotNil(image); + + GMLImage* gmlImage = [[GMLImage alloc] initWithImage:image]; + XCTAssertNotNil(gmlImage); + + return gmlImage; +} +- (void)setUp { + // Put setup code here. This method is called before the invocation of each + // test method in the class. static let bundle = Bundle(for: + // TFLSentencepieceTokenizerTest.self) + self.modelPath = [[NSBundle bundleForClass:[self class]] + pathForResource:@"mobilenet_v2_1.0_224" + ofType:@"tflite"]; + XCTAssertNotNil(self.modelPath); +} + +- (void)tearDown { + // Put teardown code here. This method is called after the invocation of each + // test method in the class. +} + +- (void)testSuccessfullImageInferenceOnMLImageWithUIImage { + TFLImageClassifierOptions* imageClassifierOptions = + [[TFLImageClassifierOptions alloc] initWithModelPath:self.modelPath]; + + TFLImageClassifier* imageClassifier = + [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions + error:nil]; + XCTAssertNotNil(imageClassifier); + GMLImage* gmlImage = [self imageFromBundleWithName:@"burger" ofType:@"jpg"]; + + TFLClassificationResult* classificationResults = + [imageClassifier classifyWithGMLImage:gmlImage error:nil]; + XCTAssertTrue([classificationResults.classifications count] > 0); + XCTAssertTrue([classificationResults.classifications[0].categories count] > + 0); + + TFLCategory* category = + classificationResults.classifications[0].categories[0]; + XCTAssertTrue([category.label isEqual:@"cheeseburger"]); + // TODO: match the score as image_classifier_test.cc + XCTAssertEqualWithAccuracy(category.score, 0.748976, 0.001); +} + +- (void)testModelOptionsWithMaxResults { + TFLImageClassifierOptions* imageClassifierOptions = + [[TFLImageClassifierOptions alloc] initWithModelPath:self.modelPath]; + int maxResults = 3; + imageClassifierOptions.classificationOptions.maxResults = maxResults; + + TFLImageClassifier* imageClassifier = + [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions + error:nil]; + XCTAssertNotNil(imageClassifier); + + GMLImage* gmlImage = [self imageFromBundleWithName:@"burger" ofType:@"jpg"]; + + TFLClassificationResult* classificationResults = + [imageClassifier classifyWithGMLImage:gmlImage error:nil]; + XCTAssertTrue([classificationResults.classifications count] > 0); + XCTAssertLessThanOrEqual( + [classificationResults.classifications[0].categories count], maxResults); + + TFLCategory* category = + classificationResults.classifications[0].categories[0]; + XCTAssertTrue([category.label isEqual:@"cheeseburger"]); + // TODO: match the score as image_classifier_test.cc + XCTAssertEqualWithAccuracy(category.score, 0.748976, 0.001); +} + +- (void)testInferenceWithBoundingBox { + TFLImageClassifierOptions* imageClassifierOptions = + [[TFLImageClassifierOptions alloc] initWithModelPath:self.modelPath]; + int maxResults = 3; + imageClassifierOptions.classificationOptions.maxResults = maxResults; + + TFLImageClassifier* imageClassifier = + [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions + error:nil]; + XCTAssertNotNil(imageClassifier); + + GMLImage* gmlImage = [self imageFromBundleWithName:@"multi_objects" + ofType:@"jpg"]; + + CGRect roi = CGRectMake(406, 110, 148, 153); + TFLClassificationResult* classificationResults = + [imageClassifier classifyWithGMLImage:gmlImage + regionOfInterest:roi + error:nil]; + XCTAssertTrue([classificationResults.classifications count] > 0); + XCTAssertTrue([classificationResults.classifications[0].categories count] > + 0); + + TFLCategory* category = + classificationResults.classifications[0].categories[0]; + // TODO: match the label and score as image_classifier_test.cc + // XCTAssertTrue([category.label isEqual:@"soccer ball"]); + // XCTAssertEqualWithAccuracy(category.score, 0.256512, 0.001); +} + +- (void)testInferenceWithRGBAImage { + TFLImageClassifierOptions* imageClassifierOptions = + [[TFLImageClassifierOptions alloc] initWithModelPath:self.modelPath]; + + TFLImageClassifier* imageClassifier = + [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions + error:nil]; + XCTAssertNotNil(imageClassifier); + + GMLImage* gmlImage = [self imageFromBundleWithName:@"sparrow" ofType:@"png"]; + XCTAssertNotNil(gmlImage); + + TFLClassificationResult* classificationResults = + [imageClassifier classifyWithGMLImage:gmlImage error:nil]; + XCTAssertTrue([classificationResults.classifications count] > 0); + XCTAssertTrue([classificationResults.classifications[0].categories count] > + 0); + + TFLCategory* category = + classificationResults.classifications[0].categories[0]; + XCTAssertTrue([category.label isEqual:@"junco"]); + // TODO: inspect if score is correct. Better to test againest "burger", + // because we know the expected result for "burger.jpg". + XCTAssertEqualWithAccuracy(category.score, 0.253016, 0.001); +} + +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.swift b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.swift new file mode 100644 index 0000000..0375168 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.swift
@@ -0,0 +1,136 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +import XCTest + +@testable import TFLImageClassifier + +class TFLImageClassifierTests: XCTestCase { + + static let bundle = Bundle(for: TFLImageClassifierTests.self) + static let modelPath = bundle.path( + forResource: "mobilenet_v2_1.0_224", + ofType: "tflite")! + + func testSuccessfullInferenceOnMLImageWithUIImage() throws { + + let modelPath = try XCTUnwrap(TFLImageClassifierTests.modelPath) + + let imageClassifierOptions = TFLImageClassifierOptions(modelPath: modelPath) + XCTAssertNotNil(imageClassifierOptions) + + let imageClassifier = + try TFLImageClassifier.imageClassifier(options: imageClassifierOptions!) + + let gmlImage = try gmlImage(withName: "burger", ofType: "jpg") + let classificationResults: TFLClassificationResult = + try imageClassifier.classify(gmlImage: gmlImage) + + XCTAssertNotNil(classificationResults) + XCTAssertEqual(classificationResults.classifications.count, 1) + XCTAssertGreaterThan(classificationResults.classifications[0].categories.count, 0) + // TODO: match the score as image_classifier_test.cc + let category = classificationResults.classifications[0].categories[0] + XCTAssertEqual(category.label, "cheeseburger") + XCTAssertEqual(category.score, 0.748976, accuracy: 0.001) + } + + func testModelOptionsWithMaxResults() throws { + + let modelPath = try XCTUnwrap(TFLImageClassifierTests.modelPath) + + let imageClassifierOptions = TFLImageClassifierOptions(modelPath: modelPath) + XCTAssertNotNil(imageClassifierOptions) + + let maxResults = 3 + imageClassifierOptions!.classificationOptions.maxResults = maxResults + + let imageClassifier = + try TFLImageClassifier.imageClassifier(options: imageClassifierOptions!) + + let gmlImage = try gmlImage(withName: "burger", ofType: "jpg") + + let classificationResults: TFLClassificationResult = try imageClassifier.classify( + gmlImage: gmlImage) + + XCTAssertNotNil(classificationResults) + XCTAssertEqual(classificationResults.classifications.count, 1) + XCTAssertLessThanOrEqual(classificationResults.classifications[0].categories.count, maxResults) + + // TODO: match the score as image_classifier_test.cc + let category = classificationResults.classifications[0].categories[0] + XCTAssertEqual(category.label, "cheeseburger") + XCTAssertEqual(category.score, 0.748976, accuracy: 0.001) + } + + func testInferenceWithBoundingBox() throws { + + let modelPath = try XCTUnwrap(TFLImageClassifierTests.modelPath) + + let imageClassifierOptions = TFLImageClassifierOptions(modelPath: modelPath) + XCTAssertNotNil(imageClassifierOptions) + + let imageClassifier = + try TFLImageClassifier.imageClassifier(options: imageClassifierOptions!) + + let gmlImage = try gmlImage(withName: "multi_objects", ofType: "jpg") + + let roi = CGRect(x: 406, y: 110, width: 148, height: 153) + let classificationResults = + try imageClassifier.classify(gmlImage: gmlImage, regionOfInterest: roi) + + XCTAssertNotNil(classificationResults) + XCTAssertEqual(classificationResults.classifications.count, 1) + XCTAssertGreaterThan(classificationResults.classifications[0].categories.count, 0) + + // TODO: match the label and score as image_classifier_test.cc + // let category = classificationResults.classifications[0].categories[0] + // XCTAssertEqual(category.label, "soccer ball") + // XCTAssertEqual(category.score, 0.256512, accuracy:0.001); + } + + func testInferenceWithRGBAImage() throws { + + let modelPath = try XCTUnwrap(TFLImageClassifierTests.modelPath) + + let imageClassifierOptions = TFLImageClassifierOptions(modelPath: modelPath) + XCTAssertNotNil(imageClassifierOptions) + + let imageClassifier = + try TFLImageClassifier.imageClassifier(options: imageClassifierOptions!) + + let gmlImage = try gmlImage(withName: "sparrow", ofType: "png") + + let classificationResults = + try imageClassifier.classify(gmlImage: gmlImage) + + XCTAssertNotNil(classificationResults) + XCTAssertEqual(classificationResults.classifications.count, 1) + XCTAssertGreaterThan(classificationResults.classifications[0].categories.count, 0) + + let category = classificationResults.classifications[0].categories[0] + XCTAssertEqual(category.label, "junco") + XCTAssertEqual(category.score, 0.253016, accuracy: 0.001) + } + + private func gmlImage(withName name: String, ofType type: String) throws -> MLImage { + let imagePath = + try XCTUnwrap(TFLImageClassifierTests.bundle.path(forResource: name, ofType: type)) + let image = UIImage(contentsOfFile: imagePath) + let imageForInference = try XCTUnwrap(image) + let gmlImage = try XCTUnwrap(MLImage(image: imageForInference)) + + return gmlImage + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/BUILD index 34ba9c6..b894465 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/BUILD
@@ -1,5 +1,5 @@ load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") -load("@org_tensorflow//tensorflow/lite/experimental/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") +load("@org_tensorflow//tensorflow/lite/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner") @@ -50,7 +50,6 @@ tags = TFL_DEFAULT_TAGS, deps = [ ":TFLBertTokenizer", - "//third_party/swift/xctest", ], ) @@ -91,7 +90,6 @@ tags = TFL_DEFAULT_TAGS, deps = [ ":TFLSentencepieceTokenizer", - "//third_party/swift/xctest", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLBertTokenizerTest.swift b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLBertTokenizerTest.swift index e805f30..b69af12e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLBertTokenizerTest.swift +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLBertTokenizerTest.swift
@@ -18,7 +18,7 @@ class TFLBertTokenizerTest: XCTestCase { static let bundle = Bundle(for: TFLBertTokenizerTest.self) - static let mobileBertVocabPath = bundle.path(forResource: "vocab", ofType: "txt")! + static let mobileBertVocabPath = bundle.path(forResource: "mobilebert_vocab", ofType: "txt")! func testInitBertTokenizerFromPath() { let bertTokenizer = TFLBertTokenizer(vocabPath: TFLBertTokenizerTest.mobileBertVocabPath)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm b/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm deleted file mode 100644 index 6fb0a7f..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm +++ /dev/null
@@ -1,27 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#inmport "base/strings/sys_string_conversions.h" -#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h" - -std::string MakeString(NSString* str) { - return SysNSStringToUTF8(str); -} - -NSString* MakeNSString(const std::string& str) { - return [[NSString alloc] - initWithBytes:const_cast<void*>(static_cast<const void*>(str.data())) - length:str.length() - encoding:NSUTF8StringEncoding]; -}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/BUILD index 8dfbf8f..4076307c 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/BUILD
@@ -1,7 +1,6 @@ # Description: # TensorFlow Lite Support API in Java. -load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") load("@build_bazel_rules_android//android:rules.bzl", "android_library") package( @@ -15,36 +14,44 @@ "debug_version_script.lds", ]) +filegroup( + name = "java_srcs", + srcs = glob(["src/java/org/tensorflow/lite/support/**/*.java"]), +) + +# Android Library target for TFLite Support Library. It depends on TensorFlow +# Lite runtime (tensorflow/lite/java:tensorflowlite). If you don't want to +# introduce the native library into dependencies, use +# "tensorflowlite_support_java" instead, which depends on +# tensorflow/lite/java:tensorflowlite_java_stable. android_library( name = "tensorflowlite_support", - srcs = glob( - ["src/java/org/tensorflow/lite/support/**/*.java"], - ), - javacopts = JAVACOPTS, + srcs = [], + javacopts = ["-source 7 -target 7"], manifest = "AndroidManifest.xml", + exports = [ + ":tensorflowlite_support_java", + ], deps = [ - "@org_checkerframework_qual", - "@org_tensorflow//tensorflow/lite/java:tensorflowlite", + ":tensorflowlite_support_java", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_stable", ], ) android_library( name = "tensorflowlite_support_java", - srcs = glob( - ["src/java/org/tensorflow/lite/support/**/*.java"], - ), - javacopts = JAVACOPTS, + srcs = [":java_srcs"], + javacopts = ["-source 7 -target 7"], manifest = "AndroidManifest.xml", + # LINT.IfChange(dep) deps = [ + "@com_google_auto_value", + "@maven//:androidx_annotation_annotation", + "@maven//:com_google_android_odml_image", "@org_checkerframework_qual", - "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java_stable", # TODO(b/198734357): Use api when TF dep catches up. ], -) - -# TODO(b/156482505): Remove this target. -alias( - name = "tensorflow-lite-support-nogpu", - actual = ":tensorflow-lite-support", + # LINT.ThenChange(<INTERNAL>/release/build_support_api_pom.sh:dep) ) # This alias matches the associated .aar library name output style. @@ -52,22 +59,3 @@ name = "tensorflow-lite-support", actual = ":tensorflowlite_support", ) - -java_library( - name = "tensorflowlite_support_precondition_lib", - srcs = ["src/java/org/tensorflow/lite/support/common/SupportPreconditions.java"], - javacopts = JAVACOPTS, - deps = [ - "@org_checkerframework_qual", - ], -) - -android_library( - name = "tensorflowlite_support_precondition", - srcs = ["src/java/org/tensorflow/lite/support/common/SupportPreconditions.java"], - javacopts = JAVACOPTS, - manifest = "AndroidManifest.xml", - deps = [ - "@org_checkerframework_qual", - ], -)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/README.md b/third_party/tflite_support/src/tensorflow_lite_support/java/README.md index 8d37bf8..0d604d33 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/README.md +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/README.md
@@ -5,7 +5,7 @@ tensorflow.org](https://www.tensorflow.org/lite/inference_with_metadata/overview) for more information about all the efforts under TensorFlow Lite Support. -This directory contains the Java code for the TensorFlow Lite SupportLibrary +This directory contains the Java code for the TensorFlow Lite Support Library and TensorFlow Lite Task Library. ## TensorFlow Lite Android Support Library @@ -28,6 +28,7 @@ ## TensorFlow Lite Android Task Library + TensorFlow Lite Task Library provides optimized ready-to-use model interfaces for popular machine learning tasks, such as image classification, question and answer, etc. The model interfaces are specifically designed for each task to
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/default_version_script.lds b/third_party/tflite_support/src/tensorflow_lite_support/java/default_version_script.lds index 46bbffe..d7701cf2 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/default_version_script.lds +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/default_version_script.lds
@@ -1,6 +1,8 @@ VERS_1.0 { # Export JNI and native C symbols. global: + # Required for libunwind. This is needed if built and then run internally. + google_find_phdr; Java_*; JNI_OnLoad; JNI_OnUnload;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java new file mode 100644 index 0000000..e066146 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java
@@ -0,0 +1,346 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.audio; + +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; + +import static java.lang.System.arraycopy; + +import android.media.AudioFormat; +import android.media.AudioRecord; +import android.os.Build; + +import androidx.annotation.RequiresApi; + +import com.google.auto.value.AutoValue; + +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; + +/** + * Defines a ring buffer and some utility functions to prepare the input audio samples. + * + * <p>It maintains a <a href="https://en.wikipedia.org/wiki/Circular_buffer">Ring Buffer</a> to hold + * input audio data. Clients could feed input audio data via `load` methods and access the + * aggregated audio samples via `getTensorBuffer` method. + * + * <p>Note that this class can only handle input audio in Float (in {@link + * android.media.AudioFormat#ENCODING_PCM_16BIT}) or Short (in {@link + * android.media.AudioFormat#ENCODING_PCM_FLOAT}). Internally it converts and stores all the audio + * samples in PCM Float encoding. + * + * <p>Typical usage in Kotlin + * + * <pre> + * val tensor = TensorAudio.create(format, modelInputLength) + * tensor.load(newData) + * interpreter.run(tensor.getTensorBuffer(), outputBuffer); + * </pre> + * + * <p>Another sample usage with {@link android.media.AudioRecord} + * + * <pre> + * val tensor = TensorAudio.create(format, modelInputLength) + * Timer().scheduleAtFixedRate(delay, period) { + * tensor.load(audioRecord) + * interpreter.run(tensor.getTensorBuffer(), outputBuffer) + * } + * </pre> + */ +public class TensorAudio { + private static final String TAG = TensorAudio.class.getSimpleName(); + private final FloatRingBuffer buffer; + private final TensorAudioFormat format; + + /** + * Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code + * sampleCounts} * {@code format.getChannels()}. + * + * @param format the expected {@link TensorAudioFormat} of audio data loaded into this class. + * @param sampleCounts the number of samples to be fed into the model + */ + public static TensorAudio create(TensorAudioFormat format, int sampleCounts) { + return new TensorAudio(format, sampleCounts); + } + + /** + * Creates a {@link TensorAudio} instance with a ring buffer whose size is {@code sampleCounts} + * * + * {@code format.getChannelCount()}. + * + * @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines + * the number of channels and sample rate. + * @param sampleCounts the number of samples to be fed into the model + */ + public static TensorAudio create(AudioFormat format, int sampleCounts) { + return new TensorAudio(TensorAudioFormat.create(format), sampleCounts); + } + + /** + * Wraps a few constants describing the format of the incoming audio samples, namely number of + * channels and the sample rate. By default, channels is set to 1. + */ + @AutoValue + public abstract static class TensorAudioFormat { + private static final int DEFAULT_CHANNELS = 1; + + /** Creates a {@link TensorAudioFormat} instance from Android AudioFormat class. */ + @RequiresApi(Build.VERSION_CODES.M) + public static TensorAudioFormat create(AudioFormat format) { + return TensorAudioFormat.builder() + .setChannels(format.getChannelCount()) + .setSampleRate(format.getSampleRate()) + .build(); + } + + public abstract int getChannels(); + + public abstract int getSampleRate(); + + public static Builder builder() { + return new AutoValue_TensorAudio_TensorAudioFormat.Builder().setChannels( + DEFAULT_CHANNELS); + } + + /** Builder for {@link TensorAudioFormat} */ + @AutoValue.Builder + public abstract static class Builder { + /* By default, it's set to have 1 channel. */ + public abstract Builder setChannels(int value); + + public abstract Builder setSampleRate(int value); + + abstract TensorAudioFormat autoBuild(); + + public TensorAudioFormat build() { + TensorAudioFormat format = autoBuild(); + checkArgument( + format.getChannels() > 0, "Number of channels should be greater than 0"); + checkArgument(format.getSampleRate() > 0, "Sample rate should be greater than 0"); + return format; + } + } + } + + /** + * Stores the input audio samples {@code src} in the ring buffer. + * + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For + * multi-channel input, the array is interleaved. + */ + public void load(float[] src) { + load(src, 0, src.length); + } + + /** + * Stores the input audio samples {@code src} in the ring buffer. + * + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For + * multi-channel input, the array is interleaved. + * @param offsetInFloat starting position in the {@code src} array + * @param sizeInFloat the number of float values to be copied + * @throws IllegalArgumentException for incompatible audio format or incorrect input size + */ + public void load(float[] src, int offsetInFloat, int sizeInFloat) { + checkArgument(sizeInFloat % format.getChannels() == 0, + String.format("Size (%d) needs to be a multiplier of the number of channels (%d)", + sizeInFloat, format.getChannels())); + buffer.load(src, offsetInFloat, sizeInFloat); + } + + /** + * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the + * ring buffer. + * + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For + * multi-channel input, the array is interleaved. + */ + public void load(short[] src) { + load(src, 0, src.length); + } + + /** + * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the + * ring buffer. + * + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For + * multi-channel input, the array is interleaved. + * @param offsetInShort starting position in the src array + * @param sizeInShort the number of short values to be copied + * @throws IllegalArgumentException if the source array can't be copied + */ + public void load(short[] src, int offsetInShort, int sizeInShort) { + checkArgument(offsetInShort + sizeInShort <= src.length, + String.format( + "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)", + offsetInShort, sizeInShort, src.length)); + float[] floatData = new float[sizeInShort]; + for (int i = offsetInShort; i < sizeInShort; i++) { + // Convert the data to PCM Float encoding i.e. values between -1 and 1 + floatData[i] = src[i] / Short.MAX_VALUE; + } + load(floatData); + } + + /** + * Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only + * supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT. + * + * @param record an instance of {@link android.media.AudioRecord} + * @return number of captured audio values whose size is {@code channelCount * sampleCount}. If + * there was no new data in the AudioRecord or an error occurred, this method will return 0. + * @throws IllegalArgumentException for unsupported audio encoding format + * @throws IllegalStateException if reading from AudioRecord failed + */ + @RequiresApi(Build.VERSION_CODES.M) + public int load(AudioRecord record) { + checkArgument(this.format.equals(TensorAudioFormat.create(record.getFormat())), + "Incompatible audio format."); + int loadedValues = 0; + if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) { + float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()]; + loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING); + if (loadedValues > 0) { + load(newData, 0, loadedValues); + return loadedValues; + } + } else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) { + short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()]; + loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING); + if (loadedValues > 0) { + load(newData, 0, loadedValues); + return loadedValues; + } + } else { + throw new IllegalArgumentException( + "Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT."); + } + + switch (loadedValues) { + case AudioRecord.ERROR_INVALID_OPERATION: + throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION"); + + case AudioRecord.ERROR_BAD_VALUE: + throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE"); + + case AudioRecord.ERROR_DEAD_OBJECT: + throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT"); + + case AudioRecord.ERROR: + throw new IllegalStateException("AudioRecord.ERROR"); + + default: + return 0; + } + } + + /** + * Returns a float {@link TensorBuffer} holding all the available audio samples in {@link + * android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1]. + */ + public TensorBuffer getTensorBuffer() { + ByteBuffer byteBuffer = buffer.getBuffer(); + TensorBuffer tensorBuffer = TensorBuffer.createFixedSize( + new int[] {/* batch= */ 1, + /* modelInputLengthInFloat= */ byteBuffer.asFloatBuffer().limit()}, + DataType.FLOAT32); + tensorBuffer.loadBuffer(byteBuffer); + return tensorBuffer; + } + + /* Returns the {@link TensorAudioFormat} associated with the tensor. */ + public TensorAudioFormat getFormat() { + return format; + } + + private TensorAudio(TensorAudioFormat format, int sampleCounts) { + this.format = format; + this.buffer = new FloatRingBuffer(sampleCounts * format.getChannels()); + } + + /** Actual implementation of the ring buffer. */ + private static class FloatRingBuffer { + private final float[] buffer; + private int nextIndex = 0; + + public FloatRingBuffer(int flatSize) { + buffer = new float[flatSize]; + } + + /** + * Loads the entire float array to the ring buffer. If the float array is longer than ring + * buffer's capacity, samples with lower indices in the array will be ignored. + */ + public void load(float[] newData) { + load(newData, 0, newData.length); + } + + /** + * Loads a slice of the float array to the ring buffer. If the float array is longer than + * ring buffer's capacity, samples with lower indices in the array will be ignored. + */ + public void load(float[] newData, int offset, int size) { + checkArgument(offset + size <= newData.length, + String.format( + "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)", + offset, size, newData.length)); + // If buffer can't hold all the data, only keep the most recent data of size + // buffer.length + if (size > buffer.length) { + offset = size - buffer.length; + size = buffer.length; + } + if (nextIndex + size < buffer.length) { + // No need to wrap nextIndex, just copy newData[offset:offset + size] + // to buffer[nextIndex:nextIndex+size] + arraycopy(newData, offset, buffer, nextIndex, size); + } else { + // Need to wrap nextIndex, perform copy in two chunks. + int firstChunkSize = buffer.length - nextIndex; + // First copy newData[offset:offset+firstChunkSize] to + // buffer[nextIndex:buffer.length] + arraycopy(newData, offset, buffer, nextIndex, firstChunkSize); + // Then copy newData[offset+firstChunkSize:offset+size] to + // buffer[0:size-firstChunkSize] + arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize); + } + + nextIndex = (nextIndex + size) % buffer.length; + } + + public ByteBuffer getBuffer() { + // Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms, + // which can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around + // 0.01ms), so generally we don't create direct buffer for every invocation. + ByteBuffer byteBuffer = + ByteBuffer.allocate(DataType.FLOAT32.byteSize() * buffer.length); + byteBuffer.order(ByteOrder.nativeOrder()); + FloatBuffer result = byteBuffer.asFloatBuffer(); + result.put(buffer, nextIndex, buffer.length - nextIndex); + result.put(buffer, 0, nextIndex); + byteBuffer.rewind(); + return byteBuffer; + } + + public int getCapacity() { + return buffer.length; + } + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java index 6e44108..6090f85 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
@@ -19,6 +19,7 @@ import android.content.res.AssetFileDescriptor; import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.support.common.internal.SupportPreconditions; import java.io.BufferedReader; import java.io.FileInputStream; @@ -132,8 +133,10 @@ /** * Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column - * text file). See details for vocabulary files in {@link FileUtil#loadVocabularyFile(Context, - * String)}. + * text file). + * + * <p>A vocabulary file is a single-column plain text file whose contents are split into lines, + * and each line is an individual value. The file should be in assets of the context. * * @param inputStream the input stream of an opened vocabulary file. * @return a list of vocabulary words.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java index 3f3d58a..a94adb89 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java
@@ -15,9 +15,7 @@ package org.tensorflow.lite.support.common; -/** - * Processes T object with prepared {@link Operator<T>}. - */ +/** Processes T object with prepared {@code Operator<T>}. */ public interface Processor<T> { T process(T input); }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java index b374ff5..aa900b7 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java
@@ -16,6 +16,7 @@ package org.tensorflow.lite.support.common; import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.support.common.internal.SupportPreconditions; import java.util.ArrayList; import java.util.Collections; @@ -24,7 +25,7 @@ import java.util.Map; /** - * A processor base class that chains a serial of {@link Operator<T>} and executes them. + * A processor base class that chains a serial of {@code Operator<T>} and executes them. * * <p>Typically, users could use its subclasses, e.g. {@link * org.tensorflow.lite.support.image.ImageProcessor} rather than directly use this one.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SupportPreconditions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SupportPreconditions.java deleted file mode 100644 index 7f47848..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SupportPreconditions.java +++ /dev/null
@@ -1,188 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.lite.support.common; - -import org.checkerframework.checker.nullness.qual.Nullable; - -/** Static error checking util methods. */ -public final class SupportPreconditions { - /** - * Ensures that an object reference passed as a parameter to the calling method is not null. - * - * @param reference an object reference - * @return the non-null reference that was validated - * @throws NullPointerException if {@code reference} is null - */ - public static <T extends Object> T checkNotNull(T reference) { - if (reference == null) { - throw new NullPointerException("The object reference is null."); - } - return reference; - } - - /** - * Ensures that an object reference passed as a parameter to the calling method is not null. - * - * @param reference an object reference - * @param errorMessage the exception message to use if the check fails; will be converted to a - * string using {@link String#valueOf(Object)} - * @return the non-null reference that was validated - * @throws NullPointerException if {@code reference} is null - */ - public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) { - if (reference == null) { - throw new NullPointerException(String.valueOf(errorMessage)); - } - return reference; - } - - /** - * Ensures that the given String is not empty and not null. - * - * @param string the String to test - * @return the non-null non-empty String that was validated - * @throws IllegalArgumentException if {@code string} is null or empty - */ - public static String checkNotEmpty(String string) { - if (string == null || string.length() == 0) { - throw new IllegalArgumentException("Given String is empty or null."); - } - return string; - } - - /** - * Ensures that the given String is not empty and not null. - * - * @param string the String to test - * @param errorMessage the exception message to use if the check fails; will be converted to a - * string using {@link String#valueOf(Object)} - * @return the non-null non-empty String that was validated - * @throws IllegalArgumentException if {@code string} is null or empty - */ - public static String checkNotEmpty(String string, Object errorMessage) { - if (string == null || string.length() == 0) { - throw new IllegalArgumentException(String.valueOf(errorMessage)); - } - return string; - } - - /** - * Ensures the truth of an expression involving one or more parameters to the calling method. - * - * @param expression a boolean expression. - * @throws IllegalArgumentException if {@code expression} is false. - */ - public static void checkArgument(boolean expression) { - if (!expression) { - throw new IllegalArgumentException(); - } - } - - /** - * Ensures the truth of an expression involving one or more parameters to the calling method. - * - * @param expression a boolean expression. - * @param errorMessage the exception message to use if the check fails; will be converted to a - * string using {@link String#valueOf(Object)}. - * @throws IllegalArgumentException if {@code expression} is false. - */ - public static void checkArgument(boolean expression, @Nullable Object errorMessage) { - if (!expression) { - throw new IllegalArgumentException(String.valueOf(errorMessage)); - } - } - - /** - * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of - * size - * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive. - * - * @param index a user-supplied index identifying an element of an array, list or string - * @param size the size of that array, list or string - * @return the value of {@code index} - * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code - * size} - * @throws IllegalArgumentException if {@code size} is negative - */ - public static int checkElementIndex(int index, int size) { - return checkElementIndex(index, size, "index"); - } - - /** - * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of - * size - * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive. - * - * @param index a user-supplied index identifying an element of an array, list or string - * @param size the size of that array, list or string - * @param desc the text to use to describe this index in an error message - * @return the value of {@code index} - * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code - * size} - * @throws IllegalArgumentException if {@code size} is negative - */ - public static int checkElementIndex(int index, int size, @Nullable String desc) { - // Carefully optimized for execution by hotspot (explanatory comment above) - if (index < 0 || index >= size) { - throw new IndexOutOfBoundsException(badElementIndex(index, size, desc)); - } - return index; - } - - /** - * Ensures the truth of an expression involving the state of the calling instance, but not - * involving any parameters to the calling method. - * - * @param expression a boolean expression - * @throws IllegalStateException if {@code expression} is false - * @see Verify#verify Verify.verify() - */ - public static void checkState(boolean expression) { - if (!expression) { - throw new IllegalStateException(); - } - } - - /** - * Ensures the truth of an expression involving the state of the calling instance, but not - * involving any parameters to the calling method. - * - * @param expression a boolean expression - * @param errorMessage the exception message to use if the check fails; will be converted to a - * string using {@link String#valueOf(Object)} - * @throws IllegalStateException if {@code expression} is false - * @see Verify#verify Verify.verify() - */ - public static void checkState(boolean expression, @Nullable Object errorMessage) { - if (!expression) { - throw new IllegalStateException(String.valueOf(errorMessage)); - } - } - - private static String badElementIndex(int index, int size, @Nullable String desc) { - if (index < 0) { - return String.format("%s (%s) must not be negative", desc, index); - } else if (size < 0) { - throw new IllegalArgumentException("negative size: " + size); - } else { // index >= size - return String.format("%s (%s) must be less than size (%s)", desc, index, size); - } - } - - private SupportPreconditions() { - throw new AssertionError("SupportPreconditions is Uninstantiable."); - } -}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java index cb22f57..4391c45 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java
@@ -29,7 +29,7 @@ * </pre> * * @see TensorProcessor.Builder to build a {@link TensorProcessor} instance. - * @see TensorProcessor#process(TensorBuffer) to apply the processor on a {@link TensorBuffer}. + * @see TensorProcessor#process to apply the processor on a {@link TensorBuffer}. */ public class TensorProcessor extends SequentialProcessor<TensorBuffer> { private TensorProcessor(Builder builder) {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java new file mode 100644 index 0000000..29faa545 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java
@@ -0,0 +1,186 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; + +/** Static error checking util methods. */ +public final class SupportPreconditions { + /** + * Ensures that an object reference passed as a parameter to the calling method is not null. + * + * @param reference an object reference + * @return the non-null reference that was validated + * @throws NullPointerException if {@code reference} is null + */ + public static <T extends Object> T checkNotNull(T reference) { + if (reference == null) { + throw new NullPointerException("The object reference is null."); + } + return reference; + } + + /** + * Ensures that an object reference passed as a parameter to the calling method is not null. + * + * @param reference an object reference + * @param errorMessage the exception message to use if the check fails; will be converted to a + * string using {@link String#valueOf(Object)} + * @return the non-null reference that was validated + * @throws NullPointerException if {@code reference} is null + */ + public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) { + if (reference == null) { + throw new NullPointerException(String.valueOf(errorMessage)); + } + return reference; + } + + /** + * Ensures that the given String is not empty and not null. + * + * @param string the String to test + * @return the non-null non-empty String that was validated + * @throws IllegalArgumentException if {@code string} is null or empty + */ + public static String checkNotEmpty(String string) { + if (string == null || string.length() == 0) { + throw new IllegalArgumentException("Given String is empty or null."); + } + return string; + } + + /** + * Ensures that the given String is not empty and not null. + * + * @param string the String to test + * @param errorMessage the exception message to use if the check fails; will be converted to a + * string using {@link String#valueOf(Object)} + * @return the non-null non-empty String that was validated + * @throws IllegalArgumentException if {@code string} is null or empty + */ + public static String checkNotEmpty(String string, Object errorMessage) { + if (string == null || string.length() == 0) { + throw new IllegalArgumentException(String.valueOf(errorMessage)); + } + return string; + } + + /** + * Ensures the truth of an expression involving one or more parameters to the calling method. + * + * @param expression a boolean expression. + * @throws IllegalArgumentException if {@code expression} is false. + */ + public static void checkArgument(boolean expression) { + if (!expression) { + throw new IllegalArgumentException(); + } + } + + /** + * Ensures the truth of an expression involving one or more parameters to the calling method. + * + * @param expression a boolean expression. + * @param errorMessage the exception message to use if the check fails; will be converted to a + * string using {@link String#valueOf(Object)}. + * @throws IllegalArgumentException if {@code expression} is false. + */ + public static void checkArgument(boolean expression, @Nullable Object errorMessage) { + if (!expression) { + throw new IllegalArgumentException(String.valueOf(errorMessage)); + } + } + + /** + * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of + * size + * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive. + * + * @param index a user-supplied index identifying an element of an array, list or string + * @param size the size of that array, list or string + * @return the value of {@code index} + * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code + * size} + * @throws IllegalArgumentException if {@code size} is negative + */ + public static int checkElementIndex(int index, int size) { + return checkElementIndex(index, size, "index"); + } + + /** + * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of + * size + * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive. + * + * @param index a user-supplied index identifying an element of an array, list or string + * @param size the size of that array, list or string + * @param desc the text to use to describe this index in an error message + * @return the value of {@code index} + * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code + * size} + * @throws IllegalArgumentException if {@code size} is negative + */ + public static int checkElementIndex(int index, int size, @Nullable String desc) { + // Carefully optimized for execution by hotspot (explanatory comment above) + if (index < 0 || index >= size) { + throw new IndexOutOfBoundsException(badElementIndex(index, size, desc)); + } + return index; + } + + /** + * Ensures the truth of an expression involving the state of the calling instance, but not + * involving any parameters to the calling method. + * + * @param expression a boolean expression + * @throws IllegalStateException if {@code expression} is false + */ + public static void checkState(boolean expression) { + if (!expression) { + throw new IllegalStateException(); + } + } + + /** + * Ensures the truth of an expression involving the state of the calling instance, but not + * involving any parameters to the calling method. + * + * @param expression a boolean expression + * @param errorMessage the exception message to use if the check fails; will be converted to a + * string using {@link String#valueOf(Object)} + * @throws IllegalStateException if {@code expression} is false + */ + public static void checkState(boolean expression, @Nullable Object errorMessage) { + if (!expression) { + throw new IllegalStateException(String.valueOf(errorMessage)); + } + } + + private static String badElementIndex(int index, int size, @Nullable String desc) { + if (index < 0) { + return String.format("%s (%s) must not be negative", desc, index); + } else if (size < 0) { + throw new IllegalArgumentException("negative size: " + size); + } else { // index >= size + return String.format("%s (%s) must be less than size (%s)", desc, index, size); + } + } + + private SupportPreconditions() { + throw new AssertionError("SupportPreconditions is Uninstantiable."); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/package-info.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/package-info.java new file mode 100644 index 0000000..48b43d3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/package-info.java
@@ -0,0 +1,22 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** + * @hide This package is for classes that are implementation details of org.tensorflow.lite.support + * AND are used from different packages within org.tensorflow.lite.support (and which cannot + * therefore simply be declared as package-private). Classes in this package should only be used + * from within other classes in org.tensorflow.lite.support. + */ +package org.tensorflow.lite.support.common.internal;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java index 7722abd..a14cd1f 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java
@@ -16,8 +16,8 @@ package org.tensorflow.lite.support.common.ops; import org.tensorflow.lite.DataType; -import org.tensorflow.lite.support.common.SupportPreconditions; import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.internal.SupportPreconditions; import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; /** Casts a {@link TensorBuffer} to a specified data type. */ @@ -33,9 +33,9 @@ * <p>When this Op is executed, if the original {@link TensorBuffer} is already in {@code * destinationType}, the original buffer will be directly returned. * - * @param destinationType: The type of the casted {@link TensorBuffer}. + * @param destinationType The type of the casted {@link TensorBuffer}. * @throws IllegalArgumentException if {@code destinationType} is neither {@link DataType#UINT8} - * nor {@link DataType#FLOAT32}. + * nor {@link DataType#FLOAT32}. */ public CastOp(DataType destinationType) { SupportPreconditions.checkArgument(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java index e5c89e8..912df13 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
@@ -17,8 +17,8 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.tensorflow.lite.DataType; -import org.tensorflow.lite.support.common.SupportPreconditions; import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.internal.SupportPreconditions; import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; import org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat; @@ -49,9 +49,9 @@ * happen, and original input will be directly returned in execution. * * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at - * present, except that the input is a {@link DataType#UINT8} tensor, {@code mean} is set to 0 + * present, except when the input is a {@link DataType#UINT8} tensor, {@code mean} is set to 0 * and - * {@code stddev} is set to 1. + * {@code stddev} is set to 1, so that the original {@link DataType#UINT8} tensor is returned. * * @param mean the mean value to be subtracted first. * @param stddev the standard deviation value to divide then.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java index 0268f82..f9b6a1f 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java
@@ -15,11 +15,12 @@ package org.tensorflow.lite.support.image; -import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; -import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkNotNull; import android.graphics.Bitmap; import android.graphics.Bitmap.Config; +import android.media.Image; import org.tensorflow.lite.DataType; import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; @@ -64,6 +65,12 @@ } @Override + public Image getMediaImage() { + throw new UnsupportedOperationException( + "Converting from Bitmap to android.media.Image is unsupported."); + } + + @Override public int getWidth() { return bitmap.getWidth(); }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java index fe53692..a2e833b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java
@@ -15,7 +15,7 @@ package org.tensorflow.lite.support.image; -import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; import android.graphics.RectF; @@ -31,8 +31,8 @@ /** * Helper class for converting values that represents bounding boxes into rectangles. * - * <p>The class provides a static function to create bounding boxes as {@link RectF} from different - * types of configurations. + * <p>The class provides a static function to create bounding boxes as {@link + * android.graphics.RectF} from different types of configurations. * * <p>Generally, a bounding box could be represented by 4 float values, but the values could be * interpreted in many ways. We now support 3 {@link Type} of configurations, and the order of
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java index 79f7c421..716cacd 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java
@@ -15,10 +15,11 @@ package org.tensorflow.lite.support.image; -import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; import android.graphics.Bitmap; import android.graphics.Bitmap.Config; +import android.graphics.ImageFormat; import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; @@ -27,7 +28,7 @@ /** Represents the type of color space of an image. */ public enum ColorSpaceType { /** Each pixel has red, green, and blue color components. */ - RGB { + RGB(0) { // The channel axis should always be 3 for RGB images. private static final int CHANNEL_VALUE = 3; @@ -56,6 +57,11 @@ } @Override + int getNumElements(int height, int width) { + return height * width * CHANNEL_VALUE; + } + + @Override String getShapeInfoMessage() { return "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels" + " representing R, G, B in order. "; @@ -68,7 +74,7 @@ }, /** Each pixel is a single element representing only the amount of light. */ - GRAYSCALE { + GRAYSCALE(1) { // The channel axis should always be 1 for grayscale images. private static final int CHANNEL_VALUE = 1; @@ -103,6 +109,11 @@ } @Override + int getNumElements(int height, int width) { + return height * width; + } + + @Override String getShapeInfoMessage() { return "The shape of a grayscale image should be (h, w) or (1, h, w, 1). "; } @@ -111,6 +122,54 @@ Config toBitmapConfig() { return Config.ALPHA_8; } + }, + + /** YUV420sp format, encoded as "YYYYYYYY UVUV". */ + NV12(2) { + @Override + int getNumElements(int height, int width) { + return getYuv420NumElements(height, width); + } + }, + + /** + * YUV420sp format, encoded as "YYYYYYYY VUVU", the standard picture format on Android Camera1 + * preview. + */ + NV21(3) { + @Override + int getNumElements(int height, int width) { + return getYuv420NumElements(height, width); + } + }, + + /** YUV420p format, encoded as "YYYYYYYY VV UU". */ + YV12(4) { + @Override + int getNumElements(int height, int width) { + return getYuv420NumElements(height, width); + } + }, + + /** YUV420p format, encoded as "YYYYYYYY UU VV". */ + YV21(5) { + @Override + int getNumElements(int height, int width) { + return getYuv420NumElements(height, width); + } + }, + + /** + * YUV420 format corresponding to {@link android.graphics.ImageFormat#YUV_420_888}. The actual + * encoding format (i.e. NV12 / Nv21 / YV12 / YV21) depends on the implementation of the image. + * + * <p>Use this format only when you load an {@link android.media.Image}. + */ + YUV_420_888(6) { + @Override + int getNumElements(int height, int width) { + return getYuv420NumElements(height, width); + } }; private static final int BATCH_DIM = 0; // The first element of the normalizaed shape. @@ -118,6 +177,11 @@ private static final int HEIGHT_DIM = 1; // The second element of the normalizaed shape. private static final int WIDTH_DIM = 2; // The third element of the normalizaed shape. private static final int CHANNEL_DIM = 3; // The fourth element of the normalizaed shape. + private final int value; + + ColorSpaceType(int value) { + this.value = value; + } /** * Converts a bitmap configuration into the corresponding color space type. @@ -137,30 +201,79 @@ } /** + * Converts an {@link ImageFormat} value into the corresponding color space type. + * + * @throws IllegalArgumentException if the config is unsupported + */ + static ColorSpaceType fromImageFormat(int imageFormat) { + switch (imageFormat) { + case ImageFormat.NV21: + return ColorSpaceType.NV21; + case ImageFormat.YV12: + return ColorSpaceType.YV12; + case ImageFormat.YUV_420_888: + return ColorSpaceType.YUV_420_888; + default: + throw new IllegalArgumentException( + "ImageFormat: " + imageFormat + ", is not supported yet."); + } + } + + public int getValue() { + return value; + } + + /** * Verifies if the given shape matches the color space type. * * @throws IllegalArgumentException if {@code shape} does not match the color space type + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE */ void assertShape(int[] shape) { + assertRgbOrGrayScale("assertShape()"); + int[] normalizedShape = getNormalizedShape(shape); checkArgument(isValidNormalizedShape(normalizedShape), getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape)); } /** + * Verifies if the given {@code numElements} in an image buffer matches {@code height} / {@code + * width} under this color space type. For example, the {@code numElements} of an RGB image of + * 30 x 20 should be {@code 30 * 20 * 3 = 1800}; the {@code numElements} of a NV21 image of 30 x + * 20 should be {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}. + * + * @throws IllegalArgumentException if {@code shape} does not match the color space type + */ + void assertNumElements(int numElements, int height, int width) { + checkArgument(numElements >= getNumElements(height, width), + String.format( + "The given number of elements (%d) does not match the image (%s) in %d x %d. The" + + " expected number of elements should be at least %d.", + numElements, this.name(), height, width, getNumElements(height, width))); + } + + /** * Converts a {@link TensorBuffer} that represents an image to a Bitmap with the color space * type. * - * @throws IllegalArgumentException if the shape of buffer does not match the color space type + * @throws IllegalArgumentException if the shape of buffer does not match the color space type, + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE */ - abstract Bitmap convertTensorBufferToBitmap(TensorBuffer buffer); + Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) { + throw new UnsupportedOperationException( + "convertTensorBufferToBitmap() is unsupported for the color space type " + + this.name()); + } /** * Returns the width of the given shape corresponding to the color space type. * * @throws IllegalArgumentException if {@code shape} does not match the color space type + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE */ int getWidth(int[] shape) { + assertRgbOrGrayScale("getWidth()"); assertShape(shape); return getNormalizedShape(shape)[WIDTH_DIM]; } @@ -169,24 +282,65 @@ * Returns the height of the given shape corresponding to the color space type. * * @throws IllegalArgumentException if {@code shape} does not match the color space type + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE */ int getHeight(int[] shape) { + assertRgbOrGrayScale("getHeight()"); assertShape(shape); return getNormalizedShape(shape)[HEIGHT_DIM]; } - abstract int getChannelValue(); - + /** + * Returns the channel value corresponding to the color space type. + * + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE + */ + int getChannelValue() { + throw new UnsupportedOperationException( + "getChannelValue() is unsupported for the color space type " + this.name()); + } /** * Gets the normalized shape in the form of (1, h, w, c). Sometimes, a given shape may not have * batch or channel axis. + * + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE */ - abstract int[] getNormalizedShape(int[] shape); + int[] getNormalizedShape(int[] shape) { + throw new UnsupportedOperationException( + "getNormalizedShape() is unsupported for the color space type " + this.name()); + } - abstract String getShapeInfoMessage(); + /** + * Returns the shape information corresponding to the color space type. + * + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE + */ + String getShapeInfoMessage() { + throw new UnsupportedOperationException( + "getShapeInfoMessage() is unsupported for the color space type " + this.name()); + } - /** Converts the color space type to the corresponding bitmap config. */ - abstract Config toBitmapConfig(); + /** + * Converts the color space type to the corresponding bitmap config. + * + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE + */ + Config toBitmapConfig() { + throw new UnsupportedOperationException( + "toBitmapConfig() is unsupported for the color space type " + this.name()); + } + + /** + * Gets the number of elements given the height and width of an image. For example, the number + * of elements of an RGB image of 30 x 20 is {@code 30 * 20 * 3 = 1800}; the number of elements + * of a NV21 image of 30 x 20 is {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}. + */ + abstract int getNumElements(int height, int width); + + private static int getYuv420NumElements(int height, int width) { + // Height and width of U/V planes are half of the Y plane. + return height * width + ((height + 1) / 2) * ((width + 1) / 2) * 2; + } /** Inserts a value at the specified position and return the new array. */ private static int[] insertValue(int[] array, int pos, int value) { @@ -202,10 +356,15 @@ } protected boolean isValidNormalizedShape(int[] shape) { - if (shape[BATCH_DIM] == BATCH_VALUE && shape[HEIGHT_DIM] > 0 && shape[WIDTH_DIM] > 0 - && shape[CHANNEL_DIM] == getChannelValue()) { - return true; + return shape[BATCH_DIM] == BATCH_VALUE && shape[HEIGHT_DIM] > 0 && shape[WIDTH_DIM] > 0 + && shape[CHANNEL_DIM] == getChannelValue(); + } + + /** Some existing methods are only valid for RGB and GRAYSCALE images. */ + private void assertRgbOrGrayScale(String unsupportedMethodName) { + if (this != ColorSpaceType.RGB && this != ColorSpaceType.GRAYSCALE) { + throw new UnsupportedOperationException(unsupportedMethodName + + " only supports RGB and GRAYSCALE formats, but not " + this.name()); } - return false; } }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java index b4481fb..5c097da 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java
@@ -16,6 +16,7 @@ package org.tensorflow.lite.support.image; import android.graphics.Bitmap; +import android.media.Image; import org.tensorflow.lite.DataType; import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; @@ -50,6 +51,9 @@ */ TensorBuffer getTensorBuffer(DataType dataType); + /** Gets the {@link Image} representation of the underlying image format. */ + Image getMediaImage(); + /** Returns the color space type of the image. */ ColorSpaceType getColorSpaceType(); }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java index c555b6a..7ed5306 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java
@@ -109,7 +109,6 @@ int[] intValues = new int[w * h]; bitmap.getPixels(intValues, 0, w, 0, 0, w, h); // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time. - int flatSize = w * h * 3; int[] shape = new int[] {h, w, 3}; switch (buffer.getDataType()) { case UINT8: @@ -119,9 +118,8 @@ byteArr[j++] = (byte) ((intValues[i] >> 8) & 0xff); byteArr[j++] = (byte) (intValues[i] & 0xff); } - ByteBuffer byteBuffer = ByteBuffer.allocateDirect(flatSize); + ByteBuffer byteBuffer = ByteBuffer.wrap(byteArr); byteBuffer.order(ByteOrder.nativeOrder()); - byteBuffer.put(byteArr); buffer.loadBuffer(byteBuffer, shape); break; case FLOAT32:
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java index 5e253265..d7a853ee 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java
@@ -15,13 +15,16 @@ package org.tensorflow.lite.support.image; +import static java.lang.Math.max; +import static java.lang.Math.min; + import android.graphics.PointF; import android.graphics.RectF; import org.tensorflow.lite.support.common.Operator; import org.tensorflow.lite.support.common.SequentialProcessor; -import org.tensorflow.lite.support.common.SupportPreconditions; import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.internal.SupportPreconditions; import org.tensorflow.lite.support.image.ops.Rot90Op; import org.tensorflow.lite.support.image.ops.TensorOperatorWrapper; @@ -110,8 +113,17 @@ new PointF(rect.left, rect.top), inputImageHeight, inputImageWidth); PointF p2 = inverseTransform( new PointF(rect.right, rect.bottom), inputImageHeight, inputImageWidth); - return new RectF(Math.min(p1.x, p2.x), Math.min(p1.y, p2.y), Math.max(p1.x, p2.x), - Math.max(p1.y, p2.y)); + return new RectF(min(p1.x, p2.x), min(p1.y, p2.y), max(p1.x, p2.x), max(p1.y, p2.y)); + } + + /** + * Processes a {@link TensorImage} object with prepared {@link TensorOperator}. + * + * @throws IllegalArgumentException if the image is not supported by any op. + */ + @Override + public TensorImage process(TensorImage image) { + return super.process(image); } /**
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java new file mode 100644 index 0000000..f61f59f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java
@@ -0,0 +1,76 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkState; + +import com.google.auto.value.AutoValue; + +/** + * Represents the properties of an image object when being loaded to a {@link TensorImage}. See + * {@link TensorImage#load}. {@link ImageProperties} currently is only used with {@link + * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}. + */ +@AutoValue +public abstract class ImageProperties { + private static final int DEFAULT_HEIGHT = -1; + private static final int DEFAULT_WIDTH = -1; + + public abstract int getHeight(); + + public abstract int getWidth(); + + public abstract ColorSpaceType getColorSpaceType(); + + public static Builder builder() { + return new AutoValue_ImageProperties.Builder() + .setHeight(DEFAULT_HEIGHT) + .setWidth(DEFAULT_WIDTH); + } + + /** + * Builder for {@link ImageProperties}. Different image objects may require different + * properties. See the detais below: + * + * <ul> + * {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}: + * <li>Mandatory proterties: height / width / colorSpaceType. The shape of the TensorBuffer + * object will not be used to determine image height and width. + * </ul> + */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setHeight(int height); + + public abstract Builder setWidth(int width); + + public abstract Builder setColorSpaceType(ColorSpaceType colorSpaceType); + + abstract ImageProperties autoBuild(); + + public ImageProperties build() { + ImageProperties properties = autoBuild(); + // If width or hight are not configured by the Builder, they will be -1. + // Enforcing all properties to be populated (AutoValue will error out if objects, like + // colorSpaceType, are not set up), since they are required for TensorBuffer images. + // If in the future we have some image object types that only require a portion of these + // properties, we can delay the check when TensorImage#load() is executed. + checkState(properties.getHeight() >= 0, "Negative image height is not allowed."); + checkState(properties.getWidth() >= 0, "Negative image width is not allowed."); + return properties; + } + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java new file mode 100644 index 0000000..519aaca --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java
@@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkNotNull; + +import android.graphics.Bitmap; +import android.graphics.ImageFormat; +import android.media.Image; + +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** Holds an {@link Image} and converts it to other image formats as needed. */ +final class MediaImageContainer implements ImageContainer { + private final Image image; + + /** + * Creates a {@link MediaImageContainer} object with a YUV_420_888 {@link Image}. + * + * @throws IllegalArgumentException if the {@link ImageFormat} of {@code image} is not ARGB_8888 + */ + static MediaImageContainer create(Image image) { + return new MediaImageContainer(image); + } + + private MediaImageContainer(Image image) { + checkNotNull(image, "Cannot load null Image."); + checkArgument(image.getFormat() == ImageFormat.YUV_420_888, + "Only supports loading YUV_420_888 Image."); + this.image = image; + } + + @Override + public MediaImageContainer clone() { + throw new UnsupportedOperationException( + "android.media.Image is an abstract class and cannot be cloned."); + } + + @Override + public Bitmap getBitmap() { + throw new UnsupportedOperationException( + "Converting an android.media.Image to Bitmap is not supported."); + } + + @Override + public TensorBuffer getTensorBuffer(DataType dataType) { + throw new UnsupportedOperationException( + "Converting an android.media.Image to TesorBuffer is not supported."); + } + + @Override + public Image getMediaImage() { + return image; + } + + @Override + public int getWidth() { + return image.getWidth(); + } + + @Override + public int getHeight() { + return image.getHeight(); + } + + @Override + public ColorSpaceType getColorSpaceType() { + return ColorSpaceType.fromImageFormat(image.getFormat()); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java new file mode 100644 index 0000000..03017bf --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java
@@ -0,0 +1,119 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import com.google.android.odml.image.BitmapExtractor; +import com.google.android.odml.image.ByteBufferExtractor; +import com.google.android.odml.image.MediaImageExtractor; +import com.google.android.odml.image.MlImage; +import com.google.android.odml.image.MlImage.ImageFormat; +import com.google.auto.value.AutoValue; + +import java.nio.ByteBuffer; + +/** Converts {@code MlImage} to {@link TensorImage} and vice versa. */ +public class MlImageAdapter { + /** Proxies an {@link ImageFormat} and its equivalent {@link ColorSpaceType}. */ + @AutoValue + abstract static class ImageFormatProxy { + abstract ColorSpaceType getColorSpaceType(); + + @ImageFormat + abstract int getImageFormat(); + + static ImageFormatProxy createFromImageFormat(@ImageFormat int format) { + switch (format) { + case MlImage.IMAGE_FORMAT_RGB: + return new AutoValue_MlImageAdapter_ImageFormatProxy( + ColorSpaceType.RGB, format); + case MlImage.IMAGE_FORMAT_NV12: + return new AutoValue_MlImageAdapter_ImageFormatProxy( + ColorSpaceType.NV12, format); + case MlImage.IMAGE_FORMAT_NV21: + return new AutoValue_MlImageAdapter_ImageFormatProxy( + ColorSpaceType.NV21, format); + case MlImage.IMAGE_FORMAT_YV12: + return new AutoValue_MlImageAdapter_ImageFormatProxy( + ColorSpaceType.YV12, format); + case MlImage.IMAGE_FORMAT_YV21: + return new AutoValue_MlImageAdapter_ImageFormatProxy( + ColorSpaceType.YV21, format); + case MlImage.IMAGE_FORMAT_YUV_420_888: + return new AutoValue_MlImageAdapter_ImageFormatProxy( + ColorSpaceType.YUV_420_888, format); + case MlImage.IMAGE_FORMAT_ALPHA: + return new AutoValue_MlImageAdapter_ImageFormatProxy( + ColorSpaceType.GRAYSCALE, format); + case MlImage.IMAGE_FORMAT_RGBA: + case MlImage.IMAGE_FORMAT_JPEG: + case MlImage.IMAGE_FORMAT_UNKNOWN: + throw new IllegalArgumentException( + "Cannot create ColorSpaceType from MlImage format: " + format); + default: + throw new AssertionError("Illegal @ImageFormat: " + format); + } + } + } + + /** + * Creates a {@link TensorImage} from an {@link MlImage}. + * + * <p>IMPORTANT: The returned {@link TensorImage} shares storage with {@code mlImage}, so do not + * modify the contained object in the {@link TensorImage}, as {@code MlImage} expects its + * contained data are immutable. Also, callers should use {@code + * MlImage#getInternal()#acquire()} and {@code MlImage#release()} to avoid the {@code mlImage} + * being released unexpectedly. + * + * @throws IllegalArgumentException if the {@code mlImage} is built from an unsupported + * container. + */ + public static TensorImage createTensorImageFrom(MlImage mlImage) { + // TODO(b/190670174): Choose the best storage from multiple containers. + com.google.android.odml.image.ImageProperties mlImageProperties = + mlImage.getContainedImageProperties().get(0); + switch (mlImageProperties.getStorageType()) { + case MlImage.STORAGE_TYPE_BITMAP: + return TensorImage.fromBitmap(BitmapExtractor.extract(mlImage)); + case MlImage.STORAGE_TYPE_MEDIA_IMAGE: + TensorImage mediaTensorImage = new TensorImage(); + mediaTensorImage.load(MediaImageExtractor.extract(mlImage)); + return mediaTensorImage; + case MlImage.STORAGE_TYPE_BYTEBUFFER: + ByteBuffer buffer = ByteBufferExtractor.extract(mlImage); + ImageFormatProxy formatProxy = + ImageFormatProxy.createFromImageFormat(mlImageProperties.getImageFormat()); + TensorImage byteBufferTensorImage = new TensorImage(); + ImageProperties properties = + ImageProperties.builder() + .setColorSpaceType(formatProxy.getColorSpaceType()) + .setHeight(mlImage.getHeight()) + .setWidth(mlImage.getWidth()) + .build(); + byteBufferTensorImage.load(buffer, properties); + return byteBufferTensorImage; + default: + throw new IllegalArgumentException( + "Illegal storage type: " + mlImageProperties.getStorageType()); + } + } + + /** Creatas a {@link ColorSpaceType} from {@code MlImage.ImageFormat}. */ + public static ColorSpaceType createColorSpaceTypeFrom(@ImageFormat int imageFormat) { + return ImageFormatProxy.createFromImageFormat(imageFormat).getColorSpaceType(); + } + + private MlImageAdapter() {} +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java index 0ed88e8..6dfef70 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java
@@ -15,7 +15,10 @@ package org.tensorflow.lite.support.image; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; + import android.graphics.Bitmap; +import android.media.Image; import android.util.Log; import org.tensorflow.lite.DataType; @@ -25,28 +28,53 @@ final class TensorBufferContainer implements ImageContainer { private final TensorBuffer buffer; private final ColorSpaceType colorSpaceType; + private final int height; + private final int width; private static final String TAG = TensorBufferContainer.class.getSimpleName(); /** * Creates a {@link TensorBufferContainer} object with the specified {@link * TensorImage#ColorSpaceType}. * + * <p>Only supports {@link ColorSapceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link + * #create(TensorBuffer, ImageProperties)} for other color space types. + * * @throws IllegalArgumentException if the shape of the {@link TensorBuffer} does not match the - * specified color space type + * specified color space type, or if the color space type is not supported */ static TensorBufferContainer create(TensorBuffer buffer, ColorSpaceType colorSpaceType) { - return new TensorBufferContainer(buffer, colorSpaceType); + checkArgument( + colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE, + "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use" + + " `create(TensorBuffer, ImageProperties)` for other color space types."); + + return new TensorBufferContainer(buffer, colorSpaceType, + colorSpaceType.getHeight(buffer.getShape()), + colorSpaceType.getWidth(buffer.getShape())); } - private TensorBufferContainer(TensorBuffer buffer, ColorSpaceType colorSpaceType) { - colorSpaceType.assertShape(buffer.getShape()); + static TensorBufferContainer create(TensorBuffer buffer, ImageProperties imageProperties) { + return new TensorBufferContainer(buffer, imageProperties.getColorSpaceType(), + imageProperties.getHeight(), imageProperties.getWidth()); + } + + private TensorBufferContainer( + TensorBuffer buffer, ColorSpaceType colorSpaceType, int height, int width) { + checkArgument(colorSpaceType != ColorSpaceType.YUV_420_888, + "The actual encoding format of YUV420 is required. Choose a ColorSpaceType from: NV12," + + " NV21, YV12, YV21. Use YUV_420_888 only when loading an android.media.Image."); + + colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width); this.buffer = buffer; this.colorSpaceType = colorSpaceType; + this.height = height; + this.width = width; } @Override public TensorBufferContainer clone() { - return create(TensorBuffer.createFrom(buffer, buffer.getDataType()), colorSpaceType); + return new TensorBufferContainer(TensorBuffer.createFrom(buffer, buffer.getDataType()), + colorSpaceType, getHeight(), getWidth()); } @Override @@ -73,13 +101,23 @@ } @Override + public Image getMediaImage() { + throw new UnsupportedOperationException( + "Converting from TensorBuffer to android.media.Image is unsupported."); + } + + @Override public int getWidth() { - return colorSpaceType.getWidth(buffer.getShape()); + // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created. + colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width); + return width; } @Override public int getHeight() { - return colorSpaceType.getHeight(buffer.getShape()); + // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created. + colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width); + return height; } @Override
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java index 0ab62ae..a5a1252 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
@@ -15,9 +15,10 @@ package org.tensorflow.lite.support.image; -import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; import android.graphics.Bitmap; +import android.media.Image; import org.tensorflow.lite.DataType; import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; @@ -32,15 +33,16 @@ * <p>At present, only RGB images are supported, and the A channel is always ignored. * * <p>Details of data storage: a {@link TensorImage} object may have 2 potential sources of truth: a - * {@link Bitmap} or a {@link TensorBuffer}. {@link TensorImage} maintains the state and only - * converts one to the other when needed. A typical use case of {@link TensorImage} is to first load - * a {@link Bitmap} image, then process it using {@link ImageProcessor}, and finally get the - * underlying {@link ByteBuffer} of the {@link TensorBuffer} and feed it into the TFLite - * interpreter. + * {@link android.graphics.Bitmap} or a {@link TensorBuffer}. {@link TensorImage} maintains the + * state and only converts one to the other when needed. A typical use case of {@link TensorImage} + * is to first load a {@link android.graphics.Bitmap} image, then process it using {@link + * ImageProcessor}, and finally get the underlying {@link ByteBuffer} of the {@link TensorBuffer} + * and feed it into the TFLite interpreter. * * <p>IMPORTANT: to achieve the best performance, {@link TensorImage} avoids copying data whenever * it's possible. Therefore, it doesn't own its data. Callers should not modify data objects those - * are passed to {@link TensorImage#load(Bitmap)} or {@link TensorImage#load(TensorBuffer)}. + * are passed to {@link TensorImage#load(Bitmap)} or {@link TensorImage#load(TensorBuffer, + * ColorSpaceType)}. * * <p>IMPORTANT: all methods are not proved thread-safe. * @@ -87,10 +89,11 @@ } /** - * Initializes a {@link TensorImage} object of {@link DataType#UINT8} with a {@link Bitmap} . + * Initializes a {@link TensorImage} object of {@link DataType#UINT8} with a {@link + * android.graphics.Bitmap} . * - * @see TensorImage#load(Bitmap) for reusing the object when it's expensive to create objects - * frequently, because every call of {@code fromBitmap} creates a new {@link TensorImage}. + * @see #load(Bitmap) for reusing the object when it's expensive to create objects frequently, + * because every call of {@code fromBitmap} creates a new {@link TensorImage}. */ public static TensorImage fromBitmap(Bitmap bitmap) { TensorImage image = new TensorImage(); @@ -113,11 +116,12 @@ } /** - * Loads a {@link Bitmap} image object into this {@link TensorImage}. + * Loads a {@link android.graphics.Bitmap} image object into this {@link TensorImage}. * * <p>Note: if the {@link TensorImage} has data type other than {@link DataType#UINT8}, numeric * casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link - * #getBuffer}, where the {@link Bitmap} will be converted into a {@link TensorBuffer}. + * #getBuffer}, where the {@link android.graphics.Bitmap} will be converted into a {@link + * TensorBuffer}. * * <p>Important: when loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore. * The @@ -187,21 +191,84 @@ /** * Loads a {@link TensorBuffer} containing pixel values with the specific {@link - * ColorSapceType}. + * ColorSpaceType}. + * + * <p>Only supports {@link ColorSpaceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link + * #load(TensorBuffer, ImageProperties)} for other color space types. * * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link * #getBuffer}. * - * @throws IllegalArgumentException if the shape of buffer does not match the color space type - * @see ColorSpaceType#assertShape + * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or + * (1, h, w, 3) for RGB images, and either (h, w) or (1, h, w) for GRAYSCALE images + * @throws IllegalArgumentException if the shape of buffer does not match the color space type, + * or + * if the color space type is not supported */ public void load(TensorBuffer buffer, ColorSpaceType colorSpaceType) { + checkArgument( + colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE, + "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use" + + " `load(TensorBuffer, ImageProperties)` for other color space types."); + container = TensorBufferContainer.create(buffer, colorSpaceType); } /** - * Returns a {@link Bitmap} representation of this {@link TensorImage}. + * Loads a {@link TensorBuffer} containing pixel values with the specific {@link + * ImageProperties}. + * + * <p>The shape of the {@link TensorBuffer} will not be used to determine image height and + * width. Set image properties through {@link ImageProperties}. + * + * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link + * #getBuffer}. + * + * @throws IllegalArgumentException if buffer size is less than the image size indicated by + * image + * height, width, and color space type in {@link ImageProperties} + */ + public void load(TensorBuffer buffer, ImageProperties imageProperties) { + container = TensorBufferContainer.create(buffer, imageProperties); + } + + /** + * Loads a {@link ByteBuffer} containing pixel values with the specific {@link ImageProperties}. + * + * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link + * #getBuffer}. + * + * @throws IllegalArgumentException if buffer size is less than the image size indicated by + * image + * height, width, and color space type in {@link ImageProperties} + */ + public void load(ByteBuffer buffer, ImageProperties imageProperties) { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); + tensorBuffer.loadBuffer(buffer, new int[] {buffer.limit()}); + container = TensorBufferContainer.create(tensorBuffer, imageProperties); + } + + /** + * Loads an {@link android.media.Image} object into this {@link TensorImage}. + * + * <p>The main usage of this method is to load an {@link android.media.Image} object as model + * input to the <a href="TFLite Task + * Library">https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview</a>. + * {@link TensorImage} backed by {@link android.media.Image} is not supported by {@link + * ImageProcessor}. + * + * <p>* @throws IllegalArgumentException if the {@link android.graphics.ImageFormat} of {@code + * image} is not YUV_420_888 + */ + public void load(Image image) { + container = MediaImageContainer.create(image); + } + + /** + * Returns a {@link android.graphics.Bitmap} representation of this {@link TensorImage}. * * <p>Numeric casting and clamping will be applied if the stored data is not uint8. * @@ -211,9 +278,9 @@ * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for * performance concern, but if modification is necessary, please make a copy. * - * @return a reference to a {@link Bitmap} in {@code ARGB_8888} config ("A" channel is always - * opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} of this {@link - * TensorBuffer}. + * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A" + * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} + * of this {@link TensorBuffer}. * @throws IllegalStateException if the {@link TensorImage} never loads data */ public Bitmap getBitmap() { @@ -265,6 +332,29 @@ } /** + * Returns an {@link android.media.Image} representation of this {@link TensorImage}. + * + * <p>This method only works when the {@link TensorImage} is backed by an {@link + * android.media.Image}, meaning you need to first load an {@link android.media.Image} through + * {@link #load(Image)}. + * + * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for + * performance concern, but if modification is necessary, please make a copy. + * + * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A" + * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} + * of this {@link TensorBuffer}. + * @throws IllegalStateException if the {@link TensorImage} never loads data + */ + public Image getMediaImage() { + if (container == null) { + throw new IllegalStateException("No image has been loaded yet."); + } + + return container.getMediaImage(); + } + + /** * Gets the data type of this {@link TensorImage}. * * @return a data type. Currently only {@link DataType#UINT8} and {@link DataType#FLOAT32} are
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java index 8105e1b..adccf23d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java
@@ -15,10 +15,13 @@ package org.tensorflow.lite.support.image.ops; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; + import android.graphics.Bitmap; import android.graphics.PointF; import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.support.image.ColorSpaceType; import org.tensorflow.lite.support.image.ImageOperator; import org.tensorflow.lite.support.image.TensorImage; @@ -40,9 +43,9 @@ /** * Creates a ResizeOp which can resize images to specified size in specified method. * - * @param targetHeight: The expected height of resized image. - * @param targetWidth: The expected width of resized image. - * @param resizeMethod: The algorithm to use for resizing. Options: {@link ResizeMethod} + * @param targetHeight The expected height of resized image. + * @param targetWidth The expected width of resized image. + * @param resizeMethod The algorithm to use for resizing. Options: {@link ResizeMethod} */ public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod) { this.targetHeight = targetHeight; @@ -62,6 +65,9 @@ @Override @NonNull public TensorImage apply(@NonNull TensorImage image) { + checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB, + "Only RGB images are supported in ResizeOp, but not " + + image.getColorSpaceType().name()); Bitmap scaled = Bitmap.createScaledBitmap( image.getBitmap(), targetWidth, targetHeight, useBilinear); image.load(scaled);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java index c3fd264..e5de5bb 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java
@@ -15,6 +15,8 @@ package org.tensorflow.lite.support.image.ops; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; + import android.graphics.Bitmap; import android.graphics.Bitmap.Config; import android.graphics.Canvas; @@ -22,6 +24,7 @@ import android.graphics.Rect; import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.support.image.ColorSpaceType; import org.tensorflow.lite.support.image.ImageOperator; import org.tensorflow.lite.support.image.TensorImage; @@ -43,8 +46,8 @@ * Creates a ResizeWithCropOrPadOp which could crop/pad images to specified size. It adopts * center-crop and zero-padding. * - * @param targetHeight: The expected height of cropped/padded image. - * @param targetWidth: The expected width of cropped/padded image. + * @param targetHeight The expected height of cropped/padded image. + * @param targetWidth The expected width of cropped/padded image. */ public ResizeWithCropOrPadOp(int targetHeight, int targetWidth) { this.targetHeight = targetHeight; @@ -65,6 +68,9 @@ @Override @NonNull public TensorImage apply(@NonNull TensorImage image) { + checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB, + "Only RGB images are supported in ResizeWithCropOrPadOp, but not " + + image.getColorSpaceType().name()); Bitmap input = image.getBitmap(); int srcL; int srcR;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java index 72e5d0d..86413c9 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java
@@ -15,11 +15,14 @@ package org.tensorflow.lite.support.image.ops; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; + import android.graphics.Bitmap; import android.graphics.Matrix; import android.graphics.PointF; import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.support.image.ColorSpaceType; import org.tensorflow.lite.support.image.ImageOperator; import org.tensorflow.lite.support.image.TensorImage; @@ -36,7 +39,7 @@ * Creates a Rot90 Op which will rotate image by 90 degree for {@code k} times * counter-clockwise. * - * @param k: The number of times the image is rotated by 90 degrees. If it's positive, the image + * @param k The number of times the image is rotated by 90 degrees. If it's positive, the image * will be rotated counter-clockwise. If it's negative, the op will rotate image clockwise. */ public Rot90Op(int k) { @@ -55,6 +58,9 @@ @NonNull @Override public TensorImage apply(@NonNull TensorImage image) { + checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB, + "Only RGB images are supported in Rot90Op, but not " + + image.getColorSpaceType().name()); Bitmap input = image.getBitmap(); if (numRotation == 0) { return image;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java index 90d457b..feb2b3b7 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java
@@ -18,8 +18,9 @@ import android.graphics.PointF; import org.checkerframework.checker.nullness.qual.NonNull; -import org.tensorflow.lite.support.common.SupportPreconditions; import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.internal.SupportPreconditions; +import org.tensorflow.lite.support.image.ColorSpaceType; import org.tensorflow.lite.support.image.ImageOperator; import org.tensorflow.lite.support.image.TensorImage; import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; @@ -52,9 +53,11 @@ SupportPreconditions.checkNotNull(image, "Op cannot apply on null image."); TensorBuffer resBuffer = tensorOp.apply(image.getTensorBuffer()); // Some ops may change the data type of the underlying TensorBuffer, such as CastOp. - // Therefore, need to create a new TensorImage with the correct data type. + // Therefore, need to create a new TensorImage with the correct data type. However the + // underlying ops should not touch the color type. + ColorSpaceType colorSpaceType = image.getColorSpaceType(); TensorImage resImage = new TensorImage(resBuffer.getDataType()); - resImage.load(resBuffer); + resImage.load(resBuffer, colorSpaceType); return resImage; }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java new file mode 100644 index 0000000..1a6f905 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java
@@ -0,0 +1,114 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image.ops; + +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; + +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.ColorFilter; +import android.graphics.ColorMatrixColorFilter; +import android.graphics.Paint; +import android.graphics.PointF; + +import org.tensorflow.lite.support.image.ColorSpaceType; +import org.tensorflow.lite.support.image.ImageOperator; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** + * Transforms an image to GrayScale as an image processing unit. + * + * <p>Supported color spaces: + * + * <ul> + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} + * </ul> + * + * <p>The conversion is based on OpenCV RGB to GRAY conversion + * https://docs.opencv.org/master/de/d25/imgproc_color_conversions.html#color_convert_rgb_gray + */ +public class TransformToGrayscaleOp implements ImageOperator { + // A matrix is created that will be applied later to canvas to generate grayscale image + // The luminance of each pixel is calculated as the weighted sum of the 3 RGB values + // Y = 0.299R + 0.587G + 0.114B + private static final float[] BITMAP_RGBA_GRAYSCALE_TRANSFORMATION = + new float[] {0.299F, 0.587F, 0.114F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, + 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F}; + + /** Creates a TransformToGrayscaleOp. */ + public TransformToGrayscaleOp() {} + + /** + * Applies the transformation to grayscale and returns a {@link TensorImage}. + * + * <p>If the input image is already {@link + * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}, this op will be a no-op. + * + * @throws IllegalArgumentException if the {@code image} is not {@link + * org.tensorflow.lite.support.image.ColorSpaceType#RGB} or {@link + * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}. + */ + @Override + public TensorImage apply(TensorImage image) { + if (image.getColorSpaceType() == ColorSpaceType.GRAYSCALE) { + return image; + } else { + checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB, + "Only RGB images are supported in TransformToGrayscaleOp, but not " + + image.getColorSpaceType().name()); + } + int h = image.getHeight(); + int w = image.getWidth(); + Bitmap bmpGrayscale = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888); + Canvas canvas = new Canvas(bmpGrayscale); + Paint paint = new Paint(); + ColorMatrixColorFilter colorMatrixFilter = + new ColorMatrixColorFilter(BITMAP_RGBA_GRAYSCALE_TRANSFORMATION); + paint.setColorFilter((ColorFilter) colorMatrixFilter); + canvas.drawBitmap(image.getBitmap(), 0.0F, 0.0F, paint); + + // Get the pixels from the generated grayscale image + int[] intValues = new int[w * h]; + bmpGrayscale.getPixels(intValues, 0, w, 0, 0, w, h); + // Shape with one channel + int[] shape = new int[] {1, h, w, 1}; + + // Get R channel from ARGB color + for (int i = 0; i < intValues.length; i++) { + intValues[i] = ((intValues[i] >> 16) & 0xff); + } + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, image.getDataType()); + buffer.loadArray(intValues, shape); + image.load(buffer, ColorSpaceType.GRAYSCALE); + return image; + } + + @Override + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { + return inputImageHeight; + } + + @Override + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { + return inputImageWidth; + } + + @Override + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { + return point; + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java index 07116cb..af56b70 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
@@ -15,13 +15,20 @@ package org.tensorflow.lite.support.label; +import org.tensorflow.lite.annotations.UsedByReflection; + import java.util.Objects; /** - * Category is a util class, contains a label, its display name and a float value as score. - * Typically it's used as result of classification tasks. + * Category is a util class, contains a label, its display name, a float value as score, and the + * index of the label in the corresponding label file. Typically it's used as result of + * classification tasks. */ +@UsedByReflection("TFLiteSupport/Task") public final class Category { + private static final int DEFAULT_INDEX = -1; + private static final float TOLERANCE = 1e-6f; + private final int index; private final String label; private final String displayName; private final float score; @@ -29,23 +36,37 @@ /** * Constructs a {@link Category} object. * + * @param label the label of this category object * @param displayName the display name of the label, which may be translated for different * locales. For exmaple, a label, "apple", may be translated into Spanish for display * purpose, so that the displayName is "manzana". + * @param score the probability score of this label category + * @param index the index of the label in the corresponding label file */ + @UsedByReflection("TFLiteSupport/Task") + public static Category create(String label, String displayName, float score, int index) { + return new Category(label, displayName, score, index); + } + + /** Constructs a {@link Category} object with the default index (-1). */ + @UsedByReflection("TFLiteSupport/Task") public static Category create(String label, String displayName, float score) { - return new Category(label, displayName, score); + return new Category(label, displayName, score, DEFAULT_INDEX); } - /** Constructs a {@link Category} object with an empty displayName. */ + /** + * Constructs a {@link Category} object with an empty displayName and the default index (-1). + */ + @UsedByReflection("TFLiteSupport/Task") public Category(String label, float score) { - this(label, /*displayName=*/"", score); + this(label, /*displayName=*/"", score, DEFAULT_INDEX); } - private Category(String label, String displayName, float score) { + private Category(String label, String displayName, float score, int index) { this.label = label; this.displayName = displayName; this.score = score; + this.index = index; } /** Gets the reference of category's label. */ @@ -68,25 +89,34 @@ return score; } + /** + * Gets the index of the category. The index value might be -1, which means it has not been set + * up properly and is invalid. + */ + public int getIndex() { + return index; + } + @Override public boolean equals(Object o) { if (o instanceof Category) { Category other = (Category) o; return (other.getLabel().equals(this.label) && other.getDisplayName().equals(this.displayName) - && other.getScore() == this.score); + && Math.abs(other.getScore() - this.score) < TOLERANCE + && other.getIndex() == this.index); } return false; } @Override public int hashCode() { - return Objects.hash(label, displayName, score); + return Objects.hash(label, displayName, score, index); } @Override public String toString() { - return "<Category \"" + label + "\" (displayName=" + displayName + "\" (score=" + score - + ")>"; + return "<Category \"" + label + "\" (displayName=" + displayName + " score=" + score + + " index=" + index + ")>"; } }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java index d135b4b..56ee89f0 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java
@@ -18,7 +18,7 @@ import android.util.Log; import org.checkerframework.checker.nullness.qual.NonNull; -import org.tensorflow.lite.support.common.SupportPreconditions; +import org.tensorflow.lite.support.common.internal.SupportPreconditions; import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; import java.util.ArrayList; @@ -32,14 +32,15 @@ * dictionary. Example: if the given tensor is [3, 1, 0], and given labels is ["background", * "apple", "banana", "cherry", "date"], the result will be ["date", "banana", "apple"]. * - * @param tensorBuffer: A tensor with index values. The values should be non-negative integers, - * and each value {@code x} will be converted to {@code labels[x + offset]}. If the tensor - * is given as a float {@link TensorBuffer}, values will be cast to integers. All values that + * @param tensorBuffer A tensor with index values. The values should be non-negative integers, + * and + * each value {@code x} will be converted to {@code labels[x + offset]}. If the tensor is + * given as a float {@link TensorBuffer}, values will be cast to integers. All values that * are out of bound will map to empty string. - * @param labels: A list of strings, used as a dictionary to look up. The index of the array + * @param labels A list of strings, used as a dictionary to look up. The index of the array * element will be used as the key. To get better performance, use an object that implements * RandomAccess, such as {@link ArrayList}. - * @param offset: The offset value when look up int values in the {@code labels}. + * @param offset The offset value when look up int values in the {@code labels}. * @return the mapped strings. The length of the list is {@link TensorBuffer#getFlatSize}. * @throws IllegalArgumentException if {@code tensorBuffer} or {@code labels} is null. */
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java index ff78f40..edd683c 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java
@@ -19,7 +19,7 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.tensorflow.lite.DataType; -import org.tensorflow.lite.support.common.SupportPreconditions; +import org.tensorflow.lite.support.common.internal.SupportPreconditions; import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; import java.nio.ByteBuffer;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java index 33411f1..e44edc6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
@@ -19,7 +19,7 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.tensorflow.lite.support.common.FileUtil; -import org.tensorflow.lite.support.common.SupportPreconditions; +import org.tensorflow.lite.support.common.internal.SupportPreconditions; import org.tensorflow.lite.support.label.TensorLabel; import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java index 96ddb39..af2061e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java
@@ -19,10 +19,11 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.tensorflow.lite.Interpreter; +import org.tensorflow.lite.InterpreterApi; +import org.tensorflow.lite.InterpreterFactory; import org.tensorflow.lite.Tensor; import org.tensorflow.lite.support.common.FileUtil; -import org.tensorflow.lite.support.common.SupportPreconditions; +import org.tensorflow.lite.support.common.internal.SupportPreconditions; import java.io.IOException; import java.nio.MappedByteBuffer; @@ -79,7 +80,7 @@ } /** An instance of the driver class to run model inference with Tensorflow Lite. */ - private final Interpreter interpreter; + private final InterpreterApi interpreter; /** Path to tflite model file in asset folder. */ private final String modelPath; @@ -104,8 +105,8 @@ /** * Creates a builder which loads tflite model from asset folder using memory-mapped files. * - * @param context: Application context to access assets. - * @param modelPath: Asset path of the model (.tflite file). + * @param context Application context to access assets. + * @param modelPath Asset path of the model (.tflite file). * @throws IOException if an I/O error occurs when loading the tflite model. */ @NonNull @@ -182,7 +183,7 @@ */ public static Model createModel(@NonNull MappedByteBuffer byteModel, @NonNull String modelPath, @NonNull Options options) { - Interpreter.Options interpreterOptions = new Interpreter.Options(); + InterpreterApi.Options interpreterOptions = new InterpreterApi.Options(); GpuDelegateProxy gpuDelegateProxy = null; switch (options.device) { case NNAPI: @@ -198,7 +199,7 @@ break; } interpreterOptions.setNumThreads(options.numThreads); - Interpreter interpreter = new Interpreter(byteModel, interpreterOptions); + InterpreterApi interpreter = new InterpreterFactory().create(byteModel, interpreterOptions); return new Model(modelPath, byteModel, interpreter, gpuDelegateProxy); } @@ -215,7 +216,7 @@ } /** - * Gets the Tensor associated with the provdied input index. + * Gets the Tensor associated with the provided input index. * * @throws IllegalStateException if the interpreter is closed. */ @@ -224,7 +225,7 @@ } /** - * Gets the Tensor associated with the provdied output index. + * Gets the Tensor associated with the provided output index. * * @throws IllegalStateException if the interpreter is closed. */ @@ -269,7 +270,7 @@ } private Model(@NonNull String modelPath, @NonNull MappedByteBuffer byteModel, - @NonNull Interpreter interpreter, @Nullable GpuDelegateProxy gpuDelegateProxy) { + @NonNull InterpreterApi interpreter, @Nullable GpuDelegateProxy gpuDelegateProxy) { this.modelPath = modelPath; this.byteModel = byteModel; this.interpreter = interpreter;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java index 7e1355ef..ec6c800e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
@@ -15,9 +15,9 @@ package org.tensorflow.lite.support.tensorbuffer; -import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; -import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull; -import static org.tensorflow.lite.support.common.SupportPreconditions.checkState; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkNotNull; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkState; import org.checkerframework.checker.nullness.qual.NonNull; import org.tensorflow.lite.DataType; @@ -50,19 +50,19 @@ * some examples: * * <pre> - * Creating a float TensorBuffer with shape {2, 3}: + * // Creating a float TensorBuffer with shape {2, 3}: * int[] shape = new int[] {2, 3}; * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); * </pre> * * <pre> - * Creating an uint8 TensorBuffer of a scalar: + * // Creating an uint8 TensorBuffer of a scalar: * int[] shape = new int[] {}; * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8); * </pre> * * <pre> - * Creating an empty uint8 TensorBuffer: + * // Creating an empty uint8 TensorBuffer: * int[] shape = new int[] {0}; * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8); * </pre> @@ -91,7 +91,25 @@ * the created {@link TensorBuffer} is {0}. * * <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of - * different buffer sizes. + * different buffer sizes. Here are some examples: + * + * <pre> + * // Creating a float dynamic TensorBuffer: + * TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); + * // Loading a float array: + * float[] arr1 = new float[] {1, 2, 3}; + * tensorBuffer.loadArray(arr, new int[] {arr1.length}); + * // loading another float array: + * float[] arr2 = new float[] {1, 2, 3, 4, 5}; + * tensorBuffer.loadArray(arr, new int[] {arr2.length}); + * // loading a third float array with the same size as arr2, assuming shape doesn't change: + * float[] arr3 = new float[] {5, 4, 3, 2, 1}; + * tensorBuffer.loadArray(arr); + * // loading a forth float array with different size as arr3 and omitting the shape will result + * // in error: + * float[] arr4 = new float[] {3, 2, 1}; + * tensorBuffer.loadArray(arr); // Error: The size of byte buffer and the shape do not match. + * </pre> * * @param dataType The dataType of the {@link TensorBuffer} to be created. */ @@ -144,12 +162,12 @@ } /** - * Gets the {@link TensorBuffer#flatSize} of the buffer. + * Gets the flatSize of the buffer. * * @throws IllegalStateException if the underlying data is corrupted */ public int getFlatSize() { - assertShapeIsCorect(); + assertShapeIsCorrect(); return flatSize; } @@ -160,7 +178,7 @@ */ @NonNull public int[] getShape() { - assertShapeIsCorect(); + assertShapeIsCorrect(); return Arrays.copyOf(shape, shape.length); } @@ -185,7 +203,7 @@ * For example, a TensorBuffer with shape {2, 3} that represents the following array, * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]]. * - * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrived by: + * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by: * float v = tensorBuffer.getFloatValue(3); * </pre> * @@ -212,7 +230,7 @@ * For example, a TensorBuffer with shape {2, 3} that represents the following array, * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]]. * - * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrived by: + * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by: * int v = tensorBuffer.getIntValue(3); * Note that v is converted from 3.0f to 3 as a result of type conversion. * </pre> @@ -255,8 +273,10 @@ * TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by * {255, 0}. * - * <p>Size of {@code src} should always match the flat size of this {@link TensorBuffer}, for - * both fixed-size and dynamic {@link TensorBuffer}. + * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this + * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always + * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link + * TensorBuffer}. Use {@link #loadArray(int[], int[])} if {@code src} has a different shape. * * @param src The source array to be loaded. */ @@ -287,8 +307,10 @@ * with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to * uint8 by {255, 0}. * - * <p>Size of {@code src} should always match the flat size of this {@link TensorBuffer}, for - * both fixed-size and dynamic {@link TensorBuffer}. + * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this + * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always + * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link + * TensorBuffer}. Use {@link #loadArray(float[], int[])} if {@code src} has a different shape. * * @param src The source array to be loaded. */ @@ -302,6 +324,9 @@ * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here * for performance concern, but if modification is necessary, please make a copy. * + * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer} + * backed by an array. + * * @param buffer The byte buffer to load. * @throws NullPointerException if {@code buffer} is null. * @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not @@ -309,11 +334,22 @@ */ public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) { checkNotNull(buffer, "Byte buffer cannot be null."); + checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative."); + int flatSize = computeFlatSize(shape); checkArgument((buffer.limit() == getTypeSize() * flatSize), - "The size of byte buffer and the shape do not match."); + "The size of byte buffer and the shape do not match. Expected: " + + getTypeSize() * flatSize + " Actual: " + buffer.limit()); - resize(shape); + if (!isDynamic) { + // Make sure the new shape fits the buffer size when TensorBuffer has fixed size. + checkArgument(Arrays.equals(shape, this.shape)); + } + + // Update to the new shape, since shape dim values might change. + this.shape = shape.clone(); + this.flatSize = flatSize; + buffer.rewind(); this.buffer = buffer; } @@ -322,9 +358,21 @@ * Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of * this {@link TensorBuffer}. * + * <p>Using this method assumes that the shape of {@code buffer} is the same as the shape of + * this + * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code buffer.limit()}) should always + * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link + * TensorBuffer}. Use {@link #loadBuffer(ByteBuffer, int[])} if {@code buffer} has a different + * shape. + * * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here * for performance concern, but if modification is necessary, please make a copy. * + * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer} + * backed by an array. + * + * <p>If the {@code buffer} is read-only, we adopt a copy-on-write strategy for performance. + * * @param buffer The byte buffer to load. */ public void loadBuffer(@NonNull ByteBuffer buffer) { @@ -373,6 +421,18 @@ } } + /** Copies the underlying {@link ByteBuffer} if it's readonly. */ + protected synchronized void copyByteBufferIfReadOnly() { + if (!buffer.isReadOnly()) { + return; + } + ByteBuffer newByteBuffer = ByteBuffer.allocateDirect(buffer.capacity()); + newByteBuffer.order(buffer.order()); + newByteBuffer.put(buffer); + newByteBuffer.rewind(); + buffer = newByteBuffer; + } + /** * Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array, * this @@ -402,7 +462,7 @@ * Verifies if the shape of the {@link TensorBuffer} matched the size of the underlying {@link * ByteBuffer}. */ - private void assertShapeIsCorect() { + private void assertShapeIsCorrect() { int flatSize = computeFlatSize(shape); checkState((buffer.limit() == getTypeSize() * flatSize), String.format(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java index 74d7e18..632db6c8 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java
@@ -17,7 +17,7 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.tensorflow.lite.DataType; -import org.tensorflow.lite.support.common.SupportPreconditions; +import org.tensorflow.lite.support.common.internal.SupportPreconditions; import java.nio.FloatBuffer; @@ -89,6 +89,7 @@ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); SupportPreconditions.checkArgument(src.length == computeFlatSize(shape), "The size of the array to be loaded does not match the specified shape."); + copyByteBufferIfReadOnly(); resize(shape); buffer.rewind(); @@ -101,6 +102,7 @@ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); SupportPreconditions.checkArgument(src.length == computeFlatSize(shape), "The size of the array to be loaded does not match the specified shape."); + copyByteBufferIfReadOnly(); resize(shape); buffer.rewind();
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java index 8bbbf95..2924ef0a 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java
@@ -17,7 +17,7 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.tensorflow.lite.DataType; -import org.tensorflow.lite.support.common.SupportPreconditions; +import org.tensorflow.lite.support.common.internal.SupportPreconditions; /** Represents data buffer with 8-bit unsigned integer values. */ public final class TensorBufferUint8 extends TensorBuffer { @@ -90,6 +90,7 @@ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); SupportPreconditions.checkArgument(src.length == computeFlatSize(shape), "The size of the array to be loaded does not match the specified shape."); + copyByteBufferIfReadOnly(); resize(shape); buffer.rewind(); @@ -106,6 +107,7 @@ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); SupportPreconditions.checkArgument(src.length == computeFlatSize(shape), "The size of the array to be loaded does not match the specified shape."); + copyByteBufferIfReadOnly(); resize(shape); buffer.rewind();
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/AndroidManifest.xml b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/AndroidManifest.xml new file mode 100644 index 0000000..cecae01 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/AndroidManifest.xml
@@ -0,0 +1,5 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="org.tensorflow.lite.task.audio"> + <uses-sdk android:minSdkVersion="23" android:targetSdkVersion="29"/> +</manifest>
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/BUILD new file mode 100644 index 0000000..1769dac --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/BUILD
@@ -0,0 +1,40 @@ +load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "AndroidManifest.xml", +]) + +android_library( + name = "task_library_audio", + srcs = [ + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier:audio_classifier_src", + ], + # TODO(b/163039980): Use JAVACOPTS in TF. "-Xep:RemoveUnusedImports:ERROR" wierdly break the build. + javacopts = ["-source 7 -target 7"], + manifest = "AndroidManifest.xml", + visibility = ["//visibility:public"], + # LINT.IfChange(dep) + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support_java", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", + "//tensorflow_lite_support/java/src/native/task/audio:task_audio_native", + "@com_google_auto_value", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java_stable", + ], + # LINT.ThenChange(<INTERNAL>/release/build_task_pom.sh:dep) +) + +# AAR target for OSS release. +# +# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ +# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio:task-library-audio +aar_with_jni( + name = "task-library-audio", + android_library = ":task_library_audio", +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java new file mode 100644 index 0000000..b3eb11f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java
@@ -0,0 +1,510 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.audio.classifier; + +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkState; + +import android.content.Context; +import android.media.AudioFormat; +import android.media.AudioRecord; +import android.media.MediaRecorder; +import android.os.ParcelFileDescriptor; + +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.annotations.UsedByReflection; +import org.tensorflow.lite.support.audio.TensorAudio; +import org.tensorflow.lite.support.audio.TensorAudio.TensorAudioFormat; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; +import org.tensorflow.lite.task.core.BaseOptions; +import org.tensorflow.lite.task.core.BaseTaskApi; +import org.tensorflow.lite.task.core.TaskJniUtils; +import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; +import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Performs classification on audio waveforms. + * + * <p>The API expects a TFLite model with <a + * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>. + * + * <p>The API supports models with one audio input tensor and one classification output tensor. To + * be more specific, here are the requirements. + * + * <ul> + * <li>Input audio tensor ({@code kTfLiteFloat32}) + * <ul> + * <li>input audio buffer of size {@code [batch x samples]}. + * <li>batch inference is not supported ({@code batch} is required to be 1). + * </ul> + * <li>Output score tensor ({@code kTfLiteFloat32}) + * <ul> + * <li>with {@code N} classes of either 2 or 4 dimensions, such as {@code [1 x N]} or {@code + * [1 x 1 x 1 x N]} + * <li>the label file is required to be packed to the metadata. See the <a + * href="https://www.tensorflow.org/lite/convert/metadata#label_output">example of + * creating metadata for an image classifier</a>. If no label files are packed, it will + * use index as label in the result. + * </ul> + * </ul> + * + * See <a href="https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1">an example</a> + * of such model, and <a + * href="https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/examples/task/audio/desktop">a + * CLI demo tool</a> for easily trying out this API. + */ +public final class AudioClassifier extends BaseTaskApi { + private static final String AUDIO_CLASSIFIER_NATIVE_LIB = "task_audio_jni"; + private static final int OPTIONAL_FD_LENGTH = -1; + private static final int OPTIONAL_FD_OFFSET = -1; + + /** + * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}. + * + * @param modelPath path of the classification model with metadata in the assets + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static AudioClassifier createFromFile(Context context, String modelPath) + throws IOException { + return createFromFileAndOptions( + context, modelPath, AudioClassifierOptions.builder().build()); + } + + /** + * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}. + * + * @param modelFile the classification model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static AudioClassifier createFromFile(File modelFile) throws IOException { + return createFromFileAndOptions(modelFile, AudioClassifierOptions.builder().build()); + } + + /** + * Creates an {@link AudioClassifier} instance with a model buffer and the default {@link + * AudioClassifierOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + */ + public static AudioClassifier createFromBuffer(final ByteBuffer modelBuffer) { + return createFromBufferAndOptions(modelBuffer, AudioClassifierOptions.builder().build()); + } + + /** + * Creates an {@link AudioClassifier} instance from {@link AudioClassifierOptions}. + * + * @param modelPath path of the classification model with metadata in the assets + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static AudioClassifier createFromFileAndOptions( + Context context, String modelPath, AudioClassifierOptions options) throws IOException { + return new AudioClassifier(TaskJniUtils.createHandleFromFdAndOptions( + context, new FdAndOptionsHandleProvider<AudioClassifierOptions>() { + @Override + public long createHandle(int fileDescriptor, long fileDescriptorLength, + long fileDescriptorOffset, AudioClassifierOptions options) { + return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength, + fileDescriptorOffset, options, + TaskJniUtils.createProtoBaseOptionsHandle( + options.getBaseOptions())); + } + }, AUDIO_CLASSIFIER_NATIVE_LIB, modelPath, options)); + } + + /** + * Creates an {@link AudioClassifier} instance. + * + * @param modelFile the classification model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static AudioClassifier createFromFileAndOptions( + File modelFile, final AudioClassifierOptions options) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + return new AudioClassifier( + TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithModelFdAndOptions(descriptor.getFd(), + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH, + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options, + TaskJniUtils.createProtoBaseOptionsHandle( + options.getBaseOptions())); + } + }, AUDIO_CLASSIFIER_NATIVE_LIB)); + } + } + + /** + * Creates an {@link AudioClassifier} instance with a model buffer and {@link + * AudioClassifierOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + */ + public static AudioClassifier createFromBufferAndOptions( + final ByteBuffer modelBuffer, final AudioClassifierOptions options) { + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { + throw new IllegalArgumentException( + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + return new AudioClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithByteBuffer(modelBuffer, options, + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions())); + } + }, AUDIO_CLASSIFIER_NATIVE_LIB)); + } + + /** + * Constructor to initialize the JNI with a pointer from C++. + * + * @param nativeHandle a pointer referencing memory allocated in C++ + */ + private AudioClassifier(long nativeHandle) { + super(nativeHandle); + } + + /** Options for setting up an {@link AudioClassifier}. */ + @UsedByReflection("audio_classifier_jni.cc") + public static class AudioClassifierOptions { + // Not using AutoValue for this class because scoreThreshold cannot have default value + // (otherwise, the default value would override the one in the model metadata) and + // `Optional` is not an option here, because + // 1. java.util.Optional require Java 8 while we need to support Java 7. + // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See + // the comments for labelAllowList. + private final BaseOptions baseOptions; + private final String displayNamesLocale; + private final int maxResults; + private final float scoreThreshold; + private final boolean isScoreThresholdSet; + // As an open source project, we've been trying avoiding depending on common java libraries, + // such as Guava, because it may introduce conflicts with clients who also happen to use + // those libraries. Therefore, instead of using ImmutableList here, we convert the List into + // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less + // vulnerable. + private final List<String> labelAllowList; + private final List<String> labelDenyList; + + public static Builder builder() { + return new Builder(); + } + + /** A builder that helps to configure an instance of AudioClassifierOptions. */ + public static class Builder { + private BaseOptions baseOptions = BaseOptions.builder().build(); + private String displayNamesLocale = "en"; + private int maxResults = -1; + private float scoreThreshold; + private boolean isScoreThresholdSet; + private List<String> labelAllowList = new ArrayList<>(); + private List<String> labelDenyList = new ArrayList<>(); + + private Builder() {} + + /** Sets the general options to configure Task APIs, such as accelerators. */ + public Builder setBaseOptions(BaseOptions baseOptions) { + this.baseOptions = baseOptions; + return this; + } + + /** + * Sets the locale to use for display names specified through the TFLite Model Metadata, + * if any. + * + * <p>Defaults to English({@code "en"}). See the <a + * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite + * Metadata schema file.</a> for the accepted pattern of locale. + */ + public Builder setDisplayNamesLocale(String displayNamesLocale) { + this.displayNamesLocale = displayNamesLocale; + return this; + } + + /** + * Sets the maximum number of top scored results to return. + * + * @param maxResults if < 0, all results will be returned. If 0, an invalid argument + * error is + * returned. Defaults to -1. + * @throws IllegalArgumentException if maxResults is 0 + */ + public Builder setMaxResults(int maxResults) { + if (maxResults == 0) { + throw new IllegalArgumentException("maxResults cannot be 0."); + } + this.maxResults = maxResults; + return this; + } + + /** + * Sets the score threshold. + * + * <p>It overrides the one provided in the model metadata (if any). Results below this + * value are rejected. + */ + public Builder setScoreThreshold(float scoreThreshold) { + this.scoreThreshold = scoreThreshold; + isScoreThresholdSet = true; + return this; + } + + /** + * Sets the optional allowlist of labels. + * + * <p>If non-empty, classifications whose label is not in this set will be filtered out. + * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList. + */ + public Builder setLabelAllowList(List<String> labelAllowList) { + this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList)); + return this; + } + + /** + * Sets the optional denylist of labels. + * + * <p>If non-empty, classifications whose label is in this set will be filtered out. + * Duplicate or unknown labels are ignored. Mutually exclusive with labelAllowList. + */ + public Builder setLabelDenyList(List<String> labelDenyList) { + this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList)); + return this; + } + + public AudioClassifierOptions build() { + return new AudioClassifierOptions(this); + } + } + + @UsedByReflection("audio_classifier_jni.cc") + public String getDisplayNamesLocale() { + return displayNamesLocale; + } + + @UsedByReflection("audio_classifier_jni.cc") + public int getMaxResults() { + return maxResults; + } + + @UsedByReflection("audio_classifier_jni.cc") + public float getScoreThreshold() { + return scoreThreshold; + } + + @UsedByReflection("audio_classifier_jni.cc") + public boolean getIsScoreThresholdSet() { + return isScoreThresholdSet; + } + + @UsedByReflection("audio_classifier_jni.cc") + public List<String> getLabelAllowList() { + return new ArrayList<>(labelAllowList); + } + + @UsedByReflection("audio_classifier_jni.cc") + public List<String> getLabelDenyList() { + return new ArrayList<>(labelDenyList); + } + + public BaseOptions getBaseOptions() { + return baseOptions; + } + + private AudioClassifierOptions(Builder builder) { + displayNamesLocale = builder.displayNamesLocale; + maxResults = builder.maxResults; + scoreThreshold = builder.scoreThreshold; + isScoreThresholdSet = builder.isScoreThresholdSet; + labelAllowList = builder.labelAllowList; + labelDenyList = builder.labelDenyList; + baseOptions = builder.baseOptions; + } + } + + /** + * Performs actual classification on the provided audio tensor. + * + * @param tensor a {@link TensorAudio} containing the input audio clip in float with values + * between [-1, 1). The {@code tensor} argument should have the same flat size as the TFLite + * model's input tensor. It's recommended to create {@code tensor} using {@code + * createInputTensorAudio} method. + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if error occurs when classifying the audio clip from the native + * code + */ + public List<Classifications> classify(TensorAudio tensor) { + TensorBuffer buffer = tensor.getTensorBuffer(); + TensorAudioFormat format = tensor.getFormat(); + checkState(buffer.getBuffer().hasArray(), + "Input tensor buffer should be a non-direct buffer with a backed array (i.e. not readonly" + + " buffer)."); + return classifyNative(getNativeHandle(), buffer.getBuffer().array(), format.getChannels(), + format.getSampleRate()); + } + + /** + * Creates a {@link TensorAudio} instance to store input audio samples. + * + * @return a {@link TensorAudio} with the same size as model input tensor + * @throws IllegalArgumentException if the model is not compatible + */ + public TensorAudio createInputTensorAudio() { + TensorAudioFormat format = getRequiredTensorAudioFormat(); + + long bufferSize = getRequiredInputBufferSize(); + long samples = bufferSize / format.getChannels(); + return TensorAudio.create(format, (int) samples); + } + + /** Returns the required input buffer size in number of float elements. */ + public long getRequiredInputBufferSize() { + return getRequiredInputBufferSizeNative(getNativeHandle()); + } + + /** + * Creates an {@link android.media.AudioRecord} instance to record audio stream. The returned + * AudioRecord instance is initialized and client needs to call {@link + * android.media.AudioRecord#startRecording} method to start recording. + * + * @return an {@link android.media.AudioRecord} instance in {@link + * android.media.AudioRecord#STATE_INITIALIZED} + * @throws IllegalArgumentException if the model required channel count is unsupported + * @throws IllegalStateException if AudioRecord instance failed to initialize + */ + public AudioRecord createAudioRecord() { + TensorAudioFormat format = getRequiredTensorAudioFormat(); + int channelConfig = 0; + + switch (format.getChannels()) { + case 1: + channelConfig = AudioFormat.CHANNEL_IN_MONO; + break; + case 2: + channelConfig = AudioFormat.CHANNEL_IN_STEREO; + break; + default: + throw new IllegalArgumentException(String.format( + "Number of channels required by the model is %d. getAudioRecord method only" + + " supports 1 or 2 audio channels.", + format.getChannels())); + } + + int bufferSizeInBytes = AudioRecord.getMinBufferSize( + format.getSampleRate(), channelConfig, AudioFormat.ENCODING_PCM_FLOAT); + if (bufferSizeInBytes == AudioRecord.ERROR + || bufferSizeInBytes == AudioRecord.ERROR_BAD_VALUE) { + throw new IllegalStateException(String.format( + "AudioRecord.getMinBufferSize failed. Returned: %d", bufferSizeInBytes)); + } + // The buffer of AudioRecord should be strictly longer than what model requires so that + // clients could run `TensorAudio::load(record)` together with `AudioClassifier::classify`. + int bufferSizeMultiplier = 2; + int modelRequiredBufferSize = (int) getRequiredInputBufferSize() + * DataType.FLOAT32.byteSize() * bufferSizeMultiplier; + if (bufferSizeInBytes < modelRequiredBufferSize) { + bufferSizeInBytes = modelRequiredBufferSize; + } + AudioRecord audioRecord = new AudioRecord( + // including MIC, UNPROCESSED, and CAMCORDER. + MediaRecorder.AudioSource.VOICE_RECOGNITION, format.getSampleRate(), channelConfig, + AudioFormat.ENCODING_PCM_FLOAT, bufferSizeInBytes); + checkState(audioRecord.getState() == AudioRecord.STATE_INITIALIZED, + "AudioRecord failed to initialize"); + return audioRecord; + } + + /** Returns the {@link TensorAudioFormat} required by the model. */ + public TensorAudioFormat getRequiredTensorAudioFormat() { + return TensorAudioFormat.builder() + .setChannels(getRequiredChannels()) + .setSampleRate(getRequiredSampleRate()) + .build(); + } + + private int getRequiredChannels() { + return getRequiredChannelsNative(getNativeHandle()); + } + + private int getRequiredSampleRate() { + return getRequiredSampleRateNative(getNativeHandle()); + } + + // TODO(b/183343074): JNI method invocation is very expensive, taking about .2ms + // each time. Consider combining the native getter methods into 1 and cache it in Java layer. + private static native long getRequiredInputBufferSizeNative(long nativeHandle); + + private static native int getRequiredChannelsNative(long nativeHandle); + + private static native int getRequiredSampleRateNative(long nativeHandle); + + private static native List<Classifications> classifyNative( + long nativeHandle, byte[] audioBuffer, int channels, int sampleRate); + + private static native long initJniWithModelFdAndOptions(int fileDescriptor, + long fileDescriptorLength, long fileDescriptorOffset, AudioClassifierOptions options, + long baseOptionsHandle); + + private static native long initJniWithByteBuffer( + ByteBuffer modelBuffer, AudioClassifierOptions options, long baseOptionsHandle); + + /** + * Releases memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier` instance. + * + * @param nativeHandle pointer to memory allocated + */ + @Override + protected void deinit(long nativeHandle) { + deinitJni(nativeHandle); + } + + /** + * Native method to release memory pointed by {@code nativeHandle}, namely a C++ + * `AudioClassifier` instance. + * + * @param nativeHandle pointer to memory allocated + */ + private static native void deinitJni(long nativeHandle); +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/BUILD new file mode 100644 index 0000000..d5678f7 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/BUILD
@@ -0,0 +1,37 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "audio_classifier_src", + srcs = glob(["**/*.java"]), +) + +# Default target that uses BuiltInOpResolver, registers all built-in OPs. +android_library( + name = "audio_classifier", + exports = [ + ":audio_classifier_java", + "//tensorflow_lite_support/java/src/native/task/audio/classifier:audio_classifier_native", + ], +) + +# Java-only target, needs to be used together with a native target similar to +# //third_party/tensorflow_lite_support/java/src/native/task/audio/classifier:audio_classifier_native. +# Use this target when you want to provide a MutableOpResolver with customized +# OPs and/or a subset of BuiltInOps to reduce binary size. +android_library( + name = "audio_classifier_java", + srcs = [":audio_classifier_src"], + javacopts = ["-source 7 -target 7"], + manifest = "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio:AndroidManifest.xml", + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support_java", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", + "@com_google_auto_value", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java_stable", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java new file mode 100644 index 0000000..7d5b07f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java
@@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.audio.classifier; + +import com.google.auto.value.AutoValue; + +import org.tensorflow.lite.annotations.UsedByReflection; +import org.tensorflow.lite.support.label.Category; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * The classification results of one head in a multihead (a.k.a. multi-output) {@link + * AudioClassifier}. A multihead {@link AudioClassifier} can perform classification for multiple + * purposes, such as a fine grained classifier to distinguish different bird sounds. + */ +// TODO(b/183343074): Create a common class that can be used for both Audio and Vision tasks. +@AutoValue +@UsedByReflection("audio_classifier_jni.cc") +public abstract class Classifications { + @UsedByReflection("audio_classifier_jni.cc") + static Classifications create(List<Category> categories, int headIndex, String headName) { + return new AutoValue_Classifications( + Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex, + headName); + } + + // Same reason for not using ImmutableList as stated in + // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}. + public abstract List<Category> getCategories(); + + public abstract int getHeadIndex(); + + public abstract String getHeadName(); +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD index f82b800..4f3e538 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD
@@ -1,5 +1,4 @@ load("@build_bazel_rules_android//android:rules.bzl", "android_library") -load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") package( default_visibility = ["//tensorflow_lite_support:users"], @@ -9,11 +8,13 @@ android_library( name = "base-task-api", srcs = glob(["**/*.java"]), - javacopts = JAVACOPTS, + javacopts = ["-source 7 -target 7"], visibility = ["//visibility:public"], + # LINT.IfChange(dep) deps = [ "@com_google_auto_value", ], + # LINT.ThenChange(<INTERNAL>/release/build_task_base_pom.sh:dep) ) alias(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java new file mode 100644 index 0000000..b2d7223 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java
@@ -0,0 +1,85 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.core; + +import com.google.auto.value.AutoValue; + +/** Options to configure Task APIs in general. */ +@AutoValue +public abstract class BaseOptions { + private static final int DEFAULT_NUM_THREADS = -1; + + /** Builder for {@link BaseOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** + * Sets the advanced accelerator options. + * + * <p>Note: this method will override those highlevel API to choose an delegate, such as + * {@link #useGpu} and {@link #useNnapi}. + */ + public abstract Builder setComputeSettings(ComputeSettings computeSettings); + + /** + * Sets the number of threads to be used for TFLite ops that support multi-threading when + * running inference with CPU. Defaults to -1. + * + * <p>{@code numThreads} should be greater than 0 or equal to -1. Setting numThreads to -1 + * has the effect to let TFLite runtime set the value. + */ + public abstract Builder setNumThreads(int numThreads); + + /** + * Uses GPU for inference. The advanced GPU configuration settings will be set to default + * values. + * + * <p>Note: this method will override the settings from {@link #setComputeSettings}. + * + * <p>To manipulate the advanced GPU configuration settings, use {@link + * #setComputeSettings}. + */ + public Builder useGpu() { + return setComputeSettings( + ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.GPU).build()); + } + + /** + * Uses NNAPI for inference. The advanced NNAPI configuration settings will be set to + * default values. + * + * <p>Note: this method will override the settings from {@link #setComputeSettings}. + * + * <p>To manipulate the advanced NNAPI configuration settings, use {@link + * #setComputeSettings}. + */ + public Builder useNnapi() { + return setComputeSettings( + ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.NNAPI).build()); + } + + public abstract BaseOptions build(); + } + + public static Builder builder() { + return new AutoValue_BaseOptions.Builder() + .setComputeSettings(ComputeSettings.builder().build()) + .setNumThreads(DEFAULT_NUM_THREADS); + } + + abstract ComputeSettings getComputeSettings(); + + abstract int getNumThreads(); +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java new file mode 100644 index 0000000..0c2d042 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java
@@ -0,0 +1,55 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.core; + +import com.google.auto.value.AutoValue; + +/** Options to configure how to accelerate the model inference using dedicated delegates. */ +@AutoValue +public abstract class ComputeSettings { + /** TFLite accelerator delegate options. */ + public enum Delegate { + NONE(0), + NNAPI(1), + GPU(2); + + private final int value; + + Delegate(int value) { + this.value = value; + } + + public int getValue() { + return value; + } + } + + /** Builder for {@link ComputeSettings}. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setDelegate(Delegate delegate); + + public abstract ComputeSettings build(); + } + + public static Builder builder() { + return new AutoValue_ComputeSettings.Builder().setDelegate(DEFAULT_DELEGATE); + } + + public abstract Delegate getDelegate(); + + private static final Delegate DEFAULT_DELEGATE = Delegate.NONE; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java index 7dd1022..9d5b7754 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java
@@ -84,7 +84,7 @@ tryLoadLibrary(libName); try { return provider.createHandle(); - } catch (Exception e) { + } catch (RuntimeException e) { String errorMessage = "Error getting native address of native library: " + libName; Log.e(TAG, errorMessage, e); throw new IllegalStateException(errorMessage, e); @@ -137,14 +137,12 @@ } } - private TaskJniUtils() {} - /** - * Try load a native library, if it's already loaded return directly. + * Try loading a native library, if it's already loaded return directly. * * @param libName name of the lib */ - static void tryLoadLibrary(String libName) { + public static void tryLoadLibrary(String libName) { try { System.loadLibrary(libName); } catch (UnsatisfiedLinkError e) { @@ -153,4 +151,22 @@ throw new UnsatisfiedLinkError(errorMessage); } } + + public static long createProtoBaseOptionsHandle(BaseOptions baseOptions) { + return createProtoBaseOptionsHandleWithLegacyNumThreads( + baseOptions, /*legacyNumThreads =*/-1); + } + + public static long createProtoBaseOptionsHandleWithLegacyNumThreads( + BaseOptions baseOptions, int legacyNumThreads) { + // NumThreads should be configured through BaseOptions. However, if NumThreads is configured + // through the legacy API of the Task Java API (then it will not equal to -1, the default + // value), use it to overide the one in baseOptions. + return createProtoBaseOptions(baseOptions.getComputeSettings().getDelegate().getValue(), + legacyNumThreads == -1 ? baseOptions.getNumThreads() : legacyNumThreads); + } + + private TaskJniUtils() {} + + private static native long createProtoBaseOptions(int delegate, int numThreads); }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java index 1107348..b1784d02 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java
@@ -41,7 +41,8 @@ * expressed in the unrotated frame of reference coordinates system, i.e. in {@code [0, * TensorImage.getWidth()) x [0, TensorImage.getHeight())}, which are the dimensions of the * underlying image data before any orientation gets applied. If the region is out of these bounds, - * the inference method, such as {@link ImageClassifier#classify}, will return error. + * the inference method, such as {@link + * org.tensorflow.lite.task.vision.classifier.ImageClassifier#classify}, will return error. */ @AutoValue public abstract class ImageProcessingOptions {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/AndroidManifest.xml b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/AndroidManifest.xml index d4d1dbad..107e933 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/AndroidManifest.xml +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/AndroidManifest.xml
@@ -1,5 +1,5 @@ <?xml version="1.0" encoding="utf-8"?> <manifest xmlns:android="http://schemas.android.com/apk/res/android" package="org.tensorflow.lite.task.text"> - <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/> + <uses-sdk android:minSdkVersion="21" android:targetSdkVersion="29"/> </manifest>
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD index 695e1bef..2d9837e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD
@@ -1,30 +1,34 @@ -load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "android_library_with_tflite") load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") -load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") package( - default_visibility = ["//tensorflow_lite_support:users"], + default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 ) exports_files(["AndroidManifest.xml"]) -android_library( +android_library_with_tflite( name = "task_library_text", srcs = [ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier:nl_classifier_src", "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa:bert_question_answerer_src", ], - javacopts = JAVACOPTS, + javacopts = ["-source 7 -target 7"], manifest = "AndroidManifest.xml", + tflite_exports = [ + "//tensorflow_lite_support/java/src/native/task/text:task_text_native", + ], visibility = ["//visibility:public"], + # LINT.IfChange(dep) deps = [ "//tensorflow_lite_support/java:tensorflowlite_support_java", "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", - "//tensorflow_lite_support/java/src/native/task/text:task_text_native", "@com_google_auto_value", - "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + "@maven//:androidx_annotation_annotation", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java_stable", ], + # LINT.ThenChange(<INTERNAL>/release/build_task_pom.sh:dep) ) # AAR target for OSS release. @@ -34,4 +38,10 @@ aar_with_jni( name = "task-library-text", android_library = ":task_library_text", + headers = [ + "//tensorflow_lite_support/c/task/text:bert_nl_classifier.h", + "//tensorflow_lite_support/c/task/text:nl_classifier.h", + "//tensorflow_lite_support/c/task/text:nl_classifier_common.h", + "//tensorflow_lite_support/c/task/text:bert_question_answerer.h", + ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BUILD index a1d78d8..0079bfda9 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BUILD
@@ -1,9 +1,9 @@ -load("@build_bazel_rules_android//android:rules.bzl", "android_library") -load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "android_library_with_tflite") load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") +load("@build_bazel_rules_android//android:rules.bzl", "android_library") package( - default_visibility = ["//tensorflow_lite_support:users"], + default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 ) @@ -12,6 +12,17 @@ srcs = glob(["**/*.java"]), ) +# Default target that uses BuiltInOpResolver, registers all built-in OPs. +android_library_with_tflite( + name = "nl_classifier", + tflite_exports = [ + "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_native", + ], + exports = [ + ":nl_classifier_java", + ], +) + # Java-only target, need to be used together with a native target similar to # third_party/tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_native. # Use this target when you want to provide a MutableOpResolver with customized @@ -21,28 +32,13 @@ srcs = [ "NLClassifier.java", ], - javacopts = JAVACOPTS, + javacopts = ["-source 7 -target 7"], deps = [ "//tensorflow_lite_support/java:tensorflowlite_support_java", "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", "@com_google_auto_value", - "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", - ], -) - -# Default target that uses BuiltInOpResolver, registers all built-in OPs. -android_library( - name = "nl_classifier", - srcs = [ - "NLClassifier.java", - ], - javacopts = JAVACOPTS, - deps = [ - "//tensorflow_lite_support/java:tensorflowlite_support_java", - "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", - "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_native", - "@com_google_auto_value", - "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + "@maven//:androidx_annotation_annotation", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java_stable", ], ) @@ -56,16 +52,20 @@ ) # Default target that uses BuiltInOpResolver, registers all built-in OPs. -android_library( +android_library_with_tflite( name = "bert_nl_classifier", srcs = [ "BertNLClassifier.java", ], - javacopts = JAVACOPTS, + javacopts = ["-source 7 -target 7"], + tflite_exports = [ + "//tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert:bert_nl_classifier_native", + ], deps = [ "//tensorflow_lite_support/java:tensorflowlite_support_java", "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", - "//tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier:bert_nl_classifier_native", + "@com_google_auto_value", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java index 74cb686..ce912c9 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
@@ -18,7 +18,11 @@ import android.content.Context; import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; + +import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.task.core.BaseOptions; import org.tensorflow.lite.task.core.BaseTaskApi; import org.tensorflow.lite.task.core.TaskJniUtils; import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; @@ -46,6 +50,171 @@ public class BertNLClassifier extends BaseTaskApi { private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni"; + /** Options to configure BertNLClassifier. */ + @AutoValue + @UsedByReflection("bert_nl_classifier_jni.cc") + public abstract static class BertNLClassifierOptions { + static final int DEFAULT_MAX_SEQ_LEN = 128; + + abstract int getMaxSeqLen(); + + abstract BaseOptions getBaseOptions(); + + public static Builder builder() { + return new AutoValue_BertNLClassifier_BertNLClassifierOptions.Builder() + .setMaxSeqLen(DEFAULT_MAX_SEQ_LEN) + .setBaseOptions(BaseOptions.builder().build()); + } + + /** Builder for {@link BertNLClassifierOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the general options to configure Task APIs, such as accelerators. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** + * Set the maximum sequence length. + * + * @deprecated maximum sequence length is now read from the model (i.e. input tensor + * size) + * automatically + */ + @Deprecated + public abstract Builder setMaxSeqLen(int value); + + public abstract BertNLClassifierOptions build(); + } + } + + /** + * Creates {@link BertNLClassifier} from a model file with metadata and default {@link + * BertNLClassifierOptions}. + * + * @param context Android context + * @param modelPath Path to the classification model + * @return a {@link BertNLClassifier} instance + * @throws IOException If model file fails to load + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static BertNLClassifier createFromFile(final Context context, final String modelPath) + throws IOException { + return createFromBuffer(TaskJniUtils.loadMappedFile(context, modelPath)); + } + + /** + * Creates {@link BertNLClassifier} from a {@link File} object with metadata and default {@link + * BertNLClassifierOptions}. + * + * @param modelFile The classification model {@link File} instance + * @return a {@link BertNLClassifier} instance + * @throws IOException If model file fails to load + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static BertNLClassifier createFromFile(File modelFile) throws IOException { + return createFromFileAndOptions(modelFile, BertNLClassifierOptions.builder().build()); + } + + /** + * Creates {@link BertNLClassifier} from a model file with metadata and {@link + * BertNLClassifierOptions}. + * + * @param context Android context. + * @param modelPath Path to the classification model + * @param options to configure the classifier + * @return a {@link BertNLClassifier} instance + * @throws IOException If model file fails to load + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static BertNLClassifier createFromFileAndOptions(final Context context, + final String modelPath, BertNLClassifierOptions options) throws IOException { + return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options); + } + + /** + * Creates {@link BertNLClassifier} from a {@link File} object with metadata and {@link + * BertNLClassifierOptions}. + * + * @param modelFile The classification model {@link File} instance + * @param options to configure the classifier + * @return a {@link BertNLClassifier} instance + * @throws IOException If model file fails to load + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static BertNLClassifier createFromFileAndOptions( + File modelFile, final BertNLClassifierOptions options) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + return new BertNLClassifier( + TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithFileDescriptor(descriptor.getFd(), options, + TaskJniUtils.createProtoBaseOptionsHandle( + options.getBaseOptions())); + } + }, BERT_NL_CLASSIFIER_NATIVE_LIBNAME)); + } + } + + /** + * Creates {@link BertNLClassifier} with a model buffer and default {@link + * BertNLClassifierOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model + * @return a {@link BertNLClassifier} instance + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static BertNLClassifier createFromBuffer(final ByteBuffer modelBuffer) { + return createFromBufferAndOptions(modelBuffer, BertNLClassifierOptions.builder().build()); + } + + /** + * Creates {@link BertNLClassifier} with a model buffer and {@link BertNLClassifierOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model + * @param options to configure the classifier + * @return a {@link BertNLClassifier} instance + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static BertNLClassifier createFromBufferAndOptions( + final ByteBuffer modelBuffer, final BertNLClassifierOptions options) { + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { + throw new IllegalArgumentException( + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + return new BertNLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithByteBuffer(modelBuffer, options, + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions())); + } + }, BERT_NL_CLASSIFIER_NATIVE_LIBNAME)); + } + + /** + * Performs classification on a string input, returns classified {@link Category}s. + * + * @param text input text to the model. + * @return A list of Category results. + */ + public List<Category> classify(String text) { + return classifyNative(getNativeHandle(), text); + } + /** * Constructor to initialize the JNI with a pointer from C++. * @@ -55,67 +224,11 @@ super(nativeHandle); } - /** - * Create {@link BertNLClassifier} from a model file with metadata. - * - * @param context Android context - * @param pathToModel Path to the classification model. - * @return {@link BertNLClassifier} instance. - * @throws IOException If model file fails to load. - */ - public static BertNLClassifier createFromFile(final Context context, final String pathToModel) - throws IOException { - return createFromBuffer(TaskJniUtils.loadMappedFile(context, pathToModel)); - } + private static native long initJniWithByteBuffer( + ByteBuffer modelBuffer, BertNLClassifierOptions options, long baseOptionsHandle); - /** - * Create {@link BertNLClassifier} from a {@link File} object with metadata. - * - * @param modelFile The classification model {@link File} instance. - * @return {@link BertNLClassifier} instance. - * @throws IOException If model file fails to load. - */ - public static BertNLClassifier createFromFile(File modelFile) throws IOException { - try (ParcelFileDescriptor descriptor = - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { - return new BertNLClassifier( - TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { - @Override - public long createHandle() { - return initJniWithFileDescriptor(descriptor.getFd()); - } - }, BERT_NL_CLASSIFIER_NATIVE_LIBNAME)); - } - } - - /** - * Create {@link BertNLClassifier} with {@link MappedByteBuffer}. - * - * @param modelBuffer In memory buffer of the model. - * @return {@link BertNLClassifier} instance. - */ - public static BertNLClassifier createFromBuffer(final MappedByteBuffer modelBuffer) { - return new BertNLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { - @Override - public long createHandle() { - return initJniWithByteBuffer(modelBuffer); - } - }, BERT_NL_CLASSIFIER_NATIVE_LIBNAME)); - } - - /** - * Perform classification on a string input, returns classified {@link Category}s. - * - * @param text input text to the model. - * @return A list of Category results. - */ - public List<Category> classify(String text) { - return classifyNative(getNativeHandle(), text); - } - - private static native long initJniWithByteBuffer(ByteBuffer modelBuffer); - - private static native long initJniWithFileDescriptor(int fd); + private static native long initJniWithFileDescriptor( + int fd, BertNLClassifierOptions options, long baseOptionsHandle); private static native List<Category> classifyNative(long nativeHandle, String text);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java index 082b662..b8aa32be 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
@@ -18,10 +18,13 @@ import android.content.Context; import android.os.ParcelFileDescriptor; +import androidx.annotation.Nullable; + import com.google.auto.value.AutoValue; import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.task.core.BaseOptions; import org.tensorflow.lite.task.core.BaseTaskApi; import org.tensorflow.lite.task.core.TaskJniUtils; import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; @@ -82,22 +85,25 @@ private static final String DEFAULT_OUTPUT_LABEL_TENSOR_NAME = "OUTPUT_LABEL"; @UsedByReflection("nl_classifier_jni.cc") - abstract int inputTensorIndex(); + abstract int getInputTensorIndex(); @UsedByReflection("nl_classifier_jni.cc") - abstract int outputScoreTensorIndex(); + abstract int getOutputScoreTensorIndex(); @UsedByReflection("nl_classifier_jni.cc") - abstract int outputLabelTensorIndex(); + abstract int getOutputLabelTensorIndex(); @UsedByReflection("nl_classifier_jni.cc") - abstract String inputTensorName(); + abstract String getInputTensorName(); @UsedByReflection("nl_classifier_jni.cc") - abstract String outputScoreTensorName(); + abstract String getOutputScoreTensorName(); @UsedByReflection("nl_classifier_jni.cc") - abstract String outputLabelTensorName(); + abstract String getOutputLabelTensorName(); + + @Nullable + abstract BaseOptions getBaseOptions(); public static Builder builder() { return new AutoValue_NLClassifier_NLClassifierOptions.Builder() @@ -112,17 +118,101 @@ /** Builder for {@link NLClassifierOptions}. */ @AutoValue.Builder public abstract static class Builder { - public abstract Builder setInputTensorIndex(int value); + /** Sets the general options to configure Task APIs, such as accelerators. */ + public abstract Builder setBaseOptions(@Nullable BaseOptions baseOptions); - public abstract Builder setOutputScoreTensorIndex(int value); + /** + * Configure the input/output tensors for NLClassifier: + * + * <p>- No special configuration is needed if the model has only one input tensor and + * one output tensor. + * + * <p>- When the model has multiple input or output tensors, use the following + * configurations to specifiy the desired tensors: <br> + * -- tensor names: {@code inputTensorName}, {@code outputScoreTensorName}, {@code + * outputLabelTensorName}<br> + * -- tensor indices: {@code inputTensorIndex}, {@code outputScoreTensorIndex}, {@code + * outputLabelTensorIndex} <br> + * Tensor names has higher priorities than tensor indices in locating the tensors. It + * means the tensors will be first located according to tensor names. If not found, then + * the tensors will be located according to tensor indices. + * + * <p>- Failing to match the input text tensor or output score tensor with neither + * tensor names nor tensor indices will trigger a runtime error. However, failing to + * locate the output label tensor will not trigger an error because the label tensor is + * optional. + */ - public abstract Builder setOutputLabelTensorIndex(int value); + /** + * Set the name of the input text tensor, if the model has multiple inputs. Only the + * input tensor specified will be used for inference; other input tensors will be + * ignored. Dafualt to {@code "INPUT"}. + * + * <p>See the section, Configure the input/output tensors for NLClassifier, for more + * details. + */ + public abstract Builder setInputTensorName(String inputTensorName); - public abstract Builder setInputTensorName(String value); + /** + * Set the name of the output score tensor, if the model has multiple outputs. Dafualt + * to + * {@code "OUTPUT_SCORE"}. + * + * <p>See the section, Configure the input/output tensors for NLClassifier, for more + * details. + */ + public abstract Builder setOutputScoreTensorName(String outputScoreTensorName); - public abstract Builder setOutputScoreTensorName(String value); + /** + * Set the name of the output label tensor, if the model has multiple outputs. Dafualt + * to + * {@code "OUTPUT_LABEL"}. + * + * <p>See the section, Configure the input/output tensors for NLClassifier, for more + * details. + * + * <p>By default, label file should be packed with the output score tensor through Model + * Metadata. See the <a + * href="https://www.tensorflow.org/lite/convert/metadata_writer_tutorial#natural_language_classifiers">MetadataWriter + * for NLClassifier</a>. NLClassifier reads and parses labels from the label file + * automatically. However, some models may output a specific label tensor instead. In + * this case, NLClassifier reads labels from the output label tensor. + */ + public abstract Builder setOutputLabelTensorName(String outputLabelTensorName); - public abstract Builder setOutputLabelTensorName(String value); + /** + * Set the index of the input text tensor among all input tensors, if the model has + * multiple inputs. Only the input tensor specified will be used for inference; other + * input tensors will be ignored. Dafualt to 0. + * + * <p>See the section, Configure the input/output tensors for NLClassifier, for more + * details. + */ + public abstract Builder setInputTensorIndex(int inputTensorIndex); + + /** + * Set the index of the output score tensor among all output tensors, if the model has + * multiple outputs. Dafualt to 0. + * + * <p>See the section, Configure the input/output tensors for NLClassifier, for more + * details. + */ + public abstract Builder setOutputScoreTensorIndex(int outputScoreTensorIndex); + + /** + * Set the index of the optional output label tensor among all output tensors, if the + * model has multiple outputs. + * + * <p>See the document above {@code outputLabelTensorName} for more information about + * what the output label tensor is. + * + * <p>See the section, Configure the input/output tensors for NLClassifier, for more + * details. + * + * <p>{@code outputLabelTensorIndex} dafualts to -1, meaning to disable the output label + * tensor. + */ + public abstract Builder setOutputLabelTensorIndex(int outputLabelTensorIndex); public abstract NLClassifierOptions build(); } @@ -131,6 +221,121 @@ private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni"; /** + * Creates {@link NLClassifier} from default {@link NLClassifierOptions}. + * + * @param context Android context + * @param modelPath path to the classification model relative to asset dir + * @return an {@link NLClassifier} instance + * @throws IOException if model file fails to load + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static NLClassifier createFromFile(Context context, String modelPath) + throws IOException { + return createFromFileAndOptions(context, modelPath, NLClassifierOptions.builder().build()); + } + + /** + * Creates {@link NLClassifier} from default {@link NLClassifierOptions}. + * + * @param modelFile the classification model {@link File} instance + * @return an {@link NLClassifier} instance + * @throws IOException if model file fails to load + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static NLClassifier createFromFile(File modelFile) throws IOException { + return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build()); + } + + /** + * Creates {@link NLClassifier} from {@link NLClassifierOptions}. + * + * @param context Android context + * @param modelPath path to the classification model relative to asset dir + * @param options configurations for the model. + * @return an {@link NLClassifier} instance + * @throws IOException if model file fails to load + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static NLClassifier createFromFileAndOptions( + Context context, String modelPath, NLClassifierOptions options) throws IOException { + return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options); + } + + /** + * Creates {@link NLClassifier} from {@link NLClassifierOptions}. + * + * @param modelFile the classification model {@link File} instance + * @param options configurations for the model + * @return an {@link NLClassifier} instance + * @throws IOException if model file fails to load + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static NLClassifier createFromFileAndOptions( + File modelFile, final NLClassifierOptions options) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + return new NLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { + long baseOptionsHandle = options.getBaseOptions() == null + ? 0 // pass an invalid native handle + : TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()); + return initJniWithFileDescriptor( + options, descriptor.getFd(), baseOptionsHandle); + } + }, NL_CLASSIFIER_NATIVE_LIBNAME)); + } + } + + /** + * Creates {@link NLClassifier} with a model {@link ByteBuffer} and {@link NLClassifierOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model + * @param options configurations for the model + * @return {@link NLClassifier} instance + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + */ + public static NLClassifier createFromBufferAndOptions( + final ByteBuffer modelBuffer, final NLClassifierOptions options) { + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { + throw new IllegalArgumentException( + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + + return new NLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { + long baseOptionsHandle = options.getBaseOptions() == null + ? 0 // pass an invalid native handle + : TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()); + return initJniWithByteBuffer(options, modelBuffer, baseOptionsHandle); + } + }, NL_CLASSIFIER_NATIVE_LIBNAME)); + } + + /** + * Performs classification on a string input, returns classified {@link Category}s. + * + * @param text input text to the model + * @return a list of Category results + */ + public List<Category> classify(String text) { + return classifyNative(getNativeHandle(), text); + } + + /** * Constructor to initialize the JNI with a pointer from C++. * * @param nativeHandle a pointer referencing memory allocated in C++. @@ -139,106 +344,19 @@ super(nativeHandle); } - /** - * Create {@link NLClassifier} from default {@link NLClassifierOptions}. - * - * @param context Android context. - * @param pathToModel Path to the classification model relative to asset dir. - * @return {@link NLClassifier} instance. - * @throws IOException If model file fails to load. - */ - public static NLClassifier createFromFile(Context context, String pathToModel) - throws IOException { - return createFromFileAndOptions( - context, pathToModel, NLClassifierOptions.builder().build()); - } - - /** - * Create {@link NLClassifier} from default {@link NLClassifierOptions}. - * - * @param modelFile The classification model {@link File} instance. - * @return {@link NLClassifier} instance. - * @throws IOException If model file fails to load. - */ - public static NLClassifier createFromFile(File modelFile) throws IOException { - return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build()); - } - - /** - * Create {@link NLClassifier} from {@link NLClassifierOptions}. - * - * @param context Android context - * @param pathToModel Path to the classification model relative to asset dir. - * @param options Configurations for the model. - * @return {@link NLClassifier} instance. - * @throws IOException If model file fails to load. - */ - public static NLClassifier createFromFileAndOptions( - Context context, String pathToModel, NLClassifierOptions options) throws IOException { - return createFromBufferAndOptions( - TaskJniUtils.loadMappedFile(context, pathToModel), options); - } - - /** - * Create {@link NLClassifier} from {@link NLClassifierOptions}. - * - * @param modelFile The classification model {@link File} instance. - * @param options Configurations for the model. - * @return {@link NLClassifier} instance. - * @throws IOException If model file fails to load. - */ - public static NLClassifier createFromFileAndOptions( - File modelFile, final NLClassifierOptions options) throws IOException { - try (ParcelFileDescriptor descriptor = - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { - return new NLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { - @Override - public long createHandle() { - return initJniWithFileDescriptor(options, descriptor.getFd()); - } - }, NL_CLASSIFIER_NATIVE_LIBNAME)); - } - } - - /** - * Create {@link NLClassifier} with {@link MappedByteBuffer} from {@link NLClassifierOptions}. - * - * @param modelBuffer In memory buffer of the classification model. - * @param options Configurations for the model. - * @return {@link NLClassifier} instance. - */ - public static NLClassifier createFromBufferAndOptions( - final ByteBuffer modelBuffer, final NLClassifierOptions options) { - return new NLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { - @Override - public long createHandle() { - return initJniWithByteBuffer(options, modelBuffer); - } - }, NL_CLASSIFIER_NATIVE_LIBNAME)); - } - - /** - * Perform classification on a string input, returns classified {@link Category}s. - * - * @param text input text to the model. - * @return A list of Category results. - */ - public List<Category> classify(String text) { - return classifyNative(getNativeHandle(), text); - } - - private static native long initJniWithByteBuffer( - NLClassifierOptions options, ByteBuffer modelBuffer); - - private static native long initJniWithFileDescriptor(NLClassifierOptions options, int fd); - - private static native List<Category> classifyNative(long nativeHandle, String text); - @Override protected void deinit(long nativeHandle) { deinitJni(nativeHandle); } + private static native long initJniWithByteBuffer( + NLClassifierOptions options, ByteBuffer modelBuffer, long baseOptionsHandle); + + private static native long initJniWithFileDescriptor( + NLClassifierOptions options, int fd, long baseOptionsHandle); + + private static native List<Category> classifyNative(long nativeHandle, String text); + /** * Native implementation to release memory pointed by the pointer. *
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BUILD index 3dad142..0d35eb8 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BUILD
@@ -1,9 +1,9 @@ -load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") -load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "android_library_with_tflite") load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") +load("@build_bazel_rules_android//android:rules.bzl", "android_library") package( - default_visibility = ["//tensorflow_lite_support:users"], + default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 ) @@ -12,14 +12,29 @@ srcs = glob(["**/*.java"]), ) -android_library( +# Default target that uses BuiltInOpResolver, registers all built-in OPs. +android_library_with_tflite( name = "bert_question_answerer", + tflite_exports = [ + "//tensorflow_lite_support/java/src/native/task/text/qa:bert_question_answerer_native", + ], + exports = [ + ":bert_question_answerer_java", + ], +) + +# Java-only target, need to be used together with a native target similar to +# //third_party/tensorflow_lite_support/java/src/native/task/text/qa:bert_question_answerer_native. +# Use this target when you want to provide a MutableOpResolver with customized +# OPs and/or a subset of BuiltInOps to reduce binary size. +android_library( + name = "bert_question_answerer_java", srcs = glob(["*.java"]), - javacopts = JAVACOPTS, + javacopts = ["-source 7 -target 7"], deps = [ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", - "//tensorflow_lite_support/java/src/native/task/text/qa:bert_question_answerer_native", - "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + "@com_google_auto_value", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java_stable", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java index fb91562..39648d9 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
@@ -18,9 +18,13 @@ import android.content.Context; import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; + +import org.tensorflow.lite.task.core.BaseOptions; import org.tensorflow.lite.task.core.BaseTaskApi; import org.tensorflow.lite.task.core.TaskJniUtils; import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; +import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider; import java.io.File; @@ -28,122 +32,186 @@ import java.nio.ByteBuffer; import java.util.List; -/** Task API for BertQA models. */ +/** + * Returns the most possible answers on a given question for QA models (BERT, Albert, etc.). + * + * <p>The API expects a Bert based TFLite model with metadata containing the following information: + * + * <ul> + * <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be used + * for a <a + * href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a> model, + * Sentencepiece Tokenizer Tokenizer can be used for an <a + * href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a> model. + * <li>3 input tensors with names "ids", "mask" and "segment_ids". + * <li>2 output tensors with names "end_logits" and "start_logits". + * </ul> + */ public class BertQuestionAnswerer extends BaseTaskApi implements QuestionAnswerer { private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni"; - - private BertQuestionAnswerer(long nativeHandle) { - super(nativeHandle); - } + private static final int OPTIONAL_FD_LENGTH = -1; + private static final int OPTIONAL_FD_OFFSET = -1; /** - * Generic API to create the QuestionAnswerer for bert models with metadata populated. The API - * expects a Bert based TFLite model with metadata containing the following information: - * - * <ul> - * <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be - * used for a <a - * href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a> - * model, Sentencepiece Tokenizer Tokenizer can be used for an <a - * href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a> - * model. - * <li>3 input tensors with names "ids", "mask" and "segment_ids". - * <li>2 output tensors with names "end_logits" and "start_logits". - * </ul> + * Creates a {@link BertQuestionAnswerer} instance from the default {@link + * BertQuestionAnswererOptions}. * * @param context android context - * @param pathToModel file path to the model with metadata. Note: The model should not be - * compressed - * @return {@link BertQuestionAnswerer} instance - * @throws IOException If model file fails to load. + * @param modelPath file path to the model with metadata. Note: The model should not be + * compressed + * @return a {@link BertQuestionAnswerer} instance + * @throws IOException if model file fails to load + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error */ - public static BertQuestionAnswerer createFromFile(Context context, String pathToModel) + public static BertQuestionAnswerer createFromFile(Context context, String modelPath) throws IOException { - return new BertQuestionAnswerer(TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( - context, new MultipleBuffersHandleProvider() { - @Override - public long createHandle(ByteBuffer... buffers) { - return BertQuestionAnswerer.initJniWithModelWithMetadataByteBuffers( - buffers); - } - }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, pathToModel)); + return createFromFileAndOptions( + context, modelPath, BertQuestionAnswererOptions.builder().build()); } /** - * Generic API to create the QuestionAnswerer for bert models with metadata populated. The API - * expects a Bert based TFLite model with metadata containing the following information: + * Creates a {@link BertQuestionAnswerer} instance from the default {@link + * BertQuestionAnswererOptions}. * - * <ul> - * <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be - * used for a <a - * href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a> - * model, Sentencepiece Tokenizer Tokenizer can be used for an <a - * href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a> - * model. - * <li>3 input tensors with names "ids", "mask" and "segment_ids". - * <li>2 output tensors with names "end_logits" and "start_logits". - * </ul> - * - * @param modelFile {@link File} object of the model - * @return {@link BertQuestionAnswerer} instance - * @throws IOException If model file fails to load. + * @param modelFile a {@link File} object of the model + * @return a {@link BertQuestionAnswerer} instance + * @throws IOException if model file fails to load + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error */ public static BertQuestionAnswerer createFromFile(File modelFile) throws IOException { + return createFromFileAndOptions(modelFile, BertQuestionAnswererOptions.builder().build()); + } + + /** + * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}. + * + * @param context android context + * @param modelPath file path to the model with metadata. Note: The model should not be + * compressed + * @return a {@link BertQuestionAnswerer} instance + * @throws IOException if model file fails to load + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static BertQuestionAnswerer createFromFileAndOptions(Context context, String modelPath, + BertQuestionAnswererOptions options) throws IOException { + return new BertQuestionAnswerer(TaskJniUtils.createHandleFromFdAndOptions( + context, new FdAndOptionsHandleProvider<BertQuestionAnswererOptions>() { + @Override + public long createHandle(int fileDescriptor, long fileDescriptorLength, + long fileDescriptorOffset, BertQuestionAnswererOptions options) { + return initJniWithFileDescriptor(fileDescriptor, fileDescriptorLength, + fileDescriptorOffset, + TaskJniUtils.createProtoBaseOptionsHandle( + options.getBaseOptions())); + } + }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, options)); + } + + /** + * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}. + * + * @param modelFile a {@link File} object of the model + * @return a {@link BertQuestionAnswerer} instance + * @throws IOException if model file fails to load + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static BertQuestionAnswerer createFromFileAndOptions( + File modelFile, final BertQuestionAnswererOptions options) throws IOException { try (ParcelFileDescriptor descriptor = ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { return new BertQuestionAnswerer( TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { @Override public long createHandle() { - return initJniWithFileDescriptor(descriptor.getFd()); + return initJniWithFileDescriptor( + /*fileDescriptor=*/descriptor.getFd(), + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH, + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, + TaskJniUtils.createProtoBaseOptionsHandle( + options.getBaseOptions())); } }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME)); } } /** - * Creates the API instance with a bert model and vocabulary file. + * Creates a {@link BertQuestionAnswerer} instance with a Bert model and a vocabulary file. * * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 * * @param context android context - * @param pathToModel file path to the bert model. Note: The model should not be compressed - * @param pathToVocab file path to the vocabulary file. Note: The file should not be compressed - * @return {@link BertQuestionAnswerer} instance - * @throws IOException If model file fails to load. + * @param modelPath file path to the Bert model. Note: The model should not be compressed + * @param vocabPath file path to the vocabulary file. Note: The file should not be compressed + * @return a {@link BertQuestionAnswerer} instance + * @throws IOException If model file fails to load + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error */ public static BertQuestionAnswerer createBertQuestionAnswererFromFile( - Context context, String pathToModel, String pathToVocab) throws IOException { + Context context, String modelPath, String vocabPath) throws IOException { return new BertQuestionAnswerer(TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( context, new MultipleBuffersHandleProvider() { @Override public long createHandle(ByteBuffer... buffers) { - return BertQuestionAnswerer.initJniWithBertByteBuffers(buffers); + return initJniWithBertByteBuffers(buffers); } - }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, pathToModel, pathToVocab)); + }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, vocabPath)); } /** - * Creates the API instance with an albert model and sentence piece model file. + * Creates a {@link BertQuestionAnswerer} instance with an Albert model and a sentence piece + * model file. * * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 * * @param context android context - * @param pathToModel file path to the albert model. Note: The model should not be compressed - * @param pathToSentencePieceModel file path to the sentence piece model file. Note: The model + * @param modelPath file path to the Albert model. Note: The model should not be compressed + * @param sentencePieceModelPath file path to the sentence piece model file. Note: The model * should not be compressed - * @return {@link BertQuestionAnswerer} instance - * @throws IOException If model file fails to load. + * @return a {@link BertQuestionAnswerer} instance + * @throws IOException If model file fails to load + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error */ - public static BertQuestionAnswerer createAlbertQuestionAnswererFromFile(Context context, - String pathToModel, String pathToSentencePieceModel) throws IOException { + public static BertQuestionAnswerer createAlbertQuestionAnswererFromFile( + Context context, String modelPath, String sentencePieceModelPath) throws IOException { return new BertQuestionAnswerer(TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( context, new MultipleBuffersHandleProvider() { @Override public long createHandle(ByteBuffer... buffers) { - return BertQuestionAnswerer.initJniWithAlbertByteBuffers(buffers); + return initJniWithAlbertByteBuffers(buffers); } - }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, pathToModel, pathToSentencePieceModel)); + }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, sentencePieceModelPath)); + } + + /** Options for setting up a {@link BertQuestionAnswerer}. */ + @AutoValue + public abstract static class BertQuestionAnswererOptions { + abstract BaseOptions getBaseOptions(); + + public static Builder builder() { + return new AutoValue_BertQuestionAnswerer_BertQuestionAnswererOptions.Builder() + .setBaseOptions(BaseOptions.builder().build()); + } + + /** Builder for {@link BertQuestionAnswererOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the general options to configure Task APIs, such as accelerators. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + public abstract BertQuestionAnswererOptions build(); + } } @Override @@ -152,6 +220,10 @@ return answerNative(getNativeHandle(), context, question); } + private BertQuestionAnswerer(long nativeHandle) { + super(nativeHandle); + } + // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer. private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers); @@ -159,10 +231,8 @@ // buffer. private static native long initJniWithAlbertByteBuffers(ByteBuffer... modelBuffers); - // modelBuffers[0] is tflite model file buffer with metadata to specify which tokenizer to use. - private static native long initJniWithModelWithMetadataByteBuffers(ByteBuffer... modelBuffers); - - private static native long initJniWithFileDescriptor(int fd); + private static native long initJniWithFileDescriptor(int fileDescriptor, + long fileDescriptorLength, long fileDescriptorOffset, long baseOptionsHandle); private static native List<QaAnswer> answerNative( long nativeHandle, String context, String question);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/AndroidManifest.xml b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/AndroidManifest.xml index e77a0734..f331580 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/AndroidManifest.xml +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/AndroidManifest.xml
@@ -1,5 +1,5 @@ <?xml version="1.0" encoding="utf-8"?> <manifest xmlns:android="http://schemas.android.com/apk/res/android" package="org.tensorflow.lite.task.vision"> - <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/> + <uses-sdk android:minSdkVersion="21" android:targetSdkVersion="29"/> </manifest>
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD index 661a766..70657c1 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD
@@ -1,8 +1,8 @@ -load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "android_library_with_tflite") load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") package( - default_visibility = ["//tensorflow_lite_support:users"], + default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 ) @@ -10,25 +10,30 @@ "AndroidManifest.xml", ]) -android_library( +android_library_with_tflite( name = "task_library_vision", srcs = [ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier:image_classifier_src", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core:base_vision_api_src", "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector:object_detector_src", "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter:image_segmenter_src", ], - # TODO(b/163039980): Use JAVACOPTS in TF. "-Xep:RemoveUnusedImports:ERROR" wierdly break the build. javacopts = ["-source 7 -target 7"], manifest = "AndroidManifest.xml", + tflite_exports = [ + "//tensorflow_lite_support/java/src/native/task/vision:task_vision_native", + ], visibility = ["//visibility:public"], + # LINT.IfChange(dep) deps = [ "//tensorflow_lite_support/java:tensorflowlite_support_java", "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", - "//tensorflow_lite_support/java/src/native/task/vision:task_vision_native", "@com_google_auto_value", "@maven//:androidx_annotation_annotation", - "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + "@maven//:com_google_android_odml_image", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java_stable", ], + # LINT.ThenChange(<INTERNAL>/release/build_task_pom.sh:dep) ) # AAR target for OSS release.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/AndroidManifest.xml b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/AndroidManifest.xml index ce07182..c0033ce 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/AndroidManifest.xml +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/AndroidManifest.xml
@@ -1,5 +1,5 @@ <?xml version="1.0" encoding="utf-8"?> <manifest xmlns:android="http://schemas.android.com/apk/res/android" package="org.tensorflow.lite.task.vision.classifier"> - <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/> + <uses-sdk android:minSdkVersion="21" android:targetSdkVersion="29"/> </manifest>
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD index c6a70a0..a4d48cda 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD
@@ -1,9 +1,9 @@ load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "android_library_with_tflite") load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") -load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") package( - default_visibility = ["//tensorflow_lite_support:users"], + default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 ) @@ -13,20 +13,44 @@ filegroup( name = "image_classifier_src", - srcs = glob(["**/*.java"]), + srcs = glob( + ["**/*.java"], + ), ) -android_library( +# Default target that uses BuiltInOpResolver, registers all built-in OPs. +# IMPORTANT: In order to use hardware acceleration delegates, you must +# additionally link to the appropriate delegate plugin target as follows: +# 1. GPU delegate plugin: tensorflow_lite_support/acceleration/configuration:gpu_delegate_plugin_android +# 2. NNAPI delegate plugin: included in this target. +android_library_with_tflite( name = "image_classifier", - srcs = glob(["*.java"]), - javacopts = JAVACOPTS, + tflite_exports = [ + "//tensorflow_lite_support/java/src/native/task/vision/classifier:image_classifier_native", + ], + exports = [ + ":image_classifier_java", + ], +) + +# Java-only target, needs to be used together with a native target similar to +# tensorflow_lite_support/java/src/native/task/vision/classifier:image_classifier_native. +# Use this target when you want to provide a MutableOpResolver with customized +# OPs and/or a subset of BuiltInOps to reduce binary size. +android_library( + name = "image_classifier_java", + srcs = [ + ":image_classifier_src", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core:base_vision_api_src", + ], + javacopts = ["-source 7 -target 7"], manifest = "AndroidManifest.xml", deps = [ "//tensorflow_lite_support/java:tensorflowlite_support_java", "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", - "//tensorflow_lite_support/java/src/native/task/vision/classifier:image_classifier_native", "@com_google_auto_value", - "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + "@maven//:com_google_android_odml_image", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java_stable", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java index c1007c11..48038f6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java
@@ -17,17 +17,25 @@ import android.content.Context; import android.graphics.Rect; +import android.os.ParcelFileDescriptor; -import org.tensorflow.lite.DataType; +import com.google.android.odml.image.MlImage; + import org.tensorflow.lite.annotations.UsedByReflection; +import org.tensorflow.lite.support.image.MlImageAdapter; import org.tensorflow.lite.support.image.TensorImage; -import org.tensorflow.lite.task.core.BaseTaskApi; +import org.tensorflow.lite.task.core.BaseOptions; import org.tensorflow.lite.task.core.TaskJniUtils; +import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; +import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi; +import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider; +import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -54,6 +62,10 @@ * <ul> * <li>with {@code N} classes of either 2 or 4 dimensions, such as {@code [1 x N]} or {@code * [1 x 1 x 1 x N]} + * <li>the label file is required to be packed to the metadata. See the <a + * href="https://www.tensorflow.org/lite/convert/metadata#label_output">example of + * creating metadata for an image classifier</a>. If no label files are packed, it will + * use index as label in the result. * </ul> * </ul> * @@ -61,16 +73,19 @@ * href="https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1">TensorFlow * Hub.</a>. */ -public final class ImageClassifier extends BaseTaskApi { +public final class ImageClassifier extends BaseVisionTaskApi { private static final String IMAGE_CLASSIFIER_NATIVE_LIB = "task_vision_jni"; + private static final int OPTIONAL_FD_LENGTH = -1; + private static final int OPTIONAL_FD_OFFSET = -1; /** * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}. * * @param modelPath path of the classification model with metadata in the assets * @throws IOException if an I/O error occurs when loading the tflite model - * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native - * code + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error */ public static ImageClassifier createFromFile(Context context, String modelPath) throws IOException { @@ -79,12 +94,41 @@ } /** + * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}. + * + * @param modelFile the classification model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static ImageClassifier createFromFile(File modelFile) throws IOException { + return createFromFileAndOptions(modelFile, ImageClassifierOptions.builder().build()); + } + + /** + * Creates an {@link ImageClassifier} instance with a model buffer and the default {@link + * ImageClassifierOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static ImageClassifier createFromBuffer(final ByteBuffer modelBuffer) { + return createFromBufferAndOptions(modelBuffer, ImageClassifierOptions.builder().build()); + } + + /** * Creates an {@link ImageClassifier} instance from {@link ImageClassifierOptions}. * * @param modelPath path of the classification model with metadata in the assets * @throws IOException if an I/O error occurs when loading the tflite model - * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native - * code + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error */ public static ImageClassifier createFromFileAndOptions( Context context, String modelPath, ImageClassifierOptions options) throws IOException { @@ -94,17 +138,73 @@ public long createHandle(int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset, ImageClassifierOptions options) { return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength, - fileDescriptorOffset, options); + fileDescriptorOffset, options, + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( + options.getBaseOptions(), options.getNumThreads())); } }, IMAGE_CLASSIFIER_NATIVE_LIB, modelPath, options)); } /** + * Creates an {@link ImageClassifier} instance. + * + * @param modelFile the classification model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static ImageClassifier createFromFileAndOptions( + File modelFile, final ImageClassifierOptions options) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + return new ImageClassifier( + TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithModelFdAndOptions(descriptor.getFd(), + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH, + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options, + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( + options.getBaseOptions(), options.getNumThreads())); + } + }, IMAGE_CLASSIFIER_NATIVE_LIB)); + } + } + + /** + * Creates an {@link ImageClassifier} instance with a model buffer and {@link + * ImageClassifierOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static ImageClassifier createFromBufferAndOptions( + final ByteBuffer modelBuffer, final ImageClassifierOptions options) { + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { + throw new IllegalArgumentException( + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + return new ImageClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithByteBuffer(modelBuffer, options, + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( + options.getBaseOptions(), options.getNumThreads())); + } + }, IMAGE_CLASSIFIER_NATIVE_LIB)); + } + + /** * Constructor to initialize the JNI with a pointer from C++. * * @param nativeHandle a pointer referencing memory allocated in C++ */ - private ImageClassifier(long nativeHandle) { + ImageClassifier(long nativeHandle) { super(nativeHandle); } @@ -117,6 +217,7 @@ // 1. java.util.Optional require Java 8 while we need to support Java 7. // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See // the comments for labelAllowList. + private final BaseOptions baseOptions; private final String displayNamesLocale; private final int maxResults; private final float scoreThreshold; @@ -128,6 +229,7 @@ // vulnerable. private final List<String> labelAllowList; private final List<String> labelDenyList; + private final int numThreads; public static Builder builder() { return new Builder(); @@ -135,14 +237,22 @@ /** A builder that helps to configure an instance of ImageClassifierOptions. */ public static class Builder { + private BaseOptions baseOptions = BaseOptions.builder().build(); private String displayNamesLocale = "en"; private int maxResults = -1; private float scoreThreshold; private boolean isScoreThresholdSet = false; private List<String> labelAllowList = new ArrayList<>(); private List<String> labelDenyList = new ArrayList<>(); + private int numThreads = -1; - private Builder() {} + Builder() {} + + /** Sets the general options to configure Task APIs, such as accelerators. */ + public Builder setBaseOptions(BaseOptions baseOptions) { + this.baseOptions = baseOptions; + return this; + } /** * Sets the locale to use for display names specified through the TFLite Model Metadata, @@ -174,7 +284,7 @@ } /** - * Sets the score threshold in [0,1). + * Sets the score threshold. * * <p>It overrides the one provided in the model metadata (if any). Results below this * value are rejected. @@ -207,6 +317,23 @@ return this; } + /** + * Sets the number of threads to be used for TFLite ops that support multi-threading + * when running inference with CPU. Defaults to -1. + * + * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has + * the effect to let TFLite runtime set the value. + * + * @deprecated use {@link BaseOptions} to configure number of threads instead. This + * method + * will override the number of threads configured from {@link BaseOptions}. + */ + @Deprecated + public Builder setNumThreads(int numThreads) { + this.numThreads = numThreads; + return this; + } + public ImageClassifierOptions build() { return new ImageClassifierOptions(this); } @@ -242,71 +369,144 @@ return new ArrayList<>(labelDenyList); } - private ImageClassifierOptions(Builder builder) { + @UsedByReflection("image_classifier_jni.cc") + public int getNumThreads() { + return numThreads; + } + + public BaseOptions getBaseOptions() { + return baseOptions; + } + + ImageClassifierOptions(Builder builder) { displayNamesLocale = builder.displayNamesLocale; maxResults = builder.maxResults; scoreThreshold = builder.scoreThreshold; isScoreThresholdSet = builder.isScoreThresholdSet; labelAllowList = builder.labelAllowList; labelDenyList = builder.labelDenyList; + numThreads = builder.numThreads; + baseOptions = builder.baseOptions; } } /** - * Performs actual classification on the provided image. + * Performs actual classification on the provided {@link TensorImage}. * - * @param image a {@link TensorImage} object that represents an RGB image - * @throws AssertionError if error occurs when classifying the image from the native code + * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types: + * + * <ul> + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} + * </ul> + * + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image + * @throws IllegalArgumentException if the color space type of image is unsupported */ public List<Classifications> classify(TensorImage image) { return classify(image, ImageProcessingOptions.builder().build()); } /** - * Performs actual classification on the provided image with {@link ImageProcessingOptions}. + * Performs actual classification on the provided {@link TensorImage} with {@link + * ImageProcessingOptions}. * * <p>{@link ImageClassifier} supports the following options: * * <ul> - * <li>Region of interest (ROI) (through {@link ImageProcessingOptions#Builder#setRoi}). It + * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It * defaults to the entire image. - * <li>image rotation (through {@link ImageProcessingOptions#Builder#setOrientation}). It - * defaults to {@link ImageProcessingOptions#Orientation#TOP_LEFT}. + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. * </ul> * - * @param image a {@link TensorImage} object that represents an RGB image - * @throws AssertionError if error occurs when classifying the image from the native code + * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types: + * + * <ul> + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} + * </ul> + * + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image + * @throws IllegalArgumentException if the color space type of image is unsupported */ public List<Classifications> classify(TensorImage image, ImageProcessingOptions options) { + return run(new InferenceProvider<List<Classifications>>() { + @Override + public List<Classifications> run( + long frameBufferHandle, int width, int height, ImageProcessingOptions options) { + return classify(frameBufferHandle, width, height, options); + } + }, image, options); + } + + /** + * Performs actual classification on the provided {@code MlImage}. + * + * @param image an {@code MlImage} object that represents an image + * @throws IllegalArgumentException if the storage type or format of the image is unsupported + */ + public List<Classifications> classify(MlImage image) { + return classify(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs actual classification on the provided {@code MlImage} with {@link + * ImageProcessingOptions}. + * + * <p>{@link ImageClassifier} supports the following options: + * + * <ul> + * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It + * defaults to the entire image. + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link + * MlImage#getRotation()} is not effective. + * </ul> + * + * @param image a {@code MlImage} object that represents an image + * @param options configures options including ROI and rotation + * @throws IllegalArgumentException if the storage type or format of the image is unsupported + */ + public List<Classifications> classify(MlImage image, ImageProcessingOptions options) { + image.getInternal().acquire(); + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); + List<Classifications> result = classify(tensorImage, options); + image.close(); + return result; + } + + private List<Classifications> classify( + long frameBufferHandle, int width, int height, ImageProcessingOptions options) { checkNotClosed(); - // image_classifier_jni.cc expects an uint8 image. Convert image of other types into uint8. - TensorImage imageUint8 = image.getDataType() == DataType.UINT8 - ? image - : TensorImage.createFrom(image, DataType.UINT8); + Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi(); - Rect roi = options.getRoi().isEmpty() - ? new Rect(0, 0, imageUint8.getWidth(), imageUint8.getHeight()) - : options.getRoi(); - - return classifyNative(getNativeHandle(), imageUint8.getBuffer(), imageUint8.getWidth(), - imageUint8.getHeight(), new int[] {roi.left, roi.top, roi.width(), roi.height()}, - options.getOrientation().getValue()); + return classifyNative(getNativeHandle(), frameBufferHandle, + new int[] {roi.left, roi.top, roi.width(), roi.height()}); } private static native long initJniWithModelFdAndOptions(int fileDescriptor, - long fileDescriptorLength, long fileDescriptorOffset, ImageClassifierOptions options); + long fileDescriptorLength, long fileDescriptorOffset, ImageClassifierOptions options, + long baseOptionsHandle); + + private static native long initJniWithByteBuffer( + ByteBuffer modelBuffer, ImageClassifierOptions options, long baseOptionsHandle); /** * The native method to classify an image with the ROI and orientation. * * @param roi the ROI of the input image, an array representing the bounding box as {left, top, * width, height} - * @param orientation the integer value corresponding to {@link - * ImageProcessingOptions#Orientation} */ private static native List<Classifications> classifyNative( - long nativeHandle, ByteBuffer image, int width, int height, int[] roi, int orientation); + long nativeHandle, long frameBufferHandle, int[] roi); @Override protected void deinit(long nativeHandle) {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BUILD new file mode 100644 index 0000000..19a6c99 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BUILD
@@ -0,0 +1,15 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +# BaseVisionTaskApi is built with the source code of vision tasks together. +# Alternatively, we could build BaseVisionTaskApi into an individual library +# with its own .so, and vision tasks could all depend on it. However, it will +# increate the binary size of task-library-vision.aar by 3 MB, because the +# native dependencies of BaseVisionTaskApi are not shared with those vision +# tasks. +filegroup( + name = "base_vision_api_src", + srcs = glob(["**/*.java"]), +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java new file mode 100644 index 0000000..59ab62a9 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java
@@ -0,0 +1,204 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.vision.core; + +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkNotNull; + +import android.graphics.ImageFormat; +import android.media.Image; +import android.media.Image.Plane; + +import com.google.auto.value.AutoValue; + +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.image.ColorSpaceType; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.task.core.BaseTaskApi; +import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; + +import java.nio.ByteBuffer; + +/** Base class for Task Vision APIs. */ +public abstract class BaseVisionTaskApi extends BaseTaskApi { + /** Syntax sugar to run vision tasks with FrameBuffer and image processing options. */ + public interface InferenceProvider<T> { + T run(long frameBufferHandle, int width, int height, ImageProcessingOptions options); + } + + protected BaseVisionTaskApi(long nativeHandle) { + super(nativeHandle); + } + + /** Runs inference with {@link TensorImage} and {@link ImageProcessingOptions}. */ + protected <T> T run( + InferenceProvider<T> provider, TensorImage image, ImageProcessingOptions options) { + FrameBufferData frameBufferData = + createFrameBuffer(image, options.getOrientation().getValue()); + T results = provider.run(frameBufferData.getFrameBufferHandle(), image.getWidth(), + image.getHeight(), options); + deleteFrameBuffer(frameBufferData.getFrameBufferHandle(), + frameBufferData.getByteArrayHandle(), frameBufferData.getByteArray()); + return results; + } + + private static FrameBufferData createFrameBuffer(TensorImage image, int orientation) { + ColorSpaceType colorSpaceType = image.getColorSpaceType(); + switch (colorSpaceType) { + case RGB: + case NV12: + case NV21: + case YV12: + case YV21: + // All these types can be converted to ByteBuffer inside TensorImage. Creating + // FrameBuffer base on the image ByteBuffer. + return createFrameBufferFromByteBuffer(image, orientation); + case YUV_420_888: + // YUV_420_888 is a specific type for android.media.Image. + return createFrameBufferFromMediaImage(image, orientation); + default: + throw new IllegalArgumentException( + "Color space type, " + colorSpaceType.name() + ", is unsupported."); + } + } + + /** + * Creates FrameBuffer from the {@link android.media.Image} stored in the given {@link + * TensorImage}. + */ + private static FrameBufferData createFrameBufferFromMediaImage( + TensorImage image, int orientation) { + Image mediaImage = image.getMediaImage(); + + checkArgument(mediaImage.getFormat() == ImageFormat.YUV_420_888, + "Only supports loading YUV_420_888 Image."); + + Plane[] planes = mediaImage.getPlanes(); + checkArgument(planes.length == 3, + String.format("The input image should have 3 planes, but got %d plane(s).", + planes.length)); + + // Verify and rewind planes. + for (Plane plane : planes) { + ByteBuffer buffer = plane.getBuffer(); + checkNotNull(buffer, "The image buffer is corrupted and the plane is null."); + // From the public documentation, plane.getBuffer() should always return a direct + // ByteBuffer. See + // https://developer.android.com/reference/android/media/Image.Plane#getBuffer() + checkArgument(buffer.isDirect(), + "The image plane buffer is not a direct ByteBuffer, and is not supported."); + buffer.rewind(); + } + + return FrameBufferData.create( + createFrameBufferFromPlanes(planes[0].getBuffer(), planes[1].getBuffer(), + planes[2].getBuffer(), mediaImage.getWidth(), mediaImage.getHeight(), + planes[0].getRowStride(), + // row_stride and pixel_stride should be identical for U/V planes. + planes[1].getRowStride(), planes[1].getPixelStride(), orientation), + // FrameBuffer created with direct ByteBuffer does not require memory freeing. + /*byteArrayHandle=*/0, + /*byteArray=*/new byte[0]); + } + + /** Creates FrameBuffer from the {@link ByteBuffer} stored in the given {@link TensorImage}. */ + private static FrameBufferData createFrameBufferFromByteBuffer( + TensorImage image, int orientation) { + // base_vision_api_jni.cc expects an uint8 image. Convert image of other types into uint8. + TensorImage imageUint8 = image.getDataType() == DataType.UINT8 + ? image + : TensorImage.createFrom(image, DataType.UINT8); + + ByteBuffer byteBuffer = imageUint8.getBuffer(); + byteBuffer.rewind(); + ColorSpaceType colorSpaceType = image.getColorSpaceType(); + if (byteBuffer.isDirect()) { + return FrameBufferData.create( + createFrameBufferFromByteBuffer(byteBuffer, imageUint8.getWidth(), + imageUint8.getHeight(), orientation, colorSpaceType.getValue()), + // FrameBuffer created with direct ByteBuffer does not require memory freeing. + /*byteArrayHandle=*/0, + /*byteArray=*/new byte[0]); + } else { + // If the byte array is copied in jni (during GetByteArrayElements), need to free + // the copied array once inference is done. + long[] byteArrayHandle = new long[1]; + byte[] byteArray = getBytesFromByteBuffer(byteBuffer); + return FrameBufferData.create( + createFrameBufferFromBytes(byteArray, imageUint8.getWidth(), + imageUint8.getHeight(), orientation, colorSpaceType.getValue(), + byteArrayHandle), + byteArrayHandle[0], byteArray); + } + } + + /** Holds the FrameBuffer and the underlying data pointers in C++. */ + @AutoValue + abstract static class FrameBufferData { + /** + * Initializes a {@link FrameBufferData} object. + * + * @param frameBufferHandle the native handle to the FrameBuffer object. + * @param byteArrayHandle the native handle to the data array that backs up the FrameBuffer + * object. If the FrameBuffer is created on a byte array, this byte array need to be + * freed after inference is done. If the FrameBuffer is created on a direct ByteBuffer, no + * byte array needs to be freed, and byteArrayHandle will be 0. + * @param byteArray the byte array that is used to create the c++ byte array object, which + * is + * needed when releasing byteArrayHandle. If the FrameBuffer is created on a direct + * ByteBuffer (no byte array needs to be freed), pass in an empty array for {@code + * byteArray}. + */ + public static FrameBufferData create( + long frameBufferHandle, long byteArrayHandle, byte[] byteArray) { + return new AutoValue_BaseVisionTaskApi_FrameBufferData( + frameBufferHandle, byteArrayHandle, byteArray); + } + + abstract long getFrameBufferHandle(); + + abstract long getByteArrayHandle(); + + // Package private method for transferring data. + @SuppressWarnings("mutable") + abstract byte[] getByteArray(); + } + + private static native long createFrameBufferFromByteBuffer( + ByteBuffer image, int width, int height, int orientation, int colorSpaceType); + + private static native long createFrameBufferFromBytes(byte[] image, int width, int height, + int orientation, int colorSpaceType, long[] byteArrayHandle); + + private static native long createFrameBufferFromPlanes(ByteBuffer yBuffer, ByteBuffer uBuffer, + ByteBuffer vBuffer, int width, int height, int yRowStride, int uvRowStride, + int uvPixelStride, int orientation); + + private static native void deleteFrameBuffer( + long frameBufferHandle, long byteArrayHandle, byte[] byteArray); + + private static byte[] getBytesFromByteBuffer(ByteBuffer byteBuffer) { + // If the ByteBuffer has a back up array, use it directly without copy. + if (byteBuffer.hasArray() && byteBuffer.arrayOffset() == 0) { + return byteBuffer.array(); + } + // Copy out the data otherwise. + byteBuffer.rewind(); + byte[] bytes = new byte[byteBuffer.limit()]; + byteBuffer.get(bytes, 0, bytes.length); + return bytes; + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/AndroidManifest.xml b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/AndroidManifest.xml index 5fefccd..9d585a2b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/AndroidManifest.xml +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/AndroidManifest.xml
@@ -1,5 +1,5 @@ <?xml version="1.0" encoding="utf-8"?> <manifest xmlns:android="http://schemas.android.com/apk/res/android" package="org.tensorflow.lite.task.vision.detector"> - <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/> + <uses-sdk android:minSdkVersion="21" android:targetSdkVersion="29"/> </manifest>
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/BUILD index d0d541ab..81249e8f 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/BUILD
@@ -1,9 +1,9 @@ load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "android_library_with_tflite") load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") -load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") package( - default_visibility = ["//tensorflow_lite_support:users"], + default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 ) @@ -16,17 +16,34 @@ srcs = glob(["**/*.java"]), ) -android_library( +# Default target that uses BuiltInOpResolver, registers all built-in OPs. +android_library_with_tflite( name = "object_detector", - srcs = glob(["*.java"]), - javacopts = JAVACOPTS, + tflite_exports = [ + "//tensorflow_lite_support/java/src/native/task/vision/detector:object_detector_native", + ], + exports = [ + ":object_detector_java", + ], +) + +# Java-only target, need to be used together with a native target similar to +# //third_party/tensorflow_lite_support/java/src/native/task/vision/detector:object_detector_native. +# Use this target when you want to provide a MutableOpResolver with customized +# OPs and/or a subset of BuiltInOps to reduce binary size. +android_library( + name = "object_detector_java", + srcs = glob(["*.java"]) + [ + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core:base_vision_api_src", + ], + javacopts = ["-source 7 -target 7"], manifest = "AndroidManifest.xml", deps = [ "//tensorflow_lite_support/java:tensorflowlite_support_java", "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", - "//tensorflow_lite_support/java/src/native/task/vision/detector:object_detector_native", "@com_google_auto_value", - "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + "@maven//:com_google_android_odml_image", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java_stable", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java index ebcab12..c0585b8 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java
@@ -16,17 +16,24 @@ package org.tensorflow.lite.task.vision.detector; import android.content.Context; +import android.os.ParcelFileDescriptor; -import org.tensorflow.lite.DataType; +import com.google.android.odml.image.MlImage; + import org.tensorflow.lite.annotations.UsedByReflection; +import org.tensorflow.lite.support.image.MlImageAdapter; import org.tensorflow.lite.support.image.TensorImage; -import org.tensorflow.lite.task.core.BaseTaskApi; +import org.tensorflow.lite.task.core.BaseOptions; import org.tensorflow.lite.task.core.TaskJniUtils; +import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; +import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi; +import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -81,16 +88,19 @@ * href="https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1">TensorFlow * Hub.</a>. */ -public final class ObjectDetector extends BaseTaskApi { +public final class ObjectDetector extends BaseVisionTaskApi { private static final String OBJECT_DETECTOR_NATIVE_LIB = "task_vision_jni"; + private static final int OPTIONAL_FD_LENGTH = -1; + private static final int OPTIONAL_FD_OFFSET = -1; /** * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}. * * @param modelPath path to the detection model with metadata in the assets * @throws IOException if an I/O error occurs when loading the tflite model - * @throws AssertionError if error occurs when creating {@link ObjectDetector} from the native - * code + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error */ public static ObjectDetector createFromFile(Context context, String modelPath) throws IOException { @@ -99,12 +109,40 @@ } /** + * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}. + * + * @param modelFile the detection model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static ObjectDetector createFromFile(File modelFile) throws IOException { + return createFromFileAndOptions(modelFile, ObjectDetectorOptions.builder().build()); + } + + /** + * Creates an {@link ObjectDetector} instance with a model buffer and the default {@link + * ObjectDetectorOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection + * model + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static ObjectDetector createFromBuffer(final ByteBuffer modelBuffer) { + return createFromBufferAndOptions(modelBuffer, ObjectDetectorOptions.builder().build()); + } + + /** * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}. * * @param modelPath path to the detection model with metadata in the assets * @throws IOException if an I/O error occurs when loading the tflite model - * @throws AssertionError if error occurs when creating {@link ObjectDetector} from the native - * code + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error */ public static ObjectDetector createFromFileAndOptions( Context context, String modelPath, ObjectDetectorOptions options) throws IOException { @@ -114,12 +152,68 @@ public long createHandle(int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset, ObjectDetectorOptions options) { return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength, - fileDescriptorOffset, options); + fileDescriptorOffset, options, + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( + options.getBaseOptions(), options.getNumThreads())); } }, OBJECT_DETECTOR_NATIVE_LIB, modelPath, options)); } /** + * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}. + * + * @param modelFile the detection model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static ObjectDetector createFromFileAndOptions( + File modelFile, final ObjectDetectorOptions options) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + return new ObjectDetector( + TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithModelFdAndOptions(descriptor.getFd(), + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH, + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options, + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( + options.getBaseOptions(), options.getNumThreads())); + } + }, OBJECT_DETECTOR_NATIVE_LIB)); + } + } + + /** + * Creates an {@link ObjectDetector} instance with a model buffer and {@link + * ObjectDetectorOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection + * model + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static ObjectDetector createFromBufferAndOptions( + final ByteBuffer modelBuffer, final ObjectDetectorOptions options) { + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { + throw new IllegalArgumentException( + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + return new ObjectDetector(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithByteBuffer(modelBuffer, options, + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( + options.getBaseOptions(), options.getNumThreads())); + } + }, OBJECT_DETECTOR_NATIVE_LIB)); + } + + /** * Constructor to initialize the JNI with a pointer from C++. * * @param nativeHandle a pointer referencing memory allocated in C++ @@ -137,6 +231,7 @@ // 1. java.util.Optional require Java 8 while we need to support Java 7. // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See // the comments for labelAllowList. + private final BaseOptions baseOptions; private final String displayNamesLocale; private final int maxResults; private final float scoreThreshold; @@ -148,6 +243,7 @@ // vulnerable. private final List<String> labelAllowList; private final List<String> labelDenyList; + private final int numThreads; public static Builder builder() { return new Builder(); @@ -155,15 +251,23 @@ /** A builder that helps to configure an instance of ObjectDetectorOptions. */ public static class Builder { + private BaseOptions baseOptions = BaseOptions.builder().build(); private String displayNamesLocale = "en"; private int maxResults = -1; private float scoreThreshold; private boolean isScoreThresholdSet = false; private List<String> labelAllowList = new ArrayList<>(); private List<String> labelDenyList = new ArrayList<>(); + private int numThreads = -1; private Builder() {} + /** Sets the general options to configure Task APIs, such as accelerators. */ + public Builder setBaseOptions(BaseOptions baseOptions) { + this.baseOptions = baseOptions; + return this; + } + /** * Sets the locale to use for display names specified through the TFLite Model Metadata, * if any. @@ -210,9 +314,9 @@ * * <p>If non-empty, detection results whose label is not in this set will be filtered * out. Duplicate or unknown labels are ignored. Mutually exclusive with {@code - * labelDenyList}. It will cause {@link AssertionError} when calling {@link - * #createFromFileAndOptions}, if both - * {@code labelDenyList} and {@code labelAllowList} are set. + * labelDenyList}. It will cause {@link IllegalStateException} when calling {@link + * #createFromFileAndOptions}, if both {@code labelDenyList} and {@code labelAllowList} + * are set. */ public Builder setLabelAllowList(List<String> labelAllowList) { this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList)); @@ -224,15 +328,32 @@ * * <p>If non-empty, detection results whose label is in this set will be filtered out. * Duplicate or unknown labels are ignored. Mutually exclusive with {@code - * labelAllowList}. It will cause {@link AssertionError} when calling {@link - * #createFromFileAndOptions}, if both - * {@code labelDenyList} and {@code labelAllowList} are set. + * labelAllowList}. It will cause {@link IllegalStateException} when calling {@link + * #createFromFileAndOptions}, if both {@code labelDenyList} and {@code labelAllowList} + * are set. */ public Builder setLabelDenyList(List<String> labelDenyList) { this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList)); return this; } + /** + * Sets the number of threads to be used for TFLite ops that support multi-threading + * when running inference with CPU. Defaults to -1. + * + * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has + * the effect to let TFLite runtime set the value. + * + * @deprecated use {@link BaseOptions} to configure number of threads instead. This + * method + * will override the number of threads configured from {@link BaseOptions}. + */ + @Deprecated + public Builder setNumThreads(int numThreads) { + this.numThreads = numThreads; + return this; + } + public ObjectDetectorOptions build() { return new ObjectDetectorOptions(this); } @@ -268,6 +389,15 @@ return new ArrayList<>(labelDenyList); } + @UsedByReflection("object_detector_jni.cc") + public int getNumThreads() { + return numThreads; + } + + public BaseOptions getBaseOptions() { + return baseOptions; + } + private ObjectDetectorOptions(Builder builder) { displayNamesLocale = builder.displayNamesLocale; maxResults = builder.maxResults; @@ -275,14 +405,28 @@ isScoreThresholdSet = builder.isScoreThresholdSet; labelAllowList = builder.labelAllowList; labelDenyList = builder.labelDenyList; + numThreads = builder.numThreads; + baseOptions = builder.baseOptions; } } /** * Performs actual detection on the provided image. * - * @param image a {@link TensorImage} object that represents a RGB image - * @throws AssertionError if error occurs when processing the image from the native code + * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types: + * + * <ul> + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} + * </ul> + * + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + * @throws IllegalArgumentException if the color space type of image is unsupported */ public List<Detection> detect(TensorImage image) { return detect(image, ImageProcessingOptions.builder().build()); @@ -291,28 +435,91 @@ /** * Performs actual detection on the provided image. * - * @param image a {@link TensorImage} object that represents a RGB image - * @param options {@link ObjectDetector} only supports image rotation (through {@link - * ImageProcessingOptions#Builder#setOrientation}) currently. The orientation of an image - * defaults to {@link ImageProcessingOptions#Orientation#TOP_LEFT}. - * @throws AssertionError if error occurs when processing the image from the native code + * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types: + * + * <ul> + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} + * </ul> + * + * <p>{@link ObjectDetector} supports the following options: + * + * <ul> + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. + * </ul> + * + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image + * @param options the options to configure how to preprocess the image + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + * @throws IllegalArgumentException if the color space type of image is unsupported */ public List<Detection> detect(TensorImage image, ImageProcessingOptions options) { + return run(new InferenceProvider<List<Detection>>() { + @Override + public List<Detection> run( + long frameBufferHandle, int width, int height, ImageProcessingOptions options) { + return detect(frameBufferHandle, options); + } + }, image, options); + } + + /** + * Performs actual detection on the provided {@code MlImage}. + * + * @param image an {@code MlImage} object that represents an image + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + * @throws IllegalArgumentException if the storage type or format of the image is unsupported + */ + public List<Detection> detect(MlImage image) { + return detect(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs actual detection on the provided {@code MlImage} with {@link + * ImageProcessingOptions}. + * + * <p>{@link ObjectDetector} supports the following options: + * + * <ul> + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link + * MlImage#getRotation()} is not effective. + * </ul> + * + * @param image an {@code MlImage} object that represents an image + * @param options the options to configure how to preprocess the image + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + * @throws IllegalArgumentException if the storage type or format of the image is unsupported + */ + public List<Detection> detect(MlImage image, ImageProcessingOptions options) { + image.getInternal().acquire(); + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); + List<Detection> result = detect(tensorImage, options); + image.close(); + return result; + } + + private List<Detection> detect(long frameBufferHandle, ImageProcessingOptions options) { checkNotClosed(); - // object_detector_jni.cc expects an uint8 image. Convert image of other types into uint8. - TensorImage imageUint8 = image.getDataType() == DataType.UINT8 - ? image - : TensorImage.createFrom(image, DataType.UINT8); - return detectNative(getNativeHandle(), imageUint8.getBuffer(), imageUint8.getWidth(), - imageUint8.getHeight(), options.getOrientation().getValue()); + return detectNative(getNativeHandle(), frameBufferHandle); } private static native long initJniWithModelFdAndOptions(int fileDescriptor, - long fileDescriptorLength, long fileDescriptorOffset, ObjectDetectorOptions options); + long fileDescriptorLength, long fileDescriptorOffset, ObjectDetectorOptions options, + long baseOptionsHandle); - private static native List<Detection> detectNative( - long nativeHandle, ByteBuffer image, int width, int height, int orientation); + private static native long initJniWithByteBuffer( + ByteBuffer modelBuffer, ObjectDetectorOptions options, long baseOptionsHandle); + + private static native List<Detection> detectNative(long nativeHandle, long frameBufferHandle); @Override protected void deinit(long nativeHandle) {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/AndroidManifest.xml b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/AndroidManifest.xml index 991d481..d9c4568 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/AndroidManifest.xml +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/AndroidManifest.xml
@@ -1,5 +1,5 @@ <?xml version="1.0" encoding="utf-8"?> <manifest xmlns:android="http://schemas.android.com/apk/res/android" package="org.tensorflow.lite.task.vision.segmenter"> - <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/> + <uses-sdk android:minSdkVersion="21" android:targetSdkVersion="29"/> </manifest>
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/BUILD index 506d5bf..31a2b11 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/BUILD
@@ -1,8 +1,9 @@ load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "android_library_with_tflite") load("@build_bazel_rules_android//android:rules.bzl", "android_library") package( - default_visibility = ["//tensorflow_lite_support:users"], + default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 ) @@ -15,19 +16,35 @@ srcs = glob(["**/*.java"]), ) -android_library( +# Default target that uses BuiltInOpResolver, registers all built-in OPs. +android_library_with_tflite( name = "image_segmenter", - srcs = glob(["*.java"]), - # TODO(b/163039980): Use JAVACOPTS in TF. "-Xep:RemoveUnusedImports:ERROR" wierdly break the build. + tflite_exports = [ + "//tensorflow_lite_support/java/src/native/task/vision/segmenter:image_segmenter_native", + ], + exports = [ + ":image_segmenter_java", + ], +) + +# Java-only target, need to be used together with a native target similar to +# //third_party/tensorflow_lite_support/java/src/native/task/vision/segmenter:image_segmenter_native", +# Use this target when you want to provide a MutableOpResolver with customized +# OPs and/or a subset of BuiltInOps to reduce binary size. +android_library( + name = "image_segmenter_java", + srcs = glob(["*.java"]) + [ + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core:base_vision_api_src", + ], javacopts = ["-source 7 -target 7"], manifest = "AndroidManifest.xml", deps = [ "//tensorflow_lite_support/java:tensorflowlite_support_java", "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", - "//tensorflow_lite_support/java/src/native/task/vision/segmenter:image_segmenter_native", "@com_google_auto_value", "@maven//:androidx_annotation_annotation", - "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + "@maven//:com_google_android_odml_image", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java_stable", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java index b175e0b..991fede 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java
@@ -34,7 +34,7 @@ * @param label the label string, as provided in the label map packed in the TFLite Model * Metadata. * @param displayName the display name of label, as configured through {@link - * ImageSegmenter#ImageSegmenterOptions#Builder#setDisplayNamesLocale} + * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale} * @param argb the color components for the label in ARGB. See <a * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android * Color ints.</a> for more details. @@ -45,12 +45,12 @@ } /** - * Creates a {@link ColoredLabel} object with a {@link Color} instance. + * Creates a {@link ColoredLabel} object with a {@link android.graphics.Color} instance. * * @param label the label string, as provided in the label map packed in the TFLite Model * Metadata. * @param displayName the display name of label, as configured through {@link - * ImageSegmenter#ImageSegmenterOptions#Builder#setDisplayNamesLocale} + * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale} * @param color the color components for the label. The Color instatnce is supported on Android * API level 26 and above. For API level lower than 26, use {@link #create(String, String, * int)}. See <a @@ -76,7 +76,7 @@ public abstract int getArgb(); /** - * Gets the {@link Color} instance of the underlying color. + * Gets the {@link android.graphics.Color} instance of the underlying color. * * <p>The Color instatnce is supported on Android API level 26 and above. For API level lower * than 26, use {@link #getArgb()}. See <a
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java index b7776dc..4c3b363 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java
@@ -17,19 +17,25 @@ import android.content.Context; import android.content.res.AssetFileDescriptor; +import android.os.ParcelFileDescriptor; +import com.google.android.odml.image.MlImage; import com.google.auto.value.AutoValue; -import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.image.MlImageAdapter; import org.tensorflow.lite.support.image.TensorImage; -import org.tensorflow.lite.task.core.BaseTaskApi; +import org.tensorflow.lite.task.core.BaseOptions; import org.tensorflow.lite.task.core.TaskJniUtils; import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; +import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi; +import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider; +import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -60,7 +66,7 @@ * is the number of classes supported by the model. * <li>optional (but recommended) label map(s) can be attached as AssociatedFile-s with type * TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if - * any) is used to fill the class name, i.e. {@link ColoredLabel#getClassName} of the + * any) is used to fill the class name, i.e. {@link ColoredLabel#getlabel} of the * results. The display name, i.e. {@link ColoredLabel#getDisplayName}, is filled from * the AssociatedFile (if any) whose locale matches the `display_names_locale` field of * the `ImageSegmenterOptions` used at creation time ("en" by default, i.e. English). If @@ -71,8 +77,10 @@ * <p>An example of such model can be found on <a * href="https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1">TensorFlow Hub.</a>. */ -public final class ImageSegmenter extends BaseTaskApi { +public final class ImageSegmenter extends BaseVisionTaskApi { private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni"; + private static final int OPTIONAL_FD_LENGTH = -1; + private static final int OPTIONAL_FD_OFFSET = -1; private final OutputType outputType; @@ -81,8 +89,9 @@ * * @param modelPath path of the segmentation model with metadata in the assets * @throws IOException if an I/O error occurs when loading the tflite model - * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native - * code + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error */ public static ImageSegmenter createFromFile(Context context, String modelPath) throws IOException { @@ -91,32 +100,101 @@ } /** + * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}. + * + * @param modelFile the segmentation model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static ImageSegmenter createFromFile(File modelFile) throws IOException { + return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build()); + } + + /** + * Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link + * ImageSegmenterOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * segmentation model + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + */ + public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) { + return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build()); + } + + /** * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}. * * @param modelPath path of the segmentation model with metadata in the assets * @throws IOException if an I/O error occurs when loading the tflite model - * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native - * code + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error */ public static ImageSegmenter createFromFileAndOptions(Context context, String modelPath, final ImageSegmenterOptions options) throws IOException { try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) { - long nativeHandle = TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { - @Override - public long createHandle() { - return initJniWithModelFdAndOptions( - /*fileDescriptor=*/assetFileDescriptor.getParcelFileDescriptor() - .getFd(), - /*fileDescriptorLength=*/assetFileDescriptor.getLength(), - /*fileDescriptorOffset=*/assetFileDescriptor.getStartOffset(), - options.getDisplayNamesLocale(), options.getOutputType().getValue()); - } - }, IMAGE_SEGMENTER_NATIVE_LIB); - return new ImageSegmenter(nativeHandle, options.getOutputType()); + return createFromModelFdAndOptions( + /*fileDescriptor=*/assetFileDescriptor.getParcelFileDescriptor().getFd(), + /*fileDescriptorLength=*/assetFileDescriptor.getLength(), + /*fileDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options); } } /** + * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}. + * + * @param modelFile the segmentation model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static ImageSegmenter createFromFileAndOptions( + File modelFile, final ImageSegmenterOptions options) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + return createFromModelFdAndOptions( + /*fileDescriptor=*/descriptor.getFd(), + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH, + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options); + } + } + + /** + * Creates an {@link ImageSegmenter} instance with a model buffer and {@link + * ImageSegmenterOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * segmentation model + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + */ + public static ImageSegmenter createFromBufferAndOptions( + final ByteBuffer modelBuffer, final ImageSegmenterOptions options) { + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { + throw new IllegalArgumentException( + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + return new ImageSegmenter(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithByteBuffer(modelBuffer, options.getDisplayNamesLocale(), + options.getOutputType().getValue(), + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( + options.getBaseOptions(), options.getNumThreads())); + } + }, IMAGE_SEGMENTER_NATIVE_LIB), options.getOutputType()); + } + + /** * Constructor to initialize the JNI with a pointer from C++. * * @param nativeHandle a pointer referencing memory allocated in C++ @@ -131,20 +209,30 @@ public abstract static class ImageSegmenterOptions { private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en"; private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK; + private static final int NUM_THREADS = -1; + + public abstract BaseOptions getBaseOptions(); public abstract String getDisplayNamesLocale(); public abstract OutputType getOutputType(); + public abstract int getNumThreads(); + public static Builder builder() { return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder() .setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE) - .setOutputType(DEFAULT_OUTPUT_TYPE); + .setOutputType(DEFAULT_OUTPUT_TYPE) + .setNumThreads(NUM_THREADS) + .setBaseOptions(BaseOptions.builder().build()); } /** Builder for {@link ImageSegmenterOptions}. */ @AutoValue.Builder public abstract static class Builder { + /** Sets the general options to configure Task APIs, such as accelerators. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + /** * Sets the locale to use for display names specified through the TFLite Model Metadata, * if any. @@ -157,6 +245,20 @@ public abstract Builder setOutputType(OutputType outputType); + /** + * Sets the number of threads to be used for TFLite ops that support multi-threading + * when running inference with CPU. Defaults to -1. + * + * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has + * the effect to let TFLite runtime set the value. + * + * @deprecated use {@link BaseOptions} to configure number of threads instead. This + * method + * will override the number of threads configured from {@link BaseOptions}. + */ + @Deprecated + public abstract Builder setNumThreads(int numThreads); + public abstract ImageSegmenterOptions build(); } } @@ -164,12 +266,24 @@ /** * Performs actual segmentation on the provided image. * - * @param image a {@link TensorImage} object that represents an RGB image + * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types: + * + * <ul> + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} + * </ul> + * + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image * @return results of performing image segmentation. Note that at the time, a single {@link * Segmentation} element is expected to be returned. The result is stored in a {@link List} * for later extension to e.g. instance segmentation models, which may return one * segmentation per object. - * @throws AssertionError if error occurs when segmenting the image from the native code + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + * @throws IllegalArgumentException if the color space type of image is unsupported */ public List<Segmentation> segment(TensorImage image) { return segment(image, ImageProcessingOptions.builder().build()); @@ -178,29 +292,97 @@ /** * Performs actual segmentation on the provided image with {@link ImageProcessingOptions}. * - * @param image a {@link TensorImage} object that represents an RGB image - * @param options {@link ImageSegmenter} only supports image rotation (through {@link - * ImageProcessingOptions#Builder#setOrientation}) currently. The orientation of an image - * defaults to {@link ImageProcessingOptions#Orientation#TOP_LEFT}. + * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types: + * + * <ul> + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} + * </ul> + * + * <p>{@link ImageSegmenter} supports the following options: + * + * <ul> + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT} + * </ul> + * + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image + * @param options the options configure how to preprocess the image * @return results of performing image segmentation. Note that at the time, a single {@link * Segmentation} element is expected to be returned. The result is stored in a {@link List} * for later extension to e.g. instance segmentation models, which may return one * segmentation per object. - * @throws AssertionError if error occurs when segmenting the image from the native code + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + * @throws IllegalArgumentException if the color space type of image is unsupported */ public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) { + return run(new InferenceProvider<List<Segmentation>>() { + @Override + public List<Segmentation> run( + long frameBufferHandle, int width, int height, ImageProcessingOptions options) { + return segment(frameBufferHandle, options); + } + }, image, options); + } + + /** + * Performs actual segmentation on the provided {@code MlImage}. + * + * @param image an {@code MlImage} to segment. + * @return results of performing image segmentation. Note that at the time, a single {@link + * Segmentation} element is expected to be returned. The result is stored in a {@link List} + * for later extension to e.g. instance segmentation models, which may return one + * segmentation per object. + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + * @throws IllegalArgumentException if the storage type or format of the image is unsupported + */ + public List<Segmentation> segment(MlImage image) { + return segment(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs actual segmentation on the provided {@code MlImage} with {@link + * ImageProcessingOptions}. + * + * <p>{@link ImageSegmenter} supports the following options: + * + * <ul> + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link + * MlImage#getRotation()} is not effective. + * </ul> + * + * @param image an {@code MlImage} to segment. + * @param options the options configure how to preprocess the image. + * @return results of performing image segmentation. Note that at the time, a single {@link + * Segmentation} element is expected to be returned. The result is stored in a {@link List} + * for later extension to e.g. instance segmentation models, which may return one + * segmentation per object. + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + * @throws IllegalArgumentException if the color space type of image is unsupported + */ + public List<Segmentation> segment(MlImage image, ImageProcessingOptions options) { + image.getInternal().acquire(); + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); + List<Segmentation> result = segment(tensorImage, options); + image.close(); + return result; + } + + public List<Segmentation> segment(long frameBufferHandle, ImageProcessingOptions options) { checkNotClosed(); - // image_segmenter_jni.cc expects an uint8 image. Convert image of other types into uint8. - TensorImage imageUint8 = image.getDataType() == DataType.UINT8 - ? image - : TensorImage.createFrom(image, DataType.UINT8); List<byte[]> maskByteArrays = new ArrayList<>(); List<ColoredLabel> coloredLabels = new ArrayList<>(); int[] maskShape = new int[2]; - segmentNative(getNativeHandle(), imageUint8.getBuffer(), imageUint8.getWidth(), - imageUint8.getHeight(), maskByteArrays, maskShape, coloredLabels, - options.getOrientation().getValue()); + segmentNative( + getNativeHandle(), frameBufferHandle, maskByteArrays, maskShape, coloredLabels); List<ByteBuffer> maskByteBuffers = new ArrayList<>(); for (byte[] bytes : maskByteArrays) { @@ -214,9 +396,28 @@ outputType.createMasksFromBuffer(maskByteBuffers, maskShape), coloredLabels)); } + private static ImageSegmenter createFromModelFdAndOptions(final int fileDescriptor, + final long fileDescriptorLength, final long fileDescriptorOffset, + final ImageSegmenterOptions options) { + long nativeHandle = TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength, + fileDescriptorOffset, options.getDisplayNamesLocale(), + options.getOutputType().getValue(), + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads( + options.getBaseOptions(), options.getNumThreads())); + } + }, IMAGE_SEGMENTER_NATIVE_LIB); + return new ImageSegmenter(nativeHandle, options.getOutputType()); + } + private static native long initJniWithModelFdAndOptions(int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset, String displayNamesLocale, - int outputType); + int outputType, long baseOptionsHandle); + + private static native long initJniWithByteBuffer(ByteBuffer modelBuffer, + String displayNamesLocale, int outputType, long baseOptionsHandle); /** * The native method to segment the image. @@ -224,9 +425,8 @@ * <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the * native layer. */ - private static native void segmentNative(long nativeHandle, ByteBuffer image, int width, - int height, List<byte[]> maskByteArrays, int[] maskShape, - List<ColoredLabel> coloredLabels, int orientation); + private static native void segmentNative(long nativeHandle, long frameBufferHandle, + List<byte[]> maskByteArrays, int[] maskShape, List<ColoredLabel> coloredLabels); @Override protected void deinit(long nativeHandle) {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java index 113cacdd..8c69cf5 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java
@@ -17,7 +17,7 @@ import static org.tensorflow.lite.DataType.FLOAT32; import static org.tensorflow.lite.DataType.UINT8; -import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; +import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument; import static org.tensorflow.lite.support.image.ColorSpaceType.GRAYSCALE; import org.tensorflow.lite.support.image.TensorImage;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/AndroidManifest.xml b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/AndroidManifest.xml new file mode 100644 index 0000000..b2e2262 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/AndroidManifest.xml
@@ -0,0 +1,6 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="org.tensorflow.lite.support"> + <uses-sdk android:minSdkVersion="19" /> +</manifest> +
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/BUILD new file mode 100644 index 0000000..ecabe42 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/BUILD
@@ -0,0 +1,467 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_library", "android_local_test") +load("//tensorflow_lite_support/tools/build_rules/android_test:android_library_instrumentation_tests.bzl", "android_library_instrumentation_tests") + +package( + default_testonly = 1, +) + +licenses(["notice"]) # Apache License 2.0 + +INSTRUMENTED_TESTS = glob(["**/*InstrumentedTest.java"]) + +# Focus testing on the oldest and newest officially supported APIs. +DEFAULT_INSTRUMENTED_DEVICES = [ +] + +exports_files(["AndroidManifest.xml"]) + +android_library( + name = "test_lib", + testonly = 1, + assets = glob(["assets/**"]), + assets_dir = "assets", + manifest = "AndroidManifest.xml", +) + +android_local_test( + name = "GpuDelegateProxyTest", + srcs = ["model/GpuDelegateProxyTest.java"], + manifest = "AndroidManifest.xml", + nocompress_extensions = ["tflite"], + tags = [ + "noasan", + "nomsan", + "notsan", + ], + test_class = "org.tensorflow.lite.support.model.GpuDelegateProxyTest", + deps = [ + ":test_lib", + "//tensorflow_lite_support/java:tensorflow-lite-support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:org_robolectric_robolectric", + "@robolectric//bazel:android-all", + ], +) + +# TODO(138904571): add a bzl file to declare tests automatically for every java test file. +# Note: This test is not able to be run with --test_strategy=local. +android_local_test( + name = "ModelTest", + srcs = ["model/ModelTest.java"], + manifest = "AndroidManifest.xml", + nocompress_extensions = ["tflite"], + tags = [ + "noasan", + "nomsan", + "notsan", + ], + test_class = "org.tensorflow.lite.support.model.ModelTest", + deps = [ + ":test_lib", + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:org_robolectric_robolectric", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "TensorBufferTest", + srcs = [ + "tensorbuffer/TensorBufferTest.java", + ], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.tensorbuffer.TensorBufferTest", + deps = [ + ":test_lib", + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "TensorBufferFloatTest", + srcs = [ + "tensorbuffer/TensorBufferFloatTest.java", + ], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.tensorbuffer.TensorBufferFloatTest", + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "TensorBufferUint8Test", + srcs = [ + "tensorbuffer/TensorBufferUint8Test.java", + ], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.tensorbuffer.TensorBufferUint8Test", + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "TensorImageTest", + size = "small", + srcs = [ + "image/TensorImageTest.java", + "image/TestImageCreator.java", + ], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.image.TensorImageTest", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_mockito_mockito_core", + "@maven//:org_mockito_mockito_inline", + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "BoundingBoxUtilTest", + size = "small", + srcs = ["image/BoundingBoxUtilTest.java"], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.image.BoundingBoxUtilTest", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlitelib_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "ImageConversionsTest", + size = "small", + srcs = [ + "image/ImageConversionsTest.java", + "image/TestImageCreator.java", + ], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.image.ImageConversionsTest", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlitelib_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "ImageProcessorTest", + srcs = ["image/ImageProcessorTest.java"], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.image.ImageProcessorTest", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlitelib_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "NormalizeOpTest", + srcs = ["common/ops/NormalizeOpTest.java"], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.common.ops.NormalizeOpTest", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlitelib_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "CastOpTest", + srcs = ["common/ops/CastOpTest.java"], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.common.ops.CastOpTest", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlitelib_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "DequantizeOpTest", + srcs = ["common/ops/DequantizeOpTest.java"], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.common.ops.DequantizeOpTest", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlitelib_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "QuantizeOpTest", + srcs = ["common/ops/QuantizeOpTest.java"], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.common.ops.QuantizeOpTest", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlitelib_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "LabelAxisOpTest", + srcs = ["label/ops/LabelAxisOpTest.java"], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.label.ops.LabelAxisOpTest", + visibility = ["//visibility:private"], + deps = [ + ":test_lib", + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "CategoryTest", + srcs = ["label/CategoryTest.java"], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.label.CategoryTest", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_robolectric_robolectric", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "LabelUtilTest", + srcs = ["label/LabelUtilTest.java"], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.label.LabelUtilTest", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "TensorLabelTest", + srcs = ["label/TensorLabelTest.java"], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.label.TensorLabelTest", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "FileUtilTest", + srcs = ["common/FileUtilTest.java"], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.common.FileUtilTest", + visibility = ["//visibility:private"], + deps = [ + ":test_lib", + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_robolectric_robolectric", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "TensorProcessorTest", + srcs = ["common/TensorProcessorTest.java"], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.common.TensorProcessorTest", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "ColorSpaceTypeTest", + size = "small", + srcs = [ + "image/ColorSpaceTypeTest.java", + "image/TestImageCreator.java", + ], + manifest = "AndroidManifest.xml", + test_class = "org.tensorflow.lite.support.image.ColorSpaceTypeTest", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_robolectric_robolectric", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_stable", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "TensorAudioTest", + size = "small", + srcs = [ + "audio/TensorAudioTest.java", + ], + manifest = "AndroidManifest.xml", + manifest_values = { + "minSdkVersion": "23", + }, + test_class = "org.tensorflow.lite.support.audio.TensorAudioTest", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", + "@maven//:org_mockito_mockito_core", + "@maven//:org_mockito_mockito_inline", + "@maven//:org_robolectric_robolectric", + "@robolectric//bazel:android-all", + ], +) + +test_suite( + name = "instrumentation_tests", + tags = [ + "no_oss", + "tflite_emulator_test_android", + ], +) + +android_library( + name = "test_image_creator", + testonly = 1, + srcs = ["image/TestImageCreator.java"], + manifest = "AndroidManifest.xml", + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_stable", + ], +) + +# This rule specifies a bundle of tests, in which each test source file converts +# to a test target. +android_library_instrumentation_tests( + name = "instrumented_unittests", + srcs = INSTRUMENTED_TESTS, + binary_args = { + "multidex": "legacy", + }, + tags = [ + "no_oss", + "noasan", # Avoid build breakage + "nomsan", # Avoid build breakage + "notsan", # Avoid build breakage + "nozapfhahn", # Avoid coverage test breakage + "tflite_emulator_test_android", + ], + target_devices = DEFAULT_INSTRUMENTED_DEVICES, + test_java_package = "org.tensorflow.lite.support", + deps = [ + ":test_image_creator", + ":test_lib", + "//tensorflow_lite_support/java:tensorflowlite_support", + "@maven//:androidx_multidex_multidex", + "@maven//:androidx_test_core", + "@maven//:androidx_test_ext_junit", + "@maven//:androidx_test_runner", + "@maven//:com_google_truth_truth", # android + "@maven//:junit_junit", # android + "@maven//:org_mockito_mockito_core", + "@maven//:org_mockito_mockito_inline", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_gpu", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_stable", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/color_grid.png b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/color_grid.png new file mode 100644 index 0000000..3cbfac5 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/color_grid.png Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/flower_labels.txt b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/flower_labels.txt new file mode 100644 index 0000000..08f0b16 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/flower_labels.txt
@@ -0,0 +1,5 @@ +daisy +dandelion +roses +sunflowers +tulips
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/grey_grid.png b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/grey_grid.png new file mode 100644 index 0000000..5cf066fe --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/grey_grid.png Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/labels.txt b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/labels.txt new file mode 100644 index 0000000..fe81123 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/labels.txt
@@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/mobilenet_v1_1.0_224_info.txt b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/mobilenet_v1_1.0_224_info.txt new file mode 100644 index 0000000..1a50fa03 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/assets/mobilenet_v1_1.0_224_info.txt
@@ -0,0 +1,3 @@ +Model: mobilenet_v1_1.0_224 +Input: input +Output: MobilenetV1/Predictions/Reshape_1
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java new file mode 100644 index 0000000..903f791 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java
@@ -0,0 +1,285 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.audio; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import android.media.AudioFormat; +import android.media.AudioRecord; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.support.audio.TensorAudio.TensorAudioFormat; + +/** Test for {@link TensorAudio}. */ +@RunWith(Suite.class) +@SuiteClasses({ + TensorAudioTest.General.class, +}) +public class TensorAudioTest { + /** General tests of TensorAudio. */ + @RunWith(RobolectricTestRunner.class) + public static final class General extends TensorAudioTest { + @Test + public void createSucceedsWithTensorAudioFormat() throws Exception { + TensorAudio tensor = TensorAudio.create( + TensorAudioFormat.builder().setChannels(1).setSampleRate(2).build(), 100); + assertThat(tensor.getFormat().getChannels()).isEqualTo(1); + assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2); + assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(100); + } + + @Test + public void createSucceedsWithTensorAudioFormatWithMultipleChannels() throws Exception { + TensorAudio tensor = TensorAudio.create( + TensorAudioFormat.builder().setChannels(5).setSampleRate(2).build(), 100); + assertThat(tensor.getFormat().getChannels()).isEqualTo(5); + assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2); + assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(500); + } + + @Test + public void createSucceededsWithDefaultArguments() throws Exception { + TensorAudio tensor = + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(20).build(), 1000); + // Number of channels defaults to 1. + assertThat(tensor.getFormat().getChannels()).isEqualTo(1); + assertThat(tensor.getFormat().getSampleRate()).isEqualTo(20); + assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(1000); + } + + @Test + public void createSucceedsWithAudioFormat() throws Exception { + AudioFormat format = new AudioFormat.Builder() + .setChannelMask(AudioFormat.CHANNEL_IN_STEREO) + .setEncoding(AudioFormat.ENCODING_PCM_16BIT) + .setSampleRate(16000) + .build(); + TensorAudio tensor = TensorAudio.create(format, 100); + // STEREO has 2 channels + assertThat(tensor.getFormat().getChannels()).isEqualTo(2); + assertThat(tensor.getFormat().getSampleRate()).isEqualTo(16000); + // flatSize = channelCount * sampleCount + assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(200); + } + + @Test + public void createFailedWithInvalidSampleRate() throws Exception { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () + -> TensorAudio.create( + TensorAudioFormat.builder().setSampleRate(0).build(), 100)); + // Sample rate 0 is not allowed + assertThat(exception).hasMessageThat().ignoringCase().contains("sample rate"); + } + + @Test + public void createFailedWithInvalidChannels() throws Exception { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () + -> TensorAudio.create(TensorAudioFormat.builder() + .setSampleRate(1) + .setChannels(-1) + .build(), + 100)); + // Negative channels is not allowed + assertThat(exception).hasMessageThat().ignoringCase().contains("channels"); + } + + @Test + public void loadSucceedsFromArray() throws Exception { + TensorAudioFormat format = + TensorAudioFormat.builder().setChannels(2).setSampleRate(2).build(); + TensorAudio tensor = TensorAudio.create(format, 2); + assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[4]); + + tensor.load(new float[] {2.f, 0}); + assertThat(tensor.getTensorBuffer().getFloatArray()) + .usingTolerance(0.001f) + .containsExactly(new float[] {0, 0, 2.f, 0}); + + tensor.load(new float[] {2.f, 3.f}, 0, 2); + assertThat(tensor.getTensorBuffer().getFloatArray()) + .usingTolerance(0.001f) + .containsExactly(new float[] {2.f, 0, 2.f, 3.f}); + + tensor.load(new short[] {Short.MAX_VALUE, Short.MIN_VALUE}); + assertThat(tensor.getTensorBuffer().getFloatArray()) + .usingTolerance(0.001f) + .containsExactly(new float[] {2.f, 3.f, 1.f, -1.f}); + + tensor.load(new short[] {1, 2, 3, 0, 1, Short.MIN_VALUE, 3, 4, 5}, 3, 6); + // The entire sequence becomes {2.f, 0, 2.f, 3.f, 1.f, -1.f, 0, 0, -1.f, 0, 0, 0} but + // the ring buffer is only keep the last 4 results. + assertThat(tensor.getTensorBuffer().getFloatArray()) + .usingTolerance(0.001f) + .containsExactly(new float[] {-1.f, 0, 0, 0}); + } + + @Test + public void loadFailsWithIndexOutOfRange() throws Exception { + TensorAudioFormat format = TensorAudioFormat.builder().setSampleRate(2).build(); + TensorAudio tensor = TensorAudio.create(format, 5); + + assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[100], 99, 2)); + + assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[100], 99, 2)); + } + + @Test + public void loadFailsWithIncompatibleInputSize() throws Exception { + TensorAudioFormat format = + TensorAudioFormat.builder().setChannels(3).setSampleRate(2).build(); + TensorAudio tensor = TensorAudio.create(format, 5); + + assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[1])); + + assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[2])); + + assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[2], 1, 1)); + + assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[5], 2, 4)); + } + + @Test + public void loadAudioRecordSucceeds() throws Exception { + TensorAudio tensor = + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4); + tensor.load(new float[] {1, 2, 3, 4, 5}); + assertThat(tensor.getTensorBuffer().getFloatArray()) + .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f}); + + AudioRecord record = mock(AudioRecord.class); + when(record.getBufferSizeInFrames()).thenReturn(5); + when(record.getChannelCount()).thenReturn(1); + when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT); + when(record.getFormat()) + .thenReturn(new AudioFormat.Builder() + .setChannelMask(AudioFormat.CHANNEL_IN_MONO) + .setEncoding(AudioFormat.ENCODING_PCM_FLOAT) + .setSampleRate(16000) + .build()); + // Unused + when(record.read( + any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) + .thenReturn(AudioRecord.ERROR_INVALID_OPERATION); + // Used + when(record.read( + any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) + .thenReturn(1); + assertThat(tensor.load(record)).isEqualTo(1); + assertThat(tensor.getTensorBuffer().getFloatArray()) + .isEqualTo(new float[] {3.f, 4.f, 5.f, 0}); + + record = mock(AudioRecord.class); + when(record.getBufferSizeInFrames()).thenReturn(5); + when(record.getChannelCount()).thenReturn(1); + when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_16BIT); + when(record.getFormat()) + .thenReturn(new AudioFormat.Builder() + .setChannelMask(AudioFormat.CHANNEL_IN_MONO) + .setEncoding(AudioFormat.ENCODING_PCM_16BIT) + .setSampleRate(16000) + .build()); + // Used + when(record.read( + any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) + .thenReturn(2); + // Unused + when(record.read( + any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) + .thenReturn(AudioRecord.ERROR_INVALID_OPERATION); + assertThat(tensor.load(record)).isEqualTo(2); + assertThat(tensor.getTensorBuffer().getFloatArray()) + .isEqualTo(new float[] {5.f, 0, 0, 0}); + } + + @Test + public void loadAudioRecordFailsWithErrorState() throws Exception { + TensorAudio tensor = + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4); + tensor.load(new float[] {1, 2, 3, 4, 5}); + assertThat(tensor.getTensorBuffer().getFloatArray()) + .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f}); + + AudioRecord record = mock(AudioRecord.class); + when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT); + when(record.getFormat()) + .thenReturn(new AudioFormat.Builder() + .setChannelMask(AudioFormat.CHANNEL_IN_MONO) + .setEncoding(AudioFormat.ENCODING_PCM_FLOAT) + .setSampleRate(16000) + .build()); + // Unused + when(record.read( + any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) + .thenReturn(AudioRecord.ERROR_INVALID_OPERATION); + // Used + when(record.read( + any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING))) + .thenReturn(AudioRecord.ERROR_DEAD_OBJECT); + IllegalStateException exception = + assertThrows(IllegalStateException.class, () -> tensor.load(record)); + assertThat(exception).hasMessageThat().contains("ERROR_DEAD_OBJECT"); + } + + @Test + public void loadAudioRecordFailsWithUnsupportedAudioEncoding() throws Exception { + TensorAudio tensor = + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4); + AudioRecord record = mock(AudioRecord.class); + when(record.getFormat()) + .thenReturn(new AudioFormat.Builder() + .setChannelMask(AudioFormat.CHANNEL_IN_MONO) + .setEncoding(AudioFormat.ENCODING_PCM_8BIT) // Not supported + .setSampleRate(16000) + .build()); + when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_8BIT); + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> tensor.load(record)); + assertThat(exception).hasMessageThat().ignoringCase().contains("unsupported encoding"); + } + + @Test + public void loadAudioRecordFailsWithIncompatibleAudioFormat() throws Exception { + TensorAudio tensor = + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4); + AudioRecord record = mock(AudioRecord.class); + when(record.getFormat()) + .thenReturn(new AudioFormat.Builder() + .setChannelMask(AudioFormat.CHANNEL_IN_MONO) + .setEncoding(AudioFormat.ENCODING_PCM_FLOAT) + .setSampleRate(44100) // Mismatch + .build()); + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> tensor.load(record)); + assertThat(exception).hasMessageThat().ignoringCase().contains( + "Incompatible audio format"); + } + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java new file mode 100644 index 0000000..1d26476 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java
@@ -0,0 +1,98 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common; + +import static com.google.common.truth.Truth.assertThat; + +import android.content.Context; + +import androidx.test.core.app.ApplicationProvider; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.MappedByteBuffer; +import java.nio.charset.Charset; +import java.util.List; + +/** Tests of {@link org.tensorflow.lite.support.common.FileUtil}. */ +@RunWith(RobolectricTestRunner.class) +public final class FileUtilTest { + private final Context context = ApplicationProvider.getApplicationContext(); + private static final String LABEL_PATH = "flower_labels.txt"; + + @Test + public void testLoadLabels() throws IOException { + List<String> labels = FileUtil.loadLabels(context, LABEL_PATH); + assertThat(labels) + .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips") + .inOrder(); + } + + @Test + public void testLoadLabelsFromInputStream() throws IOException { + InputStream inputStream = context.getAssets().open(LABEL_PATH); + assertThat(FileUtil.loadLabels(inputStream)) + .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips") + .inOrder(); + } + + @Test + public void whitespaceLabelsShouldNotCount() throws IOException { + String s = "a\nb\n \n\n\nc"; + InputStream stream = new ByteArrayInputStream(s.getBytes(Charset.defaultCharset())); + assertThat(FileUtil.loadLabels(stream)).hasSize(3); + } + + @Test + public void testLoadLabelsNullContext() throws IOException { + Context nullContext = null; + Assert.assertThrows( + NullPointerException.class, () -> FileUtil.loadLabels(nullContext, LABEL_PATH)); + } + + @Test + public void testLoadLabelsNullFilePath() throws IOException { + String nullFilePath = null; + Assert.assertThrows( + NullPointerException.class, () -> FileUtil.loadLabels(context, nullFilePath)); + } + + @Test + public void testLoadMappedFile() throws IOException { + MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, LABEL_PATH); + assertThat(byteModel).isNotNull(); + } + + @Test + public void testLoadMappedFileWithNullContext() throws IOException { + Context nullContext = null; + Assert.assertThrows( + NullPointerException.class, () -> FileUtil.loadMappedFile(nullContext, LABEL_PATH)); + } + + @Test + public void loadMappedFileWithNullFilePath() throws IOException { + String nullFilePath = null; + Assert.assertThrows( + NullPointerException.class, () -> FileUtil.loadMappedFile(context, nullFilePath)); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java new file mode 100644 index 0000000..82f97f25 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java
@@ -0,0 +1,84 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.common.ops.NormalizeOp; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** Tests for {@link TensorProcessor}. */ +@RunWith(RobolectricTestRunner.class) +public final class TensorProcessorTest { + private static final int EXAMPLE_NUM_FEATURES = 1000; + private static final float MEAN = 127.5f; + private static final float STDDEV = 127.5f; + + @Test + public void testBuild() { + TensorProcessor processor = + new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build(); + assertThat(processor).isNotNull(); + } + + @Test + public void testNormalize() { + TensorBuffer input = createExampleTensorBuffer(); + TensorProcessor processor = + new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build(); + TensorBuffer output = processor.process(input); + + float[] pixels = output.getFloatArray(); + assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES); + for (float p : pixels) { + assertThat(p).isAtLeast(-1); + assertThat(p).isAtMost(1); + } + } + + @Test + public void testMultipleNormalize() { + TensorBuffer input = createExampleTensorBuffer(); + TensorProcessor processor = + new TensorProcessor.Builder() + .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1] + .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1] + .build(); + TensorBuffer output = processor.process(input); + + float[] pixels = output.getFloatArray(); + assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES); + for (float p : pixels) { + assertThat(p).isAtLeast(0); + assertThat(p).isAtMost(1); + } + } + + // Creates a TensorBuffer of size {1, 1000}, containing values in range [0, 255]. + private static TensorBuffer createExampleTensorBuffer() { + TensorBuffer buffer = TensorBuffer.createDynamic(DataType.FLOAT32); + int[] features = new int[EXAMPLE_NUM_FEATURES]; + for (int i = 0; i < EXAMPLE_NUM_FEATURES; i++) { + features[i] = i % 256; + } + buffer.loadArray(features, new int[] {1, EXAMPLE_NUM_FEATURES}); + return buffer; + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java new file mode 100644 index 0000000..e8ba24d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java
@@ -0,0 +1,81 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common.ops; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** Tests of {@link CastOp}. */ +@RunWith(RobolectricTestRunner.class) +public final class CastOpTest { + private static final float[] FLOAT_ARRAY = new float[] {1.1f, 3.3f, 5.5f, 7.7f, 9.9f}; + private static final float[] CASTED_FLOAT_ARRAY = new float[] {1.0f, 3.0f, 5.0f, 7.0f, 9.0f}; + private static final int[] INT_ARRAY = new int[] {1, 3, 5, 7, 9}; + private static final int[] SHAPE = new int[] {5}; + + @Test + public void castFloat32ToUint8ShouldSuccess() { + TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); + floatBuffer.loadArray(FLOAT_ARRAY, SHAPE); + CastOp op = new CastOp(DataType.UINT8); + TensorBuffer uint8Buffer = op.apply(floatBuffer); + assertThat(uint8Buffer.getDataType()).isEqualTo(DataType.UINT8); + assertThat(uint8Buffer.getIntArray()).isEqualTo(INT_ARRAY); + } + + @Test + public void castUint8ToFloat32ShouldSuccess() { + TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8); + uint8Buffer.loadArray(INT_ARRAY, SHAPE); + CastOp op = new CastOp(DataType.FLOAT32); + TensorBuffer floatBuffer = op.apply(uint8Buffer); + assertThat(floatBuffer.getDataType()).isEqualTo(DataType.FLOAT32); + assertThat(floatBuffer.getFloatArray()).isEqualTo(CASTED_FLOAT_ARRAY); + } + + @Test + public void castFloat32ToFloat32ShouldNotRecreate() { + TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); + floatBuffer.loadArray(FLOAT_ARRAY, SHAPE); + CastOp op = new CastOp(DataType.FLOAT32); + TensorBuffer newBuffer = op.apply(floatBuffer); + assertThat(newBuffer.getDataType()).isEqualTo(DataType.FLOAT32); + assertThat(newBuffer).isSameInstanceAs(floatBuffer); + } + + @Test + public void castUint8ToUint8ShouldNotRecreate() { + TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8); + uint8Buffer.loadArray(INT_ARRAY, SHAPE); + CastOp op = new CastOp(DataType.UINT8); + TensorBuffer newBuffer = op.apply(uint8Buffer); + assertThat(newBuffer.getDataType()).isEqualTo(DataType.UINT8); + assertThat(newBuffer).isSameInstanceAs(uint8Buffer); + } + + @Test + public void castToUnsupportedDataTypeShouldThrow() { + for (DataType type : new DataType[] {DataType.INT32, DataType.INT64, DataType.STRING}) { + Assert.assertThrows(IllegalArgumentException.class, () -> new CastOp(type)); + } + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java new file mode 100644 index 0000000..a69bcd7 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java
@@ -0,0 +1,40 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common.ops; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** Tests of {@link DequantizeOp}. */ +@RunWith(RobolectricTestRunner.class) +public final class DequantizeOpTest { + @Test + public void dequantizeShouldSucess() { + int[] originalData = new int[] {191, 159, 63, 127, 255, 0}; + DequantizeOp op = new DequantizeOp(127.0f, 1.0f / 128); + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.UINT8); + input.loadArray(originalData); + TensorBuffer dequantized = op.apply(input); + assertThat(dequantized.getDataType()).isEqualTo(DataType.FLOAT32); + assertThat(dequantized.getFloatArray()) + .isEqualTo(new float[] {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f}); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java new file mode 100644 index 0000000..aabc6be --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java
@@ -0,0 +1,151 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common.ops; + +import static com.google.common.truth.Truth.assertThat; + +import static org.tensorflow.lite.DataType.FLOAT32; +import static org.tensorflow.lite.DataType.UINT8; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** + * Tests of {@link NormalizeOp}. + */ +@RunWith(RobolectricTestRunner.class) +public final class NormalizeOpTest { + private static final float MEAN = 50; + private static final float STDDEV = 50; + private static final int NUM_ELEMENTS = 100; + + @Test + public void testNormalizeIntBuffer() { + int[] inputArr = new int[NUM_ELEMENTS]; + for (int i = 0; i < NUM_ELEMENTS; i++) { + inputArr[i] = i; + } + TensorBuffer input = TensorBuffer.createDynamic(DataType.UINT8); + input.loadArray(inputArr, new int[] {inputArr.length}); + NormalizeOp op = new NormalizeOp(MEAN, STDDEV); + TensorBuffer output = op.apply(input); + assertThat(output.getDataType()).isEqualTo(FLOAT32); + float[] outputArr = output.getFloatArray(); + for (int i = 0; i < NUM_ELEMENTS; i++) { + assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV); + } + } + + @Test + public void testNormalizeFloatBuffer() { + float[] inputArr = new float[NUM_ELEMENTS]; + for (int i = 0; i < NUM_ELEMENTS; i++) { + inputArr[i] = i; + } + TensorBuffer input = TensorBuffer.createDynamic(FLOAT32); + input.loadArray(inputArr, new int[] {inputArr.length}); + NormalizeOp op = new NormalizeOp(MEAN, STDDEV); + TensorBuffer output = op.apply(input); + assertThat(output.getDataType()).isEqualTo(FLOAT32); + float[] outputArr = output.getFloatArray(); + for (int i = 0; i < NUM_ELEMENTS; i++) { + assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV); + } + } + + @Test + public void testZeroStddev() { + Assert.assertThrows(IllegalArgumentException.class, () -> new NormalizeOp(1, 0)); + } + + @Test + public void testIdentityShortcut() { + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); + NormalizeOp op = new NormalizeOp(0, 1); + TensorBuffer output = op.apply(input); + assertThat(output.getDataType()).isEqualTo(UINT8); + assertThat(output).isSameInstanceAs(input); + } + + @Test + public void testNormalizeOp_zeroMeanAndZeroStddev() { + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); + NormalizeOp op = new NormalizeOp(0, 0); + TensorBuffer output = op.apply(input); + assertThat(output.getDataType()).isEqualTo(UINT8); + assertThat(output).isSameInstanceAs(input); + } + + @Test + public void testNormalizeOp_zeroMeanAndInifityStddev() { + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); + NormalizeOp op = new NormalizeOp(0, Float.POSITIVE_INFINITY); + TensorBuffer output = op.apply(input); + assertThat(output.getDataType()).isEqualTo(UINT8); + assertThat(output).isSameInstanceAs(input); + } + + @Test + public void testMultiChannelNormalize() { + float[] inputArr = new float[NUM_ELEMENTS]; + for (int i = 0; i < NUM_ELEMENTS; i++) { + inputArr[i] = i; + } + TensorBuffer input = TensorBuffer.createDynamic(FLOAT32); + input.loadArray(inputArr, new int[] {20, 5}); + float[] means = new float[] {1, 2, 3, 4, 5}; + float[] stddevs = new float[] {6, 7, 8, 9, 10}; + NormalizeOp op = new NormalizeOp(means, stddevs); + TensorBuffer output = op.apply(input); + assertThat(output.getDataType()).isEqualTo(FLOAT32); + float[] outputArr = output.getFloatArray(); + for (int i = 0; i < NUM_ELEMENTS; i++) { + assertThat(outputArr[i]).isEqualTo((i - means[i % 5]) / stddevs[i % 5]); + } + } + + @Test + public void testMultiChannelShortcut() { + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); + NormalizeOp op = new NormalizeOp(new float[] {0, 0, 0}, new float[] {1, 1, 1}); + TensorBuffer output = op.apply(input); + assertThat(output.getDataType()).isEqualTo(UINT8); + assertThat(output).isSameInstanceAs(input); + } + + @Test + public void testMismatchedNumbersOfMeansAndStddevs() { + Assert.assertThrows(IllegalArgumentException.class, + () -> new NormalizeOp(new float[] {2, 3}, new float[] {1})); + } + + @Test + public void testMismatchedInputTensorChannelNum() { + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8); + NormalizeOp op = new NormalizeOp(new float[] {0, 0}, new float[] {1, 2}); + Assert.assertThrows(IllegalArgumentException.class, () -> op.apply(input)); + } + + @Test + public void testAnyChannelInvalidStddev() { + Assert.assertThrows(IllegalArgumentException.class, + () -> new NormalizeOp(new float[] {2, 3}, new float[] {1, 0})); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java new file mode 100644 index 0000000..519cd28 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java
@@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common.ops; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** Tests of {@link QuantizeOp}. */ +@RunWith(RobolectricTestRunner.class) +public final class QuantizeOpTest { + @Test + public void quantizeShouldSuccess() { + float[] originalData = {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f}; // -0.9921875 == -127 / 128 + QuantizeOp op = new QuantizeOp(127.0f, 1.0f / 128); + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.FLOAT32); + input.loadArray(originalData); + TensorBuffer quantized = op.apply(input); + assertThat(quantized.getDataType()).isEqualTo(DataType.FLOAT32); + assertThat(quantized.getIntArray()).isEqualTo(new int[] {191, 159, 63, 127, 255, 0}); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java new file mode 100644 index 0000000..e8edb58 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java
@@ -0,0 +1,169 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import static com.google.common.truth.Truth.assertThat; + +import android.graphics.RectF; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.image.BoundingBoxUtil.CoordinateType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +import java.util.List; + +/** Tests of {@link BoundingBoxUtil}. */ +@RunWith(RobolectricTestRunner.class) +public class BoundingBoxUtilTest { + private TensorBuffer tensorBuffer; + + @Before + public void setUp() { + // 2 bounding boxes with additional batch dimension. + tensorBuffer = TensorBuffer.createFixedSize(new int[] {1, 2, 4}, DataType.FLOAT32); + } + + @Test + public void convertDefaultRatioBoundaries() { + tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f}); + + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1, + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 500, 400); + + assertThat(boxList).hasSize(2); + assertThat(boxList.get(0)).isEqualTo(new RectF(100, 100, 300, 400)); + assertThat(boxList.get(1)).isEqualTo(new RectF(200, 0, 400, 500)); + } + + @Test + public void convertComplexTensor() { + tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 4, 2}, DataType.FLOAT32); + tensorBuffer.loadArray(new float[] {// sub tensor 0 + 0, 1, 10, 11, 20, 21, 30, 31, + // sub tensor 1 + 100, 101, 110, 111, 120, 121, 130, 131, + // sub tensor 2 + 200, 201, 210, 211, 220, 221, 230, 231}); + + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, 1, + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.PIXEL, 0, 0); + + assertThat(boxList).hasSize(6); + assertThat(boxList.get(0)).isEqualTo(new RectF(0, 10, 20, 30)); + assertThat(boxList.get(1)).isEqualTo(new RectF(1, 11, 21, 31)); + assertThat(boxList.get(2)).isEqualTo(new RectF(100, 110, 120, 130)); + assertThat(boxList.get(3)).isEqualTo(new RectF(101, 111, 121, 131)); + } + + @Test + public void convertIndexedRatioBoundaries() { + tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f}); + + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {1, 0, 3, 2}, -1, + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 500, 400); + + assertThat(boxList).hasSize(2); + assertThat(boxList.get(0)).isEqualTo(new RectF(80, 125, 320, 375)); + assertThat(boxList.get(1)).isEqualTo(new RectF(0, 250, 400, 500)); + } + + @Test + public void convertPixelBoundaries() { + tensorBuffer.loadArray(new float[] {100, 100, 300, 400, 200, 0, 400, 500}); + + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1, + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.PIXEL, 500, 400); + + assertThat(boxList) + .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500)) + .inOrder(); + } + + @Test + public void convertRatioUpperLeft() { + tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.5f, 0.6f, 0.5f, 0.0f, 0.5f, 1.0f}); + + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1, + BoundingBoxUtil.Type.UPPER_LEFT, CoordinateType.RATIO, 500, 400); + + assertThat(boxList).hasSize(2); + assertThat(boxList) + .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500)) + .inOrder(); + } + + @Test + public void convertPixelUpperLeft() { + tensorBuffer.loadArray(new float[] {100, 100, 200, 300, 200, 0, 200, 500}); + + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1, + BoundingBoxUtil.Type.UPPER_LEFT, CoordinateType.PIXEL, 500, 400); + + assertThat(boxList) + .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500)) + .inOrder(); + } + + @Test + public void convertRatioCenter() { + tensorBuffer.loadArray(new float[] {0.5f, 0.5f, 0.5f, 0.6f, 0.75f, 0.5f, 0.5f, 1.0f}); + + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1, + BoundingBoxUtil.Type.CENTER, CoordinateType.RATIO, 500, 400); + + assertThat(boxList) + .containsExactly(new RectF(100, 99.99999f, 300, 400), new RectF(200, 0, 400, 500)) + .inOrder(); + } + + @Test + public void convertPixelCenter() { + tensorBuffer.loadArray(new float[] {200, 250, 200, 300, 300, 250, 200, 500}); + + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1, + BoundingBoxUtil.Type.CENTER, CoordinateType.PIXEL, 500, 400); + + assertThat(boxList) + .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500)) + .inOrder(); + } + + @Test + public void convertTensorWithUnexpectedShapeShouldThrow() { + TensorBuffer badShapeTensor = + TensorBuffer.createFixedSize(new int[] {1, 5}, DataType.FLOAT32); + + Assert.assertThrows(IllegalArgumentException.class, + () + -> BoundingBoxUtil.convert(badShapeTensor, new int[] {0, 1, 2, 3}, -1, + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 300, 400)); + } + + @Test + public void convertIntTensorShouldThrow() { + TensorBuffer badTypeTensor = TensorBuffer.createFixedSize(new int[] {1, 4}, DataType.UINT8); + + Assert.assertThrows(IllegalArgumentException.class, + () + -> BoundingBoxUtil.convert(badTypeTensor, new int[] {0, 1, 2, 3}, -1, + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 300, 400)); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java new file mode 100644 index 0000000..329b5aa3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java
@@ -0,0 +1,49 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite.support.image; + +import static com.google.common.truth.Truth.assertThat; + +import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleBitmap; +import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleTensorBuffer; + +import android.graphics.Bitmap; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +@RunWith(JUnit4.class) +public final class ColorSpaceTypeInstrumentedTest { + @Test + public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithUint8() { + TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.UINT8, false); + Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer); + + Bitmap expectedBitmap = createGrayscaleBitmap(); + assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); + } + + @Test + public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithFloat() { + TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.FLOAT32, false); + Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer); + + Bitmap expectedBitmap = createGrayscaleBitmap(); + assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java new file mode 100644 index 0000000..9261225 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java
@@ -0,0 +1,390 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; +import static org.tensorflow.lite.support.image.TestImageCreator.createRgbBitmap; +import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensorBuffer; + +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.graphics.ImageFormat; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; +import org.robolectric.ParameterizedRobolectricTestRunner; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameter; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameters; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +import java.util.Arrays; +import java.util.Collection; + +/** Tests of {@link ImageConversions}. */ +@RunWith(Suite.class) +@SuiteClasses({ColorSpaceTypeTest.ValidShapeTest.class, ColorSpaceTypeTest.InvalidShapeTest.class, + ColorSpaceTypeTest.BitmapConfigTest.class, ColorSpaceTypeTest.ImageFormatTest.class, + ColorSpaceTypeTest.YuvImageTest.class, ColorSpaceTypeTest.AssertNumElementsTest.class, + ColorSpaceTypeTest.General.class}) +public class ColorSpaceTypeTest { + /** Parameterized tests for valid shapes. */ + @RunWith(ParameterizedRobolectricTestRunner.class) + public static final class ValidShapeTest extends ColorSpaceTypeTest { + @Parameter(0) + public ColorSpaceType colorSpaceType; + + /** The shape that matches the colorSpaceType. */ + @Parameter(1) + public int[] validShape; + + /** The height of validShape. */ + @Parameter(2) + public int expectedHeight; + + /** The width of validShape. */ + @Parameter(3) + public int expectedWidth; + + @Parameters(name = "colorSpaceType={0}; validShape={1}; height={2}; width={3}") + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {ColorSpaceType.RGB, new int[] {1, 10, 20, 3}, 10, 20}, + {ColorSpaceType.RGB, new int[] {10, 20, 3}, 10, 20}, + {ColorSpaceType.GRAYSCALE, new int[] {10, 20}, 10, 20}, + {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 1}, 10, 20}, + }); + } + + @Test + public void getHeightSucceedsWithValidShape() { + assertThat(colorSpaceType.getHeight(validShape)).isEqualTo(expectedHeight); + } + + @Test + public void getWidthSucceedsWithValidShape() { + assertThat(colorSpaceType.getWidth(validShape)).isEqualTo(expectedWidth); + } + } + + /** Parameterized tests for invalid shapes. */ + @RunWith(ParameterizedRobolectricTestRunner.class) + public static final class InvalidShapeTest extends ColorSpaceTypeTest { + private static final String RGB_ASSERT_SHAPE_MESSAGE = + "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels" + + " representing R, G, B in order. The provided image shape is "; + private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE = + "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image" + + " shape is "; + + @Parameter(0) + public ColorSpaceType colorSpaceType; + + /** The shape that does not match the colorSpaceType. */ + @Parameter(1) + public int[] invalidShape; + + @Parameter(2) + public String errorMessage; + + @Parameters(name = "colorSpaceType={0}; invalidShape={1}") + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {1, -10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {1, 10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {-10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20}, + GRAYSCALE_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3}, + GRAYSCALE_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.GRAYSCALE, new int[] {1, -10, 20}, + GRAYSCALE_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.GRAYSCALE, new int[] {1, 10, -20}, + GRAYSCALE_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4}, + GRAYSCALE_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.GRAYSCALE, new int[] {-10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.GRAYSCALE, new int[] {10, -20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, + }); + } + + @Test + public void assertShapeFaislsWithInvalidShape() { + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, () -> colorSpaceType.assertShape(invalidShape)); + assertThat(exception).hasMessageThat().contains( + errorMessage + Arrays.toString(invalidShape)); + } + + @Test + public void getHeightFaislsWithInvalidShape() { + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, () -> colorSpaceType.getHeight(invalidShape)); + assertThat(exception).hasMessageThat().contains( + errorMessage + Arrays.toString(invalidShape)); + } + + @Test + public void getWidthFaislsWithInvalidShape() { + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, () -> colorSpaceType.getWidth(invalidShape)); + assertThat(exception).hasMessageThat().contains( + errorMessage + Arrays.toString(invalidShape)); + } + } + + /** Parameterized tests for Bitmap Config. */ + @RunWith(ParameterizedRobolectricTestRunner.class) + public static final class BitmapConfigTest extends ColorSpaceTypeTest { + @Parameter(0) + public ColorSpaceType colorSpaceType; + + /** The Bitmap configuration match the colorSpaceType. */ + @Parameter(1) + public Config config; + + @Parameters(name = "colorSpaceType={0}; config={1}") + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {ColorSpaceType.RGB, Config.ARGB_8888}, + {ColorSpaceType.GRAYSCALE, Config.ALPHA_8}, + }); + } + + @Test + public void fromBitmapConfigSucceedsWithSupportedConfig() { + assertThat(ColorSpaceType.fromBitmapConfig(config)).isEqualTo(colorSpaceType); + } + + @Test + public void toBitmapConfigSucceedsWithSupportedConfig() { + assertThat(colorSpaceType.toBitmapConfig()).isEqualTo(config); + } + } + + /** Parameterized tests for ImageFormat. */ + @RunWith(ParameterizedRobolectricTestRunner.class) + public static final class ImageFormatTest extends ColorSpaceTypeTest { + @Parameter(0) + public ColorSpaceType colorSpaceType; + + /** The ImageFormat that matches the colorSpaceType. */ + @Parameter(1) + public int imageFormat; + + @Parameters(name = "colorSpaceType={0}; imageFormat={1}") + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {ColorSpaceType.NV21, ImageFormat.NV21}, + {ColorSpaceType.YV12, ImageFormat.YV12}, + {ColorSpaceType.YUV_420_888, ImageFormat.YUV_420_888}, + }); + } + + @Test + public void fromImageFormatSucceedsWithSupportedImageFormat() { + assertThat(ColorSpaceType.fromImageFormat(imageFormat)).isEqualTo(colorSpaceType); + } + } + + /** Parameterized tests for YUV image formats: NV12, NV21, YV12, YV21, YUV_420_888. */ + @RunWith(ParameterizedRobolectricTestRunner.class) + public static final class YuvImageTest extends ColorSpaceTypeTest { + @Parameter(0) + public ColorSpaceType colorSpaceType; + + @Parameters(name = "colorSpaceType={0}") + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {ColorSpaceType.NV12}, + {ColorSpaceType.NV21}, + {ColorSpaceType.YV12}, + {ColorSpaceType.YV21}, + {ColorSpaceType.YUV_420_888}, + }); + } + + @Test + public void convertTensorBufferToBitmapShouldFail() { + UnsupportedOperationException exception = + assertThrows(UnsupportedOperationException.class, + () + -> colorSpaceType.convertTensorBufferToBitmap( + TensorBuffer.createDynamic(DataType.FLOAT32))); + assertThat(exception).hasMessageThat().contains( + "convertTensorBufferToBitmap() is unsupported for the color space type " + + colorSpaceType.name()); + } + + @Test + public void getWidthShouldFail() { + UnsupportedOperationException exception = + assertThrows(UnsupportedOperationException.class, + () -> colorSpaceType.getWidth(new int[] {})); + assertThat(exception).hasMessageThat().contains( + "getWidth() only supports RGB and GRAYSCALE formats, but not " + + colorSpaceType.name()); + } + + @Test + public void getHeightShouldFail() { + UnsupportedOperationException exception = + assertThrows(UnsupportedOperationException.class, + () -> colorSpaceType.getHeight(new int[] {})); + assertThat(exception).hasMessageThat().contains( + "getHeight() only supports RGB and GRAYSCALE formats, but not " + + colorSpaceType.name()); + } + + @Test + public void assertShapeShouldFail() { + UnsupportedOperationException exception = + assertThrows(UnsupportedOperationException.class, + () -> colorSpaceType.assertShape(new int[] {})); + assertThat(exception).hasMessageThat().contains( + "assertShape() only supports RGB and GRAYSCALE formats, but not " + + colorSpaceType.name()); + } + + @Test + public void getChannelValueShouldFail() { + UnsupportedOperationException exception = assertThrows( + UnsupportedOperationException.class, () -> colorSpaceType.getChannelValue()); + assertThat(exception).hasMessageThat().contains( + "getChannelValue() is unsupported for the color space type " + + colorSpaceType.name()); + } + + @Test + public void getNormalizedShapeShouldFail() { + UnsupportedOperationException exception = + assertThrows(UnsupportedOperationException.class, + () -> colorSpaceType.getNormalizedShape(new int[] {})); + assertThat(exception).hasMessageThat().contains( + "getNormalizedShape() is unsupported for the color space type " + + colorSpaceType.name()); + } + + @Test + public void getShapeInfoMessageShouldFail() { + UnsupportedOperationException exception = + assertThrows(UnsupportedOperationException.class, + () -> colorSpaceType.getShapeInfoMessage()); + assertThat(exception).hasMessageThat().contains( + "getShapeInfoMessage() is unsupported for the color space type " + + colorSpaceType.name()); + } + + @Test + public void toBitmapConfigShouldFail() { + UnsupportedOperationException exception = assertThrows( + UnsupportedOperationException.class, () -> colorSpaceType.toBitmapConfig()); + assertThat(exception).hasMessageThat().contains( + "toBitmapConfig() is unsupported for the color space type " + + colorSpaceType.name()); + } + } + + /** Parameterized tests for assertNumElements/getNumElements with all image formats. */ + @RunWith(ParameterizedRobolectricTestRunner.class) + public static final class AssertNumElementsTest extends ColorSpaceTypeTest { + private static final int HEIGHT = 2; + private static final int WIDTH = 3; + private static final int LESS_NUM_ELEMENTS = 5; // less than expected + private static final int MORE_NUM_ELEMENTS = 20; // more than expected. OK. + @Rule + public ErrorCollector errorCollector = new ErrorCollector(); + + @Parameter(0) + public ColorSpaceType colorSpaceType; + + @Parameter(1) + public int expectedNumElements; + + @Parameters(name = "colorSpaceType={0};expectedNumElements={1}") + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {ColorSpaceType.RGB, 18}, + {ColorSpaceType.GRAYSCALE, 6}, + {ColorSpaceType.NV12, 10}, + {ColorSpaceType.NV21, 10}, + {ColorSpaceType.YV12, 10}, + {ColorSpaceType.YV21, 10}, + }); + } + + @Test + public void getNumElementsShouldSucceedWithExpectedNumElements() { + assertThat(colorSpaceType.getNumElements(HEIGHT, WIDTH)).isEqualTo(expectedNumElements); + } + + @Test + public void assertNumElementsShouldSucceedWithMoreNumElements() { + errorCollector.checkSucceeds(() -> { + colorSpaceType.assertNumElements(MORE_NUM_ELEMENTS, HEIGHT, WIDTH); + return null; + }); + } + + @Test + public void assertNumElementsShouldFailWithLessNumElements() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> colorSpaceType.assertNumElements(LESS_NUM_ELEMENTS, HEIGHT, WIDTH)); + assertThat(exception).hasMessageThat().contains(String.format( + "The given number of elements (%d) does not match the image (%s) in %d x %d. The" + + " expected number of elements should be at least %d.", + LESS_NUM_ELEMENTS, colorSpaceType.name(), HEIGHT, WIDTH, expectedNumElements)); + } + } + + /** General tests of ColorSpaceTypeTest. */ + @RunWith(RobolectricTestRunner.class) + public static final class General extends ColorSpaceTypeTest { + @Test + public void convertTensorBufferToBitmapShouldSuccessWithRGB() { + TensorBuffer buffer = createRgbTensorBuffer(DataType.UINT8, false); + Bitmap bitmap = ColorSpaceType.RGB.convertTensorBufferToBitmap(buffer); + + Bitmap expectedBitmap = createRgbBitmap(); + assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); + } + + @Test + public void fromBitmapConfigFailsWithUnsupportedConfig() { + Config unsupportedConfig = Config.ARGB_4444; + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> ColorSpaceType.fromBitmapConfig(unsupportedConfig)); + assertThat(exception).hasMessageThat().contains( + "Bitmap configuration: " + unsupportedConfig + ", is not supported yet."); + } + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java new file mode 100644 index 0000000..49efc427 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java
@@ -0,0 +1,234 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import static android.graphics.Bitmap.Config.ARGB_8888; +import static android.graphics.Color.BLACK; +import static android.graphics.Color.BLUE; +import static android.graphics.Color.GREEN; +import static android.graphics.Color.RED; +import static android.graphics.Color.WHITE; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; +import static org.tensorflow.lite.support.image.ImageConversions.convertGrayscaleTensorBufferToBitmap; + +import android.content.Context; +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.util.Log; + +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +import java.io.IOException; +import java.util.Arrays; + +/** Instrumented unit test for {@link ImageConversions}. */ +@RunWith(Suite.class) +@SuiteClasses({ImageConversionsInstrumentedTest.TensorBufferToBitmap.class, + ImageConversionsInstrumentedTest.BitmapToTensorBuffer.class}) +public class ImageConversionsInstrumentedTest { + /** Tests for the TensorBuffer data type and normalized form. */ + // Note that parameterized test with android_library_instrumentation_tests is currently not + // supported internally. + @RunWith(AndroidJUnit4.class) + public static final class TensorBufferToBitmap extends ImageConversionsInstrumentedTest { + @Test + public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatNormalized() { + DataType dataType = DataType.FLOAT32; + boolean isNormalized = true; + + TensorBuffer buffer = + TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized); + Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer); + + Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap(); + assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); + } + + @Test + public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatUnnormalized() { + DataType dataType = DataType.FLOAT32; + boolean isNormalized = false; + + TensorBuffer buffer = + TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized); + Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer); + + Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap(); + assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); + } + + @Test + public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Normalized() { + DataType dataType = DataType.UINT8; + boolean isNormalized = true; + + TensorBuffer buffer = + TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized); + Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer); + + Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap(); + assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); + } + + @Test + public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Unnormalized() { + DataType dataType = DataType.UINT8; + boolean isNormalized = false; + + TensorBuffer buffer = + TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized); + Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer); + + Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap(); + assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); + } + + @Test + public void + convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithFloat() { + DataType dataType = DataType.FLOAT32; + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> convertGrayscaleTensorBufferToBitmap(buffer)); + assertThat(exception).hasMessageThat().contains( + "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image" + + " shape is " + Arrays.toString(buffer.getShape())); + } + + @Test + public void + convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithUint8() { + DataType dataType = DataType.UINT8; + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> convertGrayscaleTensorBufferToBitmap(buffer)); + assertThat(exception).hasMessageThat().contains( + "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image" + + " shape is " + Arrays.toString(buffer.getShape())); + } + } + + /** BitmapToTensorBuffer tests of ImageConversionsInstrumentedTest. */ + @RunWith(AndroidJUnit4.class) + public static final class BitmapToTensorBuffer extends ImageConversionsInstrumentedTest { + private Bitmap greyGrid; + private Bitmap colorGrid; + private TensorBuffer buffer; + + static final String GREY_GRID_PATH = "grey_grid.png"; + static final String COLOR_GRID_PATH = "color_grid.png"; + + @Before + public void loadAssets() { + Context context = ApplicationProvider.getApplicationContext(); + AssetManager assetManager = context.getAssets(); + try { + greyGrid = BitmapFactory.decodeStream(assetManager.open(GREY_GRID_PATH)); + colorGrid = BitmapFactory.decodeStream(assetManager.open(COLOR_GRID_PATH)); + } catch (IOException e) { + Log.e("Test", "Cannot load asset files"); + } + Assert.assertEquals(ARGB_8888, greyGrid.getConfig()); + Assert.assertEquals(ARGB_8888, colorGrid.getConfig()); + buffer = TensorBuffer.createDynamic(DataType.UINT8); + } + + @Test + public void testBitmapDimensionLayout() { + // This test is not only for proving the correctness of bitmap -> TensorBuffer + // conversion, but also for us to better understand how Android Bitmap is storing pixels + // - height first or width first. We use a black image which has a white corner to + // understand what happens. By setting up the correct loop to pass the test, we can + // reveal the order of pixels returned from `getPixels`. The result shows that Android + // stores bitmap in an h-first manner. The returned array of `getPixels` is like [ 1st + // row, 2nd row, ... ] which is the same with TFLite. + Assert.assertEquals(100, greyGrid.getWidth()); + Assert.assertEquals(100, greyGrid.getHeight()); + Assert.assertEquals(BLACK, greyGrid.getPixel(25, 25)); // left top + Assert.assertEquals(BLACK, greyGrid.getPixel(75, 25)); // right top + Assert.assertEquals(WHITE, greyGrid.getPixel(25, 75)); // left bottom + Assert.assertEquals(BLACK, greyGrid.getPixel(75, 75)); // right bottom + + ImageConversions.convertBitmapToTensorBuffer(greyGrid, buffer); + Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape()); + Assert.assertEquals(DataType.UINT8, buffer.getDataType()); + + int[] pixels = buffer.getIntArray(); + int index = 0; + for (int h = 0; h < 100; h++) { + for (int w = 0; w < 100; w++) { + int expected = (w < 50 && h >= 50) ? 255 : 0; + Assert.assertEquals(expected, pixels[index++]); + Assert.assertEquals(expected, pixels[index++]); + Assert.assertEquals(expected, pixels[index++]); + } + } + } + + @Test + public void testBitmapARGB8888ChannelLayout() { + // This test is not only for proving the correctness of bitmap -> TensorBuffer + // conversion, but also for us to better understand how Android Bitmap is storing pixels + // - RGB channel or other possible ordering. We use an colored grid image to understand + // what happens. It's a simple grid image with 4 grid in different colors. Passed + // through our Bitmap -> TensorBuffer conversion which simply unpack channels from an + // integer returned from `getPixel`, its channel sequence could be revealed directly. + // The result shows that Android Bitmap has no magic when loading channels. If loading + // from PNG images, channel order still remains R-G-B. + Assert.assertEquals(100, colorGrid.getWidth()); + Assert.assertEquals(100, colorGrid.getHeight()); + Assert.assertEquals(BLUE, colorGrid.getPixel(25, 25)); // left top + Assert.assertEquals(BLACK, colorGrid.getPixel(75, 25)); // right top + Assert.assertEquals(GREEN, colorGrid.getPixel(25, 75)); // left bottom + Assert.assertEquals(RED, colorGrid.getPixel(75, 75)); // right bottom + + ImageConversions.convertBitmapToTensorBuffer(colorGrid, buffer); + Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape()); + Assert.assertEquals(DataType.UINT8, buffer.getDataType()); + + int[] pixels = buffer.getIntArray(); + Assert.assertArrayEquals( + new int[] {0, 0, 255}, getChannels(pixels, 25, 25)); // left top + Assert.assertArrayEquals(new int[] {0, 0, 0}, getChannels(pixels, 25, 75)); // right top + Assert.assertArrayEquals( + new int[] {0, 255, 0}, getChannels(pixels, 75, 25)); // left bottom + Assert.assertArrayEquals( + new int[] {255, 0, 0}, getChannels(pixels, 75, 75)); // right bottom + } + + /** Helper function only for {@link #testBitmapARGB8888ChannelLayout()}. */ + private static int[] getChannels(int[] pixels, int h, int w) { + int id = (h * 100 + w) * 3; + return new int[] {pixels[id++], pixels[id++], pixels[id]}; + } + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java new file mode 100644 index 0000000..c91db9d1 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java
@@ -0,0 +1,127 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; +import static org.tensorflow.lite.support.image.ImageConversions.convertBitmapToTensorBuffer; +import static org.tensorflow.lite.support.image.ImageConversions.convertRgbTensorBufferToBitmap; + +import android.graphics.Bitmap; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; +import org.robolectric.ParameterizedRobolectricTestRunner; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameter; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameters; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +import java.util.Arrays; +import java.util.Collection; + +/** Tests of {@link ImageConversions}. */ +@RunWith(Suite.class) +@SuiteClasses({ImageConversionsTest.TensorBufferToBitmap.class, ImageConversionsTest.General.class}) +public class ImageConversionsTest { + /** Parameterized tests for the TensorBuffer data type and normalized form. */ + @RunWith(ParameterizedRobolectricTestRunner.class) + public static final class TensorBufferToBitmap extends ImageConversionsTest { + /** The data type that used to create the TensorBuffer. */ + @Parameter(0) + public DataType dataType; + + /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */ + @Parameter(1) + public boolean isNormalized; + + @Parameters(name = "dataType={0}; isNormalized={1}") + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {DataType.FLOAT32, true}, + {DataType.UINT8, true}, + {DataType.FLOAT32, false}, + {DataType.UINT8, false}, + }); + } + + @Test + public void convertRgbTensorBufferToBitmapShouldSuccess() { + TensorBuffer buffer = TestImageCreator.createRgbTensorBuffer(dataType, isNormalized); + Bitmap bitmap = convertRgbTensorBufferToBitmap(buffer); + + Bitmap expectedBitmap = TestImageCreator.createRgbBitmap(); + assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); + } + + @Test + public void convertRgbTensorBufferToBitmapShouldRejectBufferWithInvalidShape() { + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10, 3}, dataType); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, () -> convertRgbTensorBufferToBitmap(buffer)); + assertThat(exception).hasMessageThat().contains( + "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels" + + " representing R, G, B in order. The provided image shape is " + + Arrays.toString(buffer.getShape())); + } + } + + /** General tests of ImageConversionsTest. */ + @RunWith(RobolectricTestRunner.class) + public static final class General extends ImageConversionsTest { + private static final Bitmap rgbBitmap = TestImageCreator.createRgbBitmap(); + private static final TensorBuffer rgbTensorBuffer = + TestImageCreator.createRgbTensorBuffer(DataType.UINT8, false); + + @Test + public void convertBitmapToTensorBufferShouldSuccess() { + TensorBuffer intBuffer = + TensorBuffer.createFixedSize(new int[] {10, 10, 3}, DataType.UINT8); + convertBitmapToTensorBuffer(rgbBitmap, intBuffer); + assertThat(areEqualIntTensorBuffer(intBuffer, rgbTensorBuffer)).isTrue(); + } + + @Test + public void convertBitmapToTensorBufferShouldThrowShapeNotExactlySame() { + TensorBuffer intBuffer = + TensorBuffer.createFixedSize(new int[] {5, 20, 3}, DataType.UINT8); + Assert.assertThrows(IllegalArgumentException.class, + () -> convertBitmapToTensorBuffer(rgbBitmap, intBuffer)); + } + + @Test + public void convertBitmapToTensorBufferShouldCastIntToFloatIfNeeded() { + TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); + convertBitmapToTensorBuffer(rgbBitmap, floatBuffer); + assertThat(areEqualIntTensorBuffer(floatBuffer, rgbTensorBuffer)).isTrue(); + } + } + + private static boolean areEqualIntTensorBuffer(TensorBuffer tb1, TensorBuffer tb2) { + if (!Arrays.equals(tb1.getShape(), tb2.getShape())) { + return false; + } + int[] arr1 = tb1.getIntArray(); + int[] arr2 = tb2.getIntArray(); + return Arrays.equals(arr1, arr2); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java new file mode 100644 index 0000000..e9cbfc1 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java
@@ -0,0 +1,146 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; + +import android.graphics.Bitmap; + +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp; +import org.tensorflow.lite.support.image.ops.Rot90Op; + +/** Instrumented unit test for {@link ImageProcessor}. */ +@RunWith(AndroidJUnit4.class) +public final class ImageProcessorInstrumentedTest { + private Bitmap exampleBitmap; + private TensorImage input; + private ImageProcessor processor; + + private static final int EXAMPLE_WIDTH = 10; + private static final int EXAMPLE_HEIGHT = 15; + + @Before + public void setUp() { + // The default number of rotation is once. + processor = new ImageProcessor.Builder().add(new Rot90Op()).build(); + exampleBitmap = createExampleBitmap(); + input = new TensorImage(DataType.UINT8); + input.load(exampleBitmap); + } + + @Test + public void updateNumberOfRotations_rotateTwice() { + int numberOfRotations = 2; + + processor.updateNumberOfRotations(numberOfRotations); + TensorImage output = processor.process(input); + + Bitmap outputBitmap = output.getBitmap(); + assertExampleBitmapWithTwoRotations(outputBitmap); + } + + @Test + public void updateNumberOfRotationsWithOpIndex_rotateTwiceAndOpIndex0() { + int numberOfRotations = 2; + int occurrence = 0; + + processor.updateNumberOfRotations(numberOfRotations, occurrence); + TensorImage output = processor.process(input); + + Bitmap outputBitmap = output.getBitmap(); + assertExampleBitmapWithTwoRotations(outputBitmap); + } + + @Test + public void updateNumberOfRotationsWithOpIndex_negativeOpIndex() { + int numberOfRotations = 2; + int negativeOpIndex = -1; + + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class, + () -> processor.updateNumberOfRotations(numberOfRotations, negativeOpIndex)); + assertThat(exception).hasMessageThat().isEqualTo("occurrence (-1) must not be negative"); + } + + @Test + public void updateNumberOfRotationsWithOpIndex_occurrenceEqualToTheNumberOfRot90Op() { + int numberOfRotations = 2; + int occurrence = 1; + + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class, + () -> processor.updateNumberOfRotations(numberOfRotations, occurrence)); + assertThat(exception).hasMessageThat().isEqualTo( + "occurrence (1) must be less than size (1)"); + } + + @Test + public void updateNumberOfRotationsWithOpIndex_noRot90OpIsAddedToImageProcessor() { + int numberOfRotations = 2; + int occurrence = 1; + // Add an op other than Rot90Op into ImageProcessor. + ImageProcessor processor = + new ImageProcessor.Builder().add(new ResizeWithCropOrPadOp(5, 5)).build(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, + () -> processor.updateNumberOfRotations(numberOfRotations, occurrence)); + assertThat(exception).hasMessageThat().isEqualTo( + "The Rot90Op has not been added to the ImageProcessor."); + } + + @Test + public void updateNumberOfRotationsWithOpIndex_twoRot90Ops() { + // The overall effect of the two rotations is equivalent to rotating for twice. + int numberOfRotations0 = 5; + int numberOfRotations1 = 1; + + // Add two Rot90Ops into ImageProcessor. + ImageProcessor processor = + new ImageProcessor.Builder().add(new Rot90Op()).add(new Rot90Op()).build(); + processor.updateNumberOfRotations(numberOfRotations0, /*occurrence=*/0); + processor.updateNumberOfRotations(numberOfRotations1, /*occurrence=*/1); + + TensorImage output = processor.process(input); + Bitmap outputBitmap = output.getBitmap(); + assertExampleBitmapWithTwoRotations(outputBitmap); + } + + private void assertExampleBitmapWithTwoRotations(Bitmap bitmapRotated) { + assertThat(bitmapRotated.getWidth()).isEqualTo(EXAMPLE_WIDTH); + assertThat(bitmapRotated.getHeight()).isEqualTo(EXAMPLE_HEIGHT); + for (int i = 0; i < exampleBitmap.getWidth(); i++) { + for (int j = 0; j < exampleBitmap.getHeight(); j++) { + assertThat(exampleBitmap.getPixel(i, j)) + .isEqualTo(bitmapRotated.getPixel( + EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j)); + } + } + } + + private static Bitmap createExampleBitmap() { + int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT]; + for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) { + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); + } + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java new file mode 100644 index 0000000..a93ba546 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java
@@ -0,0 +1,147 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; + +import android.graphics.Bitmap; +import android.graphics.RectF; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.common.ops.NormalizeOp; +import org.tensorflow.lite.support.image.ops.ResizeOp; +import org.tensorflow.lite.support.image.ops.ResizeOp.ResizeMethod; +import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp; +import org.tensorflow.lite.support.image.ops.Rot90Op; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** Tests for {@link ImageProcessor}. */ +@RunWith(RobolectricTestRunner.class) +public final class ImageProcessorTest { + private static final int EXAMPLE_WIDTH = 10; + private static final int EXAMPLE_HEIGHT = 15; + private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH; + private static final int EXAMPLE_NUM_CHANNELS = 3; + private static final float MEAN = 127.5f; + private static final float STDDEV = 127.5f; + + @Test + public void testBuild() { + ImageProcessor processor = + new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build(); + assertThat(processor).isNotNull(); + } + + @Test + public void testNormalize() { + TensorImage input = new TensorImage(DataType.FLOAT32); + input.load(createExampleBitmap()); + ImageProcessor processor = + new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build(); + TensorImage output = processor.process(input); + + float[] pixels = output.getTensorBuffer().getFloatArray(); + assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS); + for (float p : pixels) { + assertThat(p).isAtLeast(-1); + assertThat(p).isAtMost(1); + } + } + + @Test + public void testMultipleNormalize() { + TensorImage input = new TensorImage(DataType.FLOAT32); + input.load(createExampleBitmap()); + ImageProcessor processor = + new ImageProcessor.Builder() + .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1] + .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1] + .build(); + TensorImage output = processor.process(input); + + float[] pixels = output.getTensorBuffer().getFloatArray(); + assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS); + for (float p : pixels) { + assertThat(p).isAtLeast(0); + assertThat(p).isAtMost(1); + } + } + + @Test + public void inverseTransformRectCorrectly() { + ImageProcessor processor = new ImageProcessor.Builder() + .add(new ResizeOp(200, 300, ResizeMethod.BILINEAR)) + .add(new ResizeWithCropOrPadOp(100, 200)) + .add(new Rot90Op(1)) + .add(new NormalizeOp(127, 128)) + .build(); + RectF transformed = new RectF(0, 50, 100, 150); + RectF original = processor.inverseTransform(transformed, 400, 600); + assertThat(original.top).isEqualTo(100); + assertThat(original.left).isEqualTo(200); + assertThat(original.right).isEqualTo(400); + assertThat(original.bottom).isEqualTo(300); + } + + @Test + public void resizeShouldFailWithNonRgbImages() { + int[] data = new int[] {1, 2, 3}; + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); + tensorBuffer.loadArray(data, new int[] {1, 3}); + TensorImage image = new TensorImage(); + image.load(tensorBuffer, ColorSpaceType.GRAYSCALE); + + ImageProcessor processor = new ImageProcessor.Builder() + .add(new ResizeOp(200, 300, ResizeMethod.BILINEAR)) + .build(); + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> processor.process(image)); + assertThat(exception).hasMessageThat().contains( + "Only RGB images are supported in ResizeOp, but not " + + image.getColorSpaceType().name()); + } + + @Test + public void normalizeShouldSuccessWithNonRgbImages() { + int[] data = new int[] {1, 2, 3}; + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); + tensorBuffer.loadArray(data, new int[] {1, 3}); + TensorImage image = new TensorImage(); + image.load(tensorBuffer, ColorSpaceType.GRAYSCALE); + + ImageProcessor processor = + new ImageProcessor.Builder().add(new NormalizeOp(0.5f, 1f)).build(); + TensorImage output = processor.process(image); + + float[] pixels = output.getTensorBuffer().getFloatArray(); + assertThat(pixels).isEqualTo(new float[] {0.5f, 1.5f, 2.5f}); + } + + private static Bitmap createExampleBitmap() { + int[] colors = new int[EXAMPLE_NUM_PIXELS]; + for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) { + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); + } + + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java new file mode 100644 index 0000000..e8caefc --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java
@@ -0,0 +1,181 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.when; + +import android.graphics.Bitmap; +import android.media.Image; + +import com.google.android.odml.image.BitmapMlImageBuilder; +import com.google.android.odml.image.ByteBufferMlImageBuilder; +import com.google.android.odml.image.MediaMlImageBuilder; +import com.google.android.odml.image.MlImage; +import com.google.android.odml.image.MlImage.ImageFormat; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.robolectric.ParameterizedRobolectricTestRunner; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameter; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameters; +import org.robolectric.RobolectricTestRunner; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collection; + +/** Test for {@link MlImageAdapter}. */ +@RunWith(Suite.class) +@SuiteClasses({ + MlImageAdapterTest.CreateTensorImageFromSupportedByteBufferMlImage.class, + MlImageAdapterTest.CreateTensorImageFromUnsupportedByteBufferMlImage.class, + MlImageAdapterTest.General.class, +}) +public class MlImageAdapterTest { + @RunWith(ParameterizedRobolectricTestRunner.class) + public static final class CreateTensorImageFromSupportedByteBufferMlImage + extends MlImageAdapterTest { + @Parameter(0) + @ImageFormat + public int imageFormat; + + @Parameter(1) + public ColorSpaceType colorSpaceType; + + @Parameters(name = "imageFormat={0}") + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {MlImage.IMAGE_FORMAT_RGB, ColorSpaceType.RGB}, + {MlImage.IMAGE_FORMAT_ALPHA, ColorSpaceType.GRAYSCALE}, + {MlImage.IMAGE_FORMAT_NV21, ColorSpaceType.NV21}, + {MlImage.IMAGE_FORMAT_NV12, ColorSpaceType.NV12}, + {MlImage.IMAGE_FORMAT_YV12, ColorSpaceType.YV12}, + {MlImage.IMAGE_FORMAT_YV21, ColorSpaceType.YV21}, + }); + } + + @Test + public void createTensorImageFrom_supportedByteBufferMlImage_succeeds() throws IOException { + ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer(); + buffer.rewind(); + MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build(); + + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); + + assertThat(tensorImage.getWidth()).isEqualTo(1); + assertThat(tensorImage.getHeight()).isEqualTo(2); + assertThat(tensorImage.getColorSpaceType()).isEqualTo(colorSpaceType); + assertThat(tensorImage.getBuffer().position()).isEqualTo(0); + assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(buffer); + } + } + + @RunWith(ParameterizedRobolectricTestRunner.class) + public static final class CreateTensorImageFromUnsupportedByteBufferMlImage + extends MlImageAdapterTest { + @Parameter(0) + @ImageFormat + public int imageFormat; + + @Parameters(name = "imageFormat={0}") + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {MlImage.IMAGE_FORMAT_RGBA}, + {MlImage.IMAGE_FORMAT_JPEG}, + {MlImage.IMAGE_FORMAT_YUV_420_888}, + {MlImage.IMAGE_FORMAT_UNKNOWN}, + }); + } + + @Test + public void createTensorImageFrom_unsupportedByteBufferMlImage_throws() throws IOException { + ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer(); + buffer.rewind(); + MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build(); + + assertThrows(IllegalArgumentException.class, + () -> MlImageAdapter.createTensorImageFrom(image)); + } + } + + @RunWith(RobolectricTestRunner.class) + public static final class General extends MlImageAdapterTest { + @Mock + Image mediaImageMock; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + } + + @Test + public void createTensorImageFrom_bitmapMlImage_succeeds() throws IOException { + Bitmap bitmap = Bitmap.createBitmap( + new int[] {0xff000100, 0xff000001}, 1, 2, Bitmap.Config.ARGB_8888); + MlImage image = new BitmapMlImageBuilder(bitmap).build(); + ByteBuffer expectedBuffer = ByteBuffer.allocateDirect(6); + for (byte b : new byte[] {0, 1, 0, 0, 0, 1}) { + expectedBuffer.put(b); + } + expectedBuffer.rewind(); + + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); + + assertThat(tensorImage.getWidth()).isEqualTo(1); + assertThat(tensorImage.getHeight()).isEqualTo(2); + assertThat(tensorImage.getBuffer().position()).isEqualTo(0); + assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(expectedBuffer); + } + + @Test + public void createTensorImageFrom_yuv420888MediaImageMlImage_succeeds() throws IOException { + setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_420_888, 1, 2); + MlImage image = new MediaMlImageBuilder(mediaImageMock).build(); + + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); + + assertThat(tensorImage.getWidth()).isEqualTo(1); + assertThat(tensorImage.getHeight()).isEqualTo(2); + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.YUV_420_888); + } + + @Test + public void createTensorImageFrom_nonYuv420888MediaImageMlImage_throws() + throws IOException { + setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_422_888, 1, 2); + MlImage image = new MediaMlImageBuilder(mediaImageMock).build(); + + assertThrows(IllegalArgumentException.class, + () -> MlImageAdapter.createTensorImageFrom(image)); + } + + private static void setUpMediaImageMock( + Image mediaImageMock, int imageFormat, int width, int height) { + when(mediaImageMock.getFormat()).thenReturn(imageFormat); + when(mediaImageMock.getWidth()).thenReturn(width); + when(mediaImageMock.getHeight()).thenReturn(height); + } + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java new file mode 100644 index 0000000..83b54d0 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java
@@ -0,0 +1,142 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite.support.image; + +import static com.google.common.truth.Truth.assertThat; + +import static org.tensorflow.lite.DataType.FLOAT32; +import static org.tensorflow.lite.DataType.UINT8; +import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleBitmap; +import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleTensorBuffer; +import static org.tensorflow.lite.support.image.TestImageCreator.createRgbBitmap; +import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensorBuffer; + +import android.graphics.Bitmap; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +@RunWith(JUnit4.class) +public final class TensorImageInstrumentedTest { + /** + * Difference between the pair of float and uint8 values. It is used to test the data + * conversion. + */ + private static final float DELTA = 0.1f; + + // Note that parameterized test with android_library_instrumentation_tests is currently not + // supported in internally. + @Test + public void loadAndGetBitmapSucceedsWithFloatBufferFloatImage() { + DataType tensorBufferDataType = FLOAT32; + DataType tensorImageDataType = FLOAT32; + boolean isNormalized = true; + ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE; + + TensorBuffer tensorBuffer = + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); + TensorImage tensorImage = new TensorImage(tensorImageDataType); + + tensorImage.load(tensorBuffer, colorSpaceType); + Bitmap bitmap = tensorImage.getBitmap(); + + Bitmap expectedBitmap = createBitmap(colorSpaceType); + assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); + } + + @Test + public void loadAndGetBitmapSucceedsWithFloatBufferUINT8Image() { + DataType tensorBufferDataType = FLOAT32; + DataType tensorImageDataType = UINT8; + boolean isNormalized = false; + ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE; + + TensorBuffer tensorBuffer = + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); + TensorImage tensorImage = new TensorImage(tensorImageDataType); + + tensorImage.load(tensorBuffer, colorSpaceType); + Bitmap bitmap = tensorImage.getBitmap(); + + Bitmap expectedBitmap = createBitmap(colorSpaceType); + assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); + } + + @Test + public void loadAndGetBitmapSucceedsWithUINT8BufferFloatImage() { + DataType tensorBufferDataType = UINT8; + DataType tensorImageDataType = FLOAT32; + boolean isNormalized = true; + ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE; + + TensorBuffer tensorBuffer = + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); + TensorImage tensorImage = new TensorImage(tensorImageDataType); + + tensorImage.load(tensorBuffer, colorSpaceType); + Bitmap bitmap = tensorImage.getBitmap(); + + Bitmap expectedBitmap = createBitmap(colorSpaceType); + assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); + } + + @Test + public void loadAndGetBitmapSucceedsWithUINT8BufferUINT8Image() { + DataType tensorBufferDataType = UINT8; + DataType tensorImageDataType = UINT8; + boolean isNormalized = false; + ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE; + + TensorBuffer tensorBuffer = + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); + TensorImage tensorImage = new TensorImage(tensorImageDataType); + + tensorImage.load(tensorBuffer, colorSpaceType); + Bitmap bitmap = tensorImage.getBitmap(); + + Bitmap expectedBitmap = createBitmap(colorSpaceType); + assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); + } + + private static TensorBuffer createTensorBuffer( + DataType dataType, boolean isNormalized, ColorSpaceType colorSpaceType, float delta) { + switch (colorSpaceType) { + case RGB: + return createRgbTensorBuffer(dataType, isNormalized, delta); + case GRAYSCALE: + return createGrayscaleTensorBuffer(dataType, isNormalized, delta); + default: + break; + } + throw new IllegalArgumentException( + "The ColorSpaceType, " + colorSpaceType + ", is unsupported."); + } + + private static Bitmap createBitmap(ColorSpaceType colorSpaceType) { + switch (colorSpaceType) { + case RGB: + return createRgbBitmap(); + case GRAYSCALE: + return createGrayscaleBitmap(); + default: + break; + } + throw new IllegalArgumentException( + "The ColorSpaceType, " + colorSpaceType + ", is unsupported."); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java new file mode 100644 index 0000000..b3130f4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java
@@ -0,0 +1,735 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.when; +import static org.tensorflow.lite.DataType.FLOAT32; +import static org.tensorflow.lite.DataType.UINT8; +import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleBitmap; +import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleTensorBuffer; +import static org.tensorflow.lite.support.image.TestImageCreator.createRgbBitmap; +import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensorBuffer; + +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.graphics.Color; +import android.graphics.ImageFormat; +import android.media.Image; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.robolectric.ParameterizedRobolectricTestRunner; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameter; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameters; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collection; + +/** Tests of {@link org.tensorflow.lite.support.image.TensorImage}. */ +@RunWith(Suite.class) +@SuiteClasses( + {TensorImageTest.General.class, TensorImageTest.LoadTensorBufferWithRgbAndGrayscale.class, + TensorImageTest.LoadTensorBufferWithInvalidShapeTest.class, + TensorImageTest.LoadTensorBufferWithYUV.class, + TensorImageTest.LoadTensorBufferWithImageProperties.class}) +public class TensorImageTest { + @RunWith(RobolectricTestRunner.class) + public static final class General extends TensorImageTest { + private static final Bitmap exampleBitmap = createExampleBitmap(); + private static final float[] exampleFloatPixels = createExampleFloatPixels(); + private static final int[] exampleUint8Pixels = createExampleUint8Pixels(); + + private static final int EXAMPLE_WIDTH = 5; + private static final int EXAMPLE_HEIGHT = 10; + private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH; + private static final int EXAMPLE_NUM_CHANNELS = 3; + private static final int[] EXAMPLE_SHAPE = { + EXAMPLE_HEIGHT, EXAMPLE_WIDTH, EXAMPLE_NUM_CHANNELS}; + private static final float MEAN = 127.5f; + private static final float STDDEV = 127.5f; + + @Mock + Image imageMock; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void defaultConstructorCreatesUint8TensorImage() { + TensorImage image = new TensorImage(); + assertThat(image.getDataType()).isEqualTo(UINT8); + } + + @Test + public void createFromSucceedsWithUint8TensorImage() { + TensorImage uint8Image = new TensorImage(UINT8); + uint8Image.load(new int[] {1, 2, 3, 4, -5, 600}, new int[] {2, 1, 3}); + + TensorImage floatImage = TensorImage.createFrom(uint8Image, FLOAT32); + float[] pixels = floatImage.getTensorBuffer().getFloatArray(); + assertThat(pixels).isEqualTo(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 0.0f, 255.0f}); + } + + @Test + public void createFromSucceedsWithFloatTensorImage() { + TensorImage floatImage = new TensorImage(FLOAT32); + floatImage.load(new float[] {1, 2.495f, 3.5f, 4.5f, -5, 600}, new int[] {2, 1, 3}); + + TensorImage uint8Image = TensorImage.createFrom(floatImage, UINT8); + int[] pixels = uint8Image.getTensorBuffer().getIntArray(); + assertThat(pixels).isEqualTo(new int[] {1, 2, 3, 4, 0, 255}); + } + + @Test + public void loadBitmapSucceedsWithUint8TensorImage() { + Bitmap rgbBitmap = createRgbBitmap(); + TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(UINT8, false, 0.0f); + TensorImage uint8Image = new TensorImage(UINT8); + + uint8Image.load(rgbBitmap); + assertThat(uint8Image.getBitmap().sameAs(rgbBitmap)).isTrue(); + assertEqualTensorBuffers(uint8Image.getTensorBuffer(), rgbTensorBuffer); + assertThat(uint8Image.getDataType()).isEqualTo(UINT8); + } + + @Test + public void loadBitmapSucceedsWithFloatTensorImage() { + Bitmap rgbBitmap = createRgbBitmap(); + TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(FLOAT32, false, 0.0f); + TensorImage floatImage = new TensorImage(FLOAT32); + + floatImage.load(rgbBitmap); + assertThat(floatImage.getBitmap().sameAs(rgbBitmap)).isTrue(); + assertEqualTensorBuffers(floatImage.getTensorBuffer(), rgbTensorBuffer); + assertThat(floatImage.getDataType()).isEqualTo(FLOAT32); + } + + @Test + public void loadFloatArrayWithUint8TensorImage() { + TensorImage uint8Image = new TensorImage(UINT8); + + uint8Image.load(exampleFloatPixels, EXAMPLE_SHAPE); + assertThat(uint8Image.getBitmap()).isNotNull(); + assertThat(uint8Image.getTensorBuffer().getFloatArray()) + .isEqualTo(new float[exampleFloatPixels.length]); // All zero because of + // normalization and casting + // when loading. + } + + @Test + public void loadFloatArrayWithFloatTensorImage() { + TensorImage floatImage = new TensorImage(FLOAT32); + + floatImage.load(exampleFloatPixels, EXAMPLE_SHAPE); + assertThat(floatImage.getTensorBuffer().getFloatArray()).isEqualTo(exampleFloatPixels); + } + + @Test + public void loadUint8ArrayWithUint8TensorImage() { + TensorImage uint8Image = new TensorImage(UINT8); + + uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE); + assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue(); + assertThat(uint8Image.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels); + } + + @Test + public void loadUint8ArrayWithFloatTensorImage() { + TensorImage floatImage = new TensorImage(FLOAT32); + + floatImage.load(exampleUint8Pixels, EXAMPLE_SHAPE); + assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels); + } + + @Test + public void loadTensorBufferWithUint8TensorImage() { + TensorImage uint8Image = new TensorImage(UINT8); + + uint8Image.load(exampleBitmap); + TensorBuffer buffer = uint8Image.getTensorBuffer(); + uint8Image.load(buffer); + assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue(); + } + + @Test + public void loadTensorBufferWithFloatTensorImage() { + TensorImage floatImage = new TensorImage(FLOAT32); + + floatImage.load(exampleBitmap); + TensorBuffer buffer = floatImage.getTensorBuffer(); + floatImage.load(buffer); + assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels); + } + + @Test + public void loadAndGetMediaImageSucceedsWithYuv420888Format() { + setUpImageMock(imageMock, ImageFormat.YUV_420_888); + TensorImage tensorImage = new TensorImage(UINT8); + + tensorImage.load(imageMock); + Image imageReturned = tensorImage.getMediaImage(); + + assertThat(imageReturned).isEqualTo(imageMock); + } + + @Test + public void loadMediaImageFailsWithNonYuv420888Format() { + setUpImageMock(imageMock, ImageFormat.YUV_422_888); + TensorImage tensorImage = new TensorImage(UINT8); + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> tensorImage.load(imageMock)); + assertThat(exception).hasMessageThat().contains( + "Only supports loading YUV_420_888 Image."); + } + + @Test + public void getBitmapWithUint8TensorImage() { + TensorImage uint8Image = new TensorImage(UINT8); + + uint8Image.load(exampleBitmap); + assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue(); + // Also check zero copy is effective here (input and output are references of the same + // object). + assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap); + // Also check we don't create new Bitmap only with reading operations. + assertThat(uint8Image.getBuffer().limit()) + .isEqualTo(EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS); + assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap); + + uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE); + assertThat(uint8Image.getBitmap()).isNotSameInstanceAs(exampleBitmap); + } + + @Test + public void getBitmapWithFloatTensorImage() { + TensorImage floatImage = new TensorImage(FLOAT32); + + floatImage.load(exampleBitmap); + assertThat(floatImage.getBitmap()).isSameInstanceAs(exampleBitmap); + } + + @Test + public void getBitmapWithEmptyTensorImage() { + TensorImage uint8Image = new TensorImage(UINT8); + + assertThrows(IllegalStateException.class, uint8Image::getBitmap); + } + + @Test + public void getMediaImageFailsWithBackedBitmap() { + TensorImage tensorImage = TensorImage.fromBitmap(exampleBitmap); + + UnsupportedOperationException exception = assertThrows( + UnsupportedOperationException.class, () -> tensorImage.getMediaImage()); + assertThat(exception).hasMessageThat().contains( + "Converting from Bitmap to android.media.Image is unsupported."); + } + + @Test + public void getMediaImageFailsWithBackedTensorBuffer() { + TensorImage tensorImage = new TensorImage(UINT8); + tensorImage.load(exampleFloatPixels, EXAMPLE_SHAPE); + + UnsupportedOperationException exception = assertThrows( + UnsupportedOperationException.class, () -> tensorImage.getMediaImage()); + assertThat(exception).hasMessageThat().contains( + "Converting from TensorBuffer to android.media.Image is unsupported."); + } + + @Test + public void getShapeOfInternalBitmapShouldSuccess() { + Bitmap bitmap = Bitmap.createBitmap(300, 400, Config.ARGB_8888); + TensorImage image = TensorImage.fromBitmap(bitmap); + + int width = image.getWidth(); + int height = image.getHeight(); + + assertThat(width).isEqualTo(300); + assertThat(height).isEqualTo(400); + } + + @Test + public void getShapeOfInternalTensorBufferShouldSuccess() { + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 400, 300, 3}, UINT8); + TensorImage image = new TensorImage(); + image.load(buffer); + + int width = image.getWidth(); + int height = image.getHeight(); + + assertThat(width).isEqualTo(300); + assertThat(height).isEqualTo(400); + } + + @Test + public void getShapeOfNullImageShouldThrow() { + TensorImage image = new TensorImage(); + + assertThrows(IllegalStateException.class, image::getHeight); + } + + @Test + public void getShapeOfACorruptedBufferShouldThrowRatherThanCrash() { + int[] data = new int[] {1, 2, 3, 4, 5, 6}; + TensorBuffer buffer = TensorBuffer.createDynamic(UINT8); + buffer.loadArray(data, new int[] {1, 1, 2, 3}); + TensorImage image = new TensorImage(); + image.load(buffer); + // Reload data but with an invalid shape, which leads to `buffer` corrupted. + int[] newData = new int[] {1, 2, 3}; + buffer.loadArray(newData, new int[] {1, 1, 1, 3}); + + assertThrows(IllegalArgumentException.class, image::getHeight); + } + + @Test + public void getColorSpaceTypeSucceedsWithBitmapARGB_8888() { + Bitmap rgbBitmap = createRgbBitmap(); + TensorImage tensorImage = TensorImage.fromBitmap(rgbBitmap); + + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB); + } + + @Test + public void getColorSpaceTypeSucceedsWithRgbTensorBuffer() { + TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false); + TensorImage tensorImage = new TensorImage(); + tensorImage.load(rgbBuffer); + + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB); + } + + @Test + public void getColorSpaceTypeSucceedsWithGrayscaleTensorBuffer() { + TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false); + TensorImage tensorImage = new TensorImage(); + tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE); + + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE); + } + + @Test + public void getColorSpaceTypeSucceedsWithRepeatedLoading() { + TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false); + TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false); + Bitmap rgbBitmap = createRgbBitmap(); + TensorImage tensorImage = new TensorImage(); + + tensorImage.load(rgbBuffer); + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB); + tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE); + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE); + tensorImage.load(rgbBitmap); + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB); + } + + @Test + public void getColorSpaceTypeFailsWhenNoImageHasBeenLoaded() { + TensorImage tensorImage = new TensorImage(); + + IllegalStateException exception = + assertThrows(IllegalStateException.class, tensorImage::getColorSpaceType); + assertThat(exception).hasMessageThat().contains("No image has been loaded yet."); + } + + /** + * Creates an example bit map, which is a 10x10 ARGB bitmap and pixels are set by: pixel[i] + * = {A: 0, B: i + 2, G: i + 1, G: i}, where i is the flatten index + */ + private static Bitmap createExampleBitmap() { + int[] colors = new int[EXAMPLE_NUM_PIXELS]; + for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) { + colors[i] = Color.rgb(i, i + 1, i + 2); + } + + return Bitmap.createBitmap( + colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); + } + + private static float[] createExampleFloatPixels() { + float[] pixels = new float[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS]; + for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) { + pixels[j++] = (i - MEAN) / STDDEV; + pixels[j++] = (i + 1 - MEAN) / STDDEV; + pixels[j++] = (i + 2 - MEAN) / STDDEV; + } + return pixels; + } + + private static int[] createExampleUint8Pixels() { + int[] pixels = new int[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS]; + for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) { + pixels[j++] = i; + pixels[j++] = i + 1; + pixels[j++] = i + 2; + } + return pixels; + } + } + + /** Parameterized tests for loading TensorBuffers with RGB and Grayscale images. */ + @RunWith(ParameterizedRobolectricTestRunner.class) + public static final class LoadTensorBufferWithRgbAndGrayscale extends TensorImageTest { + /** + * Difference between the pair of float and uint8 values. It is used to test the data + * conversion. + */ + private static final float DELTA = 0.1f; + + /** The data type that used to create the TensorBuffer. */ + @Parameter(0) + public DataType tensorBufferDataType; + + /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */ + @Parameter(1) + public boolean isNormalized; + + /** The color space type of the TensorBuffer. */ + @Parameter(2) + public ColorSpaceType colorSpaceType; + + /** The data type that used to create the TensorImage. */ + @Parameter(3) + public DataType tensorImageDataType; + + @Parameters(name = "tensorBufferDataType={0}; isNormalized={1}; colorSpaceType={2};" + + " tensorImageDataType={3}") + public static Collection<Object[]> + data() { + return Arrays.asList(new Object[][] { + {FLOAT32, true, ColorSpaceType.RGB, FLOAT32}, + {FLOAT32, false, ColorSpaceType.RGB, UINT8}, + {UINT8, true, ColorSpaceType.RGB, FLOAT32}, + {UINT8, false, ColorSpaceType.RGB, UINT8}, + }); + } + + @Test + public void loadAndGetBitmapSucceedsWithTensorBufferAndColorSpaceType() { + TensorBuffer tensorBuffer = + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); + TensorImage tensorImage = new TensorImage(tensorImageDataType); + + tensorImage.load(tensorBuffer, colorSpaceType); + Bitmap bitmap = tensorImage.getBitmap(); + + Bitmap expectedBitmap = createBitmap(colorSpaceType); + assertThat(bitmap.sameAs(expectedBitmap)).isTrue(); + } + + @Test + public void loadAndGetTensorBufferSucceedsWithTensorBufferAndColorSpaceType() { + TensorBuffer tensorBuffer = + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA); + TensorImage tensorImage = new TensorImage(tensorImageDataType); + + tensorImage.load(tensorBuffer, colorSpaceType); + TensorBuffer buffer = tensorImage.getTensorBuffer(); + + // If tensorBufferDataType is UINT8, expectedTensorBuffer should not contain delta. + float expectedResidual = tensorBufferDataType == UINT8 ? 0.f : DELTA; + TensorBuffer expectedTensorBuffer = createTensorBuffer( + tensorImageDataType, isNormalized, colorSpaceType, expectedResidual); + assertEqualTensorBuffers(buffer, expectedTensorBuffer); + } + + private static TensorBuffer createTensorBuffer(DataType dataType, boolean isNormalized, + ColorSpaceType colorSpaceType, float delta) { + switch (colorSpaceType) { + case RGB: + return createRgbTensorBuffer(dataType, isNormalized, delta); + case GRAYSCALE: + return createGrayscaleTensorBuffer(dataType, isNormalized, delta); + default: + break; + } + throw new IllegalArgumentException( + "The ColorSpaceType, " + colorSpaceType + ", is unsupported."); + } + + private static Bitmap createBitmap(ColorSpaceType colorSpaceType) { + switch (colorSpaceType) { + case RGB: + return createRgbBitmap(); + case GRAYSCALE: + return createGrayscaleBitmap(); + default: + break; + } + throw new IllegalArgumentException( + "The ColorSpaceType, " + colorSpaceType + ", is unsupported."); + } + } + + /** Parameterized tests for loading TensorBuffers with YUV images. */ + @RunWith(ParameterizedRobolectricTestRunner.class) + public static final class LoadTensorBufferWithYUV extends TensorImageTest { + private static final int HEIGHT = 2; + private static final int WIDTH = 3; + + @Parameter(0) + public ColorSpaceType colorSpaceType; + + @Parameters(name = "colorSpaceType={0}") + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {ColorSpaceType.NV12}, + {ColorSpaceType.NV21}, + {ColorSpaceType.YV12}, + {ColorSpaceType.YV21}, + }); + } + + @Test + public void loadTensorBufferWithColorSpaceShouldFail() { + TensorImage tensorImage = new TensorImage(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () + -> tensorImage.load( + TensorBuffer.createDynamic(DataType.FLOAT32), colorSpaceType)); + assertThat(exception).hasMessageThat().contains( + "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use" + + " `load(TensorBuffer, ImageProperties)` for other color space types."); + } + + @Test + public void loadTensorBufferAndGetBitmapShouldFail() { + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)]; + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); + tensorBuffer.loadArray(data, new int[] {data.length}); + + ImageProperties imageProperties = ImageProperties.builder() + .setHeight(HEIGHT) + .setWidth(WIDTH) + .setColorSpaceType(colorSpaceType) + .build(); + + TensorImage tensorImage = new TensorImage(DataType.FLOAT32); + tensorImage.load(tensorBuffer, imageProperties); + + UnsupportedOperationException exception = assertThrows( + UnsupportedOperationException.class, () -> tensorImage.getBitmap()); + assertThat(exception).hasMessageThat().contains( + "convertTensorBufferToBitmap() is unsupported for the color space type " + + colorSpaceType.name()); + } + } + + /** Parameterized tests for loading TensorBuffers with ImageProperties. */ + @RunWith(ParameterizedRobolectricTestRunner.class) + public static final class LoadTensorBufferWithImageProperties extends TensorImageTest { + private static final int HEIGHT = 2; + private static final int WIDTH = 3; + private static final int WRONG_WIDTH = 10; + + @Parameter(0) + public ColorSpaceType colorSpaceType; + + @Parameters(name = "colorSpaceType={0}") + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {ColorSpaceType.RGB}, + {ColorSpaceType.GRAYSCALE}, + {ColorSpaceType.NV12}, + {ColorSpaceType.NV21}, + {ColorSpaceType.YV12}, + {ColorSpaceType.YV21}, + }); + } + + @Test + public void loadAndGetTensorBufferShouldSucceedWithCorrectProperties() { + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)]; + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); + tensorBuffer.loadArray(data, new int[] {data.length}); + + ImageProperties imageProperties = ImageProperties.builder() + .setHeight(HEIGHT) + .setWidth(WIDTH) + .setColorSpaceType(colorSpaceType) + .build(); + + TensorImage tensorImage = new TensorImage(DataType.FLOAT32); + tensorImage.load(tensorBuffer, imageProperties); + + assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer); + } + + @Test + public void loadAndGetTensorBufferShouldSucceedWithLargerBuffer() { + // Should allow buffer to be greater than the size specified by height and width. + int moreElements = 1; + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH) + moreElements]; + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); + tensorBuffer.loadArray(data, new int[] {data.length}); + + ImageProperties imageProperties = ImageProperties.builder() + .setHeight(HEIGHT) + .setWidth(WIDTH) + .setColorSpaceType(colorSpaceType) + .build(); + + TensorImage tensorImage = new TensorImage(DataType.FLOAT32); + tensorImage.load(tensorBuffer, imageProperties); + + assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer); + } + + @Test + public void loadAndGetByteBufferShouldSucceedWithCorrectProperties() { + ByteBuffer byteBuffer = + ByteBuffer.allocate(colorSpaceType.getNumElements(HEIGHT, WIDTH)); + + ImageProperties imageProperties = ImageProperties.builder() + .setHeight(HEIGHT) + .setWidth(WIDTH) + .setColorSpaceType(colorSpaceType) + .build(); + + TensorImage tensorImage = new TensorImage(DataType.UINT8); + tensorImage.load(byteBuffer, imageProperties); + + assertEqualByteBuffers(tensorImage.getBuffer(), byteBuffer); + } + + @Test + public void loadTensorBufferWithShouldFailWithWrongImageShape() { + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)]; + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); + tensorBuffer.loadArray(data, new int[] {data.length}); + + ImageProperties imageProperties = ImageProperties.builder() + .setHeight(HEIGHT) + .setWidth(WRONG_WIDTH) + .setColorSpaceType(colorSpaceType) + .build(); + + TensorImage tensorImage = new TensorImage(DataType.FLOAT32); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> tensorImage.load(tensorBuffer, imageProperties)); + assertThat(exception).hasMessageThat().contains(String.format( + "The given number of elements (%d) does not match the image (%s) in %d x %d. The" + + " expected number of elements should be at least %d.", + data.length, colorSpaceType.name(), HEIGHT, WRONG_WIDTH, + colorSpaceType.getNumElements(HEIGHT, WRONG_WIDTH))); + } + + @Test + public void getShapeOfInternalTensorBufferShouldSuccess() { + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)]; + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); + tensorBuffer.loadArray(data, new int[] {data.length}); + + ImageProperties imageProperties = ImageProperties.builder() + .setHeight(HEIGHT) + .setWidth(WIDTH) + .setColorSpaceType(colorSpaceType) + .build(); + + TensorImage tensorImage = new TensorImage(DataType.FLOAT32); + tensorImage.load(tensorBuffer, imageProperties); + + assertThat(tensorImage.getWidth()).isEqualTo(WIDTH); + assertThat(tensorImage.getHeight()).isEqualTo(HEIGHT); + } + } + + /** Parameterized tests for loading TensorBuffer with invalid shapes. */ + @RunWith(ParameterizedRobolectricTestRunner.class) + public static final class LoadTensorBufferWithInvalidShapeTest extends TensorImageTest { + private static final String RGB_ASSERT_SHAPE_MESSAGE = + "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels" + + " representing R, G, B in order. The provided image shape is "; + private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE = + "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image" + + " shape is "; + + @Parameter(0) + public ColorSpaceType colorSpaceType; + + /** The shape that does not match the colorSpaceType. */ + @Parameter(1) + public int[] invalidShape; + + @Parameter(2) + public String errorMessage; + + @Parameters(name = "colorSpaceType={0}; invalidShape={1}") + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20}, + GRAYSCALE_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3}, + GRAYSCALE_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4}, + GRAYSCALE_ASSERT_SHAPE_MESSAGE}, + {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE}, + }); + } + + @Test + public void loadTensorBufferWithInvalidShape() { + TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(invalidShape, UINT8); + TensorImage tensorImage = new TensorImage(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> tensorImage.load(tensorBuffer, colorSpaceType)); + assertThat(exception).hasMessageThat().contains( + errorMessage + Arrays.toString(invalidShape)); + } + } + + private static void assertEqualTensorBuffers( + TensorBuffer tensorBuffer1, TensorBuffer tensorBuffer2) { + assertEqualByteBuffers(tensorBuffer1.getBuffer(), tensorBuffer2.getBuffer()); + assertArrayEquals(tensorBuffer1.getShape(), tensorBuffer2.getShape()); + } + + private static void assertEqualByteBuffers(ByteBuffer buffer1, ByteBuffer buffer2) { + buffer1.rewind(); + buffer2.rewind(); + assertThat(buffer1.equals(buffer2)).isTrue(); + } + + private static void setUpImageMock(Image imageMock, int imageFormat) { + when(imageMock.getFormat()).thenReturn(imageFormat); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java new file mode 100644 index 0000000..4ac2eca --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java
@@ -0,0 +1,128 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import android.graphics.Bitmap; +import android.graphics.Color; + +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +import java.nio.ByteBuffer; + +/** Creates test images for other test files. */ +final class TestImageCreator { + /** + * Creates an example bitmap, which is a 10x10 ARGB bitmap and pixels are set by: <br> + * pixel[i] = {A: 255, B: i + 2, G: i + 1, R: i}, where i is the flatten index. + */ + static Bitmap createRgbBitmap() { + int[] colors = new int[100]; + for (int i = 0; i < 100; i++) { + colors[i] = Color.rgb(i, i + 1, i + 2); + } + return Bitmap.createBitmap(colors, 10, 10, Bitmap.Config.ARGB_8888); + } + + /** + * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap. + * + * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is + * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...]. + * + * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w, 3) + */ + static TensorBuffer createRgbTensorBuffer(DataType dataType, boolean isNormalized) { + return createRgbTensorBuffer(dataType, isNormalized, /*delta=*/0.1f); + } + + /** + * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap. + * + * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w) + * @param delta the delta that applied to the float values, such that the float array is [0 + + + * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...] + */ + static TensorBuffer createRgbTensorBuffer( + DataType dataType, boolean isNormalized, float delta) { + float[] rgbValues = new float[300]; + for (int i = 0, j = 0; i < 100; i++) { + rgbValues[j++] = i + delta; + rgbValues[j++] = i + 1 + delta; + rgbValues[j++] = i + 2 + delta; + } + + int[] shape = isNormalized ? new int[] {1, 10, 10, 3} : new int[] {10, 10, 3}; + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType); + // If dataType is UINT8, rgbValues will be converted into uint8, such as from + // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...]. + buffer.loadArray(rgbValues, shape); + return buffer; + } + + /** + * Creates an example bitmap, which is a 10x10 ALPHA_8 bitmap and pixels are set by: <br> + * pixel[i] = i, where i is the flatten index. + */ + static Bitmap createGrayscaleBitmap() { + byte[] grayValues = new byte[100]; + for (int i = 0; i < 100; i++) { + grayValues[i] = (byte) i; + } + ByteBuffer buffer = ByteBuffer.wrap(grayValues); + Bitmap bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ALPHA_8); + buffer.rewind(); + bitmap.copyPixelsFromBuffer(buffer); + return bitmap; + } + + /** + * Creates a 10*10 float or uint8 TensorBuffer representing the same image in + * createGrayscaleBitmap. + * + * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is + * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...]. + * + * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w) + */ + static TensorBuffer createGrayscaleTensorBuffer(DataType dataType, boolean isNormalized) { + return createGrayscaleTensorBuffer(dataType, isNormalized, /*delta=*/0.1f); + } + + /** + * Creates a 10*10 float or uint8 TensorBuffer representing the same image in + * createGrayscaleBitmap. + * + * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w) + * @param delta the delta that applied to the float values, such that the float array is [0 + + * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...] + */ + static TensorBuffer createGrayscaleTensorBuffer( + DataType dataType, boolean isNormalized, float delta) { + float[] grayValues = new float[100]; + for (int i = 0; i < 100; i++) { + grayValues[i] = i + delta; + } + int[] shape = isNormalized ? new int[] {1, 10, 10, 1} : new int[] {10, 10}; + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType); + // If dataType is UINT8, grayValues will be converted into uint8, such as from + // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...]. + buffer.loadArray(grayValues, shape); + return buffer; + } + + private TestImageCreator() {} +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java new file mode 100644 index 0000000..070e1789 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java
@@ -0,0 +1,94 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image.ops; + +import static com.google.common.truth.Truth.assertThat; + +import android.graphics.Bitmap; +import android.graphics.PointF; + +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.image.ImageProcessor; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.support.image.ops.ResizeOp.ResizeMethod; + +/** Instrumented unit test for {@link ResizeOp}. */ +@RunWith(AndroidJUnit4.class) +public class ResizeOpInstrumentedTest { + private static final int EXAMPLE_WIDTH = 10; + private static final int EXAMPLE_HEIGHT = 15; + + private Bitmap exampleBitmap; + private TensorImage input; + + @Before + public void setUp() { + exampleBitmap = createExampleBitmap(); + input = new TensorImage(DataType.UINT8); + input.load(exampleBitmap); + } + + @Test + public void resizeShouldSuccess() { + int targetWidth = EXAMPLE_WIDTH * 2; + int targetHeight = EXAMPLE_HEIGHT * 2; + ImageProcessor processor = + new ImageProcessor.Builder() + .add(new ResizeOp(targetHeight, targetWidth, ResizeMethod.NEAREST_NEIGHBOR)) + .build(); + TensorImage output = processor.process(input); + + Bitmap outputBitmap = output.getBitmap(); + assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth); + assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight); + for (int i = 0; i < outputBitmap.getWidth(); i++) { + for (int j = 0; j < outputBitmap.getHeight(); j++) { + int expected = exampleBitmap.getPixel(i / 2, j / 2); + assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected); + } + } + } + + @Test + public void inverseTransformPointShouldSuccess() { + ResizeOp op = new ResizeOp(200, 300, ResizeMethod.NEAREST_NEIGHBOR); + PointF transformed = new PointF(32.0f, 42.0f); + // The original image size is 900x400 assumed + PointF original = op.inverseTransform(transformed, 400, 900); + assertThat(original.x).isEqualTo(96); + assertThat(original.y).isEqualTo(84); + PointF outside = op.inverseTransform(new PointF(500, 1000), 400, 900); + assertThat(outside.x).isEqualTo(1500); + assertThat(outside.y).isEqualTo(2000); + } + + /** + * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] = + * {A: 255, B: i + 2, G: i + 1, G: i}, where i is the flatten index + */ + private static Bitmap createExampleBitmap() { + int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT]; + for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) { + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); + } + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java new file mode 100644 index 0000000..85c7779 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java
@@ -0,0 +1,159 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image.ops; + +import static com.google.common.truth.Truth.assertThat; + +import android.graphics.Bitmap; +import android.graphics.PointF; + +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.image.ImageProcessor; +import org.tensorflow.lite.support.image.TensorImage; + +/** Instrumented unit test for {@link ResizeWithCropOrPadOp}. */ +@RunWith(AndroidJUnit4.class) +public class ResizeWithCropOrPadOpInstrumentedTest { + private Bitmap exampleBitmap; + private TensorImage input; + + private static final int EXAMPLE_WIDTH = 10; + private static final int EXAMPLE_HEIGHT = 15; + + @Before + public void setUp() { + exampleBitmap = createExampleBitmap(); + input = new TensorImage(DataType.UINT8); + input.load(exampleBitmap); + } + + @Test + public void testResizeWithCrop() { + int targetWidth = 6; + int targetHeight = 5; + ImageProcessor processor = + new ImageProcessor.Builder() + .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth)) + .build(); + TensorImage output = processor.process(input); + + Bitmap outputBitmap = output.getBitmap(); + assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth); + assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight); + for (int i = 0; i < outputBitmap.getWidth(); i++) { + for (int j = 0; j < outputBitmap.getHeight(); j++) { + int expected = exampleBitmap.getPixel(i + (EXAMPLE_WIDTH - targetWidth) / 2, + j + (EXAMPLE_HEIGHT - targetHeight) / 2); + assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected); + } + } + } + + @Test + public void testResizeWithPad() { + int targetWidth = 15; + int targetHeight = 20; + ImageProcessor processor = + new ImageProcessor.Builder() + .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth)) + .build(); + TensorImage output = processor.process(input); + // Pad 2 rows / columns on top / left, and 3 rows / columns on bottom / right + + Bitmap outputBitmap = output.getBitmap(); + assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth); + assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight); + int leftPad = (targetWidth - EXAMPLE_WIDTH) / 2; + int topPad = (targetHeight - EXAMPLE_HEIGHT) / 2; + for (int i = 0; i < outputBitmap.getWidth(); i++) { + for (int j = 0; j < outputBitmap.getHeight(); j++) { + int expected = 0; // ZERO padding + if (i >= leftPad && i < leftPad + EXAMPLE_WIDTH && j >= topPad + && j < topPad + EXAMPLE_HEIGHT) { + expected = exampleBitmap.getPixel(i - leftPad, j - topPad); + } + assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected); + } + } + } + + @Test + public void testResizeWithCropAndPad() { + int targetSize = 12; + // Pad 1 column on left & right, crop 1 row on top and 2 rows on bottom + ImageProcessor processor = new ImageProcessor.Builder() + .add(new ResizeWithCropOrPadOp(targetSize, targetSize)) + .build(); + TensorImage output = processor.process(input); + + Bitmap outputBitmap = output.getBitmap(); + assertThat(outputBitmap.getWidth()).isEqualTo(targetSize); + assertThat(outputBitmap.getHeight()).isEqualTo(targetSize); + + int leftPad = (targetSize - EXAMPLE_WIDTH) / 2; + int topCrop = (EXAMPLE_HEIGHT - targetSize) / 2; + for (int i = 0; i < outputBitmap.getWidth(); i++) { + for (int j = 0; j < outputBitmap.getHeight(); j++) { + int expected = 0; + if (i >= leftPad && i < leftPad + EXAMPLE_WIDTH) { + expected = exampleBitmap.getPixel(i - leftPad, j + topCrop); + } + assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected); + } + } + } + + @Test + public void inverseTransformCorrectlyWhenCropped() { + ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300); + // The point (100, 50) is transformed from 600x500 image + PointF original = op.inverseTransform(new PointF(100, 50), 500, 600); + assertThat(original.x).isEqualTo(250); + assertThat(original.y).isEqualTo(150); + PointF cropped = op.inverseTransform(new PointF(-10, -10), 500, 600); + assertThat(cropped.x).isEqualTo(140); + assertThat(cropped.y).isEqualTo(90); + } + + @Test + public void inverseTransformCorrectlyWhenPadded() { + ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300); + // The point (100, 50) is transformed from 100x200 image + PointF original = op.inverseTransform(new PointF(100, 50), 200, 100); + assertThat(original.x).isEqualTo(0); + assertThat(original.y).isEqualTo(0); + PointF outside = op.inverseTransform(new PointF(50, 10), 200, 100); + assertThat(outside.x).isEqualTo(-50); + assertThat(outside.y).isEqualTo(-40); + } + + /** + * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] = + * {A: 255, R: i + 2, G: i + 1, B: i}, where i is the flatten index + */ + private static Bitmap createExampleBitmap() { + int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT]; + for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) { + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); + } + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java new file mode 100644 index 0000000..d00fe0e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java
@@ -0,0 +1,99 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image.ops; + +import static com.google.common.truth.Truth.assertThat; + +import android.graphics.Bitmap; +import android.graphics.PointF; + +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.image.ImageProcessor; +import org.tensorflow.lite.support.image.TensorImage; + +/** Instrumented unit test for {@link Rot90Op}. */ +@RunWith(AndroidJUnit4.class) +public class Rot90OpInstrumentedTest { + private Bitmap exampleBitmap; + private TensorImage input; + + private static final int EXAMPLE_WIDTH = 10; + private static final int EXAMPLE_HEIGHT = 15; + + @Before + public void setUp() { + exampleBitmap = createExampleBitmap(); + input = new TensorImage(DataType.UINT8); + input.load(exampleBitmap); + } + + @Test + public void testRot90() { + ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op()).build(); + TensorImage output = processor.process(input); + + Bitmap outputBitmap = output.getBitmap(); + assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_HEIGHT); + assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_WIDTH); + for (int i = 0; i < exampleBitmap.getWidth(); i++) { + for (int j = 0; j < exampleBitmap.getHeight(); j++) { + assertThat(exampleBitmap.getPixel(i, j)) + .isEqualTo(outputBitmap.getPixel(j, EXAMPLE_WIDTH - 1 - i)); + } + } + } + + @Test + public void testRot90Twice() { + ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op(2)).build(); + TensorImage output = processor.process(input); + + Bitmap outputBitmap = output.getBitmap(); + assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_WIDTH); + assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_HEIGHT); + for (int i = 0; i < exampleBitmap.getWidth(); i++) { + for (int j = 0; j < exampleBitmap.getHeight(); j++) { + assertThat(exampleBitmap.getPixel(i, j)) + .isEqualTo(outputBitmap.getPixel( + EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j)); + } + } + } + + @Test + public void inverseTransformCorrectlyWhenRotated() { + Rot90Op op = new Rot90Op(3); + PointF original = op.inverseTransform(new PointF(20, 10), 200, 100); + assertThat(original.x).isEqualTo(10); + assertThat(original.y).isEqualTo(180); + PointF outside = op.inverseTransform(new PointF(-10, 110), 200, 100); + assertThat(outside.x).isEqualTo(110); + assertThat(outside.y).isEqualTo(210); + } + + private static Bitmap createExampleBitmap() { + int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT]; + for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) { + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2); + } + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java new file mode 100644 index 0000000..f024f68 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java
@@ -0,0 +1,97 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image.ops; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.doReturn; +import static org.tensorflow.lite.DataType.UINT8; + +import android.graphics.Bitmap; +import android.graphics.Color; +import android.graphics.ImageFormat; +import android.media.Image; + +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.image.ColorSpaceType; +import org.tensorflow.lite.support.image.ImageProcessor; +import org.tensorflow.lite.support.image.TensorImage; + +/** Instrumented unit test for {@link TransformToGrayscaleOp}. */ +@RunWith(AndroidJUnit4.class) +public class TransformToGrayScaleOpInstrumentedTest { + @Rule + public final MockitoRule mockito = MockitoJUnit.rule(); + + private TensorImage input; + + private static final int EXAMPLE_WIDTH = 2; + private static final int EXAMPLE_HEIGHT = 3; + @Mock + Image imageMock; + + @Before + public void setUp() { + Bitmap exampleBitmap = createExampleBitmap(); + input = new TensorImage(DataType.UINT8); + input.load(exampleBitmap); + } + + @Test + public void apply_onRgb_succeeds() { + ImageProcessor processor = + new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build(); + + TensorImage output = processor.process(input); + int[] pixels = output.getTensorBuffer().getIntArray(); + + assertThat(output.getWidth()).isEqualTo(EXAMPLE_WIDTH); + assertThat(output.getHeight()).isEqualTo(EXAMPLE_HEIGHT); + assertThat(output.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE); + assertThat(pixels).isEqualTo(new int[] {0, 255, 76, 29, 150, 179}); + } + + @Test + public void apply_onYuv_throws() { + setUpImageMock(imageMock, ImageFormat.YUV_420_888); + TensorImage tensorImage = new TensorImage(UINT8); + tensorImage.load(imageMock); + ImageProcessor processor = + new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build(); + + assertThrows(IllegalArgumentException.class, () -> processor.process(tensorImage)); + } + + private static Bitmap createExampleBitmap() { + int[] colors = new int[] { + Color.BLACK, Color.WHITE, Color.RED, Color.BLUE, Color.GREEN, Color.CYAN}; + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888); + } + + private static void setUpImageMock(Image imageMock, int imageFormat) { + doReturn(imageFormat).when(imageMock).getFormat(); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java new file mode 100644 index 0000000..98d1f92 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java
@@ -0,0 +1,121 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.label; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; + +/** Tests of {@link org.tensorflow.lite.support.label.Category}. */ +@RunWith(RobolectricTestRunner.class) +public final class CategoryTest { + private static final String APPLE_LABEL = "apple"; + private static final String DEFAULT_DISPLAY_NAME = ""; + private static final String APPLE_DISPLAY_NAME = "manzana"; // "apple" in Spanish. + private static final float APPLE_SCORE = 0.5f; + private static final int APPLE_INDEX = 10; + + @Test + public void createShouldSucceed() { + Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE); + + assertThat(category.getLabel()).isEqualTo(APPLE_LABEL); + assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME); + assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE); + } + + @Test + public void createWithIndexShouldSucceed() { + Category category = + Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX); + + assertThat(category.getLabel()).isEqualTo(APPLE_LABEL); + assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME); + assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE); + assertThat(category.getIndex()).isEqualTo(APPLE_INDEX); + } + + @Test + public void constructorShouldSucceed() { + Category category = new Category(APPLE_LABEL, APPLE_SCORE); + + assertThat(category.getLabel()).isEqualTo(APPLE_LABEL); + // Using the constructor, displayName will be default to an empty string. + assertThat(category.getDisplayName()).isEqualTo(DEFAULT_DISPLAY_NAME); + assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE); + } + + @Test + public void toStringWithCreateShouldProvideReadableResult() { + Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE); + String categoryString = category.toString(); + + assertThat(categoryString) + .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + APPLE_DISPLAY_NAME + + " score=" + APPLE_SCORE + " index=-1" + + ")>"); + } + + @Test + public void toStringWithCreateIndexShouldProvideReadableResult() { + Category category = + Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX); + String categoryString = category.toString(); + + assertThat(categoryString) + .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + APPLE_DISPLAY_NAME + + " score=" + APPLE_SCORE + " index=" + APPLE_INDEX + ")>"); + } + + @Test + public void toStringWithConstuctorShouldProvideReadableResult() { + Category category = new Category(APPLE_LABEL, APPLE_SCORE); + String categoryString = category.toString(); + + assertThat(categoryString) + .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + DEFAULT_DISPLAY_NAME + + " score=" + APPLE_SCORE + " index=-1" + + ")>"); + } + + @Test + public void equalsShouldSucceedWithCreate() { + Category categoryA = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE); + Category categoryB = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE); + + assertThat(categoryA).isEqualTo(categoryB); + } + + @Test + public void equalsShouldSucceedWithCreateIndex() { + Category categoryA = + Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX); + Category categoryB = + Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX); + + assertThat(categoryA).isEqualTo(categoryB); + } + + @Test + public void equalsShouldSucceedWithConstructor() { + Category categoryA = new Category(APPLE_LABEL, APPLE_SCORE); + Category categoryB = new Category(APPLE_LABEL, APPLE_SCORE); + + assertThat(categoryA).isEqualTo(categoryB); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java new file mode 100644 index 0000000..91c81c4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java
@@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.label; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +import java.util.Arrays; +import java.util.List; + +/** Tests of {@link org.tensorflow.lite.support.label.LabelUtil}. */ +@RunWith(RobolectricTestRunner.class) +public class LabelUtilTest { + @Test + public void mapIndexToStringsWithInvalidValues() { + String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"}; + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); + tensorBuffer.loadArray(new int[] {0, 1, 2, 3, 2, 5}, new int[] {1, 6}); + List<String> categories = + LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1); + assertThat(categories.toArray()) + .isEqualTo(new String[] {"apple", "banana", "cherry", "date", "cherry", ""}); + } + + @Test + public void mapFloatIndexShouldCast() { + String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"}; + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); + tensorBuffer.loadArray(new float[] {-1.1f, -0.3f, 0.3f, 1.2f, 1.8f, 1}, new int[] {1, 6}); + List<String> categories = + LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1); + assertThat(categories.toArray()) + .isEqualTo(new String[] { + "background", "apple", "apple", "banana", "banana", "banana"}); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java new file mode 100644 index 0000000..857a77a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java
@@ -0,0 +1,203 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.label; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Tests of {@link org.tensorflow.lite.support.label.TensorLabel}. */ +@RunWith(RobolectricTestRunner.class) +public final class TensorLabelTest { + @Test + public void createTensorLabelWithNullAxisLabelsShouldFail() { + int[] shape = {2}; + int[] arr = {1, 2}; + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.UINT8); + buffer.loadArray(arr, shape); + Map<Integer, List<String>> nullAxisLabels = null; + + Assert.assertThrows( + NullPointerException.class, () -> new TensorLabel(nullAxisLabels, buffer)); + } + + @Test + public void createTensorLabelWithNullTensorBufferShouldFail() { + Map<Integer, List<String>> axisLabels = new HashMap<>(); + axisLabels.put(1, Arrays.asList("a", "b", "c", "d")); + TensorBuffer nullTensorBuffer = null; + + Assert.assertThrows( + NullPointerException.class, () -> new TensorLabel(axisLabels, nullTensorBuffer)); + } + + @Test + public void createTensorLabelWithStringListShouldSuccess() { + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 4, 3}, DataType.FLOAT32); + + TensorLabel tensorLabel = new TensorLabel(Arrays.asList("a", "b", "c", "d"), buffer); + + assertThat(tensorLabel.getMapWithTensorBuffer()).isNotNull(); + assertThat(tensorLabel.getMapWithTensorBuffer().keySet()) + .contains("c"); // randomly pick one + } + + @Test + public void createTensorLabelWithEmptyShapeShouldFail() { + int[] shape = new int[] {}; + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); + Map<Integer, List<String>> axisLabels = new HashMap<>(); + axisLabels.put(1, Arrays.asList("a", "b", "c", "d")); + + Assert.assertThrows( + IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer)); + } + + @Test + public void createTensorLabelWithMismatchedAxisShouldFail() { + int[] shape = {1, 4}; + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); + Map<Integer, List<String>> axisLabels = new HashMap<>(); + axisLabels.put(0, Arrays.asList("a", "b", "c", "d")); + + Assert.assertThrows( + IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer)); + } + + @Test + public void createTensorLabelWithMismatchedShapeShouldFail() { + int[] shape = {1, 3}; + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); + Map<Integer, List<String>> axisLabels = new HashMap<>(); + axisLabels.put(1, Arrays.asList("a", "b", "c", "d")); + + Assert.assertThrows( + IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer)); + } + + @Test + public void getMapWithFloatBufferValuesShouldSuccess() { + int numberLabel = 4; + float[] inputArr = {0.5f, 0.2f, 0.2f, 0.1f}; + int[] shape = {1, numberLabel}; + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); + input.loadArray(inputArr, shape); + Map<Integer, List<String>> axisLabels = new HashMap<>(); + int labelAxis = 1; + axisLabels.put(labelAxis, Arrays.asList("a", "b", "c", "d")); + + TensorLabel tensorLabeled = new TensorLabel(axisLabels, input); + Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer(); + + for (int i = 0; i < numberLabel; i++) { + String label = axisLabels.get(labelAxis).get(i); + assertThat(map).containsKey(label); + float[] array = map.get(label).getFloatArray(); + assertThat(array).hasLength(1); + assertThat(array[0]).isEqualTo(inputArr[i]); + } + } + + @Test + public void getMapWithIntBufferValuesShouldSuccess() { + int numberLabel = 3; + int[] inputArr = {1, 2, 0}; + int[] shape = {1, 1, numberLabel}; + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8); + input.loadArray(inputArr, shape); + Map<Integer, List<String>> axisLabels = new HashMap<>(); + int labelAxis = 2; + axisLabels.put(labelAxis, Arrays.asList("x", "y", "z")); + + TensorLabel tensorLabeled = new TensorLabel(axisLabels, input); + Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer(); + + for (int i = 0; i < numberLabel; i++) { + String label = axisLabels.get(labelAxis).get(i); + assertThat(map).containsKey(label); + int[] array = map.get(label).getIntArray(); + assertThat(array).hasLength(1); + assertThat(array[0]).isEqualTo(inputArr[i]); + } + } + + @Test + public void getFloatMapShouldSuccess() { + int[] shape = {1, 3}; + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); + buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f}); + + TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer); + Map<String, Float> map = tensorLabeled.getMapWithFloatValue(); + + assertThat(map).hasSize(3); + assertThat(map).containsEntry("a", 1.0f); + assertThat(map).containsEntry("b", 2.0f); + assertThat(map).containsEntry("c", 3.0f); + } + + @Test + public void getMapFromMultiDimensionalTensorBufferShouldSuccess() { + int numberLabel = 2; + int numDim = 3; + float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f}; + int[] shape = {numberLabel, numDim}; + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); + input.loadArray(inputArr, shape); + Map<Integer, List<String>> axisLabels = new HashMap<>(); + int labelAxis = 0; + axisLabels.put(labelAxis, Arrays.asList("pos", "neg")); + + TensorLabel tensorLabeled = new TensorLabel(axisLabels, input); + Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer(); + + for (int i = 0; i < numberLabel; i++) { + String label = axisLabels.get(labelAxis).get(i); + assertThat(map).containsKey(label); + + float[] array = map.get(label).getFloatArray(); + assertThat(array).hasLength(numDim); + for (int j = 0; j < numDim; j++) { + assertThat(array[j]).isEqualTo(inputArr[i * numDim + j]); + } + } + } + + @Test + public void getCategoryListShouldSuccess() { + int[] shape = {1, 3}; + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); + buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f}); + + TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer); + List<Category> categories = tensorLabeled.getCategoryList(); + + assertThat(categories).hasSize(3); + assertThat(categories) + .containsExactly( + new Category("a", 1.0f), new Category("b", 2.0f), new Category("c", 3.0f)); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java new file mode 100644 index 0000000..c1afe99 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java
@@ -0,0 +1,122 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.label.ops; + +import static com.google.common.truth.Truth.assertThat; + +import android.content.Context; + +import androidx.test.core.app.ApplicationProvider; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.common.FileUtil; +import org.tensorflow.lite.support.label.TensorLabel; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +/** Tests of {@link org.tensorflow.lite.support.label.ops.LabelAxisOp}. */ +@RunWith(RobolectricTestRunner.class) +public final class LabelAxisOpTest { + private final Context context = ApplicationProvider.getApplicationContext(); + private static final String LABEL_PATH = "flower_labels.txt"; + + @Test + public void testAddAxisLabelByStringList() { + int numberLabel = 2; + float[] inputArr = {0.7f, 0.3f}; + + int[] shape = {numberLabel}; + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); + input.loadArray(inputArr, shape); + + List<String> labels = Arrays.asList("pos", "neg"); + LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(0, labels).build(); + TensorLabel output = op.apply(input); + Map<String, TensorBuffer> map = output.getMapWithTensorBuffer(); + + assertThat(map).containsKey("pos"); + float[] array = map.get("pos").getFloatArray(); + assertThat(array).hasLength(1); + assertThat(array[0]).isEqualTo(0.7f); + + assertThat(map).containsKey("neg"); + array = map.get("neg").getFloatArray(); + assertThat(array).hasLength(1); + assertThat(array[0]).isEqualTo(0.3f); + } + + @Test + public void testAddAxisLabelWithMultiDimensionTensor() throws IOException { + int numberLabel = 2; + int numDim = 3; + float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f}; + + int[] shape = {1, numberLabel, numDim}; + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); + input.loadArray(inputArr, shape); + + List<String> labels = Arrays.asList("pos", "neg"); + LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(1, labels).build(); + + TensorLabel output = op.apply(input); + Map<String, TensorBuffer> map = output.getMapWithTensorBuffer(); + + assertThat(map).containsKey("pos"); + float[] array = map.get("pos").getFloatArray(); + assertThat(array).hasLength(numDim); + assertThat(array).isEqualTo(new float[] {0.5f, 0.1f, 0.3f}); + + assertThat(map).containsKey("neg"); + array = map.get("neg").getFloatArray(); + assertThat(array).hasLength(numDim); + assertThat(array).isEqualTo(new float[] {0.2f, 0.2f, 0.1f}); + } + + @Test + public void testAddAxisLabelByFilePath() throws IOException { + int numberLabel = 5; + int[] inputArr = new int[numberLabel]; + for (int i = 0; i < numberLabel; i++) { + inputArr[i] = i; + } + + int[] shape = {numberLabel}; + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8); + input.loadArray(inputArr, shape); + + LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(context, 0, LABEL_PATH).build(); + TensorLabel output = op.apply(input); + Map<String, TensorBuffer> map = output.getMapWithTensorBuffer(); + + List<String> labels = FileUtil.loadLabels(context, LABEL_PATH); + for (int i = 0; i < numberLabel; i++) { + String label = labels.get(i); + + assertThat(map).containsKey(label); + + int[] array = map.get(label).getIntArray(); + assertThat(array).hasLength(1); + assertThat(array[0]).isEqualTo(inputArr[i]); + } + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java new file mode 100644 index 0000000..d7449187 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java
@@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite.support.model; + +import static com.google.common.truth.Truth.assertThat; + +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Test; +import org.junit.runner.RunWith; + +/** + * Instrumented unit test for {@link GpuDelegateProxy}. + * + * <p>In this test, "tensorflow-lite-gpu" is provided. + */ +@RunWith(AndroidJUnit4.class) +public final class GpuDelegateProxyInstrumentedTest { + @Test + public void createGpuDelegateProxyShouldSuccess() { + GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance(); + + assertThat(proxy).isNotNull(); + proxy.getNativeHandle(); + proxy.close(); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java new file mode 100644 index 0000000..4eb2e29 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java
@@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite.support.model; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; + +/** Tests of {@link org.tensorflow.lite.support.model.GpuDelegateProxy}. */ +@RunWith(RobolectricTestRunner.class) +public final class GpuDelegateProxyTest { + @Test + public void createGpuDelegateProxyWithoutDependencyShouldReturnNull() { + GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance(); + + assertThat(proxy).isNull(); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java new file mode 100644 index 0000000..342e82b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java
@@ -0,0 +1,160 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.model; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.fail; + +import android.content.Context; + +import androidx.test.core.app.ApplicationProvider; + +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.support.model.Model.Device; +import org.tensorflow.lite.support.model.Model.Options; + +import java.io.IOException; +import java.nio.MappedByteBuffer; +import java.util.HashMap; +import java.util.Map; + +/** Tests of {@link org.tensorflow.lite.support.model.Model}. */ +@RunWith(RobolectricTestRunner.class) +public final class ModelTest { + private final Context context = ApplicationProvider.getApplicationContext(); + private static final String MODEL_PATH = "add.tflite"; + + @Ignore + @Test + public void testLoadLocalModel() throws IOException { + MappedByteBuffer byteModel = new Model.Builder(context, MODEL_PATH).build().getData(); + assertThat(byteModel).isNotNull(); + } + + @Ignore + @Test + public void testBuildMultiThreadModel() throws IOException { + MappedByteBuffer byteModel = + new Model.Builder(context, MODEL_PATH).setNumThreads(4).build().getData(); + assertThat(byteModel).isNotNull(); + } + + @Ignore + @Test + public void buildModelWithOptionsShouldSuccess() throws IOException { + Options options = new Options.Builder().setNumThreads(2).setDevice(Device.NNAPI).build(); + Model model = Model.createModel(context, MODEL_PATH, options); + assertThat(model.getData()).isNotNull(); + } + + @Ignore + @Test + public void testGetModelPath() throws IOException { + String modelPath = new Model.Builder(context, MODEL_PATH).build().getPath(); + assertThat(modelPath).isEqualTo(MODEL_PATH); + } + + @Test + public void testNonExistingLocalModel() { + try { + new Model.Builder(context, "non_exist_model_file").build(); + fail(); + } catch (IOException e) { + assertThat(e).hasMessageThat().contains("non_exist_model_file"); + } + } + + @Test + public void testNullLocalModelPath() throws IOException { + try { + new Model.Builder(context, null).build(); + fail(); + } catch (NullPointerException e) { + assertThat(e).hasMessageThat().contains("File path cannot be null."); + } + } + + @Test + public void testNullContext() throws IOException { + try { + new Model.Builder(null, MODEL_PATH).build(); + fail(); + } catch (NullPointerException e) { + assertThat(e).hasMessageThat().contains("Context should not be null."); + } + } + + @Ignore + @Test + public void testGetInputTensor() throws IOException { + Options options = new Options.Builder().build(); + Model model = Model.createModel(context, MODEL_PATH, options); + assertThat(model.getInputTensor(0)).isNotNull(); + } + + @Ignore + @Test + public void testGetOutputTensor() throws IOException { + Options options = new Options.Builder().build(); + Model model = Model.createModel(context, MODEL_PATH, options); + assertThat(model.getOutputTensor(0)).isNotNull(); + } + + @Ignore + @Test + public void testRun() throws IOException { + Context context = ApplicationProvider.getApplicationContext(); + Model model = new Model.Builder(context, MODEL_PATH).build(); + runModel(model); + } + + @Ignore + @Test + public void testMultiThreadingRun() throws IOException { + Context context = ApplicationProvider.getApplicationContext(); + Model model = new Model.Builder(context, MODEL_PATH).setNumThreads(4).build(); + runModel(model); + } + + @Ignore + @Test + public void testNnApiRun() throws IOException { + Context context = ApplicationProvider.getApplicationContext(); + Model model = new Model.Builder(context, MODEL_PATH).setDevice(Device.NNAPI).build(); + runModel(model); + } + + private static void runModel(Model model) throws IOException { + // Creates the inputs. + float[] x = {1.5f}; + float[] y = {0.5f}; + float[] expectedSum = {2.0f}; + Object[] inputs = {x, y}; + + // Creates the outputs buffer. + float[] sum = new float[1]; + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, sum); + + // Runs inference. + model.run(inputs, outputs); + assertThat(sum).isEqualTo(expectedSum); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java new file mode 100644 index 0000000..82b59b3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java
@@ -0,0 +1,76 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.tensorbuffer; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; + +/** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat}. */ +@RunWith(RobolectricTestRunner.class) +public final class TensorBufferFloatTest { + @Test + public void testCreateDynamic() { + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(); + assertThat(tensorBufferFloat).isNotNull(); + } + + @Test + public void testCreateFixedSize() { + int[] shape = new int[] {1, 2, 3}; + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape); + assertThat(tensorBufferFloat).isNotNull(); + assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6); + } + + @Test + public void testCreateFixedSizeWithScalarShape() { + int[] shape = new int[] {}; + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape); + assertThat(tensorBufferFloat).isNotNull(); + assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(1); + } + + @Test + public void testCreateWithNullShape() { + int[] shape = null; + Assert.assertThrows(NullPointerException.class, () -> new TensorBufferFloat(shape)); + } + + @Test + public void testCreateWithInvalidShape() { + int[] shape = new int[] {1, -1, 2}; + Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferFloat(shape)); + } + + @Test + public void testCreateUsingShapeWithZero() { + int[] shape = new int[] {1, 0, 2}; + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape); + assertThat(tensorBufferFloat).isNotNull(); + assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(0); + } + + @Test + public void testGetDataType() { + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(); + assertThat(tensorBufferFloat.getDataType()).isEqualTo(DataType.FLOAT32); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java new file mode 100644 index 0000000..763356f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java
@@ -0,0 +1,893 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.tensorbuffer; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.FloatBuffer; +import java.util.ArrayList; + +/** Test helper class for inserting and retrieving arrays. */ +class ArrayTestRunner { + // List of TensorBuffer types to be tested. + private static final DataType[] BUFFER_TYPE_LIST = {DataType.FLOAT32, DataType.UINT8}; + // List of source arrays to be loaded into TensorBuffer during the tests. + private final ArrayList<Object> srcArrays; + // List of array data type with respect to srcArrays. + private final ArrayList<DataType> arrDataTypes; + // List of array shape with respect to srcArrays. + private final ArrayList<int[]> arrShapes; + private final int[] tensorBufferShape; + private final ExpectedResults expectedResForFloatBuf; + private final ExpectedResults expectedResForByteBuf; + + public ArrayTestRunner(Builder builder) { + if (builder.srcArrays.size() != builder.arrDataTypes.size()) { + throw new IllegalArgumentException( + "Number of source arrays and number of data types do not match."); + } + + this.srcArrays = builder.srcArrays; + this.arrDataTypes = builder.arrDataTypes; + this.arrShapes = builder.arrShapes; + this.tensorBufferShape = builder.tensorBufferShape; + this.expectedResForFloatBuf = builder.expectedResForFloatBuf; + this.expectedResForByteBuf = builder.expectedResForByteBuf; + } + + static class ExpectedResults { + public float[] floatArr; + public int[] intArr; + public int[] shape; + } + + public static class Builder { + private final ArrayList<Object> srcArrays = new ArrayList<>(); + private final ArrayList<DataType> arrDataTypes = new ArrayList<>(); + private final ArrayList<int[]> arrShapes = new ArrayList<>(); + private int[] tensorBufferShape; + private final ExpectedResults expectedResForFloatBuf = new ExpectedResults(); + private final ExpectedResults expectedResForByteBuf = new ExpectedResults(); + + public static Builder newInstance() { + return new Builder(); + } + + private Builder() {} + + /** Loads a test array into the test runner. */ + public Builder addSrcArray(Object src, int[] shape) { + // src should be a primitive 1D array. + DataType dataType = dataTypeOfArray(src); + switch (dataType) { + case INT32: + case FLOAT32: + srcArrays.add(src); + arrDataTypes.add(dataType); + arrShapes.add(shape); + return this; + default: + throw new AssertionError( + "Cannot resolve srouce arrays in the DataType of " + dataType); + } + } + + public Builder setTensorBufferShape(int[] tensorBufferShape) { + this.tensorBufferShape = tensorBufferShape; + return this; + } + + public Builder setExpectedResults( + DataType bufferType, float[] expectedFloatArr, int[] expectedIntArr) { + ExpectedResults er; + switch (bufferType) { + case UINT8: + er = expectedResForByteBuf; + break; + case FLOAT32: + er = expectedResForFloatBuf; + break; + default: + throw new AssertionError( + "Cannot test TensorBuffer in the DataType of " + bufferType); + } + + er.floatArr = expectedFloatArr; + er.intArr = expectedIntArr; + return this; + } + + public ArrayTestRunner build() { + int[] expectedShape; + if (arrShapes.isEmpty()) { + // If no array will be loaded, the array is an empty array. + expectedShape = new int[] {0}; + } else { + expectedShape = arrShapes.get(arrShapes.size() - 1); + } + expectedResForByteBuf.shape = expectedShape; + expectedResForFloatBuf.shape = expectedShape; + return new ArrayTestRunner(this); + } + } + + public static DataType[] getBufferTypeList() { + return BUFFER_TYPE_LIST; + } + + /** + * Runs tests in the following steps: 1. Create a TensorBuffer. If tensorBufferShape is null, + * create a dynamic buffer. Otherwise, create a fixed-size buffer accordingly. 2. Load arrays in + * srcArrays one by one into the TensotBuffer. 3. Get arrays for each supported primitive types + * in TensorBuffer, such as int array and float array for now. Check if the results are + * correct. 4. Repeat Step 1 to 3 for all buffer types in BUFFER_TYPE_LIST. + */ + public void run() { + for (DataType bufferDataType : BUFFER_TYPE_LIST) { + TensorBuffer tensorBuffer; + if (tensorBufferShape == null) { + tensorBuffer = TensorBuffer.createDynamic(bufferDataType); + } else { + tensorBuffer = TensorBuffer.createFixedSize(tensorBufferShape, bufferDataType); + } + for (int i = 0; i < srcArrays.size(); i++) { + switch (arrDataTypes.get(i)) { + case INT32: + int[] arrInt = (int[]) srcArrays.get(i); + tensorBuffer.loadArray(arrInt, arrShapes.get(i)); + break; + case FLOAT32: + float[] arrFloat = (float[]) srcArrays.get(i); + tensorBuffer.loadArray(arrFloat, arrShapes.get(i)); + break; + default: + break; + } + } + checkResults(tensorBuffer); + } + } + + private void checkResults(TensorBuffer tensorBuffer) { + ExpectedResults er; + switch (tensorBuffer.getDataType()) { + case UINT8: + er = expectedResForByteBuf; + break; + case FLOAT32: + er = expectedResForFloatBuf; + break; + default: + throw new AssertionError("Cannot test TensorBuffer in the DataType of " + + tensorBuffer.getDataType()); + } + + // Checks getIntArray() and getFloatArray(). + int[] resIntArr = tensorBuffer.getIntArray(); + assertThat(resIntArr).isEqualTo(er.intArr); + float[] resFloatArr = tensorBuffer.getFloatArray(); + assertThat(resFloatArr).isEqualTo(er.floatArr); + assertThat(tensorBuffer.getShape()).isEqualTo(er.shape); + + // Checks getIntValue(int index) and getFloatValue(int index). + int flatSize = tensorBuffer.getFlatSize(); + float[] resFloatValues = new float[flatSize]; + int[] resIntValues = new int[flatSize]; + for (int i = 0; i < flatSize; i++) { + resFloatValues[i] = tensorBuffer.getFloatValue(i); + resIntValues[i] = tensorBuffer.getIntValue(i); + } + assertThat(resFloatValues).isEqualTo(er.floatArr); + assertThat(resIntValues).isEqualTo(er.intArr); + } + + /** Gets the data type of an 1D array. */ + private static DataType dataTypeOfArray(Object arr) { + if (arr != null) { + Class<?> c = arr.getClass(); + if (c.isArray()) { + c = c.getComponentType(); + if (float.class.equals(c)) { + return DataType.FLOAT32; + } else if (int.class.equals(c)) { + return DataType.INT32; + } else if (byte.class.equals(c)) { + return DataType.UINT8; + } else if (long.class.equals(c)) { + return DataType.INT64; + } else if (String.class.equals(c)) { + return DataType.STRING; + } + } + } + throw new IllegalArgumentException( + "Requires a 1D array. Cannot resolve data type of " + arr.getClass().getName()); + } +} + +/** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}. */ +@RunWith(RobolectricTestRunner.class) +public final class TensorBufferTest { + // FLOAT_ARRAY1 and INT_ARRAY1 correspond to each other. + private static final int[] ARRAY1_SHAPE = new int[] {2, 3}; + private static final float[] FLOAT_ARRAY1 = new float[] {500.1f, 4.2f, 3.3f, 2.4f, 1.5f, 6.1f}; + private static final float[] FLOAT_ARRAY1_ROUNDED = + new float[] {500.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f}; + // FLOAT_ARRAY1_CAPPED and INT_ARRAY1_CAPPED correspond to the expected values when converted + // into uint8. + private static final float[] FLOAT_ARRAY1_CAPPED = + new float[] {255.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f}; + private static final int[] INT_ARRAY1 = new int[] {500, 4, 3, 2, 1, 6}; + private static final int[] INT_ARRAY1_CAPPED = new int[] {255, 4, 3, 2, 1, 6}; + // FLOAT_ARRAY2 and INT_ARRAY2 correspond to each other. + private static final int[] ARRAY2_SHAPE = new int[] {2, 1}; + private static final float[] FLOAT_ARRAY2 = new float[] {6.7f, 7.6f}; + private static final float[] FLOAT_ARRAY2_ROUNDED = new float[] {6.0f, 7.0f}; + private static final int[] INT_ARRAY2 = new int[] {6, 7}; + // FLOAT_ARRAY2 and FLOAT_ARRAY3 have the same size. + private static final int[] ARRAY3_SHAPE = new int[] {2, 1}; + private static final float[] FLOAT_ARRAY3 = new float[] {8.2f, 9.9f}; + private static final float[] FLOAT_ARRAY3_ROUNDED = new float[] {8.0f, 9.0f}; + // INT_ARRAY2 and INT_ARRAY3 have the same size. + private static final int[] INT_ARRAY3 = new int[] {8, 9}; + private static final int[] EMPTY_ARRAY_SHAPE = new int[] {0}; + private static final int[] EMPTY_INT_ARRAY = new int[0]; + private static final float[] EMPTY_FLOAT_ARRAY = new float[0]; + // Single element array which represents a scalar. + private static final int[] SCALAR_ARRAY_SHAPE = new int[] {}; + private static final float[] FLOAT_SCALAR_ARRAY = new float[] {800.2f}; + private static final float[] FLOAT_SCALAR_ARRAY_ROUNDED = new float[] {800.0f}; + private static final float[] FLOAT_SCALAR_ARRAY_CAPPED = new float[] {255.0f}; + private static final int[] INT_SCALAR_ARRAY = new int[] {800}; + private static final int[] INT_SCALAR_ARRAY_CAPPED = new int[] {255}; + // Several different ByteBuffer. + private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocateDirect(0); + private static final ByteBuffer FLOAT_BYTE_BUFFER1 = ByteBuffer.allocateDirect(24); + + static { + FLOAT_BYTE_BUFFER1.rewind(); + + FloatBuffer floatBuffer = FLOAT_BYTE_BUFFER1.asFloatBuffer(); + floatBuffer.put(FLOAT_ARRAY1); + } + + private static final ByteBuffer INT_BYTE_BUFFER2 = ByteBuffer.allocateDirect(2); + + static { + INT_BYTE_BUFFER2.rewind(); + + for (int a : INT_ARRAY2) { + INT_BYTE_BUFFER2.put((byte) a); + } + } + + @Test + public void testCreateFixedSizeTensorBufferFloat() { + int[] shape = new int[] {1, 2, 3}; + TensorBuffer tensorBufferFloat = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); + assertThat(tensorBufferFloat).isNotNull(); + assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6); + } + + @Test + public void testCreateFixedSizeTensorBufferUint8() { + int[] shape = new int[] {1, 2, 3}; + TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8); + assertThat(tensorBufferUint8).isNotNull(); + assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6); + } + + @Test + public void testCreateDynamicTensorBufferFloat() { + TensorBuffer tensorBufferFloat = TensorBuffer.createDynamic(DataType.FLOAT32); + assertThat(tensorBufferFloat).isNotNull(); + } + + @Test + public void testCreateDynamicTensorBufferUint8() { + TensorBuffer tensorBufferUint8 = TensorBuffer.createDynamic(DataType.UINT8); + assertThat(tensorBufferUint8).isNotNull(); + } + + @Test + public void testCreateTensorBufferFromFixedSize() { + int[] shape = new int[] {1, 2, 3}; + TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8); + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32); + assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3}); + } + + @Test + public void testCreateTensorBufferFromDynamicSize() { + int[] shape = new int[] {1, 2, 3}; + TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8); + src.resize(shape); + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32); + assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3}); + } + + @Test + public void testCreateTensorBufferUInt8FromUInt8() { + int[] shape = new int[] {INT_ARRAY1.length}; + TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8); + src.loadArray(INT_ARRAY1); + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8); + int[] data = dst.getIntArray(); + assertThat(data).isEqualTo(INT_ARRAY1_CAPPED); + } + + @Test + public void testCreateTensorBufferUInt8FromFloat32() { + TensorBuffer src = TensorBuffer.createDynamic(DataType.FLOAT32); + src.loadArray(FLOAT_ARRAY1, ARRAY1_SHAPE); + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8); + int[] data = dst.getIntArray(); + assertThat(data).isEqualTo(INT_ARRAY1_CAPPED); + } + + @Test + public void testCreateTensorBufferFloat32FromUInt8() { + TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8); + src.loadArray(INT_ARRAY1, ARRAY1_SHAPE); + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32); + float[] data = dst.getFloatArray(); + assertThat(data).isEqualTo(FLOAT_ARRAY1_CAPPED); + } + + @Test + public void testCreateTensorBufferFloat32FromFloat32() { + int[] shape = new int[] {FLOAT_ARRAY1.length}; + TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); + src.loadArray(FLOAT_ARRAY1); + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32); + float[] data = dst.getFloatArray(); + assertThat(data).isEqualTo(FLOAT_ARRAY1); + } + + @Test + public void testGetBuffer() throws IOException { + int[] shape = new int[] {1, 2, 3}; + TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8); + assertThat(tensorBufferUint8.getBuffer()).isNotNull(); + } + + @Test + public void testLoadAndGetIntArrayWithFixedSizeForScalarArray() throws IOException { + ArrayTestRunner.Builder.newInstance() + .addSrcArray(INT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE) + .setTensorBufferShape(SCALAR_ARRAY_SHAPE) + .setExpectedResults( + /*bufferType = */ DataType.FLOAT32, + /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_ROUNDED, + /*expectedIntArr=*/INT_SCALAR_ARRAY) + .setExpectedResults( + /*bufferType = */ DataType.UINT8, + /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_CAPPED, + /*expectedIntArr=*/INT_SCALAR_ARRAY_CAPPED) + .build() + .run(); + } + + @Test + public void testLoadAndGetFloatArrayWithFixedSizeForScalarArray() throws IOException { + ArrayTestRunner.Builder.newInstance() + .addSrcArray(FLOAT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE) + .setTensorBufferShape(SCALAR_ARRAY_SHAPE) + .setExpectedResults( + /*bufferType = */ DataType.FLOAT32, + /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY, + /*expectedIntArr=*/INT_SCALAR_ARRAY) + .setExpectedResults( + /*bufferType = */ DataType.UINT8, + /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_CAPPED, + /*expectedIntArr=*/INT_SCALAR_ARRAY_CAPPED) + .build() + .run(); + } + + @Test + public void testLoadAndGetIntArrayWithFixedSize() { + ArrayTestRunner.Builder.newInstance() + .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE) + .setTensorBufferShape(ARRAY1_SHAPE) + .setExpectedResults( + /*bufferType = */ DataType.FLOAT32, + /*expectedFloatArr=*/FLOAT_ARRAY1_ROUNDED, + /*expectedIntArr=*/INT_ARRAY1) + .setExpectedResults( + /*bufferType = */ DataType.UINT8, + /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED, + /*expectedIntArr=*/INT_ARRAY1_CAPPED) + .build() + .run(); + } + + @Test + public void testLoadAndGetFloatArrayWithFixedSize() { + ArrayTestRunner.Builder.newInstance() + .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE) + .setTensorBufferShape(ARRAY1_SHAPE) + .setExpectedResults( + /*bufferType = */ DataType.FLOAT32, + /*expectedFloatArr=*/FLOAT_ARRAY1, + /*expectedIntArr=*/INT_ARRAY1) + .setExpectedResults( + /*bufferType = */ DataType.UINT8, + /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED, + /*expectedIntArr=*/INT_ARRAY1_CAPPED) + .build() + .run(); + } + + @Test + public void testRepeatedLoadAndGetIntArrayWithSameFixedSize() { + ArrayTestRunner.Builder.newInstance() + .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE) + .addSrcArray(INT_ARRAY3, ARRAY3_SHAPE) + .setTensorBufferShape(ARRAY2_SHAPE) + .setExpectedResults( + /*bufferType = */ DataType.FLOAT32, + /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED, + /*expectedIntArr=*/INT_ARRAY3) + .setExpectedResults( + /*bufferType = */ DataType.UINT8, + /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED, + /*expectedIntArr=*/INT_ARRAY3) + .build() + .run(); + } + + @Test + public void testRepeatedLoadAndGetFloatArrayWithSameFixedSize() { + ArrayTestRunner.Builder.newInstance() + .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE) + .addSrcArray(FLOAT_ARRAY3, ARRAY3_SHAPE) + .setTensorBufferShape(ARRAY2_SHAPE) + .setExpectedResults( + /*bufferType = */ DataType.FLOAT32, + /*expectedFloatArr=*/FLOAT_ARRAY3, + /*expectedIntArr=*/INT_ARRAY3) + .setExpectedResults( + /*bufferType = */ DataType.UINT8, + /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED, + /*expectedIntArr=*/INT_ARRAY3) + .build() + .run(); + } + + @Test + public void testRepeatedLoadIntArrayWithDifferentFixedSize() { + int[] srcArr1 = INT_ARRAY1; + int[] srcArr2 = INT_ARRAY2; + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { + TensorBuffer tensorBuffer = + TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType); + tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length}); + // Load srcArr2 which had different size as srcArr1. + Assert.assertThrows(IllegalArgumentException.class, + () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length})); + } + } + + @Test + public void testRepeatedLoadFloatArrayWithDifferentFixedSize() { + float[] srcArr1 = FLOAT_ARRAY1; + float[] srcArr2 = FLOAT_ARRAY2; + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { + TensorBuffer tensorBuffer = + TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType); + tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length}); + // Load srcArr2 which had different size as srcArr1. + Assert.assertThrows(IllegalArgumentException.class, + () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length})); + } + } + + @Test + public void testLoadAndGetIntArrayWithDynamicSize() { + ArrayTestRunner.Builder.newInstance() + .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE) + .setExpectedResults( + /*bufferType = */ DataType.FLOAT32, + /*expectedFloatArr=*/FLOAT_ARRAY1_ROUNDED, + /*expectedIntArr=*/INT_ARRAY1) + .setExpectedResults( + /*bufferType = */ DataType.UINT8, + /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED, + /*expectedIntArr=*/INT_ARRAY1_CAPPED) + .build() + .run(); + } + + @Test + public void testLoadAndGetFloatArrayWithDynamicSize() { + ArrayTestRunner.Builder.newInstance() + .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE) + .setExpectedResults( + /*bufferType = */ DataType.FLOAT32, + /*expectedFloatArr=*/FLOAT_ARRAY1, + /*expectedIntArr=*/INT_ARRAY1) + .setExpectedResults( + /*bufferType = */ DataType.UINT8, + /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED, + /*expectedIntArr=*/INT_ARRAY1_CAPPED) + .build() + .run(); + } + + @Test + public void testRepeatedLoadAndGetIntArrayWithDifferentDynamicSize() { + ArrayTestRunner.Builder.newInstance() + .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE) + .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE) + .setExpectedResults( + /*bufferType = */ DataType.FLOAT32, + /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED, + /*expectedIntArr=*/INT_ARRAY2) + .setExpectedResults( + /*bufferType = */ DataType.UINT8, + /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED, + /*expectedIntArr=*/INT_ARRAY2) + .build() + .run(); + } + + @Test + public void testRepeatedLoadAndGetFloatArrayWithDifferentDynamicSize() { + ArrayTestRunner.Builder.newInstance() + .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE) + .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE) + .setExpectedResults( + /*bufferType = */ DataType.FLOAT32, + /*expectedFloatArr=*/FLOAT_ARRAY2, + /*expectedIntArr=*/INT_ARRAY2) + .setExpectedResults( + /*bufferType = */ DataType.UINT8, + /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED, + /*expectedIntArr=*/INT_ARRAY2) + .build() + .run(); + } + + @Test + public void testGetForEmptyArrayWithFixedSizeBuffer() { + ArrayTestRunner.Builder.newInstance() + .setTensorBufferShape(EMPTY_ARRAY_SHAPE) + .setExpectedResults( + /*bufferType = */ DataType.FLOAT32, + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY, + /*expectedIntArr=*/EMPTY_INT_ARRAY) + .setExpectedResults( + /*bufferType = */ DataType.UINT8, + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY, + /*expectedIntArr=*/EMPTY_INT_ARRAY) + .build() + .run(); + } + + @Test + public void testGetForEmptyArrayWithDynamicBuffer() { + ArrayTestRunner.Builder.newInstance() + .setExpectedResults( + /*bufferType = */ DataType.FLOAT32, + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY, + /*expectedIntArr=*/EMPTY_INT_ARRAY) + .setExpectedResults( + /*bufferType = */ DataType.UINT8, + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY, + /*expectedIntArr=*/EMPTY_INT_ARRAY) + .build() + .run(); + } + + @Test + public void testRepeatedLoadAndGetForEmptyArray() { + ArrayTestRunner.Builder.newInstance() + .addSrcArray(EMPTY_INT_ARRAY, EMPTY_ARRAY_SHAPE) + .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE) + .addSrcArray(EMPTY_FLOAT_ARRAY, EMPTY_ARRAY_SHAPE) + .setExpectedResults( + /*bufferType = */ DataType.FLOAT32, + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY, + /*expectedIntArr=*/EMPTY_INT_ARRAY) + .setExpectedResults( + /*bufferType = */ DataType.UINT8, + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY, + /*expectedIntArr=*/EMPTY_INT_ARRAY) + .build() + .run(); + } + + @Test + public void testLoadNullIntArrays() { + int[] nullArray = null; + int[] shape = new int[] {}; + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); + Assert.assertThrows( + NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape)); + } + } + + @Test + public void testLoadNullFloatArrays() { + float[] nullArray = null; + int[] shape = new int[] {}; + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); + Assert.assertThrows( + NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape)); + } + } + + @Test + public void testLoadFloatArraysWithNullShape() { + float[] arr = new float[] {1.0f}; + int[] nullShape = null; + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); + Assert.assertThrows( + NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape)); + } + } + + @Test + public void testLoadIntArraysWithNullShape() { + int[] arr = new int[] {1}; + int[] nullShape = null; + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); + Assert.assertThrows( + NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape)); + } + } + + @Test + public void testLoadIntArraysWithoutShapeAndArrayDoesNotMatchShape() { + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { + TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType); + Assert.assertThrows( + IllegalArgumentException.class, () -> fixedTensorBuffer.loadArray(INT_ARRAY2)); + } + } + + @Test + public void testLoadFloatArraysWithoutShapeAndArrayDoesNotMatchShape() { + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { + TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType); + Assert.assertThrows(IllegalArgumentException.class, + () -> fixedTensorBuffer.loadArray(FLOAT_ARRAY2)); + } + } + + @Test + public void testLoadByteBufferForNullBuffer() { + ByteBuffer byteBuffer = null; + int[] shape = new int[] {}; + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); + Assert.assertThrows( + NullPointerException.class, () -> tensorBuffer.loadBuffer(byteBuffer, shape)); + } + } + + @Test + public void testLoadByteBufferForEmptyBuffer() { + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); + tensorBuffer.loadBuffer(EMPTY_BYTE_BUFFER, EMPTY_ARRAY_SHAPE); + assertThat(tensorBuffer.getFlatSize()).isEqualTo(0); + } + } + + @Test + public void testLoadByteBufferWithDifferentFixedSize() { + // Create a fixed-size TensorBuffer with size 2, and load a ByteBuffer with size 5. + int[] tensorBufferShape = new int[] {2}; + TensorBuffer tensorBuffer = + TensorBuffer.createFixedSize(tensorBufferShape, DataType.FLOAT32); + Assert.assertThrows(IllegalArgumentException.class, + () -> tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE)); + } + + @Test + public void testLoadByteBufferWithMisMatchDataType() { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); + int[] wrongShape = new int[] {1}; + // Size of INT_BYTE_BUFFER is 8 bytes. It does not match the specified shape. + Assert.assertThrows(IllegalArgumentException.class, + () -> tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, wrongShape)); + } + + @Test + public void testLoadByteBufferForTensorBufferFloat() { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32); + tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE); + assertThat(tensorBuffer.getFloatArray()).isEqualTo(FLOAT_ARRAY1); + assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY1_SHAPE); + } + + @Test + public void testLoadByteBufferForTensorBufferUint8() { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); + tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, ARRAY2_SHAPE); + assertThat(tensorBuffer.getIntArray()).isEqualTo(INT_ARRAY2); + assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY2_SHAPE); + } + + @Test + public void testGetFloatValueWithInvalidIndex() { + float[] arrayWithSixElements = FLOAT_ARRAY1; + int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE; + int[] invalidIndexes = {-1, 7}; + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); + tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements); + for (int invalidIndex : invalidIndexes) { + Assert.assertThrows(IndexOutOfBoundsException.class, + () -> tensorBuffer.getFloatValue(invalidIndex)); + } + } + } + + @Test + public void testGetFloatValueFromScalarWithInvalidIndex() { + int[] shape = new int[] {}; + float[] arr = new float[] {10.0f}; + int[] invalidIndexes = + new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); + tensorBuffer.loadArray(arr, shape); + for (int invalidIndex : invalidIndexes) { + Assert.assertThrows(IndexOutOfBoundsException.class, + () -> tensorBuffer.getFloatValue(invalidIndex)); + } + } + } + + @Test + public void testGetIntValueWithInvalidIndex() { + float[] arrayWithSixElements = FLOAT_ARRAY1; + int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE; + int[] invalidIndexes = {-1, 7}; + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); + tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements); + for (int invalidIndex : invalidIndexes) { + Assert.assertThrows(IndexOutOfBoundsException.class, + () -> tensorBuffer.getIntValue(invalidIndex)); + } + } + } + + @Test + public void testGetIntValueFromScalarWithInvalidIndex() { + int[] shape = new int[] {}; + float[] arr = new float[] {10.0f}; + int[] invalidIndexes = + new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType); + tensorBuffer.loadArray(arr, shape); + for (int invalidIndex : invalidIndexes) { + Assert.assertThrows(IndexOutOfBoundsException.class, + () -> tensorBuffer.getIntValue(invalidIndex)); + } + } + } + + @Test + public void testLoadByteBufferSliceForTensorBufferFloat() { + TensorBuffer original = TensorBuffer.createDynamic(DataType.FLOAT32); + original.loadArray(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, new int[] {6}); + ByteBuffer buffer = original.getBuffer(); + // Slice original buffer to 3 sub-buffer, each of which has 2 element + int numBuffers = 3; + int numElements = 2; + int subArrayLength = numElements * original.getTypeSize(); + TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType()); + for (int i = 0; i < numBuffers; i++) { + buffer.position(i * subArrayLength); + ByteBuffer subBuffer = buffer.slice(); + // ByteBuffer.slice doesn't keep order. + subBuffer.order(buffer.order()).limit(subArrayLength); + tensorSlice.loadBuffer(subBuffer, new int[] {numElements}); + float[] arraySlice = tensorSlice.getFloatArray(); + assertThat(arraySlice.length).isEqualTo(numElements); + assertThat(arraySlice[0]).isEqualTo(i * numElements + 1); + assertThat(arraySlice[1]).isEqualTo(i * numElements + 2); + } + } + + @Test + public void testLoadByteBufferSliceForTensorBufferUInt8() { + TensorBuffer original = TensorBuffer.createDynamic(DataType.UINT8); + original.loadArray(new int[] {1, 2, 3, 4, 5, 6}, new int[] {6}); + ByteBuffer buffer = original.getBuffer(); + // Slice original buffer to 3 sub-buffer, each of which has 2 element + int numBuffers = 3; + int numElements = 2; + int subArrayLength = numElements * original.getTypeSize(); + TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType()); + for (int i = 0; i < numBuffers; i++) { + buffer.position(i * subArrayLength); + ByteBuffer subBuffer = buffer.slice(); + // ByteBuffer.slice doesn't keep order. + subBuffer.order(buffer.order()).limit(subArrayLength); + tensorSlice.loadBuffer(subBuffer, new int[] {numElements}); + int[] arraySlice = tensorSlice.getIntArray(); + assertThat(arraySlice.length).isEqualTo(numElements); + assertThat(arraySlice[0]).isEqualTo(i * numElements + 1); + assertThat(arraySlice[1]).isEqualTo(i * numElements + 2); + } + } + + @Test + public void getShapeFailsAfterByteBufferChanged() { + TensorBuffer tensorBuffer = + TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32); + ByteBuffer byteBuffer = tensorBuffer.getBuffer(); + byteBuffer.limit(5); + + IllegalStateException exception = + assertThrows(IllegalStateException.class, tensorBuffer::getShape); + assertThat(exception).hasMessageThat().contains( + "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The" + + " ByteBuffer may have been changed."); + } + + @Test + public void getFlatSizeFailsAfterByteBufferChanged() { + TensorBuffer tensorBuffer = + TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32); + ByteBuffer byteBuffer = tensorBuffer.getBuffer(); + byteBuffer.limit(5); + + IllegalStateException exception = + assertThrows(IllegalStateException.class, tensorBuffer::getFlatSize); + assertThat(exception).hasMessageThat().contains( + "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The" + + " ByteBuffer may have been changed."); + } + + @Test + public void loadReadOnlyBuffersCopiesOnWrite() { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8); + ByteBuffer originalByteBuffer = ByteBuffer.allocateDirect(1); + originalByteBuffer.put(new byte[] {99}); + originalByteBuffer.rewind(); + ByteBuffer readOnlyByteBuffer = originalByteBuffer.asReadOnlyBuffer(); + + tensorBuffer.loadBuffer(readOnlyByteBuffer, new int[] {1}); + assertThat(tensorBuffer.getBuffer()).isSameInstanceAs(readOnlyByteBuffer); + + tensorBuffer.loadArray(new int[] {42}); + assertThat(tensorBuffer.getBuffer()).isNotSameInstanceAs(readOnlyByteBuffer); + assertThat(tensorBuffer.getBuffer().get(0)).isEqualTo(42); // updated + assertThat(originalByteBuffer.get(0)).isEqualTo(99); // original one not changed + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java new file mode 100644 index 0000000..1921f4e46 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java
@@ -0,0 +1,76 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.tensorbuffer; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.tensorflow.lite.DataType; + +/** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBufferUint8}. */ +@RunWith(RobolectricTestRunner.class) +public final class TensorBufferUint8Test { + @Test + public void testCreateDynamic() { + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(); + assertThat(tensorBufferUint8).isNotNull(); + } + + @Test + public void testCreateFixedSize() { + int[] shape = new int[] {1, 2, 3}; + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape); + assertThat(tensorBufferUint8).isNotNull(); + assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6); + } + + @Test + public void testCreateFixedSizeWithScalarShape() { + int[] shape = new int[] {}; + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape); + assertThat(tensorBufferUint8).isNotNull(); + assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(1); + } + + @Test + public void testCreateWithNullShape() { + int[] shape = null; + Assert.assertThrows(NullPointerException.class, () -> new TensorBufferUint8(shape)); + } + + @Test + public void testCreateWithInvalidShape() { + int[] shape = new int[] {1, -1, 2}; + Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferUint8(shape)); + } + + @Test + public void testCreateUsingShapeWithZero() { + int[] shape = new int[] {1, 0, 2}; + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape); + assertThat(tensorBufferUint8).isNotNull(); + assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(0); + } + + @Test + public void testGetDataType() { + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(); + assertThat(tensorBufferUint8.getDataType()).isEqualTo(DataType.UINT8); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/BUILD new file mode 100644 index 0000000..120b396 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/BUILD
@@ -0,0 +1,36 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "jni_binary_with_tflite") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library_with_tflite( + name = "task_audio_native", + srcs = [ + ":libtask_audio_jni.so", + ], +) + +jni_binary_with_tflite( + name = "libtask_audio_jni.so", + srcs = [ + "//tensorflow_lite_support/java/src/native/task/audio/classifier:audio_classifier_jni.cc", + "//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc", + ], + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + tflite_deps = [ + "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver", + "//tensorflow_lite_support/cc/task/audio:audio_classifier", + "//tensorflow_lite_support/cc/utils:jni_utils", + ], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/audio/core:audio_buffer", + "//tensorflow_lite_support/cc/task/audio/proto:audio_classifier_options_cc_proto", + "//tensorflow_lite_support/cc/task/audio/proto:class_proto_inc", + "//tensorflow_lite_support/cc/task/audio/proto:classifications_proto_inc", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", + "//tensorflow_lite_support/java/jni", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/BUILD new file mode 100644 index 0000000..af9bff9 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/BUILD
@@ -0,0 +1,47 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "jni_binary_with_tflite") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["audio_classifier_jni.cc"]) + +cc_library_with_tflite( + name = "audio_classifier_native", + srcs = [ + ":libtask_audio_jni.so", + ], +) + +jni_binary_with_tflite( + name = "libtask_audio_jni.so", + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + tflite_deps = [ + ":native_without_resolver", + "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver", + ], +) + +# Shared native logic for AudioClassifier. Combine this target and customized +# version of op_resolver to build customized audio_classifier_native target. +cc_library_with_tflite( + name = "native_without_resolver", + srcs = [ + "audio_classifier_jni.cc", + "//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/audio:audio_classifier", + "//tensorflow_lite_support/cc/utils:jni_utils", + ], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/audio/core:audio_buffer", + "//tensorflow_lite_support/cc/task/audio/proto:audio_classifier_options_cc_proto", + "//tensorflow_lite_support/cc/task/audio/proto:class_proto_inc", + "//tensorflow_lite_support/cc/task/audio/proto:classifications_proto_inc", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", + "//tensorflow_lite_support/java/jni", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc new file mode 100644 index 0000000..c3c21fa4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc
@@ -0,0 +1,373 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> + +#include <memory> +#include <string> + +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/audio/audio_classifier.h" +#include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h" +#include "tensorflow_lite_support/cc/task/audio/proto/audio_classifier_options.pb.h" +#include "tensorflow_lite_support/cc/task/audio/proto/class_proto_inc.h" +#include "tensorflow_lite_support/cc/task/audio/proto/classifications_proto_inc.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" + +namespace tflite { +namespace task { +// To be provided by a link-time library +extern std::unique_ptr<OpResolver> CreateOpResolver(); + +} // namespace task +} // namespace tflite + +namespace { + +using ::tflite::support::StatusOr; +using ::tflite::support::utils::GetExceptionClassNameForStatusCode; +using ::tflite::support::utils::kIllegalArgumentException; +using ::tflite::support::utils::kInvalidPointer; +using ::tflite::support::utils::StringListToVector; +using ::tflite::support::utils::ThrowException; +using ::tflite::task::audio::AudioBuffer; +using ::tflite::task::audio::AudioClassifier; +using ::tflite::task::audio::AudioClassifierOptions; +using ::tflite::task::audio::Class; +using ::tflite::task::audio::ClassificationResult; +using ::tflite::task::core::BaseOptions; + +// TODO(b/183343074): Share the code below with ImageClassifier. + +constexpr char kCategoryClassName[] = + "org/tensorflow/lite/support/label/Category"; +constexpr char kStringClassName[] = "Ljava/lang/String;"; +constexpr char kEmptyString[] = ""; + +jobject ConvertToCategory(JNIEnv* env, const Class& classification) { + // jclass and init of Category. + jclass category_class = env->FindClass(kCategoryClassName); + jmethodID category_create = env->GetStaticMethodID( + category_class, "create", + absl::StrCat("(", kStringClassName, kStringClassName, "FI)L", + kCategoryClassName, ";") + .c_str()); + + std::string label_string = classification.has_class_name() + ? classification.class_name() + : std::to_string(classification.index()); + jstring label = env->NewStringUTF(label_string.c_str()); + std::string display_name_string = classification.has_display_name() + ? classification.display_name() + : kEmptyString; + jstring display_name = env->NewStringUTF(display_name_string.c_str()); + jobject jcategory = env->CallStaticObjectMethod( + category_class, category_create, label, display_name, + classification.score(), classification.index()); + env->DeleteLocalRef(category_class); + env->DeleteLocalRef(label); + env->DeleteLocalRef(display_name); + return jcategory; +} + +jobject ConvertToClassificationResults(JNIEnv* env, + const ClassificationResult& results) { + // jclass and init of Classifications. + jclass classifications_class = env->FindClass( + "org/tensorflow/lite/task/audio/classifier/Classifications"); + jmethodID classifications_create = env->GetStaticMethodID( + classifications_class, "create", + "(Ljava/util/List;ILjava/lang/String;)Lorg/tensorflow/lite/" + "task/audio/classifier/Classifications;"); + + // jclass, init, and add of ArrayList. + jclass array_list_class = env->FindClass("java/util/ArrayList"); + jmethodID array_list_init = + env->GetMethodID(array_list_class, "<init>", "(I)V"); + jmethodID array_list_add_method = + env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z"); + + jobject classifications_list = + env->NewObject(array_list_class, array_list_init, + static_cast<jint>(results.classifications_size())); + for (int i = 0; i < results.classifications_size(); i++) { + auto classifications = results.classifications(i); + jobject jcategory_list = env->NewObject(array_list_class, array_list_init, + classifications.classes_size()); + for (const auto& classification : classifications.classes()) { + jobject jcategory = ConvertToCategory(env, classification); + env->CallBooleanMethod(jcategory_list, array_list_add_method, jcategory); + + env->DeleteLocalRef(jcategory); + } + + std::string head_name_string = + classifications.has_head_name() + ? classifications.head_name() + : std::to_string(classifications.head_index()); + jstring head_name = env->NewStringUTF(head_name_string.c_str()); + + jobject jclassifications = env->CallStaticObjectMethod( + classifications_class, classifications_create, jcategory_list, + classifications.head_index(), head_name); + env->CallBooleanMethod(classifications_list, array_list_add_method, + jclassifications); + + env->DeleteLocalRef(head_name); + env->DeleteLocalRef(jcategory_list); + env->DeleteLocalRef(jclassifications); + } + return classifications_list; +} + +// Creates an AudioClassifierOptions proto based on the Java class. +AudioClassifierOptions ConvertToProtoOptions(JNIEnv* env, + jobject java_options, + jlong base_options_handle) { + AudioClassifierOptions proto_options; + + if (base_options_handle != kInvalidPointer) { + // proto_options will free the previous base_options and set the new one. + proto_options.set_allocated_base_options( + reinterpret_cast<BaseOptions*>(base_options_handle)); + } + + jclass java_options_class = env->FindClass( + "org/tensorflow/lite/task/audio/classifier/" + "AudioClassifier$AudioClassifierOptions"); + + jmethodID display_names_locale_id = env->GetMethodID( + java_options_class, "getDisplayNamesLocale", "()Ljava/lang/String;"); + jstring display_names_locale = static_cast<jstring>( + env->CallObjectMethod(java_options, display_names_locale_id)); + const char* pchars = env->GetStringUTFChars(display_names_locale, nullptr); + proto_options.set_display_names_locale(pchars); + env->ReleaseStringUTFChars(display_names_locale, pchars); + + jmethodID max_results_id = + env->GetMethodID(java_options_class, "getMaxResults", "()I"); + jint max_results = env->CallIntMethod(java_options, max_results_id); + proto_options.set_max_results(max_results); + + jmethodID is_score_threshold_set_id = + env->GetMethodID(java_options_class, "getIsScoreThresholdSet", "()Z"); + jboolean is_score_threshold_set = + env->CallBooleanMethod(java_options, is_score_threshold_set_id); + if (is_score_threshold_set) { + jmethodID score_threshold_id = + env->GetMethodID(java_options_class, "getScoreThreshold", "()F"); + jfloat score_threshold = + env->CallFloatMethod(java_options, score_threshold_id); + proto_options.set_score_threshold(score_threshold); + } + + jmethodID allow_list_id = env->GetMethodID( + java_options_class, "getLabelAllowList", "()Ljava/util/List;"); + jobject allow_list = env->CallObjectMethod(java_options, allow_list_id); + auto allow_list_vector = StringListToVector(env, allow_list); + for (const auto& class_name : allow_list_vector) { + proto_options.add_class_name_allowlist(class_name); + } + + jmethodID deny_list_id = env->GetMethodID( + java_options_class, "getLabelDenyList", "()Ljava/util/List;"); + jobject deny_list = env->CallObjectMethod(java_options, deny_list_id); + auto deny_list_vector = StringListToVector(env, deny_list); + for (const auto& class_name : deny_list_vector) { + proto_options.add_class_name_denylist(class_name); + } + + return proto_options; +} + +jlong CreateAudioClassifierFromOptions(JNIEnv* env, + const AudioClassifierOptions& options) { + StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or = + AudioClassifier::CreateFromOptions(options, + tflite::task::CreateOpResolver()); + if (audio_classifier_or.ok()) { + // Deletion is handled at deinitJni time. + return reinterpret_cast<jlong>(audio_classifier_or->release()); + } else { + ThrowException( + env, + GetExceptionClassNameForStatusCode(audio_classifier_or.status().code()), + "Error occurred when initializing AudioClassifier: %s", + audio_classifier_or.status().message().data()); + } + return kInvalidPointer; +} + +} // namespace + +extern "C" JNIEXPORT void JNICALL +Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_deinitJni( + JNIEnv* env, + jobject thiz, + jlong native_handle) { + delete reinterpret_cast<AudioClassifier*>(native_handle); +} + +// Creates an AudioClassifier instance from the model file descriptor. +// file_descriptor_length and file_descriptor_offset are optional. Non-possitive +// values will be ignored. +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithModelFdAndOptions( + JNIEnv* env, + jclass thiz, + jint file_descriptor, + jlong file_descriptor_length, + jlong file_descriptor_offset, + jobject java_options, + jlong base_options_handle) { + AudioClassifierOptions proto_options = + ConvertToProtoOptions(env, java_options, base_options_handle); + auto file_descriptor_meta = proto_options.mutable_base_options() + ->mutable_model_file() + ->mutable_file_descriptor_meta(); + file_descriptor_meta->set_fd(file_descriptor); + if (file_descriptor_length > 0) { + file_descriptor_meta->set_length(file_descriptor_length); + } + if (file_descriptor_offset > 0) { + file_descriptor_meta->set_offset(file_descriptor_offset); + } + return CreateAudioClassifierFromOptions(env, proto_options); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithByteBuffer( + JNIEnv* env, + jclass thiz, + jobject model_buffer, + jobject java_options, + jlong base_options_handle) { + AudioClassifierOptions proto_options = + ConvertToProtoOptions(env, java_options, base_options_handle); + // External proto generated header does not overload `set_file_content` with + // string_view, therefore GetMappedFileBuffer does not apply here. + // Creating a std::string will cause one extra copying of data. Thus, the + // most efficient way here is to set file_content using char* and its size. + proto_options.mutable_base_options()->mutable_model_file()->set_file_content( + static_cast<char*>(env->GetDirectBufferAddress(model_buffer)), + static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer))); + return CreateAudioClassifierFromOptions(env, proto_options); +} + +// TODO(b/183343074): JNI method invocation is very expensive, taking about .2ms +// each time. Consider retrieving the AudioFormat during initialization and +// caching it in JAVA layer. +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredSampleRateNative( + JNIEnv* env, + jclass thiz, + jlong native_handle) { + auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle); + StatusOr<AudioBuffer::AudioFormat> format_or = + classifier->GetRequiredAudioFormat(); + if (format_or.ok()) { + return format_or->sample_rate; + } else { + ThrowException( + env, GetExceptionClassNameForStatusCode(format_or.status().code()), + "Error occurred when getting sample rate from AudioClassifier: %s", + format_or.status().message()); + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredChannelsNative( + JNIEnv* env, + jclass thiz, + jlong native_handle) { + auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle); + StatusOr<AudioBuffer::AudioFormat> format_or = + classifier->GetRequiredAudioFormat(); + if (format_or.ok()) { + return format_or->channels; + } else { + ThrowException( + env, GetExceptionClassNameForStatusCode(format_or.status().code()), + "Error occurred when gettng channels from AudioClassifier: %s", + format_or.status().message()); + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredInputBufferSizeNative( + JNIEnv* env, + jclass thiz, + jlong native_handle) { + auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle); + return classifier->GetRequiredInputBufferSize(); +} + +extern "C" JNIEXPORT jobject JNICALL +Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_classifyNative( + JNIEnv* env, + jclass thiz, + jlong native_handle, + jbyteArray java_array, + jint channels, + jint sample_rate) { + // Get the primitive native array. Depending on the JAVA runtime, the returned + // array might be a copy of the JAVA array (or not). + jbyte* native_array = env->GetByteArrayElements(java_array, nullptr); + if (native_array == nullptr) { + ThrowException(env, kIllegalArgumentException, + "Error occurred when converting the java audio input array " + "to native array."); + return nullptr; + } + + jobject classification_results = nullptr; + + // Prepare the AudioBuffer. + AudioBuffer::AudioFormat format = {channels, sample_rate}; + const int size_in_bytes = env->GetArrayLength(java_array); + const int size_in_float = size_in_bytes / sizeof(float); + const StatusOr<std::unique_ptr<AudioBuffer>> audio_buffer_or = + AudioBuffer::Create(reinterpret_cast<float*>(native_array), size_in_float, + format); + + if (audio_buffer_or.ok()) { + // Actual classification + auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle); + auto results_or = classifier->Classify(*(audio_buffer_or.value())); + if (results_or.ok()) { + classification_results = + ConvertToClassificationResults(env, results_or.value()); + } else { + ThrowException( + env, GetExceptionClassNameForStatusCode(results_or.status().code()), + "Error occurred when classifying the audio clip: %s", + results_or.status().message().data()); + } + } else { + ThrowException( + env, + GetExceptionClassNameForStatusCode(audio_buffer_or.status().code()), + "Error occurred when creating the AudioBuffer: %s", + audio_buffer_or.status().message().data()); + } + + // Mark native_array as no longer needed. + // TODO(b/183343074): Wrap this in SimpleCleanUp. + env->ReleaseByteArrayElements(java_array, native_array, /*mode=*/0); + return classification_results; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/BUILD index d4dd7ab3..2653bf6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/BUILD
@@ -1,16 +1,25 @@ +load( + "@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", + "cc_library_with_tflite", +) + package( default_visibility = ["//tensorflow_lite_support:users"], licenses = ["notice"], # Apache 2.0 ) +exports_files(["task_jni_utils.cc"]) + # Default provider for BuiltInOpResover. Create your own target, overwrite the # function to provide a MutableOpResolver for customized OPs and/or a subset of # builtin OPs. -cc_library( +cc_library_with_tflite( name = "builtin_op_resolver", srcs = ["builtin_op_resolver.cc"], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + ], deps = [ - "@org_tensorflow//tensorflow/lite:framework", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite:op_resolver", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc index 050f49f..440a12f73b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc
@@ -13,14 +13,14 @@ limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/core/shims/cc/kernels/register.h" namespace tflite { namespace task { // Default provider for OpResolver, provides BuiltinOpResolver. std::unique_ptr<OpResolver> CreateOpResolver() { // NOLINT - return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>( - new tflite::ops::builtin::BuiltinOpResolver); + return std::unique_ptr<tflite_shims::ops::builtin::BuiltinOpResolver>( + new tflite_shims::ops::builtin::BuiltinOpResolver); } } // namespace task
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc new file mode 100644 index 0000000..75f93d6f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc
@@ -0,0 +1,55 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> + +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" + +namespace { + +using ::tflite::proto::Delegate; +using ::tflite::support::StatusOr; +using ::tflite::support::utils::ConvertToProtoDelegate; +using ::tflite::support::utils::kIllegalStateException; +using ::tflite::support::utils::kInvalidPointer; +using ::tflite::support::utils::ThrowException; +using ::tflite::task::core::BaseOptions; + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_core_TaskJniUtils_createProtoBaseOptions( + JNIEnv* env, + jclass thiz, + jint delegate, + jint num_threads) { + StatusOr<Delegate> delegate_proto_or = ConvertToProtoDelegate(delegate); + if (!delegate_proto_or.ok()) { + ThrowException(env, kIllegalStateException, + "Error occurred when converting to the proto delegate: %s", + delegate_proto_or.status().message().data()); + return kInvalidPointer; + } + + // base_options will be owned by the task proto options, such as + // ImageClassifierOptions. + BaseOptions* base_options = new BaseOptions(); + auto tflite_settings = + base_options->mutable_compute_settings()->mutable_tflite_settings(); + tflite_settings->set_delegate(delegate_proto_or.value()); + tflite_settings->mutable_cpu_settings()->set_num_threads(num_threads); + return reinterpret_cast<jlong>(base_options); +} + +} // namespace
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/BUILD index a27aba5..88498f9 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/BUILD
@@ -1,34 +1,39 @@ -load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "jni_binary_with_tflite") package( default_visibility = ["//tensorflow_lite_support:users"], licenses = ["notice"], # Apache 2.0 ) -cc_library( +cc_library_with_tflite( name = "task_text_native", - srcs = [ + tflite_jni_binaries = [ ":libtask_text_jni.so", ], ) -tflite_jni_binary( +jni_binary_with_tflite( name = "libtask_text_jni.so", srcs = [ + "//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc", "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_jni.cc", - "//tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier:bert_nl_classifier_jni.cc", + "//tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert:bert_nl_classifier_jni.cc", "//tensorflow_lite_support/java/src/native/task/text/qa:bert_question_answerer_jni.cc", ], linkscript = "//tensorflow_lite_support/java:default_version_script.lds", - deps = [ - "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier", + tflite_deps = [ "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier", - "//tensorflow_lite_support/cc/task/text/qa:bert_question_answerer", + "//tensorflow_lite_support/cc/task/text:bert_nl_classifier", + "//tensorflow_lite_support/cc/task/text:bert_question_answerer", "//tensorflow_lite_support/cc/utils:jni_utils", - "//tensorflow_lite_support/java/jni", - "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver", "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_jni_utils", - "@org_tensorflow//tensorflow/lite:framework", + "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver", + ], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", + "//tensorflow_lite_support/cc/task/text/proto:bert_nl_classifier_options_proto_inc", + "//tensorflow_lite_support/java/jni", + "@org_tensorflow//tensorflow/lite:op_resolver", "@org_tensorflow//tensorflow/lite/kernels:kernel_util", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/BUILD index 316c6d6..116df93 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/BUILD
@@ -1,4 +1,4 @@ -load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "jni_binary_with_tflite") package( default_visibility = ["//tensorflow_lite_support:users"], @@ -10,58 +10,51 @@ ]) # Default native target for nl_classifier to provide BuiltInOpResolver. -cc_library( +cc_library_with_tflite( name = "nl_classifier_native", - srcs = [ + tflite_jni_binaries = [ ":libtask_text_jni.so", ], + deps = [ + "@org_tensorflow//tensorflow/lite/kernels:deprecated_backends", + ], ) # Note: "libtask_text_jni" is hardcoded in Java to look up the .so, therefore # the name should remain the same when creating customized version of # nl_classifier_native -tflite_jni_binary( +jni_binary_with_tflite( name = "libtask_text_jni.so", linkscript = "//tensorflow_lite_support/java:default_version_script.lds", - deps = [ + tflite_deps = [ ":native_without_resolver", "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver", ], ) -# Custom ops resolver for Wear OS Smart Reply. -cc_library( - name = "custom_ops_resolver_for_smart_reply", - srcs = ["custom_ops_resolver_for_smart_reply.cc"], - deps = [ - "//knowledge/hobbes/chat/tensorflow/tflite:tflite-all-lingua-ops-resolver", - "@com_google_absl//absl/memory", - "@org_tensorflow//tensorflow/lite:framework", - "@org_tensorflow//tensorflow/lite:string_util", - "@org_tensorflow//tensorflow/lite/kernels:kernel_util", - "@org_tensorflow//tensorflow/lite/schema:schema_fbs", - ], -) - # Shared native logic for NLClassifier. Combine this target and customized # version of op_resolver to build customized nl_classifier_native target. -cc_library( +cc_library_with_tflite( name = "native_without_resolver", srcs = [ "nl_classifier_jni.cc", + "//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc", ], - deps = [ + tflite_deps = [ ":nl_classifier_jni_utils", "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier", "//tensorflow_lite_support/cc/utils:jni_utils", + ], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", "//tensorflow_lite_support/java/jni", - "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:op_resolver", "@org_tensorflow//tensorflow/lite/kernels:kernel_util", ], alwayslink = 1, ) -cc_library( +cc_library_with_tflite( name = "nl_classifier_jni_utils", srcs = [ "nl_classifier_jni_utils.cc", @@ -69,9 +62,11 @@ hdrs = [ "nl_classifier_jni_utils.h", ], - deps = [ + tflite_deps = [ "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier", "//tensorflow_lite_support/cc/utils:jni_utils", + ], + deps = [ "//tensorflow_lite_support/java/jni", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/BUILD new file mode 100644 index 0000000..e104ab89a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/BUILD
@@ -0,0 +1,36 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "jni_binary_with_tflite") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "bert_nl_classifier_jni.cc", +]) + +cc_library_with_tflite( + name = "bert_nl_classifier_native", + tflite_jni_binaries = [ + ":libtask_text_jni.so", + ], +) + +jni_binary_with_tflite( + name = "libtask_text_jni.so", + srcs = [ + "bert_nl_classifier_jni.cc", + "//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc", + ], + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + tflite_deps = [ + "//tensorflow_lite_support/cc/task/text:bert_nl_classifier", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_jni_utils", + ], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", + "//tensorflow_lite_support/cc/task/text/proto:bert_nl_classifier_options_proto_inc", + "//tensorflow_lite_support/java/jni", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc new file mode 100644 index 0000000..2daacdf --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc
@@ -0,0 +1,118 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> + +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/text/bert_nl_classifier.h" +#include "tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options_proto_inc.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" +#include "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h" + +namespace { + +using ::tflite::support::utils::GetExceptionClassNameForStatusCode; +using ::tflite::support::utils::kInvalidPointer; +using ::tflite::support::utils::ThrowException; +using ::tflite::task::core::BaseOptions; +using ::tflite::task::text::BertNLClassifier; +using ::tflite::task::text::BertNLClassifierOptions; +using ::tflite::task::text::nlclassifier::RunClassifier; + +BertNLClassifierOptions ConvertJavaBertNLClassifierOptions( + JNIEnv* env, + jobject java_options, + jlong base_options_handle) { + BertNLClassifierOptions proto_options; + + if (base_options_handle != kInvalidPointer) { + // proto_options will free the previous base_options and set the new one. + proto_options.set_allocated_base_options( + reinterpret_cast<BaseOptions*>(base_options_handle)); + } + return proto_options; +} + +} // namespace + +extern "C" JNIEXPORT void JNICALL +Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni( + JNIEnv* env, + jobject thiz, + jlong native_handle) { + delete reinterpret_cast<BertNLClassifier*>(native_handle); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByteBuffer( + JNIEnv* env, + jclass thiz, + jobject model_buffer, + jobject java_options, + jlong base_options_handle) { + BertNLClassifierOptions proto_options = ConvertJavaBertNLClassifierOptions( + env, java_options, base_options_handle); + proto_options.mutable_base_options()->mutable_model_file()->set_file_content( + static_cast<char*>(env->GetDirectBufferAddress(model_buffer)), + static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer))); + + tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> classifier_or = + BertNLClassifier::CreateFromOptions(proto_options); + if (classifier_or.ok()) { + return reinterpret_cast<jlong>(classifier_or->release()); + } else { + ThrowException( + env, GetExceptionClassNameForStatusCode(classifier_or.status().code()), + "Error occurred when initializing Bert NLClassifier: %s", + classifier_or.status().message().data()); + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFileDescriptor( + JNIEnv* env, + jclass thiz, + jint fd, + jobject java_options, + jlong base_options_handle) { + BertNLClassifierOptions proto_options = ConvertJavaBertNLClassifierOptions( + env, java_options, base_options_handle); + proto_options.mutable_base_options() + ->mutable_model_file() + ->mutable_file_descriptor_meta() + ->set_fd(fd); + + tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> classifier_or = + BertNLClassifier::CreateFromOptions(proto_options); + if (classifier_or.ok()) { + return reinterpret_cast<jlong>(classifier_or->release()); + } else { + ThrowException( + env, GetExceptionClassNameForStatusCode(classifier_or.status().code()), + "Error occurred when initializing Bert NLClassifier: %s", + classifier_or.status().message().data()); + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT jobject JNICALL +Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_classifyNative( + JNIEnv* env, + jclass clazz, + jlong native_handle, + jstring text) { + return RunClassifier(env, native_handle, text); +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/BUILD deleted file mode 100644 index 49f3f4e..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/BUILD +++ /dev/null
@@ -1,31 +0,0 @@ -load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") - -package( - default_visibility = ["//tensorflow_lite_support:users"], - licenses = ["notice"], # Apache 2.0 -) - -exports_files([ - "bert_nl_classifier_jni.cc", -]) - -cc_library( - name = "bert_nl_classifier_native", - srcs = [ - ":libtask_text_jni.so", - ], -) - -tflite_jni_binary( - name = "libtask_text_jni.so", - srcs = [ - "bert_nl_classifier_jni.cc", - ], - linkscript = "//tensorflow_lite_support/java:default_version_script.lds", - deps = [ - "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier", - "//tensorflow_lite_support/cc/utils:jni_utils", - "//tensorflow_lite_support/java/jni", - "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_jni_utils", - ], -)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc deleted file mode 100644 index 1dd3440..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc +++ /dev/null
@@ -1,83 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include <jni.h> - -#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h" -#include "tensorflow_lite_support/cc/utils/jni_utils.h" -#include "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h" - -namespace { - -using ::tflite::support::utils::GetMappedFileBuffer; -using ::tflite::support::utils::kAssertionError; -using ::tflite::support::utils::kInvalidPointer; -using ::tflite::support::utils::ThrowException; -using ::tflite::task::text::nlclassifier::BertNLClassifier; -using ::tflite::task::text::nlclassifier::RunClassifier; - -extern "C" JNIEXPORT void JNICALL -Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni( - JNIEnv* env, - jobject thiz, - jlong native_handle) { - delete reinterpret_cast<BertNLClassifier*>(native_handle); -} - -extern "C" JNIEXPORT jlong JNICALL -Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByteBuffer( - JNIEnv* env, - jclass thiz, - jobject model_buffer) { - auto model = GetMappedFileBuffer(env, model_buffer); - tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> status = - BertNLClassifier::CreateFromBuffer(model.data(), model.size()); - if (status.ok()) { - return reinterpret_cast<jlong>(status->release()); - } else { - ThrowException(env, kAssertionError, - "Error occurred when initializing Bert NLClassifier: %s", - status.status().message().data()); - return kInvalidPointer; - } -} - -extern "C" JNIEXPORT jlong JNICALL -Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFileDescriptor( - JNIEnv* env, - jclass thiz, - jint fd) { - tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> status = - BertNLClassifier::CreateFromFd(fd); - if (status.ok()) { - return reinterpret_cast<jlong>(status->release()); - } else { - ThrowException(env, kAssertionError, - "Error occurred when initializing Bert NLClassifier: %s", - status.status().message().data()); - return kInvalidPointer; - } -} - -extern "C" JNIEXPORT jobject JNICALL -Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_classifyNative( - JNIEnv* env, - jclass clazz, - jlong native_handle, - jstring text) { - return RunClassifier(env, native_handle, text); -} - -} // namespace
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/custom_ops_resolver_for_smart_reply.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/custom_ops_resolver_for_smart_reply.cc deleted file mode 100644 index 48cbd367..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/custom_ops_resolver_for_smart_reply.cc +++ /dev/null
@@ -1,30 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "absl/memory/memory.h" -#include "knowledge/hobbes/chat/tensorflow/tflite/tflite-all-lingua-ops-resolver.h" -#include "tensorflow/lite/op_resolver.h" - -namespace tflite { -namespace task { -// Provides custom OpResolver for test NLClassifier models. -std::unique_ptr<OpResolver> CreateOpResolver() { // NOLINT - MutableOpResolver resolver; - RegisterAllLinguaOps(&resolver); - return absl::make_unique<MutableOpResolver>(resolver); -} - -} // namespace task -} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc index 6066d1d..4c71a80 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc
@@ -17,6 +17,7 @@ #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/op_resolver.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" #include "tensorflow_lite_support/cc/utils/jni_utils.h" #include "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h" @@ -31,57 +32,66 @@ namespace { +using ::tflite::support::utils::GetExceptionClassNameForStatusCode; using ::tflite::support::utils::GetMappedFileBuffer; using ::tflite::support::utils::JStringToString; -using ::tflite::support::utils::kAssertionError; using ::tflite::support::utils::kInvalidPointer; using ::tflite::support::utils::ThrowException; +using ::tflite::task::core::BaseOptions; +using ::tflite::task::text::NLClassifierOptions; using ::tflite::task::text::nlclassifier::NLClassifier; -using ::tflite::task::text::nlclassifier::NLClassifierOptions; using ::tflite::task::text::nlclassifier::RunClassifier; -NLClassifierOptions ConvertJavaNLClassifierOptions( - JNIEnv* env, - jobject java_nl_classifier_options) { +NLClassifierOptions ConvertToProtoOptions(JNIEnv* env, + jobject java_nl_classifier_options, + jlong base_options_handle) { jclass nl_classifier_options_class = env->FindClass( "org/tensorflow/lite/task/text/nlclassifier/" "NLClassifier$NLClassifierOptions"); - jmethodID input_tensor_index_method_id = - env->GetMethodID(nl_classifier_options_class, "inputTensorIndex", "()I"); + jmethodID input_tensor_index_method_id = env->GetMethodID( + nl_classifier_options_class, "getInputTensorIndex", "()I"); jmethodID output_score_tensor_index_method_id = env->GetMethodID( - nl_classifier_options_class, "outputScoreTensorIndex", "()I"); + nl_classifier_options_class, "getOutputScoreTensorIndex", "()I"); jmethodID output_label_tensor_index_method_id = env->GetMethodID( - nl_classifier_options_class, "outputLabelTensorIndex", "()I"); - jmethodID input_tensor_name_method_id = env->GetMethodID( - nl_classifier_options_class, "inputTensorName", "()Ljava/lang/String;"); + nl_classifier_options_class, "getOutputLabelTensorIndex", "()I"); + jmethodID input_tensor_name_method_id = + env->GetMethodID(nl_classifier_options_class, "getInputTensorName", + "()Ljava/lang/String;"); jmethodID output_score_tensor_name_method_id = - env->GetMethodID(nl_classifier_options_class, "outputScoreTensorName", + env->GetMethodID(nl_classifier_options_class, "getOutputScoreTensorName", "()Ljava/lang/String;"); jmethodID output_label_tensor_name_method_id = - env->GetMethodID(nl_classifier_options_class, "outputLabelTensorName", + env->GetMethodID(nl_classifier_options_class, "getOutputLabelTensorName", "()Ljava/lang/String;"); - return { - .input_tensor_index = env->CallIntMethod(java_nl_classifier_options, - input_tensor_index_method_id), - .output_score_tensor_index = env->CallIntMethod( - java_nl_classifier_options, output_score_tensor_index_method_id), - .output_label_tensor_index = env->CallIntMethod( - java_nl_classifier_options, output_label_tensor_index_method_id), - .input_tensor_name = JStringToString( - env, (jstring)env->CallObjectMethod(java_nl_classifier_options, - input_tensor_name_method_id)), - .output_score_tensor_name = JStringToString( - env, - (jstring)env->CallObjectMethod(java_nl_classifier_options, - output_score_tensor_name_method_id)), - .output_label_tensor_name = JStringToString( - env, - (jstring)env->CallObjectMethod(java_nl_classifier_options, - output_label_tensor_name_method_id)), - }; + NLClassifierOptions proto_options; + if (base_options_handle != kInvalidPointer) { + // proto_options will free the previous base_options and set the new one. + proto_options.set_allocated_base_options( + reinterpret_cast<BaseOptions*>(base_options_handle)); + } + + proto_options.set_input_tensor_index(env->CallIntMethod( + java_nl_classifier_options, input_tensor_index_method_id)); + proto_options.set_output_score_tensor_index(env->CallIntMethod( + java_nl_classifier_options, output_score_tensor_index_method_id)); + proto_options.set_output_label_tensor_index(env->CallIntMethod( + java_nl_classifier_options, output_label_tensor_index_method_id)); + proto_options.set_input_tensor_name(JStringToString( + env, (jstring)env->CallObjectMethod(java_nl_classifier_options, + input_tensor_name_method_id))); + proto_options.set_output_score_tensor_name(JStringToString( + env, (jstring)env->CallObjectMethod(java_nl_classifier_options, + output_score_tensor_name_method_id))); + proto_options.set_output_label_tensor_name(JStringToString( + env, (jstring)env->CallObjectMethod(java_nl_classifier_options, + output_label_tensor_name_method_id))); + + return proto_options; } +} // namespace + extern "C" JNIEXPORT void JNICALL Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_deinitJni( JNIEnv* env, @@ -95,20 +105,25 @@ JNIEnv* env, jclass thiz, jobject nl_classifier_options, - jobject model_buffer) { + jobject model_buffer, + jlong base_options_handle) { auto model = GetMappedFileBuffer(env, model_buffer); - tflite::support::StatusOr<std::unique_ptr<NLClassifier>> status = - NLClassifier::CreateFromBufferAndOptions( - model.data(), model.size(), - ConvertJavaNLClassifierOptions(env, nl_classifier_options), - tflite::task::CreateOpResolver()); + tflite::support::StatusOr<std::unique_ptr<NLClassifier>> classifier_or; - if (status.ok()) { - return reinterpret_cast<jlong>(status->release()); + NLClassifierOptions proto_options = + ConvertToProtoOptions(env, nl_classifier_options, base_options_handle); + proto_options.mutable_base_options()->mutable_model_file()->set_file_content( + model.data(), model.size()); + classifier_or = NLClassifier::CreateFromOptions( + proto_options, tflite::task::CreateOpResolver()); + + if (classifier_or.ok()) { + return reinterpret_cast<jlong>(classifier_or->release()); } else { - ThrowException(env, kAssertionError, - "Error occurred when initializing NLClassifier: %s", - status.status().message().data()); + ThrowException( + env, GetExceptionClassNameForStatusCode(classifier_or.status().code()), + "Error occurred when initializing NLClassifier: %s", + classifier_or.status().message().data()); return kInvalidPointer; } } @@ -118,17 +133,26 @@ JNIEnv* env, jclass thiz, jobject nl_classifier_options, - jint fd) { - tflite::support::StatusOr<std::unique_ptr<NLClassifier>> status = - NLClassifier::CreateFromFdAndOptions( - fd, ConvertJavaNLClassifierOptions(env, nl_classifier_options), - tflite::task::CreateOpResolver()); - if (status.ok()) { - return reinterpret_cast<jlong>(status->release()); + jint fd, + jlong base_options_handle) { + tflite::support::StatusOr<std::unique_ptr<NLClassifier>> classifier_or; + + NLClassifierOptions proto_options = + ConvertToProtoOptions(env, nl_classifier_options, base_options_handle); + proto_options.mutable_base_options() + ->mutable_model_file() + ->mutable_file_descriptor_meta() + ->set_fd(fd); + classifier_or = NLClassifier::CreateFromOptions( + proto_options, tflite::task::CreateOpResolver()); + + if (classifier_or.ok()) { + return reinterpret_cast<jlong>(classifier_or->release()); } else { - ThrowException(env, kAssertionError, - "Error occurred when initializing NLClassifier: %s", - status.status().message().data()); + ThrowException( + env, GetExceptionClassNameForStatusCode(classifier_or.status().code()), + "Error occurred when initializing NLClassifier: %s", + classifier_or.status().message().data()); return kInvalidPointer; } } @@ -141,5 +165,3 @@ jstring text) { return RunClassifier(env, native_handle, text); } - -} // namespace
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/BUILD index 9753e32..e586105 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/BUILD
@@ -1,4 +1,4 @@ -load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "jni_binary_with_tflite") package( default_visibility = ["//tensorflow_lite_support:users"], @@ -9,22 +9,26 @@ "bert_question_answerer_jni.cc", ]) -tflite_jni_binary( +jni_binary_with_tflite( name = "libtask_text_jni.so", srcs = [ "bert_question_answerer_jni.cc", + "//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc", ], linkscript = "//tensorflow_lite_support/java:default_version_script.lds", - deps = [ - "//tensorflow_lite_support/cc/task/text/qa:bert_question_answerer", + tflite_deps = [ + "//tensorflow_lite_support/cc/task/text:bert_question_answerer", "//tensorflow_lite_support/cc/utils:jni_utils", + ], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", "//tensorflow_lite_support/java/jni", ], ) -cc_library( +cc_library_with_tflite( name = "bert_question_answerer_native", - srcs = [ + tflite_jni_binaries = [ ":libtask_text_jni.so", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc index 9dd1895c..401e6fbd 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc
@@ -15,20 +15,41 @@ #include <jni.h> -#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h" #include "tensorflow_lite_support/cc/utils/jni_utils.h" namespace { +using ::tflite::support::StatusOr; using ::tflite::support::utils::ConvertVectorToArrayList; +using ::tflite::support::utils::GetExceptionClassNameForStatusCode; using ::tflite::support::utils::GetMappedFileBuffer; using ::tflite::support::utils::JStringToString; -using ::tflite::task::text::qa::BertQuestionAnswerer; -using ::tflite::task::text::qa::QaAnswer; -using ::tflite::task::text::qa::QuestionAnswerer; +using ::tflite::support::utils::ThrowException; +using ::tflite::task::core::BaseOptions; +using ::tflite::task::text::BertQuestionAnswerer; +using ::tflite::task::text::BertQuestionAnswererOptions; +using ::tflite::task::text::QaAnswer; +using ::tflite::task::text::QuestionAnswerer; constexpr int kInvalidPointer = 0; +// Creates a BertQuestionAnswererOptions proto based on the Java class. +BertQuestionAnswererOptions ConvertToProtoOptions(jlong base_options_handle) { + BertQuestionAnswererOptions proto_options; + + if (base_options_handle != kInvalidPointer) { + // proto_options will free the previous base_options and set the new one. + proto_options.set_allocated_base_options( + reinterpret_cast<BaseOptions*>(base_options_handle)); + } + + return proto_options; +} + +} // namespace + extern "C" JNIEXPORT void JNICALL Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_deinitJni( JNIEnv* env, @@ -38,33 +59,35 @@ } extern "C" JNIEXPORT jlong JNICALL -Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithModelWithMetadataByteBuffers( - JNIEnv* env, - jclass thiz, - jobjectArray model_buffers) { - absl::string_view model_with_metadata = - GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0)); - - tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status = - BertQuestionAnswerer::CreateFromBuffer(model_with_metadata.data(), - model_with_metadata.size()); - if (status.ok()) { - return reinterpret_cast<jlong>(status->release()); - } else { - return kInvalidPointer; - } -} - -extern "C" JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescriptor( JNIEnv* env, jclass thiz, - jint fd) { - tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status = - BertQuestionAnswerer::CreateFromFd(fd); - if (status.ok()) { - return reinterpret_cast<jlong>(status->release()); + jint file_descriptor, + jlong file_descriptor_length, + jlong file_descriptor_offset, + jlong base_options_handle) { + BertQuestionAnswererOptions proto_options = + ConvertToProtoOptions(base_options_handle); + auto file_descriptor_meta = proto_options.mutable_base_options() + ->mutable_model_file() + ->mutable_file_descriptor_meta(); + file_descriptor_meta->set_fd(file_descriptor); + if (file_descriptor_length > 0) { + file_descriptor_meta->set_length(file_descriptor_length); + } + if (file_descriptor_offset > 0) { + file_descriptor_meta->set_offset(file_descriptor_offset); + } + + StatusOr<std::unique_ptr<QuestionAnswerer>> qa_status = + BertQuestionAnswerer::CreateFromOptions(proto_options); + if (qa_status.ok()) { + return reinterpret_cast<jlong>(qa_status->release()); } else { + ThrowException( + env, GetExceptionClassNameForStatusCode(qa_status.status().code()), + "Error occurred when initializing BertQuestionAnswerer: %s", + qa_status.status().message().data()); return kInvalidPointer; } } @@ -79,12 +102,16 @@ absl::string_view vocab = GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 1)); - tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status = + StatusOr<std::unique_ptr<QuestionAnswerer>> qa_status = BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer( model.data(), model.size(), vocab.data(), vocab.size()); - if (status.ok()) { - return reinterpret_cast<jlong>(status->release()); + if (qa_status.ok()) { + return reinterpret_cast<jlong>(qa_status->release()); } else { + ThrowException( + env, GetExceptionClassNameForStatusCode(qa_status.status().code()), + "Error occurred when initializing BertQuestionAnswerer: %s", + qa_status.status().message().data()); return kInvalidPointer; } } @@ -99,12 +126,16 @@ absl::string_view sp_model = GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 1)); - tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status = + StatusOr<std::unique_ptr<QuestionAnswerer>> qa_status = BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer( model.data(), model.size(), sp_model.data(), sp_model.size()); - if (status.ok()) { - return reinterpret_cast<jlong>(status->release()); + if (qa_status.ok()) { + return reinterpret_cast<jlong>(qa_status->release()); } else { + ThrowException( + env, GetExceptionClassNameForStatusCode(qa_status.status().code()), + "Error occurred when initializing BertQuestionAnswerer: %s", + qa_status.status().message().data()); return kInvalidPointer; } } @@ -136,5 +167,3 @@ return qa_answer; }); } - -} // namespace
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/BUILD index 451a50ca..6f784145 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/BUILD
@@ -1,11 +1,11 @@ -load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "jni_binary_with_tflite") package( default_visibility = ["//tensorflow_lite_support:users"], licenses = ["notice"], # Apache 2.0 ) -cc_library( +cc_library_with_tflite( name = "jni_utils", srcs = [ "jni_utils.cc", @@ -13,35 +13,48 @@ hdrs = [ "jni_utils.h", ], + tflite_deps = [ + "//tensorflow_lite_support/cc/utils:jni_utils", + ], deps = [ + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", "//tensorflow_lite_support/cc/task/vision/proto:class_proto_inc", - "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", "//tensorflow_lite_support/java/jni", "@com_google_absl//absl/strings", ], ) -cc_library( +cc_library_with_tflite( name = "task_vision_native", - srcs = [ + tflite_jni_binaries = [ ":libtask_vision_jni.so", ], ) -tflite_jni_binary( +jni_binary_with_tflite( name = "libtask_vision_jni.so", srcs = [ + "//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc", "//tensorflow_lite_support/java/src/native/task/vision/classifier:image_classifier_jni.cc", + "//tensorflow_lite_support/java/src/native/task/vision/core:base_vision_task_api_jni.cc", "//tensorflow_lite_support/java/src/native/task/vision/detector:object_detector_jni.cc", "//tensorflow_lite_support/java/src/native/task/vision/segmenter:image_segmenter_jni.cc", ], linkscript = "//tensorflow_lite_support/java:default_version_script.lds", - deps = [ - "//tensorflow_lite_support/cc/port:statusor", + tflite_deps = [ + "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver", "//tensorflow_lite_support/cc/task/vision:image_classifier", "//tensorflow_lite_support/cc/task/vision:image_segmenter", "//tensorflow_lite_support/cc/task/vision:object_detector", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/src/native/task/vision:jni_utils", + ], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc", @@ -51,9 +64,7 @@ "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", - "//tensorflow_lite_support/cc/utils:jni_utils", "//tensorflow_lite_support/java/jni", - "//tensorflow_lite_support/java/src/native/task/vision:jni_utils", "@com_google_absl//absl/strings", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/BUILD index 8bddc2a..7fdb8e3 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/BUILD
@@ -1,4 +1,4 @@ -load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "jni_binary_with_tflite") package( default_visibility = ["//tensorflow_lite_support:users"], @@ -7,29 +7,44 @@ exports_files(["image_classifier_jni.cc"]) -cc_library( +cc_library_with_tflite( name = "image_classifier_native", - srcs = [ + tflite_jni_binaries = [ ":libtask_vision_jni.so", ], ) -tflite_jni_binary( +jni_binary_with_tflite( name = "libtask_vision_jni.so", + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + tflite_deps = [ + ":native_without_resolver", + "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver", + ], +) + +# Shared native logic for ImageClassifier. Combine this target and customized +# version of op_resolver to build customized image_classifier_native target. +cc_library_with_tflite( + name = "native_without_resolver", srcs = [ "image_classifier_jni.cc", + "//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc", + "//tensorflow_lite_support/java/src/native/task/vision/core:base_vision_task_api_jni.cc", ], - linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + tflite_deps = [ + "//tensorflow_lite_support/cc/task/vision:image_classifier", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/src/native/task/vision:jni_utils", + ], deps = [ "//tensorflow_lite_support/cc/port:statusor", - "//tensorflow_lite_support/cc/task/vision:image_classifier", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", - "//tensorflow_lite_support/cc/utils:jni_utils", "//tensorflow_lite_support/java/jni", - "//tensorflow_lite_support/java/src/native/task/vision:jni_utils", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc index bb8f8ce..2a713cf 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc
@@ -19,6 +19,7 @@ #include <string> #include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" #include "tensorflow_lite_support/cc/task/vision/image_classifier.h" #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" @@ -28,27 +29,42 @@ #include "tensorflow_lite_support/cc/utils/jni_utils.h" #include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h" +namespace tflite { +namespace task { +// To be provided by a link-time library +extern std::unique_ptr<OpResolver> CreateOpResolver(); + +} // namespace task +} // namespace tflite + namespace { using ::tflite::support::StatusOr; -using ::tflite::support::utils::GetMappedFileBuffer; -using ::tflite::support::utils::kAssertionError; +using ::tflite::support::utils::GetExceptionClassNameForStatusCode; using ::tflite::support::utils::kInvalidPointer; using ::tflite::support::utils::StringListToVector; using ::tflite::support::utils::ThrowException; +using ::tflite::task::core::BaseOptions; using ::tflite::task::vision::BoundingBox; using ::tflite::task::vision::ClassificationResult; using ::tflite::task::vision::Classifications; using ::tflite::task::vision::ConvertToCategory; -using ::tflite::task::vision::ConvertToFrameBufferOrientation; using ::tflite::task::vision::FrameBuffer; using ::tflite::task::vision::ImageClassifier; using ::tflite::task::vision::ImageClassifierOptions; // Creates an ImageClassifierOptions proto based on the Java class. ImageClassifierOptions ConvertToProtoOptions(JNIEnv* env, - jobject java_options) { + jobject java_options, + jlong base_options_handle) { ImageClassifierOptions proto_options; + + if (base_options_handle != kInvalidPointer) { + // proto_options will free the previous base_options and set the new one. + proto_options.set_allocated_base_options( + reinterpret_cast<BaseOptions*>(base_options_handle)); + } + jclass java_options_class = env->FindClass( "org/tensorflow/lite/task/vision/classifier/" "ImageClassifier$ImageClassifierOptions"); @@ -93,7 +109,6 @@ for (const auto& class_name : deny_list_vector) { proto_options.add_class_name_blacklist(class_name); } - return proto_options; } @@ -139,6 +154,26 @@ return classifications_list; } +jlong CreateImageClassifierFromOptions(JNIEnv* env, + const ImageClassifierOptions& options) { + StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or = + ImageClassifier::CreateFromOptions(options, + tflite::task::CreateOpResolver()); + if (image_classifier_or.ok()) { + // Deletion is handled at deinitJni time. + return reinterpret_cast<jlong>(image_classifier_or->release()); + } else { + ThrowException( + env, + GetExceptionClassNameForStatusCode(image_classifier_or.status().code()), + "Error occurred when initializing ImageClassifier: %s", + image_classifier_or.status().message().data()); + return kInvalidPointer; + } +} + +} // namespace + extern "C" JNIEXPORT void JNICALL Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_deinitJni( JNIEnv* env, @@ -147,6 +182,9 @@ delete reinterpret_cast<ImageClassifier*>(native_handle); } +// Creates an ImageClassifier instance from the model file descriptor. +// file_descriptor_length and file_descriptor_offset are optional. Non-possitive +// values will be ignored. extern "C" JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithModelFdAndOptions( JNIEnv* env, @@ -154,26 +192,40 @@ jint file_descriptor, jlong file_descriptor_length, jlong file_descriptor_offset, - jobject java_options) { + jobject java_options, + jlong base_options_handle) { ImageClassifierOptions proto_options = - ConvertToProtoOptions(env, java_options); - auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata() + ConvertToProtoOptions(env, java_options, base_options_handle); + auto file_descriptor_meta = proto_options.mutable_base_options() + ->mutable_model_file() ->mutable_file_descriptor_meta(); file_descriptor_meta->set_fd(file_descriptor); - file_descriptor_meta->set_length(file_descriptor_length); - file_descriptor_meta->set_offset(file_descriptor_offset); - - StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or = - ImageClassifier::CreateFromOptions(proto_options); - if (image_classifier_or.ok()) { - // Deletion is handled at deinitJni time. - return reinterpret_cast<jlong>(image_classifier_or->release()); - } else { - ThrowException(env, kAssertionError, - "Error occurred when initializing ImageClassifier: %s", - image_classifier_or.status().message().data()); - return kInvalidPointer; + if (file_descriptor_length > 0) { + file_descriptor_meta->set_length(file_descriptor_length); } + if (file_descriptor_offset > 0) { + file_descriptor_meta->set_offset(file_descriptor_offset); + } + return CreateImageClassifierFromOptions(env, proto_options); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithByteBuffer( + JNIEnv* env, + jclass thiz, + jobject model_buffer, + jobject java_options, + jlong base_options_handle) { + ImageClassifierOptions proto_options = + ConvertToProtoOptions(env, java_options, base_options_handle); + // External proto generated header does not overload `set_file_content` with + // string_view, therefore GetMappedFileBuffer does not apply here. + // Creating a std::string will cause one extra copying of data. Thus, the + // most efficient way here is to set file_content using char* and its size. + proto_options.mutable_base_options()->mutable_model_file()->set_file_content( + static_cast<char*>(env->GetDirectBufferAddress(model_buffer)), + static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer))); + return CreateImageClassifierFromOptions(env, proto_options); } extern "C" JNIEXPORT jobject JNICALL @@ -181,17 +233,12 @@ JNIEnv* env, jclass thiz, jlong native_handle, - jobject image_byte_buffer, - jint width, - jint height, - jintArray jroi, - jint jorientation) { + jlong frame_buffer_handle, + jintArray jroi) { auto* classifier = reinterpret_cast<ImageClassifier*>(native_handle); - auto image = GetMappedFileBuffer(env, image_byte_buffer); - std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( - reinterpret_cast<const uint8*>(image.data()), - FrameBuffer::Dimension{width, height}, - ConvertToFrameBufferOrientation(env, jorientation)); + // frame_buffer will be deleted after inference is done in + // base_vision_api_jni.cc. + auto* frame_buffer = reinterpret_cast<FrameBuffer*>(frame_buffer_handle); int* roi_array = env->GetIntArrayElements(jroi, 0); BoundingBox roi; @@ -205,10 +252,10 @@ if (results_or.ok()) { return ConvertToClassificationResults(env, results_or.value()); } else { - ThrowException(env, kAssertionError, - "Error occurred when classifying the image: %s", - results_or.status().message().data()); + ThrowException( + env, GetExceptionClassNameForStatusCode(results_or.status().code()), + "Error occurred when classifying the image: %s", + results_or.status().message().data()); return nullptr; } } -} // namespace
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/BUILD new file mode 100644 index 0000000..84d36e6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/BUILD
@@ -0,0 +1,6 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["base_vision_task_api_jni.cc"])
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc new file mode 100644 index 0000000..2cda1b5 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc
@@ -0,0 +1,122 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> + +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" +#include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h" + +namespace { + +using ::tflite::support::utils::GetExceptionClassNameForStatusCode; +using ::tflite::support::utils::kInvalidPointer; +using ::tflite::support::utils::ThrowException; +using ::tflite::task::vision::CreateFrameBufferFromByteBuffer; +using ::tflite::task::vision::CreateFrameBufferFromBytes; +using ::tflite::task::vision::CreateFrameBufferFromYuvPlanes; +using ::tflite::task::vision::FrameBuffer; + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromByteBuffer( + JNIEnv* env, + jclass thiz, + jobject jimage_byte_buffer, + jint width, + jint height, + jint jorientation, + jint jcolor_space_type) { + auto frame_buffer_or = CreateFrameBufferFromByteBuffer( + env, jimage_byte_buffer, width, height, jorientation, jcolor_space_type); + if (frame_buffer_or.ok()) { + return reinterpret_cast<jlong>(frame_buffer_or->release()); + } else { + ThrowException( + env, + GetExceptionClassNameForStatusCode(frame_buffer_or.status().code()), + "Error occurred when creating FrameBuffer: %s", + frame_buffer_or.status().message().data()); + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromBytes( + JNIEnv* env, + jclass thiz, + jbyteArray jimage_bytes, + jint width, + jint height, + jint jorientation, + jint jcolor_space_type, + jlongArray jbyte_array_handle) { + auto frame_buffer_or = + CreateFrameBufferFromBytes(env, jimage_bytes, width, height, jorientation, + jcolor_space_type, jbyte_array_handle); + if (frame_buffer_or.ok()) { + return reinterpret_cast<jlong>(frame_buffer_or->release()); + } else { + ThrowException( + env, + GetExceptionClassNameForStatusCode(frame_buffer_or.status().code()), + "Error occurred when creating FrameBuffer: %s", + frame_buffer_or.status().message().data()); + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromPlanes( + JNIEnv* env, + jclass thiz, + jobject jy_plane, + jobject ju_plane, + jobject jv_plane, + jint width, + jint height, + jint row_stride_y, + jint row_stride_uv, + jint pixel_stride_uv, + jint orientation) { + auto frame_buffer_or = CreateFrameBufferFromYuvPlanes( + env, jy_plane, ju_plane, jv_plane, width, height, row_stride_y, + row_stride_uv, pixel_stride_uv, orientation); + if (frame_buffer_or.ok()) { + return reinterpret_cast<jlong>(frame_buffer_or->release()); + } else { + ThrowException( + env, + GetExceptionClassNameForStatusCode(frame_buffer_or.status().code()), + "Error occurred when creating FrameBuffer: %s", + frame_buffer_or.status().message().data()); + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT void JNICALL +Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_deleteFrameBuffer( + JNIEnv* env, + jobject thiz, + jlong frame_buffer_handle, + jlong byte_array_handle, + jbyteArray jbyte_array) { + delete reinterpret_cast<FrameBuffer*>(frame_buffer_handle); + jbyte* bytes_ptr = reinterpret_cast<jbyte*>(byte_array_handle); + if (bytes_ptr != NULL) { + env->ReleaseByteArrayElements(jbyte_array, bytes_ptr, /*mode=*/0); + } +} + +} // namespace
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/BUILD index 5abd3f1..7fb94132b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/BUILD
@@ -1,4 +1,4 @@ -load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "jni_binary_with_tflite") package( default_visibility = ["//tensorflow_lite_support:users"], @@ -7,30 +7,45 @@ exports_files(["object_detector_jni.cc"]) -cc_library( +cc_library_with_tflite( name = "object_detector_native", - srcs = [ + tflite_jni_binaries = [ ":libtask_vision_jni.so", ], ) -tflite_jni_binary( +jni_binary_with_tflite( name = "libtask_vision_jni.so", + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + tflite_deps = [ + ":native_without_resolver", + "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver", + ], +) + +# Shared native logic for ObjectDetector. Combine this target and customized +# version of op_resolver to build customized object_detector_native target. +cc_library_with_tflite( + name = "native_without_resolver", srcs = [ "object_detector_jni.cc", + "//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc", + "//tensorflow_lite_support/java/src/native/task/vision/core:base_vision_task_api_jni.cc", ], - linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + tflite_deps = [ + "//tensorflow_lite_support/cc/task/vision:object_detector", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/src/native/task/vision:jni_utils", + ], deps = [ "//tensorflow_lite_support/cc/port:statusor", - "//tensorflow_lite_support/cc/task/vision:object_detector", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", - "//tensorflow_lite_support/cc/utils:jni_utils", "//tensorflow_lite_support/java/jni", - "//tensorflow_lite_support/java/src/native/task/vision:jni_utils", "@com_google_absl//absl/strings", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc index eca7b26..f720795 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc
@@ -18,8 +18,9 @@ #include <memory> #include <string> -#include "absl/strings/string_view.h" +#include "absl/strings/string_view.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" #include "tensorflow_lite_support/cc/task/vision/object_detector.h" #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" @@ -29,25 +30,41 @@ #include "tensorflow_lite_support/cc/utils/jni_utils.h" #include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h" +namespace tflite { +namespace task { +// To be provided by a link-time library +extern std::unique_ptr<OpResolver> CreateOpResolver(); + +} // namespace task +} // namespace tflite + namespace { using ::tflite::support::StatusOr; -using ::tflite::support::utils::GetMappedFileBuffer; -using ::tflite::support::utils::kAssertionError; +using ::tflite::support::utils::GetExceptionClassNameForStatusCode; using ::tflite::support::utils::kInvalidPointer; using ::tflite::support::utils::StringListToVector; using ::tflite::support::utils::ThrowException; +using ::tflite::task::core::BaseOptions; using ::tflite::task::vision::BoundingBox; using ::tflite::task::vision::ConvertToCategory; -using ::tflite::task::vision::ConvertToFrameBufferOrientation; using ::tflite::task::vision::DetectionResult; using ::tflite::task::vision::FrameBuffer; using ::tflite::task::vision::ObjectDetector; using ::tflite::task::vision::ObjectDetectorOptions; // Creates an ObjectDetectorOptions proto based on the Java class. -ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options) { +ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env, + jobject java_options, + jlong base_options_handle) { ObjectDetectorOptions proto_options; + + if (base_options_handle != kInvalidPointer) { + // proto_options will free the previous base_options and set the new one. + proto_options.set_allocated_base_options( + reinterpret_cast<BaseOptions*>(base_options_handle)); + } + jclass java_options_class = env->FindClass( "org/tensorflow/lite/task/vision/detector/" "ObjectDetector$ObjectDetectorOptions"); @@ -93,7 +110,6 @@ for (const auto& class_name : deny_list_vector) { proto_options.add_class_name_blacklist(class_name); } - return proto_options; } @@ -147,6 +163,25 @@ return detections_list; } +jlong CreateObjectDetectorFromOptions(JNIEnv* env, + const ObjectDetectorOptions& options) { + StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or = + ObjectDetector::CreateFromOptions(options, + tflite::task::CreateOpResolver()); + if (object_detector_or.ok()) { + return reinterpret_cast<jlong>(object_detector_or->release()); + } else { + ThrowException( + env, + GetExceptionClassNameForStatusCode(object_detector_or.status().code()), + "Error occurred when initializing ObjectDetector: %s", + object_detector_or.status().message().data()); + return kInvalidPointer; + } +} + +} // namespace + extern "C" JNIEXPORT void JNICALL Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_deinitJni( JNIEnv* env, @@ -155,6 +190,9 @@ delete reinterpret_cast<ObjectDetector*>(native_handle); } +// Creates an ObjectDetector instance from the model file descriptor. +// file_descriptor_length and file_descriptor_offset are optional. Non-possitive +// values will be ignored. extern "C" JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithModelFdAndOptions( JNIEnv* env, @@ -162,25 +200,36 @@ jint file_descriptor, jlong file_descriptor_length, jlong file_descriptor_offset, - jobject java_options) { + jobject java_options, + jlong base_options_handle) { ObjectDetectorOptions proto_options = - ConvertToProtoOptions(env, java_options); - auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata() + ConvertToProtoOptions(env, java_options, base_options_handle); + auto file_descriptor_meta = proto_options.mutable_base_options() + ->mutable_model_file() ->mutable_file_descriptor_meta(); file_descriptor_meta->set_fd(file_descriptor); - file_descriptor_meta->set_length(file_descriptor_length); - file_descriptor_meta->set_offset(file_descriptor_offset); - - StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or = - ObjectDetector::CreateFromOptions(proto_options); - if (object_detector_or.ok()) { - return reinterpret_cast<jlong>(object_detector_or->release()); - } else { - ThrowException(env, kAssertionError, - "Error occurred when initializing ObjectDetector: %s", - object_detector_or.status().message().data()); - return kInvalidPointer; + if (file_descriptor_length > 0) { + file_descriptor_meta->set_length(file_descriptor_length); } + if (file_descriptor_offset > 0) { + file_descriptor_meta->set_offset(file_descriptor_offset); + } + return CreateObjectDetectorFromOptions(env, proto_options); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithByteBuffer( + JNIEnv* env, + jclass thiz, + jobject model_buffer, + jobject java_options, + jlong base_options_handle) { + ObjectDetectorOptions proto_options = + ConvertToProtoOptions(env, java_options, base_options_handle); + proto_options.mutable_base_options()->mutable_model_file()->set_file_content( + static_cast<char*>(env->GetDirectBufferAddress(model_buffer)), + static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer))); + return CreateObjectDetectorFromOptions(env, proto_options); } extern "C" JNIEXPORT jobject JNICALL @@ -188,24 +237,19 @@ JNIEnv* env, jclass thiz, jlong native_handle, - jobject image_byte_buffer, - jint width, - jint height, - jint jorientation) { + jlong frame_buffer_handle) { auto* detector = reinterpret_cast<ObjectDetector*>(native_handle); - absl::string_view image = GetMappedFileBuffer(env, image_byte_buffer); - std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( - reinterpret_cast<const uint8*>(image.data()), - FrameBuffer::Dimension{width, height}, - ConvertToFrameBufferOrientation(env, jorientation)); + // frame_buffer will be deleted after inference is done in + // base_vision_api_jni.cc. + auto* frame_buffer = reinterpret_cast<FrameBuffer*>(frame_buffer_handle); auto results_or = detector->Detect(*frame_buffer); if (results_or.ok()) { return ConvertToDetectionResults(env, results_or.value()); } else { - ThrowException(env, kAssertionError, - "Error occurred when detecting the image: %s", - results_or.status().message().data()); + ThrowException( + env, GetExceptionClassNameForStatusCode(results_or.status().code()), + "Error occurred when detecting the image: %s", + results_or.status().message().data()); return nullptr; } } -} // namespace
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc index eb46b68..e0c94e2e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc
@@ -15,15 +15,21 @@ #include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h" -#include "absl/strings/str_cat.h" +#include "absl/strings/str_cat.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" #include "tensorflow_lite_support/cc/utils/jni_utils.h" namespace tflite { namespace task { namespace vision { -using ::tflite::support::utils::kAssertionError; +using ::tflite::support::StatusOr; +using ::tflite::support::utils::GetMappedFileBuffer; +using ::tflite::support::utils::kIllegalStateException; using ::tflite::support::utils::ThrowException; +using ::tflite::task::vision::CreateFromRawBuffer; constexpr char kCategoryClassName[] = "org/tensorflow/lite/support/label/Category"; @@ -35,24 +41,51 @@ jclass category_class = env->FindClass(kCategoryClassName); jmethodID category_create = env->GetStaticMethodID( category_class, "create", - absl::StrCat("(", kStringClassName, kStringClassName, "F)L", + absl::StrCat("(", kStringClassName, kStringClassName, "FI)L", kCategoryClassName, ";") .c_str()); std::string label_string = classification.has_class_name() ? classification.class_name() - : kEmptyString; + : std::to_string(classification.index()); jstring label = env->NewStringUTF(label_string.c_str()); std::string display_name_string = classification.has_display_name() ? classification.display_name() : kEmptyString; jstring display_name = env->NewStringUTF(display_name_string.c_str()); - jobject jcategory = - env->CallStaticObjectMethod(category_class, category_create, label, - display_name, classification.score()); + jobject jcategory = env->CallStaticObjectMethod( + category_class, category_create, label, display_name, + classification.score(), classification.index()); + env->DeleteLocalRef(category_class); + env->DeleteLocalRef(label); + env->DeleteLocalRef(display_name); return jcategory; } +FrameBuffer::Format ConvertToFrameBufferFormat(JNIEnv* env, + jint jcolor_space_type) { + switch (jcolor_space_type) { + case 0: + return FrameBuffer::Format::kRGB; + case 1: + return FrameBuffer::Format::kGRAY; + case 2: + return FrameBuffer::Format::kNV12; + case 3: + return FrameBuffer::Format::kNV21; + case 4: + return FrameBuffer::Format::kYV12; + case 5: + return FrameBuffer::Format::kYV21; + default: + break; + } + // Should never happen. + ThrowException(env, kIllegalStateException, + "The color space type is unsupported: %d", jcolor_space_type); + return FrameBuffer::Format::kRGB; +} + FrameBuffer::Orientation ConvertToFrameBufferOrientation(JNIEnv* env, jint jorientation) { switch (jorientation) { @@ -74,12 +107,103 @@ return FrameBuffer::Orientation::kLeftBottom; } // Should never happen. - ThrowException(env, kAssertionError, + ThrowException(env, kIllegalStateException, "The FrameBuffer Orientation type is unsupported: %d", jorientation); return FrameBuffer::Orientation::kTopLeft; } +// TODO(b/180051417): remove the code, once FrameBuffer can digest YUV buffers +// without format. +// Theoretically, when using CreateFromYuvRawBuffer, "format" can always be set +// to YV12 (or YV21, they are identical). However, prefer to set format to NV12 +// or NV21 whenever it's applicable, because NV12 and NV21 is better optimized +// in performance than YV12 or YV21. +StatusOr<FrameBuffer::Format> GetYUVImageFormat(const uint8* u_buffer, + const uint8* v_buffer, + int uv_pixel_stride) { + intptr_t u = reinterpret_cast<intptr_t>(u_buffer); + intptr_t v = reinterpret_cast<intptr_t>(v_buffer); + if ((std::abs(u - v) == 1) && (uv_pixel_stride == 2)) { + if (u_buffer > v_buffer) { + return FrameBuffer::Format::kNV21; + } else { + return FrameBuffer::Format::kNV12; + } + } + return FrameBuffer::Format::kYV12; +} + +StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromByteBuffer( + JNIEnv* env, + jobject jimage_byte_buffer, + jint width, + jint height, + jint jorientation, + jint jcolor_space_type) { + absl::string_view image = GetMappedFileBuffer(env, jimage_byte_buffer); + return CreateFromRawBuffer( + reinterpret_cast<const uint8*>(image.data()), + FrameBuffer::Dimension{width, height}, + ConvertToFrameBufferFormat(env, jcolor_space_type), + ConvertToFrameBufferOrientation(env, jorientation)); +} + +StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromBytes( + JNIEnv* env, + jbyteArray jimage_bytes, + jint width, + jint height, + jint jorientation, + jint jcolor_space_type, + jlongArray jbyte_array_handle) { + jbyte* jimage_ptr = env->GetByteArrayElements(jimage_bytes, NULL); + // Free jimage_ptr together with frame_buffer after inference is finished. + jlong jimage_ptr_handle = reinterpret_cast<jlong>(jimage_ptr); + // jbyte_array_handle has only one element, which is a holder for jimage_ptr. + env->SetLongArrayRegion(jbyte_array_handle, 0, 1, &jimage_ptr_handle); + + if (jimage_ptr == NULL) { + ThrowException(env, kIllegalStateException, + "Error occurred when reading image data from byte array."); + return nullptr; + } + + return CreateFromRawBuffer( + reinterpret_cast<const uint8*>(jimage_ptr), + FrameBuffer::Dimension{width, height}, + ConvertToFrameBufferFormat(env, jcolor_space_type), + ConvertToFrameBufferOrientation(env, jorientation)); +} + +StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromYuvPlanes( + JNIEnv* env, + jobject jy_plane, + jobject ju_plane, + jobject jv_plane, + jint width, + jint height, + jint row_stride_y, + jint row_stride_uv, + jint pixel_stride_uv, + jint jorientation) { + const uint8* y_plane = + reinterpret_cast<const uint8*>(GetMappedFileBuffer(env, jy_plane).data()); + const uint8* u_plane = + reinterpret_cast<const uint8*>(GetMappedFileBuffer(env, ju_plane).data()); + const uint8* v_plane = + reinterpret_cast<const uint8*>(GetMappedFileBuffer(env, jv_plane).data()); + + FrameBuffer::Format format; + ASSIGN_OR_RETURN(format, + GetYUVImageFormat(u_plane, v_plane, pixel_stride_uv)); + + return CreateFromYuvRawBuffer( + y_plane, u_plane, v_plane, format, FrameBuffer::Dimension{width, height}, + row_stride_y, row_stride_uv, pixel_stride_uv, + ConvertToFrameBufferOrientation(env, jorientation)); +} + } // namespace vision } // namespace task } // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h index 7cb63f3..4d7ec17 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h
@@ -18,6 +18,7 @@ #include <jni.h> +#include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" #include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" @@ -31,6 +32,37 @@ FrameBuffer::Orientation ConvertToFrameBufferOrientation(JNIEnv* env, jint jorientation); +// Creates FrameBuffer from a direct ByteBuffer. +::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> +CreateFrameBufferFromByteBuffer(JNIEnv* env, + jobject jimage_byte_buffer, + jint width, + jint height, + jint jorientation, + jint jcolor_space_type); + +// Creates FrameBuffer from a byte array. +::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> +CreateFrameBufferFromBytes(JNIEnv* env, + jbyteArray jimage_bytes, + jint width, + jint height, + jint jorientation, + jint jcolor_space_type, + jlongArray jbyte_array_handle); + +// Creates FrameBuffer from YUV planes. +::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> +CreateFrameBufferFromYuvPlanes(JNIEnv* env, + jobject jy_plane, + jobject ju_plane, + jobject jv_plane, + jint width, + jint height, + jint row_stride_y, + jint row_stride_uv, + jint pixel_stride_uv, + jint jorientation); } // namespace vision } // namespace task } // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/BUILD index 23d601df3..341980a 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/BUILD
@@ -1,4 +1,4 @@ -load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "jni_binary_with_tflite") package( default_visibility = ["//tensorflow_lite_support:users"], @@ -7,28 +7,44 @@ exports_files(["image_segmenter_jni.cc"]) -cc_library( +cc_library_with_tflite( name = "image_segmenter_native", - srcs = [ + tflite_jni_binaries = [ ":libtask_vision_jni.so", ], ) -tflite_jni_binary( +jni_binary_with_tflite( name = "libtask_vision_jni.so", + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + tflite_deps = [ + ":native_without_resolver", + "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver", + ], +) + +# Shared native logic for ImageSegmenter. Combine this target and customized +# version of op_resolver to build customized image_segmenter_native target. +cc_library_with_tflite( + name = "native_without_resolver", srcs = [ "image_segmenter_jni.cc", + "//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc", + "//tensorflow_lite_support/java/src/native/task/vision/core:base_vision_task_api_jni.cc", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/vision:image_segmenter", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/src/native/task/vision:jni_utils", ], deps = [ "//tensorflow_lite_support/cc/port:statusor", - "//tensorflow_lite_support/cc/task/vision:image_segmenter", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", "//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", - "//tensorflow_lite_support/cc/utils:jni_utils", "//tensorflow_lite_support/java/jni", - "//tensorflow_lite_support/java/src/native/task/vision:jni_utils", "@com_google_absl//absl/strings", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc index 6e0bbc5..8d8c8ee 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc
@@ -18,8 +18,9 @@ #include <memory> #include <string> -#include "absl/strings/str_cat.h" +#include "absl/strings/str_cat.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" #include "tensorflow_lite_support/cc/task/vision/image_segmenter.h" #include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h" @@ -28,16 +29,24 @@ #include "tensorflow_lite_support/cc/utils/jni_utils.h" #include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h" +namespace tflite { +namespace task { +// To be provided by a link-time library +extern std::unique_ptr<OpResolver> CreateOpResolver(); + +} // namespace task +} // namespace tflite + namespace { using ::tflite::support::StatusOr; using ::tflite::support::utils::CreateByteArray; -using ::tflite::support::utils::GetMappedFileBuffer; -using ::tflite::support::utils::kAssertionError; +using ::tflite::support::utils::GetExceptionClassNameForStatusCode; using ::tflite::support::utils::kIllegalArgumentException; +using ::tflite::support::utils::kIllegalStateException; using ::tflite::support::utils::kInvalidPointer; using ::tflite::support::utils::ThrowException; -using ::tflite::task::vision::ConvertToFrameBufferOrientation; +using ::tflite::task::core::BaseOptions; using ::tflite::task::vision::FrameBuffer; using ::tflite::task::vision::ImageSegmenter; using ::tflite::task::vision::ImageSegmenterOptions; @@ -59,9 +68,16 @@ // Creates an ImageSegmenterOptions proto based on the Java class. ImageSegmenterOptions ConvertToProtoOptions(JNIEnv* env, jstring display_names_locale, - jint output_type) { + jint output_type, + jlong base_options_handle) { ImageSegmenterOptions proto_options; + if (base_options_handle != kInvalidPointer) { + // proto_options will free the previous base_options and set the new one. + proto_options.set_allocated_base_options( + reinterpret_cast<BaseOptions*>(base_options_handle)); + } + const char* pchars = env->GetStringUTFChars(display_names_locale, nullptr); proto_options.set_display_names_locale(pchars); env->ReleaseStringUTFChars(display_names_locale, pchars); @@ -82,15 +98,15 @@ return proto_options; } -void ConvertToSegmentationResults(JNIEnv* env, - const SegmentationResult& results, - jobject jmask_buffers, - jintArray jmask_shape, - jobject jcolored_labels) { +void ConvertFromSegmentationResults(JNIEnv* env, + const SegmentationResult& results, + jobject jmask_buffers, + jintArray jmask_shape, + jobject jcolored_labels) { if (results.segmentation_size() != 1) { // Should never happen. ThrowException( - env, kAssertionError, + env, kIllegalStateException, "ImageSegmenter only supports one segmentation result, getting %d", results.segmentation_size()); } @@ -157,6 +173,25 @@ } } +jlong CreateImageSegmenterFromOptions(JNIEnv* env, + const ImageSegmenterOptions& options) { + StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or = + ImageSegmenter::CreateFromOptions(options, + tflite::task::CreateOpResolver()); + if (image_segmenter_or.ok()) { + return reinterpret_cast<jlong>(image_segmenter_or->release()); + } else { + ThrowException( + env, + GetExceptionClassNameForStatusCode(image_segmenter_or.status().code()), + "Error occurred when initializing ImageSegmenter: %s", + image_segmenter_or.status().message().data()); + return kInvalidPointer; + } +} + +} // namespace + extern "C" JNIEXPORT void JNICALL Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni( JNIEnv* env, @@ -165,6 +200,9 @@ delete reinterpret_cast<ImageSegmenter*>(native_handle); } +// Creates an ImageSegmenter instance from the model file descriptor. +// file_descriptor_length and file_descriptor_offset are optional. Non-possitive +// values will be ignored. extern "C" JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFdAndOptions( JNIEnv* env, @@ -173,25 +211,37 @@ jlong file_descriptor_length, jlong file_descriptor_offset, jstring display_names_locale, - jint output_type) { - ImageSegmenterOptions proto_options = - ConvertToProtoOptions(env, display_names_locale, output_type); - auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata() + jint output_type, + jlong base_options_handle) { + ImageSegmenterOptions proto_options = ConvertToProtoOptions( + env, display_names_locale, output_type, base_options_handle); + auto file_descriptor_meta = proto_options.mutable_base_options() + ->mutable_model_file() ->mutable_file_descriptor_meta(); file_descriptor_meta->set_fd(file_descriptor); - file_descriptor_meta->set_length(file_descriptor_length); - file_descriptor_meta->set_offset(file_descriptor_offset); - - StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or = - ImageSegmenter::CreateFromOptions(proto_options); - if (image_segmenter_or.ok()) { - return reinterpret_cast<jlong>(image_segmenter_or->release()); - } else { - ThrowException(env, kAssertionError, - "Error occurred when initializing ImageSegmenter: %s", - image_segmenter_or.status().message().data()); - return kInvalidPointer; + if (file_descriptor_length > 0) { + file_descriptor_meta->set_length(file_descriptor_length); } + if (file_descriptor_offset > 0) { + file_descriptor_meta->set_offset(file_descriptor_offset); + } + return CreateImageSegmenterFromOptions(env, proto_options); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuffer( + JNIEnv* env, + jclass thiz, + jobject model_buffer, + jstring display_names_locale, + jint output_type, + jlong base_options_handle) { + ImageSegmenterOptions proto_options = ConvertToProtoOptions( + env, display_names_locale, output_type, base_options_handle); + proto_options.mutable_base_options()->mutable_model_file()->set_file_content( + static_cast<char*>(env->GetDirectBufferAddress(model_buffer)), + static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer))); + return CreateImageSegmenterFromOptions(env, proto_options); } extern "C" JNIEXPORT void JNICALL @@ -199,28 +249,23 @@ JNIEnv* env, jclass thiz, jlong native_handle, - jobject jimage_byte_buffer, - jint width, - jint height, + jlong frame_buffer_handle, jobject jmask_buffers, jintArray jmask_shape, - jobject jcolored_labels, - jint jorientation) { + jobject jcolored_labels) { auto* segmenter = reinterpret_cast<ImageSegmenter*>(native_handle); - absl::string_view image = GetMappedFileBuffer(env, jimage_byte_buffer); - std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( - reinterpret_cast<const uint8*>(image.data()), - FrameBuffer::Dimension{width, height}, - ConvertToFrameBufferOrientation(env, jorientation)); + // frame_buffer will be deleted after inference is done in + // base_vision_api_jni.cc. + auto* frame_buffer = reinterpret_cast<FrameBuffer*>(frame_buffer_handle); + auto results_or = segmenter->Segment(*frame_buffer); if (results_or.ok()) { - ConvertToSegmentationResults(env, results_or.value(), jmask_buffers, - jmask_shape, jcolored_labels); + ConvertFromSegmentationResults(env, results_or.value(), jmask_buffers, + jmask_shape, jcolored_labels); } else { - ThrowException(env, kAssertionError, - "Error occurred when segmenting the image: %s", - results_or.status().message().data()); + ThrowException( + env, GetExceptionClassNameForStatusCode(results_or.status().code()), + "Error occurred when segmenting the image: %s", + results_or.status().message().data()); } } - -} // namespace
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/BUILD index ed5bedc..5e5c25b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/BUILD
@@ -1,7 +1,7 @@ load("//tensorflow_lite_support/metadata:build_defs.bzl", "stamp_metadata_parser_version") package( - default_visibility = ["//tensorflow_lite_support:users"], + default_visibility = ["//tensorflow_lite_support:internal"], licenses = ["notice"], # Apache 2.0 ) @@ -15,7 +15,12 @@ name = "metadata_extractor", srcs = ["metadata_extractor.cc"], hdrs = ["metadata_extractor.h"], + visibility = ["//visibility:public"], deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/metadata:metadata_schema_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -23,15 +28,7 @@ "@com_google_absl//absl/strings:str_format", "@flatbuffers", "@org_libzip//:zip", - ] + select({ - "//tensorflow_lite_support/cc:tflite_use_c_api": ["@org_tensorflow//tensorflow/lite/c:c_api"], - "//conditions:default": ["@org_tensorflow//tensorflow/lite:framework"], - }) + [ "@org_tensorflow//tensorflow/lite/schema:schema_fbs", - "//tensorflow_lite_support/cc:common", - "//tensorflow_lite_support/cc/port:status_macros", - "//tensorflow_lite_support/cc/port:statusor", - "//tensorflow_lite_support/metadata:metadata_schema_cc", ], ) @@ -51,3 +48,22 @@ "@org_tensorflow//tensorflow/lite/tools:logging", ], ) + +cc_library( + name = "metadata_populator", + srcs = ["metadata_populator.cc"], + hdrs = ["metadata_populator.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "//tensorflow_lite_support/metadata/cc/utils:zip_mem_file", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@flatbuffers//:runtime_cc", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + "@zlib//:zlib_minizip", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc index 42f2a7c..3aae0aa 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
@@ -17,9 +17,9 @@ #include <functional> -#include "absl/memory/memory.h" -#include "absl/status/status.h" -#include "absl/strings/str_format.h" +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "lib/zip.h" // from @org_libzip #include "tensorflow/lite/schema/schema_generated.h" @@ -27,12 +27,6 @@ #include "tensorflow_lite_support/cc/port/status_macros.h" #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" -#if TFLITE_USE_C_API -#include "tensorflow/lite/c/c_api.h" -#else -#include "tensorflow/lite/model_builder.h" -#endif - namespace tflite { namespace metadata {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h index 8eafe93..dc9a992 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h
@@ -15,9 +15,9 @@ #ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_ #define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_ -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/strings/string_view.h" +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h new file mode 100644 index 0000000..9037f58 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h
@@ -0,0 +1,92 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_POPULATOR_H_ +#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_POPULATOR_H_ + +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace metadata { + +// TODO(b/185787843): bring to feature parity with Python library. + +// Provides an interface to pack TFLite ModelMetadata [1] and corresponding +// associated files into a TFLite FlatBuffer. +// +// [1]: https://www.tensorflow.org/lite/convert/metadata +class ModelMetadataPopulator { + public: + // Creates a ModelMetadataPopulator from the provided TFLite Model FlatBuffer + // and returns a pointer to the new object. Ownership is transferred to the + // caller. Returns an error if the creation failed, which may happen e.g. if + // the provided buffer is not a valid TFLite FlatBuffer. + // + // It is recommended to obtain and manage the buffer through an + // ExternalFileHandler[1], which is optimized through mmap(2) to avoid having + // to load the entire buffer in memory when provided by path or file + // descriptor. + // + // [1]: + // tensorflow_lite_support/c/task/core/external_file_handler.h + static tflite::support::StatusOr<std::unique_ptr<ModelMetadataPopulator>> + CreateFromModelBuffer(const char* buffer_data, size_t buffer_size); + + // Writes the TFLite ModelMetadata provided as a buffer into the TFLite + // FlatBuffer model. + // + // Warning: this method overwrites any already existing TFLite Model Metadata. + // Calling this method multiple times overwrites the metadata from previous + // calls, so this method should usually be called only once. + void LoadMetadata(const char* metadata_buffer_data, + size_t metadata_buffer_size); + + // Loads associated files into the TFLite FlatBuffer model. The input is a map + // of {filename, file contents}. + // + // Warning: this method removes any previoulsy present associated files. + // Calling this method multiple time removes any associated files from + // previous calls, so this method should usually be called only once. + void LoadAssociatedFiles( + const absl::flat_hash_map<std::string, std::string>& associated_files); + + // Finalizes metadata population. Returns the TFLite FlatBuffer model with + // metadata and associated files as a string buffer. + tflite::support::StatusOr<std::string> Populate(); + + private: + // Private constructor. + explicit ModelMetadataPopulator(const tflite::Model& model); + // Zips and appends associated files to the provided model buffer. Called + // internally by `Populate()`. + tflite::support::StatusOr<std::string> AppendAssociatedFiles( + const char* model_buffer_data, + size_t model_buffer_size); + + // The unpacked model FlatBuffer. + tflite::ModelT model_t_; + // The associated files. + absl::flat_hash_map<std::string, std::string> associated_files_; +}; + +} // namespace metadata +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_POPULATOR_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc index 8c8e1fd..ed75b65 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc
@@ -22,8 +22,8 @@ #include <string> #include <vector> -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" +#include "absl/strings/str_join.h" // from @com_google_absl +#include "absl/strings/str_split.h" // from @com_google_absl #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" @@ -45,6 +45,7 @@ kSubGraphMetadataInputTensorGroups = 5, kSubGraphMetadataOutputTensorGroups = 6, kProcessUnitOptionsRegexTokenizerOptions = 7, + kContentPropertiesAudioProperties = 8, }; // Helper class to compare semantic versions in terms of three integers, major, @@ -107,6 +108,8 @@ return Version(1, 2, 0); case SchemaMembers::kProcessUnitOptionsRegexTokenizerOptions: return Version(1, 2, 1); + case SchemaMembers::kContentPropertiesAudioProperties: + return Version(1, 3, 0); default: // Should never happen. TFLITE_LOG(FATAL) << "Unsupported schema member: " @@ -182,6 +185,20 @@ } template <> +void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table, + Version* min_version) { + if (table == nullptr) + return; + + // Checks the ContenProperties field. + if (table->content_properties_type() == ContentProperties_AudioProperties) { + UpdateMinimumVersion( + GetMemberVersion(SchemaMembers::kContentPropertiesAudioProperties), + min_version); + } +} + +template <> void UpdateMinimumVersionForTable<tflite::TensorMetadata>( const tflite::TensorMetadata* table, Version* min_version) { @@ -195,6 +212,9 @@ // Checks the process_units field. UpdateMinimumVersionForArray<tflite::ProcessUnit>(table->process_units(), min_version); + + // Check the content field. + UpdateMinimumVersionForTable<tflite::Content>(table->content(), min_version); } template <>
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/python/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/python/BUILD index 34e9a4f..7f78543 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/python/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/python/BUILD
@@ -1,9 +1,7 @@ load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") package( - default_visibility = [ - "//tensorflow_lite_support/metadata:__subpackages__", - ], + default_visibility = ["//tensorflow_lite_support:internal"], licenses = ["notice"], # Apache 2.0 )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/BUILD new file mode 100644 index 0000000..ced8644 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/BUILD
@@ -0,0 +1,16 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "zip_mem_file", + srcs = ["zip_mem_file.cc"], + hdrs = ["zip_mem_file.h"], + deps = [ + "@com_google_absl//absl/strings", + "@zlib//:zlib_minizip", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/BUILD index d4171bf..5efd8072 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/BUILD
@@ -1,9 +1,7 @@ load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") package( - default_visibility = [ - "//visibility:public", - ], + default_visibility = ["//tensorflow_lite_support:internal"], licenses = ["notice"], # Apache 2.0 )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/BUILD index 0fbe028..5f000b4 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/BUILD
@@ -2,7 +2,6 @@ # TensorFlow Lite Support API in Java for metadata. load("@build_bazel_rules_android//android:rules.bzl", "android_library") -load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") package( default_visibility = ["//visibility:public"], @@ -16,7 +15,7 @@ android_library( name = "tensorflowlite_support_metadata", srcs = METADATA_SRCS, - javacopts = JAVACOPTS, + javacopts = ["-source 7 -target 7"], manifest = "AndroidManifest.xml", deps = [ "//tensorflow_lite_support/metadata:metadata_schema_fbs_android", @@ -33,16 +32,18 @@ java_library( name = "tensorflowlite_support_metadata_lib", srcs = METADATA_SRCS, - javacopts = JAVACOPTS, + javacopts = ["--release 7"], resource_jars = [ "//tensorflow_lite_support/metadata:libmetadata_schema_java.jar", "//tensorflow_lite_support/metadata:libschema_fbs_java.jar", ], + # LINT.IfChange(dep) deps = [ "//tensorflow_lite_support/metadata:metadata_schema_java", "//tensorflow_lite_support/metadata:schema_fbs_java", "@org_checkerframework_qual", ], + # LINT.ThenChange(<INTERNAL>/release/build_metadata_pom.sh:dep) ) alias(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java index c4537c8..20f5566 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java
@@ -21,7 +21,7 @@ * The version of the metadata parser that this metadata extractor library is depending on. The * value should match the value of "Schema Semantic version" in metadata_schema.fbs. */ - public static final String VERSION = "1.2.1"; + public static final String VERSION = "1.3.0"; private MetadataParser() {} }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/BUILD index 204b24b7..22446ca 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/BUILD
@@ -31,10 +31,13 @@ "//tensorflow_lite_support/metadata:metadata_schema_fbs_android", "//tensorflow_lite_support/metadata:schema_fbs_android", "//tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata", - "//third_party/java/jakarta_commons_io", - "//third_party/java/truth:truth-android", "@flatbuffers//:runtime_android", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:commons_io_commons_io", + "@maven//:org_robolectric_robolectric", "@org_checkerframework_qual", + "@robolectric//bazel:android-all", ], ) @@ -48,7 +51,10 @@ test_class = "org.tensorflow.lite.support.metadata.ByteBufferChannelTest", deps = [ "//tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata", - "//third_party/java/truth:truth-android", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:org_robolectric_robolectric", + "@robolectric//bazel:android-all", ], ) @@ -65,8 +71,11 @@ deps = [ ":test_lib", "//tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata", - "//third_party/java/jakarta_commons_io", - "//third_party/java/truth:truth-android", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:commons_io_commons_io", + "@maven//:org_robolectric_robolectric", + "@robolectric//bazel:android-all", ], ) @@ -80,7 +89,10 @@ test_class = "org.tensorflow.lite.support.metadata.BoundedInputStreamTest", deps = [ "//tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata", - "//third_party/java/truth:truth-android", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:org_robolectric_robolectric", + "@robolectric//bazel:android-all", ], ) @@ -93,7 +105,7 @@ test_class = "org.tensorflow.lite.support.metadata.MetadataParserTest", deps = [ "//tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata_lib", - "//third_party/java/truth", - "@junit", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/assets/mobilenet_v1_1.0_224_quant.tflite b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/assets/mobilenet_v1_1.0_224_quant.tflite new file mode 100644 index 0000000..8cf2048f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/assets/mobilenet_v1_1.0_224_quant.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java index 5990979..9f1173a1 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java
@@ -29,6 +29,7 @@ import org.apache.commons.io.IOUtils; import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Suite; @@ -107,6 +108,7 @@ assertThat(metadataExtractor.hasMetadata()).isFalse(); } + @Ignore @Test public void getAssociatedFile_validAssociateFile() throws Exception { ByteBuffer mobileNetBuffer = loadMobileNetBuffer(); @@ -121,6 +123,7 @@ .isTrue(); } + @Ignore @Test public void getAssociatedFile_invalidAssociateFile() throws Exception { ByteBuffer mobileNetBuffer = loadMobileNetBuffer(); @@ -131,6 +134,7 @@ "The file, %s, does not exist in the zip file.", INVALID_LABEL_FILE_NAME)); } + @Ignore @Test public void getAssociatedFile_nullFileName() throws Exception { ByteBuffer mobileNetBuffer = loadMobileNetBuffer(); @@ -165,6 +169,7 @@ "This model does not contain associated files, and is not a Zip file."); } + @Ignore @Test public void getAssociatedFileNames_validFileNames() throws Exception { ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java index af80fba..80d2ddc6f 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java
@@ -25,6 +25,7 @@ import androidx.test.core.app.ApplicationProvider; import org.apache.commons.io.IOUtils; +import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.robolectric.RobolectricTestRunner; @@ -78,6 +79,7 @@ assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive."); } + @Ignore @Test public void getFileNames_correctFileName() throws Exception { ByteBufferChannel modelChannel = loadModel(MODEL_PATH); @@ -87,6 +89,7 @@ assertThat(zipFile.getFileNames()).isEqualTo(expectedSet); } + @Ignore @Test public void getRawInputStream_existentFile() throws Exception { ByteBufferChannel modelChannel = loadModel(MODEL_PATH); @@ -98,6 +101,7 @@ assertThat(IOUtils.contentEquals(goldenFileStream, fileStream)).isTrue(); } + @Ignore @Test public void getRawInputStream_nonExistentFile() throws Exception { ByteBufferChannel modelChannel = loadModel(MODEL_PATH); @@ -109,6 +113,7 @@ "The file, %s, does not exist in the zip file.", INVALID_LABEL_FILE_NAME)); } + @Ignore @Test public void close_validStatus() throws Exception { ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/metadata_schema.fbs b/third_party/tflite_support/src/tensorflow_lite_support/metadata/metadata_schema.fbs index 8faae0a..e0b95af 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/metadata_schema.fbs +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/metadata_schema.fbs
@@ -50,7 +50,7 @@ // for which they were added. // // LINT.IfChange -// Schema Semantic version: 1.2.1 +// Schema Semantic version: 1.3.0 // LINT.ThenChange(//tensorflow_lite_support/\ // metadata/java/src/java/org/tensorflow/lite/support/metadata/\ // MetadataParser.java) @@ -69,6 +69,7 @@ // 1.2.0 - Added input_tensor_group to SubGraphMetadata. // Added output_tensor_group to SubGraphMetadata. // 1.2.1 - Added RegexTokenizerOptions to ProcessUnitOptions. +// 1.3.0 - Added AudioProperties to ContentProperties. // File extension of any written files. file_extension "tflitemeta"; @@ -80,10 +81,12 @@ // Files such as readme.txt. DESCRIPTIONS = 1, - // Contains labels that annotate certain axis of the tensor. For example, + // Contains a list of labels (characters separated by "\n" or in lines) that + // annotate certain axis of the tensor. For example, // the label file in image classification. Those labels annotate the // the output tensor, such that each value in the output tensor is the - // probability of that corresponding category specified by the label. + // probability of that corresponding category specified by the label. See the + // example label file used in image classification [1]. // // <Codegen usage>: // If an output tensor has an associated file as TENSOR_AXIS_LABELS, return @@ -92,12 +95,16 @@ // If multiple files of the same type are present, the first one is used by // default; additional ones are to be distinguished from one another by their // specified locale. + // + // [1]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt TENSOR_AXIS_LABELS = 2, - // Contains labels that tensor values correspond to. For example, in + // Contains a list of labels (characters separated by "\n" or in lines) that + // tensor values correspond to. For example, in // the object detection model, one of the output tensors is the detected // classes. And each value in the tensor refers to the index of label in the - // category label file. + // category label file. See the example label file used in object detection + // [1]. // // <Codegen usage>: // If an output tensor has an associated file as TENSOR_VALUE_LABELS, convert @@ -105,23 +112,33 @@ // If multiple files of the same type are present, the first one is used by // default; additional ones are to be distinguished from one another by their // specified locale. + // + // [1]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/labelmap.txt TENSOR_VALUE_LABELS = 3, // Contains sigmoid-based score calibration parameters, formatted as CSV. // Lines contain for each index of an output tensor the scale, slope, offset // and (optional) min_score parameters to be used for sigmoid fitting (in this - // order and in `strtof`-compatible [1] format). + // order and in `strtof`-compatible [1] format). Scale should be a + // non-negative value. // A line may be left empty to default calibrated scores for this index to // default_score. // In summary, each line should thus contain 0, 3 or 4 comma-separated values. // + // See the example score calibration file used in image classification [2]. + // // See documentation for ScoreCalibrationOptions for details. // // [1]: https://en.cppreference.com/w/c/string/byte/strtof + // [2]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/score_calibration.txt TENSOR_AXIS_SCORE_CALIBRATION = 4, // Contains a list of unique words (characters separated by "\n" or in lines) // that help to convert natural language words to embedding vectors. + // + // See the example vocab file used in text classification [1]. + // + // [1]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/vocab.txt // Added in: 1.0.1 VOCABULARY = 5, } @@ -232,6 +249,16 @@ } +// The properties for audio tensors. +// Added in: 1.3.0 +table AudioProperties { + // The sample rate in Hz when the audio was captured. + sample_rate:uint; + + // The channel count of the audio. + channels:uint; +} + enum CoordinateType : byte { // The coordinates are float values from 0 to 1. RATIO = 0, @@ -267,6 +294,8 @@ FeatureProperties, ImageProperties, BoundingBoxProperties, + // Added in: 1.3.0 + AudioProperties, } table ValueRange { @@ -412,6 +441,9 @@ // An AssociatedFile with type TANSOR_AXIS_SCORE_CALIBRATION specifying the // index-specific parameters must be associated with the corresponding // TensorMetadata for score calibration be applied. +// +// See the example score calibration file used in image classification [1]. +// [1]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/score_calibration.txt table ScoreCalibrationOptions { // The function to use for transforming the uncalibrated score before // applying score calibration.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/BUILD index d6d9357..3668e49 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/BUILD
@@ -1,9 +1,8 @@ load("//tensorflow_lite_support/metadata:build_defs.bzl", "stamp_metadata_parser_version") +# Placeholder for internal Python strict library compatibility macro. package( - default_visibility = [ - "//visibility:public", - ], + default_visibility = ["//tensorflow_lite_support:internal"], licenses = ["notice"], # Apache 2.0 ) @@ -20,14 +19,39 @@ ":metadata_parser_py", ], data = ["//tensorflow_lite_support/metadata:metadata_schema.fbs"], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ - "//tensorflow_lite_support/custom_ops:expect_numpy_installed", "//tensorflow_lite_support/metadata:metadata_schema_py", "//tensorflow_lite_support/metadata:schema_py", "//tensorflow_lite_support/metadata/cc/python:_pywrap_metadata_version", "//tensorflow_lite_support/metadata/flatbuffers_lib:_pywrap_flatbuffers", - "//tensorflow_lite_support/tools:expect_flatbuffers_installed", + "@flatbuffers//:runtime_py", + ], +) + +py_binary( + name = "metadata_displayer", + srcs = ["metadata_displayer.py"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":metadata", + "@absl_py//absl:app", + "@absl_py//absl/flags", + ], +) + +py_library( + name = "metadata_writer_for_task", + srcs = ["metadata_writer_for_task.py"], + srcs_version = "PY3", + deps = [ + ":metadata", + "//tensorflow_lite_support/metadata:metadata_schema_py", + "//tensorflow_lite_support/metadata/python/metadata_writers:metadata_info", + "//tensorflow_lite_support/metadata/python/metadata_writers:metadata_writer", + "//tensorflow_lite_support/metadata/python/metadata_writers:writer_utils", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/__init__.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/__init__.py new file mode 100644 index 0000000..2ae3e34 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/__init__.py
@@ -0,0 +1,13 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata.py index 15a677d..a312974 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata.py
@@ -14,10 +14,6 @@ # ============================================================================== """TensorFlow Lite metadata tools.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import copy import inspect import io @@ -34,6 +30,47 @@ from tensorflow_lite_support.metadata.cc.python import _pywrap_metadata_version from tensorflow_lite_support.metadata.flatbuffers_lib import _pywrap_flatbuffers +try: + # If exists, optionally use TensorFlow to open and check files. Used to + # support more than local file systems. + # In pip requirements, we doesn't necessarily need tensorflow as a dep. + import tensorflow as tf # pylint: disable=g-import-not-at-top + _open_file = tf.io.gfile.GFile + _exists_file = tf.io.gfile.exists +except ImportError as e: + # If TensorFlow package doesn't exist, fall back to original open and exists. + _open_file = open + _exists_file = os.path.exists + + +def _maybe_open_as_binary(filename, mode): + """Maybe open the binary file, and returns a file-like.""" + if hasattr(filename, "read"): # A file-like has read(). + return filename + openmode = mode if "b" in mode else mode + "b" # Add binary explicitly. + return _open_file(filename, openmode) + + +def _open_as_zipfile(filename, mode="r"): + """Open file as a zipfile. + + Args: + filename: str or file-like or path-like, to the zipfile. + mode: str, common file mode for zip. + (See: https://docs.python.org/3/library/zipfile.html) + + Returns: + A ZipFile object. + """ + file_like = _maybe_open_as_binary(filename, mode) + return zipfile.ZipFile(file_like, mode) + + +def _is_zipfile(filename): + """Checks whether it is a zipfile.""" + with _maybe_open_as_binary(filename, "r") as f: + return zipfile.is_zipfile(f) + def get_path_to_datafile(path): """Gets the path to the specified file in the data dependencies. @@ -88,6 +125,8 @@ # populator.load_metadata_buffer(metadata_buf) populator.load_metadata_file(metadata_file) populator.load_associated_files([label.txt]) + # For associated file buffer (bytearray read from the file), use: + # populator.load_associated_file_buffers({"label.txt": b"file content"}) populator.populate() # Populating a metadata file (or a metadta buffer) and associated files to @@ -100,6 +139,14 @@ updated_model_buf = populator.get_model_buffer() with open("updated_model.tflite", "wb") as f: f.write(updated_model_buf) + + # Transferring metadata and associated files from another TFLite model: + populator = MetadataPopulator.with_model_buffer(model_buf) + populator_dst.load_metadata_and_associated_files(src_model_buf) + populator_dst.populate() + updated_model_buf = populator.get_model_buffer() + with open("updated_model.tflite", "wb") as f: + f.write(updated_model_buf) ``` Note that existing metadata buffer (if applied) will be overridden by the new @@ -128,7 +175,8 @@ _assert_model_file_identifier(model_file) self._model_file = model_file self._metadata_buf = None - self._associated_files = set() + # _associated_files is a dict of file name and file buffer. + self._associated_files = {} @classmethod def with_model_file(cls, model_file): @@ -169,7 +217,7 @@ Returns: Model buffer (in bytearray). """ - with open(self._model_file, "rb") as f: + with _open_file(self._model_file, "rb") as f: return f.read() def get_packed_associated_file_list(self): @@ -178,10 +226,10 @@ Returns: List of packed associated files. """ - if not zipfile.is_zipfile(self._model_file): + if not _is_zipfile(self._model_file): return [] - with zipfile.ZipFile(self._model_file, "r") as zf: + with _open_as_zipfile(self._model_file, "r") as zf: return zf.namelist() def get_recorded_associated_file_list(self): @@ -193,47 +241,31 @@ Returns: List of recorded associated files. """ - recorded_files = [] - if not self._metadata_buf: - return recorded_files + return [] - metadata = _metadata_fb.ModelMetadata.GetRootAsModelMetadata( - self._metadata_buf, 0) + metadata = _metadata_fb.ModelMetadataT.InitFromObj( + _metadata_fb.ModelMetadata.GetRootAsModelMetadata( + self._metadata_buf, 0)) - # Add associated files attached to ModelMetadata. - recorded_files += self._get_associated_files_from_table( - metadata, "AssociatedFiles") + return [ + file.name.decode("utf-8") + for file in self._get_recorded_associated_file_object_list(metadata) + ] - # Add associated files attached to each SubgraphMetadata. - for j in range(metadata.SubgraphMetadataLength()): - subgraph = metadata.SubgraphMetadata(j) - recorded_files += self._get_associated_files_from_table( - subgraph, "AssociatedFiles") + def load_associated_file_buffers(self, associated_files): + """Loads the associated file buffers (in bytearray) to be populated. - # Add associated files attached to each input tensor. - for k in range(subgraph.InputTensorMetadataLength()): - recorded_files += self._get_associated_files_from_table( - subgraph.InputTensorMetadata(k), "AssociatedFiles") - recorded_files += self._get_associated_files_from_process_units( - subgraph.InputTensorMetadata(k), "ProcessUnits") + Args: + associated_files: a dictionary of associated file names and corresponding + file buffers, such as {"file.txt": b"file content"}. If pass in file + paths for the file name, only the basename will be populated. + """ - # Add associated files attached to each output tensor. - for k in range(subgraph.OutputTensorMetadataLength()): - recorded_files += self._get_associated_files_from_table( - subgraph.OutputTensorMetadata(k), "AssociatedFiles") - recorded_files += self._get_associated_files_from_process_units( - subgraph.OutputTensorMetadata(k), "ProcessUnits") - - # Add associated files attached to the input_process_units. - recorded_files += self._get_associated_files_from_process_units( - subgraph, "InputProcessUnits") - - # Add associated files attached to the output_process_units. - recorded_files += self._get_associated_files_from_process_units( - subgraph, "OutputProcessUnits") - - return recorded_files + self._associated_files.update({ + os.path.basename(name): buffers + for name, buffers in associated_files.items() + }) def load_associated_files(self, associated_files): """Loads associated files that to be concatenated after the model file. @@ -245,9 +277,10 @@ IOError: File not found. """ - for af in associated_files: - _assert_file_exist(af) - self._associated_files.add(af) + for af_name in associated_files: + _assert_file_exist(af_name) + with _open_file(af_name, "rb") as af: + self.load_associated_file_buffers({af_name: af.read()}) def load_metadata_buffer(self, metadata_buf): """Loads the metadata buffer (in bytearray) to be populated. @@ -277,6 +310,10 @@ _metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0)) metadata.minParserVersion = min_version + # Remove local file directory in the `name` field of `AssociatedFileT`, and + # make it consistent with the name of the actual file packed in the model. + self._use_basename_for_associated_files_in_metadata(metadata) + b = flatbuffers.Builder(0) b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER) metadata_buf_with_version = b.Output() @@ -299,10 +336,28 @@ of input/output tensor metadata. """ _assert_file_exist(metadata_file) - with open(metadata_file, "rb") as f: + with _open_file(metadata_file, "rb") as f: metadata_buf = f.read() self.load_metadata_buffer(bytearray(metadata_buf)) + def load_metadata_and_associated_files(self, src_model_buf): + """Loads the metadata and associated files from another model buffer. + + Args: + src_model_buf: source model buffer (in bytearray) with metadata and + associated files. + """ + # Load the model metadata from src_model_buf if exist. + metadata_buffer = _get_metadata_buffer(src_model_buf) + if metadata_buffer: + self.load_metadata_buffer(metadata_buffer) + + # Load the associated files from src_model_buf if exist. + if _is_zipfile(io.BytesIO(src_model_buf)): + with _open_as_zipfile(io.BytesIO(src_model_buf)) as zf: + self.load_associated_file_buffers( + {f: zf.read(f) for f in zf.namelist()}) + def populate(self): """Populates loaded metadata and associated files into the model file.""" self._assert_validate() @@ -324,9 +379,7 @@ packed_files = self.get_packed_associated_file_list() # Gets the file name of those associated files to be populated. - to_be_populated_files = [] - for af in self._associated_files: - to_be_populated_files.append(os.path.basename(af)) + to_be_populated_files = self._associated_files.keys() # Checks all files recorded in the metadata will be populated. for rf in recorded_files: @@ -340,18 +393,17 @@ if f not in recorded_files: warnings.warn( - "File, '{0}', does not exsit in the metadata. But packing it to " + "File, '{0}', does not exist in the metadata. But packing it to " "tflite model is still allowed.".format(f)) - def _copy_archived_files(self, src_zip, dst_zip, file_list): + def _copy_archived_files(self, src_zip, file_list, dst_zip): """Copy archieved files in file_list from src_zip ro dst_zip.""" - if not zipfile.is_zipfile(src_zip): + if not _is_zipfile(src_zip): raise ValueError("File, '{0}', is not a zipfile.".format(src_zip)) - with zipfile.ZipFile(src_zip, - "r") as src_zf, zipfile.ZipFile(dst_zip, - "a") as dst_zf: + with _open_as_zipfile(src_zip, "r") as src_zf, \ + _open_as_zipfile(dst_zip, "a") as dst_zf: src_list = src_zf.namelist() for f in file_list: if f not in src_list: @@ -373,36 +425,25 @@ either "InputProcessUnits" or "OutputProcessUnits". Returns: - the associated files list. + A list of AssociatedFileT objects. """ if table is None: - return + return [] file_list = [] - length_method = getattr(table, field_name + "Length", None) - member_method = getattr(table, field_name, None) - if length_method is None or member_method is None: - raise ValueError("{0} does not have the field {1}".format( - type(table).__name__, field_name)) - - for k in range(length_method()): - process_unit = member_method(k) - tokenizer = process_unit.Options() - if (process_unit.OptionsType() is - _metadata_fb.ProcessUnitOptions.BertTokenizerOptions): - bert_tokenizer = _metadata_fb.BertTokenizerOptions() - bert_tokenizer.Init(tokenizer.Bytes, tokenizer.Pos) + process_units = getattr(table, field_name) + # If the process_units field is not populated, it will be None. Use an + # empty list to skip the check. + for process_unit in process_units or []: + options = process_unit.options + if isinstance(options, (_metadata_fb.BertTokenizerOptionsT, + _metadata_fb.RegexTokenizerOptionsT)): + file_list += self._get_associated_files_from_table(options, "vocabFile") + elif isinstance(options, _metadata_fb.SentencePieceTokenizerOptionsT): file_list += self._get_associated_files_from_table( - bert_tokenizer, "VocabFile") - elif (process_unit.OptionsType() is - _metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions): - sentence_piece = _metadata_fb.SentencePieceTokenizerOptions() - sentence_piece.Init(tokenizer.Bytes, tokenizer.Pos) - file_list += self._get_associated_files_from_table( - sentence_piece, "SentencePieceModel") - file_list += self._get_associated_files_from_table( - sentence_piece, "VocabFile") + options, "sentencePieceModel") + file_list += self._get_associated_files_from_table(options, "vocabFile") return file_list def _get_associated_files_from_table(self, table, field_name): @@ -417,17 +458,62 @@ be "VocabFile". Returns: - the associated files list. + A list of AssociatedFileT objects. """ if table is None: - return - file_list = [] - length_method = getattr(table, field_name + "Length") - member_method = getattr(table, field_name) - for j in range(length_method()): - file_list.append(member_method(j).Name().decode("utf-8")) - return file_list + return [] + + # If the associated file field is not populated, + # `getattr(table, field_name)` will be None. Return an empty list. + return getattr(table, field_name) or [] + + def _get_recorded_associated_file_object_list(self, metadata): + """Gets a list of AssociatedFileT objects recorded in the metadata. + + Associated files may be attached to a model, a subgraph, or an input/output + tensor. + + Args: + metadata: the ModelMetadataT object. + + Returns: + List of recorded AssociatedFileT objects. + """ + recorded_files = [] + + # Add associated files attached to ModelMetadata. + recorded_files += self._get_associated_files_from_table( + metadata, "associatedFiles") + + # Add associated files attached to each SubgraphMetadata. + for subgraph in metadata.subgraphMetadata or []: + recorded_files += self._get_associated_files_from_table( + subgraph, "associatedFiles") + + # Add associated files attached to each input tensor. + for tensor_metadata in subgraph.inputTensorMetadata or []: + recorded_files += self._get_associated_files_from_table( + tensor_metadata, "associatedFiles") + recorded_files += self._get_associated_files_from_process_units( + tensor_metadata, "processUnits") + + # Add associated files attached to each output tensor. + for tensor_metadata in subgraph.outputTensorMetadata or []: + recorded_files += self._get_associated_files_from_table( + tensor_metadata, "associatedFiles") + recorded_files += self._get_associated_files_from_process_units( + tensor_metadata, "processUnits") + + # Add associated files attached to the input_process_units. + recorded_files += self._get_associated_files_from_process_units( + subgraph, "inputProcessUnits") + + # Add associated files attached to the output_process_units. + recorded_files += self._get_associated_files_from_process_units( + subgraph, "outputProcessUnits") + + return recorded_files def _populate_associated_files(self): """Concatenates associated files after TensorFlow Lite model file. @@ -441,10 +527,20 @@ # self._model_file = old_tflite_file | label1.txt | label2.txt # Then after trigger populate() to add label3.txt, self._model_file becomes # self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt - with zipfile.ZipFile(self._model_file, "a") as zf: - for af in self._associated_files: - filename = os.path.basename(af) - zf.write(af, filename) + with tempfile.SpooledTemporaryFile() as temp: + # (1) Copy content from model file of to temp file. + with _open_file(self._model_file, "rb") as f: + shutil.copyfileobj(f, temp) + + # (2) Append of to a temp file as a zip. + with _open_as_zipfile(temp, "a") as zf: + for file_name, file_buffer in self._associated_files.items(): + zf.writestr(file_name, file_buffer) + + # (3) Copy temp file to model file. + temp.seek(0) + with _open_file(self._model_file, "wb") as f: + shutil.copyfileobj(temp, f) def _populate_metadata_buffer(self): """Populates the metadata buffer (in bytearray) into the model file. @@ -457,7 +553,7 @@ buffer. """ - with open(self._model_file, "rb") as f: + with _open_file(self._model_file, "rb") as f: model_buf = f.read() model = _schema_fb.ModelT.InitFromObj( @@ -495,18 +591,22 @@ packed_files = self.get_packed_associated_file_list() if packed_files: # Writes the updated model buffer and associated files into a new model - # file. Then overwrites the original model file. - with tempfile.NamedTemporaryFile() as temp: - new_file = temp.name - with open(new_file, "wb") as f: - f.write(model_buf) - self._copy_archived_files(self._model_file, new_file, packed_files) - shutil.copy(new_file, self._model_file) - os.remove(new_file) + # file (in memory). Then overwrites the original model file. + with tempfile.SpooledTemporaryFile() as temp: + temp.write(model_buf) + self._copy_archived_files(self._model_file, packed_files, temp) + temp.seek(0) + with _open_file(self._model_file, "wb") as f: + shutil.copyfileobj(temp, f) else: - with open(self._model_file, "wb") as f: + with _open_file(self._model_file, "wb") as f: f.write(model_buf) + def _use_basename_for_associated_files_in_metadata(self, metadata): + """Removes any associated file local directory (if exists).""" + for file in self._get_recorded_associated_file_object_list(metadata): + file.name = os.path.basename(file.name) + def _validate_metadata(self, metadata_buf): """Validates the metadata to be populated.""" _assert_metadata_buffer_identifier(metadata_buf) @@ -521,7 +621,7 @@ model_meta.SubgraphMetadataLength())) # Verify if the number of tensor metadata matches the number of tensors. - with open(self._model_file, "rb") as f: + with _open_file(self._model_file, "rb") as f: model_buf = f.read() model = _schema_fb.Model.GetRootAsModel(model_buf, 0) @@ -569,10 +669,10 @@ with tempfile.NamedTemporaryFile() as temp: model_file = temp.name - with open(model_file, "wb") as f: + with _open_file(model_file, "wb") as f: f.write(model_buf) - MetadataPopulator.__init__(self, model_file) + super().__init__(model_file) def __del__(self): """Destructor of _MetadataPopulatorWithBuffer. @@ -596,6 +696,7 @@ """ _assert_model_buffer_identifier(model_buffer) _assert_metadata_buffer_identifier(metadata_buffer) + self._model_buffer = model_buffer self._metadata_buffer = metadata_buffer self._associated_file_list = associated_file_list @@ -614,7 +715,7 @@ ValueError: The model does not have metadata. """ _assert_file_exist(model_file) - with open(model_file, "rb") as f: + with _open_file(model_file, "rb") as f: return cls.with_model_buffer(f.read()) @classmethod @@ -629,20 +730,38 @@ """ if not model_buffer: raise ValueError("model_buffer cannot be empty.") - metadata_buffer = cls._get_metadata_buffer(model_buffer) + metadata_buffer = _get_metadata_buffer(model_buffer) + if not metadata_buffer: + raise ValueError("The model does not have metadata.") associated_file_list = cls._parse_packed_associted_file_list(model_buffer) return cls(model_buffer, metadata_buffer, associated_file_list) + def get_associated_file_buffer(self, filename): + """Get the specified associated file content in bytearray. + + Args: + filename: name of the file to be extracted. + + Returns: + The file content in bytearray. + + Raises: + ValueError: if the file does not exist in the model. + """ + if filename not in self._associated_file_list: + raise ValueError( + "The file, {}, does not exist in the model.".format(filename)) + + with _open_as_zipfile(io.BytesIO(self._model_buffer)) as zf: + return zf.read(filename) + + def get_metadata_buffer(self): + """Get the metadata buffer in bytearray out from the model.""" + return copy.deepcopy(self._metadata_buffer) + def get_metadata_json(self): """Converts the metadata into a json string.""" - opt = _pywrap_flatbuffers.IDLOptions() - opt.strict_json = True - parser = _pywrap_flatbuffers.Parser(opt) - with open(_FLATC_TFLITE_METADATA_SCHEMA_FILE) as f: - metadata_schema_content = f.read() - if not parser.parse(metadata_schema_content): - raise ValueError("Cannot parse metadata schema. Reason: " + parser.error) - return _pywrap_flatbuffers.generate_text(parser, self._metadata_buffer) + return convert_to_json(self._metadata_buffer) def get_packed_associated_file_list(self): """Returns a list of associated files that are packed in the model. @@ -653,32 +772,6 @@ return copy.deepcopy(self._associated_file_list) @staticmethod - def _get_metadata_buffer(model_buf): - """Returns the metadata in the model file as a buffer. - - Args: - model_buf: valid buffer of the model file. - - Returns: - Metadata buffer. - - Raises: - ValueError: The model does not have metadata. - """ - tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0) - - # Gets metadata from the model file. - for i in range(tflite_model.MetadataLength()): - meta = tflite_model.Metadata(i) - if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME: - buffer_index = meta.Buffer() - metadata = tflite_model.Buffers(buffer_index) - metadata_buf = metadata.DataAsNumpy().tobytes() - return metadata_buf - - raise ValueError("The model does not have metadata.") - - @staticmethod def _parse_packed_associted_file_list(model_buf): """Gets a list of associated files packed to the model file. @@ -690,22 +783,47 @@ """ try: - with zipfile.ZipFile(io.BytesIO(model_buf)) as zf: + with _open_as_zipfile(io.BytesIO(model_buf)) as zf: return zf.namelist() except zipfile.BadZipFile: return [] +# Create an individual method for getting the metadata json file, so that it can +# be used as a standalone util. +def convert_to_json(metadata_buffer): + """Converts the metadata into a json string. + + Args: + metadata_buffer: valid metadata buffer in bytes. + + Returns: + Metadata in JSON format. + + Raises: + ValueError: error occured when parsing the metadata schema file. + """ + + opt = _pywrap_flatbuffers.IDLOptions() + opt.strict_json = True + parser = _pywrap_flatbuffers.Parser(opt) + with _open_file(_FLATC_TFLITE_METADATA_SCHEMA_FILE) as f: + metadata_schema_content = f.read() + if not parser.parse(metadata_schema_content): + raise ValueError("Cannot parse metadata schema. Reason: " + parser.error) + return _pywrap_flatbuffers.generate_text(parser, metadata_buffer) + + def _assert_file_exist(filename): """Checks if a file exists.""" - if not os.path.exists(filename): + if not _exists_file(filename): raise IOError("File, '{0}', does not exist.".format(filename)) def _assert_model_file_identifier(model_file): """Checks if a model file has the expected TFLite schema identifier.""" _assert_file_exist(model_file) - with open(model_file, "rb") as f: + with _open_file(model_file, "rb") as f: _assert_model_buffer_identifier(f.read()) @@ -723,3 +841,25 @@ raise ValueError( "The metadata buffer does not have the expected identifier, and may not" " be a valid TFLite Metadata.") + + +def _get_metadata_buffer(model_buf): + """Returns the metadata in the model file as a buffer. + + Args: + model_buf: valid buffer of the model file. + + Returns: + Metadata buffer. Returns `None` if the model does not have metadata. + """ + tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0) + + # Gets metadata from the model file. + for i in range(tflite_model.MetadataLength()): + meta = tflite_model.Metadata(i) + if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME: + buffer_index = meta.Buffer() + metadata = tflite_model.Buffers(buffer_index) + return metadata.DataAsNumpy().tobytes() + + return None
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_displayer.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_displayer.py new file mode 100644 index 0000000..9d160fc --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_displayer.py
@@ -0,0 +1,34 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""CLI tool for display metadata.""" + +from absl import app +from absl import flags + +from tensorflow_lite_support.metadata.python import metadata + +FLAGS = flags.FLAGS +flags.DEFINE_string('model_path', None, 'Path to the TFLite model file.') +flags.DEFINE_string('export_json_path', None, 'Path to the output JSON file.') + + +def main(_): + displayer = metadata.MetadataDisplayer.with_model_file(FLAGS.model_path) + with open(FLAGS.export_json_path, 'w') as f: + f.write(displayer.get_metadata_json()) + + +if __name__ == '__main__': + app.run(main)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writer_for_task.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writer_for_task.py new file mode 100644 index 0000000..8d2a9b8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writer_for_task.py
@@ -0,0 +1,345 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Object oriented generic metadata writer for modular task API.""" + +import collections +import os +import tempfile +from typing import Optional, List + +from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow_lite_support.metadata.python import metadata as _metadata +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_writer +from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils + +CalibrationParameter = collections.namedtuple( + 'CalibrationParameter', ['scale', 'slope', 'offset', 'min_score']) + +LabelItem = collections.namedtuple('LabelItem', ['locale', 'filename', 'names']) + + +class Labels: + """Simple container holding classification labels of a particular tensor.""" + + def __init__(self): + self._labels = [] # [LabelItem] + + def add(self, + labels: List[str], + locale: Optional[str] = None, + use_as_category_name=False, + exported_filename: Optional[str] = None): + """Adds labels in the container.""" + if not labels: + raise ValueError('The list of labels is empty') + + # Prepare the new item to be inserted + if not exported_filename: + exported_filename = 'labels' + if locale: + exported_filename += f'_{locale}' + exported_filename += '.txt' + item = LabelItem(locale, exported_filename, labels) + + if self._labels and use_as_category_name: + # Category names should be the first one in the list + pos = 0 + + # Double check if we need to replace exising category name or insert one. + if self._labels[pos].locale: + # No category names available, insert one + self._labels.insert(pos, item) + else: + # Update the category name + self._labels[pos] = item + + else: + # insert the new element at the end of the list + self._labels.append(item) + return self + + def add_from_file(self, + label_filepath: str, + locale: Optional[str] = None, + use_as_category_name=False, + exported_filename: Optional[str] = None): + """Adds a label file in the container.""" + with open(label_filepath, 'r') as f: + labels = f.read().split('\n') + return self.add(labels, locale, use_as_category_name, exported_filename) + + +class ScoreCalibration: + """Simple container holding score calibration related parameters.""" + + # A shortcut to avoid client side code importing _metadata_fb + transformation_types = _metadata_fb.ScoreTransformationType + + def __init__(self, + transformation_type: _metadata_fb.ScoreTransformationType, + parameters: List[CalibrationParameter], + default_score: int = 0): + self.transformation_type = transformation_type + self.parameters = parameters + self.default_score = default_score + + +class Writer: + """Generic object-oriented Metadata writer. + + Note that this API is experimental and is subject to changes. Also it only + supports limited input and output tensor types for now. More types are being + added. + + Example usage: + + The model has two inputs, audio and image respectively. And generates two + outputs: classification and embedding. + + with open(model_path, 'rb') as f: + with Writer(f.read(), 'model_name', 'model description') as writer: + writer + .add_audio_input(sample_rate=16000, channels=1) + .add_image_input() + .add_classification_output(Labels().add(['A', 'B'])) + .add_embedding_output() + .populate('model.tflite', 'model.json') + """ + + def __init__(self, model_buffer: bytearray, model_name: str, + model_description: str): + self._model_buffer = model_buffer + self._general_md = metadata_info.GeneralMd( + name=model_name, description=model_description) + self._input_mds = [] + self._output_mds = [] + self._associate_files = [] + + def __enter__(self): + self._temp_folder = tempfile.TemporaryDirectory() + return self + + def __exit__(self, unused_exc_type, unused_exc_val, unused_exc_tb): + self._temp_folder.cleanup() + # Delete the attribute so that it errors out outside the `with` statement. + delattr(self, '_temp_folder') + + def populate(self, + tflite_path: Optional[str] = None, + json_path: Optional[str] = None): + """Writes the generated flatbuffer file / json metadata to disk. + + Note that you'll only need the tflite file for deployment. The JSON file + is useful to help you understand what's in the metadata. + + Args: + tflite_path: path to the tflite file. + json_path: path to the JSON file. + + Returns: + A tuple of (tflite_content_in_bytes, metdata_json_content) + """ + tflite_content = None + metadata_json_content = None + + writer = metadata_writer.MetadataWriter.create_from_metadata_info( + model_buffer=self._model_buffer, + general_md=self._general_md, + input_md=self._input_mds, + output_md=self._output_mds, + associated_files=self._associate_files) + + if tflite_path: + tflite_content = writer.populate() + writer_utils.save_file(tflite_content, tflite_path) + + if json_path: + displayer = _metadata.MetadataDisplayer.with_model_file(tflite_path) + metadata_json_content = displayer.get_metadata_json() + with open(json_path, 'w') as f: + f.write(metadata_json_content) + + return (tflite_content, metadata_json_content) + + def _export_labels(self, filename: str, index_to_label: List[str]): + filepath = os.path.join(self._temp_folder.name, filename) + with open(filepath, 'w') as f: + f.write('\n'.join(index_to_label)) + self._associate_files.append(filepath) + return filepath + + def _input_tensor_type(self, idx): + return writer_utils.get_input_tensor_types(self._model_buffer)[idx] + + def _output_tensor_type(self, idx): + return writer_utils.get_output_tensor_types(self._model_buffer)[idx] + + _INPUT_AUDIO_NAME = 'audio' + _INPUT_AUDIO_DESCRIPTION = 'Input audio clip to be processed.' + + def add_audio_input(self, + sample_rate: int, + channels: int, + name: str = _INPUT_AUDIO_NAME, + description: str = _INPUT_AUDIO_DESCRIPTION): + """Marks the next input tensor as an audio input.""" + # To make Task Library working properly, sample_rate, channels need to be + # positive. + if sample_rate <= 0: + raise ValueError( + 'sample_rate should be positive, but got {}.'.format(sample_rate)) + if channels <= 0: + raise ValueError( + 'channels should be positive, but got {}.'.format(channels)) + + input_md = metadata_info.InputAudioTensorMd( + name=name, + description=description, + sample_rate=sample_rate, + channels=channels) + self._input_mds.append(input_md) + return self + + _INPUT_IMAGE_NAME = 'image' + _INPUT_IMAGE_DESCRIPTION = 'Input image to be processed.' + color_space_types = _metadata_fb.ColorSpaceType + + def add_image_input( + self, + norm_mean: List[float], + norm_std: List[float], + color_space_type: Optional[ + _metadata_fb.ColorSpaceType] = _metadata_fb.ColorSpaceType.RGB, + name: str = _INPUT_IMAGE_NAME, + description: str = _INPUT_IMAGE_DESCRIPTION): + """Marks the next input tensor as an image input. + + Args: + norm_mean: The mean value used to normalize each input channel. If there + is only one element in the list, its value will be broadcasted to all + channels. Also note that norm_mean and norm_std should have the same + number of elements. [1] + norm_std: The std value used to normalize each input channel. If there is + only one element in the list, its value will be broadcasted to all + channels. [1] + color_space_type: The color space type of the input image. [2] + name: Name of the input tensor. + description: Description of the input tensor. + + Returns: + The Writer instance, can be used for chained operation. + + [1]: + https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters + [2]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L172 + """ + input_md = metadata_info.InputImageTensorMd( + name=name, + description=description, + norm_mean=norm_mean, + norm_std=norm_std, + color_space_type=color_space_type, + tensor_type=self._input_tensor_type(len(self._input_mds))) + + self._input_mds.append(input_md) + return self + + _OUTPUT_EMBEDDING_NAME = 'embedding' + _OUTPUT_EMBEDDING_DESCRIPTION = 'Embedding vector of the input.' + + def add_embedding_output(self, + name: str = _OUTPUT_EMBEDDING_NAME, + description: str = _OUTPUT_EMBEDDING_DESCRIPTION): + """Marks the next output tensor as embedding.""" + output_md = metadata_info.TensorMd(name=name, description=description) + self._output_mds.append(output_md) + return self + + def _export_calibration_file(self, filename: str, + calibrations: List[CalibrationParameter]): + """Store calibration parameters in a csv file.""" + filepath = os.path.join(self._temp_folder.name, filename) + with open(filepath, 'w') as f: + for idx, item in enumerate(calibrations): + if idx != 0: + f.write('\n') + if item: + scale, slope, offset, min_score = item + if all(x is not None for x in item): + f.write(f'{scale},{slope},{offset},{min_score}') + elif all(x is not None for x in item[:3]): + f.write(f'{scale},{slope},{offset}') + else: + raise ValueError('scale, slope and offset values can not be set to ' + 'None.') + self._associate_files.append(filepath) + return filepath + + _OUTPUT_CLASSIFICATION_NAME = 'score' + _OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively' + + def add_classification_output( + self, + labels: Labels, + score_calibration: Optional[ScoreCalibration] = None, + name=_OUTPUT_CLASSIFICATION_NAME, + description=_OUTPUT_CLASSIFICATION_DESCRIPTION): + """Marks model's next output tensor as a classification head. + + Example usage: + writer.add_classification_output( + Labels() + .add(['cat', 'dog], 'en') + .add(['chat', 'chien], 'fr') + .add(['/m/011l78', '/m/031d23'], use_as_category_name=True)) + + Args: + labels: an instance of Labels helper class. + score_calibration: an instance of ScoreCalibration helper class. + name: Metadata name of the tensor. Note that this is different from tensor + name in the flatbuffer. + description: human readable description of what the tensor does. + + Returns: + The current Writer instance to allow chained operation. + """ + calibration_md = None + if score_calibration: + calibration_md = metadata_info.ScoreCalibrationMd( + score_transformation_type=score_calibration.transformation_type, + default_score=score_calibration.default_score, + file_path=self._export_calibration_file('score_calibration.txt', + score_calibration.parameters)) + + idx = len(self._output_mds) + + label_files = [] + for item in labels._labels: # pylint: disable=protected-access + label_files.append( + metadata_info.LabelFileMd( + self._export_labels(item.filename, item.names), + locale=item.locale)) + + output_md = metadata_info.ClassificationTensorMd( + name=name, + description=description, + label_files=label_files, + tensor_type=self._output_tensor_type(idx), + score_calibration_md=calibration_md, + ) + self._output_mds.append(output_md) + return self
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/BUILD new file mode 100644 index 0000000..e84aac32 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/BUILD
@@ -0,0 +1,142 @@ +# Placeholder for internal Python strict library compatibility macro. + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +py_library( + name = "metadata_writer", + srcs = [ + "metadata_writer.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + ":metadata_info", + ":writer_utils", + "//tensorflow_lite_support/metadata:metadata_schema_py", + "//tensorflow_lite_support/metadata:schema_py", + "//tensorflow_lite_support/metadata/python:metadata", + "@flatbuffers//:runtime_py", + ], +) + +py_library( + name = "metadata_info", + srcs = [ + "metadata_info.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + ":writer_utils", + "//tensorflow_lite_support/metadata:metadata_schema_py", + "//tensorflow_lite_support/metadata:schema_py", + ], +) + +py_library( + name = "writer_utils", + srcs = [ + "writer_utils.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_lite_support/metadata:metadata_schema_py", + "//tensorflow_lite_support/metadata:schema_py", + ], +) + +py_library( + name = "image_classifier", + srcs = [ + "image_classifier.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + ":metadata_info", + ":metadata_writer", + ":writer_utils", + "//tensorflow_lite_support/metadata:metadata_schema_py", + ], +) + +py_library( + name = "object_detector", + srcs = [ + "object_detector.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + ":metadata_info", + ":metadata_writer", + ":writer_utils", + "//tensorflow_lite_support/metadata:metadata_schema_py", + "//tensorflow_lite_support/metadata:schema_py", + "//tensorflow_lite_support/metadata/python:metadata", + "@flatbuffers//:runtime_py", + ], +) + +py_library( + name = "image_segmenter", + srcs = [ + "image_segmenter.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + ":metadata_info", + ":metadata_writer", + ":writer_utils", + "//tensorflow_lite_support/metadata:metadata_schema_py", + ], +) + +py_library( + name = "nl_classifier", + srcs = [ + "nl_classifier.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + ":metadata_info", + ":metadata_writer", + ":writer_utils", + ], +) + +py_library( + name = "audio_classifier", + srcs = [ + "audio_classifier.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + ":metadata_info", + ":metadata_writer", + ":writer_utils", + ], +) + +py_library( + name = "bert_nl_classifier", + srcs = [ + "bert_nl_classifier.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + ":metadata_info", + ":metadata_writer", + ":writer_utils", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/__init__.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/__init__.py new file mode 100644 index 0000000..2ae3e34 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/__init__.py
@@ -0,0 +1,13 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/audio_classifier.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/audio_classifier.py new file mode 100644 index 0000000..ce9bfec --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/audio_classifier.py
@@ -0,0 +1,168 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Writes metadata and label file to the audio classifier models.""" + +from typing import List, Optional + +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_writer +from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils + +_MODEL_NAME = "AudioClassifier" +_MODEL_DESCRIPTION = ( + "Identify the most prominent type in the audio clip from a known set of " + "categories.") +_INPUT_NAME = "audio_clip" +_INPUT_DESCRIPTION = "Input audio clip to be classified." +_OUTPUT_NAME = "probability" +_OUTPUT_DESCRIPTION = "Scores of the labels respectively." + + +class MetadataWriter(metadata_writer.MetadataWriter): + """Writes metadata into an audio classifier.""" + + @classmethod + def create_from_metadata_info( + cls, + model_buffer: bytearray, + general_md: Optional[metadata_info.GeneralMd] = None, + input_md: Optional[metadata_info.InputAudioTensorMd] = None, + output_md: Optional[metadata_info.ClassificationTensorMd] = None): + """Creates MetadataWriter based on general/input/output information. + + Args: + model_buffer: valid buffer of the model file. + general_md: general information about the model. If not specified, default + general metadata will be generated. + input_md: input audio tensor informaton. If not specified, default input + metadata will be generated. + output_md: output classification tensor informaton. If not specified, + default output metadata will be generated. + + Returns: + A MetadataWriter object. + """ + if output_md is None: + output_md = metadata_info.ClassificationTensorMd( + name=_OUTPUT_NAME, description=_OUTPUT_DESCRIPTION) + + return cls.create_from_metadata_info_for_multihead(model_buffer, general_md, + input_md, [output_md]) + + @classmethod + def create_from_metadata_info_for_multihead( + cls, + model_buffer: bytearray, + general_md: Optional[metadata_info.GeneralMd] = None, + input_md: Optional[metadata_info.InputAudioTensorMd] = None, + output_md_list: Optional[List[ + metadata_info.ClassificationTensorMd]] = None): + """Creates a MetadataWriter instance for multihead models. + + Args: + model_buffer: valid buffer of the model file. + general_md: general information about the model. If not specified, default + general metadata will be generated. + input_md: input audio tensor informaton. If not specified, default input + metadata will be generated. + output_md_list: information of each output tensor head. If not specified, + default metadata will be generated for each output tensor. If + `tensor_name` in each `ClassificationTensorMd` instance is not + specified, elements in `output_md_list` need to have one-to-one mapping + with the output tensors [1] in the TFLite model. + [1]: + https://github.com/tensorflow/tflite-support/blob/b2a509716a2d71dfff706468680a729cc1604cff/tensorflow_lite_support/metadata/metadata_schema.fbs#L605-L612 + + Returns: + A MetadataWriter object. + """ + + if general_md is None: + general_md = metadata_info.GeneralMd( + name=_MODEL_NAME, description=_MODEL_DESCRIPTION) + + if input_md is None: + input_md = metadata_info.InputAudioTensorMd( + name=_INPUT_NAME, description=_INPUT_DESCRIPTION) + + associated_files = [] + for md in output_md_list or []: + associated_files.extend( + [file.file_path for file in md.associated_files or []]) + + return super().create_from_metadata_info( + model_buffer=model_buffer, + general_md=general_md, + input_md=[input_md], + output_md=output_md_list, + associated_files=associated_files) + + @classmethod + def create_for_inference( + cls, + model_buffer: bytearray, + sample_rate: int, + channels: int, + label_file_paths: List[str], + score_calibration_md: Optional[metadata_info.ScoreCalibrationMd] = None): + """Creates mandatory metadata for TFLite Support inference. + + The parameters required in this method are mandatory when using TFLite + Support features, such as Task library and Codegen tool (Android Studio ML + Binding). Other metadata fields will be set to default. If other fields need + to be filled, use the method `create_from_metadata_info` to edit them. + + Args: + model_buffer: valid buffer of the model file. + sample_rate: the sample rate in Hz when the audio was captured. + channels: the channel count of the audio. + label_file_paths: paths to the label files [1] in the classification + tensor. Pass in an empty list if the model does not have any label file. + score_calibration_md: information of the score calibration operation [2] + in the classification tensor. Optional if the model does not use score + calibration. + [1]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L95 + [2]: + https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434 + + Returns: + A MetadataWriter object. + """ + # To make Task Library working properly, sample_rate, channels need to be + # positive. + if sample_rate <= 0: + raise ValueError( + "sample_rate should be positive, but got {}.".format(sample_rate)) + + if channels <= 0: + raise ValueError( + "channels should be positive, but got {}.".format(channels)) + + input_md = metadata_info.InputAudioTensorMd(_INPUT_NAME, _INPUT_DESCRIPTION, + sample_rate, channels) + + output_md = metadata_info.ClassificationTensorMd( + name=_OUTPUT_NAME, + description=_OUTPUT_DESCRIPTION, + label_files=[ + metadata_info.LabelFileMd(file_path=file_path) + for file_path in label_file_paths + ], + tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0], + score_calibration_md=score_calibration_md) + + return cls.create_from_metadata_info( + model_buffer, input_md=input_md, output_md=output_md)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/bert_nl_classifier.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/bert_nl_classifier.py new file mode 100644 index 0000000..1f34bcf --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/bert_nl_classifier.py
@@ -0,0 +1,150 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Writes metadata and label file to the Bert NL classifier models.""" + +from typing import List, Optional, Union + +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_writer +from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils + +_MODEL_NAME = "BertNLClassifier" +_MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.") + +_OUTPUT_NAME = "probability" +_OUTPUT_DESCRIPTION = "Probabilities of the labels respectively." + +# The input tensor names of models created by Model Maker. +_DEFAULT_ID_NAME = "serving_default_input_word_ids:0" +_DEFAULT_MASK_NAME = "serving_default_input_mask:0" +_DEFAULT_SEGMENT_ID_NAME = "serving_default_input_type_ids:0" + + +class MetadataWriter(metadata_writer.MetadataWriter): + """Writes metadata into the Bert NL classifier.""" + + @classmethod + def create_from_metadata_info( + cls, + model_buffer: bytearray, + general_md: Optional[metadata_info.GeneralMd] = None, + input_md: Optional[metadata_info.BertInputTensorsMd] = None, + output_md: Optional[metadata_info.ClassificationTensorMd] = None): + """Creates MetadataWriter based on general/input/output information. + + Args: + model_buffer: valid buffer of the model file. + general_md: general information about the model. If not specified, default + general metadata will be generated. + input_md: input tensor information. If not specified, default input + metadata will be generated. + output_md: output classification tensor informaton. If not specified, + default output metadata will be generated. + + Returns: + A MetadataWriter object. + """ + if general_md is None: + general_md = metadata_info.GeneralMd( + name=_MODEL_NAME, description=_MODEL_DESCRIPTION) + + if input_md is None: + input_md = metadata_info.BertInputTensorsMd(model_buffer, + _DEFAULT_ID_NAME, + _DEFAULT_MASK_NAME, + _DEFAULT_SEGMENT_ID_NAME) + + if output_md is None: + output_md = metadata_info.ClassificationTensorMd( + name=_OUTPUT_NAME, description=_OUTPUT_DESCRIPTION) + + if output_md.associated_files is None: + output_md.associated_files = [] + + return cls.create_from_metadata( + model_buffer, + model_metadata=general_md.create_metadata(), + input_metadata=input_md.create_input_tesnor_metadata(), + output_metadata=[output_md.create_metadata()], + associated_files=[ + file.file_path for file in output_md.associated_files + ] + input_md.get_tokenizer_associated_files(), + input_process_units=input_md.create_input_process_unit_metadata()) + + @classmethod + def create_for_inference( + cls, + model_buffer: bytearray, + tokenizer_md: Union[metadata_info.BertTokenizerMd, + metadata_info.SentencePieceTokenizerMd], + label_file_paths: List[str], + ids_name: str = _DEFAULT_ID_NAME, + mask_name: str = _DEFAULT_MASK_NAME, + segment_name: str = _DEFAULT_SEGMENT_ID_NAME, + ): + """Creates mandatory metadata for TFLite Support inference. + + The parameters required in this method are mandatory when using TFLite + Support features, such as Task library and Codegen tool (Android Studio ML + Binding). Other metadata fields will be set to default. If other fields need + to be filled, use the method `create_from_metadata_info` to edit them. + + `ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name` + in the TFLite schema, which help to determine the tensor order when + populating metadata. The default values come from Model Maker. + + Args: + model_buffer: valid buffer of the model file. + tokenizer_md: information of the tokenizer used to process the input + string, if any. Supported tokenziers are: `BertTokenizer` [1] and + `SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer` + [3], refer to `nl_classifier.MetadataWriter`. + [1]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436 + [2]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473 + [3]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475 + label_file_paths: paths to the label files [4] in the classification + tensor. Pass in an empty list if the model does not have any label file. + [4]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L95 + ids_name: name of the ids tensor, which represents the tokenized ids of + the input text. + mask_name: name of the mask tensor, which represents the mask with 1 for + real tokens and 0 for padding tokens. + segment_name: name of the segment ids tensor, where `0` stands for the + first sequence, and `1` stands for the second sequence if exists. + + Returns: + A MetadataWriter object. + """ + input_md = metadata_info.BertInputTensorsMd( + model_buffer, + ids_name, + mask_name, + segment_name, + tokenizer_md=tokenizer_md) + output_md = metadata_info.ClassificationTensorMd( + name=_OUTPUT_NAME, + description=_OUTPUT_DESCRIPTION, + label_files=[ + metadata_info.LabelFileMd(file_path=file_path) + for file_path in label_file_paths + ], + tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0]) + + return cls.create_from_metadata_info( + model_buffer, input_md=input_md, output_md=output_md)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/image_classifier.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/image_classifier.py new file mode 100644 index 0000000..1f07270 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/image_classifier.py
@@ -0,0 +1,138 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Writes metadata and label file to the image classifier models.""" + +from typing import List, Optional + +from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_writer +from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils + +_MODEL_NAME = "ImageClassifier" +MODEL_DESCRIPTION = ("Identify the most prominent object in the image from a " + "known set of categories.") +INPUT_NAME = "image" +INPUT_DESCRIPTION = "Input image to be classified." +OUTPUT_NAME = "probability" +OUTPUT_DESCRIPTION = "Probabilities of the labels respectively." + + +class MetadataWriter(metadata_writer.MetadataWriter): + """Writes metadata into an image classifier.""" + + @classmethod + def create_from_metadata_info( + cls, + model_buffer: bytearray, + general_md: Optional[metadata_info.GeneralMd] = None, + input_md: Optional[metadata_info.InputImageTensorMd] = None, + output_md: Optional[metadata_info.ClassificationTensorMd] = None): + """Creates MetadataWriter based on general/input/output information. + + Args: + model_buffer: valid buffer of the model file. + general_md: general information about the model. If not specified, default + general metadata will be generated. + input_md: input image tensor informaton, if not specified, default input + metadata will be generated. + output_md: output classification tensor informaton, if not specified, + default output metadata will be generated. + + Returns: + A MetadataWriter object. + """ + + if general_md is None: + general_md = metadata_info.GeneralMd( + name=_MODEL_NAME, description=MODEL_DESCRIPTION) + + if input_md is None: + input_md = metadata_info.InputImageTensorMd( + name=INPUT_NAME, + description=INPUT_DESCRIPTION, + color_space_type=_metadata_fb.ColorSpaceType.RGB) + + if output_md is None: + output_md = metadata_info.ClassificationTensorMd( + name=OUTPUT_NAME, description=OUTPUT_DESCRIPTION) + + if output_md.associated_files is None: + output_md.associated_files = [] + + return super().create_from_metadata_info( + model_buffer=model_buffer, + general_md=general_md, + input_md=[input_md], + output_md=[output_md], + associated_files=[ + file.file_path for file in output_md.associated_files + ]) + + @classmethod + def create_for_inference( + cls, + model_buffer: bytearray, + input_norm_mean: List[float], + input_norm_std: List[float], + label_file_paths: List[str], + score_calibration_md: Optional[metadata_info.ScoreCalibrationMd] = None): + """Creates mandatory metadata for TFLite Support inference. + + The parameters required in this method are mandatory when using TFLite + Support features, such as Task library and Codegen tool (Android Studio ML + Binding). Other metadata fields will be set to default. If other fields need + to be filled, use the method `create_from_metadata_info` to edit them. + + Args: + model_buffer: valid buffer of the model file. + input_norm_mean: the mean value used in the input tensor normalization + [1]. + input_norm_std: the std value used in the input tensor normalizarion [1]. + label_file_paths: paths to the label files [2] in the classification + tensor. Pass in an empty list if the model does not have any label file. + score_calibration_md: information of the score calibration operation [3] + in the classification tensor. Optional if the model does not use score + calibration. + [1]: + https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters + [2]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L95 + [3]: + https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434 + + Returns: + A MetadataWriter object. + """ + input_md = metadata_info.InputImageTensorMd( + name=INPUT_NAME, + description=INPUT_DESCRIPTION, + norm_mean=input_norm_mean, + norm_std=input_norm_std, + color_space_type=_metadata_fb.ColorSpaceType.RGB, + tensor_type=writer_utils.get_input_tensor_types(model_buffer)[0]) + + output_md = metadata_info.ClassificationTensorMd( + name=OUTPUT_NAME, + description=OUTPUT_DESCRIPTION, + label_files=[ + metadata_info.LabelFileMd(file_path=file_path) + for file_path in label_file_paths + ], + tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0], + score_calibration_md=score_calibration_md) + + return cls.create_from_metadata_info( + model_buffer, input_md=input_md, output_md=output_md)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/image_segmenter.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/image_segmenter.py new file mode 100644 index 0000000..ea3766f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/image_segmenter.py
@@ -0,0 +1,157 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Writes metadata and label file to the image segmenter models.""" + +from typing import List, Optional + +from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_writer +from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils + +_MODEL_NAME = "ImageSegmenter" +_MODEL_DESCRIPTION = ("Semantic image segmentation predicts whether each pixel " + "of an image is associated with a certain class.") +_INPUT_NAME = "image" +_INPUT_DESCRIPTION = "Input image to be segmented." +_OUTPUT_NAME = "segmentation_masks" +_OUTPUT_DESCRIPTION = "Masks over the target objects with high accuracy." +# The output tensor is in the shape of [1, ImageHeight, ImageWidth, N], where N +# is the number of objects that the segmentation model can recognize. The output +# tensor is essentially a list of grayscale bitmaps, where each value is the +# probability of the corresponding pixel belonging to a certain object type. +# Therefore, the content dimension range of the output tensor is [1, 2]. +_CONTENT_DIM_MIN = 1 +_CONTENT_DIM_MAX = 2 + + +def _create_segmentation_masks_metadata( + masks_md: metadata_info.TensorMd) -> _metadata_fb.TensorMetadataT: + """Creates the metadata for the segmentation masks tensor.""" + masks_metadata = masks_md.create_metadata() + + # Create tensor content information. + content = _metadata_fb.ContentT() + content.contentProperties = _metadata_fb.ImagePropertiesT() + content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.GRAYSCALE + content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties + # Add the content range. See + # https://github.com/tensorflow/tflite-support/blob/ace5d3f3ce44c5f77c70284fa9c5a4e3f2f92abb/tensorflow_lite_support/metadata/metadata_schema.fbs#L285-L347 + dim_range = _metadata_fb.ValueRangeT() + dim_range.min = _CONTENT_DIM_MIN + dim_range.max = _CONTENT_DIM_MAX + content.range = dim_range + masks_metadata.content = content + + return masks_metadata + + +class MetadataWriter(metadata_writer.MetadataWriter): + """Writes metadata into an image segmenter.""" + + @classmethod + def create_from_metadata_info( + cls, + model_buffer: bytearray, + general_md: Optional[metadata_info.GeneralMd] = None, + input_md: Optional[metadata_info.InputImageTensorMd] = None, + output_md: Optional[metadata_info.TensorMd] = None): + """Creates MetadataWriter based on general/input/outputs information. + + Args: + model_buffer: valid buffer of the model file. + general_md: general information about the model. + input_md: input image tensor informaton. + output_md: output segmentation mask tensor informaton. This tensor is a + multidimensional array of [1 x mask_height x mask_width x num_classes], + where mask_width and mask_height are the dimensions of the segmentation + masks produced by the model, and num_classes is the number of classes + supported by the model. + + Returns: + A MetadataWriter object. + """ + + if general_md is None: + general_md = metadata_info.GeneralMd( + name=_MODEL_NAME, description=_MODEL_DESCRIPTION) + + if input_md is None: + input_md = metadata_info.InputImageTensorMd( + name=_INPUT_NAME, + description=_INPUT_DESCRIPTION, + color_space_type=_metadata_fb.ColorSpaceType.RGB) + + if output_md is None: + output_md = metadata_info.TensorMd( + name=_OUTPUT_NAME, description=_OUTPUT_DESCRIPTION) + + if output_md.associated_files is None: + output_md.associated_files = [] + + return super().create_from_metadata( + model_buffer, + model_metadata=general_md.create_metadata(), + input_metadata=[input_md.create_metadata()], + output_metadata=[_create_segmentation_masks_metadata(output_md)], + associated_files=[ + file.file_path for file in output_md.associated_files + ]) + + @classmethod + def create_for_inference(cls, model_buffer: bytearray, + input_norm_mean: List[float], + input_norm_std: List[float], + label_file_paths: List[str]): + """Creates mandatory metadata for TFLite Support inference. + + The parameters required in this method are mandatory when using TFLite + Support features, such as Task library and Codegen tool (Android Studio ML + Binding). Other metadata fields will be set to default. If other fields need + to be filled, use the method `create_from_metadata_info` to edit them. + + Args: + model_buffer: valid buffer of the model file. + input_norm_mean: the mean value used in the input tensor normalization + [1]. + input_norm_std: the std value used in the input tensor normalizarion [1]. + label_file_paths: paths to the label files [2] in the category tensor. + Pass in an empty list If the model does not have any label file. + [1]: + https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters + [2]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L108 + + Returns: + A MetadataWriter object. + """ + input_md = metadata_info.InputImageTensorMd( + name=_INPUT_NAME, + description=_INPUT_DESCRIPTION, + norm_mean=input_norm_mean, + norm_std=input_norm_std, + color_space_type=_metadata_fb.ColorSpaceType.RGB, + tensor_type=writer_utils.get_input_tensor_types(model_buffer)[0]) + + output_md = metadata_info.TensorMd( + name=_OUTPUT_NAME, + description=_OUTPUT_DESCRIPTION, + associated_files=[ + metadata_info.LabelFileMd(file_path=file_path) + for file_path in label_file_paths + ]) + + return cls.create_from_metadata_info( + model_buffer, input_md=input_md, output_md=output_md)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/metadata_info.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/metadata_info.py new file mode 100644 index 0000000..b1b6c1c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/metadata_info.py
@@ -0,0 +1,817 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper classes for common model metadata information.""" + +import collections +import csv +import os +from typing import List, Optional, Type, Union + +from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb +from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils + +# Min and max values for UINT8 tensors. +_MIN_UINT8 = 0 +_MAX_UINT8 = 255 + +# Default description for vocabulary files. +_VOCAB_FILE_DESCRIPTION = ("Vocabulary file to convert natural language " + "words to embedding vectors.") + + +class GeneralMd: + """A container for common metadata information of a model. + + Attributes: + name: name of the model. + version: version of the model. + description: description of what the model does. + author: author of the model. + licenses: licenses of the model. + """ + + def __init__(self, + name: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + author: Optional[str] = None, + licenses: Optional[str] = None): + self.name = name + self.version = version + self.description = description + self.author = author + self.licenses = licenses + + def create_metadata(self) -> _metadata_fb.ModelMetadataT: + """Creates the model metadata based on the general model information. + + Returns: + A Flatbuffers Python object of the model metadata. + """ + model_metadata = _metadata_fb.ModelMetadataT() + model_metadata.name = self.name + model_metadata.version = self.version + model_metadata.description = self.description + model_metadata.author = self.author + model_metadata.license = self.licenses + return model_metadata + + +class AssociatedFileMd: + """A container for common associated file metadata information. + + Attributes: + file_path: path to the associated file. + description: description of the associated file. + file_type: file type of the associated file [1]. + locale: locale of the associated file [2]. + [1]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L77 + [2]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L154 + """ + + def __init__( + self, + file_path: str, + description: Optional[str] = None, + file_type: Optional[_metadata_fb.AssociatedFileType] = _metadata_fb + .AssociatedFileType.UNKNOWN, + locale: Optional[str] = None): + self.file_path = file_path + self.description = description + self.file_type = file_type + self.locale = locale + + def create_metadata(self) -> _metadata_fb.AssociatedFileT: + """Creates the associated file metadata. + + Returns: + A Flatbuffers Python object of the associated file metadata. + """ + file_metadata = _metadata_fb.AssociatedFileT() + file_metadata.name = os.path.basename(self.file_path) + file_metadata.description = self.description + file_metadata.type = self.file_type + file_metadata.locale = self.locale + return file_metadata + + +class LabelFileMd(AssociatedFileMd): + """A container for label file metadata information.""" + + _LABEL_FILE_DESCRIPTION = ("Labels for categories that the model can " + "recognize.") + _FILE_TYPE = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS + + def __init__(self, file_path: str, locale: Optional[str] = None): + """Creates a LabelFileMd object. + + Args: + file_path: file_path of the label file. + locale: locale of the label file [1]. + [1]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L154 + """ + super().__init__(file_path, self._LABEL_FILE_DESCRIPTION, self._FILE_TYPE, + locale) + + +class RegexTokenizerMd: + """A container for the Regex tokenizer [1] metadata information. + + [1]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L459 + """ + + def __init__(self, delim_regex_pattern: str, vocab_file_path: str): + """Initializes a RegexTokenizerMd object. + + Args: + delim_regex_pattern: the regular expression to segment strings and create + tokens. + vocab_file_path: path to the vocabulary file. + """ + self._delim_regex_pattern = delim_regex_pattern + self._vocab_file_path = vocab_file_path + + def create_metadata(self) -> _metadata_fb.ProcessUnitT: + """Creates the Bert tokenizer metadata based on the information. + + Returns: + A Flatbuffers Python object of the Bert tokenizer metadata. + """ + vocab = _metadata_fb.AssociatedFileT() + vocab.name = self._vocab_file_path + vocab.description = _VOCAB_FILE_DESCRIPTION + vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY + + # Create the RegexTokenizer. + tokenizer = _metadata_fb.ProcessUnitT() + tokenizer.optionsType = ( + _metadata_fb.ProcessUnitOptions.RegexTokenizerOptions) + tokenizer.options = _metadata_fb.RegexTokenizerOptionsT() + tokenizer.options.delimRegexPattern = self._delim_regex_pattern + tokenizer.options.vocabFile = [vocab] + return tokenizer + + +class BertTokenizerMd: + """A container for the Bert tokenizer [1] metadata information. + + [1]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436 + """ + + def __init__(self, vocab_file_path: str): + """Initializes a BertTokenizerMd object. + + Args: + vocab_file_path: path to the vocabulary file. + """ + self._vocab_file_path = vocab_file_path + + def create_metadata(self) -> _metadata_fb.ProcessUnitT: + """Creates the Bert tokenizer metadata based on the information. + + Returns: + A Flatbuffers Python object of the Bert tokenizer metadata. + """ + vocab = _metadata_fb.AssociatedFileT() + vocab.name = self._vocab_file_path + vocab.description = _VOCAB_FILE_DESCRIPTION + vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY + tokenizer = _metadata_fb.ProcessUnitT() + tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions + tokenizer.options = _metadata_fb.BertTokenizerOptionsT() + tokenizer.options.vocabFile = [vocab] + return tokenizer + + +class SentencePieceTokenizerMd: + """A container for the sentence piece tokenizer [1] metadata information. + + [1]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473 + """ + + _SP_MODEL_DESCRIPTION = "The sentence piece model file." + _SP_VOCAB_FILE_DESCRIPTION = _VOCAB_FILE_DESCRIPTION + ( + " This file is optional during tokenization, while the sentence piece " + "model is mandatory.") + + def __init__(self, + sentence_piece_model_path: str, + vocab_file_path: Optional[str] = None): + """Initializes a SentencePieceTokenizerMd object. + + Args: + sentence_piece_model_path: path to the sentence piece model file. + vocab_file_path: path to the vocabulary file. + """ + self._sentence_piece_model_path = sentence_piece_model_path + self._vocab_file_path = vocab_file_path + + def create_metadata(self) -> _metadata_fb.ProcessUnitT: + """Creates the sentence piece tokenizer metadata based on the information. + + Returns: + A Flatbuffers Python object of the sentence piece tokenizer metadata. + """ + tokenizer = _metadata_fb.ProcessUnitT() + tokenizer.optionsType = ( + _metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions) + tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT() + + sp_model = _metadata_fb.AssociatedFileT() + sp_model.name = self._sentence_piece_model_path + sp_model.description = self._SP_MODEL_DESCRIPTION + tokenizer.options.sentencePieceModel = [sp_model] + if self._vocab_file_path: + vocab = _metadata_fb.AssociatedFileT() + vocab.name = self._vocab_file_path + vocab.description = self._SP_VOCAB_FILE_DESCRIPTION + vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY + tokenizer.options.vocabFile = [vocab] + return tokenizer + + +class ScoreCalibrationMd: + """A container for score calibration [1] metadata information. + + [1]: + https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434 + """ + + _SCORE_CALIBRATION_FILE_DESCRIPTION = ( + "Contains sigmoid-based score calibration parameters. The main purposes " + "of score calibration is to make scores across classes comparable, so " + "that a common threshold can be used for all output classes.") + _FILE_TYPE = _metadata_fb.AssociatedFileType.TENSOR_AXIS_SCORE_CALIBRATION + + def __init__(self, + score_transformation_type: _metadata_fb.ScoreTransformationType, + default_score: float, file_path: str): + """Creates a ScoreCalibrationMd object. + + Args: + score_transformation_type: type of the function used for transforming the + uncalibrated score before applying score calibration. + default_score: the default calibrated score to apply if the uncalibrated + score is below min_score or if no parameters were specified for a given + index. + file_path: file_path of the score calibration file [1]. + [1]: + https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L122 + + Raises: + ValueError: if the score_calibration file is malformed. + """ + self._score_transformation_type = score_transformation_type + self._default_score = default_score + self._file_path = file_path + + # Sanity check the score calibration file. + with open(self._file_path) as calibration_file: + csv_reader = csv.reader(calibration_file, delimiter=",") + for row in csv_reader: + if row and len(row) != 3 and len(row) != 4: + raise ValueError( + f"Expected empty lines or 3 or 4 parameters per line in score" + f" calibration file, but got {len(row)}.") + + if row and float(row[0]) < 0: + raise ValueError( + f"Expected scale to be a non-negative value, but got " + f"{float(row[0])}.") + + def create_metadata(self) -> _metadata_fb.ProcessUnitT: + """Creates the score calibration metadata based on the information. + + Returns: + A Flatbuffers Python object of the score calibration metadata. + """ + score_calibration = _metadata_fb.ProcessUnitT() + score_calibration.optionsType = ( + _metadata_fb.ProcessUnitOptions.ScoreCalibrationOptions) + options = _metadata_fb.ScoreCalibrationOptionsT() + options.scoreTransformation = self._score_transformation_type + options.defaultScore = self._default_score + score_calibration.options = options + return score_calibration + + def create_score_calibration_file_md(self) -> AssociatedFileMd: + return AssociatedFileMd(self._file_path, + self._SCORE_CALIBRATION_FILE_DESCRIPTION, + self._FILE_TYPE) + + +class TensorMd: + """A container for common tensor metadata information. + + Attributes: + name: name of the tensor. + description: description of what the tensor is. + min_values: per-channel minimum value of the tensor. + max_values: per-channel maximum value of the tensor. + content_type: content_type of the tensor. + associated_files: information of the associated files in the tensor. + tensor_name: name of the corresponding tensor [1] in the TFLite model. It is + used to locate the corresponding tensor and decide the order of the tensor + metadata [2] when populating model metadata. + [1]: + https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136 + [2]: + https://github.com/tensorflow/tflite-support/blob/b2a509716a2d71dfff706468680a729cc1604cff/tensorflow_lite_support/metadata/metadata_schema.fbs#L595-L612 + """ + + def __init__(self, + name: Optional[str] = None, + description: Optional[str] = None, + min_values: Optional[List[float]] = None, + max_values: Optional[List[float]] = None, + content_type: _metadata_fb.ContentProperties = _metadata_fb + .ContentProperties.FeatureProperties, + associated_files: Optional[List[Type[AssociatedFileMd]]] = None, + tensor_name: Optional[str] = None): + self.name = name + self.description = description + self.min_values = min_values + self.max_values = max_values + self.content_type = content_type + self.associated_files = associated_files + self.tensor_name = tensor_name + + def create_metadata(self) -> _metadata_fb.TensorMetadataT: + """Creates the input tensor metadata based on the information. + + Returns: + A Flatbuffers Python object of the input metadata. + """ + tensor_metadata = _metadata_fb.TensorMetadataT() + tensor_metadata.name = self.name + tensor_metadata.description = self.description + + # Create min and max values + stats = _metadata_fb.StatsT() + stats.max = self.max_values + stats.min = self.min_values + tensor_metadata.stats = stats + + # Create content properties + content = _metadata_fb.ContentT() + if self.content_type is _metadata_fb.ContentProperties.FeatureProperties: + content.contentProperties = _metadata_fb.FeaturePropertiesT() + elif self.content_type is _metadata_fb.ContentProperties.ImageProperties: + content.contentProperties = _metadata_fb.ImagePropertiesT() + elif self.content_type is ( + _metadata_fb.ContentProperties.BoundingBoxProperties): + content.contentProperties = _metadata_fb.BoundingBoxPropertiesT() + elif self.content_type is _metadata_fb.ContentProperties.AudioProperties: + content.contentProperties = _metadata_fb.AudioPropertiesT() + + content.contentPropertiesType = self.content_type + tensor_metadata.content = content + + # TODO(b/174091474): check if multiple label files have populated locale. + # Create associated files + if self.associated_files: + tensor_metadata.associatedFiles = [ + file.create_metadata() for file in self.associated_files + ] + return tensor_metadata + + +class InputImageTensorMd(TensorMd): + """A container for input image tensor metadata information. + + Attributes: + norm_mean: the mean value used in tensor normalization [1]. + norm_std: the std value used in the tensor normalization [1]. norm_mean and + norm_std must have the same dimension. + color_space_type: the color space type of the input image [2]. + [1]: + https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters + [2]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L172 + """ + + # Min and max float values for image pixels. + _MIN_PIXEL = 0.0 + _MAX_PIXEL = 255.0 + + def __init__( + self, + name: Optional[str] = None, + description: Optional[str] = None, + norm_mean: Optional[List[float]] = None, + norm_std: Optional[List[float]] = None, + color_space_type: Optional[ + _metadata_fb.ColorSpaceType] = _metadata_fb.ColorSpaceType.UNKNOWN, + tensor_type: Optional[_schema_fb.TensorType] = None): + """Initializes the instance of InputImageTensorMd. + + Args: + name: name of the tensor. + description: description of what the tensor is. + norm_mean: the mean value used in tensor normalization [1]. + norm_std: the std value used in the tensor normalization [1]. norm_mean + and norm_std must have the same dimension. + color_space_type: the color space type of the input image [2]. + tensor_type: data type of the tensor. + [1]: + https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters + [2]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L172 + + Raises: + ValueError: if norm_mean and norm_std have different dimensions. + """ + if norm_std and norm_mean and len(norm_std) != len(norm_mean): + raise ValueError( + f"norm_mean and norm_std are expected to be the same dim. But got " + f"{len(norm_mean)} and {len(norm_std)}") + + if tensor_type is _schema_fb.TensorType.UINT8: + min_values = [_MIN_UINT8] + max_values = [_MAX_UINT8] + elif tensor_type is _schema_fb.TensorType.FLOAT32 and norm_std and norm_mean: + min_values = [ + float(self._MIN_PIXEL - mean) / std + for mean, std in zip(norm_mean, norm_std) + ] + max_values = [ + float(self._MAX_PIXEL - mean) / std + for mean, std in zip(norm_mean, norm_std) + ] + else: + # Uint8 and Float32 are the two major types currently. And Task library + # doesn't support other types so far. + min_values = None + max_values = None + + super().__init__(name, description, min_values, max_values, + _metadata_fb.ContentProperties.ImageProperties) + self.norm_mean = norm_mean + self.norm_std = norm_std + self.color_space_type = color_space_type + + def create_metadata(self) -> _metadata_fb.TensorMetadataT: + """Creates the input image metadata based on the information. + + Returns: + A Flatbuffers Python object of the input image metadata. + """ + tensor_metadata = super().create_metadata() + tensor_metadata.content.contentProperties.colorSpace = self.color_space_type + # Create normalization parameters + if self.norm_mean and self.norm_std: + normalization = _metadata_fb.ProcessUnitT() + normalization.optionsType = ( + _metadata_fb.ProcessUnitOptions.NormalizationOptions) + normalization.options = _metadata_fb.NormalizationOptionsT() + normalization.options.mean = self.norm_mean + normalization.options.std = self.norm_std + tensor_metadata.processUnits = [normalization] + return tensor_metadata + + +class InputTextTensorMd(TensorMd): + """A container for the input text tensor metadata information. + + Attributes: + tokenizer_md: information of the tokenizer in the input text tensor, if any. + """ + + def __init__(self, + name: Optional[str] = None, + description: Optional[str] = None, + tokenizer_md: Optional[RegexTokenizerMd] = None): + """Initializes the instance of InputTextTensorMd. + + Args: + name: name of the tensor. + description: description of what the tensor is. + tokenizer_md: information of the tokenizer in the input text tensor, if + any. Only `RegexTokenizer` [1] is currenly supported. If the tokenizer + is `BertTokenizer` [2] or `SentencePieceTokenizer` [3], refer to + `bert_nl_classifier.MetadataWriter`. + [1]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475 + [2]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436 + [3]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473 + """ + super().__init__(name, description) + self.tokenizer_md = tokenizer_md + + def create_metadata(self) -> _metadata_fb.TensorMetadataT: + """Creates the input text metadata based on the information. + + Returns: + A Flatbuffers Python object of the input text metadata. + + Raises: + ValueError: if the type of tokenizer_md is unsupported. + """ + if not isinstance(self.tokenizer_md, (type(None), RegexTokenizerMd)): + raise ValueError( + f"The type of tokenizer_options, {type(self.tokenizer_md)}, is " + f"unsupported") + + tensor_metadata = super().create_metadata() + if self.tokenizer_md: + tensor_metadata.processUnits = [self.tokenizer_md.create_metadata()] + return tensor_metadata + + +class InputAudioTensorMd(TensorMd): + """A container for the input audio tensor metadata information. + + Attributes: + sample_rate: the sample rate in Hz when the audio was captured. + channels: the channel count of the audio. + """ + + def __init__(self, + name: Optional[str] = None, + description: Optional[str] = None, + sample_rate: int = 0, + channels: int = 0): + """Initializes the instance of InputAudioTensorMd. + + Args: + name: name of the tensor. + description: description of what the tensor is. + sample_rate: the sample rate in Hz when the audio was captured. + channels: the channel count of the audio. + """ + super().__init__( + name, + description, + content_type=_metadata_fb.ContentProperties.AudioProperties) + + self.sample_rate = sample_rate + self.channels = channels + + def create_metadata(self) -> _metadata_fb.TensorMetadataT: + """Creates the input audio metadata based on the information. + + Returns: + A Flatbuffers Python object of the input audio metadata. + + Raises: + ValueError: if any value of sample_rate, channels is negative. + """ + # 0 is the default value in Flatbuffers. + if self.sample_rate < 0: + raise ValueError( + f"sample_rate should be non-negative, but got {self.sample_rate}.") + + if self.channels < 0: + raise ValueError( + f"channels should be non-negative, but got {self.channels}.") + + tensor_metadata = super().create_metadata() + properties = tensor_metadata.content.contentProperties + properties.sampleRate = self.sample_rate + properties.channels = self.channels + + return tensor_metadata + + +class ClassificationTensorMd(TensorMd): + """A container for the classification tensor metadata information. + + Attributes: + label_files: information of the label files [1] in the classification + tensor. + score_calibration_md: information of the score calibration operation [2] in + the classification tensor. + [1]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L95 + [2]: + https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434 + """ + + # Min and max float values for classification results. + _MIN_FLOAT = 0.0 + _MAX_FLOAT = 1.0 + + def __init__(self, + name: Optional[str] = None, + description: Optional[str] = None, + label_files: Optional[List[LabelFileMd]] = None, + tensor_type: Optional[_schema_fb.TensorType] = None, + score_calibration_md: Optional[ScoreCalibrationMd] = None, + tensor_name: Optional[str] = None): + """Initializes the instance of ClassificationTensorMd. + + Args: + name: name of the tensor. + description: description of what the tensor is. + label_files: information of the label files [1] in the classification + tensor. + tensor_type: data type of the tensor. + score_calibration_md: information of the score calibration files operation + [2] in the classification tensor. + tensor_name: name of the corresponding tensor [3] in the TFLite model. It + is used to locate the corresponding classification tensor and decide the + order of the tensor metadata [4] when populating model metadata. + [1]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L95 + [2]: + https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434 + [3]: + https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136 + [4]: + https://github.com/tensorflow/tflite-support/blob/b2a509716a2d71dfff706468680a729cc1604cff/tensorflow_lite_support/metadata/metadata_schema.fbs#L595-L612 + """ + self.score_calibration_md = score_calibration_md + + if tensor_type is _schema_fb.TensorType.UINT8: + min_values = [_MIN_UINT8] + max_values = [_MAX_UINT8] + elif tensor_type is _schema_fb.TensorType.FLOAT32: + min_values = [self._MIN_FLOAT] + max_values = [self._MAX_FLOAT] + else: + # Uint8 and Float32 are the two major types currently. And Task library + # doesn't support other types so far. + min_values = None + max_values = None + + associated_files = label_files or [] + if self.score_calibration_md: + associated_files.append( + score_calibration_md.create_score_calibration_file_md()) + + super().__init__(name, description, min_values, max_values, + _metadata_fb.ContentProperties.FeatureProperties, + associated_files, tensor_name) + + def create_metadata(self) -> _metadata_fb.TensorMetadataT: + """Creates the classification tensor metadata based on the information.""" + tensor_metadata = super().create_metadata() + if self.score_calibration_md: + tensor_metadata.processUnits = [ + self.score_calibration_md.create_metadata() + ] + return tensor_metadata + + +class CategoryTensorMd(TensorMd): + """A container for the category tensor metadata information.""" + + def __init__(self, + name: Optional[str] = None, + description: Optional[str] = None, + label_files: Optional[List[LabelFileMd]] = None): + """Initializes a CategoryTensorMd object. + + Args: + name: name of the tensor. + description: description of what the tensor is. + label_files: information of the label files [1] in the category tensor. + [1]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L108 + """ + # In category tensors, label files are in the type of TENSOR_VALUE_LABELS. + value_label_files = label_files + if value_label_files: + for file in value_label_files: + file.file_type = _metadata_fb.AssociatedFileType.TENSOR_VALUE_LABELS + + super().__init__( + name=name, description=description, associated_files=value_label_files) + + +class BertInputTensorsMd: + """A container for the input tensor metadata information of Bert models.""" + + _IDS_NAME = "ids" + _IDS_DESCRIPTION = "Tokenized ids of the input text." + _MASK_NAME = "mask" + _MASK_DESCRIPTION = ("Mask with 1 for real tokens and 0 for padding " + "tokens.") + _SEGMENT_IDS_NAME = "segment_ids" + _SEGMENT_IDS_DESCRIPTION = ( + "0 for the first sequence, 1 for the second sequence if exists.") + + def __init__(self, + model_buffer: bytearray, + ids_name: str, + mask_name: str, + segment_name: str, + ids_md: Optional[TensorMd] = None, + mask_md: Optional[TensorMd] = None, + segment_ids_md: Optional[TensorMd] = None, + tokenizer_md: Union[None, BertTokenizerMd, + SentencePieceTokenizerMd] = None): + """Initializes a BertInputTensorsMd object. + + `ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name` + in the TFLite schema, which help to determine the tensor order when + populating metadata. + + Args: + model_buffer: valid buffer of the model file. + ids_name: name of the ids tensor, which represents the tokenized ids of + the input text. + mask_name: name of the mask tensor, which represents the mask with 1 for + real tokens and 0 for padding tokens. + segment_name: name of the segment ids tensor, where `0` stands for the + first sequence, and `1` stands for the second sequence if exists. + ids_md: input ids tensor informaton. + mask_md: input mask tensor informaton. + segment_ids_md: input segment tensor informaton. + tokenizer_md: information of the tokenizer used to process the input + string, if any. Supported tokenziers are: `BertTokenizer` [1] and + `SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer` + [3], refer to `nl_classifier.MetadataWriter`. + [1]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436 + [2]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473 + [3]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475 + """ + + self._input_names = [ids_name, mask_name, segment_name] + + # Get the input tensor names in order from the model. Later, we need to + # order the input metadata according to this tensor order. + self._ordered_input_names = writer_utils.get_input_tensor_names( + model_buffer) + + # Verify that self._ordered_input_names (read from the model) and + # self._input_name (collected from users) are aligned. + if collections.Counter(self._ordered_input_names) != collections.Counter( + self._input_names): + raise ValueError( + f"The input tensor names ({self._ordered_input_names}) do not match " + f"the tensor names read from the model ({self._input_names}).") + + if ids_md is None: + ids_md = TensorMd(name=self._IDS_NAME, description=self._IDS_DESCRIPTION) + + if mask_md is None: + mask_md = TensorMd( + name=self._MASK_NAME, description=self._MASK_DESCRIPTION) + + if segment_ids_md is None: + segment_ids_md = TensorMd( + name=self._SEGMENT_IDS_NAME, + description=self._SEGMENT_IDS_DESCRIPTION) + + # The order of self._input_md matches the order of self._input_names. + self._input_md = [ids_md, mask_md, segment_ids_md] + + if not isinstance(tokenizer_md, + (type(None), BertTokenizerMd, SentencePieceTokenizerMd)): + raise ValueError( + f"The type of tokenizer_options, {type(tokenizer_md)}, is unsupported" + ) + + self._tokenizer_md = tokenizer_md + + def create_input_tesnor_metadata(self) -> List[_metadata_fb.TensorMetadataT]: + """Creates the input metadata for the three input tesnors.""" + # The order of the three input tensors may vary with each model conversion. + # We need to order the input metadata according to the tensor order in the + # model. + ordered_metadata = [] + name_md_dict = dict(zip(self._input_names, self._input_md)) + for name in self._ordered_input_names: + ordered_metadata.append(name_md_dict[name].create_metadata()) + return ordered_metadata + + def create_input_process_unit_metadata( + self) -> List[_metadata_fb.ProcessUnitT]: + """Creates the input process unit metadata.""" + if self._tokenizer_md: + return [self._tokenizer_md.create_metadata()] + else: + return [] + + def get_tokenizer_associated_files(self) -> List[str]: + """Gets the associated files that are packed in the tokenizer.""" + if self._tokenizer_md: + return writer_utils.get_tokenizer_associated_files( + self._tokenizer_md.create_metadata().options) + else: + return []
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/metadata_writer.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/metadata_writer.py new file mode 100644 index 0000000..702a254 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/metadata_writer.py
@@ -0,0 +1,239 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper class to write metadata into TFLite models.""" + +import collections +from typing import List, Optional, Type + +import flatbuffers +from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb +from tensorflow_lite_support.metadata.python import metadata as _metadata +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils + + +class MetadataWriter: + """Writes the metadata and associated files into a TFLite model.""" + + def __init__(self, + model_buffer: bytearray, + metadata_buffer: Optional[bytearray] = None, + associated_files: Optional[List[str]] = None): + """Constructs the MetadataWriter. + + Args: + model_buffer: valid buffer of the model file. + metadata_buffer: valid buffer of the metadata. + associated_files: path to the associated files to be populated. + """ + self._model_buffer = model_buffer + self._metadata_buffer = metadata_buffer + self._associated_files = associated_files if associated_files else [] + self._populated_model_buffer = None + + @classmethod + def create_from_metadata_info( + cls, + model_buffer: bytearray, + general_md: Optional[metadata_info.GeneralMd] = None, + input_md: Optional[List[Type[metadata_info.TensorMd]]] = None, + output_md: Optional[List[Type[metadata_info.TensorMd]]] = None, + associated_files: Optional[List[str]] = None): + """Creates MetadataWriter based on the metadata information. + + Args: + model_buffer: valid buffer of the model file. + general_md: general information about the model. + input_md: metadata information of the input tensors. + output_md: metadata information of the output tensors. + associated_files: path to the associated files to be populated. + + Returns: + A MetadataWriter Object. + + Raises: + ValueError: if the tensor names from `input_md` and `output_md` do not + match the tensor names read from the model. + """ + + if general_md is None: + general_md = metadata_info.GeneralMd() + if input_md is None: + input_md = [] + if output_md is None: + output_md = [] + + # Order the input/output metadata according to tensor orders from the model. + input_md = _order_tensor_metadata( + input_md, writer_utils.get_input_tensor_names(model_buffer)) + output_md = _order_tensor_metadata( + output_md, writer_utils.get_output_tensor_names(model_buffer)) + + model_metadata = general_md.create_metadata() + input_metadata = [m.create_metadata() for m in input_md] + output_metadata = [m.create_metadata() for m in output_md] + return cls.create_from_metadata(model_buffer, model_metadata, + input_metadata, output_metadata, + associated_files) + + @classmethod + def create_from_metadata( + cls, + model_buffer: bytearray, + model_metadata: Optional[_metadata_fb.ModelMetadataT] = None, + input_metadata: Optional[List[_metadata_fb.TensorMetadataT]] = None, + output_metadata: Optional[List[_metadata_fb.TensorMetadataT]] = None, + associated_files: Optional[List[str]] = None, + input_process_units: Optional[List[_metadata_fb.ProcessUnitT]] = None, + output_process_units: Optional[List[_metadata_fb.ProcessUnitT]] = None): + """Creates MetadataWriter based on the metadata Flatbuffers Python Objects. + + Args: + model_buffer: valid buffer of the model file. + model_metadata: general model metadata [1]. The subgraph_metadata will be + refreshed with input_metadata and output_metadata. + input_metadata: a list of metadata of the input tensors [2]. + output_metadata: a list of metadata of the output tensors [3]. + associated_files: path to the associated files to be populated. + input_process_units: a lits of metadata of the input process units [4]. + output_process_units: a lits of metadata of the output process units [5]. + [1]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L640-L681 + [2]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L590 + [3]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L599 + [4]: + https://github.com/tensorflow/tflite-support/blob/b5cc57c74f7990d8bc055795dfe8d50267064a57/tensorflow_lite_support/metadata/metadata_schema.fbs#L646 + [5]: + https://github.com/tensorflow/tflite-support/blob/b5cc57c74f7990d8bc055795dfe8d50267064a57/tensorflow_lite_support/metadata/metadata_schema.fbs#L650 + Returns: + A MetadataWriter Object. + """ + # Create empty tensor metadata when input_metadata/output_metadata are None + # to bypass MetadataPopulator verification. + if not input_metadata: + model = _schema_fb.Model.GetRootAsModel(model_buffer, 0) + num_input_tensors = model.Subgraphs(0).InputsLength() + input_metadata = [ + _metadata_fb.TensorMetadataT() for i in range(num_input_tensors) + ] + + if not output_metadata: + model = _schema_fb.Model.GetRootAsModel(model_buffer, 0) + num_output_tensors = model.Subgraphs(0).OutputsLength() + output_metadata = [ + _metadata_fb.TensorMetadataT() for i in range(num_output_tensors) + ] + + _fill_default_tensor_names( + input_metadata, writer_utils.get_input_tensor_names(model_buffer)) + + _fill_default_tensor_names( + output_metadata, writer_utils.get_output_tensor_names(model_buffer)) + + subgraph_metadata = _metadata_fb.SubGraphMetadataT() + subgraph_metadata.inputTensorMetadata = input_metadata + subgraph_metadata.outputTensorMetadata = output_metadata + subgraph_metadata.inputProcessUnits = input_process_units + subgraph_metadata.outputProcessUnits = output_process_units + + if model_metadata is None: + model_metadata = _metadata_fb.ModelMetadataT() + model_metadata.subgraphMetadata = [subgraph_metadata] + + b = flatbuffers.Builder(0) + b.Finish( + model_metadata.Pack(b), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + return cls(model_buffer, b.Output(), associated_files) + + def populate(self) -> bytearray: + """Populates the metadata and label file to the model file. + + Returns: + A new model buffer with the metadata and associated files. + """ + if self._populated_model_buffer: + return self._populated_model_buffer + + populator = _metadata.MetadataPopulator.with_model_buffer( + self._model_buffer) + if self._model_buffer is not None: + populator.load_metadata_buffer(self._metadata_buffer) + if self._associated_files: + populator.load_associated_files(self._associated_files) + populator.populate() + self._populated_model_buffer = populator.get_model_buffer() + return self._populated_model_buffer + + def get_metadata_json(self) -> str: + """Gets the generated JSON metadata string before populated into model. + + This method returns the metadata buffer before populated into the model. + More fields could be filled by MetadataPopulator, such as + min_parser_version. Use get_populated_metadata_json() if you want to get the + final metadata string. + + Returns: + The generated JSON metadata string before populated into model. + """ + return _metadata.convert_to_json(bytes(self._metadata_buffer)) + + def get_populated_metadata_json(self) -> str: + """Gets the generated JSON metadata string after populated into model. + + More fields could be filled by MetadataPopulator, such as + min_parser_version. Use get_metadata_json() if you want to get the + original metadata string. + + Returns: + The generated JSON metadata string after populated into model. + """ + displayer = _metadata.MetadataDisplayer.with_model_buffer(self.populate()) + return displayer.get_metadata_json() + + +# If tensor name in metadata is empty, default to the tensor name saved in +# the model. +def _fill_default_tensor_names( + tensor_metadata: List[_metadata_fb.TensorMetadataT], + tensor_names_from_model: List[str]): + for metadata, name in zip(tensor_metadata, tensor_names_from_model): + metadata.name = metadata.name or name + + +def _order_tensor_metadata( + tensor_md: List[Type[metadata_info.TensorMd]], + tensor_names_from_model: List[str]) -> List[Type[metadata_info.TensorMd]]: + """Orders tensor_md according to the tensor names from the model.""" + tensor_names_from_arg = [ + md.tensor_name for md in tensor_md or [] if md.tensor_name is not None + ] + if not tensor_names_from_arg: + return tensor_md + + if collections.Counter(tensor_names_from_arg) != collections.Counter( + tensor_names_from_model): + raise ValueError( + "The tensor names from arguments ({}) do not match the tensor names" + " read from the model ({}).".format(tensor_names_from_arg, + tensor_names_from_model)) + ordered_tensor_md = [] + name_md_dict = dict(zip(tensor_names_from_arg, tensor_md)) + for name in tensor_names_from_model: + ordered_tensor_md.append(name_md_dict[name]) + return ordered_tensor_md
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/nl_classifier.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/nl_classifier.py new file mode 100644 index 0000000..a20d1fd --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/nl_classifier.py
@@ -0,0 +1,134 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Writes metadata and label file to the NL classifier models.""" + +from typing import List, Optional + +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_writer +from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils + +_MODEL_NAME = "NLClassifier" +_MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.") +_INPUT_NAME = "input_text" +_INPUT_DESCRIPTION = ("Embedding vectors representing the input text to be " + "classified.") +_OUTPUT_NAME = "probability" +_OUTPUT_DESCRIPTION = "Probabilities of the labels respectively." + + +class MetadataWriter(metadata_writer.MetadataWriter): + """Writes metadata into the NL classifier.""" + + @classmethod + def create_from_metadata_info( + cls, + model_buffer: bytearray, + general_md: Optional[metadata_info.GeneralMd] = None, + input_md: Optional[metadata_info.InputTextTensorMd] = None, + output_md: Optional[metadata_info.ClassificationTensorMd] = None): + """Creates MetadataWriter based on general/input/output information. + + Args: + model_buffer: valid buffer of the model file. + general_md: general information about the model. If not specified, default + general metadata will be generated. + input_md: input text tensor information, if not specified, default input + metadata will be generated. + output_md: output classification tensor information, if not specified, + default output metadata will be generated. + + Returns: + A MetadataWriter object. + """ + + if general_md is None: + general_md = metadata_info.GeneralMd( + name=_MODEL_NAME, description=_MODEL_DESCRIPTION) + + if input_md is None: + input_md = metadata_info.InputTextTensorMd( + name=_INPUT_NAME, description=_INPUT_DESCRIPTION) + + if output_md is None: + output_md = metadata_info.ClassificationTensorMd( + name=_OUTPUT_NAME, description=_OUTPUT_DESCRIPTION) + + if output_md.associated_files is None: + output_md.associated_files = [] + + tokenizer_files = [] + if input_md.tokenizer_md: + tokenizer_files = writer_utils.get_tokenizer_associated_files( + input_md.tokenizer_md.create_metadata().options) + + return super().create_from_metadata_info( + model_buffer=model_buffer, + general_md=general_md, + input_md=[input_md], + output_md=[output_md], + associated_files=[ + file.file_path for file in output_md.associated_files + ] + tokenizer_files) + + @classmethod + def create_for_inference( + cls, model_buffer: bytearray, + tokenizer_md: Optional[metadata_info.RegexTokenizerMd], + label_file_paths: List[str]): + """Creates mandatory metadata for TFLite Support inference. + + The parameters required in this method are mandatory when using TFLite + Support features, such as Task library and Codegen tool (Android Studio ML + Binding). Other metadata fields will be set to default. If other fields need + to be filled, use the method `create_from_metadata_info` to edit them. + + Args: + model_buffer: valid buffer of the model file. + tokenizer_md: information of the tokenizer used to process the input + string, if any. Only `RegexTokenizer` [1] is currently supported. If the + tokenizer is `BertTokenizer` [2] or `SentencePieceTokenizer` [3], refer + to `bert_nl_classifier.MetadataWriter`. + [1]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475 + [2]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436 + [3]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473 + label_file_paths: paths to the label files [4] in the classification + tensor. Pass in an empty list if the model does not have any label + file. + [4]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L95 + + Returns: + A MetadataWriter object. + """ + input_md = metadata_info.InputTextTensorMd( + name=_INPUT_NAME, + description=_INPUT_DESCRIPTION, + tokenizer_md=tokenizer_md) + + output_md = metadata_info.ClassificationTensorMd( + name=_OUTPUT_NAME, + description=_OUTPUT_DESCRIPTION, + label_files=[ + metadata_info.LabelFileMd(file_path=file_path) + for file_path in label_file_paths + ], + tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0]) + + return cls.create_from_metadata_info( + model_buffer, input_md=input_md, output_md=output_md)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/object_detector.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/object_detector.py new file mode 100644 index 0000000..3caf063 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/object_detector.py
@@ -0,0 +1,295 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Writes metadata and label file to the object detector models.""" + +import logging +from typing import List, Optional, Type, Union + +import flatbuffers +from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb +from tensorflow_lite_support.metadata.python import metadata as _metadata +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_writer +from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils + +_MODEL_NAME = "ObjectDetector" +_MODEL_DESCRIPTION = ( + "Identify which of a known set of objects might be present and provide " + "information about their positions within the given image or a video " + "stream.") +_INPUT_NAME = "image" +_INPUT_DESCRIPTION = "Input image to be detected." +# The output tensor names shouldn't be changed since these name will be used +# to handle the order of output in TFLite Task Library when doing inference +# in on-device application. +_OUTPUT_LOCATION_NAME = "location" +_OUTPUT_LOCATION_DESCRIPTION = "The locations of the detected boxes." +_OUTPUT_CATRGORY_NAME = "category" +_OUTPUT_CATEGORY_DESCRIPTION = "The categories of the detected boxes." +_OUTPUT_SCORE_NAME = "score" +_OUTPUT_SCORE_DESCRIPTION = "The scores of the detected boxes." +_OUTPUT_NUMBER_NAME = "number of detections" +_OUTPUT_NUMBER_DESCRIPTION = "The number of the detected boxes." +_CONTENT_VALUE_DIM = 2 +_BOUNDING_BOX_INDEX = (1, 0, 3, 2) +_GROUP_NAME = "detection_result" + + +def _create_1d_value_range(dim: int) -> _metadata_fb.ValueRangeT: + """Creates the 1d ValueRange based on the given dimension.""" + value_range = _metadata_fb.ValueRangeT() + value_range.min = dim + value_range.max = dim + return value_range + + +def _create_location_metadata( + location_md: metadata_info.TensorMd) -> _metadata_fb.TensorMetadataT: + """Creates the metadata for the location tensor.""" + location_metadata = location_md.create_metadata() + content = _metadata_fb.ContentT() + content.contentPropertiesType = ( + _metadata_fb.ContentProperties.BoundingBoxProperties) + properties = _metadata_fb.BoundingBoxPropertiesT() + properties.index = list(_BOUNDING_BOX_INDEX) + properties.type = _metadata_fb.BoundingBoxType.BOUNDARIES + properties.coordinateType = _metadata_fb.CoordinateType.RATIO + content.contentProperties = properties + content.range = _create_1d_value_range(_CONTENT_VALUE_DIM) + location_metadata.content = content + return location_metadata + + +# This is needed for both the output category tensor and the score tensor. +def _create_metadata_with_value_range( + tensor_md: metadata_info.TensorMd) -> _metadata_fb.TensorMetadataT: + """Creates tensor metadata with extra value range information.""" + tensor_metadata = tensor_md.create_metadata() + tensor_metadata.content.range = _create_1d_value_range(_CONTENT_VALUE_DIM) + return tensor_metadata + + +def _get_tflite_outputs(model_buffer: bytearray) -> List[int]: + """Gets the tensor indices of output in the TFLite Subgraph.""" + model = _schema_fb.Model.GetRootAsModel(model_buffer, 0) + return model.Subgraphs(0).OutputsAsNumpy() + + +def _extend_new_files( + file_list: List[str], + associated_files: Optional[List[Type[metadata_info.AssociatedFileMd]]]): + """Extends new associated files to the file list.""" + if not associated_files: + return + + for file in associated_files: + if file.file_path not in file_list: + file_list.append(file.file_path) + + +class MetadataWriter(metadata_writer.MetadataWriter): + """Writes metadata into an object detector.""" + + @classmethod + def create_from_metadata_info( + cls, + model_buffer: bytearray, + general_md: Optional[metadata_info.GeneralMd] = None, + input_md: Optional[metadata_info.InputImageTensorMd] = None, + output_location_md: Optional[metadata_info.TensorMd] = None, + output_category_md: Optional[metadata_info.CategoryTensorMd] = None, + output_score_md: Union[None, metadata_info.TensorMd, + metadata_info.ClassificationTensorMd] = None, + output_number_md: Optional[metadata_info.TensorMd] = None): + """Creates MetadataWriter based on general/input/outputs information. + + Args: + model_buffer: valid buffer of the model file. + general_md: general information about the model. + input_md: input image tensor informaton. + output_location_md: output location tensor informaton. The location tensor + is a multidimensional array of [N][4] floating point values between 0 + and 1, the inner arrays representing bounding boxes in the form [top, + left, bottom, right]. + output_category_md: output category tensor information. The category + tensor is an array of N integers (output as floating point values) each + indicating the index of a class label from the labels file. + output_score_md: output score tensor information. The score tensor is an + array of N floating point values between 0 and 1 representing + probability that a class was detected. Use ClassificationTensorMd to + calibrate score. + output_number_md: output number of detections tensor information. This + tensor is an integer value of N. + + Returns: + A MetadataWriter object. + """ + if general_md is None: + general_md = metadata_info.GeneralMd( + name=_MODEL_NAME, description=_MODEL_DESCRIPTION) + + if input_md is None: + input_md = metadata_info.InputImageTensorMd( + name=_INPUT_NAME, + description=_INPUT_DESCRIPTION, + color_space_type=_metadata_fb.ColorSpaceType.RGB) + + warn_message_format = ( + "The output name isn't the default string \"%s\". This may cause the " + "model not work in the TFLite Task Library since the tensor name will " + "be used to handle the output order in the TFLite Task Library.") + if output_location_md is None: + output_location_md = metadata_info.TensorMd( + name=_OUTPUT_LOCATION_NAME, description=_OUTPUT_LOCATION_DESCRIPTION) + elif output_location_md.name != _OUTPUT_LOCATION_NAME: + logging.warning(warn_message_format, _OUTPUT_LOCATION_NAME) + + if output_category_md is None: + output_category_md = metadata_info.CategoryTensorMd( + name=_OUTPUT_CATRGORY_NAME, description=_OUTPUT_CATEGORY_DESCRIPTION) + elif output_category_md.name != _OUTPUT_CATRGORY_NAME: + logging.warning(warn_message_format, _OUTPUT_CATRGORY_NAME) + + if output_score_md is None: + output_score_md = metadata_info.ClassificationTensorMd( + name=_OUTPUT_SCORE_NAME, + description=_OUTPUT_SCORE_DESCRIPTION, + ) + elif output_score_md.name != _OUTPUT_SCORE_NAME: + logging.warning(warn_message_format, _OUTPUT_SCORE_NAME) + + if output_number_md is None: + output_number_md = metadata_info.TensorMd( + name=_OUTPUT_NUMBER_NAME, description=_OUTPUT_NUMBER_DESCRIPTION) + elif output_number_md.name != _OUTPUT_NUMBER_NAME: + logging.warning(warn_message_format, _OUTPUT_NUMBER_NAME) + + # Create output tensor group info. + group = _metadata_fb.TensorGroupT() + group.name = _GROUP_NAME + group.tensorNames = [ + output_location_md.name, output_category_md.name, output_score_md.name + ] + + # Gets the tensor inidces of tflite outputs and then gets the order of the + # output metadata by the value of tensor indices. For instance, if the + # output indices are [601, 599, 598, 600], tensor names and indices aligned + # are: + # - location: 598 + # - category: 599 + # - score: 600 + # - number of detections: 601 + # because of the op's ports of TFLITE_DETECTION_POST_PROCESS + # (https://github.com/tensorflow/tensorflow/blob/a4fe268ea084e7d323133ed7b986e0ae259a2bc7/tensorflow/lite/kernels/detection_postprocess.cc#L47-L50). + # Thus, the metadata of tensors are sorted in this way, according to + # output_tensor_indicies correctly. + output_tensor_indices = _get_tflite_outputs(model_buffer) + metadata_list = [ + _create_location_metadata(output_location_md), + _create_metadata_with_value_range(output_category_md), + _create_metadata_with_value_range(output_score_md), + output_number_md.create_metadata() + ] + + # Align indices with tensors. + sorted_indices = sorted(output_tensor_indices) + indices_to_tensors = dict(zip(sorted_indices, metadata_list)) + + # Output metadata according to output_tensor_indices. + output_metadata = [indices_to_tensors[i] for i in output_tensor_indices] + + # Create subgraph info. + subgraph_metadata = _metadata_fb.SubGraphMetadataT() + subgraph_metadata.inputTensorMetadata = [input_md.create_metadata()] + subgraph_metadata.outputTensorMetadata = output_metadata + subgraph_metadata.outputTensorGroups = [group] + + # Create model metadata + model_metadata = general_md.create_metadata() + model_metadata.subgraphMetadata = [subgraph_metadata] + + b = flatbuffers.Builder(0) + b.Finish( + model_metadata.Pack(b), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + + associated_files = [] + _extend_new_files(associated_files, output_category_md.associated_files) + _extend_new_files(associated_files, output_score_md.associated_files) + return cls(model_buffer, b.Output(), associated_files=associated_files) + + @classmethod + def create_for_inference( + cls, + model_buffer: bytearray, + input_norm_mean: List[float], + input_norm_std: List[float], + label_file_paths: List[str], + score_calibration_md: Optional[metadata_info.ScoreCalibrationMd] = None): + """Creates mandatory metadata for TFLite Support inference. + + The parameters required in this method are mandatory when using TFLite + Support features, such as Task library and Codegen tool (Android Studio ML + Binding). Other metadata fields will be set to default. If other fields need + to be filled, use the method `create_from_metadata_info` to edit them. + + Args: + model_buffer: valid buffer of the model file. + input_norm_mean: the mean value used in the input tensor normalization + [1]. + input_norm_std: the std value used in the input tensor normalizarion [1]. + label_file_paths: paths to the label files [2] in the category tensor. + Pass in an empty list, If the model does not have any label file. + score_calibration_md: information of the score calibration operation [3] + in the classification tensor. Optional if the model does not use score + calibration. + [1]: + https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters + [2]: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L108 + [3]: + https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434 + + Returns: + A MetadataWriter object. + """ + input_md = metadata_info.InputImageTensorMd( + name=_INPUT_NAME, + description=_INPUT_DESCRIPTION, + norm_mean=input_norm_mean, + norm_std=input_norm_std, + color_space_type=_metadata_fb.ColorSpaceType.RGB, + tensor_type=writer_utils.get_input_tensor_types(model_buffer)[0]) + + output_category_md = metadata_info.CategoryTensorMd( + name=_OUTPUT_CATRGORY_NAME, + description=_OUTPUT_CATEGORY_DESCRIPTION, + label_files=[ + metadata_info.LabelFileMd(file_path=file_path) + for file_path in label_file_paths + ]) + + output_score_md = metadata_info.ClassificationTensorMd( + name=_OUTPUT_SCORE_NAME, + description=_OUTPUT_SCORE_DESCRIPTION, + score_calibration_md=score_calibration_md + ) + + return cls.create_from_metadata_info( + model_buffer, + input_md=input_md, + output_category_md=output_category_md, + output_score_md=output_score_md)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/writer_utils.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/writer_utils.py new file mode 100644 index 0000000..a8815fb --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/writer_utils.py
@@ -0,0 +1,182 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper methods for writing metadata into TFLite models.""" + +import array +import functools +from typing import List, Union, Optional + +from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb + + +def compute_flat_size(tensor_shape: Optional["array.array[int]"]) -> int: + """Computes the flat size (number of elements) of tensor shape. + + Args: + tensor_shape: an array of the tensor shape values. + + Returns: + The flat size of the tensor shape. Return 0 if tensor_shape is None. + """ + if not tensor_shape: + return 0 + return functools.reduce(lambda x, y: x * y, tensor_shape) + + +def get_input_tensor_names(model_buffer: bytearray) -> List[str]: + """Gets a list of the input tensor names.""" + subgraph = _get_subgraph(model_buffer) + tensor_names = [] + for i in range(subgraph.InputsLength()): + index = subgraph.Inputs(i) + tensor_names.append(subgraph.Tensors(index).Name().decode("utf-8")) + return tensor_names + + +def get_output_tensor_names(model_buffer: bytearray) -> List[str]: + """Gets a list of the output tensor names.""" + subgraph = _get_subgraph(model_buffer) + tensor_names = [] + for i in range(subgraph.OutputsLength()): + index = subgraph.Outputs(i) + tensor_names.append(subgraph.Tensors(index).Name().decode("utf-8")) + return tensor_names + + +def get_input_tensor_types( + model_buffer: bytearray) -> List[_schema_fb.TensorType]: + """Gets a list of the input tensor types.""" + subgraph = _get_subgraph(model_buffer) + tensor_types = [] + for i in range(subgraph.InputsLength()): + index = subgraph.Inputs(i) + tensor_types.append(subgraph.Tensors(index).Type()) + return tensor_types + + +def get_output_tensor_types( + model_buffer: bytearray) -> List[_schema_fb.TensorType]: + """Gets a list of the output tensor types.""" + subgraph = _get_subgraph(model_buffer) + tensor_types = [] + for i in range(subgraph.OutputsLength()): + index = subgraph.Outputs(i) + tensor_types.append(subgraph.Tensors(index).Type()) + return tensor_types + + +def get_input_tensor_shape(model_buffer: bytearray, + tensor_index: int) -> array.array: + """Gets the shape of the specified input tensor.""" + subgraph = _get_subgraph(model_buffer) + return subgraph.Tensors(subgraph.Inputs(tensor_index)).ShapeAsNumpy() + + +def load_file(file_path: str, mode: str = "rb") -> Union[str, bytes]: + """Loads file from the file path. + + Args: + file_path: valid file path string. + mode: a string specifies the model in which the file is opened. Use "rt" for + reading in text mode; use "rb" for reading in binary mode. + + Returns: + The loaded file in str or bytes. + """ + with open(file_path, mode) as file: + return file.read() + + +def save_file(file_bytes: Union[bytes, bytearray], + save_to_path: str, + mode: str = "wb"): + """Loads file from the file path. + + Args: + file_bytes: the bytes to be saved to file. + save_to_path: valid file path string. + mode: a string specifies the model in which the file is opened. Use "wt" for + writing in text mode; use "wb" for writing in binary mode. + + Returns: + The loaded file in str or bytes. + """ + with open(save_to_path, mode) as file: + file.write(file_bytes) + + +def get_tokenizer_associated_files( + tokenizer_options: Union[None, _metadata_fb.BertTokenizerOptionsT, + _metadata_fb.SentencePieceTokenizerOptionsT, + _metadata_fb.RegexTokenizerOptionsT] +) -> List[Optional[str]]: + """Gets a list of associated files packed in the tokenzier_options. + + Args: + tokenizer_options: a tokenizer metadata object. Support the following + tokenizer types: + 1. BertTokenizerOptions: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436 + 2. SentencePieceTokenizerOptions: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473 + 3. RegexTokenizerOptions: + https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475 + + Returns: + A list of associated files included in tokenizer_options. + """ + + if not tokenizer_options: + return [] + + def _get_file_path(files: _metadata_fb.AssociatedFileT) -> List[str]: + if not files: + return [] + return [file.name for file in files] + + if isinstance(tokenizer_options, (_metadata_fb.BertTokenizerOptionsT, + _metadata_fb.RegexTokenizerOptionsT)): + return _get_file_path(tokenizer_options.vocabFile) + elif isinstance(tokenizer_options, + _metadata_fb.SentencePieceTokenizerOptionsT): + return _get_file_path(tokenizer_options.vocabFile) + _get_file_path( + tokenizer_options.sentencePieceModel) + else: + return [] + + +def _get_subgraph(model_buffer: bytearray) -> _schema_fb.SubGraph: + """Gets the subgraph of the model. + + TFLite does not support multi-subgraph. A model should have exactly one + subgraph. + + Args: + model_buffer: valid buffer of the model file. + + Returns: + The subgraph of the model. + + Raises: + ValueError: if the model has more than one subgraph or has no subgraph. + """ + + model = _schema_fb.Model.GetRootAsModel(model_buffer, 0) + + # Use the first subgraph as default. TFLite Interpreter doesn't support + # multiple subgraphs yet, but models with mini-benchmark may have multiple + # subgraphs for acceleration evaluation purpose. + return model.Subgraphs(0)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/BUILD new file mode 100644 index 0000000..2b54974 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/BUILD
@@ -0,0 +1,50 @@ +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +py_test( + name = "metadata_test", + srcs = ["metadata_test.py"], + data = ["//tensorflow_lite_support/metadata/python/tests/testdata:test_files"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + # build rule placeholder: six dep, + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/metadata:metadata_schema_py", + "//tensorflow_lite_support/metadata:schema_py", + "//tensorflow_lite_support/metadata/python:metadata", + "@flatbuffers//:runtime_py", + ], +) + +py_test( + name = "metadata_parser_test", + srcs = ["metadata_parser_test.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/metadata/python:metadata", + ], +) + +py_test( + name = "metadata_writer_for_task_test", + srcs = ["metadata_writer_for_task_test.py"], + data = [ + "//tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier:test_files", + "//tensorflow_lite_support/metadata/python/tests/testdata/audio_embedder:test_files", + "//tensorflow_lite_support/metadata/python/tests/testdata/image_classifier:test_files", + ], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/metadata/python:metadata_writer_for_task", + "//tensorflow_lite_support/metadata/python/tests/metadata_writers:test_utils", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_parser_test.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_parser_test.py new file mode 100644 index 0000000..e02d69ae --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_parser_test.py
@@ -0,0 +1,37 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow_lite_support.metadata.metadata_parser.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import tensorflow as tf + +from tensorflow_lite_support.metadata.python import metadata_parser + + +class MetadataParserTest(tf.test.TestCase): + + def testVersionWellFormedSemanticVersion(self): + # Validates that the version is well-formed (x.y.z). + self.assertTrue( + re.match('[0-9]+\\.[0-9]+\\.[0-9]+', + metadata_parser.MetadataParser.VERSION)) + + +if __name__ == '__main__': + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_test.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_test.py new file mode 100644 index 0000000..bf8d3b7 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_test.py
@@ -0,0 +1,860 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow_lite_support.metadata.metadata.""" + +import enum +import os + +from absl.testing import parameterized +import six +import tensorflow as tf + +import flatbuffers +from tensorflow.python.platform import resource_loader +from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb +from tensorflow_lite_support.metadata.python import metadata as _metadata + + +class Tokenizer(enum.Enum): + BERT_TOKENIZER = 0 + SENTENCE_PIECE = 1 + + +class TensorType(enum.Enum): + INPUT = 0 + OUTPUT = 1 + + +def _read_file(file_name, mode="rb"): + with open(file_name, mode) as f: + return f.read() + + +class MetadataTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super(MetadataTest, self).setUp() + self._invalid_model_buf = None + self._invalid_file = "not_existed_file" + self._model_buf = self._create_model_buf() + self._model_file = self.create_tempfile().full_path + with open(self._model_file, "wb") as f: + f.write(self._model_buf) + self._metadata_file = self._create_metadata_file() + self._metadata_file_with_version = self._create_metadata_file_with_version( + self._metadata_file, "1.0.0") + self._file1 = self.create_tempfile("file1").full_path + self._file2 = self.create_tempfile("file2").full_path + self._file2_content = b"file2_content" + with open(self._file2, "wb") as f: + f.write(self._file2_content) + self._file3 = self.create_tempfile("file3").full_path + + def _create_model_buf(self): + # Create a model with two inputs and one output, which matches the metadata + # created by _create_metadata_file(). + metadata_field = _schema_fb.MetadataT() + subgraph = _schema_fb.SubGraphT() + subgraph.inputs = [0, 1] + subgraph.outputs = [2] + + metadata_field.name = "meta" + buffer_field = _schema_fb.BufferT() + model = _schema_fb.ModelT() + model.subgraphs = [subgraph] + # Creates the metadata and buffer fields for testing purposes. + model.metadata = [metadata_field, metadata_field] + model.buffers = [buffer_field, buffer_field, buffer_field] + model_builder = flatbuffers.Builder(0) + model_builder.Finish( + model.Pack(model_builder), + _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER) + return model_builder.Output() + + def _create_metadata_file(self): + associated_file1 = _metadata_fb.AssociatedFileT() + associated_file1.name = b"file1" + associated_file2 = _metadata_fb.AssociatedFileT() + associated_file2.name = b"file2" + self.expected_recorded_files = [ + six.ensure_str(associated_file1.name), + six.ensure_str(associated_file2.name) + ] + + input_meta = _metadata_fb.TensorMetadataT() + output_meta = _metadata_fb.TensorMetadataT() + output_meta.associatedFiles = [associated_file2] + subgraph = _metadata_fb.SubGraphMetadataT() + # Create a model with two inputs and one output. + subgraph.inputTensorMetadata = [input_meta, input_meta] + subgraph.outputTensorMetadata = [output_meta] + + model_meta = _metadata_fb.ModelMetadataT() + model_meta.name = "Mobilenet_quantized" + model_meta.associatedFiles = [associated_file1] + model_meta.subgraphMetadata = [subgraph] + b = flatbuffers.Builder(0) + b.Finish( + model_meta.Pack(b), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + + metadata_file = self.create_tempfile().full_path + with open(metadata_file, "wb") as f: + f.write(b.Output()) + return metadata_file + + def _create_model_buffer_with_wrong_identifier(self): + wrong_identifier = b"widn" + model = _schema_fb.ModelT() + model_builder = flatbuffers.Builder(0) + model_builder.Finish(model.Pack(model_builder), wrong_identifier) + return model_builder.Output() + + def _create_metadata_buffer_with_wrong_identifier(self): + # Creates a metadata with wrong identifier + wrong_identifier = b"widn" + metadata = _metadata_fb.ModelMetadataT() + metadata_builder = flatbuffers.Builder(0) + metadata_builder.Finish(metadata.Pack(metadata_builder), wrong_identifier) + return metadata_builder.Output() + + def _populate_metadata_with_identifier(self, model_buf, metadata_buf, + identifier): + # For testing purposes only. MetadataPopulator cannot populate metadata with + # wrong identifiers. + model = _schema_fb.ModelT.InitFromObj( + _schema_fb.Model.GetRootAsModel(model_buf, 0)) + buffer_field = _schema_fb.BufferT() + buffer_field.data = metadata_buf + model.buffers = [buffer_field] + # Creates a new metadata field. + metadata_field = _schema_fb.MetadataT() + metadata_field.name = _metadata.MetadataPopulator.METADATA_FIELD_NAME + metadata_field.buffer = len(model.buffers) - 1 + model.metadata = [metadata_field] + b = flatbuffers.Builder(0) + b.Finish(model.Pack(b), identifier) + return b.Output() + + def _create_metadata_file_with_version(self, metadata_file, min_version): + # Creates a new metadata file with the specified min_version for testing + # purposes. + metadata_buf = bytearray(_read_file(metadata_file)) + + metadata = _metadata_fb.ModelMetadataT.InitFromObj( + _metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0)) + metadata.minParserVersion = min_version + + b = flatbuffers.Builder(0) + b.Finish( + metadata.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + + metadata_file_with_version = self.create_tempfile().full_path + with open(metadata_file_with_version, "wb") as f: + f.write(b.Output()) + return metadata_file_with_version + + +class MetadataPopulatorTest(MetadataTest): + + def _create_bert_tokenizer(self): + vocab_file_name = "bert_vocab" + vocab = _metadata_fb.AssociatedFileT() + vocab.name = vocab_file_name + vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY + tokenizer = _metadata_fb.ProcessUnitT() + tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions + tokenizer.options = _metadata_fb.BertTokenizerOptionsT() + tokenizer.options.vocabFile = [vocab] + return tokenizer, [vocab_file_name] + + def _create_sentence_piece_tokenizer(self): + sp_model_name = "sp_model" + vocab_file_name = "sp_vocab" + sp_model = _metadata_fb.AssociatedFileT() + sp_model.name = sp_model_name + vocab = _metadata_fb.AssociatedFileT() + vocab.name = vocab_file_name + vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY + tokenizer = _metadata_fb.ProcessUnitT() + tokenizer.optionsType = ( + _metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions) + tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT() + tokenizer.options.sentencePieceModel = [sp_model] + tokenizer.options.vocabFile = [vocab] + return tokenizer, [sp_model_name, vocab_file_name] + + def _create_tokenizer(self, tokenizer_type): + if tokenizer_type is Tokenizer.BERT_TOKENIZER: + return self._create_bert_tokenizer() + elif tokenizer_type is Tokenizer.SENTENCE_PIECE: + return self._create_sentence_piece_tokenizer() + else: + raise ValueError( + "The tokenizer type, {0}, is unsupported.".format(tokenizer_type)) + + def _create_tempfiles(self, file_names): + tempfiles = [] + for name in file_names: + tempfiles.append(self.create_tempfile(name).full_path) + return tempfiles + + def _create_model_meta_with_subgraph_meta(self, subgraph_meta): + model_meta = _metadata_fb.ModelMetadataT() + model_meta.subgraphMetadata = [subgraph_meta] + b = flatbuffers.Builder(0) + b.Finish( + model_meta.Pack(b), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + return b.Output() + + def testToValidModelFile(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + self.assertIsInstance(populator, _metadata.MetadataPopulator) + + def testToInvalidModelFile(self): + with self.assertRaises(IOError) as error: + _metadata.MetadataPopulator.with_model_file(self._invalid_file) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testToValidModelBuffer(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + self.assertIsInstance(populator, _metadata.MetadataPopulator) + + def testToInvalidModelBuffer(self): + with self.assertRaises(ValueError) as error: + _metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf) + self.assertEqual("model_buf cannot be empty.", str(error.exception)) + + def testToModelBufferWithWrongIdentifier(self): + model_buf = self._create_model_buffer_with_wrong_identifier() + with self.assertRaises(ValueError) as error: + _metadata.MetadataPopulator.with_model_buffer(model_buf) + self.assertEqual( + "The model provided does not have the expected identifier, and " + "may not be a valid TFLite model.", str(error.exception)) + + def testSinglePopulateAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + populator.load_associated_files([self._file1]) + populator.populate() + + packed_files = populator.get_packed_associated_file_list() + expected_packed_files = [os.path.basename(self._file1)] + self.assertEqual(set(packed_files), set(expected_packed_files)) + + def testRepeatedPopulateAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_associated_files([self._file1, self._file2]) + # Loads file2 multiple times. + populator.load_associated_files([self._file2]) + populator.populate() + + packed_files = populator.get_packed_associated_file_list() + expected_packed_files = [ + os.path.basename(self._file1), + os.path.basename(self._file2) + ] + self.assertLen(packed_files, 2) + self.assertEqual(set(packed_files), set(expected_packed_files)) + + # Check if the model buffer read from file is the same as that read from + # get_model_buffer(). + model_buf_from_file = _read_file(self._model_file) + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testPopulateInvalidAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(IOError) as error: + populator.load_associated_files([self._invalid_file]) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testPopulatePackedAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + populator.load_associated_files([self._file1]) + populator.populate() + with self.assertRaises(ValueError) as error: + populator.load_associated_files([self._file1]) + populator.populate() + self.assertEqual( + "File, '{0}', has already been packed.".format( + os.path.basename(self._file1)), str(error.exception)) + + def testLoadAssociatedFileBuffers(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + file_buffer = _read_file(self._file1) + populator.load_associated_file_buffers({self._file1: file_buffer}) + populator.populate() + + packed_files = populator.get_packed_associated_file_list() + expected_packed_files = [os.path.basename(self._file1)] + self.assertEqual(set(packed_files), set(expected_packed_files)) + + def testRepeatedLoadAssociatedFileBuffers(self): + file_buffer1 = _read_file(self._file1) + file_buffer2 = _read_file(self._file2) + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + + populator.load_associated_file_buffers({ + self._file1: file_buffer1, + self._file2: file_buffer2 + }) + # Loads file2 multiple times. + populator.load_associated_file_buffers({self._file2: file_buffer2}) + populator.populate() + + packed_files = populator.get_packed_associated_file_list() + expected_packed_files = [ + os.path.basename(self._file1), + os.path.basename(self._file2) + ] + self.assertEqual(set(packed_files), set(expected_packed_files)) + + # Check if the model buffer read from file is the same as that read from + # get_model_buffer(). + model_buf_from_file = _read_file(self._model_file) + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testLoadPackedAssociatedFileBuffersFails(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + file_buffer = _read_file(self._file1) + populator.load_associated_file_buffers({self._file1: file_buffer}) + populator.populate() + + # Load file1 again should fail. + with self.assertRaises(ValueError) as error: + populator.load_associated_file_buffers({self._file1: file_buffer}) + populator.populate() + self.assertEqual( + "File, '{0}', has already been packed.".format( + os.path.basename(self._file1)), str(error.exception)) + + def testGetPackedAssociatedFileList(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + packed_files = populator.get_packed_associated_file_list() + self.assertEqual(packed_files, []) + + def testPopulateMetadataFileToEmptyModelFile(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1, self._file2]) + populator.populate() + + model_buf_from_file = _read_file(self._model_file) + model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0) + # self._model_file already has two elements in the metadata field, so the + # populated TFLite metadata will be the third element. + metadata_field = model.Metadata(2) + self.assertEqual( + six.ensure_str(metadata_field.Name()), + six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME)) + + buffer_index = metadata_field.Buffer() + buffer_data = model.Buffers(buffer_index) + metadata_buf_np = buffer_data.DataAsNumpy() + metadata_buf = metadata_buf_np.tobytes() + expected_metadata_buf = bytearray( + _read_file(self._metadata_file_with_version)) + self.assertEqual(metadata_buf, expected_metadata_buf) + + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(self.expected_recorded_files)) + + # Up to now, we've proved the correctness of the model buffer that read from + # file. Then we'll test if get_model_buffer() gives the same model buffer. + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testPopulateMetadataFileWithoutAssociatedFiles(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1]) + # Suppose to populate self._file2, because it is recorded in the metadta. + with self.assertRaises(ValueError) as error: + populator.populate() + self.assertEqual(("File, '{0}', is recorded in the metadata, but has " + "not been loaded into the populator.").format( + os.path.basename(self._file2)), str(error.exception)) + + def testPopulateMetadataBufferWithWrongIdentifier(self): + metadata_buf = self._create_metadata_buffer_with_wrong_identifier() + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(metadata_buf) + self.assertEqual( + "The metadata buffer does not have the expected identifier, and may not" + " be a valid TFLite Metadata.", str(error.exception)) + + def _assert_golden_metadata(self, model_file): + model_buf_from_file = _read_file(model_file) + model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0) + # There are two elements in model.Metadata array before the population. + # Metadata should be packed to the third element in the array. + metadata_field = model.Metadata(2) + self.assertEqual( + six.ensure_str(metadata_field.Name()), + six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME)) + + buffer_index = metadata_field.Buffer() + buffer_data = model.Buffers(buffer_index) + metadata_buf_np = buffer_data.DataAsNumpy() + metadata_buf = metadata_buf_np.tobytes() + expected_metadata_buf = bytearray( + _read_file(self._metadata_file_with_version)) + self.assertEqual(metadata_buf, expected_metadata_buf) + + def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self): + # First, creates a dummy metadata different from self._metadata_file. It + # needs to have the same input/output tensor numbers as self._model_file. + # Populates it and the associated files into the model. + input_meta = _metadata_fb.TensorMetadataT() + output_meta = _metadata_fb.TensorMetadataT() + subgraph = _metadata_fb.SubGraphMetadataT() + # Create a model with two inputs and one output. + subgraph.inputTensorMetadata = [input_meta, input_meta] + subgraph.outputTensorMetadata = [output_meta] + model_meta = _metadata_fb.ModelMetadataT() + model_meta.subgraphMetadata = [subgraph] + b = flatbuffers.Builder(0) + b.Finish( + model_meta.Pack(b), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_buf = b.Output() + + # Populate the metadata. + populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator1.load_metadata_buffer(metadata_buf) + populator1.load_associated_files([self._file1, self._file2]) + populator1.populate() + + # Then, populate the metadata again. + populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator2.load_metadata_file(self._metadata_file) + populator2.populate() + + # Test if the metadata is populated correctly. + self._assert_golden_metadata(self._model_file) + + def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1, self._file2]) + populator.populate() + + # Tests if the metadata is populated correctly. + self._assert_golden_metadata(self._model_file) + + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(self.expected_recorded_files)) + + # Up to now, we've proved the correctness of the model buffer that read from + # file. Then we'll test if get_model_buffer() gives the same model buffer. + model_buf_from_file = _read_file(self._model_file) + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testPopulateInvalidMetadataFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(IOError) as error: + populator.load_metadata_file(self._invalid_file) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testPopulateInvalidMetadataBuffer(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer([]) + self.assertEqual("The metadata to be populated is empty.", + str(error.exception)) + + def testGetModelBufferBeforePopulatingData(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + model_buf = populator.get_model_buffer() + expected_model_buf = self._model_buf + self.assertEqual(model_buf, expected_model_buf) + + def testLoadMetadataBufferWithNoSubgraphMetadataThrowsException(self): + # Create a dummy metadata without Subgraph. + model_meta = _metadata_fb.ModelMetadataT() + builder = flatbuffers.Builder(0) + builder.Finish( + model_meta.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + meta_buf = builder.Output() + + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(meta_buf) + self.assertEqual( + "The number of SubgraphMetadata should be exactly one, but got 0.", + str(error.exception)) + + def testLoadMetadataBufferWithWrongInputMetaNumberThrowsException(self): + # Create a dummy metadata with no input tensor metadata, while the expected + # number is 2. + output_meta = _metadata_fb.TensorMetadataT() + subgprah_meta = _metadata_fb.SubGraphMetadataT() + subgprah_meta.outputTensorMetadata = [output_meta] + model_meta = _metadata_fb.ModelMetadataT() + model_meta.subgraphMetadata = [subgprah_meta] + builder = flatbuffers.Builder(0) + builder.Finish( + model_meta.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + meta_buf = builder.Output() + + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(meta_buf) + self.assertEqual( + ("The number of input tensors (2) should match the number of " + "input tensor metadata (0)"), str(error.exception)) + + def testLoadMetadataBufferWithWrongOutputMetaNumberThrowsException(self): + # Create a dummy metadata with no output tensor metadata, while the expected + # number is 1. + input_meta = _metadata_fb.TensorMetadataT() + subgprah_meta = _metadata_fb.SubGraphMetadataT() + subgprah_meta.inputTensorMetadata = [input_meta, input_meta] + model_meta = _metadata_fb.ModelMetadataT() + model_meta.subgraphMetadata = [subgprah_meta] + builder = flatbuffers.Builder(0) + builder.Finish( + model_meta.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + meta_buf = builder.Output() + + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(meta_buf) + self.assertEqual( + ("The number of output tensors (1) should match the number of " + "output tensor metadata (0)"), str(error.exception)) + + def testLoadMetadataAndAssociatedFilesShouldSucceeds(self): + # Create a src model with metadata and two associated files. + src_model_buf = self._create_model_buf() + populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf) + populator_src.load_metadata_file(self._metadata_file) + populator_src.load_associated_files([self._file1, self._file2]) + populator_src.populate() + + # Create a model to be populated with the metadata and files from + # src_model_buf. + dst_model_buf = self._create_model_buf() + populator_dst = _metadata.MetadataPopulator.with_model_buffer(dst_model_buf) + populator_dst.load_metadata_and_associated_files( + populator_src.get_model_buffer()) + populator_dst.populate() + + # Tests if the metadata and associated files are populated correctly. + dst_model_file = self.create_tempfile().full_path + with open(dst_model_file, "wb") as f: + f.write(populator_dst.get_model_buffer()) + self._assert_golden_metadata(dst_model_file) + + recorded_files = populator_dst.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(self.expected_recorded_files)) + + @parameterized.named_parameters( + { + "testcase_name": "InputTensorWithBert", + "tensor_type": TensorType.INPUT, + "tokenizer_type": Tokenizer.BERT_TOKENIZER + }, { + "testcase_name": "OutputTensorWithBert", + "tensor_type": TensorType.OUTPUT, + "tokenizer_type": Tokenizer.BERT_TOKENIZER + }, { + "testcase_name": "InputTensorWithSentencePiece", + "tensor_type": TensorType.INPUT, + "tokenizer_type": Tokenizer.SENTENCE_PIECE + }, { + "testcase_name": "OutputTensorWithSentencePiece", + "tensor_type": TensorType.OUTPUT, + "tokenizer_type": Tokenizer.SENTENCE_PIECE + }) + def testGetRecordedAssociatedFileListWithSubgraphTensor( + self, tensor_type, tokenizer_type): + # Creates a metadata with the tokenizer in the tensor process units. + tokenizer, expected_files = self._create_tokenizer(tokenizer_type) + + # Create the tensor with process units. + tensor = _metadata_fb.TensorMetadataT() + tensor.processUnits = [tokenizer] + + # Create the subgrah with the tensor. + subgraph = _metadata_fb.SubGraphMetadataT() + dummy_tensor_meta = _metadata_fb.TensorMetadataT() + subgraph.outputTensorMetadata = [dummy_tensor_meta] + if tensor_type is TensorType.INPUT: + subgraph.inputTensorMetadata = [tensor, dummy_tensor_meta] + subgraph.outputTensorMetadata = [dummy_tensor_meta] + elif tensor_type is TensorType.OUTPUT: + subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta] + subgraph.outputTensorMetadata = [tensor] + else: + raise ValueError( + "The tensor type, {0}, is unsupported.".format(tensor_type)) + + # Create a model metadata with the subgraph metadata + meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph) + + # Creates the tempfiles. + tempfiles = self._create_tempfiles(expected_files) + + # Creates the MetadataPopulator object. + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_buffer(meta_buffer) + populator.load_associated_files(tempfiles) + populator.populate() + + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(expected_files)) + + @parameterized.named_parameters( + { + "testcase_name": "InputTensorWithBert", + "tensor_type": TensorType.INPUT, + "tokenizer_type": Tokenizer.BERT_TOKENIZER + }, { + "testcase_name": "OutputTensorWithBert", + "tensor_type": TensorType.OUTPUT, + "tokenizer_type": Tokenizer.BERT_TOKENIZER + }, { + "testcase_name": "InputTensorWithSentencePiece", + "tensor_type": TensorType.INPUT, + "tokenizer_type": Tokenizer.SENTENCE_PIECE + }, { + "testcase_name": "OutputTensorWithSentencePiece", + "tensor_type": TensorType.OUTPUT, + "tokenizer_type": Tokenizer.SENTENCE_PIECE + }) + def testGetRecordedAssociatedFileListWithSubgraphProcessUnits( + self, tensor_type, tokenizer_type): + # Creates a metadata with the tokenizer in the subgraph process units. + tokenizer, expected_files = self._create_tokenizer(tokenizer_type) + + # Create the subgraph with process units. + subgraph = _metadata_fb.SubGraphMetadataT() + if tensor_type is TensorType.INPUT: + subgraph.inputProcessUnits = [tokenizer] + elif tensor_type is TensorType.OUTPUT: + subgraph.outputProcessUnits = [tokenizer] + else: + raise ValueError( + "The tensor type, {0}, is unsupported.".format(tensor_type)) + + # Creates the input and output tensor meta to match self._model_file. + dummy_tensor_meta = _metadata_fb.TensorMetadataT() + subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta] + subgraph.outputTensorMetadata = [dummy_tensor_meta] + + # Create a model metadata with the subgraph metadata + meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph) + + # Creates the tempfiles. + tempfiles = self._create_tempfiles(expected_files) + + # Creates the MetadataPopulator object. + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_buffer(meta_buffer) + populator.load_associated_files(tempfiles) + populator.populate() + + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(expected_files)) + + def testPopulatedFullPathAssociatedFileShouldSucceed(self): + # Create AssociatedFileT using the full path file name. + associated_file = _metadata_fb.AssociatedFileT() + associated_file.name = self._file1 + + # Create model metadata with the associated file. + subgraph = _metadata_fb.SubGraphMetadataT() + subgraph.associatedFiles = [associated_file] + # Creates the input and output tensor metadata to match self._model_file. + dummy_tensor = _metadata_fb.TensorMetadataT() + subgraph.inputTensorMetadata = [dummy_tensor, dummy_tensor] + subgraph.outputTensorMetadata = [dummy_tensor] + md_buffer = self._create_model_meta_with_subgraph_meta(subgraph) + + # Populate the metadata to a model. + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_buffer(md_buffer) + populator.load_associated_files([self._file1]) + populator.populate() + + # The recorded file name in metadata should only contain file basename; file + # directory should not be included. + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set([os.path.basename(self._file1)])) + + +class MetadataDisplayerTest(MetadataTest): + + def setUp(self): + super(MetadataDisplayerTest, self).setUp() + self._model_with_meta_file = ( + self._create_model_with_metadata_and_associated_files()) + + def _create_model_with_metadata_and_associated_files(self): + model_buf = self._create_model_buf() + model_file = self.create_tempfile().full_path + with open(model_file, "wb") as f: + f.write(model_buf) + + populator = _metadata.MetadataPopulator.with_model_file(model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1, self._file2]) + populator.populate() + return model_file + + def testLoadModelBufferMetadataBufferWithWrongIdentifierThrowsException(self): + model_buf = self._create_model_buffer_with_wrong_identifier() + metadata_buf = self._create_metadata_buffer_with_wrong_identifier() + model_buf = self._populate_metadata_with_identifier( + model_buf, metadata_buf, + _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER) + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(model_buf) + self.assertEqual( + "The metadata buffer does not have the expected identifier, and may not" + " be a valid TFLite Metadata.", str(error.exception)) + + def testLoadModelBufferModelBufferWithWrongIdentifierThrowsException(self): + model_buf = self._create_model_buffer_with_wrong_identifier() + metadata_file = self._create_metadata_file() + wrong_identifier = b"widn" + metadata_buf = bytearray(_read_file(metadata_file)) + model_buf = self._populate_metadata_with_identifier(model_buf, metadata_buf, + wrong_identifier) + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(model_buf) + self.assertEqual( + "The model provided does not have the expected identifier, and " + "may not be a valid TFLite model.", str(error.exception)) + + def testLoadModelFileInvalidModelFileThrowsException(self): + with self.assertRaises(IOError) as error: + _metadata.MetadataDisplayer.with_model_file(self._invalid_file) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testLoadModelFileModelWithoutMetadataThrowsException(self): + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_file(self._model_file) + self.assertEqual("The model does not have metadata.", str(error.exception)) + + def testLoadModelFileModelWithMetadata(self): + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + self.assertIsInstance(displayer, _metadata.MetadataDisplayer) + + def testLoadModelBufferInvalidModelBufferThrowsException(self): + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(_read_file(self._file1)) + self.assertEqual("model_buffer cannot be empty.", str(error.exception)) + + def testLoadModelBufferModelWithOutMetadataThrowsException(self): + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(self._create_model_buf()) + self.assertEqual("The model does not have metadata.", str(error.exception)) + + def testLoadModelBufferModelWithMetadata(self): + displayer = _metadata.MetadataDisplayer.with_model_buffer( + _read_file(self._model_with_meta_file)) + self.assertIsInstance(displayer, _metadata.MetadataDisplayer) + + def testGetAssociatedFileBufferShouldSucceed(self): + # _model_with_meta_file contains file1 and file2. + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + + actual_content = displayer.get_associated_file_buffer("file2") + self.assertEqual(actual_content, self._file2_content) + + def testGetAssociatedFileBufferFailsWithNonExistentFile(self): + # _model_with_meta_file contains file1 and file2. + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + + non_existent_file = "non_existent_file" + with self.assertRaises(ValueError) as error: + displayer.get_associated_file_buffer(non_existent_file) + self.assertEqual( + "The file, {}, does not exist in the model.".format(non_existent_file), + str(error.exception)) + + def testGetMetadataBufferShouldSucceed(self): + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + actual_buffer = displayer.get_metadata_buffer() + actual_json = _metadata.convert_to_json(actual_buffer) + + # Verifies the generated json file. + golden_json_file_path = resource_loader.get_path_to_datafile( + "testdata/golden_json.json") + with open(golden_json_file_path, "r") as f: + expected = f.read() + self.assertEqual(actual_json, expected) + + def testGetMetadataJsonModelWithMetadata(self): + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + actual = displayer.get_metadata_json() + + # Verifies the generated json file. + golden_json_file_path = resource_loader.get_path_to_datafile( + "testdata/golden_json.json") + expected = _read_file(golden_json_file_path, "r") + self.assertEqual(actual, expected) + + def testGetPackedAssociatedFileListModelWithMetadata(self): + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + packed_files = displayer.get_packed_associated_file_list() + + expected_packed_files = [ + os.path.basename(self._file1), + os.path.basename(self._file2) + ] + self.assertLen( + packed_files, 2, + "The following two associated files packed to the model: {0}; {1}" + .format(expected_packed_files[0], expected_packed_files[1])) + self.assertEqual(set(packed_files), set(expected_packed_files)) + + +class MetadataUtilTest(MetadataTest): + + def test_convert_to_json_should_succeed(self): + metadata_buf = _read_file(self._metadata_file_with_version) + metadata_json = _metadata.convert_to_json(metadata_buf) + + # Verifies the generated json file. + golden_json_file_path = resource_loader.get_path_to_datafile( + "testdata/golden_json.json") + expected = _read_file(golden_json_file_path, "r") + self.assertEqual(metadata_json, expected) + + +if __name__ == "__main__": + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writer_for_task_test.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writer_for_task_test.py new file mode 100644 index 0000000..8c4f331 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writer_for_task_test.py
@@ -0,0 +1,507 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow_lite_support.metadata.metadata_writer_for_task.""" + +import os +import sys +import tensorflow as tf +from tensorflow_lite_support.metadata.python import metadata_writer_for_task as mt +from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils + +_AUDIO_CLASSIFICATION_MODEL = '../testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6.tflite' +_AUDIO_EMBEDDING_MODEL = '../testdata/audio_embedder/yamnet_embedding.tflite' +_IMAGE_CLASSIFIER_MODEL = '../testdata/image_classifier/mobilenet_v2_1.0_224.tflite' + + +class LabelsTest(tf.test.TestCase): + + def test_category_name(self): + labels = mt.Labels() + self.assertEqual( + labels.add(['a', 'b'], use_as_category_name=True)._labels, + [(None, 'labels.txt', ['a', 'b'])]) + # Overwrite categories + self.assertEqual( + labels.add(['new_a', 'new_b'], use_as_category_name=True)._labels, + [(None, 'labels.txt', ['new_a', 'new_b'])]) + + def test_locale(self): + labels = mt.Labels() + + # Add from file. + en_filepath = self.create_tempfile().full_path + with open(en_filepath, 'w') as f: + f.write('a\nb') + labels.add_from_file(en_filepath, 'en') + + # Customized file name + labels.add(['A', 'B'], 'fr', exported_filename='my_file.txt') + self.assertEqual(labels._labels, [ + ('en', 'labels_en.txt', ['a', 'b']), + ('fr', 'my_file.txt', ['A', 'B']), + ]) + + # Add category name, which should be the first file in the list. + labels.add(['aa', 'bb'], 'cn', use_as_category_name=True) + self.assertEqual(labels._labels, [ + ('cn', 'labels_cn.txt', ['aa', 'bb']), + ('en', 'labels_en.txt', ['a', 'b']), + ('fr', 'my_file.txt', ['A', 'B']), + ]) + + +class MetadataWriterForTaskTest(tf.test.TestCase): + + def test_initialize_without_with_block(self): + writer = mt.Writer( + test_utils.load_file(_AUDIO_CLASSIFICATION_MODEL), + model_name='test_model', + model_description='test_description') + + # Calling `add_classification_output` outside the `with` block fails. + with self.assertRaisesRegex(AttributeError, '_temp_folder'): + writer.add_classification_output(mt.Labels().add(['cat', 'dog'])) + + writer.__enter__() + writer.add_classification_output(mt.Labels().add(['cat', 'dog'])) + writer.__exit__(*sys.exc_info()) + + # Calling `add_classification_output` after `with` block closes also fails. + with self.assertRaisesRegex(AttributeError, '_temp_folder'): + writer.add_classification_output(mt.Labels().add(['cat', 'dog'])) + + def test_initialize_and_populate(self): + with mt.Writer( + test_utils.load_file(_AUDIO_CLASSIFICATION_MODEL), + model_name='my_audio_model', + model_description='my_description') as writer: + out_dir = self.create_tempdir() + _, metadata_json = writer.populate( + os.path.join(out_dir, 'model.tflite'), + os.path.join(out_dir, 'metadata.json')) + self.assertJsonEqual( + metadata_json, """{ + "name": "my_audio_model", + "description": "my_description", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "waveform_binary" + } + ], + "output_tensor_metadata": [ + { + "name": "tower0/network/layer32/final_output" + } + ] + } + ], + "min_parser_version": "1.0.0" +} +""") + + def test_audio_classifier(self): + with mt.Writer( + test_utils.load_file(_AUDIO_CLASSIFICATION_MODEL), + model_name='audio_classifier', + model_description='Classify the input audio clip') as writer: + out_dir = self.create_tempdir() + writer.add_audio_input(sample_rate=16000, channels=1) + writer.add_classification_output(mt.Labels().add( + ['sound1', 'sound2'], 'en', use_as_category_name=True)) + writer.populate( + os.path.join(out_dir, 'model.tflite'), + os.path.join(out_dir, 'metadata.tflite')) + self.assertEqual( + test_utils.load_file(os.path.join(out_dir, 'metadata.tflite'), 'r'), + """{ + "name": "audio_classifier", + "description": "Classify the input audio clip", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "audio", + "description": "Input audio clip to be processed.", + "content": { + "content_properties_type": "AudioProperties", + "content_properties": { + "sample_rate": 16000, + "channels": 1 + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels_en.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "en" + } + ] + } + ] + } + ], + "min_parser_version": "1.3.0" +} +""") + + def test_audio_classifier_with_locale(self): + with mt.Writer( + test_utils.load_file(_AUDIO_CLASSIFICATION_MODEL), + model_name='audio_classifier', + model_description='Classify the input audio clip') as writer: + out_dir = self.create_tempdir() + writer.add_audio_input(sample_rate=16000, channels=1) + writer.add_classification_output(mt.Labels().add( + ['/id1', '/id2'], + use_as_category_name=True).add(['sound1', 'sound2'], + 'en').add(['son1', 'son2'], 'fr')) + _, metadata_json = writer.populate( + os.path.join(out_dir, 'model.tflite'), + os.path.join(out_dir, 'metadata.tflite')) + self.assertEqual( + metadata_json, """{ + "name": "audio_classifier", + "description": "Classify the input audio clip", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "audio", + "description": "Input audio clip to be processed.", + "content": { + "content_properties_type": "AudioProperties", + "content_properties": { + "sample_rate": 16000, + "channels": 1 + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "labels_en.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "en" + }, + { + "name": "labels_fr.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "fr" + } + ] + } + ] + } + ], + "min_parser_version": "1.3.0" +} +""") + + def test_audio_classifier_with_locale_and_score_calibration(self): + with mt.Writer( + test_utils.load_file(_AUDIO_CLASSIFICATION_MODEL), + model_name='audio_classifier', + model_description='Classify the input audio clip') as writer: + out_dir = self.create_tempdir() + writer.add_audio_input(sample_rate=16000, channels=1) + writer.add_classification_output( + mt.Labels().add(['/id1', '/id2'], use_as_category_name=True).add( + ['sound1', 'sound2'], 'en').add(['son1', 'son2'], 'fr'), + score_calibration=mt.ScoreCalibration( + mt.ScoreCalibration.transformation_types.INVERSE_LOGISTIC, [ + mt.CalibrationParameter(1., 2., 3., None), + mt.CalibrationParameter(1., 2., 3., 4.), + ], + default_score=0.5)) + _, metadata_json = writer.populate( + os.path.join(out_dir, 'model.tflite'), + os.path.join(out_dir, 'metadata.tflite')) + self.assertEqual( + metadata_json, """{ + "name": "audio_classifier", + "description": "Classify the input audio clip", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "audio", + "description": "Input audio clip to be processed.", + "content": { + "content_properties_type": "AudioProperties", + "content_properties": { + "sample_rate": 16000, + "channels": 1 + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "INVERSE_LOGISTIC", + "default_score": 0.5 + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "labels_en.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "en" + }, + { + "name": "labels_fr.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "fr" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ], + "min_parser_version": "1.3.0" +} +""") + + def test_audio_embedder(self): + with mt.Writer( + test_utils.load_file(_AUDIO_EMBEDDING_MODEL), + model_name='audio_embedder', + model_description='Generate embedding for the input audio clip' + ) as writer: + out_dir = self.create_tempdir() + writer.add_audio_input(sample_rate=16000, channels=1) + writer.add_embedding_output() + _, metadata_json = writer.populate( + os.path.join(out_dir, 'model.tflite'), + os.path.join(out_dir, 'metadata.json')) + self.assertEqual( + metadata_json, """{ + "name": "audio_embedder", + "description": "Generate embedding for the input audio clip", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "audio", + "description": "Input audio clip to be processed.", + "content": { + "content_properties_type": "AudioProperties", + "content_properties": { + "sample_rate": 16000, + "channels": 1 + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "embedding", + "description": "Embedding vector of the input.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ] + } + ], + "min_parser_version": "1.3.0" +} +""") + + def test_image_classifier(self): + with mt.Writer( + test_utils.load_file(_IMAGE_CLASSIFIER_MODEL), + model_name='image_classifier', + model_description='Imagenet classification model') as writer: + out_dir = self.create_tempdir() + writer.add_image_input( + norm_mean=[127.5, 127.5, 127.5], + norm_std=[127.5, 127.5, 127.5], + color_space_type=mt.Writer.color_space_types.RGB) + writer.add_classification_output(mt.Labels().add(['a', 'b', 'c'])) + _, metadata_json = writer.populate( + os.path.join(out_dir, 'model.tflite'), + os.path.join(out_dir, 'metadat.json')) + self.assertEqual( + metadata_json, """{ + "name": "image_classifier", + "description": "Imagenet classification model", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be processed.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5, + 127.5, + 127.5 + ], + "std": [ + 127.5, + 127.5, + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 1.0, + 1.0, + 1.0 + ], + "min": [ + -1.0, + -1.0, + -1.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ] + } + ], + "min_parser_version": "1.0.0" +} +""") + + +if __name__ == '__main__': + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/BUILD new file mode 100644 index 0000000..fc1365d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/BUILD
@@ -0,0 +1,171 @@ +# Placeholder for internal Python strict test compatibility macro. +# Placeholder for internal Python strict library compatibility macro. + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +py_library( + name = "test_utils", + testonly = 1, + srcs = [ + "test_utils.py", + ], + visibility = ["//tensorflow_lite_support/metadata/python/tests:__subpackages__"], +) + +py_test( + name = "metadata_writer_test", + srcs = ["metadata_writer_test.py"], + data = [ + "//tensorflow_lite_support/metadata/python/tests/testdata:test_files", + "//tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier:test_files", + "//tensorflow_lite_support/metadata/python/tests/testdata/question_answerer:test_files", + ], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":test_utils", + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/metadata:metadata_schema_py", + "//tensorflow_lite_support/metadata/python:metadata", + "//tensorflow_lite_support/metadata/python/metadata_writers:metadata_info", + "//tensorflow_lite_support/metadata/python/metadata_writers:metadata_writer", + "@flatbuffers//:runtime_py", + ], +) + +py_test( + name = "metadata_info_test", + srcs = ["metadata_info_test.py"], + data = [ + "//tensorflow_lite_support/metadata/python/tests/testdata:test_files", + "//tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier:test_files", + "//tensorflow_lite_support/metadata/python/tests/testdata/image_classifier:test_files", + ], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":test_utils", + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/metadata:metadata_schema_py", + "//tensorflow_lite_support/metadata:schema_py", + "//tensorflow_lite_support/metadata/python:metadata", + "//tensorflow_lite_support/metadata/python/metadata_writers:metadata_info", + "@absl_py//absl/testing:parameterized", + "@flatbuffers//:runtime_py", + ], +) + +py_test( + name = "writer_utils_test", + srcs = ["writer_utils_test.py"], + data = ["//tensorflow_lite_support/metadata/python/tests/testdata/object_detector:test_files"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":test_utils", + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/metadata:schema_py", + "//tensorflow_lite_support/metadata/python/metadata_writers:metadata_info", + "//tensorflow_lite_support/metadata/python/metadata_writers:writer_utils", + ], +) + +py_test( + name = "image_classifier_test", + srcs = ["image_classifier_test.py"], + data = ["//tensorflow_lite_support/metadata/python/tests/testdata/image_classifier:test_files"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":test_utils", + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/metadata:metadata_schema_py", + "//tensorflow_lite_support/metadata/python/metadata_writers:image_classifier", + "//tensorflow_lite_support/metadata/python/metadata_writers:metadata_info", + "@absl_py//absl/testing:parameterized", + "@flatbuffers//:runtime_py", + ], +) + +py_test( + name = "object_detector_test", + srcs = ["object_detector_test.py"], + data = ["//tensorflow_lite_support/metadata/python/tests/testdata/object_detector:test_files"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":test_utils", + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/metadata:metadata_schema_py", + "//tensorflow_lite_support/metadata/python:metadata", + "//tensorflow_lite_support/metadata/python/metadata_writers:metadata_info", + "//tensorflow_lite_support/metadata/python/metadata_writers:object_detector", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "image_segmenter_test", + srcs = ["image_segmenter_test.py"], + data = ["//tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter:test_files"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":test_utils", + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/metadata/python/metadata_writers:image_segmenter", + "@flatbuffers//:runtime_py", + ], +) + +py_test( + name = "nl_classifier_test", + srcs = ["nl_classifier_test.py"], + data = ["//tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier:test_files"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":test_utils", + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/metadata/python:metadata", + "//tensorflow_lite_support/metadata/python/metadata_writers:metadata_info", + "//tensorflow_lite_support/metadata/python/metadata_writers:nl_classifier", + "@flatbuffers//:runtime_py", + ], +) + +py_test( + name = "audio_classifier_test", + srcs = ["audio_classifier_test.py"], + data = ["//tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier:test_files"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":test_utils", + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/metadata:metadata_schema_py", + "//tensorflow_lite_support/metadata/python/metadata_writers:audio_classifier", + "//tensorflow_lite_support/metadata/python/metadata_writers:metadata_info", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "bert_nl_classifier_test", + srcs = ["bert_nl_classifier_test.py"], + data = ["//tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier:test_files"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":test_utils", + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/metadata/python:metadata", + "//tensorflow_lite_support/metadata/python/metadata_writers:bert_nl_classifier", + "//tensorflow_lite_support/metadata/python/metadata_writers:metadata_info", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/audio_classifier_test.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/audio_classifier_test.py new file mode 100644 index 0000000..643f30e7 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/audio_classifier_test.py
@@ -0,0 +1,183 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for AudioClassifier.MetadataWriter.""" + +from absl.testing import parameterized +import tensorflow as tf + +from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow_lite_support.metadata.python.metadata_writers import audio_classifier +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils + +_FIXED_INPUT_SIZE_MODEL = "../testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6.tflite" +_DYNAMIC_INPUT_SIZE_MODEL = "../testdata/audio_classifier/yamnet_tfhub.tflite" +_MULTIHEAD_MODEL = "../testdata/audio_classifier/two_heads.tflite" +_YAMNET_LABEL_FILE = "../testdata/audio_classifier/yamnet_521_labels.txt" +_LABEL_FILE = "../testdata/audio_classifier/labelmap.txt" +_DEFAULT_SCORE_CALIBRATION_VALUE = 0.2 +_JSON_FOR_INFERENCE_DYNAMIC = "../testdata/audio_classifier/yamnet_tfhub.json" +_JSON_FOR_INFERENCE_FIXED = "../testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6.json" +_JSON_DEFAULT = "../testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6_default.json" +_JSON_DEFAULT_MULTIHEAD = "../testdata/audio_classifier/two_heads_default.json" +_JSON_MULTIHEAD = "../testdata/audio_classifier/two_heads.json" +_SAMPLE_RATE = 2 +_CHANNELS = 1 + + +class MetadataWriterTest(tf.test.TestCase): + + def test_create_for_inference_should_succeed_dynamic_input_shape_model(self): + writer = audio_classifier.MetadataWriter.create_for_inference( + test_utils.load_file(_DYNAMIC_INPUT_SIZE_MODEL), _SAMPLE_RATE, + _CHANNELS, [_LABEL_FILE], + metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.LOG, + _DEFAULT_SCORE_CALIBRATION_VALUE, + test_utils.create_calibration_file(self.get_temp_dir()))) + + metadata_json = writer.get_metadata_json() + expected_json = test_utils.load_file(_JSON_FOR_INFERENCE_DYNAMIC, "r") + self.assertEqual(metadata_json, expected_json) + + def test_create_for_inference_should_succeed_with_fixed_input_shape_model( + self): + writer = audio_classifier.MetadataWriter.create_for_inference( + test_utils.load_file(_FIXED_INPUT_SIZE_MODEL), _SAMPLE_RATE, _CHANNELS, + [_YAMNET_LABEL_FILE], + metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.LOG, + _DEFAULT_SCORE_CALIBRATION_VALUE, + test_utils.create_calibration_file(self.get_temp_dir()))) + + metadata_json = writer.get_metadata_json() + expected_json = test_utils.load_file(_JSON_FOR_INFERENCE_FIXED, "r") + self.assertEqual(metadata_json, expected_json) + + def test_create_from_metadata_info_by_default_should_succeed(self): + writer = audio_classifier.MetadataWriter.create_from_metadata_info( + test_utils.load_file(_FIXED_INPUT_SIZE_MODEL)) + + metadata_json = writer.get_metadata_json() + expected_json = test_utils.load_file(_JSON_DEFAULT, "r") + self.assertEqual(metadata_json, expected_json) + + def test_create_from_metadata_info_by_default_succeeds_for_multihead(self): + writer = ( + audio_classifier.MetadataWriter.create_from_metadata_info_for_multihead( + test_utils.load_file(_MULTIHEAD_MODEL))) + + metadata_json = writer.get_metadata_json() + expected_json = test_utils.load_file(_JSON_DEFAULT_MULTIHEAD, "r") + self.assertEqual(metadata_json, expected_json) + + def test_create_from_metadata_info_succeeds_for_multihead(self): + calibration_file1 = test_utils.create_calibration_file( + self.get_temp_dir(), "score_cali_1.txt") + calibration_file2 = test_utils.create_calibration_file( + self.get_temp_dir(), "score_cali_2.txt") + + general_md = metadata_info.GeneralMd(name="AudioClassifier") + input_md = metadata_info.InputAudioTensorMd( + name="audio_clip", sample_rate=_SAMPLE_RATE, channels=_CHANNELS) + # The output tensors in the model are: Identity, Identity_1 + # Create metadata in a different order to test if MetadataWriter can correct + # it. + output_head_md_1 = metadata_info.ClassificationTensorMd( + name="head1", + label_files=[ + metadata_info.LabelFileMd("labels_en_1.txt"), + metadata_info.LabelFileMd("labels_cn_1.txt") + ], + score_calibration_md=metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.LOG, + _DEFAULT_SCORE_CALIBRATION_VALUE, calibration_file1), + tensor_name="Identity_1") + output_head_md_2 = metadata_info.ClassificationTensorMd( + name="head2", + label_files=[ + metadata_info.LabelFileMd("labels_en_2.txt"), + metadata_info.LabelFileMd("labels_cn_2.txt") + ], + score_calibration_md=metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.LOG, + _DEFAULT_SCORE_CALIBRATION_VALUE, calibration_file2), + tensor_name="Identity") + + writer = ( + audio_classifier.MetadataWriter.create_from_metadata_info_for_multihead( + test_utils.load_file(_MULTIHEAD_MODEL), general_md, input_md, + [output_head_md_1, output_head_md_2])) + + metadata_json = writer.get_metadata_json() + expected_json = test_utils.load_file(_JSON_MULTIHEAD, "r") + self.assertEqual(metadata_json, expected_json) + + +class MetadataWriterSampleRateTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + { + "testcase_name": "negative", + "wrong_sample_rate": -1 + }, { + "testcase_name": "zero", + "wrong_sample_rate": 0 + }) + def test_create_for_inference_fails_with_wrong_sample_rate( + self, wrong_sample_rate): + + with self.assertRaises(ValueError) as error: + audio_classifier.MetadataWriter.create_for_inference( + test_utils.load_file(_DYNAMIC_INPUT_SIZE_MODEL), wrong_sample_rate, + _CHANNELS, [_LABEL_FILE], + metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.LOG, + _DEFAULT_SCORE_CALIBRATION_VALUE, + test_utils.create_calibration_file(self.get_temp_dir()))) + + self.assertEqual( + "sample_rate should be positive, but got {}.".format(wrong_sample_rate), + str(error.exception)) + + +class MetadataWriterChannelsTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + { + "testcase_name": "negative", + "wrong_channels": -1 + }, { + "testcase_name": "zero", + "wrong_channels": 0 + }) + def test_create_for_inference_fails_with_wrong_channels(self, wrong_channels): + + with self.assertRaises(ValueError) as error: + audio_classifier.MetadataWriter.create_for_inference( + test_utils.load_file(_DYNAMIC_INPUT_SIZE_MODEL), _SAMPLE_RATE, + wrong_channels, [_LABEL_FILE], + metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.LOG, + _DEFAULT_SCORE_CALIBRATION_VALUE, + test_utils.create_calibration_file(self.get_temp_dir()))) + + self.assertEqual( + "channels should be positive, but got {}.".format(wrong_channels), + str(error.exception)) + + +if __name__ == "__main__": + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/bert_nl_classifier_test.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/bert_nl_classifier_test.py new file mode 100644 index 0000000..730af7e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/bert_nl_classifier_test.py
@@ -0,0 +1,70 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for bert_nl_classifier.MetadataWriter.""" + +import tensorflow as tf + +from tensorflow_lite_support.metadata.python import metadata as _metadata +from tensorflow_lite_support.metadata.python.metadata_writers import bert_nl_classifier +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils + +_TEST_DIR = "tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/" +_MODEL = "../testdata/bert_nl_classifier/bert_nl_classifier_no_metadata.tflite" +_LABEL_FILE = _TEST_DIR + "labels.txt" +_VOCAB_FILE = _TEST_DIR + "mobilebert_vocab.txt" +_SP_MODEL_FILE = _TEST_DIR + "30k-clean.model" +_DELIM_REGEX_PATTERN = r"[^\w\']+" +_JSON_FOR_INFERENCE_WITH_BERT = "../testdata/bert_nl_classifier/bert_nl_classifier_with_bert_tokenizer.json" +_JSON_FOR_INFERENCE_WITH_SENTENCE_PIECE = "../testdata/bert_nl_classifier/bert_nl_classifier_with_sentence_piece.json" +_JSON_DEFAULT = "../testdata/bert_nl_classifier/bert_nl_classifier_default.json" + + +class MetadataWriterTest(tf.test.TestCase): + + def test_create_for_inference_with_bert_should_succeed(self): + writer = bert_nl_classifier.MetadataWriter.create_for_inference( + test_utils.load_file(_MODEL), + metadata_info.BertTokenizerMd(_VOCAB_FILE), [_LABEL_FILE]) + + displayer = _metadata.MetadataDisplayer.with_model_buffer(writer.populate()) + metadata_json = displayer.get_metadata_json() + expected_json = test_utils.load_file(_JSON_FOR_INFERENCE_WITH_BERT, "r") + + self.assertEqual(metadata_json, expected_json) + + def test_create_for_inference_with_sentence_piece_should_succeed(self): + writer = bert_nl_classifier.MetadataWriter.create_for_inference( + test_utils.load_file(_MODEL), + metadata_info.SentencePieceTokenizerMd(_SP_MODEL_FILE), [_LABEL_FILE]) + + displayer = _metadata.MetadataDisplayer.with_model_buffer(writer.populate()) + metadata_json = displayer.get_metadata_json() + expected_json = test_utils.load_file( + _JSON_FOR_INFERENCE_WITH_SENTENCE_PIECE, "r") + + self.assertEqual(metadata_json, expected_json) + + def test_create_from_metadata_info_by_default_should_succeed(self): + writer = bert_nl_classifier.MetadataWriter.create_from_metadata_info( + test_utils.load_file(_MODEL)) + + metadata_json = writer.get_metadata_json() + expected_json = test_utils.load_file(_JSON_DEFAULT, "r") + self.assertEqual(metadata_json, expected_json) + + +if __name__ == "__main__": + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/image_classifier_test.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/image_classifier_test.py new file mode 100644 index 0000000..5035d43d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/image_classifier_test.py
@@ -0,0 +1,83 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ImageClassifier.MetadataWriter.""" + +from absl.testing import parameterized + +import tensorflow as tf + +from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow_lite_support.metadata.python.metadata_writers import image_classifier +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils + +_FLOAT_MODEL = "../testdata/image_classifier/mobilenet_v2_1.0_224.tflite" +_QUANT_MODEL = "../testdata/image_classifier/mobilenet_v2_1.0_224_quant.tflite" +_LABEL_FILE = "../testdata/image_classifier/labels.txt" +_SCORE_CALIBRATION_FILE = "../testdata/image_classifier/score_calibration.txt" +_DEFAULT_SCORE_CALIBRATION_VALUE = 0.2 +_NORM_MEAN = 127.5 +_NORM_STD = 127.5 +_FLOAT_JSON_FOR_INFERENCE = "../testdata/image_classifier/mobilenet_v2_1.0_224.json" +_FLOAT_JSON_DEFAULT = "../testdata/image_classifier/mobilenet_v2_1.0_224_default.json" +_QUANT_JSON_FOR_INFERENCE = "../testdata/image_classifier/mobilenet_v2_1.0_224_quant.json" +_JSON_DEFAULT = "../testdata/image_classifier/mobilenet_v2_1.0_224_default.json" + + +class MetadataWriterTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + { + "testcase_name": "float_model", + "model_file": _FLOAT_MODEL, + "golden_json": _FLOAT_JSON_FOR_INFERENCE + }, { + "testcase_name": "quant_model", + "model_file": _QUANT_MODEL, + "golden_json": _QUANT_JSON_FOR_INFERENCE + }) + def test_create_for_inference_should_succeed(self, model_file, golden_json): + writer = image_classifier.MetadataWriter.create_for_inference( + test_utils.load_file(model_file), [_NORM_MEAN], [_NORM_STD], + [_LABEL_FILE], + metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.LOG, + _DEFAULT_SCORE_CALIBRATION_VALUE, + test_utils.get_resource_path(_SCORE_CALIBRATION_FILE))) + + metadata_json = writer.get_metadata_json() + expected_json = test_utils.load_file(golden_json, "r") + self.assertEqual(metadata_json, expected_json) + + @parameterized.named_parameters( + { + "testcase_name": "float_model", + "model_file": _FLOAT_MODEL, + }, { + "testcase_name": "quant_model", + "model_file": _QUANT_MODEL, + }) + def test_create_from_metadata_info_by_default_should_succeed( + self, model_file): + writer = image_classifier.MetadataWriter.create_from_metadata_info( + test_utils.load_file(model_file)) + + metadata_json = writer.get_metadata_json() + expected_json = test_utils.load_file(_JSON_DEFAULT, "r") + self.assertEqual(metadata_json, expected_json) + + +if __name__ == "__main__": + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/image_segmenter_test.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/image_segmenter_test.py new file mode 100644 index 0000000..b7b5147 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/image_segmenter_test.py
@@ -0,0 +1,50 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for image_segmenter.MetadataWriter.""" + +import tensorflow as tf + +from tensorflow_lite_support.metadata.python.metadata_writers import image_segmenter +from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils + +_MODEL = "../testdata/image_segmenter/deeplabv3.tflite" +_LABEL_FILE = "../testdata/image_segmenter/labelmap.txt" +_NORM_MEAN = 127.5 +_NORM_STD = 127.5 +_JSON_FOR_INFERENCE = "../testdata/image_segmenter/deeplabv3.json" +_JSON_DEFAULT = "../testdata/image_segmenter/deeplabv3_default.json" + + +class MetadataWriterTest(tf.test.TestCase): + + def test_create_for_inference_should_succeed(self): + writer = image_segmenter.MetadataWriter.create_for_inference( + test_utils.load_file(_MODEL), [_NORM_MEAN], [_NORM_STD], [_LABEL_FILE]) + + metadata_json = writer.get_metadata_json() + expected_json = test_utils.load_file(_JSON_FOR_INFERENCE, "r") + self.assertEqual(metadata_json, expected_json) + + def test_create_from_metadata_info_by_default_should_succeed(self): + writer = image_segmenter.MetadataWriter.create_from_metadata_info( + test_utils.load_file(_MODEL)) + + metadata_json = writer.get_metadata_json() + expected_json = test_utils.load_file(_JSON_DEFAULT, "r") + self.assertEqual(metadata_json, expected_json) + + +if __name__ == "__main__": + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/metadata_info_test.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/metadata_info_test.py new file mode 100644 index 0000000..f860f5b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/metadata_info_test.py
@@ -0,0 +1,491 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for metadata info classes.""" + +from absl.testing import parameterized +import tensorflow as tf + +import flatbuffers +from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb +from tensorflow_lite_support.metadata.python import metadata as _metadata +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils + +_SCORE_CALIBRATION_FILE = test_utils.get_resource_path( + "../testdata/image_classifier/score_calibration.txt") + + +class GeneralMdTest(tf.test.TestCase): + + _EXPECTED_GENERAL_META_JSON = "../testdata/general_meta.json" + + def test_create_metadata_should_succeed(self): + general_md = metadata_info.GeneralMd( + name="model", + version="v1", + description="A ML model.", + author="TensorFlow", + licenses="Apache") + general_metadata = general_md.create_metadata() + + # Create the Flatbuffers object and convert it to the json format. + builder = flatbuffers.Builder(0) + builder.Finish( + general_metadata.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_json = _metadata.convert_to_json(bytes(builder.Output())) + + expected_json = test_utils.load_file(self._EXPECTED_GENERAL_META_JSON, "r") + self.assertEqual(metadata_json, expected_json) + + +class AssociatedFileMdTest(tf.test.TestCase): + + _EXPECTED_META_JSON = "../testdata/associated_file_meta.json" + + def test_create_metadata_should_succeed(self): + file_md = metadata_info.AssociatedFileMd( + file_path="label.txt", + description="The label file.", + file_type=_metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS, + locale="en") + file_metadata = file_md.create_metadata() + + # Create the Flatbuffers object and convert it to the json format. + model_metadata = _metadata_fb.ModelMetadataT() + model_metadata.associatedFiles = [file_metadata] + builder = flatbuffers.Builder(0) + builder.Finish( + model_metadata.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_json = _metadata.convert_to_json(bytes(builder.Output())) + + expected_json = test_utils.load_file(self._EXPECTED_META_JSON, "r") + self.assertEqual(metadata_json, expected_json) + + +class TensorMdTest(tf.test.TestCase, parameterized.TestCase): + + _TENSOR_NAME = "input" + _TENSOR_DESCRIPTION = "The input tensor." + _TENSOR_MIN = 0 + _TENSOR_MAX = 1 + _LABEL_FILE_EN = "labels.txt" + _LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese. + _EXPECTED_FEATURE_TENSOR_JSON = "../testdata/feature_tensor_meta.json" + _EXPECTED_IMAGE_TENSOR_JSON = "../testdata/image_tensor_meta.json" + _EXPECTED_BOUNDING_BOX_TENSOR_JSON = "../testdata/bounding_box_tensor_meta.json" + + @parameterized.named_parameters( + { + "testcase_name": "feature_tensor", + "content_type": _metadata_fb.ContentProperties.FeatureProperties, + "golden_json": _EXPECTED_FEATURE_TENSOR_JSON + }, { + "testcase_name": "image_tensor", + "content_type": _metadata_fb.ContentProperties.ImageProperties, + "golden_json": _EXPECTED_IMAGE_TENSOR_JSON + }, { + "testcase_name": "bounding_box_tensor", + "content_type": _metadata_fb.ContentProperties.BoundingBoxProperties, + "golden_json": _EXPECTED_BOUNDING_BOX_TENSOR_JSON + }) + def test_create_metadata_should_succeed(self, content_type, golden_json): + associated_file1 = metadata_info.AssociatedFileMd( + file_path=self._LABEL_FILE_EN, locale="en") + associated_file2 = metadata_info.AssociatedFileMd( + file_path=self._LABEL_FILE_CN, locale="cn") + + tensor_md = metadata_info.TensorMd( + name=self._TENSOR_NAME, + description=self._TENSOR_DESCRIPTION, + min_values=[self._TENSOR_MIN], + max_values=[self._TENSOR_MAX], + content_type=content_type, + associated_files=[associated_file1, associated_file2]) + tensor_metadata = tensor_md.create_metadata() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_tensor(tensor_metadata)) + expected_json = test_utils.load_file(golden_json, "r") + self.assertEqual(metadata_json, expected_json) + + +class InputImageTensorMdTest(tf.test.TestCase, parameterized.TestCase): + + _NAME = "image" + _DESCRIPTION = "The input image." + _NORM_MEAN = (0, 127.5, 255) + _NORM_STD = (127.5, 127.5, 127.5) + _COLOR_SPACE_TYPE = _metadata_fb.ColorSpaceType.RGB + _EXPECTED_FLOAT_TENSOR_JSON = "../testdata/input_image_tensor_float_meta.json" + _EXPECTED_UINT8_TENSOR_JSON = "../testdata/input_image_tensor_uint8_meta.json" + _EXPECTED_UNSUPPORTED_TENSOR_JSON = "../testdata/input_image_tensor_unsupported_meta.json" + + @parameterized.named_parameters( + { + "testcase_name": "float", + "tensor_type": _schema_fb.TensorType.FLOAT32, + "golden_json": _EXPECTED_FLOAT_TENSOR_JSON + }, { + "testcase_name": "uint8", + "tensor_type": _schema_fb.TensorType.UINT8, + "golden_json": _EXPECTED_UINT8_TENSOR_JSON + }, { + "testcase_name": "unsupported_tensor_type", + "tensor_type": _schema_fb.TensorType.INT16, + "golden_json": _EXPECTED_UNSUPPORTED_TENSOR_JSON + }) + def test_create_metadata_should_succeed(self, tensor_type, golden_json): + tesnor_md = metadata_info.InputImageTensorMd( + name=self._NAME, + description=self._DESCRIPTION, + norm_mean=list(self._NORM_MEAN), + norm_std=list(self._NORM_STD), + color_space_type=self._COLOR_SPACE_TYPE, + tensor_type=tensor_type) + tensor_metadata = tesnor_md.create_metadata() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_tensor(tensor_metadata)) + expected_json = test_utils.load_file(golden_json, "r") + self.assertEqual(metadata_json, expected_json) + + def test_init_should_throw_exception_with_incompatible_mean_and_std(self): + norm_mean = [0] + norm_std = [1, 2] + with self.assertRaises(ValueError) as error: + metadata_info.InputImageTensorMd(norm_mean=norm_mean, norm_std=norm_std) + self.assertEqual( + f"norm_mean and norm_std are expected to be the same dim. But got " + f"{len(norm_mean)} and {len(norm_std)}", str(error.exception)) + + +class InputTextTensorMdTest(tf.test.TestCase): + + _NAME = "input text" + _DESCRIPTION = "The input string." + _VOCAB_FILE = "vocab.txt" + _DELIM_REGEX_PATTERN = r"[^\w\']+" + _EXPECTED_TENSOR_JSON = "../testdata/input_text_tesnor_meta.json" + _EXPECTED_TENSOR_DEFAULT_JSON = "../testdata/input_text_tesnor_default_meta.json" + + def test_create_metadata_should_succeed(self): + regex_tokenizer_md = metadata_info.RegexTokenizerMd( + self._DELIM_REGEX_PATTERN, self._VOCAB_FILE) + + text_tensor_md = metadata_info.InputTextTensorMd(self._NAME, + self._DESCRIPTION, + regex_tokenizer_md) + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_tensor( + text_tensor_md.create_metadata())) + expected_json = test_utils.load_file(self._EXPECTED_TENSOR_JSON, "r") + self.assertEqual(metadata_json, expected_json) + + def test_create_metadata_by_default_should_succeed(self): + text_tensor_md = metadata_info.InputTextTensorMd() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_tensor( + text_tensor_md.create_metadata())) + expected_json = test_utils.load_file(self._EXPECTED_TENSOR_DEFAULT_JSON, + "r") + self.assertEqual(metadata_json, expected_json) + + def test_create_metadata_throws_exception_with_unsupported_tokenizer(self): + invalid_tokenzier = metadata_info.BertTokenizerMd("vocab.txt") + + with self.assertRaises(ValueError) as error: + tensor_md = metadata_info.InputTextTensorMd( + tokenizer_md=invalid_tokenzier) + tensor_md.create_metadata() + + self.assertEqual( + f"The type of tokenizer_options, {type(invalid_tokenzier)}, is " + f"unsupported", str(error.exception)) + + +class InputAudioTensorMd(tf.test.TestCase): + + _NAME = "input text" + _DESCRIPTION = "The input string." + _SAMPLE_RATE = 10 + _CHANNELS = 2 + _EXPECTED_TENSOR_JSON = "../testdata/input_audio_tesnor_meta.json" + _EXPECTED_TENSOR_DEFAULT_JSON = "../testdata/input_audio_tesnor_default_meta.json" + + def test_create_metadata_should_succeed(self): + text_tensor_md = metadata_info.InputAudioTensorMd(self._NAME, + self._DESCRIPTION, + self._SAMPLE_RATE, + self._CHANNELS) + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_tensor( + text_tensor_md.create_metadata())) + expected_json = test_utils.load_file(self._EXPECTED_TENSOR_JSON, "r") + self.assertEqual(metadata_json, expected_json) + + def test_create_metadata_by_default_should_succeed(self): + audio_tensor_md = metadata_info.InputAudioTensorMd() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_tensor( + audio_tensor_md.create_metadata())) + expected_json = test_utils.load_file(self._EXPECTED_TENSOR_DEFAULT_JSON, + "r") + self.assertEqual(metadata_json, expected_json) + + def test_create_metadata_fail_with_negative_sample_rate(self): + negative_sample_rate = -1 + with self.assertRaises(ValueError) as error: + tensor_md = metadata_info.InputAudioTensorMd( + sample_rate=negative_sample_rate) + tensor_md.create_metadata() + + self.assertEqual( + f"sample_rate should be non-negative, but got {negative_sample_rate}.", + str(error.exception)) + + def test_create_metadata_fail_with_negative_channels(self): + negative_channels = -1 + with self.assertRaises(ValueError) as error: + tensor_md = metadata_info.InputAudioTensorMd(channels=negative_channels) + tensor_md.create_metadata() + + self.assertEqual( + f"channels should be non-negative, but got {negative_channels}.", + str(error.exception)) + + +class ClassificationTensorMdTest(tf.test.TestCase, parameterized.TestCase): + + _NAME = "probability" + _DESCRIPTION = "The classification result tensor." + _LABEL_FILE_EN = "labels.txt" + _LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese. + _CALIBRATION_DEFAULT_SCORE = 0.2 + _EXPECTED_FLOAT_TENSOR_JSON = "../testdata/classification_tensor_float_meta.json" + _EXPECTED_UINT8_TENSOR_JSON = "../testdata/classification_tensor_uint8_meta.json" + _EXPECTED_UNSUPPORTED_TENSOR_JSON = "../testdata/classification_tensor_unsupported_meta.json" + + @parameterized.named_parameters( + { + "testcase_name": "float", + "tensor_type": _schema_fb.TensorType.FLOAT32, + "golden_json": _EXPECTED_FLOAT_TENSOR_JSON + }, { + "testcase_name": "uint8", + "tensor_type": _schema_fb.TensorType.UINT8, + "golden_json": _EXPECTED_UINT8_TENSOR_JSON + }, { + "testcase_name": "unsupported_tensor_type", + "tensor_type": _schema_fb.TensorType.INT16, + "golden_json": _EXPECTED_UNSUPPORTED_TENSOR_JSON + }) + def test_create_metadata_should_succeed(self, tensor_type, golden_json): + label_file_en = metadata_info.LabelFileMd( + file_path=self._LABEL_FILE_EN, locale="en") + label_file_cn = metadata_info.LabelFileMd( + file_path=self._LABEL_FILE_CN, locale="cn") + score_calibration_md = metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.IDENTITY, + self._CALIBRATION_DEFAULT_SCORE, _SCORE_CALIBRATION_FILE) + + tesnor_md = metadata_info.ClassificationTensorMd( + name=self._NAME, + description=self._DESCRIPTION, + label_files=[label_file_en, label_file_cn], + tensor_type=tensor_type, + score_calibration_md=score_calibration_md) + tensor_metadata = tesnor_md.create_metadata() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_tensor(tensor_metadata)) + expected_json = test_utils.load_file(golden_json, "r") + self.assertEqual(metadata_json, expected_json) + + +class CategoryTensorMdTest(tf.test.TestCase, parameterized.TestCase): + + _NAME = "category" + _DESCRIPTION = "The category tensor." + _LABEL_FILE_EN = "labels.txt" + _LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese. + _EXPECTED_TENSOR_JSON = "../testdata/category_tensor_float_meta.json" + + def test_create_metadata_should_succeed(self): + label_file_en = metadata_info.LabelFileMd( + file_path=self._LABEL_FILE_EN, locale="en") + label_file_cn = metadata_info.LabelFileMd( + file_path=self._LABEL_FILE_CN, locale="cn") + tesnor_md = metadata_info.CategoryTensorMd( + name=self._NAME, + description=self._DESCRIPTION, + label_files=[label_file_en, label_file_cn]) + tensor_metadata = tesnor_md.create_metadata() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_tensor(tensor_metadata)) + expected_json = test_utils.load_file(self._EXPECTED_TENSOR_JSON, "r") + self.assertEqual(metadata_json, expected_json) + + +class RegexTokenizerMdTest(tf.test.TestCase): + + _VOCAB_FILE = "vocab.txt" + _DELIM_REGEX_PATTERN = r"[^\w\']+" + _EXPECTED_TENSOR_JSON = "../testdata/regex_tokenizer_meta.json" + + def test_create_metadata_should_succeed(self): + tokenizer_md = metadata_info.RegexTokenizerMd(self._DELIM_REGEX_PATTERN, + self._VOCAB_FILE) + tokenizer_metadata = tokenizer_md.create_metadata() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_process_uint(tokenizer_metadata)) + expected_json = test_utils.load_file(self._EXPECTED_TENSOR_JSON, "r") + self.assertEqual(metadata_json, expected_json) + + +class BertTokenizerMdTest(tf.test.TestCase): + + _VOCAB_FILE = "vocab.txt" + _EXPECTED_TENSOR_JSON = "../testdata/bert_tokenizer_meta.json" + + def test_create_metadata_should_succeed(self): + tokenizer_md = metadata_info.BertTokenizerMd(self._VOCAB_FILE) + tokenizer_metadata = tokenizer_md.create_metadata() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_process_uint(tokenizer_metadata)) + expected_json = test_utils.load_file(self._EXPECTED_TENSOR_JSON, "r") + self.assertEqual(metadata_json, expected_json) + + +class SentencePieceTokenizerMdTest(tf.test.TestCase): + + _VOCAB_FILE = "vocab.txt" + _SP_MODEL = "sp.model" + _EXPECTED_TENSOR_JSON = "../testdata/sentence_piece_tokenizer_meta.json" + + def test_create_metadata_should_succeed(self): + tokenizer_md = metadata_info.SentencePieceTokenizerMd( + self._SP_MODEL, self._VOCAB_FILE) + tokenizer_metadata = tokenizer_md.create_metadata() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_process_uint(tokenizer_metadata)) + expected_json = test_utils.load_file(self._EXPECTED_TENSOR_JSON, "r") + self.assertEqual(metadata_json, expected_json) + + +class ScoreCalibrationMdTest(tf.test.TestCase): + _DEFAULT_VALUE = 0.2 + _EXPECTED_TENSOR_JSON = "../testdata/score_calibration_tensor_meta.json" + _EXPECTED_MODEL_META_JSON = "../testdata/score_calibration_file_meta.json" + + def test_create_metadata_should_succeed(self): + score_calibration_md = metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.LOG, self._DEFAULT_VALUE, + _SCORE_CALIBRATION_FILE) + score_calibration_metadata = score_calibration_md.create_metadata() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_process_uint( + score_calibration_metadata)) + expected_json = test_utils.load_file(self._EXPECTED_TENSOR_JSON, "r") + self.assertEqual(metadata_json, expected_json) + + def test_create_score_calibration_file_md_should_succeed(self): + score_calibration_md = metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.LOG, self._DEFAULT_VALUE, + _SCORE_CALIBRATION_FILE) + score_calibration_file_md = ( + score_calibration_md.create_score_calibration_file_md()) + file_metadata = score_calibration_file_md.create_metadata() + + # Create the Flatbuffers object and convert it to the json format. + model_metadata = _metadata_fb.ModelMetadataT() + model_metadata.associatedFiles = [file_metadata] + builder = flatbuffers.Builder(0) + builder.Finish( + model_metadata.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_json = _metadata.convert_to_json(bytes(builder.Output())) + + expected_json = test_utils.load_file(self._EXPECTED_MODEL_META_JSON, "r") + self.assertEqual(metadata_json, expected_json) + + def test_create_score_calibration_file_fails_with_less_colunms(self): + malformed_calibration_file = test_utils.create_calibration_file( + self.get_temp_dir(), content="1.0,0.2") + + with self.assertRaisesRegex( + ValueError, + "Expected empty lines or 3 or 4 parameters per line in score" + + " calibration file, but got 2."): + metadata_info.ScoreCalibrationMd(_metadata_fb.ScoreTransformationType.LOG, + self._DEFAULT_VALUE, + malformed_calibration_file) + + def test_create_score_calibration_file_fails_with_negative_scale(self): + malformed_calibration_file = test_utils.create_calibration_file( + self.get_temp_dir(), content="-1.0,0.2,0.1") + + with self.assertRaisesRegex( + ValueError, "Expected scale to be a non-negative value, but got -1.0."): + metadata_info.ScoreCalibrationMd(_metadata_fb.ScoreTransformationType.LOG, + self._DEFAULT_VALUE, + malformed_calibration_file) + + +def _create_dummy_model_metadata_with_tensor( + tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes: + # Create a dummy model using the tensor metadata. + subgraph_metadata = _metadata_fb.SubGraphMetadataT() + subgraph_metadata.inputTensorMetadata = [tensor_metadata] + model_metadata = _metadata_fb.ModelMetadataT() + model_metadata.subgraphMetadata = [subgraph_metadata] + + # Create the Flatbuffers object and convert it to the json format. + builder = flatbuffers.Builder(0) + builder.Finish( + model_metadata.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + return bytes(builder.Output()) + + +def _create_dummy_model_metadata_with_process_uint( + process_unit_metadata: _metadata_fb.ProcessUnitT) -> bytes: + # Create a dummy model using the tensor metadata. + subgraph_metadata = _metadata_fb.SubGraphMetadataT() + subgraph_metadata.inputProcessUnits = [process_unit_metadata] + model_metadata = _metadata_fb.ModelMetadataT() + model_metadata.subgraphMetadata = [subgraph_metadata] + + # Create the Flatbuffers object and convert it to the json format. + builder = flatbuffers.Builder(0) + builder.Finish( + model_metadata.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + return bytes(builder.Output()) + + +if __name__ == "__main__": + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/metadata_writer_test.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/metadata_writer_test.py new file mode 100644 index 0000000..574da0f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/metadata_writer_test.py
@@ -0,0 +1,203 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MetadataWriter.""" + +import os +import tensorflow as tf + +from tensorflow.python.platform import resource_loader +from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow_lite_support.metadata.python import metadata as _metadata +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_writer +from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils + +_MODEL = "../testdata/mobilenet_v2_1.0_224_quant.tflite" +_MULTI_INPUTS_MODEL = "../testdata/question_answerer/mobilebert_float.tflite" +_MULTI_OUTPUTS_MODEL = "../testdata/audio_classifier/two_heads.tflite" +_MODEL_NAME = "mobilenet_v2_1.0_224_quant" +_INPUT_NAME = "image" +_OUTPUT_NAME = "probability" +_LABEL_FILE = resource_loader.get_path_to_datafile("../testdata/labels.txt") +_EXPECTED_DUMMY_JSON = "../testdata/mobilenet_v2_1.0_224_quant_dummy.json" +_EXPECTED_META_INFO_JSON = "../testdata/mobilenet_v2_1.0_224_quant_meta_info_.json" +_EXPECTED_DEFAULT_JSON = "../testdata/mobilenet_v2_1.0_224_quant_default.json" +# Before populated into the model, metadata does not have the verson string +_EXPECTED_DUMMY_NO_VERSION_JSON = "../testdata/mobilenet_v2_1.0_224_quant_dummy_no_version.json" +_EXPECTED_MULTI_INPUTS_JSON = "../testdata/multi_inputs.json" +_EXPECTED_MULTI_OUTPUTS_JSON = "../testdata/multi_outputs.json" + + +class MetadataWriterTest(tf.test.TestCase): + + def test_populate_from_metadata_should_succeed(self): + model_buffer = test_utils.load_file(_MODEL) + model_metadata, input_metadata, output_metadata = ( + self._create_dummy_metadata()) + + writer = metadata_writer.MetadataWriter.create_from_metadata( + model_buffer, model_metadata, [input_metadata], [output_metadata], + [_LABEL_FILE]) + model_with_metadata = writer.populate() + + self._assert_correct_metadata(model_with_metadata, _EXPECTED_DUMMY_JSON, + _LABEL_FILE) + + def test_create_from_metadata_with_default_value_should_succeed(self): + model_buffer = test_utils.load_file(_MODEL) + + writer = metadata_writer.MetadataWriter.create_from_metadata(model_buffer) + model_with_metadata = writer.populate() + + self._assert_correct_metadata(model_with_metadata, _EXPECTED_DEFAULT_JSON) + + def test_populate_create_from_metadata_info_should_succeed(self): + model_buffer = test_utils.load_file(_MODEL) + general_md = metadata_info.GeneralMd(name=_MODEL_NAME) + input_md = metadata_info.TensorMd(name=_INPUT_NAME) + output_md = metadata_info.TensorMd(name=_OUTPUT_NAME) + + writer = metadata_writer.MetadataWriter.create_from_metadata_info( + model_buffer, general_md, [input_md], [output_md], [_LABEL_FILE]) + model_with_metadata = writer.populate() + + self._assert_correct_metadata(model_with_metadata, _EXPECTED_META_INFO_JSON, + _LABEL_FILE) + + def test_create_from_metadata_info_with_default_value_should_succeed(self): + model_buffer = test_utils.load_file(_MODEL) + + writer = metadata_writer.MetadataWriter.create_from_metadata_info( + model_buffer) + model_with_metadata = writer.populate() + + self._assert_correct_metadata(model_with_metadata, _EXPECTED_DEFAULT_JSON) + + def test_create_from_metadata_info_with_input_tensor_name_should_succeed( + self): + model_buffer = test_utils.load_file(_MULTI_INPUTS_MODEL) + # The input tensors in the model are: input_ids, input_mask, segment_ids. + input_md_1 = metadata_info.TensorMd(name="ids", tensor_name="input_ids") + input_md_2 = metadata_info.TensorMd(name="mask", tensor_name="input_mask") + input_md_3 = metadata_info.TensorMd( + name="segment", tensor_name="segment_ids") + + # Create input metadata in a different order to test if MetadataWriter can + # correct it. + writer = metadata_writer.MetadataWriter.create_from_metadata_info( + model_buffer, input_md=[input_md_2, input_md_3, input_md_1]) + model_with_metadata = writer.populate() + + self._assert_correct_metadata(model_with_metadata, + _EXPECTED_MULTI_INPUTS_JSON) + + def test_create_from_metadata_info_fails_with_wrong_input_tesnor_name(self): + model_buffer = test_utils.load_file(_MODEL) + input_md = metadata_info.TensorMd(tensor_name="wrong_tensor_name") + with self.assertRaises(ValueError) as error: + metadata_writer.MetadataWriter.create_from_metadata_info( + model_buffer, input_md=[input_md]) + self.assertEqual( + "The tensor names from arguments (['wrong_tensor_name']) do not match" + " the tensor names read from the model (['input']).", + str(error.exception)) + + def test_create_from_metadata_info_with_output_tensor_name_should_succeed( + self): + model_buffer = test_utils.load_file(_MULTI_OUTPUTS_MODEL) + # The output tensors in the model are: Identity, Identity_1 + output_md_1 = metadata_info.TensorMd( + name="Identity", tensor_name="Identity") + output_md_2 = metadata_info.TensorMd( + name="Identity 1", tensor_name="Identity_1") + + # Create output metadata in a different order to test if MetadataWriter can + # correct it. + writer = metadata_writer.MetadataWriter.create_from_metadata_info( + model_buffer, output_md=[output_md_2, output_md_1]) + model_with_metadata = writer.populate() + + self._assert_correct_metadata(model_with_metadata, + _EXPECTED_MULTI_OUTPUTS_JSON) + + def test_create_from_metadata_info_fails_with_wrong_output_tesnor_name(self): + model_buffer = test_utils.load_file(_MODEL) + output_md = metadata_info.TensorMd(tensor_name="wrong_tensor_name") + with self.assertRaises(ValueError) as error: + metadata_writer.MetadataWriter.create_from_metadata_info( + model_buffer, output_md=[output_md]) + self.assertEqual( + "The tensor names from arguments (['wrong_tensor_name']) do not match" + " the tensor names read from the model (['output']).", + str(error.exception)) + + def test_get_metadata_json_should_succeed(self): + model_buffer = test_utils.load_file(_MODEL) + model_metadata, input_metadata, output_metadata = ( + self._create_dummy_metadata()) + + writer = metadata_writer.MetadataWriter.create_from_metadata( + model_buffer, model_metadata, [input_metadata], [output_metadata], + [_LABEL_FILE]) + metadata_json = writer.get_metadata_json() + + expected_json = test_utils.load_file(_EXPECTED_DUMMY_NO_VERSION_JSON, "r") + self.assertEqual(metadata_json, expected_json) + + def test_get_populated_metadata_json_should_succeed(self): + model_buffer = test_utils.load_file(_MODEL) + model_metadata, input_metadata, output_metadata = ( + self._create_dummy_metadata()) + + writer = metadata_writer.MetadataWriter.create_from_metadata( + model_buffer, model_metadata, [input_metadata], [output_metadata], + [_LABEL_FILE]) + metadata_json = writer.get_populated_metadata_json() + + expected_json = test_utils.load_file(_EXPECTED_DUMMY_JSON, "r") + self.assertEqual(metadata_json, expected_json) + + def _assert_correct_metadata(self, + model_with_metadata, + expected_json_file, + expected_label_file=None): + # Verify if the metadata populated is correct. + displayer = _metadata.MetadataDisplayer.with_model_buffer( + model_with_metadata) + metadata_json = displayer.get_metadata_json() + expected_json = test_utils.load_file(expected_json_file, "r") + self.assertEqual(metadata_json, expected_json) + + # Verify if the associated file is packed as expected. + if expected_label_file: + packed_files = displayer.get_packed_associated_file_list() + expected_packed_files = [os.path.basename(expected_label_file)] + self.assertEqual(set(packed_files), set(expected_packed_files)) + + def _create_dummy_metadata(self): + # Create dummy input metadata + input_metadata = _metadata_fb.TensorMetadataT() + input_metadata.name = _INPUT_NAME + # Create dummy output metadata + output_metadata = _metadata_fb.TensorMetadataT() + output_metadata.name = _OUTPUT_NAME + # Create dummy model_metadata + model_metadata = _metadata_fb.ModelMetadataT() + model_metadata.name = _MODEL_NAME + return model_metadata, input_metadata, output_metadata + + +if __name__ == "__main__": + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/nl_classifier_test.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/nl_classifier_test.py new file mode 100644 index 0000000..eda6913 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/nl_classifier_test.py
@@ -0,0 +1,56 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for nl_classifier.MetadataWriter.""" + +import tensorflow as tf + +from tensorflow_lite_support.metadata.python import metadata as _metadata +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.metadata_writers import nl_classifier +from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils + +_TEST_DIR = "tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/" +_MODEL = "../testdata/nl_classifier/movie_review.tflite" +_LABEL_FILE = _TEST_DIR + "labels.txt" +_VOCAB_FILE = _TEST_DIR + "vocab.txt" +_DELIM_REGEX_PATTERN = r"[^\w\']+" +_JSON_FOR_INFERENCE_REGEX = "../testdata/nl_classifier/movie_review_regex.json" +_JSON_DEFAULT = "../testdata/nl_classifier/movie_review_default.json" + + +class MetadataWriterTest(tf.test.TestCase): + + def test_create_for_inference_should_succeed(self): + writer = nl_classifier.MetadataWriter.create_for_inference( + test_utils.load_file(_MODEL), + metadata_info.RegexTokenizerMd(_DELIM_REGEX_PATTERN, _VOCAB_FILE), + [_LABEL_FILE]) + + displayer = _metadata.MetadataDisplayer.with_model_buffer(writer.populate()) + metadata_json = displayer.get_metadata_json() + expected_json = test_utils.load_file(_JSON_FOR_INFERENCE_REGEX, "r") + self.assertEqual(metadata_json, expected_json) + + def test_create_from_metadata_info_by_default_should_succeed(self): + writer = nl_classifier.MetadataWriter.create_from_metadata_info( + test_utils.load_file(_MODEL)) + + metadata_json = writer.get_metadata_json() + expected_json = test_utils.load_file(_JSON_DEFAULT, "r") + self.assertEqual(metadata_json, expected_json) + + +if __name__ == "__main__": + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/object_detector_test.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/object_detector_test.py new file mode 100644 index 0000000..ddb6d68 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/object_detector_test.py
@@ -0,0 +1,129 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ObjectDetector.MetadataWriter.""" + +import json +import os +import tempfile + +from absl.testing import parameterized +import tensorflow as tf + +from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow_lite_support.metadata.python import metadata +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.metadata_writers import object_detector +from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils + +_PATH = "../testdata/object_detector/" +_MODEL = "../testdata/object_detector/ssd_mobilenet_v1.tflite" +_LABEL_FILE = "../testdata/object_detector/labelmap.txt" +_NORM_MEAN = 127.5 +_NORM_STD = 127.5 +_JSON_FOR_INFERENCE = "../testdata/object_detector/ssd_mobilenet_v1.json" +_JSON_DEFAULT = "../testdata/object_detector/ssd_mobilenet_v1_default.json" + +_MODEL_COCO = "../testdata/object_detector/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite" +_SCORE_CALIBRATION_FILE = "../testdata/object_detector/score_calibration.csv" +_SCORE_CALIBRATION_DEFAULT_SCORE = 0.2 +_JSON_FOR_SCORE_CALIBRATION = "../testdata/object_detector/coco_ssd_mobilenet_v1_score_calibration.json" + +_DUMMY_SCORE_CALIBRATION_FILE = "../testdata/object_detector/score_calibration_dummy.csv" +_DUMMY_SCORE_CALIBRATION_DEFAULT_SCORE = 0.0 +_JSON_FOR_DUMMY_SCORE_CALIBRATION = "../testdata/object_detector/coco_ssd_mobilenet_v1_score_calibration_dummy.json" +_EXPECTED_DUMMY_MODEL = "../testdata/object_detector/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflite" + + +class MetadataWriterTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + self._label_file = test_utils.get_resource_path(_LABEL_FILE) + self._score_file = test_utils.get_resource_path(_SCORE_CALIBRATION_FILE) + self._dummy_score_file = test_utils.get_resource_path( + _DUMMY_SCORE_CALIBRATION_FILE) + + @parameterized.parameters( + ("ssd_mobilenet_v1"), + ("efficientdet_lite0_v1"), + ) + def test_create_for_inference_should_succeed(self, model_name): + model_path = os.path.join(_PATH, model_name + ".tflite") + writer = object_detector.MetadataWriter.create_for_inference( + test_utils.load_file(model_path), [_NORM_MEAN], [_NORM_STD], + [self._label_file]) + + json_path = os.path.join(_PATH, model_name + ".json") + self._validate_metadata(writer, json_path) + self._validate_populated_model(writer) + + @parameterized.parameters( + ("ssd_mobilenet_v1"), + ("efficientdet_lite0_v1"), + ) + def test_create_from_metadata_info_by_default_should_succeed( + self, model_name: str): + model_path = os.path.join(_PATH, model_name + ".tflite") + writer = object_detector.MetadataWriter.create_from_metadata_info( + test_utils.load_file(model_path)) + json_path = os.path.join(_PATH, model_name + "_default.json") + self._validate_metadata(writer, json_path) + self._validate_populated_model(writer) + + def test_create_for_inference_score_calibration_should_succeed(self): + score_calibration_md = metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.INVERSE_LOGISTIC, + _SCORE_CALIBRATION_DEFAULT_SCORE, + self._score_file, + ) + writer = object_detector.MetadataWriter.create_for_inference( + test_utils.load_file(_MODEL_COCO), [_NORM_MEAN], [_NORM_STD], + [self._label_file], score_calibration_md) + self._validate_metadata(writer, _JSON_FOR_SCORE_CALIBRATION) + self._validate_populated_model(writer) + + def test_create_for_inference_dummy_score_calibration_should_succeed(self): + score_calibration_md = metadata_info.ScoreCalibrationMd( + _metadata_fb.ScoreTransformationType.INVERSE_LOGISTIC, + _DUMMY_SCORE_CALIBRATION_DEFAULT_SCORE, + self._dummy_score_file, + ) + writer = object_detector.MetadataWriter.create_for_inference( + test_utils.load_file(_MODEL_COCO), [_NORM_MEAN], [_NORM_STD], + [self._label_file], score_calibration_md) + self._validate_metadata(writer, _JSON_FOR_DUMMY_SCORE_CALIBRATION) + self._validate_populated_model(writer) + + # Test if populated model is equivalent to the expected model. + metadata_dict = json.loads(writer.get_metadata_json()) + displayer = metadata.MetadataDisplayer.with_model_buffer( + test_utils.load_file(_EXPECTED_DUMMY_MODEL)) + expected_metadata_dict = json.loads(displayer.get_metadata_json()) + self.assertDictContainsSubset(metadata_dict, expected_metadata_dict) + + def _validate_metadata(self, writer, expected_json_file): + metadata_json = writer.get_metadata_json() + expected_json = test_utils.load_file(expected_json_file, "r") + self.assertEqual(metadata_json, expected_json) + + def _validate_populated_model(self, writer): + with tempfile.NamedTemporaryFile() as tmp: + with open(tmp.name, "wb") as f: + f.write(writer.populate()) + self.assertGreater(os.path.getsize(tmp.name), 0) + + +if __name__ == "__main__": + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/test_utils.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/test_utils.py new file mode 100644 index 0000000..52b0ea5 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/test_utils.py
@@ -0,0 +1,41 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test utils for MetadataWriter.""" + +import os +from typing import Union +from tensorflow.python.platform import resource_loader + + +def create_calibration_file(file_dir: str, + file_name: str = "score_calibration.txt", + content: str = "1.0,2.0,3.0,4.0") -> str: + """Creates the calibration file.""" + calibration_file = os.path.join(file_dir, file_name) + with open(calibration_file, mode="w") as file: + file.write(content) + return calibration_file + + +def load_file(file_name: str, mode: str = "rb") -> Union[str, bytes]: + """Loads files from resources.""" + file_path = get_resource_path(file_name) + with open(file_path, mode) as file: + return file.read() + + +def get_resource_path(file_name: str) -> str: + """Gets resource path from the loader.""" + return resource_loader.get_path_to_datafile(file_name)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/writer_utils_test.py b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/writer_utils_test.py new file mode 100644 index 0000000..b0502d3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_writers/writer_utils_test.py
@@ -0,0 +1,121 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for wrtier util methods.""" + +import array +import tensorflow as tf + +from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils +from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils + +_FLOAT_TYPE = _schema_fb.TensorType.FLOAT32 +_UINT8_TYPE = _schema_fb.TensorType.UINT8 +# mobilebert_float.tflite has 1 input tensor and 4 output tensors. +_MODEL_NAME = "../testdata/object_detector/ssd_mobilenet_v1.tflite" +_IMAGE_TENSOR_INDEX = 0 +_EXPECTED_INPUT_TYPES = _UINT8_TYPE +_EXPECTED_INPUT_IMAGE_SHAPE = (1, 300, 300, 3) +_EXPECTED_OUTPUT_TYPES = (_FLOAT_TYPE, _FLOAT_TYPE, _FLOAT_TYPE, _FLOAT_TYPE) +_EXOECTED_INPUT_TENSOR_NAMES = "normalized_input_image_tensor" +_EXOECTED_OUTPUT_TENSOR_NAMES = ("TFLite_Detection_PostProcess", + "TFLite_Detection_PostProcess:1", + "TFLite_Detection_PostProcess:2", + "TFLite_Detection_PostProcess:3") + + +class WriterUtilsTest(tf.test.TestCase): + + def test_compute_flat_size(self): + shape = array.array("i", [1, 2, 3]) + expected_flat_size = 6 + + flat_size = writer_utils.compute_flat_size(shape) + self.assertEqual(flat_size, expected_flat_size) + + def test_compute_flat_size_with_none_shape(self): + shape = None + expected_flat_size = 0 + + flat_size = writer_utils.compute_flat_size(shape) + self.assertEqual(flat_size, expected_flat_size) + + def test_get_input_tensor_names(self): + tensor_names = writer_utils.get_input_tensor_names( + model_buffer=test_utils.load_file(_MODEL_NAME)) + self.assertEqual(tensor_names, [_EXOECTED_INPUT_TENSOR_NAMES]) + + def test_get_output_tensor_names(self): + tensor_names = writer_utils.get_output_tensor_names( + model_buffer=test_utils.load_file(_MODEL_NAME)) + self.assertEqual(tensor_names, list(_EXOECTED_OUTPUT_TENSOR_NAMES)) + + def test_get_input_tensor_types(self): + tensor_types = writer_utils.get_input_tensor_types( + model_buffer=test_utils.load_file(_MODEL_NAME)) + self.assertEqual(tensor_types, [_EXPECTED_INPUT_TYPES]) + + def test_get_output_tensor_types(self): + tensor_types = writer_utils.get_output_tensor_types( + model_buffer=test_utils.load_file(_MODEL_NAME)) + self.assertEqual(tensor_types, list(_EXPECTED_OUTPUT_TYPES)) + + def test_get_input_tensor_shape(self): + tensor_shape = writer_utils.get_input_tensor_shape( + test_utils.load_file(_MODEL_NAME), _IMAGE_TENSOR_INDEX) + self.assertEqual(list(tensor_shape), list(_EXPECTED_INPUT_IMAGE_SHAPE)) + + def test_save_and_load_file(self): + expected_file_bytes = b"This is a test file." + file_path = self.create_tempfile().full_path + + writer_utils.save_file(expected_file_bytes, file_path) + file_bytes = writer_utils.load_file(file_path) + self.assertEqual(file_bytes, expected_file_bytes) + + def test_get_tokenizer_associated_files_with_bert_tokenizer(self): + # Create Bert tokenizer + vocab_file = "vocab.txt" + tokenizer_md = metadata_info.BertTokenizerMd(vocab_file) + + associated_files = writer_utils.get_tokenizer_associated_files( + tokenizer_md.create_metadata().options) + self.assertEqual(associated_files, [vocab_file]) + + def test_get_tokenizer_associated_files_with_sentence_piece_tokenizer(self): + # Create Sentence Piece tokenizer + vocab_file = "vocab.txt" + sp_model = "sp.model" + tokenizer_md = metadata_info.SentencePieceTokenizerMd(sp_model, vocab_file) + + associated_files = writer_utils.get_tokenizer_associated_files( + tokenizer_md.create_metadata().options) + self.assertEqual(set(associated_files), set([vocab_file, sp_model])) + + def test_get_tokenizer_associated_files_with_regex_tokenizer(self): + # Create Regex tokenizer + delim_regex_pattern = r"[^\w\']+" + vocab_file = "vocab.txt" + tokenizer_md = metadata_info.RegexTokenizerMd(delim_regex_pattern, + vocab_file) + + associated_files = writer_utils.get_tokenizer_associated_files( + tokenizer_md.create_metadata().options) + self.assertEqual(associated_files, [vocab_file]) + + +if __name__ == "__main__": + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/BUILD new file mode 100644 index 0000000..0f8caca --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/BUILD
@@ -0,0 +1,13 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "test_files", + srcs = glob([ + "*.json", + "*.tflite", + "*.txt", + ]), +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/associated_file_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/associated_file_meta.json new file mode 100644 index 0000000..2a3d47b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/associated_file_meta.json
@@ -0,0 +1,10 @@ +{ + "associated_files": [ + { + "name": "label.txt", + "description": "The label file.", + "type": "TENSOR_AXIS_LABELS", + "locale": "en" + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/BUILD new file mode 100644 index 0000000..0f8caca --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/BUILD
@@ -0,0 +1,13 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "test_files", + srcs = glob([ + "*.json", + "*.tflite", + "*.txt", + ]), +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/daredevil_sound_recognizer_320ms.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/daredevil_sound_recognizer_320ms.json new file mode 100644 index 0000000..8881154 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/daredevil_sound_recognizer_320ms.json
@@ -0,0 +1,63 @@ +{ + "name": "AudioClassifier", + "description": "Identify the most prominent type in the audio clip from a known set of categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "audio_clip", + "description": "Input audio clip to be classified.", + "content": { + "content_properties_type": "AudioProperties", + "content_properties": { + "sample_rate": 2, + "channels": 1 + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Scores of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "LOG", + "default_score": 0.2 + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labelmap.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/labelmap.txt b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/labelmap.txt new file mode 100644 index 0000000..15f635e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/labelmap.txt
@@ -0,0 +1,2 @@ +speech +laughter
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/two_heads.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/two_heads.json new file mode 100644 index 0000000..a572614 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/two_heads.json
@@ -0,0 +1,95 @@ +{ + "name": "AudioClassifier", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "audio_clip", + "content": { + "content_properties_type": "AudioProperties", + "content_properties": { + "sample_rate": 2, + "channels": 1 + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "head2", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "LOG", + "default_score": 0.2 + } + } + ], + "stats": { + }, + "associated_files": [ + { + "name": "labels_en_2.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "labels_cn_2.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "score_cali_2.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + }, + { + "name": "head1", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "LOG", + "default_score": 0.2 + } + } + ], + "stats": { + }, + "associated_files": [ + { + "name": "labels_en_1.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "labels_cn_1.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "score_cali_1.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/two_heads.tflite b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/two_heads.tflite new file mode 100644 index 0000000..5ad4ffe --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/two_heads.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/two_heads_default.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/two_heads_default.json new file mode 100644 index 0000000..9f98f84 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/two_heads_default.json
@@ -0,0 +1,29 @@ +{ + "name": "AudioClassifier", + "description": "Identify the most prominent type in the audio clip from a known set of categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "audio_clip", + "description": "Input audio clip to be classified.", + "content": { + "content_properties_type": "AudioProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "Identity" + }, + { + "name": "Identity_1" + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_521_labels.txt b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_521_labels.txt new file mode 100644 index 0000000..c3e864e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_521_labels.txt
@@ -0,0 +1,521 @@ +Speech +Child speech, kid speaking +Conversation +Narration, monologue +Babbling +Speech synthesizer +Shout +Bellow +Whoop +Yell +Children shouting +Screaming +Whispering +Laughter +Baby laughter +Giggle +Snicker +Belly laugh +Chuckle, chortle +Crying, sobbing +Baby cry, infant cry +Whimper +Wail, moan +Sigh +Singing +Choir +Yodeling +Chant +Mantra +Child singing +Synthetic singing +Rapping +Humming +Groan +Grunt +Whistling +Breathing +Wheeze +Snoring +Gasp +Pant +Snort +Cough +Throat clearing +Sneeze +Sniff +Run +Shuffle +Walk, footsteps +Chewing, mastication +Biting +Gargling +Stomach rumble +Burping, eructation +Hiccup +Fart +Hands +Finger snapping +Clapping +Heart sounds, heartbeat +Heart murmur +Cheering +Applause +Chatter +Crowd +Hubbub, speech noise, speech babble +Children playing +Animal +Domestic animals, pets +Dog +Bark +Yip +Howl +Bow-wow +Growling +Whimper (dog) +Cat +Purr +Meow +Hiss +Caterwaul +Livestock, farm animals, working animals +Horse +Clip-clop +Neigh, whinny +Cattle, bovinae +Moo +Cowbell +Pig +Oink +Goat +Bleat +Sheep +Fowl +Chicken, rooster +Cluck +Crowing, cock-a-doodle-doo +Turkey +Gobble +Duck +Quack +Goose +Honk +Wild animals +Roaring cats (lions, tigers) +Roar +Bird +Bird vocalization, bird call, bird song +Chirp, tweet +Squawk +Pigeon, dove +Coo +Crow +Caw +Owl +Hoot +Bird flight, flapping wings +Canidae, dogs, wolves +Rodents, rats, mice +Mouse +Patter +Insect +Cricket +Mosquito +Fly, housefly +Buzz +Bee, wasp, etc. +Frog +Croak +Snake +Rattle +Whale vocalization +Music +Musical instrument +Plucked string instrument +Guitar +Electric guitar +Bass guitar +Acoustic guitar +Steel guitar, slide guitar +Tapping (guitar technique) +Strum +Banjo +Sitar +Mandolin +Zither +Ukulele +Keyboard (musical) +Piano +Electric piano +Organ +Electronic organ +Hammond organ +Synthesizer +Sampler +Harpsichord +Percussion +Drum kit +Drum machine +Drum +Snare drum +Rimshot +Drum roll +Bass drum +Timpani +Tabla +Cymbal +Hi-hat +Wood block +Tambourine +Rattle (instrument) +Maraca +Gong +Tubular bells +Mallet percussion +Marimba, xylophone +Glockenspiel +Vibraphone +Steelpan +Orchestra +Brass instrument +French horn +Trumpet +Trombone +Bowed string instrument +String section +Violin, fiddle +Pizzicato +Cello +Double bass +Wind instrument, woodwind instrument +Flute +Saxophone +Clarinet +Harp +Bell +Church bell +Jingle bell +Bicycle bell +Tuning fork +Chime +Wind chime +Change ringing (campanology) +Harmonica +Accordion +Bagpipes +Didgeridoo +Shofar +Theremin +Singing bowl +Scratching (performance technique) +Pop music +Hip hop music +Beatboxing +Rock music +Heavy metal +Punk rock +Grunge +Progressive rock +Rock and roll +Psychedelic rock +Rhythm and blues +Soul music +Reggae +Country +Swing music +Bluegrass +Funk +Folk music +Middle Eastern music +Jazz +Disco +Classical music +Opera +Electronic music +House music +Techno +Dubstep +Drum and bass +Electronica +Electronic dance music +Ambient music +Trance music +Music of Latin America +Salsa music +Flamenco +Blues +Music for children +New-age music +Vocal music +A capella +Music of Africa +Afrobeat +Christian music +Gospel music +Music of Asia +Carnatic music +Music of Bollywood +Ska +Traditional music +Independent music +Song +Background music +Theme music +Jingle (music) +Soundtrack music +Lullaby +Video game music +Christmas music +Dance music +Wedding music +Happy music +Sad music +Tender music +Exciting music +Angry music +Scary music +Wind +Rustling leaves +Wind noise (microphone) +Thunderstorm +Thunder +Water +Rain +Raindrop +Rain on surface +Stream +Waterfall +Ocean +Waves, surf +Steam +Gurgling +Fire +Crackle +Vehicle +Boat, Water vehicle +Sailboat, sailing ship +Rowboat, canoe, kayak +Motorboat, speedboat +Ship +Motor vehicle (road) +Car +Vehicle horn, car horn, honking +Toot +Car alarm +Power windows, electric windows +Skidding +Tire squeal +Car passing by +Race car, auto racing +Truck +Air brake +Air horn, truck horn +Reversing beeps +Ice cream truck, ice cream van +Bus +Emergency vehicle +Police car (siren) +Ambulance (siren) +Fire engine, fire truck (siren) +Motorcycle +Traffic noise, roadway noise +Rail transport +Train +Train whistle +Train horn +Railroad car, train wagon +Train wheels squealing +Subway, metro, underground +Aircraft +Aircraft engine +Jet engine +Propeller, airscrew +Helicopter +Fixed-wing aircraft, airplane +Bicycle +Skateboard +Engine +Light engine (high frequency) +Dental drill, dentist's drill +Lawn mower +Chainsaw +Medium engine (mid frequency) +Heavy engine (low frequency) +Engine knocking +Engine starting +Idling +Accelerating, revving, vroom +Door +Doorbell +Ding-dong +Sliding door +Slam +Knock +Tap +Squeak +Cupboard open or close +Drawer open or close +Dishes, pots, and pans +Cutlery, silverware +Chopping (food) +Frying (food) +Microwave oven +Blender +Water tap, faucet +Sink (filling or washing) +Bathtub (filling or washing) +Hair dryer +Toilet flush +Toothbrush +Electric toothbrush +Vacuum cleaner +Zipper (clothing) +Keys jangling +Coin (dropping) +Scissors +Electric shaver, electric razor +Shuffling cards +Typing +Typewriter +Computer keyboard +Writing +Alarm +Telephone +Telephone bell ringing +Ringtone +Telephone dialing, DTMF +Dial tone +Busy signal +Alarm clock +Siren +Civil defense siren +Buzzer +Smoke detector, smoke alarm +Fire alarm +Foghorn +Whistle +Steam whistle +Mechanisms +Ratchet, pawl +Clock +Tick +Tick-tock +Gears +Pulleys +Sewing machine +Mechanical fan +Air conditioning +Cash register +Printer +Camera +Single-lens reflex camera +Tools +Hammer +Jackhammer +Sawing +Filing (rasp) +Sanding +Power tool +Drill +Explosion +Gunshot, gunfire +Machine gun +Fusillade +Artillery fire +Cap gun +Fireworks +Firecracker +Burst, pop +Eruption +Boom +Wood +Chop +Splinter +Crack +Glass +Chink, clink +Shatter +Liquid +Splash, splatter +Slosh +Squish +Drip +Pour +Trickle, dribble +Gush +Fill (with liquid) +Spray +Pump (liquid) +Stir +Boiling +Sonar +Arrow +Whoosh, swoosh, swish +Thump, thud +Thunk +Electronic tuner +Effects unit +Chorus effect +Basketball bounce +Bang +Slap, smack +Whack, thwack +Smash, crash +Breaking +Bouncing +Whip +Flap +Scratch +Scrape +Rub +Roll +Crushing +Crumpling, crinkling +Tearing +Beep, bleep +Ping +Ding +Clang +Squeal +Creak +Rustle +Whir +Clatter +Sizzle +Clicking +Clickety-clack +Rumble +Plop +Jingle, tinkle +Hum +Zing +Boing +Crunch +Silence +Sine wave +Harmonic +Chirp tone +Sound effect +Pulse +Inside, small room +Inside, large room or hall +Inside, public space +Outside, urban or manmade +Outside, rural or natural +Reverberation +Echo +Noise +Environmental noise +Static +Mains hum +Distortion +Sidetone +Cacophony +White noise +Pink noise +Throbbing +Vibration +Television +Radio +Field recording
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_tfhub.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_tfhub.json new file mode 100644 index 0000000..8881154 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_tfhub.json
@@ -0,0 +1,63 @@ +{ + "name": "AudioClassifier", + "description": "Identify the most prominent type in the audio clip from a known set of categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "audio_clip", + "description": "Input audio clip to be classified.", + "content": { + "content_properties_type": "AudioProperties", + "content_properties": { + "sample_rate": 2, + "channels": 1 + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Scores of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "LOG", + "default_score": 0.2 + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labelmap.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_tfhub.tflite b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_tfhub.tflite new file mode 100644 index 0000000..f19a3fbc --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_tfhub.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6.json new file mode 100644 index 0000000..a6a3a87 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6.json
@@ -0,0 +1,63 @@ +{ + "name": "AudioClassifier", + "description": "Identify the most prominent type in the audio clip from a known set of categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "audio_clip", + "description": "Input audio clip to be classified.", + "content": { + "content_properties_type": "AudioProperties", + "content_properties": { + "sample_rate": 2, + "channels": 1 + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Scores of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "LOG", + "default_score": 0.2 + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "yamnet_521_labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6.tflite b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6.tflite new file mode 100644 index 0000000..52b93a3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6_default.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6_default.json new file mode 100644 index 0000000..af556e6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6_default.json
@@ -0,0 +1,34 @@ +{ + "name": "AudioClassifier", + "description": "Identify the most prominent type in the audio clip from a known set of categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "audio_clip", + "description": "Input audio clip to be classified.", + "content": { + "content_properties_type": "AudioProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Scores of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_embedder/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_embedder/BUILD new file mode 100644 index 0000000..0f8caca --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_embedder/BUILD
@@ -0,0 +1,13 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "test_files", + srcs = glob([ + "*.json", + "*.tflite", + "*.txt", + ]), +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_embedder/yamnet_embedding.tflite b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_embedder/yamnet_embedding.tflite new file mode 100644 index 0000000..dff6786 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/audio_embedder/yamnet_embedding.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/BUILD new file mode 100644 index 0000000..7b77249 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/BUILD
@@ -0,0 +1,32 @@ +load("//tensorflow_lite_support/tools/build_rules:http_files.bzl", "tflite_file", "tflite_model") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "test_files", + srcs = glob([ + "*.json", + "*.tflite", + "*.txt", + "*.model", + ]) + [ + ":bert_nl_classifier_no_metadata", + "mobilebert_vocab", + ":30k-clean", + ], +) + +tflite_model(name = "bert_nl_classifier_no_metadata") + +tflite_file( + name = "mobilebert_vocab", + extension = "txt", +) + +tflite_file( + name = "30k-clean", + extension = "model", +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/bert_nl_classifier_default.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/bert_nl_classifier_default.json new file mode 100644 index 0000000..19b1bc7 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/bert_nl_classifier_default.json
@@ -0,0 +1,59 @@ +{ + "name": "BertNLClassifier", + "description": "Classify the input text into a set of known categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "ids", + "description": "Tokenized ids of the input text.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "segment_ids", + "description": "0 for the first sequence, 1 for the second sequence if exists.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "mask", + "description": "Mask with 1 for real tokens and 0 for padding tokens.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Probabilities of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "input_process_units": [ + + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/bert_nl_classifier_with_bert_tokenizer.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/bert_nl_classifier_with_bert_tokenizer.json new file mode 100644 index 0000000..24e4367 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/bert_nl_classifier_with_bert_tokenizer.json
@@ -0,0 +1,84 @@ +{ + "name": "BertNLClassifier", + "description": "Classify the input text into a set of known categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "ids", + "description": "Tokenized ids of the input text.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "segment_ids", + "description": "0 for the first sequence, 1 for the second sequence if exists.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "mask", + "description": "Mask with 1 for real tokens and 0 for padding tokens.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Probabilities of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ], + "input_process_units": [ + { + "options_type": "BertTokenizerOptions", + "options": { + "vocab_file": [ + { + "name": "mobilebert_vocab.txt", + "description": "Vocabulary file to convert natural language words to embedding vectors.", + "type": "VOCABULARY" + } + ] + } + } + ] + } + ], + "min_parser_version": "1.1.0" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/bert_nl_classifier_with_sentence_piece.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/bert_nl_classifier_with_sentence_piece.json new file mode 100644 index 0000000..9472e50 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/bert_nl_classifier_with_sentence_piece.json
@@ -0,0 +1,83 @@ +{ + "name": "BertNLClassifier", + "description": "Classify the input text into a set of known categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "ids", + "description": "Tokenized ids of the input text.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "segment_ids", + "description": "0 for the first sequence, 1 for the second sequence if exists.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "mask", + "description": "Mask with 1 for real tokens and 0 for padding tokens.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Probabilities of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ], + "input_process_units": [ + { + "options_type": "SentencePieceTokenizerOptions", + "options": { + "sentencePiece_model": [ + { + "name": "30k-clean.model", + "description": "The sentence piece model file." + } + ] + } + } + ] + } + ], + "min_parser_version": "1.1.0" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/labels.txt b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/labels.txt new file mode 100644 index 0000000..0cbde6b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_nl_classifier/labels.txt
@@ -0,0 +1,2 @@ +negative +positive
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_tokenizer_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_tokenizer_meta.json new file mode 100644 index 0000000..c90e7df --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bert_tokenizer_meta.json
@@ -0,0 +1,20 @@ +{ + "subgraph_metadata": [ + { + "input_process_units": [ + { + "options_type": "BertTokenizerOptions", + "options": { + "vocab_file": [ + { + "name": "vocab.txt", + "description": "Vocabulary file to convert natural language words to embedding vectors.", + "type": "VOCABULARY" + } + ] + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bounding_box_tensor_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bounding_box_tensor_meta.json new file mode 100644 index 0000000..55b0624c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/bounding_box_tensor_meta.json
@@ -0,0 +1,35 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input", + "description": "The input tensor.", + "content": { + "content_properties_type": "BoundingBoxProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "locale": "en" + }, + { + "name": "labels_cn.txt", + "locale": "cn" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/category_tensor_float_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/category_tensor_float_meta.json new file mode 100644 index 0000000..9ca0588 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/category_tensor_float_meta.json
@@ -0,0 +1,33 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "category", + "description": "The category tensor.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_VALUE_LABELS", + "locale": "en" + }, + { + "name": "labels_cn.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_VALUE_LABELS", + "locale": "cn" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/classification_tensor_float_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/classification_tensor_float_meta.json new file mode 100644 index 0000000..1b146d5 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/classification_tensor_float_meta.json
@@ -0,0 +1,52 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "probability", + "description": "The classification result tensor.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "default_score": 0.2 + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "en" + }, + { + "name": "labels_cn.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "cn" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/classification_tensor_uint8_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/classification_tensor_uint8_meta.json new file mode 100644 index 0000000..f544afd --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/classification_tensor_uint8_meta.json
@@ -0,0 +1,52 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "probability", + "description": "The classification result tensor.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "default_score": 0.2 + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "en" + }, + { + "name": "labels_cn.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "cn" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/classification_tensor_unsupported_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/classification_tensor_unsupported_meta.json new file mode 100644 index 0000000..98cf178 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/classification_tensor_unsupported_meta.json
@@ -0,0 +1,46 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "probability", + "description": "The classification result tensor.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "default_score": 0.2 + } + } + ], + "stats": { + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "en" + }, + { + "name": "labels_cn.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS", + "locale": "cn" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/feature_tensor_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/feature_tensor_meta.json new file mode 100644 index 0000000..4502d24c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/feature_tensor_meta.json
@@ -0,0 +1,35 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input", + "description": "The input tensor.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "locale": "en" + }, + { + "name": "labels_cn.txt", + "locale": "cn" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/general_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/general_meta.json new file mode 100644 index 0000000..1145cf404 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/general_meta.json
@@ -0,0 +1,7 @@ +{ + "name": "model", + "description": "A ML model.", + "version": "v1", + "author": "TensorFlow", + "license": "Apache" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/golden_json.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/golden_json.json new file mode 100644 index 0000000..601a597 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/golden_json.json
@@ -0,0 +1,28 @@ +{ + "name": "Mobilenet_quantized", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + }, + { + } + ], + "output_tensor_metadata": [ + { + "associated_files": [ + { + "name": "file2" + } + ] + } + ] + } + ], + "associated_files": [ + { + "name": "file1" + } + ], + "min_parser_version": "1.0.0" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/BUILD new file mode 100644 index 0000000..0f8caca --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/BUILD
@@ -0,0 +1,13 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "test_files", + srcs = glob([ + "*.json", + "*.tflite", + "*.txt", + ]), +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt new file mode 100644 index 0000000..fe81123 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt
@@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.json new file mode 100644 index 0000000..3c032df --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.json
@@ -0,0 +1,81 @@ +{ + "name": "ImageClassifier", + "description": "Identify the most prominent object in the image from a known set of categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be classified.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + -1.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Probabilities of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "LOG", + "default_score": 0.2 + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite new file mode 100644 index 0000000..56268e1 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224_default.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224_default.json new file mode 100644 index 0000000..5380fa5 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224_default.json
@@ -0,0 +1,35 @@ +{ + "name": "ImageClassifier", + "description": "Identify the most prominent object in the image from a known set of categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be classified.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Probabilities of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224_quant.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224_quant.json new file mode 100644 index 0000000..a1b58ce --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224_quant.json
@@ -0,0 +1,81 @@ +{ + "name": "ImageClassifier", + "description": "Identify the most prominent object in the image from a known set of categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be classified.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Probabilities of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "LOG", + "default_score": 0.2 + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224_quant.tflite b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224_quant.tflite new file mode 100644 index 0000000..c26ff77 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224_quant.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/score_calibration.txt b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/score_calibration.txt new file mode 100644 index 0000000..b64f0c1 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/score_calibration.txt
@@ -0,0 +1,511 @@ + +0.9876328110694885,0.36622241139411926,0.5352765321731567,0.71484375 +0.9584911465644836,1.0602262020111084,0.2777034342288971,0.019999999552965164 +0.9698624014854431,0.8795201778411865,0.539591908454895,0.00390625 +0.7486230731010437,1.1876736879348755,2.552982807159424,0.019999999552965164 +0.9745277166366577,0.3739396333694458,0.4621727764606476,0.19921875 +0.9683839678764343,0.6996201276779175,0.7690851092338562,0.019999999552965164 +0.6875,0.31044548749923706,1.0056899785995483,0.019999999552965164 +0.9849396347999573,0.8532888889312744,-0.2361421436071396,0.03125 +0.9878578186035156,1.0118975639343262,0.13313621282577515,0.359375 +0.9915205836296082,0.4434199929237366,1.0268371105194092,0.05078125 +0.9370332360267639,0.4586562216281891,-0.08101099729537964,0.019999999552965164 +0.9905818104743958,0.8670706152915955,0.012704282067716122,0.019999999552965164 +0.9080020189285278,0.8507471680641174,0.5081117749214172,0.019999999552965164 +0.985953152179718,0.9933826923370361,-0.8114940524101257,0.109375 +0.9819648861885071,1.12098228931427,-0.6330763697624207,0.01171875 +0.9025918245315552,0.7803755402565002,0.03275677561759949,0.08984375 +0.9863958954811096,0.11243592947721481,0.935604453086853,0.61328125 +0.9905291795730591,0.3710605800151825,0.708966851234436,0.359375 +0.9917052984237671,0.9596433043479919,0.19800108671188354,0.09765625 +0.8762937188148499,0.3449830114841461,0.5352474451065063,0.0078125 +0.9902125000953674,0.8918796181678772,-0.1306992471218109,0.26171875 + +0.9902340173721313,0.9177873134613037,-0.4322589933872223,0.019999999552965164 +0.9707600474357605,0.7028177976608276,0.9813734889030457,0.019999999552965164 +0.9823090434074402,1.0499590635299683,0.12045472860336304,0.0078125 +0.990516185760498,0.9449402093887329,1.3773189783096313,0.019999999552965164 +0.9875434041023254,0.577914297580719,1.282518982887268,0.0390625 +0.9821421504020691,0.0967339277267456,0.8279788494110107,0.47265625 +0.9875047206878662,0.9038218259811401,2.1208062171936035,0.38671875 +0.9857864379882812,0.8627446889877319,0.18189261853694916,0.019999999552965164 +0.9647751450538635,1.0752476453781128,-0.018294010311365128,0.0234375 +0.9830358624458313,0.5638481378555298,0.8346489667892456,0.019999999552965164 +0.9904966354370117,1.0160938501358032,-0.0573287308216095,0.00390625 +0.8458405137062073,0.4868394434452057,0.6617084741592407,0.019999999552965164 +0.9847381711006165,0.5939620137214661,0.008616370148956776,0.00390625 +0.9375938773155212,0.723095178604126,0.6635608077049255,0.019999999552965164 +0.9334303140640259,0.5689108967781067,0.37019580602645874,0.019999999552965164 +0.9716793894767761,1.0037211179733276,0.5898993611335754,0.02734375 +0.9197732210159302,0.46794334053993225,0.7365336418151855,0.640625 +0.9857497811317444,0.7299028635025024,0.9195274114608765,0.0390625 +0.8758038282394409,1.200216293334961,0.02580185979604721,0.019999999552965164 +0.9841026067733765,0.8050475716590881,0.9698556661605835,0.0078125 +0.9908539652824402,0.7911490201950073,0.19351358711719513,0.12109375 +0.9179956316947937,0.023991893976926804,0.35193610191345215,0.04296875 +0.9903728365898132,0.7744967341423035,0.2686336636543274,0.359375 +0.906022846698761,0.5766159892082214,1.0600007772445679,0.04296875 +0.9885554909706116,0.99117511510849,0.5611960291862488,0.4140625 +0.9906331896781921,1.1376535892486572,1.45369291305542,0.019999999552965164 +0.9640991687774658,0.5387894511222839,1.1824018955230713,0.019999999552965164 +0.9932155609130859,0.4347895085811615,1.3938102722167969,0.0078125 +0.9884702563285828,0.885567843914032,0.1556047648191452,0.1484375 +0.9891508221626282,0.04143073782324791,0.6111864447593689,0.0078125 +0.8935436010360718,0.2937895655632019,0.3215920031070709,0.00390625 +0.8327123522758484,0.8381986021995544,-0.026293788105249405,0.019999999552965164 +0.9839455485343933,0.9581400156021118,1.495324969291687,0.640625 +0.9904995560646057,0.9168422818183899,0.33293962478637695,0.015625 +0.9856975674629211,1.0433714389801025,0.5954801440238953,0.019999999552965164 +0.9942344427108765,0.7206616997718811,1.666426181793213,0.9609375 +0.8182767033576965,0.9546273946762085,0.5500107407569885,0.019999999552965164 +0.9631295800209045,0.6277880668640137,0.05952891707420349,0.05859375 +0.9819005727767944,1.0826934576034546,0.7444049715995789,0.30859375 +0.9884315133094788,1.0500890016555786,1.1161768436431885,0.019999999552965164 +0.9175815582275391,0.09232989698648453,1.596696138381958,0.47265625 +0.9868760108947754,0.903079628944397,-0.15774966776371002,0.8515625 +0.9866015911102295,0.7533788084983826,0.7489103078842163,0.03125 +0.8074312806129456,0.8615151643753052,0.40621864795684814,0.00390625 +0.9829285144805908,0.8954831957817078,0.4462486207485199,0.02734375 +0.9681841135025024,0.6257772445678711,0.43809664249420166,0.38671875 +0.9872947931289673,0.9947993159294128,0.9271130561828613,0.26171875 +0.7997345328330994,0.3995186686515808,-0.3755347430706024,0.019999999552965164 +0.9922754168510437,1.1357101202011108,-0.10267537832260132,0.5 +0.9861471652984619,0.8725204467773438,1.1657888889312744,0.019999999552965164 +0.9888646006584167,1.2098380327224731,-0.27832522988319397,0.05078125 +0.5641342997550964,1.0501892566680908,1.9519661664962769,0.019999999552965164 +0.9548168778419495,0.8971696496009827,1.378737449645996,0.00390625 +0.9875019788742065,0.8718118071556091,0.5476236939430237,0.0078125 +0.9725168347358704,0.6989551782608032,-1.3157455921173096,0.61328125 +0.9864014983177185,0.7576251029968262,-0.41650667786598206,0.00390625 +0.960071861743927,0.13068856298923492,0.4819187819957733,0.019999999552965164 +0.9849705100059509,0.7724528908729553,0.3877875804901123,0.03125 +0.9703006744384766,0.8848260641098022,-1.1767181158065796,0.80078125 +0.9837008714675903,0.7015050053596497,0.18209102749824524,0.00390625 +0.9579976797103882,0.053806986659765244,2.7309608459472656,0.4000000059604645 +0.9896979928016663,0.41135814785957336,0.5738034844398499,0.019999999552965164 +0.9853873252868652,0.5438565611839294,0.20562179386615753,0.02734375 +0.9784129858016968,0.6330984830856323,-0.1789831817150116,0.015625 +0.9375,0.855596125125885,-0.1933964192867279,0.019999999552965164 +0.9524176716804504,0.08709807693958282,0.6299692988395691,0.33203125 +0.9808038473129272,1.2909820079803467,0.3397117257118225,0.00390625 +0.8008236885070801,0.7974631786346436,1.0567312240600586,0.019999999552965164 +0.9421642422676086,0.6754576563835144,0.32419073581695557,0.23828125 +0.9072281718254089,1.1716840267181396,-0.10382208228111267,0.00390625 +0.9497162103652954,1.1582106351852417,-0.11845408380031586,0.00390625 +0.9773319959640503,0.5042116641998291,1.2815768718719482,0.23828125 +0.9743752479553223,1.1731196641921997,0.48585158586502075,0.1640625 +0.9601503610610962,1.0114264488220215,-0.9113408327102661,0.38671875 +0.97279292345047,0.32572469115257263,0.548393964767456,0.01171875 +0.9845231175422668,0.9852075576782227,1.0973742008209229,0.69140625 +0.9764596223831177,0.2248251885175705,0.8963963985443115,0.33203125 +0.8746626377105713,0.016590777784585953,1.4492003917694092,0.359375 +0.9726155996322632,0.8712832927703857,-0.6451321840286255,0.52734375 +0.980800211429596,0.8469374775886536,0.0718703418970108,0.04296875 +0.7734344005584717,0.8508065342903137,0.4233662784099579,0.019999999552965164 +0.969182550907135,0.8082079887390137,-0.4314402937889099,0.0234375 +0.9037994742393494,0.1387290209531784,1.8660004138946533,0.5 +0.9869260191917419,0.6927974820137024,0.4927133619785309,0.019999999552965164 +0.8794143795967102,0.8060213327407837,-0.6247795820236206,0.09765625 +0.9895913600921631,0.8851431012153625,0.9641156196594238,0.28515625 +0.9833245873451233,0.9379183053970337,1.5143399238586426,0.0078125 +0.26580730080604553,1.488408088684082,2.5120370388031006,0.019999999552965164 +0.9859549403190613,1.5805137157440186,0.7283271551132202,0.01171875 +0.9376091361045837,0.6854841709136963,0.20175717771053314,0.00390625 +0.965065598487854,0.7363166213035583,-0.3636060357093811,0.1484375 +0.9904685020446777,0.9182849526405334,0.30159056186676025,0.05859375 +0.5014551877975464,0.7409977912902832,0.2045259326696396,0.019999999552965164 +0.9434370398521423,0.3679845631122589,0.6447131633758545,0.38671875 +0.9806621670722961,0.9568924307823181,1.2417932748794556,0.019999999552965164 +0.9825865626335144,1.2273900508880615,-0.0674915760755539,0.0390625 +0.9859767556190491,0.7635276317596436,-0.8502742648124695,0.109375 +0.9701240658760071,0.46266916394233704,0.38697123527526855,0.0703125 +0.9651575088500977,0.5057743191719055,0.6578569412231445,0.0078125 + +0.9685596227645874,0.6961715817451477,0.20829983055591583,0.015625 +0.9772806167602539,0.8312440514564514,-0.09966880083084106,0.019999999552965164 +0.9718109369277954,0.8248763680458069,1.2387524843215942,0.08984375 +0.9890084266662598,2.0058324337005615,1.7648913860321045,0.019999999552965164 +0.9813475608825684,1.02803373336792,1.4689184427261353,0.019999999552965164 +0.9925220608711243,0.8020634055137634,0.7509317994117737,0.015625 +0.9754987955093384,0.5145153999328613,0.4638928472995758,0.00390625 +0.9735408425331116,0.7434492111206055,0.06251777708530426,0.01171875 +0.8753963112831116,1.6830265522003174,4.509310722351074,0.019999999552965164 +0.9385876655578613,0.46194836497306824,0.13496099412441254,0.13671875 +0.9676342010498047,0.5462782979011536,0.9306238889694214,0.1796875 +0.9829097986221313,0.8054409623146057,0.11194216459989548,0.08984375 +0.9503080248832703,0.44028621912002563,0.4689175486564636,0.00390625 +0.9808863997459412,0.8023126721382141,-0.022534284740686417,0.015625 +0.9079821109771729,0.33415740728378296,0.544142484664917,0.019999999552965164 +0.9839802980422974,0.9184480905532837,0.2658761739730835,0.1484375 +0.75,0.8216301798820496,0.3300539255142212,0.019999999552965164 +0.9590148329734802,0.722118616104126,0.255025178194046,0.015625 +0.9616804122924805,0.8398274779319763,0.33006206154823303,0.019999999552965164 +0.7859238386154175,0.5596626400947571,0.5452361702919006,0.019999999552965164 +0.9842674732208252,0.07029404491186142,1.189304232597351,0.30859375 +0.7237641215324402,0.2756437361240387,-0.10612351447343826,0.019999999552965164 +0.9793540239334106,0.5117573738098145,0.8033715486526489,0.01953125 +0.9825188517570496,0.3965616822242737,0.17742407321929932,0.019999999552965164 +0.9859991073608398,1.32109534740448,0.5763598084449768,0.019999999552965164 +0.9551243782043457,0.3639756143093109,0.19449777901172638,0.00390625 +0.9606218338012695,0.8222983479499817,0.43461644649505615,0.00390625 +0.9785885810852051,0.9104304909706116,0.2279568761587143,0.01171875 +0.9705367684364319,0.0769517719745636,0.7330215573310852,0.04296875 +0.9736841320991516,0.9110560417175293,0.10864781588315964,0.05859375 +0.9880238771438599,1.1702078580856323,0.05487633869051933,0.00390625 +0.9913991093635559,0.7445327043533325,1.2198610305786133,0.01171875 +0.8302573561668396,0.33997753262519836,1.0731935501098633,0.019999999552965164 +0.9880614280700684,0.9227356910705566,2.1198885440826416,0.61328125 +0.9173498153686523,0.2221490740776062,0.11565151065587997,0.0078125 +0.962620735168457,1.011454701423645,-1.5519139766693115,0.8203125 +0.9828791618347168,0.7543124556541443,0.29118794202804565,0.00390625 +0.9908701181411743,0.8183356523513794,0.48734790086746216,0.019999999552965164 +0.5002585649490356,0.12179236859083176,0.20199841260910034,0.019999999552965164 +0.9631574153900146,0.41631683707237244,1.1000276803970337,0.44140625 +0.9875426888465881,0.8117235898971558,0.8689690232276917,0.08203125 +0.9410585761070251,0.3703889548778534,0.7951740026473999,0.0078125 +0.9877454042434692,0.2155231237411499,1.635109305381775,0.94921875 +0.9860436320304871,1.0054532289505005,-0.9608616232872009,0.03125 +0.9721421003341675,0.5174740552902222,0.43327680230140686,0.0078125 + +0.9908374547958374,0.8122930526733398,0.21533408761024475,0.0078125 +0.9896888136863708,0.7030488848686218,-0.062063876539468765,0.01953125 +0.9861313700675964,0.49431633949279785,0.981758177280426,0.12109375 +0.9792494177818298,1.0670701265335083,0.7028639316558838,0.019999999552965164 +0.9871346950531006,1.3606067895889282,-3.00394868850708,0.61328125 +0.9583333134651184,0.9180184602737427,-0.05760742351412773,0.019999999552965164 +0.9764145612716675,0.5258041024208069,1.1425464153289795,0.019999999552965164 +0.9076833128929138,1.081973910331726,0.6340405344963074,0.019999999552965164 +0.9895729422569275,0.27958083152770996,1.2441545724868774,0.08203125 +0.916824221611023,0.7878308892250061,-1.3060243129730225,0.359375 +0.9883677363395691,0.6098470687866211,0.7665972709655762,0.52734375 +0.949999988079071,0.818132758140564,1.5476282835006714,0.019999999552965164 + +0.9666821360588074,0.707548201084137,0.7326748967170715,0.00390625 +0.9861665368080139,0.7194502353668213,2.1585183143615723,0.38671875 +0.9811879992485046,0.32190269231796265,0.31508582830429077,0.05078125 +0.9625869989395142,0.11173010617494583,0.9030138850212097,0.019999999552965164 +0.9675677418708801,0.49738144874572754,0.5481624007225037,0.019999999552965164 +0.9764066934585571,1.0306450128555298,0.2257029116153717,0.00390625 +0.9857029318809509,0.8312124013900757,-0.12777498364448547,0.00390625 +0.9781621098518372,0.621485710144043,0.3126043975353241,0.21875 +0.9705549478530884,0.15182119607925415,1.7296228408813477,0.13671875 +0.9801698923110962,0.8953424692153931,0.6697174310684204,0.019999999552965164 +0.9842199087142944,0.7984838485717773,0.7436375617980957,0.0078125 +0.9159231185913086,0.05519663542509079,0.011483916081488132,0.47265625 +0.9742691516876221,0.9268448352813721,1.1530364751815796,0.019999999552965164 +0.9579406380653381,0.7879363894462585,1.1582229137420654,0.00390625 +0.8999202251434326,0.8120636343955994,0.37021151185035706,0.019999999552965164 +0.9870507121086121,1.1666820049285889,1.387096881866455,0.019999999552965164 +0.9769532680511475,0.6519474983215332,0.3170791268348694,0.109375 +0.9546447396278381,0.7559569478034973,0.9533731937408447,0.0078125 +0.9773718118667603,1.3183629512786865,1.0090563297271729,0.019999999552965164 +0.9049819707870483,1.0706751346588135,1.7704588174819946,0.019999999552965164 +0.9003662467002869,0.7251236438751221,-1.4905513525009155,0.4140625 +0.9834321141242981,0.5246152877807617,1.2191725969314575,0.47265625 +0.9748008847236633,0.8448761105537415,-0.01744924671947956,0.00390625 +0.9904628396034241,0.8762193918228149,0.22459718585014343,0.01171875 +0.6833457946777344,0.8996955752372742,1.2423095703125,0.019999999552965164 +0.9909645318984985,0.8978683948516846,0.7022045254707336,0.019999999552965164 +0.9843918681144714,0.12815311551094055,1.5720607042312622,0.78125 +0.9382115602493286,0.4989806115627289,1.1206520795822144,0.03515625 +0.9832627177238464,0.6727185845375061,0.2797912657260895,0.08984375 +0.8830162286758423,1.1294968128204346,1.1474463939666748,0.019999999552965164 +0.9554208517074585,0.9476046562194824,0.8490120768547058,0.019999999552965164 +0.98823082447052,0.7835749983787537,0.5608289837837219,0.03515625 +0.9790570139884949,0.9982950091362,0.3763321042060852,0.00390625 +0.5039305686950684,0.9079190492630005,1.265581488609314,0.019999999552965164 +0.9871423840522766,0.6633929014205933,0.09028752893209457,0.019999999552965164 +0.8614975214004517,0.9595098495483398,-0.5349600315093994,0.00390625 +0.9873358011245728,0.698331892490387,0.7571848630905151,0.1484375 +0.7227392196655273,1.1300171613693237,1.1754553318023682,0.019999999552965164 +0.9814568758010864,0.46864795684814453,0.6286783218383789,0.19921875 +0.9876973032951355,0.29863566160202026,0.7726709842681885,0.61328125 +0.9887779951095581,1.1818888187408447,-1.0321481227874756,0.38671875 +0.9684743285179138,0.7226923108100891,0.0908145159482956,0.0390625 +0.9854185581207275,1.0576037168502808,0.35190048813819885,0.0078125 +0.9463624954223633,0.781932532787323,0.7598024606704712,0.01171875 +0.9837555885314941,0.8735848665237427,0.5948384404182434,0.019999999552965164 +0.9700835347175598,0.45710718631744385,2.141801357269287,0.8359375 +0.9896127581596375,1.018708348274231,0.23626597225666046,0.01953125 +0.7728451490402222,8.084141001063472e-08,0.7415778636932373,0.4000000059604645 +0.9838477969169617,0.8994008302688599,0.15494465827941895,0.00390625 +0.9421281218528748,0.4648025333881378,0.12706322968006134,0.00390625 +0.9843724370002747,1.0055731534957886,-0.911835253238678,0.23828125 +0.958256185054779,1.1208757162094116,-0.31016042828559875,0.0078125 +0.9832971692085266,0.056124646216630936,1.7148709297180176,0.23828125 +0.9804430603981018,0.4016909897327423,0.6085042357444763,0.0703125 +0.9825966358184814,0.9228396415710449,0.912163257598877,0.019999999552965164 +0.9441317915916443,0.048142336308956146,0.6141980290412903,0.109375 +0.9856440424919128,0.8616625666618347,0.28943121433258057,0.015625 +0.9913654923439026,1.0482347011566162,0.6889304518699646,0.015625 +0.97914719581604,0.8870795369148254,-0.700239360332489,0.015625 +0.9836585521697998,0.5450212955474854,0.009687358513474464,0.01953125 +0.990472137928009,0.8221097588539124,2.5926225185394287,0.97265625 +0.6274135708808899,0.6787079572677612,0.12988793849945068,0.015625 +0.982601523399353,0.7495649456977844,1.2217103242874146,0.019999999552965164 +0.9841020703315735,0.9071263670921326,1.3682825565338135,0.09765625 +0.9872562885284424,0.818276584148407,-0.14663955569267273,0.05859375 +0.5041943192481995,0.35444244742393494,0.46112486720085144,0.00390625 +0.7517910599708557,0.91172856092453,1.3611085414886475,0.019999999552965164 +0.9861181378364563,1.0613479614257812,-0.46272075176239014,0.015625 +0.9914185404777527,0.9464229941368103,1.2103853225708008,0.0234375 +0.984909176826477,0.5985794067382812,0.7704220414161682,0.08203125 +0.9575125575065613,0.7695640325546265,0.6132461428642273,0.00390625 +0.9845197200775146,0.7421835064888,1.332088589668274,0.019999999552965164 +0.9470700621604919,0.357934832572937,1.0986406803131104,0.359375 +0.9287161231040955,0.6833012104034424,0.373298704624176,0.00390625 +0.9531774520874023,0.3247152864933014,0.6011538505554199,0.66796875 +0.9779354929924011,0.828241229057312,0.3349589705467224,0.03125 +0.9863978028297424,0.932086169719696,0.04865559563040733,0.02734375 +0.9826814532279968,0.06353739649057388,1.879408359527588,0.61328125 +0.974474310874939,0.8063777685165405,0.8257133364677429,0.019999999552965164 +0.9670184254646301,0.09195757657289505,1.7024414539337158,0.5 +0.9885809421539307,0.7981435656547546,-0.11792337149381638,0.0703125 +0.9829109907150269,0.9578585028648376,-1.9371291399002075,0.13671875 +0.9754639863967896,1.137816071510315,0.5887423157691956,0.00390625 +0.9755549430847168,0.677255392074585,0.20494212210178375,0.00390625 +0.9903355836868286,1.0475162267684937,2.1768462657928467,0.52734375 +0.9855127930641174,0.9580414891242981,0.35021960735321045,0.76171875 +0.9450457692146301,0.4737727642059326,-0.3041325807571411,0.01171875 +0.9360163807868958,0.9219141006469727,1.2481396198272705,0.019999999552965164 +0.9696909189224243,0.06589268147945404,1.456658124923706,0.30000001192092896 +0.6495901942253113,0.8538134098052979,0.3043774366378784,0.019999999552965164 +0.9901140928268433,0.8112474679946899,0.7102972269058228,0.019999999552965164 +0.9925929307937622,0.49307680130004883,0.6297348737716675,0.019999999552965164 +0.9840761423110962,0.5691578388214111,0.9437046647071838,0.00390625 +0.9625457525253296,0.9322702288627625,1.3358750343322754,0.0234375 +0.9820173978805542,0.6805416345596313,1.0065922737121582,0.05859375 +0.9883391261100769,0.742003321647644,0.6168643236160278,0.0078125 +0.9119130969047546,0.8404607176780701,0.8882355690002441,0.01171875 +0.9854885935783386,1.295777440071106,0.5272557735443115,0.00390625 +0.9911734461784363,1.152715802192688,-0.05230601131916046,0.019999999552965164 +0.8071879744529724,0.4576769471168518,1.391660451889038,0.00390625 +0.9919166564941406,1.1775370836257935,0.5039792060852051,0.019999999552965164 +0.9831258654594421,0.9164834022521973,0.3790256977081299,0.01171875 +0.990642249584198,0.9242916107177734,1.477474570274353,0.38671875 +0.7415178418159485,0.2909083068370819,0.19971248507499695,0.019999999552965164 +0.9146556854248047,0.06850286573171616,1.3211928606033325,0.61328125 +0.976986825466156,0.6469135284423828,-0.7279839515686035,0.02734375 +0.968462347984314,0.4640704393386841,1.4650955200195312,0.1484375 +0.937825083732605,0.9767780303955078,-0.7378027439117432,0.0390625 +0.9878604412078857,1.1423084735870361,1.7311146259307861,0.1484375 +0.9904257655143738,0.9551829099655151,1.564165472984314,0.00390625 +0.9830996990203857,0.92529296875,-0.1086890697479248,0.02734375 + +0.9820512533187866,0.7556048631668091,0.6512532830238342,0.109375 +0.9740781188011169,0.8380919098854065,0.19731587171554565,0.019999999552965164 +0.9830799698829651,1.183397650718689,-0.801214873790741,0.019999999552965164 +0.9898439049720764,1.168870210647583,1.2985308170318604,0.00390625 +0.97286057472229,0.8012385964393616,-1.657444953918457,0.09765625 +0.9182834625244141,0.5254654884338379,-0.027080848813056946,0.04296875 +0.9729798436164856,0.4111078381538391,1.077646255493164,0.019999999552965164 +0.6875,1.756393551826477,0.34522199630737305,0.019999999552965164 +0.9920725226402283,1.0676580667495728,1.1592471599578857,0.019999999552965164 +0.37564563751220703,0.07466565072536469,0.3562135696411133,0.019999999552965164 +0.9894161224365234,0.8109862804412842,1.3056280612945557,0.0390625 +0.9386259317398071,0.5322021842002869,-0.03461914509534836,0.08984375 +0.9866133332252502,0.8940346240997314,1.0361984968185425,0.00390625 +0.9822850823402405,0.6215930581092834,-0.6859042048454285,0.00390625 +0.9752063155174255,1.0129338502883911,0.3866007626056671,0.019999999552965164 +0.9825329184532166,0.567034125328064,0.5370683670043945,0.5 +0.9422088861465454,0.9411858320236206,0.5332568883895874,0.38671875 +0.9506444931030273,0.7494101524353027,0.9869776368141174,0.00390625 +0.9923189282417297,1.1255286931991577,0.8734608292579651,0.019999999552965164 +0.9807777404785156,0.9558923244476318,1.5415621995925903,0.09765625 +0.961335301399231,0.7840818762779236,0.06915930658578873,0.00390625 +0.9867202639579773,1.0596263408660889,0.21268242597579956,0.0078125 +0.9926426410675049,0.8886650204658508,0.6200761198997498,0.019999999552965164 +0.9791930913925171,0.4474319517612457,0.5827012062072754,0.019999999552965164 +0.986801028251648,1.1846712827682495,1.4253416061401367,0.00390625 +0.9549052119255066,0.6142332553863525,0.4867286682128906,0.00390625 +0.983259916305542,0.42561075091362,0.9666317105293274,0.08203125 +0.98175048828125,0.7744573354721069,0.4953071177005768,0.019999999552965164 +0.987273097038269,0.8209654092788696,0.5267868041992188,0.019999999552965164 +0.9916341304779053,0.6881924271583557,0.9522916078567505,0.019999999552965164 +0.9819192886352539,0.8128346800804138,0.6556753516197205,0.05859375 +0.9854727387428284,0.6597779393196106,0.9645410180091858,0.8359375 +0.9891805648803711,0.7752296924591064,1.34084153175354,0.52734375 +0.9489904046058655,0.6988677978515625,0.5052891969680786,0.019999999552965164 +0.9741962552070618,0.43797168135643005,0.7825477123260498,0.01171875 +0.9907783269882202,0.8732656240463257,1.1458243131637573,0.19921875 +0.9760454297065735,0.7810378670692444,-0.29553040862083435,0.015625 +0.9885720014572144,0.8427382707595825,0.2628841996192932,0.019999999552965164 +0.8171960115432739,0.3271152079105377,1.30915105342865,0.26171875 +0.9881270527839661,0.13021250069141388,1.6307408809661865,0.55859375 +0.9751906991004944,0.8255484104156494,0.21788427233695984,0.019999999552965164 +0.9630831480026245,2.1396600701476974e-15,2.883542776107788,0.5 +0.8849332332611084,0.888649582862854,1.0651483535766602,0.01171875 +0.9897550344467163,0.08640030771493912,2.661073923110962,0.69140625 +0.9030827879905701,0.7017505168914795,0.07822071760892868,0.00390625 +0.9650112986564636,0.36098214983940125,0.7112777829170227,0.0078125 +0.9872719049453735,0.7115703821182251,0.6924230456352234,0.019999999552965164 +0.5884749889373779,0.0942283645272255,0.24825790524482727,0.019999999552965164 +0.9642857313156128,0.5304845571517944,0.6281308531761169,0.019999999552965164 +0.9651434421539307,0.07168509066104889,1.4704163074493408,0.61328125 +0.9779187440872192,1.0171563625335693,-2.8089962005615234,0.1484375 +0.9375227689743042,0.9291267991065979,0.6853470802307129,0.019999999552965164 +0.9820515513420105,0.7226945757865906,-0.19336646795272827,0.61328125 +0.984882652759552,0.8176864385604858,1.161419153213501,0.0078125 +0.9573767185211182,0.9027169346809387,0.15423306822776794,0.26171875 +0.9059234261512756,0.872424840927124,0.7419941425323486,0.019999999552965164 +0.9914654493331909,1.0662620067596436,2.7141172885894775,0.55859375 +0.9839044809341431,0.9037585854530334,0.7042809724807739,0.01953125 +0.986689567565918,0.6848335266113281,0.9014078974723816,0.00390625 +0.9837497472763062,0.7507086396217346,0.7179840207099915,0.0078125 +0.9895229339599609,1.1564929485321045,0.5822750926017761,0.019999999552965164 +0.9845471978187561,0.8716567158699036,0.19987598061561584,0.01953125 +0.971385657787323,0.49073365330696106,1.2333439588546753,0.73828125 +0.9841684699058533,0.6468350887298584,1.0000839233398438,0.0703125 +0.9882851839065552,0.26080548763275146,0.8985073566436768,0.01171875 +0.9851044416427612,0.8687262535095215,0.07842865586280823,0.1796875 +0.9799972772598267,0.25032666325569153,1.2494641542434692,0.10000000149011612 +0.9896620512008667,0.7762697339057922,0.20227234065532684,0.019999999552965164 +0.990495502948761,0.15801414847373962,1.006077766418457,0.01171875 +0.9806667566299438,0.7082678079605103,0.35462483763694763,0.02734375 +0.9715457558631897,0.0615643672645092,0.9478678703308105,0.4000000059604645 +0.9168440103530884,0.5679594874382019,-0.6143214106559753,0.1484375 +0.9824567437171936,0.45072048902511597,1.0683321952819824,0.1484375 +0.9840478301048279,0.08733312040567398,1.3535010814666748,0.47265625 +0.9896746873855591,1.1761761903762817,0.7102295756340027,0.94140625 +0.9827673435211182,0.8215981125831604,0.6729252338409424,0.019999999552965164 +0.9906817674636841,0.16318124532699585,1.133107304573059,0.30000001192092896 +0.9701097011566162,1.0519390106201172,-0.16105352342128754,0.00390625 +0.9417809844017029,0.7868722081184387,1.1539735794067383,0.019999999552965164 +0.9615354537963867,0.8469739556312561,0.6801642179489136,0.0390625 +0.988472580909729,0.81600022315979,0.6296193599700928,0.019999999552965164 +0.9841001629829407,0.8400164246559143,-0.06806250661611557,0.00390625 +0.9276565313339233,0.32582467794418335,-0.14148345589637756,0.019999999552965164 +0.7008209228515625,0.545078694820404,1.1250351667404175,0.019999999552965164 +0.9907881021499634,0.9919379353523254,-0.12143492698669434,0.019999999552965164 +0.9702130556106567,0.7762024402618408,0.24524429440498352,0.0078125 +0.9876235723495483,0.7181832790374756,0.41931474208831787,0.019999999552965164 +0.9841905236244202,0.8836563229560852,0.28947240114212036,0.00390625 +0.990247905254364,0.9825950860977173,0.6003378033638,0.00390625 +0.9635987281799316,0.3707619905471802,-0.03457726538181305,0.0390625 +0.9924789071083069,1.485293984413147,0.5796234607696533,0.00390625 +0.9839015603065491,0.06343062222003937,1.9442640542984009,0.5 +0.9927193522453308,0.7006005048751831,0.3714500069618225,0.019999999552965164 +0.9870567321777344,0.869498610496521,1.5008329153060913,0.00390625 +0.9002388119697571,0.4945279657840729,-0.27996397018432617,0.0078125 +0.98891282081604,0.8541091680526733,0.5112633109092712,0.66796875 +0.9001862406730652,0.43330734968185425,0.3592444360256195,0.00390625 +0.958705723285675,0.7425220012664795,0.15833647549152374,0.00390625 +0.9910086989402771,0.9245886206626892,0.8454338908195496,0.01953125 +0.9912900328636169,1.3806378841400146,1.0953043699264526,0.99609375 +0.9887956976890564,1.0331758260726929,0.6490115523338318,0.640625 +0.8638584017753601,0.902369499206543,-0.2767508327960968,0.0078125 +0.7059138417243958,1.0,1.032223091723683e-11,0.019999999552965164 +0.9889519810676575,0.8361310362815857,0.811896800994873,0.03515625 +0.970467209815979,0.07315781712532043,0.20799599587917328,0.00390625 +0.9828550219535828,0.8393198251724243,0.6089786291122437,0.28515625 +0.9553551077842712,0.7775288820266724,-0.4464336037635803,0.046875 +0.9782186150550842,0.4313304126262665,0.4458310604095459,0.019999999552965164 +0.9371097087860107,0.9338632225990295,1.3358187675476074,0.019999999552965164 +0.9861361384391785,0.24091234803199768,1.4301774501800537,0.80078125 +0.9890525341033936,1.1365840435028076,0.3055979013442993,0.00390625 +0.957517683506012,0.058012738823890686,0.15909947454929352,0.046875 +0.9762251377105713,0.72292160987854,0.49151331186294556,0.019999999552965164 +0.9875496625900269,0.9114606976509094,-0.5052767992019653,0.05859375 +0.9715835452079773,0.8113637566566467,-2.0302956104278564,0.019999999552965164 +0.9846333265304565,0.49688151478767395,0.7285738587379456,0.019999999552965164 +0.98553466796875,0.1484774351119995,1.3616747856140137,0.5859375 +0.9866309762001038,1.0217945575714111,-0.8717418313026428,0.02734375 +0.9891880750656128,0.42588523030281067,0.7833192944526672,0.109375 +0.9870361685752869,0.8525673151016235,1.2773776054382324,0.019999999552965164 +0.9897037744522095,0.8012522459030151,0.3973642885684967,0.109375 +0.9828903079032898,1.1558295488357544,-0.6781614422798157,0.5859375 +0.9924454689025879,1.1040401458740234,1.3243318796157837,0.019999999552965164 +0.9826735258102417,1.0064337253570557,-0.5324167013168335,0.38671875 +0.949999988079071,0.8152432441711426,0.6293236613273621,0.00390625 +0.9905489087104797,0.9191447496414185,0.5621309876441956,0.019999999552965164 +0.9664857387542725,0.5995981693267822,-0.7409313321113586,0.01171875 +0.9847198724746704,0.8284208178520203,0.2851041555404663,0.9296875 +0.9342833757400513,0.5566492676734924,0.6875373721122742,0.019999999552965164 +0.8894915580749512,0.4102778434753418,0.37977635860443115,0.01953125 +0.9870865941047668,0.44245558977127075,0.16041725873947144,0.10000000149011612 +0.9890456795692444,1.1491310596466064,1.0844204425811768,0.01953125 +0.7304704785346985,0.12790271639823914,-0.1085965558886528,0.019999999552965164 +0.9830618500709534,0.8738722205162048,-0.11583804339170456,0.0234375 +0.9885876178741455,0.744857668876648,0.11028216779232025,0.01953125 +0.9575535655021667,0.3011772632598877,0.5136104226112366,0.00390625 +0.9298899173736572,1.1736249923706055,4.0247297286987305,0.09765625 +0.9907795190811157,1.0897759199142456,0.6261603236198425,0.019999999552965164 +0.9855174422264099,0.6543705463409424,0.08955699950456619,0.08984375 +0.976660430431366,0.5610390901565552,0.6389923095703125,0.0390625 +0.9870068430900574,0.80875563621521,-0.6651867032051086,0.08984375 +0.9652793407440186,0.5887689590454102,0.5353426933288574,0.0703125 +0.9875175952911377,0.7699108123779297,0.876632034778595,0.019999999552965164 +0.9016479849815369,0.9994669556617737,0.30356451869010925,0.015625 +0.989987850189209,0.7350922226905823,0.8748764991760254,0.0078125 +0.983323335647583,0.8931586146354675,1.0226351022720337,0.01171875 +0.9914804100990295,0.9369975328445435,0.8283791542053223,0.019999999552965164 +0.9704275727272034,1.124052882194519,0.9457330107688904,0.019999999552965164 +0.9867291450500488,0.9667392373085022,-0.6122757196426392,0.44140625 +0.9887421131134033,0.7823470234870911,0.343982458114624,0.00390625 +0.9861542582511902,0.9171664118766785,0.35665032267570496,0.019999999552965164 +0.9772396683692932,0.08705096691846848,1.7621256113052368,0.66796875 +0.9819098114967346,0.8605496883392334,0.5151250958442688,0.01171875 +0.982971727848053,0.5631197690963745,1.608361005783081,0.019999999552965164 + +0.9914254546165466,0.3850722908973694,1.4068152904510498,0.98828125 +0.9880355596542358,1.1387118101119995,1.4653834104537964,0.05859375 +0.9586950540542603,1.7633997201919556,1.0344760417938232,0.019999999552965164 +0.9828103184700012,0.8817474842071533,0.7680216431617737,0.890625 +0.9880233407020569,0.899823784828186,0.44692227244377136,0.19921875 +0.9862816333770752,0.8610615134239197,0.4195229709148407,0.03125 +0.9813369512557983,0.8014124631881714,1.1136316061019897,0.0078125 +0.9148907661437988,0.5909111499786377,1.2860896587371826,0.015625 +0.9865161776542664,0.8720636963844299,0.6233670115470886,0.015625 +0.9786784648895264,0.48225611448287964,-0.005022380966693163,0.12109375 +0.9843324422836304,1.0519789457321167,-2.2056643962860107,0.03125 +0.9688847064971924,0.8007095456123352,0.14495795965194702,0.1640625 +0.9724696278572083,0.9987169504165649,0.32869264483451843,0.019999999552965164 +0.9875112175941467,1.0948023796081543,2.15657114982605,0.03125 +0.9923174381256104,0.10759950429201126,0.6762840747833252,0.019999999552965164 +0.9666666388511658,0.6234443783760071,1.4971232414245605,0.0390625 +0.989655613899231,0.8248854279518127,0.4701078534126282,0.019999999552965164 +0.9753870368003845,0.6746605634689331,-0.23550045490264893,0.1640625 +0.9170913100242615,1.0504746437072754,2.7344093322753906,0.019999999552965164 +0.9821392297744751,1.4154850244522095,1.2012253999710083,0.019999999552965164 +0.9886221885681152,1.22860586643219,1.160277009010315,0.890625 +0.9877735376358032,0.6805673837661743,1.5975077152252197,0.359375 +0.9831939339637756,0.6648986339569092,1.1059051752090454,0.28515625 +0.950076162815094,0.724887490272522,0.316800057888031,0.019999999552965164 +0.9817547798156738,0.8619367480278015,-0.24251239001750946,0.109375 +0.9849069714546204,0.8399055004119873,1.7567216157913208,0.4000000059604645 +0.9821556806564331,0.8135135769844055,0.33616918325424194,0.0078125 +0.8329862356185913,0.7938078045845032,1.0597797632217407,0.019999999552965164 +0.9856904149055481,0.05120579153299332,0.8267747759819031,0.5 +0.9766159057617188,0.7623113989830017,0.7656452059745789,0.09765625 +0.9885436296463013,0.9814053177833557,0.05546858534216881,0.00390625 +0.9900276064872742,0.9320858716964722,-0.36458709836006165,0.03125 +0.9058290123939514,0.7260504364967346,1.1726433038711548,0.019999999552965164 +0.9503811597824097,0.6632846593856812,0.7332696914672852,0.019999999552965164 +0.9846004247665405,0.6996731758117676,-0.8613988757133484,0.019999999552965164 +0.9897956252098083,0.8407823443412781,1.2952353954315186,0.76171875 +0.9898385405540466,0.7309674024581909,0.7317643761634827,0.019999999552965164 +0.9850022196769714,0.7537633180618286,0.3925366699695587,0.03125 +0.9858620762825012,0.9250133633613586,2.0220303535461426,0.9296875 +0.8120821714401245,0.3994182348251343,-0.4576922655105591,0.019999999552965164 +0.9496838450431824,0.8251343965530396,0.15125347673892975,0.019999999552965164 +0.9420520067214966,0.6087028384208679,1.0767998695373535,0.019999999552965164 +0.9899152517318726,0.8887513279914856,0.9602599143981934,0.019999999552965164 +0.9461711049079895,1.1373282670974731,0.6371906995773315,0.00390625 +0.9834751486778259,0.7226889729499817,0.8995278477668762,0.109375 +0.9850850105285645,1.2857465744018555,-2.2220215797424316,0.38671875 +0.9789451956748962,0.9153420925140381,0.12551555037498474,0.01171875 +0.8774109482765198,0.9271970987319946,0.5529487729072571,0.019999999552965164 +0.9074040651321411,0.920030951499939,0.40618932247161865,0.00390625 +0.9878932237625122,0.5347745418548584,0.8865230679512024,0.046875 +0.937852144241333,1.1346293687820435,-0.3324768841266632,0.019999999552965164 +0.7542195916175842,0.44728168845176697,0.45312440395355225,0.019999999552965164 +0.9915731549263,1.3838905096054077,-0.043990228325128555,0.01171875 +0.9284758567810059,0.4973248541355133,0.9887621998786926,0.019999999552965164 +0.9700435400009155,0.8664135336875916,1.0059133768081665,0.046875 +0.9667003750801086,0.7796391844749451,-0.10554620623588562,0.00390625 +0.9698932766914368,0.7340040802955627,0.4837290942668915,0.00390625 +0.973517894744873,0.9678344130516052,0.36683231592178345,0.00390625 +0.9770389795303345,0.8958415389060974,1.2423408031463623,0.015625 +0.9902989864349365,0.7568255066871643,0.9843511581420898,0.019999999552965164 +0.9908176064491272,0.8731094002723694,0.6906698346138,0.00390625 +0.9901729226112366,0.8561913371086121,0.8783953189849854,0.5859375 \ No newline at end of file
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/BUILD new file mode 100644 index 0000000..0f8caca --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/BUILD
@@ -0,0 +1,13 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "test_files", + srcs = glob([ + "*.json", + "*.tflite", + "*.txt", + ]), +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/deeplabv3.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/deeplabv3.json new file mode 100644 index 0000000..c7bf7c5 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/deeplabv3.json
@@ -0,0 +1,66 @@ +{ + "name": "ImageSegmenter", + "description": "Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be segmented.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + -1.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "segmentation_masks", + "description": "Masks over the target objects with high accuracy.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "GRAYSCALE" + }, + "range": { + "min": 1, + "max": 2 + } + }, + "stats": { + }, + "associated_files": [ + { + "name": "labelmap.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/deeplabv3.tflite b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/deeplabv3.tflite new file mode 100644 index 0000000..d2d9a9b0 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/deeplabv3.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/deeplabv3_default.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/deeplabv3_default.json new file mode 100644 index 0000000..18cb12f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/deeplabv3_default.json
@@ -0,0 +1,40 @@ +{ + "name": "ImageSegmenter", + "description": "Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be segmented.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "segmentation_masks", + "description": "Masks over the target objects with high accuracy.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "GRAYSCALE" + }, + "range": { + "min": 1, + "max": 2 + } + }, + "stats": { + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/labelmap.txt b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/labelmap.txt new file mode 100644 index 0000000..204608c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/labelmap.txt
@@ -0,0 +1,21 @@ +background +aeroplane +bicycle +bird +boat +bottle +bus +car +cat +chair +cow +dining table +dog +horse +motorbike +person +potted plant +sheep +sofa +train +tv
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_tensor_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_tensor_meta.json new file mode 100644 index 0000000..834848c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/image_tensor_meta.json
@@ -0,0 +1,35 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input", + "description": "The input tensor.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "locale": "en" + }, + { + "name": "labels_cn.txt", + "locale": "cn" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_audio_tesnor_default_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_audio_tesnor_default_meta.json new file mode 100644 index 0000000..6cfdc83 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_audio_tesnor_default_meta.json
@@ -0,0 +1,17 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "content": { + "content_properties_type": "AudioProperties", + "content_properties": { + } + }, + "stats": { + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_audio_tesnor_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_audio_tesnor_meta.json new file mode 100644 index 0000000..6988dac --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_audio_tesnor_meta.json
@@ -0,0 +1,21 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input text", + "description": "The input string.", + "content": { + "content_properties_type": "AudioProperties", + "content_properties": { + "sample_rate": 10, + "channels": 2 + } + }, + "stats": { + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_image_tensor_float_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_image_tensor_float_meta.json new file mode 100644 index 0000000..2f9b288 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_image_tensor_float_meta.json
@@ -0,0 +1,47 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "The input image.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 0.0, + 127.5, + 255.0 + ], + "std": [ + 127.5, + 127.5, + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 2.0, + 1.0, + 0.0 + ], + "min": [ + 0.0, + -1.0, + -2.0 + ] + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_image_tensor_uint8_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_image_tensor_uint8_meta.json new file mode 100644 index 0000000..fc1d840 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_image_tensor_uint8_meta.json
@@ -0,0 +1,43 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "The input image.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 0.0, + 127.5, + 255.0 + ], + "std": [ + 127.5, + 127.5, + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_image_tensor_unsupported_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_image_tensor_unsupported_meta.json new file mode 100644 index 0000000..09a05aa --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_image_tensor_unsupported_meta.json
@@ -0,0 +1,37 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "The input image.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 0.0, + 127.5, + 255.0 + ], + "std": [ + 127.5, + 127.5, + 127.5 + ] + } + } + ], + "stats": { + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_text_tesnor_default_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_text_tesnor_default_meta.json new file mode 100644 index 0000000..07c23bb --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_text_tesnor_default_meta.json
@@ -0,0 +1,17 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_text_tesnor_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_text_tesnor_meta.json new file mode 100644 index 0000000..04bab2f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/input_text_tesnor_meta.json
@@ -0,0 +1,34 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input text", + "description": "The input string.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "RegexTokenizerOptions", + "options": { + "delim_regex_pattern": "[^\\w\\']+", + "vocab_file": [ + { + "name": "vocab.txt", + "description": "Vocabulary file to convert natural language words to embedding vectors.", + "type": "VOCABULARY" + } + ] + } + } + ], + "stats": { + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/labels.txt b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/labels.txt new file mode 100644 index 0000000..fe81123 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/labels.txt
@@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant.json new file mode 100644 index 0000000..ecb72f8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant.json
@@ -0,0 +1,67 @@ +{ + "name": "ImageClassifier", + "description": "Identify the most prominent object in the image from a known set of categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be classified.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Probabilities of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant.tflite b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant.tflite new file mode 100644 index 0000000..c26ff77 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant_default.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant_default.json new file mode 100644 index 0000000..b1f84e58 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant_default.json
@@ -0,0 +1,17 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input" + } + ], + "output_tensor_metadata": [ + { + "name": "output" + } + ] + } + ], + "min_parser_version": "1.0.0" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant_dummy.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant_dummy.json new file mode 100644 index 0000000..d4f9169 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant_dummy.json
@@ -0,0 +1,18 @@ +{ + "name": "mobilenet_v2_1.0_224_quant", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image" + } + ], + "output_tensor_metadata": [ + { + "name": "probability" + } + ] + } + ], + "min_parser_version": "1.0.0" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant_dummy_no_version.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant_dummy_no_version.json new file mode 100644 index 0000000..52c1541 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant_dummy_no_version.json
@@ -0,0 +1,17 @@ +{ + "name": "mobilenet_v2_1.0_224_quant", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image" + } + ], + "output_tensor_metadata": [ + { + "name": "probability" + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant_meta_info_.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant_meta_info_.json new file mode 100644 index 0000000..1396d15 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/mobilenet_v2_1.0_224_quant_meta_info_.json
@@ -0,0 +1,32 @@ +{ + "name": "mobilenet_v2_1.0_224_quant", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ] + } + ], + "min_parser_version": "1.0.0" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/multi_inputs.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/multi_inputs.json new file mode 100644 index 0000000..64f3066 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/multi_inputs.json
@@ -0,0 +1,47 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "ids", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "mask", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "segment", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "end_logits" + }, + { + "name": "start_logits" + } + ] + } + ], + "min_parser_version": "1.0.0" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/multi_outputs.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/multi_outputs.json new file mode 100644 index 0000000..bdf0870b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/multi_outputs.json
@@ -0,0 +1,34 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "audio" + } + ], + "output_tensor_metadata": [ + { + "name": "Identity", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "Identity 1", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ] + } + ], + "min_parser_version": "1.0.0" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/BUILD new file mode 100644 index 0000000..0f8caca --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/BUILD
@@ -0,0 +1,13 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "test_files", + srcs = glob([ + "*.json", + "*.tflite", + "*.txt", + ]), +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/labels.txt b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/labels.txt new file mode 100644 index 0000000..84b84ff --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/labels.txt
@@ -0,0 +1,2 @@ +Negative +Positive \ No newline at end of file
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/movie_review.tflite b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/movie_review.tflite new file mode 100644 index 0000000..0ec3f5a4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/movie_review.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/movie_review_default.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/movie_review_default.json new file mode 100644 index 0000000..2b06c29 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/movie_review_default.json
@@ -0,0 +1,34 @@ +{ + "name": "NLClassifier", + "description": "Classify the input text into a set of known categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input_text", + "description": "Embedding vectors representing the input text to be classified.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Probabilities of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/movie_review_regex.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/movie_review_regex.json new file mode 100644 index 0000000..de81e42 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/movie_review_regex.json
@@ -0,0 +1,63 @@ +{ + "name": "NLClassifier", + "description": "Classify the input text into a set of known categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input_text", + "description": "Embedding vectors representing the input text to be classified.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "RegexTokenizerOptions", + "options": { + "delim_regex_pattern": "[^\\w\\']+", + "vocab_file": [ + { + "name": "vocab.txt", + "description": "Vocabulary file to convert natural language words to embedding vectors.", + "type": "VOCABULARY" + } + ] + } + } + ], + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "probability", + "description": "Probabilities of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ] + } + ], + "min_parser_version": "1.2.1" +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/vocab.txt b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/vocab.txt new file mode 100644 index 0000000..0a27d7c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/vocab.txt
@@ -0,0 +1,10000 @@ +<PAD> 0 +<START> 1 +<UNKNOWN> 2 +<UNUSED> 3 +the 4 +and 5 +a 6 +of 7 +to 8 +is 9 +br 10 +in 11 +it 12 +i 13 +this 14 +that 15 +was 16 +as 17 +for 18 +with 19 +movie 20 +but 21 +film 22 +on 23 +not 24 +you 25 +are 26 +his 27 +have 28 +he 29 +be 30 +one 31 +all 32 +at 33 +by 34 +an 35 +they 36 +who 37 +so 38 +from 39 +like 40 +her 41 +or 42 +just 43 +about 44 +it's 45 +out 46 +has 47 +if 48 +some 49 +there 50 +what 51 +good 52 +more 53 +when 54 +very 55 +up 56 +no 57 +time 58 +she 59 +even 60 +my 61 +would 62 +which 63 +only 64 +story 65 +really 66 +see 67 +their 68 +had 69 +can 70 +were 71 +me 72 +well 73 +than 74 +we 75 +much 76 +been 77 +bad 78 +get 79 +will 80 +do 81 +also 82 +into 83 +people 84 +other 85 +first 86 +great 87 +because 88 +how 89 +him 90 +most 91 +don't 92 +made 93 +its 94 +then 95 +way 96 +make 97 +them 98 +too 99 +could 100 +any 101 +movies 102 +after 103 +think 104 +characters 105 +watch 106 +two 107 +films 108 +character 109 +seen 110 +many 111 +being 112 +life 113 +plot 114 +never 115 +acting 116 +little 117 +best 118 +love 119 +over 120 +where 121 +did 122 +show 123 +know 124 +off 125 +ever 126 +does 127 +better 128 +your 129 +end 130 +still 131 +man 132 +here 133 +these 134 +say 135 +scene 136 +while 137 +why 138 +scenes 139 +go 140 +such 141 +something 142 +through 143 +should 144 +back 145 +i'm 146 +real 147 +those 148 +watching 149 +now 150 +though 151 +doesn't 152 +years 153 +old 154 +thing 155 +actors 156 +work 157 +10 158 +before 159 +another 160 +didn't 161 +new 162 +funny 163 +nothing 164 +actually 165 +makes 166 +director 167 +look 168 +find 169 +going 170 +few 171 +same 172 +part 173 +again 174 +every 175 +lot 176 +cast 177 +us 178 +quite 179 +down 180 +want 181 +world 182 +things 183 +pretty 184 +young 185 +seems 186 +around 187 +got 188 +horror 189 +however 190 +can't 191 +fact 192 +take 193 +big 194 +enough 195 +long 196 +thought 197 +that's 198 +both 199 +between 200 +series 201 +give 202 +may 203 +original 204 +own 205 +action 206 +i've 207 +right 208 +without 209 +always 210 +times 211 +comedy 212 +point 213 +gets 214 +must 215 +come 216 +role 217 +isn't 218 +saw 219 +almost 220 +interesting 221 +least 222 +family 223 +done 224 +there's 225 +whole 226 +bit 227 +music 228 +script 229 +far 230 +making 231 +guy 232 +anything 233 +minutes 234 +feel 235 +last 236 +since 237 +might 238 +performance 239 +he's 240 +2 241 +probably 242 +kind 243 +am 244 +away 245 +yet 246 +rather 247 +tv 248 +worst 249 +girl 250 +day 251 +sure 252 +fun 253 +hard 254 +woman 255 +played 256 +each 257 +found 258 +anyone 259 +having 260 +although 261 +especially 262 +our 263 +believe 264 +course 265 +comes 266 +looking 267 +screen 268 +trying 269 +set 270 +goes 271 +looks 272 +place 273 +book 274 +different 275 +put 276 +ending 277 +money 278 +maybe 279 +once 280 +sense 281 +reason 282 +true 283 +actor 284 +everything 285 +wasn't 286 +shows 287 +dvd 288 +three 289 +worth 290 +year 291 +job 292 +main 293 +someone 294 +together 295 +watched 296 +play 297 +american 298 +plays 299 +1 300 +said 301 +effects 302 +later 303 +takes 304 +instead 305 +seem 306 +beautiful 307 +john 308 +himself 309 +version 310 +audience 311 +high 312 +house 313 +night 314 +during 315 +everyone 316 +left 317 +special 318 +seeing 319 +half 320 +excellent 321 +wife 322 +star 323 +shot 324 +war 325 +idea 326 +nice 327 +black 328 +less 329 +mind 330 +simply 331 +read 332 +second 333 +else 334 +you're 335 +father 336 +fan 337 +poor 338 +help 339 +completely 340 +death 341 +3 342 +used 343 +home 344 +either 345 +short 346 +line 347 +given 348 +men 349 +top 350 +dead 351 +budget 352 +try 353 +performances 354 +wrong 355 +classic 356 +boring 357 +enjoy 358 +need 359 +rest 360 +use 361 +kids 362 +hollywood 363 +low 364 +production 365 +until 366 +along 367 +full 368 +friends 369 +camera 370 +truly 371 +women 372 +awful 373 +video 374 +next 375 +tell 376 +remember 377 +couple 378 +stupid 379 +start 380 +stars 381 +perhaps 382 +sex 383 +mean 384 +came 385 +recommend 386 +let 387 +moments 388 +wonderful 389 +episode 390 +understand 391 +small 392 +face 393 +terrible 394 +playing 395 +school 396 +getting 397 +written 398 +doing 399 +often 400 +keep 401 +early 402 +name 403 +perfect 404 +style 405 +human 406 +definitely 407 +gives 408 +others 409 +itself 410 +lines 411 +live 412 +become 413 +dialogue 414 +person 415 +lost 416 +finally 417 +piece 418 +head 419 +case 420 +felt 421 +yes 422 +liked 423 +supposed 424 +title 425 +couldn't 426 +absolutely 427 +white 428 +against 429 +boy 430 +picture 431 +sort 432 +worse 433 +certainly 434 +went 435 +entire 436 +waste 437 +cinema 438 +problem 439 +hope 440 +entertaining 441 +she's 442 +mr 443 +overall 444 +evil 445 +called 446 +loved 447 +based 448 +oh 449 +several 450 +fans 451 +mother 452 +drama 453 +beginning 454 +killer 455 +lives 456 +5 457 +direction 458 +care 459 +already 460 +becomes 461 +laugh 462 +example 463 +friend 464 +dark 465 +despite 466 +under 467 +seemed 468 +throughout 469 +4 470 +turn 471 +unfortunately 472 +wanted 473 +i'd 474 +– 475 +children 476 +final 477 +fine 478 +history 479 +amazing 480 +sound 481 +guess 482 +heart 483 +totally 484 +lead 485 +humor 486 +writing 487 +michael 488 +quality 489 +you'll 490 +close 491 +son 492 +guys 493 +wants 494 +works 495 +behind 496 +tries 497 +art 498 +side 499 +game 500 +past 501 +able 502 +b 503 +days 504 +turns 505 +child 506 +they're 507 +hand 508 +flick 509 +enjoyed 510 +act 511 +genre 512 +town 513 +favorite 514 +soon 515 +kill 516 +starts 517 +sometimes 518 +car 519 +gave 520 +run 521 +late 522 +eyes 523 +actress 524 +etc 525 +directed 526 +horrible 527 +won't 528 +viewer 529 +brilliant 530 +parts 531 +self 532 +themselves 533 +hour 534 +expect 535 +thinking 536 +stories 537 +stuff 538 +girls 539 +obviously 540 +blood 541 +decent 542 +city 543 +voice 544 +highly 545 +myself 546 +feeling 547 +fight 548 +except 549 +slow 550 +matter 551 +type 552 +anyway 553 +kid 554 +roles 555 +killed 556 +heard 557 +god 558 +age 559 +says 560 +moment 561 +took 562 +leave 563 +writer 564 +strong 565 +cannot 566 +violence 567 +police 568 +hit 569 +stop 570 +happens 571 +particularly 572 +known 573 +involved 574 +happened 575 +extremely 576 +daughter 577 +obvious 578 +told 579 +chance 580 +living 581 +coming 582 +lack 583 +alone 584 +experience 585 +wouldn't 586 +including 587 +murder 588 +attempt 589 +s 590 +please 591 +james 592 +happen 593 +wonder 594 +crap 595 +ago 596 +brother 597 +film's 598 +gore 599 +none 600 +complete 601 +interest 602 +score 603 +group 604 +cut 605 +simple 606 +save 607 +ok 608 +hell 609 +looked 610 +career 611 +number 612 +song 613 +possible 614 +seriously 615 +annoying 616 +shown 617 +exactly 618 +sad 619 +running 620 +musical 621 +serious 622 +taken 623 +yourself 624 +whose 625 +released 626 +cinematography 627 +david 628 +scary 629 +ends 630 +english 631 +hero 632 +usually 633 +hours 634 +reality 635 +opening 636 +i'll 637 +across 638 +today 639 +jokes 640 +light 641 +hilarious 642 +somewhat 643 +usual 644 +started 645 +cool 646 +ridiculous 647 +body 648 +relationship 649 +view 650 +level 651 +opinion 652 +change 653 +happy 654 +middle 655 +taking 656 +wish 657 +husband 658 +finds 659 +saying 660 +order 661 +talking 662 +ones 663 +documentary 664 +shots 665 +huge 666 +novel 667 +female 668 +mostly 669 +robert 670 +power 671 +episodes 672 +room 673 +important 674 +rating 675 +talent 676 +five 677 +major 678 +turned 679 +strange 680 +word 681 +modern 682 +call 683 +apparently 684 +disappointed 685 +single 686 +events 687 +due 688 +four 689 +songs 690 +basically 691 +attention 692 +7 693 +knows 694 +clearly 695 +supporting 696 +knew 697 +british 698 +television 699 +comic 700 +non 701 +fast 702 +earth 703 +country 704 +future 705 +cheap 706 +class 707 +thriller 708 +8 709 +silly 710 +king 711 +problems 712 +aren't 713 +easily 714 +words 715 +tells 716 +miss 717 +jack 718 +local 719 +sequence 720 +bring 721 +entertainment 722 +paul 723 +beyond 724 +upon 725 +whether 726 +predictable 727 +moving 728 +similar 729 +straight 730 +romantic 731 +sets 732 +review 733 +falls 734 +oscar 735 +mystery 736 +enjoyable 737 +needs 738 +appears 739 +talk 740 +rock 741 +george 742 +giving 743 +eye 744 +richard 745 +within 746 +ten 747 +animation 748 +message 749 +theater 750 +near 751 +above 752 +dull 753 +nearly 754 +sequel 755 +theme 756 +points 757 +' 758 +stand 759 +mention 760 +lady 761 +bunch 762 +add 763 +feels 764 +herself 765 +release 766 +red 767 +team 768 +storyline 769 +surprised 770 +ways 771 +using 772 +named 773 +haven't 774 +lots 775 +easy 776 +fantastic 777 +begins 778 +actual 779 +working 780 +effort 781 +york 782 +die 783 +hate 784 +french 785 +minute 786 +tale 787 +clear 788 +stay 789 +9 790 +elements 791 +feature 792 +among 793 +follow 794 +comments 795 +re 796 +viewers 797 +avoid 798 +sister 799 +showing 800 +typical 801 +editing 802 +what's 803 +famous 804 +tried 805 +sorry 806 +dialog 807 +check 808 +fall 809 +period 810 +season 811 +form 812 +certain 813 +filmed 814 +weak 815 +soundtrack 816 +means 817 +buy 818 +material 819 +somehow 820 +realistic 821 +figure 822 +crime 823 +doubt 824 +gone 825 +peter 826 +tom 827 +kept 828 +viewing 829 +t 830 +general 831 +leads 832 +greatest 833 +space 834 +lame 835 +suspense 836 +dance 837 +imagine 838 +brought 839 +third 840 +atmosphere 841 +hear 842 +particular 843 +sequences 844 +whatever 845 +parents 846 +move 847 +lee 848 +indeed 849 +learn 850 +rent 851 +de 852 +eventually 853 +note 854 +deal 855 +average 856 +reviews 857 +wait 858 +forget 859 +japanese 860 +sexual 861 +poorly 862 +premise 863 +okay 864 +zombie 865 +surprise 866 +believable 867 +stage 868 +possibly 869 +sit 870 +who's 871 +decided 872 +expected 873 +you've 874 +subject 875 +nature 876 +became 877 +difficult 878 +free 879 +killing 880 +screenplay 881 +truth 882 +romance 883 +dr 884 +nor 885 +reading 886 +needed 887 +question 888 +leaves 889 +street 890 +20 891 +meets 892 +hot 893 +unless 894 +begin 895 +baby 896 +superb 897 +credits 898 +imdb 899 +otherwise 900 +write 901 +shame 902 +let's 903 +situation 904 +dramatic 905 +memorable 906 +directors 907 +earlier 908 +meet 909 +disney 910 +open 911 +dog 912 +badly 913 +joe 914 +male 915 +weird 916 +acted 917 +forced 918 +laughs 919 +sci 920 +emotional 921 +older 922 +realize 923 +fi 924 +dream 925 +society 926 +writers 927 +interested 928 +footage 929 +forward 930 +comment 931 +crazy 932 +deep 933 +sounds 934 +plus 935 +beauty 936 +whom 937 +america 938 +fantasy 939 +directing 940 +keeps 941 +ask 942 +development 943 +features 944 +air 945 +quickly 946 +mess 947 +creepy 948 +towards 949 +perfectly 950 +mark 951 +worked 952 +box 953 +cheesy 954 +unique 955 +setting 956 +hands 957 +plenty 958 +result 959 +previous 960 +brings 961 +effect 962 +e 963 +total 964 +personal 965 +incredibly 966 +rate 967 +fire 968 +monster 969 +business 970 +leading 971 +apart 972 +casting 973 +admit 974 +joke 975 +powerful 976 +appear 977 +background 978 +telling 979 +girlfriend 980 +meant 981 +christmas 982 +hardly 983 +present 984 +battle 985 +potential 986 +create 987 +bill 988 +break 989 +pay 990 +masterpiece 991 +gay 992 +political 993 +return 994 +dumb 995 +fails 996 +fighting 997 +various 998 +era 999 +portrayed 1000 +co 1001 +cop 1002 +secret 1003 +inside 1004 +outside 1005 +nudity 1006 +reasons 1007 +ideas 1008 +twist 1009 +western 1010 +front 1011 +missing 1012 +boys 1013 +match 1014 +deserves 1015 +jane 1016 +expecting 1017 +fairly 1018 +villain 1019 +talented 1020 +married 1021 +ben 1022 +success 1023 +william 1024 +unlike 1025 +rich 1026 +attempts 1027 +spoilers 1028 +list 1029 +manages 1030 +social 1031 +odd 1032 +recently 1033 +remake 1034 +flat 1035 +cute 1036 +further 1037 +sadly 1038 +copy 1039 +wrote 1040 +agree 1041 +doctor 1042 +cold 1043 +plain 1044 +following 1045 +mentioned 1046 +sweet 1047 +incredible 1048 +missed 1049 +pure 1050 +crew 1051 +office 1052 +wasted 1053 +ended 1054 +produced 1055 +gun 1056 +filmmakers 1057 +large 1058 +caught 1059 +revenge 1060 +filled 1061 +pace 1062 +popular 1063 +waiting 1064 +'the 1065 +members 1066 +science 1067 +decides 1068 +considering 1069 +hold 1070 +public 1071 +cartoon 1072 +party 1073 +tension 1074 +created 1075 +slightly 1076 +uses 1077 +convincing 1078 +compared 1079 +la 1080 +familiar 1081 +neither 1082 +mary 1083 +spent 1084 +sees 1085 +6 1086 +suddenly 1087 +30 1088 +intelligent 1089 +escape 1090 +scott 1091 +fear 1092 +water 1093 +brothers 1094 +d 1095 +clever 1096 +entirely 1097 +kills 1098 +choice 1099 +bored 1100 +language 1101 +moves 1102 +spirit 1103 +laughing 1104 +dancing 1105 +we're 1106 +value 1107 +cover 1108 +credit 1109 +state 1110 +island 1111 +successful 1112 +trouble 1113 +visual 1114 +violent 1115 +ultimately 1116 +century 1117 +singing 1118 +15 1119 +concept 1120 +basic 1121 +italian 1122 +positive 1123 +german 1124 +animated 1125 +biggest 1126 +exciting 1127 +speak 1128 +runs 1129 +store 1130 +died 1131 +cat 1132 +consider 1133 +effective 1134 +walk 1135 +recent 1136 +depth 1137 +former 1138 +amusing 1139 +control 1140 +common 1141 +spend 1142 +band 1143 +appreciate 1144 +zombies 1145 +portrayal 1146 +force 1147 +c 1148 +pointless 1149 +rated 1150 +books 1151 +focus 1152 +hair 1153 +adventure 1154 +younger 1155 +solid 1156 +trash 1157 +adult 1158 +impressive 1159 +follows 1160 +respect 1161 +bizarre 1162 +tone 1163 +law 1164 +super 1165 +amount 1166 +impossible 1167 +mad 1168 +company 1169 +college 1170 +van 1171 +prison 1172 +weren't 1173 +conclusion 1174 +chemistry 1175 +win 1176 +showed 1177 +recommended 1178 +slasher 1179 +producers 1180 +culture 1181 +studio 1182 +fit 1183 +starring 1184 +heavy 1185 +situations 1186 +project 1187 +makers 1188 +trip 1189 +awesome 1190 +accent 1191 +considered 1192 +disturbing 1193 +changed 1194 +sick 1195 +failed 1196 +decide 1197 +somewhere 1198 +won 1199 +leaving 1200 +barely 1201 +honest 1202 +cause 1203 +questions 1204 +shooting 1205 +u 1206 +longer 1207 +post 1208 +f 1209 +anti 1210 +tough 1211 +aside 1212 +ghost 1213 +fake 1214 +cult 1215 +thanks 1216 +meaning 1217 +images 1218 +fiction 1219 +charming 1220 +audiences 1221 +computer 1222 +tony 1223 +brain 1224 +planet 1225 +south 1226 +literally 1227 +generally 1228 +touch 1229 +steve 1230 +stick 1231 +likes 1232 +ex 1233 +values 1234 +pathetic 1235 +magic 1236 +involving 1237 +surprisingly 1238 +alive 1239 +jim 1240 +immediately 1241 +grade 1242 +yeah 1243 +garbage 1244 +100 1245 +dad 1246 +bought 1247 +military 1248 +natural 1249 +camp 1250 +aspect 1251 +honestly 1252 +adaptation 1253 +utterly 1254 +detective 1255 +ability 1256 +fair 1257 +shoot 1258 +smith 1259 +explain 1260 +pick 1261 +genius 1262 +west 1263 +glad 1264 +frank 1265 +sitting 1266 +appearance 1267 +pictures 1268 +week 1269 +motion 1270 +appeal 1271 +army 1272 +standard 1273 +attack 1274 +knowing 1275 +personally 1276 +catch 1277 +drive 1278 +sexy 1279 +normal 1280 +rare 1281 +nowhere 1282 +added 1283 +sam 1284 +humour 1285 +walking 1286 +remains 1287 +purpose 1288 +edge 1289 +comedies 1290 +thinks 1291 +loud 1292 +beautifully 1293 +thank 1294 +silent 1295 +taste 1296 +unbelievable 1297 +naked 1298 +twists 1299 +master 1300 +touching 1301 +subtle 1302 +terms 1303 +date 1304 +equally 1305 +dreams 1306 +terrific 1307 +channel 1308 +drawn 1309 +mood 1310 +journey 1311 +door 1312 +chase 1313 +fully 1314 +complex 1315 +london 1316 +key 1317 +wow 1318 +managed 1319 +road 1320 +narrative 1321 +laughable 1322 +mistake 1323 +bottom 1324 +producer 1325 +themes 1326 +movie's 1327 +pieces 1328 +likely 1329 +climax 1330 +g 1331 +disappointing 1332 +club 1333 +lovely 1334 +harry 1335 +blue 1336 +nobody 1337 +excuse 1338 +outstanding 1339 +soldiers 1340 +issues 1341 +stewart 1342 +constantly 1343 +award 1344 +pass 1345 +thus 1346 +plan 1347 +surely 1348 +marriage 1349 +painful 1350 +justice 1351 +costumes 1352 +presented 1353 +batman 1354 +80's 1355 +innocent 1356 +soul 1357 +wild 1358 +noir 1359 +cinematic 1360 +spoiler 1361 +vampire 1362 +finish 1363 +slowly 1364 +ride 1365 +gang 1366 +contains 1367 +christopher 1368 +presence 1369 +places 1370 +besides 1371 +government 1372 +details 1373 +train 1374 +central 1375 +thrown 1376 +manner 1377 +chris 1378 +historical 1379 +stunning 1380 +photography 1381 +charm 1382 +hoping 1383 +impression 1384 +scenery 1385 +speaking 1386 +disappointment 1387 +loves 1388 +animals 1389 +you'd 1390 +developed 1391 +drug 1392 +smart 1393 +charles 1394 +indian 1395 +numbers 1396 +mysterious 1397 +expectations 1398 +color 1399 +hey 1400 +exception 1401 +throw 1402 +minor 1403 +ahead 1404 +double 1405 +track 1406 +stands 1407 +suppose 1408 +aspects 1409 +boss 1410 +woods 1411 +sent 1412 +festival 1413 +bother 1414 +cry 1415 +church 1416 +feelings 1417 +critics 1418 +green 1419 +brief 1420 +acts 1421 +opera 1422 +filming 1423 +mainly 1424 +support 1425 +emotion 1426 +element 1427 +held 1428 +fascinating 1429 +building 1430 +million 1431 +boyfriend 1432 +names 1433 +opportunity 1434 +serial 1435 +intended 1436 +forever 1437 +emotions 1438 +available 1439 +victim 1440 +charlie 1441 +dies 1442 +changes 1443 +compelling 1444 +bed 1445 +six 1446 +born 1447 +happening 1448 +bar 1449 +paris 1450 +likable 1451 +lived 1452 +twice 1453 +falling 1454 +hotel 1455 +zero 1456 +puts 1457 +tired 1458 +image 1459 +pain 1460 +lover 1461 +everybody 1462 +giant 1463 +offer 1464 +shock 1465 +spot 1466 +suggest 1467 +j 1468 +henry 1469 +include 1470 +confused 1471 +trailer 1472 +adults 1473 +difference 1474 +student 1475 +fresh 1476 +followed 1477 +bruce 1478 +r 1479 +kelly 1480 +hasn't 1481 +appeared 1482 +approach 1483 +victims 1484 +christian 1485 +fellow 1486 +hurt 1487 +impact 1488 +putting 1489 +gorgeous 1490 +step 1491 +sub 1492 +mix 1493 +event 1494 +notice 1495 +murders 1496 +share 1497 +laughed 1498 +confusing 1499 +content 1500 +mediocre 1501 +11 1502 +lacks 1503 +direct 1504 +supposedly 1505 +summer 1506 +actresses 1507 +flaws 1508 +porn 1509 +system 1510 +page 1511 +holes 1512 +wall 1513 +billy 1514 +moral 1515 +jerry 1516 +worthy 1517 +creative 1518 +relationships 1519 +rape 1520 +tragedy 1521 +race 1522 +thin 1523 +lighting 1524 +helps 1525 +random 1526 +answer 1527 +gem 1528 +funniest 1529 +ii 1530 +americans 1531 +jones 1532 +merely 1533 +proves 1534 +wondering 1535 +alien 1536 +students 1537 +ray 1538 +paid 1539 +al 1540 +land 1541 +seven 1542 +damn 1543 +agent 1544 +delivers 1545 +imagination 1546 +park 1547 +childhood 1548 +flying 1549 +hospital 1550 +forgotten 1551 +90 1552 +standards 1553 +flicks 1554 +impressed 1555 +finding 1556 +absolute 1557 +ugly 1558 +beat 1559 +jean 1560 +don 1561 +thoroughly 1562 +ms 1563 +attractive 1564 +ground 1565 +negative 1566 +wise 1567 +provides 1568 +latter 1569 +50 1570 +stuck 1571 +extreme 1572 +seemingly 1573 +seconds 1574 +becoming 1575 +winning 1576 +addition 1577 +reminded 1578 +tragic 1579 +offers 1580 +inspired 1581 +count 1582 +fell 1583 +thats 1584 +lose 1585 +affair 1586 +turning 1587 +folks 1588 +detail 1589 +faces 1590 +cliché 1591 +design 1592 +martin 1593 +collection 1594 +afraid 1595 +intense 1596 +fashion 1597 +pull 1598 +hidden 1599 +industry 1600 +man's 1601 +allen 1602 +apartment 1603 +o 1604 +quick 1605 +nasty 1606 +arthur 1607 +adds 1608 +area 1609 +rented 1610 +alan 1611 +angry 1612 +personality 1613 +artistic 1614 +length 1615 +shouldn't 1616 +therefore 1617 +information 1618 +chinese 1619 +brian 1620 +shocking 1621 +location 1622 +ready 1623 +professional 1624 +lets 1625 +animal 1626 +anymore 1627 +games 1628 +teen 1629 +states 1630 +soldier 1631 +listen 1632 +mom 1633 +describe 1634 +lord 1635 +news 1636 +picked 1637 +led 1638 +wooden 1639 +favourite 1640 +dirty 1641 +mouth 1642 +asks 1643 +food 1644 +deliver 1645 +onto 1646 +martial 1647 +bond 1648 +clothes 1649 +wars 1650 +struggle 1651 +queen 1652 +redeeming 1653 +stone 1654 +jason 1655 +scientist 1656 +p 1657 +wearing 1658 +ed 1659 +stephen 1660 +compare 1661 +castle 1662 +intelligence 1663 +creature 1664 +cross 1665 +sleep 1666 +teenage 1667 +allowed 1668 +wonderfully 1669 +necessary 1670 +carry 1671 +drugs 1672 +40 1673 +tears 1674 +fox 1675 +criminal 1676 +rip 1677 +helped 1678 +member 1679 +desperate 1680 +moved 1681 +sight 1682 +cgi 1683 +trust 1684 +deeply 1685 +roll 1686 +includes 1687 +willing 1688 +whatsoever 1689 +disaster 1690 +12 1691 +machine 1692 +ship 1693 +treat 1694 +began 1695 +mid 1696 +uncle 1697 +grace 1698 +phone 1699 +70's 1700 +williams 1701 +commentary 1702 +build 1703 +accident 1704 +captain 1705 +realized 1706 +plane 1707 +energy 1708 +station 1709 +warning 1710 +epic 1711 +davis 1712 +rarely 1713 +humans 1714 +loving 1715 +theatre 1716 +comedic 1717 +witch 1718 +pop 1719 +suicide 1720 +dying 1721 +powers 1722 +filmmaker 1723 +independent 1724 +introduced 1725 +nightmare 1726 +extra 1727 +engaging 1728 +actions 1729 +character's 1730 +superior 1731 +unusual 1732 +arts 1733 +apparent 1734 +suit 1735 +religious 1736 +heroes 1737 +danny 1738 +remarkable 1739 +artist 1740 +allow 1741 +pleasure 1742 +continue 1743 +unnecessary 1744 +x 1745 +ring 1746 +returns 1747 +physical 1748 +sky 1749 +teacher 1750 +pre 1751 +mental 1752 +watchable 1753 +provide 1754 +absurd 1755 +tim 1756 +memory 1757 +grand 1758 +technical 1759 +normally 1760 +wedding 1761 +desire 1762 +limited 1763 +anywhere 1764 +scared 1765 +russian 1766 +surprising 1767 +douglas 1768 +finished 1769 +brutal 1770 +skip 1771 +vision 1772 +process 1773 +intriguing 1774 +bloody 1775 +media 1776 +holds 1777 +exist 1778 +accept 1779 +nicely 1780 +suspect 1781 +000 1782 +jump 1783 +twenty 1784 +paced 1785 +wanting 1786 +search 1787 +cops 1788 +torture 1789 +growing 1790 +reminds 1791 +jr 1792 +according 1793 +pacing 1794 +legend 1795 +soft 1796 +passion 1797 +andy 1798 +player 1799 +hated 1800 +bits 1801 +fred 1802 +asked 1803 +faith 1804 +joy 1805 +johnny 1806 +clichés 1807 +jeff 1808 +academy 1809 +dressed 1810 +pilot 1811 +eddie 1812 +constant 1813 +anybody 1814 +ill 1815 +deserved 1816 +horse 1817 +gold 1818 +drunk 1819 +joan 1820 +blame 1821 +originally 1822 +explanation 1823 +dangerous 1824 +instance 1825 +smile 1826 +heaven 1827 +heads 1828 +sat 1829 +community 1830 +england 1831 +superman 1832 +deserve 1833 +issue 1834 +nonsense 1835 +met 1836 +dick 1837 +lies 1838 +capture 1839 +gotten 1840 +toward 1841 +kevin 1842 +somebody 1843 +soap 1844 +field 1845 +lovers 1846 +plots 1847 +taylor 1848 +mixed 1849 +players 1850 +nick 1851 +explained 1852 +record 1853 +fail 1854 +creating 1855 +vhs 1856 +knowledge 1857 +quiet 1858 +unknown 1859 +fights 1860 +starting 1861 +friendship 1862 +accurate 1863 +whilst 1864 +guns 1865 +price 1866 +adam 1867 +kate 1868 +hadn't 1869 +sucks 1870 +ball 1871 +river 1872 +floor 1873 +european 1874 +spanish 1875 +wide 1876 +cable 1877 +radio 1878 +fu 1879 +cars 1880 +jackson 1881 +realism 1882 +memories 1883 +moon 1884 +finest 1885 +heroine 1886 +aware 1887 +loose 1888 +eating 1889 +featuring 1890 +prince 1891 +lacking 1892 +responsible 1893 +saved 1894 +keeping 1895 +empty 1896 +understanding 1897 +japan 1898 +treated 1899 +eat 1900 +results 1901 +cuts 1902 +ice 1903 +bland 1904 +terribly 1905 +pulled 1906 +saving 1907 +below 1908 +officer 1909 +villains 1910 +candy 1911 +broken 1912 +sign 1913 +ladies 1914 +hopes 1915 +rubbish 1916 +delightful 1917 +vs 1918 +judge 1919 +witty 1920 +manage 1921 +fat 1922 +mine 1923 +gene 1924 +noticed 1925 +included 1926 +bright 1927 +months 1928 +forces 1929 +screaming 1930 +higher 1931 +kinda 1932 +wind 1933 +tarzan 1934 +cage 1935 +hits 1936 +loss 1937 +today's 1938 +monsters 1939 +youth 1940 +sing 1941 +numerous 1942 +partner 1943 +conflict 1944 +whenever 1945 +humanity 1946 +concerned 1947 +pretentious 1948 +fate 1949 +singer 1950 +dealing 1951 +mike 1952 +driving 1953 +jesus 1954 +private 1955 +talents 1956 +discovered 1957 +naturally 1958 +skills 1959 +unfunny 1960 +opposite 1961 +finale 1962 +bigger 1963 +v 1964 +ann 1965 +international 1966 +dated 1967 +kick 1968 +ups 1969 +prove 1970 +perspective 1971 +morning 1972 +mission 1973 +discover 1974 +portray 1975 +blonde 1976 +here's 1977 +loses 1978 +locations 1979 +visit 1980 +ordinary 1981 +bank 1982 +m 1983 +humorous 1984 +werewolf 1985 +streets 1986 +psychological 1987 +regular 1988 +reviewers 1989 +received 1990 +kong 1991 +w 1992 +edited 1993 +gags 1994 +ass 1995 +luck 1996 +curious 1997 +gary 1998 +continues 1999 +magnificent 2000 +13 2001 +we've 2002 +behavior 2003 +captured 2004 +jimmy 2005 +satire 2006 +survive 2007 +context 2008 +visually 2009 +breaks 2010 +existence 2011 +shallow 2012 +opens 2013 +l 2014 +mrs 2015 +debut 2016 +advice 2017 +calls 2018 +sea 2019 +foot 2020 +morgan 2021 +shop 2022 +h 2023 +murdered 2024 +connection 2025 +core 2026 +essentially 2027 +current 2028 +revealed 2029 +director's 2030 +corny 2031 +remembered 2032 +deals 2033 +blind 2034 +frankly 2035 +occasionally 2036 +lesson 2037 +genuine 2038 +scream 2039 +traditional 2040 +they've 2041 +lucky 2042 +identity 2043 +dimensional 2044 +african 2045 +bob 2046 +anthony 2047 +efforts 2048 +sean 2049 +golden 2050 +learned 2051 +segment 2052 +stock 2053 +window 2054 +cameo 2055 +owner 2056 +visuals 2057 +versions 2058 +village 2059 +albert 2060 +develop 2061 +santa 2062 +formula 2063 +miles 2064 +keaton 2065 +one's 2066 +sucked 2067 +decade 2068 +buddy 2069 +genuinely 2070 +grown 2071 +references 2072 +suffering 2073 +boat 2074 +lewis 2075 +unexpected 2076 +favor 2077 +study 2078 +washington 2079 +allows 2080 +program 2081 +national 2082 +grew 2083 +80s 2084 +proved 2085 +meanwhile 2086 +overly 2087 +ages 2088 +board 2089 +standing 2090 +logic 2091 +desert 2092 +spectacular 2093 +awkward 2094 +ultimate 2095 +comparison 2096 +reaction 2097 +rob 2098 +sheer 2099 +jennifer 2100 +reach 2101 +thomas 2102 +unable 2103 +failure 2104 +brilliantly 2105 +travel 2106 +grant 2107 +ford 2108 +vampires 2109 +types 2110 +parody 2111 +gangster 2112 +devil 2113 +steal 2114 +brown 2115 +passed 2116 +sudden 2117 +stereotypes 2118 +sake 2119 +flesh 2120 +leader 2121 +frame 2122 +bear 2123 +strength 2124 +speed 2125 +creates 2126 +eric 2127 +awards 2128 +laughter 2129 +dan 2130 +technology 2131 +delivered 2132 +author 2133 +bet 2134 +kung 2135 +crappy 2136 +wood 2137 +site 2138 +broadway 2139 +insane 2140 +trek 2141 +executed 2142 +relief 2143 +lake 2144 +hitler 2145 +gonna 2146 +discovers 2147 +emotionally 2148 +painfully 2149 +dreadful 2150 +marie 2151 +utter 2152 +commercial 2153 +decision 2154 +code 2155 +steven 2156 +fault 2157 +anime 2158 +majority 2159 +anne 2160 +round 2161 +pair 2162 +robin 2163 +caused 2164 +bomb 2165 +families 2166 +psycho 2167 +driven 2168 +attitude 2169 +clean 2170 +built 2171 +gratuitous 2172 +harris 2173 +native 2174 +luke 2175 +entertained 2176 +graphic 2177 +ran 2178 +killers 2179 +meeting 2180 +test 2181 +simon 2182 +flashbacks 2183 +underrated 2184 +nevertheless 2185 +model 2186 +seasons 2187 +asian 2188 +foreign 2189 +hill 2190 +levels 2191 +obsessed 2192 +evening 2193 +feet 2194 +halloween 2195 +vehicle 2196 +barbara 2197 +relate 2198 +treatment 2199 +rise 2200 +practically 2201 +range 2202 +endless 2203 +freedom 2204 +costs 2205 +religion 2206 +gory 2207 +cash 2208 +described 2209 +wit 2210 +pleasant 2211 +aged 2212 +ancient 2213 +tape 2214 +reviewer 2215 +center 2216 +president 2217 +chosen 2218 +lynch 2219 +product 2220 +combination 2221 +send 2222 +fly 2223 +seat 2224 +sell 2225 +70s 2226 +irritating 2227 +exploitation 2228 +excited 2229 +stopped 2230 +hearing 2231 +rescue 2232 +fill 2233 +howard 2234 +portrays 2235 +gordon 2236 +assume 2237 +parker 2238 +classics 2239 +pity 2240 +0 2241 +produce 2242 +hunter 2243 +breaking 2244 +dry 2245 +fame 2246 +anna 2247 +generation 2248 +sheriff 2249 +capable 2250 +believes 2251 +handsome 2252 +theatrical 2253 +asking 2254 +sports 2255 +largely 2256 +choose 2257 +theaters 2258 +sympathetic 2259 +extras 2260 +proper 2261 +ruined 2262 +cares 2263 +contrived 2264 +portraying 2265 +drew 2266 +individual 2267 +embarrassing 2268 +rules 2269 +unrealistic 2270 +learns 2271 +warm 2272 +victor 2273 +daniel 2274 +marry 2275 +appealing 2276 +safe 2277 +dubbed 2278 +depressing 2279 +canadian 2280 +freddy 2281 +shakespeare 2282 +recall 2283 +chick 2284 +uk 2285 +winner 2286 +hearted 2287 +contrast 2288 +sequels 2289 +involves 2290 +par 2291 +woody 2292 +crowd 2293 +matters 2294 +k 2295 +correct 2296 +chief 2297 +costume 2298 +haunting 2299 +paper 2300 +research 2301 +vote 2302 +strongly 2303 +heck 2304 +nominated 2305 +grow 2306 +clue 2307 +claim 2308 +facts 2309 +eight 2310 +protagonist 2311 +matt 2312 +rose 2313 +evidence 2314 +joseph 2315 +appropriate 2316 +disgusting 2317 +excitement 2318 +football 2319 +lousy 2320 +germany 2321 +cost 2322 +france 2323 +saturday 2324 +priest 2325 +talks 2326 +substance 2327 +losing 2328 +patrick 2329 +destroy 2330 +circumstances 2331 +tedious 2332 +training 2333 +thoughts 2334 +hunt 2335 +market 2336 +scare 2337 +voices 2338 +promise 2339 +naive 2340 +bringing 2341 +amateurish 2342 +teenager 2343 +angel 2344 +walter 2345 +captures 2346 +convinced 2347 +hanging 2348 +satisfying 2349 +bodies 2350 +united 2351 +fits 2352 +tend 2353 +jackie 2354 +trilogy 2355 +roy 2356 +horribly 2357 +lower 2358 +asleep 2359 +virtually 2360 +baseball 2361 +robot 2362 +hopefully 2363 +rental 2364 +alex 2365 +com 2366 +factor 2367 +haunted 2368 +teenagers 2369 +hall 2370 +walks 2371 +spoil 2372 +creatures 2373 +amateur 2374 +relatively 2375 +steals 2376 +mask 2377 +welcome 2378 +cinderella 2379 +covered 2380 +ryan 2381 +danger 2382 +europe 2383 +insult 2384 +category 2385 +continuity 2386 +mini 2387 +unlikely 2388 +drag 2389 +sinatra 2390 +skin 2391 +contemporary 2392 +louis 2393 +semi 2394 +viewed 2395 +fare 2396 +north 2397 +influence 2398 +depicted 2399 +handled 2400 +target 2401 +oliver 2402 +offensive 2403 +hat 2404 +initial 2405 +nancy 2406 +scale 2407 +lawyer 2408 +tiny 2409 +cutting 2410 +unfortunate 2411 +holding 2412 +witness 2413 +shocked 2414 +africa 2415 +remain 2416 +believed 2417 +fool 2418 +inner 2419 +politics 2420 +hide 2421 +reporter 2422 +presents 2423 +section 2424 +movement 2425 +provided 2426 +surreal 2427 +promising 2428 +designed 2429 +makeup 2430 +max 2431 +qualities 2432 +liners 2433 +refreshing 2434 +australian 2435 +source 2436 +14 2437 +structure 2438 +closer 2439 +drop 2440 +forgettable 2441 +touches 2442 +welles 2443 +display 2444 +angles 2445 +pile 2446 +fairy 2447 +repeated 2448 +till 2449 +texas 2450 +wayne 2451 +claims 2452 +previously 2453 +faced 2454 +sharp 2455 +deaths 2456 +ruin 2457 +accents 2458 +surprises 2459 +universal 2460 +degree 2461 +focused 2462 +propaganda 2463 +plans 2464 +serves 2465 +speaks 2466 +supernatural 2467 +highlight 2468 +service 2469 +peace 2470 +chose 2471 +related 2472 +cartoons 2473 +adventures 2474 +erotic 2475 +25 2476 +roger 2477 +suffers 2478 +blow 2479 +weekend 2480 +sisters 2481 +granted 2482 +mainstream 2483 +latest 2484 +weeks 2485 +prime 2486 +crash 2487 +cant 2488 +professor 2489 +experiences 2490 +speech 2491 +print 2492 +lesbian 2493 +harsh 2494 +deadly 2495 +veteran 2496 +mistakes 2497 +edward 2498 +routine 2499 +whoever 2500 +notch 2501 +uninteresting 2502 +realizes 2503 +invisible 2504 +combined 2505 +sympathy 2506 +accidentally 2507 +kim 2508 +twisted 2509 +brave 2510 +colors 2511 +dollars 2512 +security 2513 +draw 2514 +dogs 2515 +nude 2516 +rain 2517 +universe 2518 +struggling 2519 +dozen 2520 +teens 2521 +convince 2522 +guilty 2523 +path 2524 +appreciated 2525 +atrocious 2526 +mountain 2527 +treasure 2528 +walked 2529 +columbo 2530 +irish 2531 +frightening 2532 +would've 2533 +committed 2534 +aliens 2535 +technically 2536 +recognize 2537 +cowboy 2538 +blah 2539 +birth 2540 +enter 2541 +gritty 2542 +enemy 2543 +aka 2544 +spy 2545 +changing 2546 +magical 2547 +anderson 2548 +princess 2549 +department 2550 +gas 2551 +occasional 2552 +friday 2553 +sword 2554 +directly 2555 +false 2556 +massive 2557 +surface 2558 +narration 2559 +legendary 2560 +featured 2561 +victoria 2562 +anger 2563 +offered 2564 +paint 2565 +performed 2566 +moore 2567 +explains 2568 +abuse 2569 +suspenseful 2570 +vietnam 2571 +kinds 2572 +terror 2573 +experienced 2574 +friendly 2575 +subtitles 2576 +reputation 2577 +crying 2578 +hong 2579 +sorts 2580 +passing 2581 +junk 2582 +beach 2583 +multiple 2584 +forest 2585 +stolen 2586 +everywhere 2587 +figures 2588 +forth 2589 +statement 2590 +exact 2591 +powell 2592 +variety 2593 +required 2594 +clark 2595 +reveal 2596 +donald 2597 +regret 2598 +conversation 2599 +prior 2600 +darkness 2601 +remotely 2602 +execution 2603 +theory 2604 +trapped 2605 +proud 2606 +belief 2607 +urban 2608 +russell 2609 +lonely 2610 +placed 2611 +downright 2612 +wilson 2613 +san 2614 +fictional 2615 +melodrama 2616 +spends 2617 +insight 2618 +court 2619 +effectively 2620 +listening 2621 +grave 2622 +express 2623 +demons 2624 +crude 2625 +figured 2626 +bothered 2627 +abandoned 2628 +scares 2629 +network 2630 +unconvincing 2631 +jobs 2632 +hired 2633 +revolution 2634 +favorites 2635 +jon 2636 +wear 2637 +minds 2638 +metal 2639 +worthwhile 2640 +emma 2641 +california 2642 +dean 2643 +buying 2644 +blockbuster 2645 +lifetime 2646 +bus 2647 +paying 2648 +pulls 2649 +account 2650 +angle 2651 +happiness 2652 +von 2653 +blown 2654 +afternoon 2655 +imagery 2656 +rights 2657 +driver 2658 +alright 2659 +rolling 2660 +matrix 2661 +mexican 2662 +productions 2663 +amazed 2664 +idiot 2665 +rings 2666 +cultural 2667 +status 2668 +delivery 2669 +thankfully 2670 +grim 2671 +reveals 2672 +rule 2673 +stayed 2674 +handed 2675 +alice 2676 +stays 2677 +scenario 2678 +focuses 2679 +ha 2680 +significant 2681 +quest 2682 +rough 2683 +starred 2684 +examples 2685 +julia 2686 +jungle 2687 +sir 2688 +indie 2689 +lights 2690 +mere 2691 +views 2692 +murphy 2693 +shadow 2694 +sarah 2695 +bore 2696 +con 2697 +teeth 2698 +heavily 2699 +mature 2700 +device 2701 +table 2702 +skill 2703 +interview 2704 +caine 2705 +tight 2706 +necessarily 2707 +he'd 2708 +ron 2709 +sunday 2710 +clichéd 2711 +suffer 2712 +mexico 2713 +china 2714 +achieve 2715 +spite 2716 +understood 2717 +format 2718 +artists 2719 +position 2720 +initially 2721 +closing 2722 +campy 2723 +desperately 2724 +bound 2725 +fabulous 2726 +dress 2727 +sensitive 2728 +mgm 2729 +destroyed 2730 +hip 2731 +complicated 2732 +burns 2733 +demon 2734 +summary 2735 +seek 2736 +faithful 2737 +forgot 2738 +sun 2739 +decades 2740 +breath 2741 +gross 2742 +pitt 2743 +bourne 2744 +ghosts 2745 +titanic 2746 +cruel 2747 +murderer 2748 +stereotypical 2749 +deeper 2750 +lisa 2751 +facial 2752 +renting 2753 +ignore 2754 +pregnant 2755 +league 2756 +answers 2757 +racist 2758 +un 2759 +helping 2760 +ludicrous 2761 +beloved 2762 +flashback 2763 +slapstick 2764 +sleeping 2765 +17 2766 +dude 2767 +cell 2768 +musicals 2769 +fourth 2770 +wing 2771 +intellectual 2772 +beast 2773 +sounded 2774 +settings 2775 +environment 2776 +suck 2777 +critical 2778 +drinking 2779 +nazi 2780 +reminiscent 2781 +brad 2782 +calling 2783 +lugosi 2784 +dragon 2785 +description 2786 +susan 2787 +prefer 2788 +amazingly 2789 +task 2790 +mildly 2791 +pacino 2792 +disbelief 2793 +encounter 2794 +regarding 2795 +larry 2796 +inept 2797 +greater 2798 +learning 2799 +arms 2800 +dennis 2801 +extraordinary 2802 +turkey 2803 +storytelling 2804 +funnier 2805 +julie 2806 +halfway 2807 +ain't 2808 +expert 2809 +base 2810 +criticism 2811 +quirky 2812 +father's 2813 +leslie 2814 +warned 2815 +cabin 2816 +flight 2817 +titles 2818 +criminals 2819 +johnson 2820 +raw 2821 +praise 2822 +depiction 2823 +screening 2824 +throwing 2825 +extent 2826 +expression 2827 +kiss 2828 +jail 2829 +studios 2830 +freeman 2831 +truck 2832 +convey 2833 +originality 2834 +chan 2835 +entertain 2836 +choices 2837 +spoof 2838 +notorious 2839 +tree 2840 +raised 2841 +touched 2842 +children's 2843 +rachel 2844 +punch 2845 +experiment 2846 +daughters 2847 +prepared 2848 +comical 2849 +spoken 2850 +people's 2851 +timing 2852 +india 2853 +headed 2854 +purely 2855 +could've 2856 +basis 2857 +hoffman 2858 +bollywood 2859 +chilling 2860 +michelle 2861 +underground 2862 +dollar 2863 +via 2864 +picks 2865 +lie 2866 +inspiration 2867 +novels 2868 +wave 2869 +elizabeth 2870 +introduction 2871 +weapons 2872 +trick 2873 +lazy 2874 +jessica 2875 +graphics 2876 +breathtaking 2877 +notable 2878 +stomach 2879 +succeeds 2880 +term 2881 +crafted 2882 +join 2883 +throws 2884 +handle 2885 +strangely 2886 +properly 2887 +toy 2888 +nowadays 2889 +christ 2890 +sidney 2891 +reference 2892 +adding 2893 +claire 2894 +serve 2895 +ratings 2896 +locked 2897 +honor 2898 +wears 2899 +sitcom 2900 +ted 2901 +authentic 2902 +foster 2903 +regard 2904 +everyday 2905 +causes 2906 +maria 2907 +provoking 2908 +charge 2909 +protect 2910 +lesser 2911 +hitchcock 2912 +caring 2913 +mouse 2914 +mirror 2915 +bat 2916 +fallen 2917 +carrying 2918 +bitter 2919 +jewish 2920 +established 2921 +pet 2922 +amongst 2923 +east 2924 +shut 2925 +guard 2926 +midnight 2927 +sleazy 2928 +southern 2929 +determined 2930 +ned 2931 +challenge 2932 +daily 2933 +obnoxious 2934 +nonetheless 2935 +cases 2936 +carried 2937 +carries 2938 +wins 2939 +alas 2940 +remote 2941 +embarrassed 2942 +gruesome 2943 +hole 2944 +2006 2945 +lane 2946 +attempting 2947 +westerns 2948 +escapes 2949 +sinister 2950 +confusion 2951 +nation 2952 +tales 2953 +ironic 2954 +tradition 2955 +interpretation 2956 +arrives 2957 +busy 2958 +replaced 2959 +risk 2960 +enjoying 2961 +sold 2962 +essential 2963 +needless 2964 +aunt 2965 +hardy 2966 +burt 2967 +goofy 2968 +mass 2969 +obsession 2970 +minded 2971 +balance 2972 +flow 2973 +clips 2974 +existent 2975 +successfully 2976 +legs 2977 +presentation 2978 +screenwriter 2979 +jumps 2980 +exists 2981 +attacked 2982 +blair 2983 +laid 2984 +mentally 2985 +bbc 2986 +seeking 2987 +raise 2988 +topic 2989 +oddly 2990 +warner 2991 +inspector 2992 +horrific 2993 +fortunately 2994 +shape 2995 +marvelous 2996 +usa 2997 +intentions 2998 +buck 2999 +retarded 3000 +madness 3001 +stupidity 3002 +stops 3003 +text 3004 +stylish 3005 +stanley 3006 +che 3007 +rival 3008 +served 3009 +workers 3010 +maker 3011 +sides 3012 +ashamed 3013 +shower 3014 +packed 3015 +comedian 3016 +thrilling 3017 +wwii 3018 +interviews 3019 +nine 3020 +laura 3021 +frequently 3022 +upper 3023 +mob 3024 +mansion 3025 +bridge 3026 +remind 3027 +tongue 3028 +navy 3029 +wanna 3030 +contain 3031 +albeit 3032 +intensity 3033 +attacks 3034 +vacation 3035 +thief 3036 +delight 3037 +manager 3038 +chair 3039 +sum 3040 +hence 3041 +80 3042 +cheese 3043 +drives 3044 +2001 3045 +expressions 3046 +struggles 3047 +flawed 3048 +poignant 3049 +angels 3050 +personalities 3051 +rogers 3052 +riding 3053 +revolves 3054 +refuses 3055 +adapted 3056 +opened 3057 +greatly 3058 +credibility 3059 +philip 3060 +cooper 3061 +glass 3062 +pitch 3063 +tracy 3064 +1950s 3065 +jay 3066 +torn 3067 +dinner 3068 +bette 3069 +18 3070 +cynical 3071 +upset 3072 +pool 3073 +sin 3074 +tour 3075 +2000 3076 +internet 3077 +suspects 3078 +advantage 3079 +lessons 3080 +warn 3081 +lion 3082 +overcome 3083 +credible 3084 +wishes 3085 +thousands 3086 +spin 3087 +miller 3088 +racism 3089 +90's 3090 +mindless 3091 +wealthy 3092 +innocence 3093 +tense 3094 +broke 3095 +bugs 3096 +happily 3097 +catholic 3098 +guessing 3099 +trial 3100 +lucy 3101 +hood 3102 +hundreds 3103 +trite 3104 +physically 3105 +thrillers 3106 +cook 3107 +fish 3108 +alike 3109 +dubbing 3110 +fbi 3111 +crisis 3112 +per 3113 +pride 3114 +succeed 3115 +controversial 3116 +suffered 3117 +reed 3118 +bag 3119 +technique 3120 +wasting 3121 +dislike 3122 +medical 3123 +sexuality 3124 +countries 3125 +perform 3126 +patient 3127 +stranger 3128 +enjoyment 3129 +corner 3130 +arm 3131 +glimpse 3132 +gripping 3133 +reunion 3134 +franchise 3135 +holmes 3136 +ensemble 3137 +separate 3138 +hundred 3139 +lincoln 3140 +60's 3141 +sings 3142 +noble 3143 +shines 3144 +whereas 3145 +tied 3146 +ourselves 3147 +uncomfortable 3148 +infamous 3149 +neat 3150 +atmospheric 3151 +millions 3152 +shorts 3153 +contact 3154 +card 3155 +hint 3156 +pack 3157 +courage 3158 +irony 3159 +exceptional 3160 +plastic 3161 +storm 3162 +drink 3163 +ralph 3164 +searching 3165 +oscars 3166 +scripts 3167 +connected 3168 +italy 3169 +proof 3170 +sandler 3171 +snow 3172 +lying 3173 +flash 3174 +nose 3175 +curse 3176 +helen 3177 +sentimental 3178 +mst3k 3179 +grey 3180 +aired 3181 +holiday 3182 +steps 3183 +hills 3184 +performers 3185 +letting 3186 +chasing 3187 +suggests 3188 +dancer 3189 +tune 3190 +meaningful 3191 +idiotic 3192 +knife 3193 +quote 3194 +weapon 3195 +plague 3196 +sons 3197 +entry 3198 +kurt 3199 +fortune 3200 +cameos 3201 +consists 3202 +perfection 3203 +lovable 3204 +hoped 3205 +troubled 3206 +thousand 3207 +hiding 3208 +develops 3209 +unforgettable 3210 +accepted 3211 +noted 3212 +portrait 3213 +dear 3214 +equal 3215 +bettie 3216 +assistant 3217 +stretch 3218 +woman's 3219 +saves 3220 +colorful 3221 +annoyed 3222 +larger 3223 +attraction 3224 +condition 3225 +miscast 3226 +chases 3227 +brooks 3228 +virgin 3229 +spots 3230 +basement 3231 +host 3232 +dialogs 3233 +shoots 3234 +gain 3235 +horses 3236 +guilt 3237 +protagonists 3238 +oil 3239 +terrifying 3240 +month 3241 +cousin 3242 +neighborhood 3243 +vincent 3244 +pg 3245 +belongs 3246 +stealing 3247 +16 3248 +nelson 3249 +worry 3250 +burning 3251 +concert 3252 +ad 3253 +zone 3254 +strip 3255 +appearing 3256 +worlds 3257 +object 3258 +split 3259 +repeat 3260 +hang 3261 +boredom 3262 +destruction 3263 +thirty 3264 +redemption 3265 +hunting 3266 +encounters 3267 +imaginative 3268 +expensive 3269 +eerie 3270 +cube 3271 +seagal 3272 +jake 3273 +pie 3274 +competent 3275 +homeless 3276 +concerns 3277 +andrew 3278 +flaw 3279 +closely 3280 +bo 3281 +ultra 3282 +factory 3283 +1st 3284 +multi 3285 +civil 3286 +dramas 3287 +gag 3288 +stunts 3289 +wake 3290 +guts 3291 +sends 3292 +60 3293 +sutherland 3294 +glory 3295 +knock 3296 +matthau 3297 +massacre 3298 +letter 3299 +elsewhere 3300 +achieved 3301 +dig 3302 +checking 3303 +widmark 3304 +hooked 3305 +complaint 3306 +neck 3307 +endearing 3308 +segments 3309 +shark 3310 +sullivan 3311 +rushed 3312 +virus 3313 +ripped 3314 +charisma 3315 +incoherent 3316 +dragged 3317 +beating 3318 +dentist 3319 +essence 3320 +bears 3321 +profound 3322 +library 3323 +weight 3324 +tear 3325 +crimes 3326 +arnold 3327 +dare 3328 +appearances 3329 +solve 3330 +trade 3331 +pat 3332 +24 3333 +stanwyck 3334 +colour 3335 +teach 3336 +dorothy 3337 +roberts 3338 +rocks 3339 +fest 3340 +spell 3341 +catherine 3342 +dealt 3343 +stan 3344 +fitting 3345 +hitting 3346 +striking 3347 +pro 3348 +2005 3349 +tribute 3350 +tricks 3351 +60s 3352 +battles 3353 +believing 3354 +briefly 3355 +countless 3356 +fashioned 3357 +loser 3358 +goal 3359 +gothic 3360 +noise 3361 +techniques 3362 +n 3363 +videos 3364 +health 3365 +thumbs 3366 +attempted 3367 +scientists 3368 +st 3369 +painting 3370 +baker 3371 +strikes 3372 +inspiring 3373 +huh 3374 +sexually 3375 +birthday 3376 +secretary 3377 +curtis 3378 +jeremy 3379 +covers 3380 +pointed 3381 +slight 3382 +specific 3383 +tea 3384 +hearts 3385 +unintentionally 3386 +denzel 3387 +horrendous 3388 +charismatic 3389 +silver 3390 +surrounded 3391 +surrounding 3392 +reactions 3393 +branagh 3394 +importance 3395 +rochester 3396 +admittedly 3397 +carefully 3398 +jerk 3399 +tons 3400 +hype 3401 +relevant 3402 +they'd 3403 +walls 3404 +stood 3405 +eyed 3406 +bible 3407 +corrupt 3408 +rush 3409 +stunt 3410 +revelation 3411 +smoking 3412 +magazine 3413 +lloyd 3414 +kicks 3415 +karloff 3416 +stronger 3417 +grows 3418 +mild 3419 +hamlet 3420 +represents 3421 +dawn 3422 +andrews 3423 +intention 3424 +easier 3425 +enters 3426 +spending 3427 +scooby 3428 +fired 3429 +killings 3430 +stated 3431 +chances 3432 +shall 3433 +brand 3434 +exercise 3435 +university 3436 +increasingly 3437 +row 3438 +disagree 3439 +cardboard 3440 +winter 3441 +comics 3442 +requires 3443 +dropped 3444 +associated 3445 +world's 3446 +chuck 3447 +iii 3448 +medium 3449 +bush 3450 +projects 3451 +bride 3452 +occurs 3453 +korean 3454 +inevitable 3455 +messages 3456 +brando 3457 +le 3458 +strike 3459 +poverty 3460 +forgive 3461 +performing 3462 +stiff 3463 +attached 3464 +drags 3465 +luckily 3466 +ian 3467 +identify 3468 +1970s 3469 +gift 3470 +bobby 3471 +acceptable 3472 +resolution 3473 +eva 3474 +typically 3475 +canada 3476 +guest 3477 +nuclear 3478 +elvis 3479 +toilet 3480 +strictly 3481 +vague 3482 +spike 3483 +contract 3484 +hire 3485 +1980s 3486 +thrills 3487 +selling 3488 +hudson 3489 +homage 3490 +lab 3491 +boll 3492 +mafia 3493 +depression 3494 +sophisticated 3495 +fifteen 3496 +disease 3497 +allowing 3498 +brilliance 3499 +investigation 3500 +continued 3501 +struck 3502 +insulting 3503 +worker 3504 +instantly 3505 +useless 3506 +breasts 3507 +barry 3508 +jesse 3509 +sally 3510 +afterwards 3511 +chaplin 3512 +britain 3513 +carter 3514 +executive 3515 +handful 3516 +importantly 3517 +godfather 3518 +estate 3519 +hanks 3520 +pleased 3521 +overlooked 3522 +evident 3523 +burn 3524 +gotta 3525 +wreck 3526 +nights 3527 +2002 3528 +beings 3529 +ego 3530 +kidnapped 3531 +presumably 3532 +competition 3533 +press 3534 +partly 3535 +digital 3536 +shining 3537 +commit 3538 +tremendous 3539 +raped 3540 +menacing 3541 +silence 3542 +talked 3543 +derek 3544 +worthless 3545 +jamie 3546 +realise 3547 +ambitious 3548 +meat 3549 +wondered 3550 +photographed 3551 +sacrifice 3552 +arrested 3553 +buried 3554 +burton 3555 +threatening 3556 +smooth 3557 +aforementioned 3558 +superbly 3559 +boxing 3560 +kane 3561 +flawless 3562 +regardless 3563 +fears 3564 +creation 3565 +shy 3566 +heat 3567 +highlights 3568 +savage 3569 +persona 3570 +frustrated 3571 +drivel 3572 +conspiracy 3573 +individuals 3574 +wonders 3575 +listed 3576 +appalling 3577 +doc 3578 +'s 3579 +spiritual 3580 +pushed 3581 +returning 3582 +jumping 3583 +elvira 3584 +cox 3585 +corpse 3586 +size 3587 +characterization 3588 +bullets 3589 +walken 3590 +generous 3591 +string 3592 +rex 3593 +doors 3594 +pleasantly 3595 +bucks 3596 +relative 3597 +45 3598 +outrageous 3599 +kudos 3600 +planning 3601 +ticket 3602 +achievement 3603 +accomplished 3604 +miserably 3605 +monkey 3606 +beaten 3607 +neighbor 3608 +distant 3609 +fatal 3610 +repetitive 3611 +accused 3612 +picking 3613 +ironically 3614 +consequences 3615 +curiosity 3616 +union 3617 +admire 3618 +guide 3619 +splendid 3620 +prevent 3621 +reynolds 3622 +border 3623 +attracted 3624 +butt 3625 +clues 3626 +trap 3627 +notes 3628 +chain 3629 +opposed 3630 +watches 3631 +samurai 3632 +shortly 3633 +heston 3634 +twin 3635 +cole 3636 +glover 3637 +slightest 3638 +response 3639 +beer 3640 +territory 3641 +spooky 3642 +diamond 3643 +rap 3644 +horrors 3645 +20th 3646 +cup 3647 +dire 3648 +spirited 3649 +melodramatic 3650 +lucas 3651 +flynn 3652 +los 3653 +piano 3654 +push 3655 +revealing 3656 +spoiled 3657 +uninspired 3658 +ritter 3659 +convoluted 3660 +pulling 3661 +ken 3662 +root 3663 +they'll 3664 +streisand 3665 +motivation 3666 +directorial 3667 +installment 3668 +precious 3669 +titled 3670 +logical 3671 +documentaries 3672 +spring 3673 +lacked 3674 +suits 3675 +tall 3676 +subplot 3677 +mate 3678 +timeless 3679 +hatred 3680 +throat 3681 +blows 3682 +jealous 3683 +creators 3684 +blank 3685 +farce 3686 +spielberg 3687 +slap 3688 +ward 3689 +carol 3690 +subsequent 3691 +cared 3692 +mile 3693 +exaggerated 3694 +duke 3695 +morality 3696 +liberal 3697 +francisco 3698 +indians 3699 +psychotic 3700 +overdone 3701 +psychiatrist 3702 +astaire 3703 +intrigued 3704 +jet 3705 +blob 3706 +50's 3707 +conceived 3708 +fx 3709 +neil 3710 +aimed 3711 +remaining 3712 +doo 3713 +ignored 3714 +elderly 3715 +reasonably 3716 +mitchell 3717 +failing 3718 +sole 3719 +obscure 3720 +drunken 3721 +minimal 3722 +temple 3723 +progress 3724 +fancy 3725 +captivating 3726 +repeatedly 3727 +wes 3728 +tunes 3729 +shoes 3730 +grandmother 3731 +cia 3732 +nurse 3733 +marks 3734 +notably 3735 +emily 3736 +soviet 3737 +shirt 3738 +explore 3739 +smoke 3740 +souls 3741 +pushing 3742 +argument 3743 +distance 3744 +warrior 3745 +outcome 3746 +reduced 3747 +loosely 3748 +scientific 3749 +goldberg 3750 +gradually 3751 +bleak 3752 +timothy 3753 +manhattan 3754 +idiots 3755 +restaurant 3756 +scripted 3757 +misses 3758 +explicit 3759 +providing 3760 +elaborate 3761 +poster 3762 +lou 3763 +dignity 3764 +carpenter 3765 +norman 3766 +rid 3767 +turner 3768 +show's 3769 +davies 3770 +draws 3771 +discussion 3772 +exposed 3773 +mel 3774 +sticks 3775 +kenneth 3776 +definite 3777 +darker 3778 +laurel 3779 +intent 3780 +1950's 3781 +returned 3782 +superhero 3783 +sloppy 3784 +cried 3785 +worried 3786 +childish 3787 +shadows 3788 +craig 3789 +cruise 3790 +hysterical 3791 +imagined 3792 +reasonable 3793 +editor 3794 +ah 3795 +birds 3796 +horrid 3797 +areas 3798 +wicked 3799 +gentle 3800 +wannabe 3801 +alexander 3802 +thick 3803 +contrary 3804 +joey 3805 +empire 3806 +connect 3807 +discovery 3808 +unbearable 3809 +tortured 3810 +screams 3811 +fever 3812 +unbelievably 3813 +1930s 3814 +disc 3815 +99 3816 +load 3817 +heroic 3818 +absence 3819 +reached 3820 +ho 3821 +choreography 3822 +triumph 3823 +complain 3824 +annie 3825 +broad 3826 +improved 3827 +concerning 3828 +brazil 3829 +movements 3830 +2003 3831 +2004 3832 +dave 3833 +folk 3834 +eve 3835 +purple 3836 +commercials 3837 +futuristic 3838 +vicious 3839 +gray 3840 +freak 3841 +threat 3842 +cusack 3843 +extended 3844 +citizen 3845 +stole 3846 +anyways 3847 +glenn 3848 +existed 3849 +cheek 3850 +broadcast 3851 +photographer 3852 +translation 3853 +arrive 3854 +differences 3855 +displays 3856 +critic 3857 +slave 3858 +landscape 3859 +occurred 3860 +builds 3861 +drawing 3862 +incident 3863 +warren 3864 +burned 3865 +involvement 3866 +styles 3867 +bathroom 3868 +machines 3869 +narrator 3870 +antics 3871 +he'll 3872 +fisher 3873 +swear 3874 +australia 3875 +matthew 3876 +resembles 3877 +lily 3878 +overrated 3879 +currently 3880 +symbolism 3881 +ought 3882 +bare 3883 +audio 3884 +web 3885 +farm 3886 +contained 3887 +greek 3888 +affected 3889 +blend 3890 +q 3891 +recognized 3892 +duo 3893 +genres 3894 +population 3895 +carrie 3896 +ranks 3897 +demands 3898 +we'll 3899 +abc 3900 +prom 3901 +altogether 3902 +superficial 3903 +kitchen 3904 +pseudo 3905 +sunshine 3906 +sadness 3907 +secrets 3908 +bone 3909 +website 3910 +receive 3911 +popcorn 3912 +threw 3913 +craft 3914 +enjoys 3915 +occur 3916 +twelve 3917 +block 3918 +girl's 3919 +proceedings 3920 +dynamic 3921 +daring 3922 +swedish 3923 +argue 3924 +bite 3925 +wolf 3926 +adequate 3927 +investigate 3928 +harder 3929 +ruth 3930 +ridiculously 3931 +tap 3932 +dinosaurs 3933 +hugh 3934 +synopsis 3935 +beats 3936 +carrey 3937 +explosion 3938 +foul 3939 +merit 3940 +suited 3941 +holy 3942 +staged 3943 +journalist 3944 +pretend 3945 +composed 3946 +cagney 3947 +robots 3948 +giallo 3949 +aging 3950 +fay 3951 +sadistic 3952 +engaged 3953 +escaped 3954 +juvenile 3955 +rambo 3956 +ireland 3957 +conversations 3958 +thugs 3959 +modesty 3960 +selfish 3961 +margaret 3962 +dialogues 3963 +ease 3964 +cameras 3965 +tame 3966 +leg 3967 +rural 3968 +comfortable 3969 +nazis 3970 +clothing 3971 +innovative 3972 +terry 3973 +thrill 3974 +2nd 3975 +dancers 3976 +brosnan 3977 +explosions 3978 +bin 3979 +rage 3980 +overwhelming 3981 +jazz 3982 +vivid 3983 +coherent 3984 +bullet 3985 +odds 3986 +mountains 3987 +kidding 3988 +versus 3989 +lit 3990 +offering 3991 +mother's 3992 +trio 3993 +newspaper 3994 +pulp 3995 +ellen 3996 +dawson 3997 +bird 3998 +buddies 3999 +combat 4000 +dracula 4001 +lol 4002 +grab 4003 +orders 4004 +staff 4005 +nearby 4006 +cats 4007 +wealth 4008 +unpleasant 4009 +staying 4010 +devoted 4011 +centered 4012 +errors 4013 +disturbed 4014 +bell 4015 +atlantis 4016 +snake 4017 +felix 4018 +damage 4019 +clint 4020 +lust 4021 +groups 4022 +banned 4023 +blowing 4024 +fighter 4025 +removed 4026 +react 4027 +conventional 4028 +kapoor 4029 +intrigue 4030 +possessed 4031 +cringe 4032 +eyre 4033 +liking 4034 +implausible 4035 +philosophy 4036 +producing 4037 +abilities 4038 +seventies 4039 +bang 4040 +murderous 4041 +deliberately 4042 +gandhi 4043 +tommy 4044 +meaningless 4045 +subjects 4046 +lips 4047 +ingredients 4048 +mildred 4049 +perry 4050 +warming 4051 +causing 4052 +possibility 4053 +detailed 4054 +walker 4055 +garden 4056 +prostitute 4057 +nightmares 4058 +cameron 4059 +flop 4060 +influenced 4061 +spare 4062 +unwatchable 4063 +undoubtedly 4064 +celluloid 4065 +relies 4066 +resemblance 4067 +neo 4068 +parent 4069 +falk 4070 +uneven 4071 +unintentional 4072 +eccentric 4073 +mistaken 4074 +distracting 4075 +careers 4076 +yesterday 4077 +forbidden 4078 +panic 4079 +crack 4080 +brains 4081 +highest 4082 +occasion 4083 +signs 4084 +focusing 4085 +hollow 4086 +explored 4087 +aid 4088 +cary 4089 +scheme 4090 +shine 4091 +it'll 4092 +kirk 4093 +bedroom 4094 +satisfied 4095 +rat 4096 +passes 4097 +survival 4098 +coffee 4099 +furthermore 4100 +primary 4101 +succeeded 4102 +politically 4103 +pays 4104 +apes 4105 +stiller 4106 +dating 4107 +defeat 4108 +sport 4109 +catches 4110 +mickey 4111 +clown 4112 +roman 4113 +discuss 4114 +karen 4115 +clumsy 4116 +chaos 4117 +financial 4118 +official 4119 +trees 4120 +explaining 4121 +models 4122 +spirits 4123 +carl 4124 +jeffrey 4125 +duty 4126 +whale 4127 +funeral 4128 +secondly 4129 +sentence 4130 +2007 4131 +classes 4132 +sidekick 4133 +tracks 4134 +props 4135 +travels 4136 +flies 4137 +remarkably 4138 +smaller 4139 +wallace 4140 +awake 4141 +1996 4142 +brady 4143 +blatant 4144 +decisions 4145 +afford 4146 +notion 4147 +recorded 4148 +glorious 4149 +enterprise 4150 +maggie 4151 +consistently 4152 +toys 4153 +offended 4154 +officers 4155 +danes 4156 +backdrop 4157 +beneath 4158 +masters 4159 +measure 4160 +endings 4161 +doomed 4162 +mysteries 4163 +lifestyle 4164 +houses 4165 +portion 4166 +primarily 4167 +satan 4168 +hates 4169 +devoid 4170 +impress 4171 +outer 4172 +generic 4173 +dutch 4174 +punk 4175 +lyrics 4176 +yellow 4177 +eastwood 4178 +exotic 4179 +represent 4180 +instant 4181 +desperation 4182 +mixture 4183 +settle 4184 +frustration 4185 +unfolds 4186 +goodness 4187 +wives 4188 +directs 4189 +fetched 4190 +ape 4191 +cheating 4192 +dozens 4193 +rebel 4194 +cuba 4195 +paulie 4196 +enormous 4197 +revolutionary 4198 +hints 4199 +shelf 4200 +brooklyn 4201 +florida 4202 +dances 4203 +motives 4204 +destiny 4205 +1999 4206 +donna 4207 +hardcore 4208 +mill 4209 +wrestling 4210 +subtlety 4211 +forty 4212 +describes 4213 +drops 4214 +blake 4215 +stinker 4216 +doll 4217 +painted 4218 +fond 4219 +linda 4220 +principal 4221 +rank 4222 +ideal 4223 +kennedy 4224 +hammer 4225 +montage 4226 +hollywood's 4227 +tie 4228 +disjointed 4229 +3rd 4230 +reaches 4231 +amy 4232 +immensely 4233 +ginger 4234 +judging 4235 +companion 4236 +communist 4237 +urge 4238 +winds 4239 +developing 4240 +trailers 4241 +cliff 4242 +lawrence 4243 +stellar 4244 +topless 4245 +circle 4246 +surviving 4247 +avoided 4248 +relations 4249 +bold 4250 +hideous 4251 +voight 4252 +closet 4253 +et 4254 +surfing 4255 +melting 4256 +soccer 4257 +edie 4258 +matches 4259 +backgrounds 4260 +planned 4261 +enemies 4262 +advance 4263 +bull 4264 +authority 4265 +crush 4266 +outfit 4267 +emphasis 4268 +method 4269 +terrorist 4270 +senseless 4271 +pig 4272 +uwe 4273 +simplistic 4274 +benefit 4275 +adorable 4276 +eighties 4277 +ruthless 4278 +godzilla 4279 +blew 4280 +countryside 4281 +specifically 4282 +wont 4283 +performer 4284 +hbo 4285 +traveling 4286 +todd 4287 +practice 4288 +diane 4289 +fix 4290 +faster 4291 +1980 4292 +commented 4293 +sh 4294 +loyal 4295 +saga 4296 +ties 4297 +disappear 4298 +awe 4299 +earned 4300 +buff 4301 +rick 4302 +loads 4303 +link 4304 +angeles 4305 +corruption 4306 +forms 4307 +menace 4308 +miserable 4309 +claimed 4310 +vast 4311 +coach 4312 +divorce 4313 +hal 4314 +gadget 4315 +chorus 4316 +limits 4317 +cure 4318 +introduces 4319 +cards 4320 +solo 4321 +blues 4322 +splatter 4323 +april 4324 +endure 4325 +riveting 4326 +dedicated 4327 +tender 4328 +winters 4329 +illogical 4330 +choreographed 4331 +disappeared 4332 +unsettling 4333 +waters 4334 +guessed 4335 +lemmon 4336 +involve 4337 +transformation 4338 +depressed 4339 +rooms 4340 +lasted 4341 +displayed 4342 +weakest 4343 +leonard 4344 +philosophical 4345 +racial 4346 +interaction 4347 +arrogant 4348 +tag 4349 +rocket 4350 +similarities 4351 +hurts 4352 +thoughtful 4353 +realizing 4354 +harvey 4355 +justify 4356 +hook 4357 +survivors 4358 +represented 4359 +pot 4360 +possibilities 4361 +wore 4362 +disappoint 4363 +voiced 4364 +kicked 4365 +abysmal 4366 +hamilton 4367 +buffs 4368 +safety 4369 +widow 4370 +ears 4371 +nomination 4372 +trashy 4373 +honesty 4374 +stereotype 4375 +severe 4376 +formulaic 4377 +moody 4378 +similarly 4379 +stress 4380 +pan 4381 +chased 4382 +isolated 4383 +blond 4384 +stinks 4385 +mario 4386 +passionate 4387 +finger 4388 +shirley 4389 +march 4390 +hank 4391 +improve 4392 +mann 4393 +understandable 4394 +characters' 4395 +considerable 4396 +scope 4397 +holly 4398 +diana 4399 +grasp 4400 +command 4401 +solely 4402 +'em 4403 +concern 4404 +treats 4405 +akshay 4406 +promised 4407 +colonel 4408 +jonathan 4409 +faults 4410 +helicopter 4411 +inventive 4412 +sounding 4413 +quotes 4414 +trained 4415 +switch 4416 +celebrity 4417 +tad 4418 +swimming 4419 +orson 4420 +education 4421 +aids 4422 +nail 4423 +judy 4424 +cg 4425 +user 4426 +nervous 4427 +nostalgic 4428 +daddy 4429 +alert 4430 +amanda 4431 +facing 4432 +comparing 4433 +unhappy 4434 +preview 4435 +report 4436 +bonus 4437 +purchase 4438 +chess 4439 +wet 4440 +lately 4441 +horrifying 4442 +agrees 4443 +thru 4444 +dolls 4445 +cinematographer 4446 +ignorant 4447 +species 4448 +seed 4449 +consistent 4450 +downhill 4451 +corporate 4452 +photos 4453 +confidence 4454 +letters 4455 +berlin 4456 +dinosaur 4457 +rotten 4458 +taught 4459 +fooled 4460 +laws 4461 +nicholson 4462 +namely 4463 +shake 4464 +waited 4465 +wished 4466 +embarrassment 4467 +everyone's 4468 +boot 4469 +pretending 4470 +reaching 4471 +someone's 4472 +transfer 4473 +sits 4474 +armed 4475 +del 4476 +dub 4477 +defend 4478 +hart 4479 +35 4480 +constructed 4481 +mall 4482 +poetic 4483 +motivations 4484 +inane 4485 +behave 4486 +tonight 4487 +staring 4488 +humble 4489 +snl 4490 +elephant 4491 +agents 4492 +oz 4493 +grandfather 4494 +writes 4495 +relation 4496 +hop 4497 +delivering 4498 +fonda 4499 +edgar 4500 +cave 4501 +artificial 4502 +grinch 4503 +sappy 4504 +prize 4505 +1972 4506 +useful 4507 +buildings 4508 +li 4509 +cake 4510 +eager 4511 +closest 4512 +suitable 4513 +raising 4514 +destroying 4515 +combine 4516 +beatty 4517 +pants 4518 +cleverly 4519 +ballet 4520 +convincingly 4521 +porno 4522 +1990 4523 +miike 4524 +affect 4525 +engage 4526 +cd 4527 +conservative 4528 +wound 4529 +arrived 4530 +stevens 4531 +alcoholic 4532 +valuable 4533 +ya 4534 +reads 4535 +scottish 4536 +elegant 4537 +vegas 4538 +chest 4539 +charlotte 4540 +climactic 4541 +tiresome 4542 +z 4543 +conflicts 4544 +babe 4545 +vengeance 4546 +square 4547 +bath 4548 +secretly 4549 +airport 4550 +campbell 4551 +kingdom 4552 +september 4553 +inferior 4554 +1968 4555 +latin 4556 +plant 4557 +button 4558 +museum 4559 +maintain 4560 +wrapped 4561 +kicking 4562 +cheated 4563 +global 4564 +robbery 4565 +virginia 4566 +wells 4567 +waves 4568 +stilted 4569 +blunt 4570 +lena 4571 +boom 4572 +access 4573 +raymond 4574 +1960s 4575 +catching 4576 +nicholas 4577 +yelling 4578 +scarecrow 4579 +beliefs 4580 +paranoia 4581 +christians 4582 +vice 4583 +jumped 4584 +lay 4585 +iron 4586 +steel 4587 +lowest 4588 +reflect 4589 +closed 4590 +mummy 4591 +transition 4592 +advertising 4593 +vulnerable 4594 +abusive 4595 +1970's 4596 +spoke 4597 +plight 4598 +mars 4599 +spread 4600 +adams 4601 +wizard 4602 +poetry 4603 +im 4604 +sandra 4605 +germans 4606 +pokemon 4607 +progresses 4608 +70 4609 +00 4610 +hung 4611 +questionable 4612 +remarks 4613 +airplane 4614 +centers 4615 +potentially 4616 +bottle 4617 +chicago 4618 +guarantee 4619 +couples 4620 +messed 4621 +catchy 4622 +slick 4623 +gangsters 4624 +misery 4625 +blade 4626 +designs 4627 +construction 4628 +ethan 4629 +desired 4630 +miracle 4631 +carradine 4632 +firstly 4633 +scores 4634 +wandering 4635 +greedy 4636 +recognition 4637 +understated 4638 +restored 4639 +complexity 4640 +madonna 4641 +attitudes 4642 +rendition 4643 +hunters 4644 +intentionally 4645 +experiments 4646 +ruby 4647 +alongside 4648 +vaguely 4649 +inappropriate 4650 +copies 4651 +operation 4652 +brutally 4653 +taxi 4654 +amounts 4655 +stooges 4656 +joined 4657 +pearl 4658 +demand 4659 +crocodile 4660 +depicts 4661 +purchased 4662 +acid 4663 +myers 4664 +exploration 4665 +advise 4666 +illegal 4667 +balls 4668 +king's 4669 +gundam 4670 +disney's 4671 +gender 4672 +lengthy 4673 +survived 4674 +hopper 4675 +niro 4676 +advanced 4677 +simplicity 4678 +bela 4679 +parallel 4680 +ocean 4681 +slaughter 4682 +rising 4683 +witnesses 4684 +chicks 4685 +streep 4686 +visible 4687 +nostalgia 4688 +arguably 4689 +careful 4690 +intimate 4691 +online 4692 +floating 4693 +rubber 4694 +june 4695 +illness 4696 +resources 4697 +khan 4698 +jaw 4699 +newly 4700 +witches 4701 +showcase 4702 +signed 4703 +opinions 4704 +dust 4705 +eaten 4706 +civilization 4707 +shelley 4708 +incomprehensible 4709 +invasion 4710 +lee's 4711 +monkeys 4712 +resort 4713 +literature 4714 +junior 4715 +likewise 4716 +homosexual 4717 +family's 4718 +viewings 4719 +sue 4720 +wisdom 4721 +matched 4722 +amitabh 4723 +edition 4724 +witnessed 4725 +visits 4726 +mistress 4727 +1983 4728 +demented 4729 +basketball 4730 +neighbors 4731 +macy 4732 +fascinated 4733 +dreary 4734 +suspicious 4735 +accompanied 4736 +worn 4737 +mail 4738 +challenging 4739 +doom 4740 +ensues 4741 +manipulative 4742 +robinson 4743 +classical 4744 +olivier 4745 +agreed 4746 +appreciation 4747 +franco 4748 +montana 4749 +troops 4750 +capturing 4751 +alternate 4752 +bands 4753 +twilight 4754 +ridden 4755 +responsibility 4756 +proceeds 4757 +chapter 4758 +jenny 4759 +prisoners 4760 +pops 4761 +analysis 4762 +subplots 4763 +lively 4764 +nuts 4765 +prisoner 4766 +incompetent 4767 +damon 4768 +sellers 4769 +mayor 4770 +rats 4771 +simpson 4772 +90s 4773 +persons 4774 +feed 4775 +descent 4776 +reel 4777 +bay 4778 +assault 4779 +losers 4780 +widely 4781 +rabbit 4782 +smiling 4783 +relatives 4784 +excessive 4785 +defined 4786 +satisfy 4787 +solution 4788 +legal 4789 +molly 4790 +arrival 4791 +overacting 4792 +equivalent 4793 +iran 4794 +pit 4795 +masterful 4796 +capital 4797 +richardson 4798 +compelled 4799 +plausible 4800 +stale 4801 +scrooge 4802 +cities 4803 +francis 4804 +enthusiasm 4805 +lone 4806 +parties 4807 +tomatoes 4808 +channels 4809 +hilariously 4810 +rocky 4811 +crucial 4812 +dropping 4813 +unit 4814 +waitress 4815 +domestic 4816 +attorney 4817 +bakshi 4818 +serving 4819 +wrap 4820 +jaws 4821 +historically 4822 +3d 4823 +defense 4824 +hello 4825 +greed 4826 +1973 4827 +priceless 4828 +sincere 4829 +warmth 4830 +paltrow 4831 +gerard 4832 +tends 4833 +god's 4834 +patients 4835 +creep 4836 +counter 4837 +dalton 4838 +kay 4839 +whats 4840 +louise 4841 +peoples 4842 +exceptionally 4843 +nyc 4844 +pal 4845 +seeks 4846 +terrorists 4847 +lumet 4848 +morris 4849 +ninja 4850 +randomly 4851 +frequent 4852 +despair 4853 +irrelevant 4854 +dressing 4855 +pursuit 4856 +prequel 4857 +creativity 4858 +imitation 4859 +bumbling 4860 +hyde 4861 +property 4862 +muslim 4863 +wishing 4864 +richards 4865 +bargain 4866 +50s 4867 +creator 4868 +calm 4869 +bacall 4870 +gabriel 4871 +mentioning 4872 +rangers 4873 +methods 4874 +earl 4875 +royal 4876 +butler 4877 +justin 4878 +psychic 4879 +chooses 4880 +belong 4881 +der 4882 +photo 4883 +polanski 4884 +mundane 4885 +specially 4886 +mighty 4887 +homer 4888 +ear 4889 +masterpieces 4890 +generated 4891 +leo 4892 +improvement 4893 +poem 4894 +ham 4895 +cliche 4896 +marty 4897 +caliber 4898 +mentions 4899 +minimum 4900 +showdown 4901 +borrowed 4902 +elm 4903 +icon 4904 +brenda 4905 +polished 4906 +1984 4907 +mechanical 4908 +overlook 4909 +loaded 4910 +map 4911 +recording 4912 +craven 4913 +tiger 4914 +roth 4915 +awfully 4916 +suffice 4917 +troubles 4918 +introduce 4919 +equipment 4920 +ashley 4921 +wendy 4922 +pamela 4923 +empathy 4924 +phantom 4925 +betty 4926 +resident 4927 +unreal 4928 +ruins 4929 +performs 4930 +promises 4931 +monk 4932 +iraq 4933 +hippie 4934 +purposes 4935 +marketing 4936 +angela 4937 +keith 4938 +sink 4939 +gifted 4940 +opportunities 4941 +garbo 4942 +assigned 4943 +feminist 4944 +household 4945 +wacky 4946 +alfred 4947 +absent 4948 +sneak 4949 +popularity 4950 +trail 4951 +inducing 4952 +moronic 4953 +wounded 4954 +receives 4955 +willis 4956 +unseen 4957 +stretched 4958 +fulci 4959 +unaware 4960 +dimension 4961 +dolph 4962 +definition 4963 +testament 4964 +educational 4965 +survivor 4966 +attend 4967 +clip 4968 +contest 4969 +petty 4970 +13th 4971 +christy 4972 +respected 4973 +resist 4974 +year's 4975 +album 4976 +expressed 4977 +randy 4978 +quit 4979 +phony 4980 +unoriginal 4981 +punishment 4982 +activities 4983 +suspend 4984 +rolled 4985 +eastern 4986 +1933 4987 +instinct 4988 +distinct 4989 +championship 4990 +tech 4991 +doubts 4992 +interests 4993 +exposure 4994 +travesty 4995 +israel 4996 +sixties 4997 +pink 4998 +orange 4999 +resulting 5000 +spain 5001 +bergman 5002 +1987 5003 +verhoeven 5004 +distribution 5005 +laughably 5006 +depicting 5007 +kissing 5008 +tooth 5009 +shed 5010 +kubrick 5011 +pin 5012 +nonsensical 5013 +roots 5014 +assumed 5015 +swim 5016 +whoopi 5017 +domino 5018 +heights 5019 +spock 5020 +inevitably 5021 +abraham 5022 +stunned 5023 +businessman 5024 +correctly 5025 +deceased 5026 +buffalo 5027 +wholly 5028 +underlying 5029 +dud 5030 +othello 5031 +unpredictable 5032 +package 5033 +hopeless 5034 +teaching 5035 +valley 5036 +uplifting 5037 +peters 5038 +integrity 5039 +1993 5040 +biography 5041 +yard 5042 +brutality 5043 +america's 5044 +trademark 5045 +retired 5046 +shaw 5047 +reflection 5048 +maniac 5049 +– 5050 +meryl 5051 +accuracy 5052 +sid 5053 +compassion 5054 +dreck 5055 +2008 5056 +edgy 5057 +greatness 5058 +assassin 5059 +greg 5060 +palace 5061 +suggested 5062 +patience 5063 +landscapes 5064 +1971 5065 +mankind 5066 +supported 5067 +merits 5068 +directions 5069 +fed 5070 +romero 5071 +spider 5072 +mtv 5073 +metaphor 5074 +masses 5075 +puppet 5076 +seldom 5077 +wife's 5078 +loyalty 5079 +deaf 5080 +grayson 5081 +strangers 5082 +3000 5083 +passable 5084 +checked 5085 +connery 5086 +confess 5087 +shaky 5088 +drake 5089 +eugene 5090 +significance 5091 +pierce 5092 +unfair 5093 +maid 5094 +indulgent 5095 +comfort 5096 +orleans 5097 +willie 5098 +glasses 5099 +pressure 5100 +alec 5101 +composer 5102 +marion 5103 +nicole 5104 +tribe 5105 +fought 5106 +technicolor 5107 +watson 5108 +dee 5109 +emperor 5110 +adaptations 5111 +romp 5112 +peak 5113 +conditions 5114 +grabs 5115 +exchange 5116 +fury 5117 +immediate 5118 +women's 5119 +timon 5120 +omen 5121 +generations 5122 +barrymore 5123 +resemble 5124 +1995 5125 +1997 5126 +confrontation 5127 +landing 5128 +frustrating 5129 +demise 5130 +spacey 5131 +lackluster 5132 +disliked 5133 +kyle 5134 +y 5135 +victory 5136 +wretched 5137 +Â… 5138 +farrell 5139 +we'd 5140 +respectively 5141 +crazed 5142 +din 5143 +expedition 5144 +chicken 5145 +cannibal 5146 +conscious 5147 +experimental 5148 +astonishing 5149 +inability 5150 +examination 5151 +wilderness 5152 +tube 5153 +blast 5154 +nerd 5155 +legacy 5156 +companies 5157 +subjected 5158 +ships 5159 +rises 5160 +invented 5161 +stuart 5162 +ambiguous 5163 +grief 5164 +rave 5165 +cracking 5166 +unexpectedly 5167 +scotland 5168 +stargate 5169 +milk 5170 +singers 5171 +darren 5172 +billed 5173 +tripe 5174 +ordered 5175 +furious 5176 +flair 5177 +griffith 5178 +refused 5179 +fascination 5180 +tastes 5181 +owen 5182 +frightened 5183 +amused 5184 +masks 5185 +females 5186 +graham 5187 +rates 5188 +simultaneously 5189 +senses 5190 +walsh 5191 +marc 5192 +simmons 5193 +shanghai 5194 +premiere 5195 +remained 5196 +warriors 5197 +1936 5198 +josh 5199 +antwone 5200 +difficulties 5201 +shoulders 5202 +femme 5203 +alternative 5204 +sentiment 5205 +relax 5206 +ollie 5207 +leon 5208 +rooney 5209 +objective 5210 +deranged 5211 +alcohol 5212 +austin 5213 +sissy 5214 +tank 5215 +dysfunctional 5216 +vulgar 5217 +stumbled 5218 +desires 5219 +replace 5220 +dixon 5221 +claus 5222 +joel 5223 +hears 5224 +coast 5225 +poison 5226 +addicted 5227 +slice 5228 +lundgren 5229 +parade 5230 +gather 5231 +appropriately 5232 +abused 5233 +cream 5234 +challenged 5235 +awhile 5236 +tacky 5237 +interactions 5238 +function 5239 +pun 5240 +bud 5241 +filling 5242 +primitive 5243 +fishing 5244 +raises 5245 +infected 5246 +musicians 5247 +precisely 5248 +caricatures 5249 +karl 5250 +underneath 5251 +ross 5252 +alicia 5253 +prey 5254 +fingers 5255 +nephew 5256 +crystal 5257 +skull 5258 +remakes 5259 +favour 5260 +wildly 5261 +phil 5262 +phrase 5263 +julian 5264 +sopranos 5265 +complaints 5266 +presenting 5267 +noises 5268 +19th 5269 +twins 5270 +les 5271 +ramones 5272 +lands 5273 +joins 5274 +wakes 5275 +require 5276 +fifty 5277 +items 5278 +frankenstein 5279 +nathan 5280 +christianity 5281 +reid 5282 +accomplish 5283 +22 5284 +dana 5285 +wang 5286 +breed 5287 +millionaire 5288 +sums 5289 +knocked 5290 +teaches 5291 +literary 5292 +loneliness 5293 +fiancé 5294 +complaining 5295 +silliness 5296 +sharon 5297 +celebration 5298 +gentleman 5299 +ustinov 5300 +husband's 5301 +exposition 5302 +choppy 5303 +altman 5304 +minus 5305 +amusement 5306 +sugar 5307 +husbands 5308 +framed 5309 +other's 5310 +andre 5311 +unlikable 5312 +sunny 5313 +roommate 5314 +stark 5315 +absurdity 5316 +rifle 5317 +electric 5318 +posters 5319 +aspiring 5320 +conscience 5321 +fields 5322 +hackneyed 5323 +downey 5324 +buster 5325 +edit 5326 +straightforward 5327 +misleading 5328 +carell 5329 +murdering 5330 +credited 5331 +sung 5332 +releases 5333 +muddled 5334 +raines 5335 +coincidence 5336 +unfold 5337 +rude 5338 +charged 5339 +weakness 5340 +quietly 5341 +pitiful 5342 +marshall 5343 +objects 5344 +shared 5345 +inexplicably 5346 +automatically 5347 +heartfelt 5348 +agenda 5349 +dresses 5350 +trend 5351 +acclaimed 5352 +blacks 5353 +murray 5354 +beverly 5355 +asylum 5356 +belushi 5357 +en 5358 +moreover 5359 +shoddy 5360 +bernard 5361 +teachers 5362 +devices 5363 +cattle 5364 +preston 5365 +dont 5366 +grotesque 5367 +visited 5368 +discovering 5369 +roof 5370 +spark 5371 +realised 5372 +handling 5373 +adopted 5374 +bread 5375 +haired 5376 +ethnic 5377 +encourage 5378 +lock 5379 +conviction 5380 +imaginable 5381 +fog 5382 +crawford 5383 +firm 5384 +servant 5385 +invites 5386 +dirt 5387 +cancer 5388 +fantasies 5389 +rely 5390 +biased 5391 +occasions 5392 +dose 5393 +industrial 5394 +harm 5395 +hungry 5396 +vance 5397 +kansas 5398 +active 5399 +preposterous 5400 +profanity 5401 +positively 5402 +prepare 5403 +ladder 5404 +sketch 5405 +alison 5406 +controlled 5407 +squad 5408 +outfits 5409 +deniro 5410 +canyon 5411 +babies 5412 +frankie 5413 +referred 5414 +kumar 5415 +regarded 5416 +designer 5417 +1988 5418 +paradise 5419 +comedians 5420 +russia 5421 +fido 5422 +provocative 5423 +behaviour 5424 +region 5425 +1930's 5426 +baldwin 5427 +laurence 5428 +translated 5429 +tracking 5430 +clock 5431 +1939 5432 +chills 5433 +hawke 5434 +cue 5435 +heist 5436 +citizens 5437 +da 5438 +1978 5439 +mode 5440 +hk 5441 +counts 5442 +riot 5443 +uncut 5444 +musician 5445 +accepts 5446 +shoulder 5447 +heartbreaking 5448 +secondary 5449 +option 5450 +75 5451 +roller 5452 +1980's 5453 +fathers 5454 +mclaglen 5455 +hopelessly 5456 +tasteless 5457 +bye 5458 +challenges 5459 +bitch 5460 +additional 5461 +backs 5462 +should've 5463 +swing 5464 +betrayal 5465 +labor 5466 +lush 5467 +morbid 5468 +abrupt 5469 +gambling 5470 +historic 5471 +iv 5472 +insurance 5473 +1986 5474 +fade 5475 +screens 5476 +bike 5477 +damme 5478 +pages 5479 +nut 5480 +admirable 5481 +rejected 5482 +skits 5483 +lip 5484 +ignorance 5485 +chainsaw 5486 +cassidy 5487 +suspension 5488 +respective 5489 +nod 5490 +chuckle 5491 +recommendation 5492 +guitar 5493 +youngest 5494 +reign 5495 +1970 5496 +biko 5497 +severely 5498 +affection 5499 +coaster 5500 +visiting 5501 +kid's 5502 +darn 5503 +refer 5504 +boxer 5505 +naughty 5506 +macarthur 5507 +deserted 5508 +amazon 5509 +paramount 5510 +files 5511 +corpses 5512 +realm 5513 +nemesis 5514 +1979 5515 +sabrina 5516 +address 5517 +beware 5518 +shares 5519 +tomorrow 5520 +prejudice 5521 +el 5522 +guaranteed 5523 +wwe 5524 +sooner 5525 +reluctant 5526 +1989 5527 +invited 5528 +aim 5529 +dickens 5530 +evidently 5531 +lindsay 5532 +hyped 5533 +penny 5534 +praised 5535 +jews 5536 +sympathize 5537 +barrel 5538 +disappears 5539 +guests 5540 +anticipation 5541 +conventions 5542 +outs 5543 +tail 5544 +deleted 5545 +freaks 5546 +rome 5547 +indication 5548 +bunny 5549 +actor's 5550 +19 5551 +fist 5552 +mayhem 5553 +1969 5554 +policeman 5555 +cannon 5556 +thread 5557 +basinger 5558 +bridget 5559 +selection 5560 +palma 5561 +inconsistent 5562 +saint 5563 +stopping 5564 +gut 5565 +burst 5566 +visions 5567 +angst 5568 +daughter's 5569 +beside 5570 +reader 5571 +sentinel 5572 +nails 5573 +promote 5574 +weaknesses 5575 +heading 5576 +www 5577 +venture 5578 +malone 5579 +misguided 5580 +1960's 5581 +muppet 5582 +uh 5583 +drove 5584 +overlong 5585 +gal 5586 +cope 5587 +mccoy 5588 +threatens 5589 +iconic 5590 +rita 5591 +stages 5592 +underworld 5593 +adolescent 5594 +tip 5595 +previews 5596 +depending 5597 +hammy 5598 +behold 5599 +steady 5600 +circus 5601 +filler 5602 +conveys 5603 +glowing 5604 +vader 5605 +shades 5606 +acceptance 5607 +psychology 5608 +bent 5609 +banal 5610 +receiving 5611 +palance 5612 +reflects 5613 +cruelty 5614 +guy's 5615 +tyler 5616 +insipid 5617 +posted 5618 +hack 5619 +curly 5620 +sassy 5621 +nicolas 5622 +harmless 5623 +morally 5624 +affairs 5625 +macho 5626 +understands 5627 +fluff 5628 +demonstrates 5629 +exceptions 5630 +bow 5631 +investigating 5632 +widescreen 5633 +30's 5634 +remade 5635 +studies 5636 +records 5637 +bros 5638 +unexplained 5639 +sirk 5640 +oldest 5641 +firing 5642 +vein 5643 +explores 5644 +completed 5645 +eternal 5646 +marvel 5647 +preachy 5648 +triple 5649 +schlock 5650 +min 5651 +employed 5652 +campaign 5653 +difficulty 5654 +strongest 5655 +gregory 5656 +grainy 5657 +popping 5658 +disguise 5659 +filth 5660 +dates 5661 +obligatory 5662 +robbins 5663 +terrified 5664 +portrayals 5665 +commander 5666 +hokey 5667 +emerges 5668 +confident 5669 +connections 5670 +lifted 5671 +artsy 5672 +height 5673 +entitled 5674 +outing 5675 +rukh 5676 +hopkins 5677 +pounds 5678 +sending 5679 +hapless 5680 +physics 5681 +phenomenon 5682 +assuming 5683 +unrelated 5684 +kitty 5685 +repeating 5686 +stores 5687 +attract 5688 +fifties 5689 +assured 5690 +clan 5691 +insists 5692 +interestingly 5693 +patricia 5694 +mentality 5695 +knight 5696 +1981 5697 +bug 5698 +paxton 5699 +pole 5700 +hughes 5701 +communicate 5702 +sox 5703 +rhythm 5704 +nolan 5705 +bitten 5706 +despicable 5707 +slimy 5708 +predict 5709 +recognizable 5710 +rounded 5711 +shakespeare's 5712 +gate 5713 +1945 5714 +recycled 5715 +conclude 5716 +casual 5717 +disgusted 5718 +comparisons 5719 +zombi 5720 +couch 5721 +offs 5722 +vital 5723 +representation 5724 +rod 5725 +duck 5726 +martha 5727 +danish 5728 +yawn 5729 +studying 5730 +1976 5731 +clarke 5732 +woo 5733 +route 5734 +prominent 5735 +tarantino 5736 +legends 5737 +paintings 5738 +suitably 5739 +someday 5740 +snakes 5741 +absorbed 5742 +stairs 5743 +redeem 5744 +gear 5745 +shortcomings 5746 +agency 5747 +tempted 5748 +rapist 5749 +inexplicable 5750 +locals 5751 +http 5752 +clueless 5753 +pleasing 5754 +vibrant 5755 +independence 5756 +marries 5757 +clad 5758 +charms 5759 +rendered 5760 +heartwarming 5761 +melody 5762 +shouting 5763 +wig 5764 +defeated 5765 +friend's 5766 +stack 5767 +lois 5768 +novak 5769 +coup 5770 +globe 5771 +soup 5772 +claustrophobic 5773 +eats 5774 +flashy 5775 +trivia 5776 +spinal 5777 +thompson 5778 +considerably 5779 +forcing 5780 +befriends 5781 +grudge 5782 +chavez 5783 +net 5784 +shopping 5785 +gems 5786 +claiming 5787 +foxx 5788 +muppets 5789 +discussing 5790 +boston 5791 +ingenious 5792 +flowers 5793 +harold 5794 +feeding 5795 +eternity 5796 +norm 5797 +sharing 5798 +meg 5799 +quinn 5800 +election 5801 +camcorder 5802 +limit 5803 +genie 5804 +daniels 5805 +quaid 5806 +bacon 5807 +runner 5808 +tierney 5809 +champion 5810 +stallone 5811 +minister 5812 +publicity 5813 +static 5814 +springer 5815 +info 5816 +screw 5817 +inhabitants 5818 +'70s 5819 +renaissance 5820 +carla 5821 +screwed 5822 +delicate 5823 +marlon 5824 +weather 5825 +deserving 5826 +incidentally 5827 +depends 5828 +winchester 5829 +boyle 5830 +gina 5831 +immature 5832 +lift 5833 +wings 5834 +partners 5835 +rope 5836 +ace 5837 +phillips 5838 +kathryn 5839 +elite 5840 +pete 5841 +brother's 5842 +glamorous 5843 +transformed 5844 +blatantly 5845 +symbolic 5846 +traffic 5847 +belt 5848 +strings 5849 +excess 5850 +stalker 5851 +smiles 5852 +ton 5853 +politician 5854 +keen 5855 +esther 5856 +ambition 5857 +surgery 5858 +ants 5859 +audrey 5860 +housewife 5861 +ish 5862 +lasting 5863 +allen's 5864 +dvds 5865 +schools 5866 +concepts 5867 +hilarity 5868 +newman 5869 +shaking 5870 +28 5871 +programs 5872 +frames 5873 +coupled 5874 +cheer 5875 +disorder 5876 +salt 5877 +beatles 5878 +fuller 5879 +shorter 5880 +voted 5881 +toronto 5882 +raj 5883 +1940 5884 +exploring 5885 +debate 5886 +yeti 5887 +layers 5888 +fontaine 5889 +backwards 5890 +continually 5891 +feat 5892 +georges 5893 +organized 5894 +destined 5895 +bombs 5896 +differently 5897 +nope 5898 +bend 5899 +towers 5900 +mothers 5901 +partially 5902 +outdated 5903 +punches 5904 +stumbles 5905 +bully 5906 +threatened 5907 +thrilled 5908 +leigh 5909 +charlton 5910 +wax 5911 +bondage 5912 +kolchak 5913 +spree 5914 +assassination 5915 +doctors 5916 +remove 5917 +claude 5918 +europa 5919 +wire 5920 +leather 5921 +messy 5922 +item 5923 +institution 5924 +departure 5925 +centre 5926 +else's 5927 +detectives 5928 +triangle 5929 +lifeless 5930 +handles 5931 +hides 5932 +wanders 5933 +dudley 5934 +accurately 5935 +duration 5936 +hum 5937 +harrison 5938 +damaged 5939 +satirical 5940 +1950 5941 +minority 5942 +suggestion 5943 +insightful 5944 +hangs 5945 +btw 5946 +preferred 5947 +sorely 5948 +windows 5949 +formed 5950 +profession 5951 +boy's 5952 +commenting 5953 +newer 5954 +landed 5955 +colin 5956 +tenant 5957 +goers 5958 +gunga 5959 +uniformly 5960 +neurotic 5961 +trials 5962 +authorities 5963 +oriented 5964 +swept 5965 +northern 5966 +computers 5967 +dylan 5968 +racing 5969 +kline 5970 +95 5971 +vocal 5972 +steele 5973 +1990s 5974 +viewer's 5975 +bridges 5976 +proving 5977 +entered 5978 +demonic 5979 +natives 5980 +seeming 5981 +brendan 5982 +reeves 5983 +obtain 5984 +rear 5985 +evolution 5986 +ie 5987 +christine 5988 +token 5989 +elevator 5990 +braveheart 5991 +garner 5992 +ripping 5993 +refuse 5994 +firmly 5995 +outright 5996 +mermaid 5997 +exquisite 5998 +mutual 5999 +posey 6000 +biblical 6001 +disastrous 6002 +sleaze 6003 +bars 6004 +helpful 6005 +wendigo 6006 +eleven 6007 +choosing 6008 +neatly 6009 +engrossing 6010 +kidman 6011 +freddy's 6012 +earn 6013 +tops 6014 +uma 6015 +anton 6016 +justified 6017 +wtf 6018 +demanding 6019 +mannerisms 6020 +inspire 6021 +speeches 6022 +containing 6023 +pacific 6024 +myth 6025 +sleeps 6026 +reliable 6027 +fifth 6028 +gillian 6029 +setup 6030 +vile 6031 +cookie 6032 +4th 6033 +hitler's 6034 +bowl 6035 +she'll 6036 +sincerely 6037 +tapes 6038 +vanessa 6039 +insanity 6040 +casts 6041 +ratso 6042 +brooding 6043 +disgrace 6044 +luis 6045 +helpless 6046 +1991 6047 +mirrors 6048 +label 6049 +emerge 6050 +kent 6051 +altered 6052 +forgiven 6053 +predecessor 6054 +heels 6055 +skit 6056 +contempt 6057 +activity 6058 +crossing 6059 +describing 6060 +1985 6061 +duvall 6062 +rampage 6063 +healthy 6064 +knightley 6065 +mercy 6066 +undead 6067 +cemetery 6068 +spies 6069 +mesmerizing 6070 +homicide 6071 +cons 6072 +frontal 6073 +ariel 6074 +restrained 6075 +valentine 6076 +approaches 6077 +startling 6078 +cerebral 6079 +vain 6080 +rooting 6081 +destroys 6082 +preparing 6083 +subtly 6084 +1977 6085 +1974 6086 +jordan 6087 +hats 6088 +grateful 6089 +pc 6090 +boasts 6091 +gere 6092 +regards 6093 +creek 6094 +survives 6095 +mixing 6096 +realities 6097 +conan 6098 +topics 6099 +educated 6100 +shaped 6101 +insights 6102 +melissa 6103 +carey 6104 +tunnel 6105 +artwork 6106 +hulk 6107 +hartley 6108 +radical 6109 +deny 6110 +modest 6111 +unlikeable 6112 +compete 6113 +1994 6114 +sometime 6115 +statue 6116 +grounds 6117 +weaker 6118 +seedy 6119 +mitch 6120 +breakfast 6121 +inspirational 6122 +jess 6123 +hugely 6124 +leaders 6125 +coat 6126 +miami 6127 +scariest 6128 +owners 6129 +casino 6130 +miniseries 6131 +freeze 6132 +akin 6133 +timberlake 6134 +deer 6135 +jared 6136 +bulk 6137 +conrad 6138 +wardrobe 6139 +poker 6140 +crashes 6141 +hers 6142 +rapidly 6143 +applaud 6144 +tara 6145 +nominations 6146 +wrenching 6147 +votes 6148 +contribution 6149 +candidate 6150 +loretta 6151 +affects 6152 +homes 6153 +cinemas 6154 +dubious 6155 +child's 6156 +stare 6157 +banter 6158 +exploits 6159 +advertised 6160 +21st 6161 +guards 6162 +vastly 6163 +relentless 6164 +disguised 6165 +masterfully 6166 +critique 6167 +dim 6168 +located 6169 +refers 6170 +narrow 6171 +des 6172 +washed 6173 +origin 6174 +puppets 6175 +addict 6176 +internal 6177 +error 6178 +disgust 6179 +injured 6180 +cartoonish 6181 +bronson 6182 +gods 6183 +alvin 6184 +30s 6185 +shell 6186 +owes 6187 +repulsive 6188 +gimmick 6189 +boris 6190 +linear 6191 +randolph 6192 +photographs 6193 +rides 6194 +ingrid 6195 +scifi 6196 +abruptly 6197 +limitations 6198 +joker 6199 +youthful 6200 +dandy 6201 +unsure 6202 +dazzling 6203 +gained 6204 +arab 6205 +detract 6206 +underwear 6207 +christina 6208 +caricature 6209 +bloom 6210 +continuing 6211 +lasts 6212 +inaccurate 6213 +where's 6214 +swallow 6215 +standout 6216 +motive 6217 +nations 6218 +convicted 6219 +bravo 6220 +youtube 6221 +nolte 6222 +lauren 6223 +holocaust 6224 +vehicles 6225 +bones 6226 +thirties 6227 +audition 6228 +factors 6229 +headache 6230 +growth 6231 +natured 6232 +mason 6233 +expertly 6234 +spine 6235 +hires 6236 +zizek 6237 +undeniably 6238 +bates 6239 +excellently 6240 +highway 6241 +nina 6242 +screenwriters 6243 +buzz 6244 +chronicles 6245 +insults 6246 +corn 6247 +stunningly 6248 +dread 6249 +homosexuality 6250 +perception 6251 +antonio 6252 +lukas 6253 +reward 6254 +decline 6255 +son's 6256 +las 6257 +mol 6258 +unsuspecting 6259 +strengths 6260 +convinces 6261 +spit 6262 +entering 6263 +natalie 6264 +tossed 6265 +toni 6266 +colours 6267 +ronald 6268 +mathieu 6269 +implied 6270 +teams 6271 +resolved 6272 +tower 6273 +entirety 6274 +confront 6275 +wander 6276 +derivative 6277 +missile 6278 +definitive 6279 +gates 6280 +supply 6281 +bachelor 6282 +anyone's 6283 +divorced 6284 +attenborough 6285 +males 6286 +promptly 6287 +painter 6288 +sinking 6289 +polly 6290 +origins 6291 +endlessly 6292 +nerves 6293 +1959 6294 +wagner 6295 +carmen 6296 +judd 6297 +poe 6298 +walt 6299 +unimaginative 6300 +anil 6301 +mice 6302 +1940s 6303 +confronted 6304 +200 6305 +lend 6306 +authenticity 6307 +siblings 6308 +longest 6309 +repressed 6310 +alexandre 6311 +span 6312 +sergeant 6313 +stardom 6314 +cassavetes 6315 +vividly 6316 +salvation 6317 +yep 6318 +jacket 6319 +users 6320 +jarring 6321 +enhanced 6322 +puerto 6323 +colleagues 6324 +referring 6325 +jedi 6326 +tokyo 6327 +niece 6328 +published 6329 +jackson's 6330 +mates 6331 +cbs 6332 +damned 6333 +sgt 6334 +delicious 6335 +uniform 6336 +dominated 6337 +judgment 6338 +juliet 6339 +accessible 6340 +bsg 6341 +exterior 6342 +misfortune 6343 +zane 6344 +phillip 6345 +ally 6346 +giants 6347 +netflix 6348 +energetic 6349 +austen 6350 +unattractive 6351 +devil's 6352 +mobile 6353 +underwater 6354 +stalking 6355 +disabled 6356 +depict 6357 +offbeat 6358 +earnest 6359 +servants 6360 +jill 6361 +bruno 6362 +cliches 6363 +crisp 6364 +nerve 6365 +peck 6366 +wounds 6367 +hepburn 6368 +terminator 6369 +sized 6370 +suburban 6371 +depths 6372 +buys 6373 +hindi 6374 +sticking 6375 +literal 6376 +playboy 6377 +gable 6378 +meandering 6379 +belly 6380 +sensible 6381 +lighter 6382 +21 6383 +stranded 6384 +yokai 6385 +pray 6386 +mutant 6387 +sale 6388 +exit 6389 +estranged 6390 +anyhow 6391 +identical 6392 +foolish 6393 +eventual 6394 +errol 6395 +separated 6396 +bashing 6397 +cushing 6398 +soylent 6399 +antonioni 6400 +galaxy 6401 +glued 6402 +imo 6403 +tormented 6404 +syndrome 6405 +biting 6406 +dragons 6407 +macabre 6408 +dealer 6409 +filthy 6410 +residents 6411 +victorian 6412 +witchcraft 6413 +cents 6414 +improbable 6415 +inherent 6416 +alley 6417 +lester 6418 +readers 6419 +scratch 6420 +pirate 6421 +cher 6422 +pickford 6423 +astounding 6424 +devastating 6425 +breathing 6426 +clash 6427 +approaching 6428 +severed 6429 +owned 6430 +interact 6431 +cleaning 6432 +characteristics 6433 +expects 6434 +guinness 6435 +dismal 6436 +sniper 6437 +lance 6438 +sand 6439 +respectable 6440 +budgets 6441 +sought 6442 +scoop 6443 +slide 6444 +butch 6445 +nightclub 6446 +yours 6447 +blooded 6448 +she'd 6449 +appeals 6450 +ebert 6451 +harriet 6452 +farmer 6453 +stylized 6454 +owns 6455 +noticeable 6456 +kurosawa 6457 +dustin 6458 +id 6459 +balanced 6460 +fragile 6461 +sublime 6462 +salman 6463 +answered 6464 +penn 6465 +amrita 6466 +adore 6467 +logan 6468 +demonstrate 6469 +concentrate 6470 +exploit 6471 +races 6472 +laden 6473 +psychopath 6474 +affleck 6475 +1982 6476 +garland 6477 +worms 6478 +23 6479 +filmmaking 6480 +pattern 6481 +habit 6482 +incapable 6483 +isolation 6484 +fatale 6485 +decidedly 6486 +steam 6487 +jules 6488 +ford's 6489 +asia 6490 +possess 6491 +senior 6492 +reminder 6493 +cheaply 6494 +principals 6495 +immortal 6496 +christie 6497 +monty 6498 +sf 6499 +evelyn 6500 +denis 6501 +corporation 6502 +turd 6503 +soderbergh 6504 +deliverance 6505 +subway 6506 +potter 6507 +breakdown 6508 +flimsy 6509 +packs 6510 +judged 6511 +wisely 6512 +moe 6513 +bogus 6514 +enthusiastic 6515 +cries 6516 +conveyed 6517 +escaping 6518 +plotting 6519 +wilder 6520 +pale 6521 +deliberate 6522 +dvd's 6523 +informed 6524 +promoted 6525 +axe 6526 +flashes 6527 +cypher 6528 +tremendously 6529 +esquire 6530 +1944 6531 +feast 6532 +glaring 6533 +irene 6534 +spectacle 6535 +chopped 6536 +cyborg 6537 +assembled 6538 +drinks 6539 +dump 6540 +celebrated 6541 +quarter 6542 +boyer 6543 +clara 6544 +arguing 6545 +selected 6546 +numbing 6547 +romeo 6548 +volume 6549 +truman 6550 +combines 6551 +embrace 6552 +troma 6553 +expose 6554 +laurie 6555 +kidnapping 6556 +debt 6557 +contribute 6558 +ominous 6559 +jodie 6560 +magician 6561 +o'hara 6562 +conveniently 6563 +outline 6564 +excruciatingly 6565 +accounts 6566 +pound 6567 +pixar 6568 +pierre 6569 +hackman 6570 +lightning 6571 +absorbing 6572 +copied 6573 +clone 6574 +lola 6575 +ugh 6576 +burke 6577 +cecil 6578 +jan 6579 +mitchum 6580 +jealousy 6581 +advised 6582 +40s 6583 +ensure 6584 +collect 6585 +rewarding 6586 +updated 6587 +freaky 6588 +attacking 6589 +rescued 6590 +lex 6591 +1975 6592 +dilemma 6593 +colored 6594 +beowulf 6595 +hi 6596 +melvyn 6597 +ps 6598 +pocket 6599 +passengers 6600 +accepting 6601 +sydney 6602 +classy 6603 +whiny 6604 +loy 6605 +experiencing 6606 +exorcist 6607 +destructive 6608 +300 6609 +goods 6610 +spencer 6611 +corbett 6612 +shepherd 6613 +reports 6614 +expectation 6615 +sophie 6616 +sentimentality 6617 +pause 6618 +sidewalk 6619 +karate 6620 +quantum 6621 +intricate 6622 +tax 6623 +scarface 6624 +crippled 6625 +longing 6626 +nbc 6627 +reeve 6628 +vintage 6629 +crown 6630 +1998 6631 +quentin 6632 +obsessive 6633 +immense 6634 +knocks 6635 +bounty 6636 +indiana 6637 +adaption 6638 +delighted 6639 +er 6640 +naschy 6641 +liam 6642 +establish 6643 +addiction 6644 +europeans 6645 +tool 6646 +stroke 6647 +overblown 6648 +goldblum 6649 +jaded 6650 +pursue 6651 +sucker 6652 +slip 6653 +theories 6654 +rookie 6655 +havoc 6656 +1953 6657 +anticipated 6658 +dukes 6659 +principle 6660 +voyage 6661 +gamera 6662 +swearing 6663 +unsatisfying 6664 +wonderland 6665 +frontier 6666 +parallels 6667 +crashing 6668 +downs 6669 +incorrect 6670 +erika 6671 +aggressive 6672 +divine 6673 +paula 6674 +dashing 6675 +turmoil 6676 +suspected 6677 +aided 6678 +grass 6679 +story's 6680 +distract 6681 +cape 6682 +snuff 6683 +bach 6684 +comprehend 6685 +werewolves 6686 +masterson 6687 +resulted 6688 +miranda 6689 +tendency 6690 +fright 6691 +spaghetti 6692 +goals 6693 +rainy 6694 +reviewing 6695 +juliette 6696 +establishment 6697 +redundant 6698 +switched 6699 +taped 6700 +sarcastic 6701 +arguments 6702 +rider 6703 +peaceful 6704 +barbra 6705 +butcher 6706 +shootout 6707 +bubble 6708 +routines 6709 +demonstrated 6710 +spice 6711 +backed 6712 +polish 6713 +cultures 6714 +parsons 6715 +distress 6716 +hero's 6717 +chill 6718 +morons 6719 +slugs 6720 +subtext 6721 +ultimatum 6722 +intentional 6723 +virtual 6724 +morals 6725 +cutter 6726 +hayworth 6727 +mouthed 6728 +fleshed 6729 +fascist 6730 +dramatically 6731 +passage 6732 +realization 6733 +slaves 6734 +gentlemen 6735 +liu 6736 +hyper 6737 +peculiar 6738 +avoiding 6739 +lavish 6740 +adrian 6741 +vanilla 6742 +boiled 6743 +admired 6744 +thieves 6745 +moron 6746 +sixth 6747 +'cause 6748 +arranged 6749 +climb 6750 +horny 6751 +approached 6752 +alleged 6753 +pumbaa 6754 +predictably 6755 +wielding 6756 +armstrong 6757 +commitment 6758 +seymour 6759 +serum 6760 +odyssey 6761 +hybrid 6762 +messing 6763 +begging 6764 +alter 6765 +establishing 6766 +toby 6767 +whining 6768 +canceled 6769 +collective 6770 +define 6771 +dame 6772 +bikini 6773 +afterward 6774 +mystical 6775 +tourist 6776 +furniture 6777 +fairbanks 6778 +casper 6779 +revolt 6780 +remembering 6781 +exploding 6782 +consideration 6783 +arrest 6784 +inmates 6785 +1934 6786 +shift 6787 +aiming 6788 +samantha 6789 +puzzle 6790 +ghetto 6791 +arc 6792 +traits 6793 +apply 6794 +olds 6795 +sang 6796 +distraction 6797 +hateful 6798 +fools 6799 +anytime 6800 +reviewed 6801 +enhance 6802 +lunch 6803 +coke 6804 +upside 6805 +papers 6806 +insist 6807 +medieval 6808 +wine 6809 +vega 6810 +insomnia 6811 +arriving 6812 +keaton's 6813 +phenomenal 6814 +fills 6815 +graveyard 6816 +stella 6817 +exploited 6818 +writer's 6819 +acquired 6820 +strict 6821 +slapped 6822 +jewel 6823 +thelma 6824 +mcqueen 6825 +pedestrian 6826 +cal 6827 +anthology 6828 +vince 6829 +mythology 6830 +consciousness 6831 +kinnear 6832 +life's 6833 +carnage 6834 +courtroom 6835 +tolerable 6836 +populated 6837 +huston 6838 +contributed 6839 +poses 6840 +actors' 6841 +optimistic 6842 +verdict 6843 +rebellious 6844 +trace 6845 +whites 6846 +commits 6847 +kelly's 6848 +mouths 6849 +stream 6850 +respects 6851 +leap 6852 +sickening 6853 +puppy 6854 +overboard 6855 +diverse 6856 +monologue 6857 +tuned 6858 +corman 6859 +gypo 6860 +skilled 6861 +seasoned 6862 +settled 6863 +horrified 6864 +remembers 6865 +relentlessly 6866 +dj 6867 +— 6868 +jersey 6869 +psychologist 6870 +borders 6871 +lethal 6872 +tony's 6873 +shoe 6874 +smash 6875 +taboo 6876 +wiped 6877 +excuses 6878 +crosses 6879 +salesman 6880 +ritual 6881 +mormon 6882 +achieves 6883 +thunderbirds 6884 +scored 6885 +vanity 6886 +pad 6887 +aussie 6888 +explodes 6889 +ira 6890 +dynamics 6891 +preminger 6892 +franklin 6893 +verbal 6894 +feminine 6895 +policy 6896 +flavor 6897 +expense 6898 +suggesting 6899 +trains 6900 +instincts 6901 +nuances 6902 +dumber 6903 +flock 6904 +feeble 6905 +deanna 6906 +hoot 6907 +cuban 6908 +kathy 6909 +possession 6910 +document 6911 +cohen 6912 +foundation 6913 +diary 6914 +guinea 6915 +covering 6916 +vomit 6917 +readily 6918 +fluid 6919 +cigarette 6920 +tactics 6921 +deliciously 6922 +seductive 6923 +circles 6924 +phase 6925 +themed 6926 +busey 6927 +marilyn 6928 +amidst 6929 +posing 6930 +lean 6931 +cooking 6932 +deputy 6933 +duel 6934 +brainless 6935 +mute 6936 +meantime 6937 +unsympathetic 6938 +wheel 6939 +update 6940 +immigrant 6941 +weary 6942 +basket 6943 +attending 6944 +mortal 6945 +clive 6946 +regularly 6947 +delightfully 6948 +possesses 6949 +newcomer 6950 +porter 6951 +invention 6952 +sources 6953 +wash 6954 +contestants 6955 +shockingly 6956 +wheelchair 6957 +stephanie 6958 +ritchie 6959 +wong 6960 +pushes 6961 +ricky 6962 +audience's 6963 +einstein 6964 +controlling 6965 +mama 6966 +encountered 6967 +pathos 6968 +zorro 6969 +mysteriously 6970 +korea 6971 +bachchan 6972 +jury 6973 +keys 6974 +skinny 6975 +sells 6976 +satisfaction 6977 +romances 6978 +meal 6979 +explosive 6980 +defies 6981 +drab 6982 +clerk 6983 +pfeiffer 6984 +sunrise 6985 +symbol 6986 +pirates 6987 +otto 6988 +novelty 6989 +jacques 6990 +void 6991 +herbert 6992 +narrated 6993 +lionel 6994 +targets 6995 +august 6996 +razor 6997 +rivers 6998 +admitted 6999 +mum 7000 +sundance 7001 +lends 7002 +cliched 7003 +screwball 7004 +serials 7005 +neglected 7006 +olivia 7007 +truths 7008 +sided 7009 +steer 7010 +flower 7011 +indifferent 7012 +dumped 7013 +lucille 7014 +mole 7015 +products 7016 +beg 7017 +releasing 7018 +niven 7019 +stewart's 7020 +ordeal 7021 +darth 7022 +um 7023 +crosby 7024 +statements 7025 +followers 7026 +psyche 7027 +excruciating 7028 +noteworthy 7029 +swinging 7030 +deed 7031 +aftermath 7032 +ranch 7033 +consist 7034 +embarrassingly 7035 +unusually 7036 +convention 7037 +shifts 7038 +produces 7039 +motorcycle 7040 +tickets 7041 +wider 7042 +longoria 7043 +gwyneth 7044 +employee 7045 +instances 7046 +parking 7047 +intact 7048 +starters 7049 +rapid 7050 +arrow 7051 +thurman 7052 +debbie 7053 +dumbest 7054 +wastes 7055 +sarandon 7056 +economic 7057 +israeli 7058 +additionally 7059 +fanatic 7060 +planes 7061 +pursued 7062 +legitimate 7063 +discussed 7064 +forties 7065 +introducing 7066 +anxious 7067 +cannes 7068 +biker 7069 +deciding 7070 +sanders 7071 +fuzzy 7072 +agony 7073 +alot 7074 +assignment 7075 +stones 7076 +scorsese 7077 +caron 7078 +degrees 7079 +medicine 7080 +hannah 7081 +reverse 7082 +inaccuracies 7083 +july 7084 +attended 7085 +gilbert 7086 +forgetting 7087 +jane's 7088 +gielgud 7089 +angie 7090 +milo 7091 +laputa 7092 +branagh's 7093 +motions 7094 +auto 7095 +controversy 7096 +grandma 7097 +cunningham 7098 +professionals 7099 +criticize 7100 +kidnap 7101 +artistry 7102 +sarcasm 7103 +fishburne 7104 +brow 7105 +bogart 7106 +columbia 7107 +incidents 7108 +vera 7109 +meteor 7110 +georgia 7111 +arty 7112 +freaking 7113 +hadley 7114 +suspicion 7115 +scott's 7116 +coffin 7117 +juan 7118 +crossed 7119 +idol 7120 +grip 7121 +obstacles 7122 +mentor 7123 +consequently 7124 +begs 7125 +stating 7126 +ambitions 7127 +muslims 7128 +executives 7129 +daisy 7130 +manners 7131 +warns 7132 +1948 7133 +jolie 7134 +arquette 7135 +distracted 7136 +centuries 7137 +abound 7138 +jose 7139 +factual 7140 +goodbye 7141 +trigger 7142 +breast 7143 +invite 7144 +tcm 7145 +unanswered 7146 +indicate 7147 +shepard 7148 +session 7149 +daylight 7150 +minnelli 7151 +cindy 7152 +funding 7153 +pains 7154 +predator 7155 +flames 7156 +fried 7157 +scripting 7158 +rational 7159 +stabbed 7160 +collette 7161 +'i 7162 +compliment 7163 +hooker 7164 +cliffhanger 7165 +inclusion 7166 +debra 7167 +roughly 7168 +moss 7169 +1967 7170 +awakening 7171 +viewpoint 7172 +kazan 7173 +rejects 7174 +toned 7175 +sentences 7176 +denise 7177 +originals 7178 +cycle 7179 +informative 7180 +pros 7181 +harlow 7182 +stern 7183 +corey 7184 +stalked 7185 +foil 7186 +plodding 7187 +varied 7188 +sweden 7189 +detroit 7190 +misunderstood 7191 +clay 7192 +relevance 7193 +depictions 7194 +blamed 7195 +paints 7196 +pointing 7197 +click 7198 +stance 7199 +protest 7200 +chamber 7201 +robbers 7202 +gooding 7203 +soprano 7204 +likeable 7205 +exclusively 7206 +slim 7207 +campus 7208 +haines 7209 +cheadle 7210 +cap 7211 +cab 7212 +rambling 7213 +paranoid 7214 +seats 7215 +frances 7216 +rowlands 7217 +101 7218 +consequence 7219 +murky 7220 +abandon 7221 +gap 7222 +berkeley 7223 +ruining 7224 +stink 7225 +denouement 7226 +penelope 7227 +intro 7228 +abortion 7229 +tomei 7230 +replies 7231 +antagonist 7232 +gloria 7233 +stardust 7234 +tomb 7235 +gallery 7236 +bug's 7237 +determination 7238 +40's 7239 +c'mon 7240 +translate 7241 +bait 7242 +killer's 7243 +eagerly 7244 +relating 7245 +iranian 7246 +rips 7247 +momentum 7248 +uncanny 7249 +frozen 7250 +begun 7251 +generate 7252 +uniforms 7253 +intensely 7254 +dreamy 7255 +martian 7256 +festivals 7257 +grabbed 7258 +mock 7259 +jenna 7260 +che's 7261 +schedule 7262 +surroundings 7263 +coma 7264 +imaginary 7265 +schneider 7266 +gus 7267 +foremost 7268 +composition 7269 +robertson 7270 +politicians 7271 +services 7272 +hysterically 7273 +snowman 7274 +maureen 7275 +omar 7276 +republic 7277 +lurking 7278 +pans 7279 +alliance 7280 +hostel 7281 +diner 7282 +sheen 7283 +injury 7284 +rupert 7285 +hippies 7286 +rosario 7287 +chamberlain 7288 +ww2 7289 +scenarios 7290 +participants 7291 +realistically 7292 +communication 7293 +kris 7294 +sg 7295 +kathleen 7296 +brat 7297 +redneck 7298 +launch 7299 +therapy 7300 +quasi 7301 +miyazaki 7302 +hmmm 7303 +85 7304 +faux 7305 +geisha 7306 +bauer 7307 +mick 7308 +enigmatic 7309 +1951 7310 +phones 7311 +shaggy 7312 +hostage 7313 +destination 7314 +lens 7315 +glimpses 7316 +1943 7317 +lastly 7318 +rehash 7319 +gestures 7320 +shotgun 7321 +casablanca 7322 +dismiss 7323 +sights 7324 +periods 7325 +burnt 7326 +bats 7327 +resembling 7328 +charlie's 7329 +apt 7330 +linked 7331 +widowed 7332 +dominic 7333 +glance 7334 +cow 7335 +tho 7336 +traps 7337 +curiously 7338 +heath 7339 +envy 7340 +playwright 7341 +gigantic 7342 +paths 7343 +bleed 7344 +ambiguity 7345 +gaps 7346 +bosses 7347 +hayes 7348 +sterling 7349 +necessity 7350 +comeback 7351 +sketches 7352 +sondra 7353 +ignoring 7354 +revolving 7355 +apocalyptic 7356 +reiser 7357 +sailor 7358 +saloon 7359 +frantic 7360 +resistance 7361 +pegg 7362 +overs 7363 +precise 7364 +herman 7365 +rounds 7366 +arkin 7367 +gloomy 7368 +pressed 7369 +haunt 7370 +1992 7371 +enchanted 7372 +iturbi 7373 +fuel 7374 +blaise 7375 +mabel 7376 +laboratory 7377 +county 7378 +veterans 7379 +studied 7380 +cheers 7381 +bearing 7382 +eh 7383 +sunset 7384 +reflected 7385 +rolls 7386 +investigator 7387 +adele 7388 +pen 7389 +maintains 7390 +capacity 7391 +kubrick's 7392 +unstable 7393 +avid 7394 +midst 7395 +man' 7396 +qualify 7397 +bonnie 7398 +person's 7399 +mins 7400 +geek 7401 +nun 7402 +jude 7403 +angelina 7404 +galactica 7405 +sufficient 7406 +substantial 7407 +incest 7408 +handicapped 7409 +trier 7410 +ample 7411 +doctor's 7412 +warden 7413 +supreme 7414 +hinted 7415 +slashers 7416 +rewarded 7417 +rice 7418 +complications 7419 +trauma 7420 +biopic 7421 +sebastian 7422 +'80s 7423 +characterizations 7424 +awareness 7425 +popped 7426 +sparks 7427 +vignettes 7428 +psychedelic 7429 +unclear 7430 +kells 7431 +tightly 7432 +existing 7433 +du 7434 +entrance 7435 +offend 7436 +goldie 7437 +guardian 7438 +collins 7439 +targeted 7440 +talky 7441 +extensive 7442 +ny 7443 +benefits 7444 +epics 7445 +pilots 7446 +payoff 7447 +stadium 7448 +october 7449 +stake 7450 +characterisation 7451 +applied 7452 +applies 7453 +pivotal 7454 +lowe 7455 +gathering 7456 +marisa 7457 +brent 7458 +upcoming 7459 +1963 7460 +overbearing 7461 +eli 7462 +occult 7463 +joking 7464 +ol' 7465 +graduate 7466 +beckinsale 7467 +nuanced 7468 +homicidal 7469 +addressed 7470 +evans 7471 +lunatic 7472 +parrot 7473 +edith 7474 +revival 7475 +convict 7476 +ignores 7477 +safely 7478 +plate 7479 +sour 7480 +turkish 7481 +favourites 7482 +ajay 7483 +boundaries 7484 +northam 7485 +profile 7486 +russ 7487 +skeptical 7488 +frog 7489 +invested 7490 +repeats 7491 +bias 7492 +'60s 7493 +drowned 7494 +iq 7495 +diversity 7496 +outlandish 7497 +nightmarish 7498 +dynamite 7499 +unfolding 7500 +convent 7501 +clooney 7502 +observations 7503 +johansson 7504 +1955 7505 +enchanting 7506 +tire 7507 +stabbing 7508 +disco 7509 +excellence 7510 +27 7511 +clunky 7512 +valid 7513 +array 7514 +engine 7515 +sammo 7516 +doug 7517 +sly 7518 +interior 7519 +resolve 7520 +hating 7521 +olsen 7522 +interviewed 7523 +chong 7524 +protection 7525 +maximum 7526 +nauseating 7527 +versa 7528 +apocalypse 7529 +exploitative 7530 +observation 7531 +murderers 7532 +questioning 7533 +gosh 7534 +stereotyped 7535 +flag 7536 +shore 7537 +pose 7538 +acknowledge 7539 +fruit 7540 +caretaker 7541 +rosemary's 7542 +interpretations 7543 +shin 7544 +stations 7545 +flavia 7546 +nutshell 7547 +announced 7548 +assure 7549 +silverman 7550 +duh 7551 +sonny 7552 +1958 7553 +blockbusters 7554 +pornography 7555 +vivian 7556 +sensibility 7557 +courtesy 7558 +battlestar 7559 +macdonald 7560 +boots 7561 +brides 7562 +reunite 7563 +brooke 7564 +controls 7565 +masked 7566 +phantasm 7567 +prophecy 7568 +slower 7569 +relying 7570 +sweat 7571 +divided 7572 +mannered 7573 +marked 7574 +witnessing 7575 +girlfriends 7576 +snipes 7577 +fortunate 7578 +watcher 7579 +brett 7580 +ernie 7581 +villainous 7582 +strung 7583 +rebels 7584 +candle 7585 +counting 7586 +mccarthy 7587 +rodriguez 7588 +bonham 7589 +portuguese 7590 +daytime 7591 +rea 7592 +insert 7593 +misty 7594 +displaying 7595 +substitute 7596 +satanic 7597 +wayans 7598 +magically 7599 +sincerity 7600 +owl 7601 +cocaine 7602 +spotlight 7603 +inter 7604 +chewing 7605 +lopez 7606 +chiba 7607 +progressed 7608 +entries 7609 +demille 7610 +chuckles 7611 +climbing 7612 +26 7613 +chaotic 7614 +criticized 7615 +confined 7616 +sanity 7617 +goat 7618 +unhinged 7619 +bittersweet 7620 +collar 7621 +realises 7622 +peril 7623 +bust 7624 +smell 7625 +turtle 7626 +wartime 7627 +admits 7628 +commanding 7629 +evokes 7630 +beard 7631 +seduce 7632 +harrowing 7633 +janet 7634 +phoenix 7635 +stiles 7636 +interrupted 7637 +whore 7638 +shocks 7639 +inadvertently 7640 +jar 7641 +wright 7642 +fart 7643 +resume 7644 +lynch's 7645 +needing 7646 +delirious 7647 +upstairs 7648 +obscurity 7649 +famed 7650 +palm 7651 +weekly 7652 +replacement 7653 +monotonous 7654 +smug 7655 +preaching 7656 +projected 7657 +randall 7658 +enduring 7659 +hmm 7660 +organization 7661 +landmark 7662 +thereby 7663 +fundamental 7664 +ripoff 7665 +rightly 7666 +ins 7667 +chew 7668 +slavery 7669 +unnatural 7670 +arrogance 7671 +waking 7672 +manipulation 7673 +jagger 7674 +reserved 7675 +blazing 7676 +finishes 7677 +somethings 7678 +observe 7679 +raging 7680 +thrust 7681 +trivial 7682 +madsen 7683 +carlos 7684 +samuel 7685 +tones 7686 +commendable 7687 +crushed 7688 +similarity 7689 +deemed 7690 +choir 7691 +imagining 7692 +unappealing 7693 +understatement 7694 +apple 7695 +discipline 7696 +thailand 7697 +colleague 7698 +convenient 7699 +rendering 7700 +hines 7701 +cena 7702 +mandy 7703 +testing 7704 +motel 7705 +subsequently 7706 +fassbinder 7707 +reluctantly 7708 +platform 7709 +men's 7710 +egyptian 7711 +aesthetic 7712 +hooper 7713 +accompanying 7714 +protective 7715 +penned 7716 +fetish 7717 +kirsten 7718 +herd 7719 +layered 7720 +scarecrows 7721 +incestuous 7722 +thunder 7723 +boogie 7724 +participate 7725 +forgiveness 7726 +baddies 7727 +hardened 7728 +forgets 7729 +comparable 7730 +combs 7731 +understandably 7732 +shahid 7733 +laying 7734 +marine 7735 +recover 7736 +scheming 7737 +cancelled 7738 +vargas 7739 +stumble 7740 +celebrities 7741 +merry 7742 +russo 7743 +frost 7744 +unfamiliar 7745 +madeleine 7746 +isabelle 7747 +crooks 7748 +python 7749 +filmography 7750 +explode 7751 +sylvia 7752 +article 7753 +climatic 7754 +achievements 7755 +conductor 7756 +pizza 7757 +reminding 7758 +remark 7759 +lo 7760 +gackt 7761 +traumatic 7762 +benjamin 7763 +stuffed 7764 +accidental 7765 +travis 7766 +govinda 7767 +must've 7768 +quintessential 7769 +deathtrap 7770 +cheerful 7771 +hostile 7772 +orchestra 7773 +ninety 7774 +gorilla 7775 +marcel 7776 +cameraman 7777 +shred 7778 +sholay 7779 +wrestler 7780 +customers 7781 +hallmark 7782 +beers 7783 +glossy 7784 +despise 7785 +anita 7786 +goings 7787 +spontaneous 7788 +1932 7789 +fleet 7790 +shameless 7791 +charges 7792 +camping 7793 +finishing 7794 +district 7795 +sins 7796 +dallas 7797 +file 7798 +yell 7799 +serbian 7800 +myrna 7801 +wholesome 7802 +titular 7803 +boo 7804 +o'brien 7805 +implies 7806 +sack 7807 +flip 7808 +salvage 7809 +annoy 7810 +restraint 7811 +imho 7812 +creations 7813 +affecting 7814 +pornographic 7815 +spoiling 7816 +bonanza 7817 +ala 7818 +raid 7819 +raunchy 7820 +sales 7821 +cheering 7822 +captivated 7823 +je 7824 +espionage 7825 +license 7826 +defining 7827 +beforehand 7828 +se 7829 +conclusions 7830 +bakshi's 7831 +hawn 7832 +sherlock 7833 +caprica 7834 +ruled 7835 +unconventional 7836 +diego 7837 +awry 7838 +verge 7839 +krueger 7840 +grin 7841 +whimsical 7842 +ideals 7843 +meyer 7844 +surround 7845 +characteristic 7846 +digging 7847 +shameful 7848 +coolest 7849 +philo 7850 +cells 7851 +reagan 7852 +seattle 7853 +infinitely 7854 +sickness 7855 +excels 7856 +2009 7857 +novelist 7858 +1946 7859 +burial 7860 +fades 7861 +faded 7862 +shannon 7863 +traditions 7864 +fraud 7865 +perverted 7866 +sheets 7867 +voodoo 7868 +desk 7869 +abundance 7870 +flashing 7871 +hunted 7872 +betrayed 7873 +admission 7874 +gershwin 7875 +rampant 7876 +relaxed 7877 +fires 7878 +polar 7879 +kindly 7880 +tits 7881 +melancholy 7882 +drowning 7883 +semblance 7884 +temper 7885 +cracks 7886 +tide 7887 +oblivious 7888 +miraculously 7889 +clarity 7890 +elliott 7891 +inserted 7892 +considers 7893 +constraints 7894 +drift 7895 +sunk 7896 +distributed 7897 +unnecessarily 7898 +welles' 7899 +flows 7900 +sexist 7901 +beckham 7902 +summed 7903 +henchmen 7904 +tools 7905 +transparent 7906 +devotion 7907 +hitchcock's 7908 +earliest 7909 +scarlett 7910 +dangerously 7911 +taut 7912 +dafoe 7913 +dreaming 7914 +seth 7915 +prop 7916 +cain 7917 +wesley 7918 +adapt 7919 +openly 7920 +sane 7921 +hugo 7922 +creasy 7923 +chops 7924 +pitched 7925 +juice 7926 +riff 7927 +blandings 7928 +shah 7929 +screened 7930 +tashan 7931 +meredith 7932 +doyle 7933 +mud 7934 +zodiac 7935 +regime 7936 +irritated 7937 +eagle 7938 +paycheck 7939 +egypt 7940 +spiral 7941 +letdown 7942 +wherever 7943 +madison 7944 +deeds 7945 +robotic 7946 +faint 7947 +outrageously 7948 +sheep 7949 +elsa 7950 +baron 7951 +overtones 7952 +searched 7953 +unleashed 7954 +sporting 7955 +lennon 7956 +gangs 7957 +dahmer 7958 +peggy 7959 +vapid 7960 +heap 7961 +circa 7962 +simpsons 7963 +slater 7964 +permanent 7965 +voyager 7966 +presidential 7967 +compensate 7968 +deepest 7969 +reject 7970 +uneasy 7971 +ghastly 7972 +gretchen 7973 +sophia 7974 +warehouse 7975 +switching 7976 +cedric 7977 +lara 7978 +evoke 7979 +flame 7980 +automatic 7981 +submarine 7982 +plug 7983 +programme 7984 +sucking 7985 +pursuing 7986 +avoids 7987 +assistance 7988 +assumes 7989 +orphan 7990 +mart 7991 +practical 7992 +joining 7993 +failures 7994 +liner 7995 +garfield 7996 +dwight 7997 +slut 7998 +oprah 7999 +committing 8000 +intend 8001 +ealing 8002 +shirts 8003 +locke 8004 +admirer 8005 +awaiting 8006 +ram 8007 +fritz 8008 +melbourne 8009 +contestant 8010 +timmy 8011 +rivals 8012 +buffy 8013 +clouds 8014 +ambiance 8015 +babes 8016 +ensue 8017 +coburn 8018 +occupied 8019 +sergio 8020 +sitcoms 8021 +variation 8022 +censorship 8023 +ferrell 8024 +radiation 8025 +snap 8026 +underdeveloped 8027 +takashi 8028 +hobgoblins 8029 +finney 8030 +listened 8031 +fiancée 8032 +complained 8033 +pauline 8034 +kinski 8035 +alarm 8036 +engineer 8037 +chloe 8038 +proceed 8039 +demeanor 8040 +suzanne 8041 +battlefield 8042 +rebellion 8043 +criticisms 8044 +remainder 8045 +ghostly 8046 +spaceship 8047 +howling 8048 +motivated 8049 +joint 8050 +carpenter's 8051 +fodder 8052 +bert 8053 +dominate 8054 +monks 8055 +dragging 8056 +inclined 8057 +upbeat 8058 +encouraged 8059 +networks 8060 +han 8061 +loren 8062 +brazilian 8063 +atlantic 8064 +flowing 8065 +progression 8066 +tess 8067 +meek 8068 +darkly 8069 +disappearance 8070 +colman 8071 +crashed 8072 +caper 8073 +solved 8074 +fairness 8075 +distinction 8076 +sensual 8077 +feinstone 8078 +sho 8079 +warrant 8080 +grease 8081 +visitor 8082 +marijuana 8083 +sections 8084 +avenge 8085 +tv's 8086 +croc 8087 +sober 8088 +badness 8089 +who've 8090 +ninjas 8091 +myrtle 8092 +runaway 8093 +helmet 8094 +scratching 8095 +quaint 8096 +busby 8097 +defending 8098 +buttons 8099 +artemisia 8100 +cloak 8101 +noting 8102 +confuse 8103 +experts 8104 +whip 8105 +borrow 8106 +barney 8107 +garage 8108 +happenings 8109 +mega 8110 +1990's 8111 +disregard 8112 +bean 8113 +aaron 8114 +edges 8115 +diving 8116 +investment 8117 +wee 8118 +electronic 8119 +gena 8120 +gypsy 8121 +suave 8122 +mustache 8123 +toxic 8124 +mira 8125 +bartender 8126 +prologue 8127 +transport 8128 +atrocity 8129 +everett 8130 +bernsen 8131 +notices 8132 +jo 8133 +boogeyman 8134 +knees 8135 +1966 8136 +1000 8137 +robbed 8138 +epitome 8139 +bennett 8140 +vcr 8141 +who'd 8142 +'a 8143 +detached 8144 +brit 8145 +hometown 8146 +jack's 8147 +prone 8148 +enormously 8149 +gilliam 8150 +jackman 8151 +dom 8152 +impending 8153 +bloodbath 8154 +mister 8155 +macmurray 8156 +vigilante 8157 +offense 8158 +prostitutes 8159 +fashions 8160 +idealistic 8161 +pigs 8162 +abomination 8163 +carpet 8164 +battling 8165 +principles 8166 +paz 8167 +pretends 8168 +awarded 8169 +admiration 8170 +incidental 8171 +tin 8172 +pairing 8173 +woefully 8174 +chip 8175 +classmates 8176 +timed 8177 +budding 8178 +gandolfini 8179 +revolver 8180 +liberty 8181 +associate 8182 +padding 8183 +colony 8184 +zelah 8185 +drum 8186 +vincenzo 8187 +secure 8188 +palestinian 8189 +girls' 8190 +blames 8191 +torment 8192 +kids' 8193 +framing 8194 +tackle 8195 +tended 8196 +peers 8197 +policemen 8198 +facility 8199 +ostensibly 8200 +harron 8201 +prank 8202 +lindy 8203 +bimbo 8204 +1957 8205 +saints 8206 +capote 8207 +shrek 8208 +breathe 8209 +nineties 8210 +worrying 8211 +believability 8212 +paragraph 8213 +mediocrity 8214 +influences 8215 +reported 8216 +conveying 8217 +programming 8218 +stoned 8219 +val 8220 +barnes 8221 +sharks 8222 +unravel 8223 +courageous 8224 +deck 8225 +giovanna 8226 +grating 8227 +britney 8228 +distinctive 8229 +blondell 8230 +spoofs 8231 +brush 8232 +effortlessly 8233 +riders 8234 +midget 8235 +annoyance 8236 +counterparts 8237 +economy 8238 +rivalry 8239 +stab 8240 +knights 8241 +socially 8242 +symbols 8243 +bodyguard 8244 +qualifies 8245 +connie 8246 +acclaim 8247 +managing 8248 +vibe 8249 +monroe 8250 +frat 8251 +baked 8252 +combining 8253 +martians 8254 +boobs 8255 +prostitution 8256 +closure 8257 +senator 8258 +outset 8259 +magazines 8260 +respond 8261 +interiors 8262 +division 8263 +slam 8264 +celebrate 8265 +elected 8266 +zu 8267 +monica 8268 +dillinger 8269 +brashear 8270 +cohesive 8271 +clinic 8272 +gig 8273 +tacked 8274 +coward 8275 +parodies 8276 +greene 8277 +billing 8278 +weirdness 8279 +dunst 8280 +rourke 8281 +manipulated 8282 +concentration 8283 +sinks 8284 +dreyfuss 8285 +asset 8286 +duchovny 8287 +superstar 8288 +clyde 8289 +december 8290 +pompous 8291 +fabric 8292 +placement 8293 +gibson 8294 +bless 8295 +boards 8296 +troopers 8297 +reese 8298 +goodman 8299 +transplant 8300 +shocker 8301 +examine 8302 +chock 8303 +scarlet 8304 +informs 8305 +responds 8306 +collapse 8307 +data 8308 +swiss 8309 +reasoning 8310 +confines 8311 +categories 8312 +injustice 8313 +laser 8314 +dish 8315 +employees 8316 +smith's 8317 +em 8318 +gasp 8319 +sacrifices 8320 +maurice 8321 +worship 8322 +screenplays 8323 +tolerate 8324 +pee 8325 +overshadowed 8326 +dern 8327 +reunited 8328 +brick 8329 +loner 8330 +holt 8331 +sites 8332 +uncertain 8333 +theatres 8334 +morse 8335 +yells 8336 +sibling 8337 +cheech 8338 +butchered 8339 +mae 8340 +ernest 8341 +sensibilities 8342 +500 8343 +ali 8344 +irving 8345 +castro 8346 +influential 8347 +terrorism 8348 +strained 8349 +derived 8350 +chandler 8351 +slept 8352 +perspectives 8353 +bleeding 8354 +madman 8355 +1942 8356 +inconsistencies 8357 +sensitivity 8358 +jam 8359 +hans 8360 +sustain 8361 +systems 8362 +armor 8363 +burgess 8364 +fiery 8365 +queens 8366 +katie 8367 +gruff 8368 +ewoks 8369 +faye 8370 +tramp 8371 +brandon 8372 +lighthearted 8373 +inform 8374 +cursed 8375 +retro 8376 +250 8377 +malden 8378 +cody 8379 +spelled 8380 +manic 8381 +labeled 8382 +perverse 8383 +collector 8384 +drain 8385 +shelter 8386 +spade 8387 +fallon 8388 +ang 8389 +gino 8390 +kareena 8391 +depardieu 8392 +apollo 8393 +officially 8394 +playful 8395 +informer 8396 +banks 8397 +retirement 8398 +booth 8399 +replacing 8400 +transforms 8401 +surrender 8402 +shield 8403 +jigsaw 8404 +fiend 8405 +predecessors 8406 +judgement 8407 +bing 8408 +englund 8409 +ads 8410 +damsel 8411 +stirring 8412 +structured 8413 +patty 8414 +poet 8415 +signature 8416 +tolerance 8417 +bites 8418 +dash 8419 +seriousness 8420 +casted 8421 +mercifully 8422 +edison 8423 +advances 8424 +padded 8425 +czech 8426 +lingering 8427 +sensational 8428 +crowded 8429 +bigfoot 8430 +captive 8431 +plotted 8432 +premiered 8433 +dictator 8434 +locale 8435 +bastard 8436 +manga 8437 +fighters 8438 +sophistication 8439 +lifts 8440 +yarn 8441 +spelling 8442 +uptight 8443 +farrah 8444 +drummer 8445 +amid 8446 +kidnaps 8447 +peaks 8448 +drastically 8449 +cringing 8450 +coop 8451 +dealers 8452 +geoffrey 8453 +rousing 8454 +supermarket 8455 +standpoint 8456 +thereafter 8457 +portions 8458 +latino 8459 +henchman 8460 +berenger 8461 +slash 8462 +sandy 8463 +lurid 8464 +coal 8465 +interplay 8466 +stares 8467 +willingly 8468 +mines 8469 +ss 8470 +ceremony 8471 +inexperienced 8472 +awfulness 8473 +condemned 8474 +benny 8475 +alba 8476 +mythical 8477 +spotted 8478 +sara 8479 +fierce 8480 +thereof 8481 +bloodshed 8482 +enthralling 8483 +geniuses 8484 +lars 8485 +rant 8486 +theodore 8487 +heather 8488 +echoes 8489 +maintaining 8490 +bombed 8491 +bitchy 8492 +fiasco 8493 +powered 8494 +tina 8495 +ossessione 8496 +worm 8497 +godard 8498 +observed 8499 +staging 8500 +attendant 8501 +anxiety 8502 +villa 8503 +varying 8504 +stepmother 8505 +aircraft 8506 +david's 8507 +justification 8508 +identified 8509 +downfall 8510 +anguish 8511 +shoved 8512 +allan 8513 +bliss 8514 +caution 8515 +transported 8516 +impressions 8517 +miike's 8518 +alexandra 8519 +shout 8520 +functions 8521 +imitate 8522 +norris 8523 +dwarf 8524 +nearest 8525 +funky 8526 +drugged 8527 +stabs 8528 +marrying 8529 +hallucinations 8530 +allies 8531 +communism 8532 +fixed 8533 +sorrow 8534 +orlando 8535 +register 8536 +surf 8537 +scarier 8538 +freed 8539 +tasty 8540 +baddie 8541 +vet 8542 +attic 8543 +representing 8544 +widower 8545 +cunning 8546 +plagued 8547 +hunky 8548 +apartheid 8549 +cockney 8550 +luc 8551 +islands 8552 +fur 8553 +emphasize 8554 +confession 8555 +ceiling 8556 +hairy 8557 +warhols 8558 +stricken 8559 +presume 8560 +rosenstrasse 8561 +meadows 8562 +distorted 8563 +virtue 8564 +natali 8565 +forrest 8566 +starship 8567 +lampoon 8568 +depend 8569 +marvin 8570 +mixes 8571 +jewelry 8572 +correctness 8573 +nest 8574 +myra 8575 +rockets 8576 +russians 8577 +glenda 8578 +byron 8579 +sammy 8580 +grandpa 8581 +monday 8582 +entertains 8583 +adultery 8584 +egg 8585 +massey 8586 +drawings 8587 +travolta 8588 +tricked 8589 +abu 8590 +bio 8591 +lin 8592 +fagin 8593 +cowardly 8594 +overwrought 8595 +determine 8596 +throne 8597 +ratio 8598 +tsui 8599 +paired 8600 +cannibals 8601 +fuss 8602 +client 8603 +animator 8604 +hurry 8605 +romania 8606 +foreboding 8607 +pub 8608 +earns 8609 +bon 8610 +gen 8611 +della 8612 +photograph 8613 +pecker 8614 +censors 8615 +groundbreaking 8616 +predicted 8617 +crooked 8618 +engagement 8619 +arnie 8620 +torturing 8621 +towns 8622 +intellectually 8623 +bald 8624 +finely 8625 +confirmed 8626 +natasha 8627 +hale 8628 +chemical 8629 +spells 8630 +loony 8631 +richly 8632 +edmund 8633 +groove 8634 +vaudeville 8635 +bills 8636 +ma 8637 +millennium 8638 +gladiator 8639 +icy 8640 +irrational 8641 +ballroom 8642 +daria 8643 +conflicted 8644 +clarence 8645 +subdued 8646 +sigh 8647 +artistically 8648 +keanu 8649 +laced 8650 +potent 8651 +representative 8652 +gently 8653 +reckless 8654 +dopey 8655 +jerky 8656 +deborah 8657 +decency 8658 +grossly 8659 +predictability 8660 +consumed 8661 +belle 8662 +blessed 8663 +parks 8664 +curtain 8665 +dukakis 8666 +federal 8667 +analyze 8668 +echo 8669 +contributes 8670 +accomplishment 8671 +cheesiness 8672 +romanian 8673 +almighty 8674 +continuously 8675 +gathered 8676 +dive 8677 +undercover 8678 +diaz 8679 +profoundly 8680 +identities 8681 +crypt 8682 +downbeat 8683 +1949 8684 +gusto 8685 +missions 8686 +sasquatch 8687 +locate 8688 +borrows 8689 +maturity 8690 +harbor 8691 +denial 8692 +emmy 8693 +arch 8694 +animations 8695 +airing 8696 +superfluous 8697 +lists 8698 +officials 8699 +steaming 8700 +operate 8701 +threads 8702 +significantly 8703 +aniston 8704 +goldsworthy 8705 +anchors 8706 +disappoints 8707 +collaboration 8708 +trusted 8709 +lays 8710 +sync 8711 +1920s 8712 +wrongly 8713 +lindsey 8714 +optimism 8715 +vertigo 8716 +abroad 8717 +judges 8718 +continent 8719 +lizard 8720 +muni 8721 +helena 8722 +hartley's 8723 +zeta 8724 +denying 8725 +proportions 8726 +winners 8727 +ll 8728 +monologues 8729 +gravity 8730 +forbes 8731 +launched 8732 +robbing 8733 +mash 8734 +mocking 8735 +confronts 8736 +mutants 8737 +beetle 8738 +nifty 8739 +fence 8740 +horn 8741 +luxury 8742 +athletic 8743 +imprisoned 8744 +scriptwriter 8745 +mack 8746 +handy 8747 +pia 8748 +uninspiring 8749 +rhyme 8750 +1964 8751 +promoting 8752 +73 8753 +flew 8754 +98 8755 +corbin 8756 +chevy 8757 +mobster 8758 +altman's 8759 +extraordinarily 8760 +applause 8761 +abstract 8762 +switches 8763 +garde 8764 +icons 8765 +showcases 8766 +intelligently 8767 +capitalism 8768 +developments 8769 +lions 8770 +hanzo 8771 +hypnotic 8772 +temptation 8773 +dedication 8774 +opposition 8775 +sensation 8776 +kristofferson 8777 +barton 8778 +lds 8779 +bothers 8780 +satisfactory 8781 +nora 8782 +genetic 8783 +moonstruck 8784 +illustrate 8785 +notwithstanding 8786 +elephants 8787 +stripper 8788 +grendel 8789 +fulfilling 8790 +languages 8791 +hilton 8792 +autobiography 8793 +pleasures 8794 +lightweight 8795 +increasing 8796 +preferably 8797 +shifting 8798 +bearable 8799 +prefers 8800 +idiocy 8801 +heroin 8802 +manipulate 8803 +uncredited 8804 +sheridan 8805 +conniving 8806 +surgeon 8807 +nonexistent 8808 +deservedly 8809 +clutter 8810 +bullies 8811 +penalty 8812 +scattered 8813 +owe 8814 +lawn 8815 +upbringing 8816 +increase 8817 +oblivion 8818 +fanning 8819 +shiny 8820 +cynicism 8821 +kings 8822 +hazzard 8823 +preacher 8824 +ongoing 8825 +luthor 8826 +sister's 8827 +quirks 8828 +michaels 8829 +transitions 8830 +ravishing 8831 +reno 8832 +corridors 8833 +shady 8834 +cloth 8835 +liotta 8836 +spinning 8837 +sleeper 8838 +auteur 8839 +plummer 8840 +appalled 8841 +reportedly 8842 +dodgy 8843 +todays 8844 +harilal 8845 +kilmer 8846 +blackmail 8847 +toss 8848 +distinctly 8849 +violently 8850 +ebay 8851 +limp 8852 +marines 8853 +lesbians 8854 +vaughn 8855 +bart 8856 +knocking 8857 +palma's 8858 +boost 8859 +aboard 8860 +defy 8861 +civilians 8862 +brunette 8863 +fewer 8864 +cinematographic 8865 +liberties 8866 +shrill 8867 +youngsters 8868 +strain 8869 +hammerhead 8870 +inhabit 8871 +thug 8872 +dyke 8873 +euro 8874 +cassie 8875 +fellini 8876 +puzzled 8877 +chop 8878 +sweeping 8879 +throats 8880 +thirds 8881 +billion 8882 +witted 8883 +operating 8884 +atomic 8885 +lt 8886 +supportive 8887 +henderson 8888 +profit 8889 +prolific 8890 +sore 8891 +virginity 8892 +sleepy 8893 +golf 8894 +outlaw 8895 +unnerving 8896 +expresses 8897 +mills 8898 +forsythe 8899 +authors 8900 +behaving 8901 +visconti 8902 +efficient 8903 +visceral 8904 +glow 8905 +jones' 8906 +melinda 8907 +muscle 8908 +pepper 8909 +heavenly 8910 +unwilling 8911 +1965 8912 +roach 8913 +marcus 8914 +tables 8915 +shelves 8916 +dunne 8917 +tedium 8918 +illustrated 8919 +explanations 8920 +snowy 8921 +patriotic 8922 +alcoholism 8923 +whipped 8924 +ledger 8925 +slaughtered 8926 +redford 8927 +percent 8928 +rapes 8929 +disasters 8930 +dickinson 8931 +examined 8932 +cradle 8933 +fleeing 8934 +healing 8935 +lightly 8936 +nerdy 8937 +torch 8938 +rodney 8939 +believer 8940 +teddy 8941 +meyers 8942 +lorre 8943 +denver 8944 +dangers 8945 +architect 8946 +vulnerability 8947 +knives 8948 +dillon 8949 +goo 8950 +numbingly 8951 +inch 8952 +compositions 8953 +flipping 8954 +amoral 8955 +wrath 8956 +rack 8957 +imply 8958 +bonds 8959 +pistol 8960 +perceived 8961 +aura 8962 +tobe 8963 +seventh 8964 +verhoeven's 8965 +insignificant 8966 +simpler 8967 +shatner 8968 +mac 8969 +kornbluth 8970 +barbarian 8971 +zoom 8972 +proudly 8973 +hawaii 8974 +hustler 8975 +penguin 8976 +supports 8977 +thumb 8978 +segal 8979 +fulfill 8980 +bothering 8981 +jurassic 8982 +compromise 8983 +annoyingly 8984 +kenny 8985 +scandal 8986 +overtly 8987 +fleeting 8988 +metropolis 8989 +guru 8990 +rotting 8991 +sixteen 8992 +deadpan 8993 +retrieve 8994 +moderately 8995 +chat 8996 +lang 8997 +simon's 8998 +illusion 8999 +heartless 9000 +backwoods 9001 +climate 9002 +righteous 9003 +beth 9004 +grisly 9005 +prejudices 9006 +immigrants 9007 +alienation 9008 +muscular 9009 +astonishingly 9010 +doses 9011 +traveled 9012 +happier 9013 +electricity 9014 +succession 9015 +cousins 9016 +mandatory 9017 +dental 9018 +breakthrough 9019 +freaked 9020 +clockwork 9021 +ursula 9022 +recurring 9023 +notions 9024 +mechanic 9025 +recovering 9026 +zhang 9027 +comprised 9028 +coverage 9029 +elder 9030 +afghanistan 9031 +trendy 9032 +keeper 9033 +hungarian 9034 +attributes 9035 +brennan 9036 +protecting 9037 +priests 9038 +aztec 9039 +ranger 9040 +recipe 9041 +vienna 9042 +ogre 9043 +farnsworth 9044 +tasks 9045 +romero's 9046 +purse 9047 +subtitled 9048 +lansbury 9049 +pickup 9050 +pals 9051 +unconscious 9052 +animators 9053 +legion 9054 +meanings 9055 +needlessly 9056 +sleuth 9057 +association 9058 +slips 9059 +doris 9060 +pond 9061 +improvised 9062 +relates 9063 +mcdowell 9064 +volumes 9065 +ranging 9066 +zany 9067 +irresistible 9068 +elisha 9069 +herrings 9070 +coppola 9071 +prolonged 9072 +relaxing 9073 +1931 9074 +1938 9075 +rudd 9076 +heir 9077 +innuendo 9078 +urgency 9079 +bloke 9080 +flamboyant 9081 +muriel 9082 +prophet 9083 +reruns 9084 +christensen 9085 +lure 9086 +cracker 9087 +levy 9088 +shakespearean 9089 +encourages 9090 +mockery 9091 +swords 9092 +penis 9093 +pam 9094 +welcomed 9095 +rugged 9096 +academic 9097 +honeymoon 9098 +climbs 9099 +snatch 9100 +overwhelmed 9101 +gays 9102 +roommates 9103 +jolly 9104 +heavens 9105 +placing 9106 +watered 9107 +fable 9108 +zealand 9109 +carnival 9110 +gee 9111 +archer 9112 +locales 9113 +thorn 9114 +smarmy 9115 +kiddie 9116 +farewell 9117 +cheat 9118 +hopeful 9119 +backdrops 9120 +treating 9121 +kamal 9122 +irresponsible 9123 +behalf 9124 +benoit 9125 +unemployed 9126 +backyard 9127 +norton 9128 +stumbling 9129 +theirs 9130 +anonymous 9131 +temporary 9132 +distinguished 9133 +moore's 9134 +inhabited 9135 +wwi 9136 +eastwood's 9137 +pranks 9138 +custody 9139 +yearning 9140 +interspersed 9141 +agatha 9142 +chocolate 9143 +hug 9144 +guided 9145 +martino 9146 +steamy 9147 +feared 9148 +opponents 9149 +crawl 9150 +mans 9151 +jew 9152 +bombing 9153 +assortment 9154 +poke 9155 +imitating 9156 +management 9157 +keitel 9158 +frenzy 9159 +mcadams 9160 +architecture 9161 +spitting 9162 +48 9163 +hector 9164 +fitzgerald 9165 +rko 9166 +redgrave 9167 +induced 9168 +plants 9169 +rusty 9170 +janitor 9171 +weaver 9172 +recreate 9173 +islam 9174 +rogue 9175 +roads 9176 +rewrite 9177 +dodge 9178 +balloon 9179 +honey 9180 +neeson 9181 +conquest 9182 +slug 9183 +wolves 9184 +neglect 9185 +shawn 9186 +concentrated 9187 +tested 9188 +existential 9189 +expanded 9190 +worldwide 9191 +truthful 9192 +unlucky 9193 +liz 9194 +compassionate 9195 +limbs 9196 +impeccable 9197 +dogma 9198 +shattering 9199 +sailors 9200 +peterson 9201 +jock 9202 +rizzo 9203 +kalifornia 9204 +mcdermott 9205 +versatile 9206 +400 9207 +michael's 9208 +naval 9209 +burden 9210 +cheung 9211 +largest 9212 +culkin 9213 +retelling 9214 +muted 9215 +leaps 9216 +theo 9217 +passive 9218 +bucket 9219 +pertwee 9220 +eddy 9221 +rapture 9222 +continuous 9223 +gage 9224 +stretches 9225 +giggle 9226 +marx 9227 +concludes 9228 +stalks 9229 +amok 9230 +adequately 9231 +melt 9232 +stature 9233 +counted 9234 +borderline 9235 +mastermind 9236 +boxes 9237 +posh 9238 +taker 9239 +counterpart 9240 +izzard 9241 +straw 9242 +toe 9243 +shamelessly 9244 +crenna 9245 +tango 9246 +pour 9247 +behaves 9248 +sematary 9249 +expand 9250 +azumi 9251 +country's 9252 +stimulating 9253 +grady 9254 +expressing 9255 +payne 9256 +crass 9257 +intellect 9258 +booker 9259 +dani 9260 +parents' 9261 +lotr 9262 +miyazaki's 9263 +wits 9264 +waving 9265 +traumatized 9266 +illiterate 9267 +chan's 9268 +puzzling 9269 +splitting 9270 +subtleties 9271 +seduction 9272 +condescending 9273 +rebecca 9274 +inherited 9275 +seal 9276 +consisted 9277 +stubborn 9278 +didnt 9279 +lieutenant 9280 +slows 9281 +john's 9282 +glee 9283 +honorable 9284 +'73 9285 +valerie 9286 +smoothly 9287 +poo 9288 +evolved 9289 +darling 9290 +planted 9291 +mold 9292 +supremacy 9293 +opener 9294 +seuss 9295 +craven's 9296 +celine 9297 +hesitate 9298 +conception 9299 +supporters 9300 +revolting 9301 +practices 9302 +orgy 9303 +cheaper 9304 +town's 9305 +forgivable 9306 +nutty 9307 +speechless 9308 +nailed 9309 +associates 9310 +platoon 9311 +disdain 9312 +waits 9313 +knox 9314 +it´s 9315 +collecting 9316 +alligator 9317 +hispanic 9318 +mutated 9319 +woven 9320 +hardest 9321 +lubitsch 9322 +january 9323 +apprentice 9324 +uber 9325 +sarne 9326 +pets 9327 +fawcett 9328 +marred 9329 +elevate 9330 +drivers 9331 +creepiness 9332 +revive 9333 +harlem 9334 +vivah 9335 +kindness 9336 +marathon 9337 +bishop 9338 +gannon 9339 +carole 9340 +brits 9341 +submit 9342 +embarrass 9343 +boyfriends 9344 +dreadfully 9345 +oppressive 9346 +discernible 9347 +intruder 9348 +tourists 9349 +conduct 9350 +rehearsal 9351 +bolivia 9352 +astronaut 9353 +joanna 9354 +grounded 9355 +sessions 9356 +cocktail 9357 +stir 9358 +gimmicks 9359 +archive 9360 +stereotyping 9361 +aweigh 9362 +18th 9363 +undeveloped 9364 +rico 9365 +concentrates 9366 +bruckheimer 9367 +psychiatric 9368 +incompetence 9369 +villagers 9370 +customs 9371 +alienate 9372 +slew 9373 +footsteps 9374 +approximately 9375 +discussions 9376 +blink 9377 +vault 9378 +transformers 9379 +sloane 9380 +choke 9381 +infidelity 9382 +relied 9383 +undertaker 9384 +lovingly 9385 +casually 9386 +luzhin 9387 +disappearing 9388 +historians 9389 +shaolin 9390 +mastroianni 9391 +midler 9392 +atrocities 9393 +bash 9394 +inc 9395 +hedy 9396 +drums 9397 +bonding 9398 +entertainer 9399 +revelations 9400 +holland 9401 +floriane 9402 +downtown 9403 +denied 9404 +connor 9405 +stupidest 9406 +tel 9407 +sinatra's 9408 +lyrical 9409 +woke 9410 +knack 9411 +dripping 9412 +saddest 9413 +loathing 9414 +insects 9415 +hoover 9416 +apologize 9417 +premises 9418 +elmer 9419 +screamed 9420 +lecture 9421 +skipping 9422 +bursts 9423 +noam 9424 +passions 9425 +cocky 9426 +prevalent 9427 +regrets 9428 +suspended 9429 +shack 9430 +democracy 9431 +overacts 9432 +enhances 9433 +deathstalker 9434 +1960 9435 +choreographer 9436 +keeler 9437 +cillian 9438 +contemplate 9439 +smarter 9440 +marlene 9441 +philadelphia 9442 +sammi 9443 +kingsley 9444 +micheal 9445 +mpaa 9446 +duryea 9447 +creeps 9448 +capsule 9449 +converted 9450 +zabriskie 9451 +perceive 9452 +confronting 9453 +administration 9454 +arizona 9455 +viggo 9456 +ecstasy 9457 +candidates 9458 +branch 9459 +passenger 9460 +benson 9461 +sans 9462 +victoria's 9463 +callahan 9464 +intestines 9465 +swamp 9466 +sparse 9467 +request 9468 +overseas 9469 +bass 9470 +surpasses 9471 +organs 9472 +rohmer 9473 +montages 9474 +joshua 9475 +ella 9476 +maguire 9477 +rhys 9478 +cloud 9479 +stripped 9480 +rushes 9481 +kentucky 9482 +tensions 9483 +mom's 9484 +operas 9485 +chapters 9486 +monstrous 9487 +usage 9488 +fugitive 9489 +shaun 9490 +slipped 9491 +documents 9492 +email 9493 +classified 9494 +norwegian 9495 +reception 9496 +ash 9497 +sacrificed 9498 +switzerland 9499 +rightfully 9500 +cruella 9501 +psychologically 9502 +bury 9503 +liar 9504 +clumsily 9505 +crow 9506 +mindset 9507 +untrue 9508 +barker 9509 +lange 9510 +toro 9511 +ahmad 9512 +wipe 9513 +sixty 9514 +brink 9515 +insanely 9516 +mourning 9517 +vets 9518 +wu 9519 +1956 9520 +restless 9521 +loop 9522 +fanatics 9523 +rests 9524 +guevara 9525 +connecting 9526 +city's 9527 +friendships 9528 +satellite 9529 +empathize 9530 +surfers 9531 +immersed 9532 +mostel 9533 +squeeze 9534 +backing 9535 +admirably 9536 +confirm 9537 +equals 9538 +vengeful 9539 +pauses 9540 +snippets 9541 +mamet 9542 +that'll 9543 +anchorman 9544 +dense 9545 +strikingly 9546 +daphne 9547 +misplaced 9548 +1941 9549 +streak 9550 +shrink 9551 +garnered 9552 +breathless 9553 +hiv 9554 +delve 9555 +grain 9556 +spectrum 9557 +dusty 9558 +durbin 9559 +locks 9560 +november 9561 +o'neill 9562 +crook 9563 +render 9564 +participation 9565 +deception 9566 +replay 9567 +apartments 9568 +sr 9569 +lawyers 9570 +requisite 9571 +telly 9572 +basil 9573 +kinky 9574 +assist 9575 +spectacularly 9576 +scantily 9577 +prevented 9578 +obscene 9579 +reincarnation 9580 +morgana 9581 +bout 9582 +looney 9583 +adventurous 9584 +sykes 9585 +maverick 9586 +lucio 9587 +travelling 9588 +diabolical 9589 +capt 9590 +promotion 9591 +partial 9592 +eater 9593 +dime 9594 +bathing 9595 +criminally 9596 +underdog 9597 +interpret 9598 +suggestive 9599 +springs 9600 +graves 9601 +spielberg's 9602 +technological 9603 +wan 9604 +cortez 9605 +proverbial 9606 +granger 9607 +phrases 9608 +societies 9609 +thankful 9610 +palette 9611 +outrage 9612 +betrays 9613 +lung 9614 +marquis 9615 +ing 9616 +regal 9617 +oriental 9618 +duties 9619 +whacked 9620 +kerr 9621 +documented 9622 +700 9623 +stoic 9624 +fairytale 9625 +listing 9626 +acknowledged 9627 +allison 9628 +matching 9629 +longtime 9630 +garcia 9631 +elliot 9632 +33 9633 +adopt 9634 +flea 9635 +carlito's 9636 +1940's 9637 +coleman 9638 +draft 9639 +witless 9640 +kramer 9641 +haha 9642 +lap 9643 +alternately 9644 +1930 9645 +sentenced 9646 +harry's 9647 +daisies 9648 +overt 9649 +mining 9650 +stepped 9651 +eliminate 9652 +chains 9653 +regain 9654 +nuance 9655 +italians 9656 +hurting 9657 +honour 9658 +sealed 9659 +societal 9660 +indifference 9661 +lombard 9662 +teamed 9663 +cathy 9664 +its' 9665 +unfinished 9666 +floors 9667 +downside 9668 +tucker 9669 +paperhouse 9670 +compound 9671 +eggs 9672 +underused 9673 +incarnation 9674 +hunk 9675 +goer 9676 +presumed 9677 +caruso 9678 +interpreted 9679 +colourful 9680 +stills 9681 +caroline 9682 +keyboard 9683 +claw 9684 +snappy 9685 +camps 9686 +crop 9687 +sheet 9688 +overnight 9689 +dung 9690 +booze 9691 +risks 9692 +rub 9693 +oddball 9694 +exhibit 9695 +anchor 9696 +fireworks 9697 +batwoman 9698 +gesture 9699 +skinned 9700 +undertones 9701 +achieving 9702 +lanza 9703 +goofs 9704 +flee 9705 +recalls 9706 +stable 9707 +fantastically 9708 +exposing 9709 +shakes 9710 +addressing 9711 +prototype 9712 +carface 9713 +hes 9714 +competently 9715 +retain 9716 +schemes 9717 +hogan 9718 +voting 9719 +episodic 9720 +occurring 9721 +topped 9722 +1954 9723 +norma 9724 +chore 9725 +chang 9726 +shouts 9727 +rainer 9728 +colonial 9729 +recreation 9730 +forum 9731 +companions 9732 +apologies 9733 +insulted 9734 +holidays 9735 +throwaway 9736 +tepid 9737 +darkest 9738 +pulse 9739 +pita 9740 +superiors 9741 +grumpy 9742 +illustrates 9743 +sweetheart 9744 +showtime 9745 +aiello 9746 +btk 9747 +cbc 9748 +baseketball 9749 +horizon 9750 +eliminated 9751 +weirdo 9752 +welch 9753 +stepping 9754 +leno 9755 +beau 9756 +affections 9757 +leopold 9758 +inheritance 9759 +masturbation 9760 +itchy 9761 +locker 9762 +universally 9763 +shadowy 9764 +employ 9765 +skywalker 9766 +grips 9767 +gardens 9768 +sorvino 9769 +expertise 9770 +irwin 9771 +t'aime 9772 +babysitter 9773 +bryan 9774 +positions 9775 +coarse 9776 +tremors 9777 +iceberg 9778 +monumental 9779 +thinner 9780 +allegedly 9781 +dominick 9782 +allied 9783 +bogdanovich 9784 +raving 9785 +supplies 9786 +kaufman 9787 +sacred 9788 +shootings 9789 +primal 9790 +hiring 9791 +hockey 9792 +flamenco 9793 +thirteen 9794 +carlito 9795 +polite 9796 +exudes 9797 +gaining 9798 +darius 9799 +quarters 9800 +willem 9801 +crummy 9802 +duff 9803 +sorta 9804 +rigid 9805 +eponymous 9806 +smitten 9807 +attributed 9808 +variations 9809 +mischievous 9810 +unborn 9811 +wayne's 9812 +circuit 9813 +integrated 9814 +unimpressive 9815 +carson 9816 +150 9817 +siege 9818 +endured 9819 +surrogate 9820 +gifts 9821 +practicing 9822 +disgruntled 9823 +drifter 9824 +renowned 9825 +chef 9826 +operatic 9827 +maiden 9828 +frenetic 9829 +wal 9830 +roaring 9831 +author's 9832 +wondrous 9833 +greta 9834 +gamut 9835 +marital 9836 +gym 9837 +offerings 9838 +zatoichi 9839 +emerged 9840 +exaggeration 9841 +planets 9842 +raft 9843 +connolly 9844 +mcintire 9845 +strangest 9846 +marvellous 9847 +runtime 9848 +misfire 9849 +extremes 9850 +swift 9851 +seinfeld 9852 +jackass 9853 +harmony 9854 +plantation 9855 +bravery 9856 +pavarotti 9857 +catastrophe 9858 +malcolm 9859 +portman 9860 +solving 9861 +albums 9862 +winston 9863 +corky 9864 +allegory 9865 +spears 9866 +saif 9867 +goof 9868 +outta 9869 +virtues 9870 +monstrosity 9871 +ideology 9872 +edits 9873 +celebrating 9874 +adapting 9875 +ferry 9876 +desolate 9877 +jessie 9878 +inflicted 9879 +rocker 9880 +projection 9881 +irs 9882 +cambodia 9883 +enthralled 9884 +ensuing 9885 +leia 9886 +o'toole 9887 +transferred 9888 +exposes 9889 +competing 9890 +yourselves 9891 +sentiments 9892 +kisses 9893 +stray 9894 +turgid 9895 +declares 9896 +nuns 9897 +mercilessly 9898 +it'd 9899 +exceedingly 9900 +ted's 9901 +insecure 9902 +ben's 9903 +tanks 9904 +kusturica 9905 +spaces 9906 +spliced 9907 +sheila 9908 +crowds 9909 +balcony 9910 +menu 9911 +lamas 9912 +diver 9913 +secluded 9914 +integral 9915 +redeemed 9916 +halt 9917 +decapitated 9918 +stealth 9919 +budgeted 9920 +voters 9921 +overweight 9922 +praying 9923 +stevenson 9924 +cleveland 9925 +stakes 9926 +mattei 9927 +charity 9928 +stalk 9929 +olympia 9930 +olympic 9931 +aspirations 9932 +decoration 9933 +slack 9934 +bullying 9935 +bum 9936 +mo 9937 +capitalize 9938 +jameson 9939 +skimpy 9940 +wicker 9941 +starving 9942 +frenchman 9943 +frye 9944 +ate 9945 +monastery 9946 +wb 9947 +hayden 9948 +banana 9949 +grandparents 9950 +vacuous 9951 +willy 9952 +darkman 9953 +neutral 9954 +rumors 9955 +somber 9956 +aunts 9957 +amateurs 9958 +radar 9959 +ounce 9960 +bagdad 9961 +stud 9962 +closeups 9963 +insisted 9964 +jed 9965 +geeky 9966 +64 9967 +aims 9968 +complains 9969 +ewan 9970 +exhausted 9971 +day's 9972 +weaves 9973 +gladly 9974 +misogynistic 9975 +soles 9976 +michel 9977 +uniquely 9978 +interminable 9979 +aristocrat 9980 +paul's 9981 +everybody's 9982 +avant 9983 +answering 9984 +smallest 9985 +contacts 9986 +enlightenment 9987 +murphy's 9988 +employs 9989 +unforgivable 9990 +punchline 9991 +culminating 9992 +talentless 9993 +grabbing 9994 +soulless 9995 +unfairly 9996 +grail 9997 +retrospect 9998 +edged 9999
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/BUILD new file mode 100644 index 0000000..9cf6039e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/BUILD
@@ -0,0 +1,14 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "test_files", + srcs = glob([ + "*.json", + "*.tflite", + "*.txt", + "*.csv", + ]), +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite new file mode 100644 index 0000000..8015ee5 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflite b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflite new file mode 100644 index 0000000..6416c4c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/coco_ssd_mobilenet_v1_score_calibration.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/coco_ssd_mobilenet_v1_score_calibration.json new file mode 100644 index 0000000..9459b77 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/coco_ssd_mobilenet_v1_score_calibration.json
@@ -0,0 +1,139 @@ +{ + "name": "ObjectDetector", + "description": "Identify which of a known set of objects might be present and provide information about their positions within the given image or a video stream.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be detected.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "location", + "description": "The locations of the detected boxes.", + "content": { + "content_properties_type": "BoundingBoxProperties", + "content_properties": { + "index": [ + 1, + 0, + 3, + 2 + ], + "type": "BOUNDARIES" + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + } + }, + { + "name": "category", + "description": "The categories of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + }, + "associated_files": [ + { + "name": "labelmap.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_VALUE_LABELS" + } + ] + }, + { + "name": "score", + "description": "The scores of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + }, + "range": { + "min": 2, + "max": 2 + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "INVERSE_LOGISTIC", + "default_score": 0.2 + } + } + ], + "stats": { + }, + "associated_files": [ + { + "name": "score_calibration.csv", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + }, + { + "name": "number of detections", + "description": "The number of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_groups": [ + { + "name": "detection_result", + "tensor_names": [ + "location", + "category", + "score" + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/coco_ssd_mobilenet_v1_score_calibration_dummy.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/coco_ssd_mobilenet_v1_score_calibration_dummy.json new file mode 100644 index 0000000..888c832a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/coco_ssd_mobilenet_v1_score_calibration_dummy.json
@@ -0,0 +1,138 @@ +{ + "name": "ObjectDetector", + "description": "Identify which of a known set of objects might be present and provide information about their positions within the given image or a video stream.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be detected.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "location", + "description": "The locations of the detected boxes.", + "content": { + "content_properties_type": "BoundingBoxProperties", + "content_properties": { + "index": [ + 1, + 0, + 3, + 2 + ], + "type": "BOUNDARIES" + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + } + }, + { + "name": "category", + "description": "The categories of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + }, + "associated_files": [ + { + "name": "labelmap.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_VALUE_LABELS" + } + ] + }, + { + "name": "score", + "description": "The scores of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + }, + "range": { + "min": 2, + "max": 2 + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "INVERSE_LOGISTIC" + } + } + ], + "stats": { + }, + "associated_files": [ + { + "name": "score_calibration_dummy.csv", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + }, + { + "name": "number of detections", + "description": "The number of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_groups": [ + { + "name": "detection_result", + "tensor_names": [ + "location", + "category", + "score" + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/efficientdet_lite0_v1.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/efficientdet_lite0_v1.json new file mode 100644 index 0000000..3c29c17 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/efficientdet_lite0_v1.json
@@ -0,0 +1,123 @@ +{ + "name": "ObjectDetector", + "description": "Identify which of a known set of objects might be present and provide information about their positions within the given image or a video stream.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be detected.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "The scores of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + } + }, + { + "name": "location", + "description": "The locations of the detected boxes.", + "content": { + "content_properties_type": "BoundingBoxProperties", + "content_properties": { + "index": [ + 1, + 0, + 3, + 2 + ], + "type": "BOUNDARIES" + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + } + }, + { + "name": "number of detections", + "description": "The number of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "category", + "description": "The categories of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + }, + "associated_files": [ + { + "name": "labelmap.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_VALUE_LABELS" + } + ] + } + ], + "output_tensor_groups": [ + { + "name": "detection_result", + "tensor_names": [ + "location", + "category", + "score" + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/efficientdet_lite0_v1.tflite b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/efficientdet_lite0_v1.tflite new file mode 100644 index 0000000..c2d49da --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/efficientdet_lite0_v1.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/efficientdet_lite0_v1_default.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/efficientdet_lite0_v1_default.json new file mode 100644 index 0000000..77e4955 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/efficientdet_lite0_v1_default.json
@@ -0,0 +1,97 @@ +{ + "name": "ObjectDetector", + "description": "Identify which of a known set of objects might be present and provide information about their positions within the given image or a video stream.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be detected.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "The scores of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + } + }, + { + "name": "location", + "description": "The locations of the detected boxes.", + "content": { + "content_properties_type": "BoundingBoxProperties", + "content_properties": { + "index": [ + 1, + 0, + 3, + 2 + ], + "type": "BOUNDARIES" + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + } + }, + { + "name": "number of detections", + "description": "The number of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "category", + "description": "The categories of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + } + } + ], + "output_tensor_groups": [ + { + "name": "detection_result", + "tensor_names": [ + "location", + "category", + "score" + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/labelmap.txt b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/labelmap.txt new file mode 100644 index 0000000..695772dc --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/labelmap.txt
@@ -0,0 +1,90 @@ +person +bicycle +car +motorcycle +airplane +bus +train +truck +boat +traffic light +fire hydrant +??? +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +??? +backpack +umbrella +??? +??? +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +??? +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +couch +potted plant +bed +??? +dining table +??? +??? +toilet +??? +tv +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +??? +book +clock +vase +scissors +teddy bear +hair drier +toothbrush
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/score_calibration.csv b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/score_calibration.csv new file mode 100644 index 0000000..1127d3d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/score_calibration.csv
@@ -0,0 +1,89 @@ + +0.9876328110694885,0.36622241139411926,0.5352765321731567,0.71484375 +0.9584911465644836,1.0602262020111084,0.2777034342288971,0.019999999552965164 +0.9698624014854431,0.8795201778411865,0.539591908454895,0.00390625 +0.7486230731010437,1.1876736879348755,2.552982807159424,0.019999999552965164 +0.9745277166366577,0.3739396333694458,0.4621727764606476,0.19921875 +0.9683839678764343,0.6996201276779175,0.7690851092338562,0.019999999552965164 +0.6875,0.31044548749923706,1.0056899785995483,0.019999999552965164 +0.9849396347999573,0.8532888889312744,-0.2361421436071396,0.03125 +0.9878578186035156,1.0118975639343262,0.13313621282577515,0.359375 +0.9915205836296082,0.4434199929237366,1.0268371105194092,0.05078125 +0.9370332360267639,0.4586562216281891,-0.08101099729537964,0.019999999552965164 +0.9905818104743958,0.8670706152915955,0.012704282067716122,0.019999999552965164 +0.9080020189285278,0.8507471680641174,0.5081117749214172,0.019999999552965164 +0.985953152179718,0.9933826923370361,-0.8114940524101257,0.109375 +0.9819648861885071,1.12098228931427,-0.6330763697624207,0.01171875 +0.9025918245315552,0.7803755402565002,0.03275677561759949,0.08984375 +0.9863958954811096,0.11243592947721481,0.935604453086853,0.61328125 +0.9905291795730591,0.3710605800151825,0.708966851234436,0.359375 +0.9917052984237671,0.9596433043479919,0.19800108671188354,0.09765625 +0.8762937188148499,0.3449830114841461,0.5352474451065063,0.0078125 +0.9902125000953674,0.8918796181678772,-0.1306992471218109,0.26171875 + +0.9902340173721313,0.9177873134613037,-0.4322589933872223,0.019999999552965164 +0.9707600474357605,0.7028177976608276,0.9813734889030457,0.019999999552965164 +0.9823090434074402,1.0499590635299683,0.12045472860336304,0.0078125 +0.990516185760498,0.9449402093887329,1.3773189783096313,0.019999999552965164 +0.9875434041023254,0.577914297580719,1.282518982887268,0.0390625 +0.9821421504020691,0.0967339277267456,0.8279788494110107,0.47265625 +0.9875047206878662,0.9038218259811401,2.1208062171936035,0.38671875 +0.9857864379882812,0.8627446889877319,0.18189261853694916,0.019999999552965164 +0.9647751450538635,1.0752476453781128,-0.018294010311365128,0.0234375 +0.9830358624458313,0.5638481378555298,0.8346489667892456,0.019999999552965164 +0.9904966354370117,1.0160938501358032,-0.0573287308216095,0.00390625 +0.8458405137062073,0.4868394434452057,0.6617084741592407,0.019999999552965164 +0.9847381711006165,0.5939620137214661,0.008616370148956776,0.00390625 +0.9375938773155212,0.723095178604126,0.6635608077049255,0.019999999552965164 +0.9334303140640259,0.5689108967781067,0.37019580602645874,0.019999999552965164 +0.9716793894767761,1.0037211179733276,0.5898993611335754,0.02734375 +0.9197732210159302,0.46794334053993225,0.7365336418151855,0.640625 +0.9857497811317444,0.7299028635025024,0.9195274114608765,0.0390625 +0.8758038282394409,1.200216293334961,0.02580185979604721,0.019999999552965164 +0.9841026067733765,0.8050475716590881,0.9698556661605835,0.0078125 +0.9908539652824402,0.7911490201950073,0.19351358711719513,0.12109375 +0.9179956316947937,0.023991893976926804,0.35193610191345215,0.04296875 +0.9903728365898132,0.7744967341423035,0.2686336636543274,0.359375 +0.906022846698761,0.5766159892082214,1.0600007772445679,0.04296875 +0.9885554909706116,0.99117511510849,0.5611960291862488,0.4140625 +0.9906331896781921,1.1376535892486572,1.45369291305542,0.019999999552965164 +0.9640991687774658,0.5387894511222839,1.1824018955230713,0.019999999552965164 +0.9932155609130859,0.4347895085811615,1.3938102722167969,0.0078125 +0.9884702563285828,0.885567843914032,0.1556047648191452,0.1484375 +0.9891508221626282,0.04143073782324791,0.6111864447593689,0.0078125 +0.8935436010360718,0.2937895655632019,0.3215920031070709,0.00390625 +0.8327123522758484,0.8381986021995544,-0.026293788105249405,0.019999999552965164 +0.9839455485343933,0.9581400156021118,1.495324969291687,0.640625 +0.9904995560646057,0.9168422818183899,0.33293962478637695,0.015625 +0.9856975674629211,1.0433714389801025,0.5954801440238953,0.019999999552965164 +0.9942344427108765,0.7206616997718811,1.666426181793213,0.9609375 +0.8182767033576965,0.9546273946762085,0.5500107407569885,0.019999999552965164 +0.9631295800209045,0.6277880668640137,0.05952891707420349,0.05859375 +0.9819005727767944,1.0826934576034546,0.7444049715995789,0.30859375 +0.9884315133094788,1.0500890016555786,1.1161768436431885,0.019999999552965164 +0.9175815582275391,0.09232989698648453,1.596696138381958,0.47265625 +0.9868760108947754,0.903079628944397,-0.15774966776371002,0.8515625 +0.9866015911102295,0.7533788084983826,0.7489103078842163,0.03125 +0.8074312806129456,0.8615151643753052,0.40621864795684814,0.00390625 +0.9829285144805908,0.8954831957817078,0.4462486207485199,0.02734375 +0.9681841135025024,0.6257772445678711,0.43809664249420166,0.38671875 +0.9872947931289673,0.9947993159294128,0.9271130561828613,0.26171875 +0.7997345328330994,0.3995186686515808,-0.3755347430706024,0.019999999552965164 +0.9922754168510437,1.1357101202011108,-0.10267537832260132,0.5 +0.9861471652984619,0.8725204467773438,1.1657888889312744,0.019999999552965164 +0.9888646006584167,1.2098380327224731,-0.27832522988319397,0.05078125 +0.5641342997550964,1.0501892566680908,1.9519661664962769,0.019999999552965164 +0.9548168778419495,0.8971696496009827,1.378737449645996,0.00390625 +0.9875019788742065,0.8718118071556091,0.5476236939430237,0.0078125 +0.9725168347358704,0.6989551782608032,-1.3157455921173096,0.61328125 +0.9864014983177185,0.7576251029968262,-0.41650667786598206,0.00390625 +0.960071861743927,0.13068856298923492,0.4819187819957733,0.019999999552965164 +0.9849705100059509,0.7724528908729553,0.3877875804901123,0.03125 +0.9703006744384766,0.8848260641098022,-1.1767181158065796,0.80078125 +0.9837008714675903,0.7015050053596497,0.18209102749824524,0.00390625 +0.9579976797103882,0.053806986659765244,2.7309608459472656,0.4000000059604645 +0.9896979928016663,0.41135814785957336,0.5738034844398499,0.019999999552965164 +0.9853873252868652,0.5438565611839294,0.20562179386615753,0.02734375 +0.9784129858016968,0.6330984830856323,-0.1789831817150116,0.015625 +0.9375,0.855596125125885,-0.1933964192867279,0.019999999552965164 +0.9524176716804504,0.08709807693958282,0.6299692988395691,0.33203125
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/score_calibration_dummy.csv b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/score_calibration_dummy.csv new file mode 100644 index 0000000..98b1b14 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/score_calibration_dummy.csv
@@ -0,0 +1,89 @@ +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 + +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 + +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 + +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 + +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99 +1.0,1.0,0.0,-99
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/ssd_mobilenet_v1.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/ssd_mobilenet_v1.json new file mode 100644 index 0000000..36eb0c4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/ssd_mobilenet_v1.json
@@ -0,0 +1,123 @@ +{ + "name": "ObjectDetector", + "description": "Identify which of a known set of objects might be present and provide information about their positions within the given image or a video stream.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be detected.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "location", + "description": "The locations of the detected boxes.", + "content": { + "content_properties_type": "BoundingBoxProperties", + "content_properties": { + "index": [ + 1, + 0, + 3, + 2 + ], + "type": "BOUNDARIES" + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + } + }, + { + "name": "category", + "description": "The categories of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + }, + "associated_files": [ + { + "name": "labelmap.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_VALUE_LABELS" + } + ] + }, + { + "name": "score", + "description": "The scores of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + } + }, + { + "name": "number of detections", + "description": "The number of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_groups": [ + { + "name": "detection_result", + "tensor_names": [ + "location", + "category", + "score" + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/ssd_mobilenet_v1.tflite b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/ssd_mobilenet_v1.tflite new file mode 100644 index 0000000..8015ee5 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/ssd_mobilenet_v1.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/ssd_mobilenet_v1_default.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/ssd_mobilenet_v1_default.json new file mode 100644 index 0000000..365892a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/ssd_mobilenet_v1_default.json
@@ -0,0 +1,97 @@ +{ + "name": "ObjectDetector", + "description": "Identify which of a known set of objects might be present and provide information about their positions within the given image or a video stream.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be detected.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "location", + "description": "The locations of the detected boxes.", + "content": { + "content_properties_type": "BoundingBoxProperties", + "content_properties": { + "index": [ + 1, + 0, + 3, + 2 + ], + "type": "BOUNDARIES" + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + } + }, + { + "name": "category", + "description": "The categories of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + } + }, + { + "name": "score", + "description": "The scores of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + }, + "range": { + "min": 2, + "max": 2 + } + }, + "stats": { + } + }, + { + "name": "number of detections", + "description": "The number of the detected boxes.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_groups": [ + { + "name": "detection_result", + "tensor_names": [ + "location", + "category", + "score" + ] + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/question_answerer/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/question_answerer/BUILD new file mode 100644 index 0000000..a136206 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/question_answerer/BUILD
@@ -0,0 +1,17 @@ +load("//tensorflow_lite_support/tools/build_rules:http_files.bzl", "tflite_model") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "test_files", + srcs = glob([ + "*.json", + "*.tflite", + "*.txt", + ]) + [":mobilebert_float"], +) + +tflite_model(name = "mobilebert_float")
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/regex_tokenizer_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/regex_tokenizer_meta.json new file mode 100644 index 0000000..dc1f2f6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/regex_tokenizer_meta.json
@@ -0,0 +1,21 @@ +{ + "subgraph_metadata": [ + { + "input_process_units": [ + { + "options_type": "RegexTokenizerOptions", + "options": { + "delim_regex_pattern": "[^\\w\\']+", + "vocab_file": [ + { + "name": "vocab.txt", + "description": "Vocabulary file to convert natural language words to embedding vectors.", + "type": "VOCABULARY" + } + ] + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/score_calibration_file_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/score_calibration_file_meta.json new file mode 100644 index 0000000..c47d846 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/score_calibration_file_meta.json
@@ -0,0 +1,9 @@ +{ + "associated_files": [ + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/score_calibration_tensor_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/score_calibration_tensor_meta.json new file mode 100644 index 0000000..fd6c0e4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/score_calibration_tensor_meta.json
@@ -0,0 +1,15 @@ +{ + "subgraph_metadata": [ + { + "input_process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "LOG", + "default_score": 0.2 + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/sentence_piece_tokenizer_meta.json b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/sentence_piece_tokenizer_meta.json new file mode 100644 index 0000000..b728908 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/sentence_piece_tokenizer_meta.json
@@ -0,0 +1,26 @@ +{ + "subgraph_metadata": [ + { + "input_process_units": [ + { + "options_type": "SentencePieceTokenizerOptions", + "options": { + "sentencePiece_model": [ + { + "name": "sp.model", + "description": "The sentence piece model file." + } + ], + "vocab_file": [ + { + "name": "vocab.txt", + "description": "Vocabulary file to convert natural language words to embedding vectors. This file is optional during tokenization, while the sentence piece model is mandatory.", + "type": "VOCABULARY" + } + ] + } + } + ] + } + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/README.md b/third_party/tflite_support/src/tensorflow_lite_support/odml/README.md new file mode 100644 index 0000000..f9f568d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/README.md
@@ -0,0 +1,7 @@ +# Standardized ODML Data Containers + +Provides a set of standardized on-device ML (ODML) data containers, which could +be shared cross-platforms, among different ODML frameworks (TFLite, MediaPipe, +ML Kit), to increase the interoperability, share data conversions with hardwares +(CPU, OpenGL, OpenCL, NNAPI, AHardwareBuffer, Vulkan, etc.) and reduce data +copies.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/BUILD new file mode 100644 index 0000000..0da2aa4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/BUILD
@@ -0,0 +1,126 @@ +# ODML image library for iOS + +load( + "@org_tensorflow//tensorflow/lite/ios:ios.bzl", + "TFL_MINIMUM_OS_VERSION", +) +load( + "@org_tensorflow//tensorflow/lite:special_rules.bzl", + "tflite_ios_lab_runner", +) +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_static_framework", + "ios_unit_test", +) + +package( + default_visibility = ["//visibility:private"], + licenses = ["notice"], # Apache 2.0 +) + +SOURCES = glob([ + "sources/*.h", + "sources/*.m", + "sources/*.mm", +]) + +API_HEADERS = glob([ + "apis/*.h", +]) + +# Compiler flags for building regular non-test libraries. +RELEASE_COPTS = [ + # Enables language-specific warnings for Objective-C, Objective-C++, C, and C++. + "-Wall", + # Warns if functions, variables, and types marked with the deprecated attribute are being used. + "-Wdeprecated-declarations", + # Warns for errors in documentation. + "-Wdocumentation", + # Turns all warnings into errors. + "-Werror", + # Enables extra warning flags that are not enabled by -Wall. + "-Wextra", + # Warns if a global function is defined without a previous prototype declaration. + "-Wmissing-prototypes", + # From -Wextra. Disables warning when signed value is converted to unsigned value during comparison. + "-Wno-sign-compare", + # From -Wextra. Disables warning for unused parameters, which are common in delegate methods and block callbacks. + "-Wno-unused-parameter", + # Warns if a global or local variable or type declaration shadows another variable, parameter, type, class member, or instance variable. + "-Wshadow", + # Warns if a function is declared or defined without specifying the argument types. For a block with no args, use (void) instead of (). + "-Wstrict-prototypes", + # Warns if an @selector() expression is encountered with a method name that hasn't been defined yet. + "-Wundeclared-selector", + # Turn off warnings for headers not part of TensorFlow Lite Objective-C API. + "--system-header-prefix=third_party/tensorflow/lite/c/", +] + +# Compiler flags for building test libraries. +TEST_COPTS = RELEASE_COPTS + [ + # From -Wall. Disables warning when passing nil to a callee that requires a non-null argument. + "-Wno-nonnull", + # Disables warning when a global or local variable or type declaration shadows another. + "-Wno-shadow", +] + +objc_library( + name = "MLImage", + srcs = SOURCES, + hdrs = API_HEADERS, + copts = RELEASE_COPTS, + module_name = "MLImage", + sdk_frameworks = [ + "CoreGraphics", + "CoreMedia", + "CoreVideo", + ], + visibility = ["//visibility:public"], + alwayslink = 1, +) + +ios_static_framework( + name = "MLImage_framework", + hdrs = API_HEADERS, + bundle_name = "MLImage", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + visibility = ["//visibility:public"], + deps = [":MLImage"], +) + +ios_unit_test( + name = "tests", + size = "small", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + deps = [ + ":TestsLibrary", + ], +) + +objc_library( + name = "TestsLibrary", + testonly = 1, + srcs = glob([ + "tests/*.m", + ]), + hdrs = glob([ + "apis/*.h", + "sources/*.h", + "tests/*.h", + ]), + copts = TEST_COPTS, + data = glob([ + "resources/*.jpg", + ]), + sdk_frameworks = [ + "Accelerate", + "CoreGraphics", + "CoreMedia", + "CoreVideo", + ], + deps = [ + ":MLImage", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/MLImage.tulsiproj/Configs/MLImage.tulsigen b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/MLImage.tulsiproj/Configs/MLImage.tulsigen new file mode 100644 index 0000000..2a9620d28 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/MLImage.tulsiproj/Configs/MLImage.tulsigen
@@ -0,0 +1,54 @@ +{ + "sourceFilters" : [ + "third_party/tensorflow_lite_support/odml/ios/image", + "third_party/tensorflow_lite_support/odml/ios/image/apis", + "third_party/tensorflow_lite_support/odml/ios/image/sources", + "third_party/tensorflow_lite_support/odml/ios/image/tests" + ], + "buildTargets" : [ + "//third_party/tensorflow_lite_support/odml/ios/image:MLImage", + "//third_party/tensorflow_lite_support/odml/ios/image:tests" + ], + "projectName" : "MLImage", + "optionSet" : { + "LaunchActionPreActionScript" : { + "p" : "$(inherited)" + }, + "BazelBuildStartupOptionsRelease" : { + "p" : "$(inherited)" + }, + "BazelBuildOptionsRelease" : { + "p" : "$(inherited)" + }, + "BazelBuildOptionsDebug" : { + "p" : "$(inherited)" + }, + "EnvironmentVariables" : { + "p" : "$(inherited)" + }, + "BuildActionPreActionScript" : { + "p" : "$(inherited)" + }, + "CommandlineArguments" : { + "p" : "$(inherited)" + }, + "TestActionPreActionScript" : { + "p" : "$(inherited)" + }, + "BazelBuildStartupOptionsDebug" : { + "p" : "$(inherited)" + }, + "BuildActionPostActionScript" : { + "p" : "$(inherited)" + }, + "TestActionPostActionScript" : { + "p" : "$(inherited)" + }, + "LaunchActionPostActionScript" : { + "p" : "$(inherited)" + } + }, + "additionalFilePaths" : [ + "third_party/tensorflow_lite_support/odml/ios/image/BUILD" + ] +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/MLImage.tulsiproj/project.tulsiconf b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/MLImage.tulsiproj/project.tulsiconf new file mode 100644 index 0000000..15279da --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/MLImage.tulsiproj/project.tulsiconf
@@ -0,0 +1,17 @@ +{ + "configDefaults" : { + "optionSet" : { + "BazelBuildOptionsDebug" : { + + }, + "BazelBuildOptionsRelease" : { + + }, + } + }, + "projectName" : "MLImage", + "packages" : [ + "third_party/tensorflow_lite_support/odml/ios/image" + ], + "workspaceRoot" : "../../../../../.." +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/README.md b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/README.md new file mode 100644 index 0000000..2f9a0658 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/README.md
@@ -0,0 +1,27 @@ +# MLImage for iOS + +MLImage is a lightweight image library for iOS developers. It provides a +common image format to increase interoperability among various on-device +machine learning frameworks. + +Build the `MLImage` Objective-C library target: + +```shell +bazel build -c opt --apple_platform_type=ios tensorflow_lite_support/odml/ios/image:MLImage +``` + +Build the `tests` target: + +```shell +bazel test -c opt --apple_platform_type=ios tensorflow_lite_support/odml/ios/image:tests +``` + +#### Generate the Xcode project using Tulsi + +Open the `//tensorflow_lite_support/odml/ios/image/MLImage.tulsiproj` using +the [TulsiApp](https://github.com/bazelbuild/tulsi) or by running the +[`generate_xcodeproj.sh`](https://github.com/bazelbuild/tulsi/blob/master/src/tools/generate_xcodeproj.sh): + +```shell +generate_xcodeproj.sh --genconfig tensorflow_lite_support/odml/ios/image/MLImage.tulsiproj:MLImage --outputfolder ~/path/to/generated/MLImage.xcodeproj +```
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h new file mode 100644 index 0000000..0c49491 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h
@@ -0,0 +1,103 @@ +// Copyright 2021 Google LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import <CoreMedia/CoreMedia.h> +#import <CoreVideo/CoreVideo.h> +#import <UIKit/UIKit.h> + +NS_ASSUME_NONNULL_BEGIN + +/** Types of image sources. */ +typedef NSInteger GMLImageSourceType NS_TYPED_ENUM + NS_SWIFT_NAME(MLImageSourceType); +/** Image source is a `UIImage`. */ +static const GMLImageSourceType GMLImageSourceTypeImage = 0; +/** Image source is a `CVPixelBuffer`. */ +static const GMLImageSourceType GMLImageSourceTypePixelBuffer = 1; +/** Image source is a `CMSampleBuffer`. */ +static const GMLImageSourceType GMLImageSourceTypeSampleBuffer = 2; + +/** An image used in on-device machine learning. */ +NS_SWIFT_NAME(MLImage) +@interface GMLImage : NSObject + +/** Width of the image in pixels. */ +@property(nonatomic, readonly) CGFloat width; + +/** Height of the image in pixels. */ +@property(nonatomic, readonly) CGFloat height; + +/** + * The display orientation of the image. If `imageSourceType` is `.image`, the + * default value is `image.imageOrientation`; otherwise the default value is + * `.up`. + */ +@property(nonatomic) UIImageOrientation orientation; + +/** The type of the image source. */ +@property(nonatomic, readonly) GMLImageSourceType imageSourceType; + +/** The source image. `nil` if `imageSourceType` is not `.image`. */ +@property(nonatomic, readonly, nullable) UIImage* image; + +/** The source pixel buffer. `nil` if `imageSourceType` is not `.pixelBuffer`. + */ +@property(nonatomic, readonly, nullable) CVPixelBufferRef pixelBuffer; + +/** The source sample buffer. `nil` if `imageSourceType` is not `.sampleBuffer`. + */ +@property(nonatomic, readonly, nullable) CMSampleBufferRef sampleBuffer; + +/** + * Initializes an `MLImage` object with the given image. + * + * @param image The image to use as the source. Its `CGImage` property must not + * be `NULL`. + * @return A new `MLImage` instance with the given image as the source. `nil` if + * the given `image` is `nil` or invalid. + */ +- (nullable instancetype)initWithImage:(UIImage*)image + NS_DESIGNATED_INITIALIZER; + +/** + * Initializes an `MLImage` object with the given pixel buffer. + * + * @param pixelBuffer The pixel buffer to use as the source. It will be retained + * by the new `MLImage` instance for the duration of its lifecycle. + * @return A new `MLImage` instance with the given pixel buffer as the source. + * `nil` if the given pixel buffer is `nil` or invalid. + */ +- (nullable instancetype)initWithPixelBuffer:(CVPixelBufferRef)pixelBuffer + NS_DESIGNATED_INITIALIZER; + +/** + * Initializes an `MLImage` object with the given sample buffer. + * + * @param sampleBuffer The sample buffer to use as the source. It will be + * retained by the new `MLImage` instance for the duration of its lifecycle. The + * sample buffer must be based on a pixel buffer (not compressed data). In + * practice, it should be the video output of the camera on an iOS device, not + * other arbitrary types of `CMSampleBuffer`s. + * @return A new `MLImage` instance with the given sample buffer as the source. + * `nil` if the given sample buffer is `nil` or invalid. + */ +- (nullable instancetype)initWithSampleBuffer:(CMSampleBufferRef)sampleBuffer + NS_DESIGNATED_INITIALIZER; + +/** Unavailable. */ +- (instancetype)init NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/resources/grace_hopper.jpg b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/resources/grace_hopper.jpg new file mode 100644 index 0000000..d2a4278 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/resources/grace_hopper.jpg Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/sources/GMLImage.m b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/sources/GMLImage.m new file mode 100644 index 0000000..38ca7426 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/sources/GMLImage.m
@@ -0,0 +1,91 @@ +// Copyright 2021 Google LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow_lite_support/odml/ios/image/apis/GMLImage.h" + +NS_ASSUME_NONNULL_BEGIN + +@implementation GMLImage + +#pragma mark - Public + +- (nullable instancetype)initWithImage:(UIImage*)image { + if (image.CGImage == NULL) { + return nil; + } + + self = [super init]; + if (self != nil) { + _imageSourceType = GMLImageSourceTypeImage; + _orientation = image.imageOrientation; + _image = image; + _width = image.size.width * image.scale; + _height = image.size.height * image.scale; + } + return self; +} + +- (nullable instancetype)initWithPixelBuffer:(CVPixelBufferRef)pixelBuffer { + if (pixelBuffer == NULL) { + return nil; + } + + self = [super init]; + if (self != nil) { + _imageSourceType = GMLImageSourceTypePixelBuffer; + _orientation = UIImageOrientationUp; + CVPixelBufferRetain(pixelBuffer); + _pixelBuffer = pixelBuffer; + _width = CVPixelBufferGetWidth(pixelBuffer); + _height = CVPixelBufferGetHeight(pixelBuffer); + } + return self; +} + +- (nullable instancetype)initWithSampleBuffer:(CMSampleBufferRef)sampleBuffer { + if (!CMSampleBufferIsValid(sampleBuffer)) { + return nil; + } + + CVImageBufferRef imageBuffer = CMSampleBufferGetImageBuffer(sampleBuffer); + if (imageBuffer == NULL) { + return nil; + } + + self = [super init]; + if (self != nil) { + _imageSourceType = GMLImageSourceTypeSampleBuffer; + _orientation = UIImageOrientationUp; + CFRetain(sampleBuffer); + _sampleBuffer = sampleBuffer; + _width = CVPixelBufferGetWidth(imageBuffer); + _height = CVPixelBufferGetHeight(imageBuffer); + } + return self; +} + +#pragma mark - NSObject + +- (void)dealloc { + if (_sampleBuffer != NULL) { + CFRelease(_sampleBuffer); + } + if (_pixelBuffer != NULL) { + CVPixelBufferRelease(_pixelBuffer); + } +} + +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/tests/GMLImageTests.m b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/tests/GMLImageTests.m new file mode 100644 index 0000000..8abee1a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/tests/GMLImageTests.m
@@ -0,0 +1,196 @@ +// Copyright 2021 Google LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow_lite_support/odml/ios/image/apis/GMLImage.h" + +#import <Accelerate/Accelerate.h> +#import <CoreGraphics/CoreGraphics.h> +#import <CoreMedia/CoreMedia.h> +#import <CoreVideo/CoreVideo.h> +#import <XCTest/XCTest.h> + +NS_ASSUME_NONNULL_BEGIN + +static NSString* const kTestImageName = @"grace_hopper"; +static NSString* const kTestImageType = @"jpg"; +static CGFloat kTestImageWidthInPixels = 517.0f; +static CGFloat kTestImageHeightInPixels = 606.0f; + +/** Unit tests for `GMLImage`. */ +@interface GMLImageTests : XCTestCase + +/** Test image. */ +@property(nonatomic, nullable) UIImage* image; + +@end + +@implementation GMLImageTests + +#pragma mark - Tests + +- (void)setUp { + [super setUp]; + NSString* imageName = + [[NSBundle bundleForClass:[self class]] pathForResource:kTestImageName + ofType:kTestImageType]; + self.image = [[UIImage alloc] initWithContentsOfFile:imageName]; +} + +- (void)tearDown { + self.image = nil; + [super tearDown]; +} + +- (void)testInitWithImage { + GMLImage* mlImage = [[GMLImage alloc] initWithImage:self.image]; + XCTAssertNotNil(mlImage); + XCTAssertEqual(mlImage.imageSourceType, GMLImageSourceTypeImage); + XCTAssertEqual(mlImage.orientation, self.image.imageOrientation); + mlImage.orientation = UIImageOrientationDown; + XCTAssertEqual(mlImage.orientation, UIImageOrientationDown); + XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels, + FLT_EPSILON); + XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels, + FLT_EPSILON); +} + +- (void)testInitWithImage_nilImage { + GMLImage* mlImage = [[GMLImage alloc] initWithImage:nil]; + XCTAssertNil(mlImage); +} + +- (void)testInitWithSampleBuffer { + CMSampleBufferRef sampleBuffer = [self sampleBuffer]; + GMLImage* mlImage = [[GMLImage alloc] initWithSampleBuffer:sampleBuffer]; + XCTAssertNotNil(mlImage); + XCTAssertEqual(mlImage.imageSourceType, GMLImageSourceTypeSampleBuffer); + XCTAssertEqual(mlImage.orientation, UIImageOrientationUp); + mlImage.orientation = UIImageOrientationDown; + XCTAssertEqual(mlImage.orientation, UIImageOrientationDown); + XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels, + FLT_EPSILON); + XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels, + FLT_EPSILON); +} + +- (void)testInitWithSampleBuffer_nilImage { + GMLImage* mlImage = [[GMLImage alloc] initWithSampleBuffer:nil]; + XCTAssertNil(mlImage); +} + +- (void)testInitWithPixelBuffer { + CMSampleBufferRef sampleBuffer = [self sampleBuffer]; + CVPixelBufferRef pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer); + GMLImage* mlImage = [[GMLImage alloc] initWithPixelBuffer:pixelBuffer]; + XCTAssertNotNil(mlImage); + XCTAssertEqual(mlImage.imageSourceType, GMLImageSourceTypePixelBuffer); + XCTAssertEqual(mlImage.orientation, UIImageOrientationUp); + mlImage.orientation = UIImageOrientationDown; + XCTAssertEqual(mlImage.orientation, UIImageOrientationDown); + XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels, + FLT_EPSILON); + XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels, + FLT_EPSILON); +} + +- (void)testInitWithPixelBuffer_nilImage { + GMLImage* mlImage = [[GMLImage alloc] initWithPixelBuffer:nil]; + XCTAssertNil(mlImage); +} + +#pragma mark - Private + +/** + * Converts the input image in RGBA space into a `CMSampleBuffer`. + * + * @return `CMSampleBuffer` converted from the given `UIImage`. + */ +- (CMSampleBufferRef)sampleBuffer { + // Rotate the image and convert from RGBA to BGRA. + CGImageRef CGImage = self.image.CGImage; + size_t width = CGImageGetWidth(CGImage); + size_t height = CGImageGetHeight(CGImage); + size_t bpr = CGImageGetBytesPerRow(CGImage); + + CGDataProviderRef provider = CGImageGetDataProvider(CGImage); + NSData* imageRGBAData = + (id)CFBridgingRelease(CGDataProviderCopyData(provider)); + const uint8_t order[4] = {2, 1, 0, 3}; + + NSData* imageBGRAData = nil; + unsigned char* bgraPixel = (unsigned char*)malloc([imageRGBAData length]); + if (bgraPixel) { + vImage_Buffer src; + src.height = height; + src.width = width; + src.rowBytes = bpr; + src.data = (void*)[imageRGBAData bytes]; + + vImage_Buffer dest; + dest.height = height; + dest.width = width; + dest.rowBytes = bpr; + dest.data = bgraPixel; + + // Specify ordering changes in map. + vImage_Error error = + vImagePermuteChannels_ARGB8888(&src, &dest, order, kvImageNoFlags); + + // Package the result. + if (error == kvImageNoError) { + imageBGRAData = [NSData dataWithBytes:bgraPixel + length:[imageRGBAData length]]; + } + + // Memory cleanup. + free(bgraPixel); + } + + if (imageBGRAData == nil) { + XCTFail(@"Failed to convert input image."); + } + + // Write data to `CMSampleBuffer`. + NSDictionary* options = @{ + (__bridge NSString*)kCVPixelBufferCGImageCompatibilityKey : @(YES), + (__bridge NSString*)kCVPixelBufferCGBitmapContextCompatibilityKey : @(YES) + }; + CVPixelBufferRef pixelBuffer; + CVReturn status = CVPixelBufferCreateWithBytes( + kCFAllocatorDefault, width, height, kCVPixelFormatType_32BGRA, + (void*)[imageBGRAData bytes], bpr, NULL, nil, + (__bridge CFDictionaryRef)options, &pixelBuffer); + + if (status != kCVReturnSuccess) { + XCTFail(@"Failed to create pixel buffer."); + } + + CVPixelBufferLockBaseAddress(pixelBuffer, 0); + CMVideoFormatDescriptionRef videoInfo = NULL; + CMVideoFormatDescriptionCreateForImageBuffer(kCFAllocatorDefault, pixelBuffer, + &videoInfo); + + CMSampleBufferRef buffer; + CMSampleBufferCreateForImageBuffer(kCFAllocatorDefault, pixelBuffer, true, + NULL, NULL, videoInfo, + &kCMTimingInfoInvalid, &buffer); + + CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); + + return buffer; +} + +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/AndroidManifest.xml b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/AndroidManifest.xml new file mode 100644 index 0000000..aa95182 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/AndroidManifest.xml
@@ -0,0 +1,6 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="com.google.android.odml.image"> + <uses-sdk android:minSdkVersion="16" /> + <application /> +</manifest>
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/BUILD new file mode 100644 index 0000000..882fcf2 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/BUILD
@@ -0,0 +1,17 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +licenses(["notice"]) + +android_library( + name = "image", + srcs = glob(["src/**/*.java"]), + custom_package = "com.google.android.odml.image", + manifest = "AndroidManifest.xml", + visibility = [ + "//visibility:public", + ], + deps = [ + "@com_google_auto_value", + "@maven//:androidx_annotation_annotation", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/image.pgcfg b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/image.pgcfg new file mode 100644 index 0000000..95a2a2d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/image.pgcfg
@@ -0,0 +1,38 @@ +# Copyright 2021 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Keep objects marked as needing to be "kept" when making the SDK, but don't +# export these rules to end developers, so they will be allowed to strip these. +-keep @interface com.google.android.odml.image.annotation.KeepForSdk +-keep @com.google.android.odml.image.annotation.KeepForSdk class * +-keepclasseswithmembers,includedescriptorclasses class * { + @com.google.android.odml.image.annotation.KeepForSdk <fields>; +} +-keepclasseswithmembers,includedescriptorclasses class * { + # We use "!bridge" below to avoid applying the keep on "bridge" methods. + # Bridge methods are generated by the Java compiler during type erasure: + # https://docs.oracle.com/javase/tutorial/java/generics/bridgeMethods.html + # This may cause a super class that does not need the @KeepForSdk on a method + # to get a generated method that does include the @KeepForSdk (because the + # subclass has @KeepForSdk on the method being bridged). Since the generated + # method in the super class gets @KeepForSdk, this causes the super class + # itself to be unnecessarily kept by name. Since our @KeepForSdk are meant to + # be explicit directives on a particular method (and not necessarily + # inheritted), we do not need to apply the @KeepForSdk rule on such bridge + # methods. + @com.google.android.odml.image.annotation.KeepForSdk !bridge <methods>; +} +-keepclasseswithmembers,includedescriptorclasses class * { + @com.google.android.odml.image.annotation.KeepForSdk <init>(...); +} \ No newline at end of file
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java new file mode 100644 index 0000000..59116a7 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java
@@ -0,0 +1,50 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import android.graphics.Bitmap; + +/** + * Utility for extracting {@link android.graphics.Bitmap} from {@link MlImage}. + * + * <p>Currently it only supports {@link MlImage} with {@link MlImage#STORAGE_TYPE_BITMAP}, otherwise + * {@link IllegalArgumentException} will be thrown. + */ +public final class BitmapExtractor { + /** + * Extracts a {@link android.graphics.Bitmap} from an {@link MlImage}. + * + * <p>Notice: Properties of the {@code image} like rotation will not take effects. + * + * @param image the image to extract {@link android.graphics.Bitmap} from. + * @return the {@link android.graphics.Bitmap} stored in {@link MlImage} + * @throws IllegalArgumentException when the extraction requires unsupported format or data type + * conversions. + */ + public static Bitmap extract(MlImage image) { + ImageContainer imageContainer = image.getContainer(MlImage.STORAGE_TYPE_BITMAP); + if (imageContainer != null) { + return ((BitmapImageContainer) imageContainer).getBitmap(); + } else { + // TODO(b/180504869): Support ByteBuffer -> Bitmap conversion. + throw new IllegalArgumentException( + "Extracting Bitmap from an MlImage created by objects other than Bitmap is not" + + " supported"); + } + } + + private BitmapExtractor() {} +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java new file mode 100644 index 0000000..b1b02f8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java
@@ -0,0 +1,59 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import android.graphics.Bitmap; + +import com.google.android.odml.image.MlImage.ImageFormat; + +class BitmapImageContainer implements ImageContainer { + private final Bitmap bitmap; + private final ImageProperties properties; + + public BitmapImageContainer(Bitmap bitmap) { + this.bitmap = bitmap; + this.properties = ImageProperties.builder() + .setImageFormat(convertFormatCode(bitmap.getConfig())) + .setStorageType(MlImage.STORAGE_TYPE_BITMAP) + .build(); + } + + public Bitmap getBitmap() { + return bitmap; + } + + @Override + public ImageProperties getImageProperties() { + return properties; + } + + @Override + public void close() { + bitmap.recycle(); + } + + @ImageFormat + static int convertFormatCode(Bitmap.Config config) { + switch (config) { + case ALPHA_8: + return MlImage.IMAGE_FORMAT_ALPHA; + case ARGB_8888: + return MlImage.IMAGE_FORMAT_RGBA; + default: + return MlImage.IMAGE_FORMAT_UNKNOWN; + } + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java new file mode 100644 index 0000000..6c4552b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java
@@ -0,0 +1,108 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.Rect; +import android.net.Uri; +import android.provider.MediaStore; + +import java.io.IOException; + +/** + * Builds {@link MlImage} from {@link android.graphics.Bitmap}. + * + * <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once + * {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content + * in it. + * + * <p>Use {@link BitmapExtractor} to get {@link android.graphics.Bitmap} you passed in. + */ +public class BitmapMlImageBuilder { + // Mandatory fields. + private final Bitmap bitmap; + + // Optional fields. + private int rotation; + private Rect roi; + private long timestamp; + + /** + * Creates the builder with a mandatory {@link android.graphics.Bitmap}. + * + * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the + * values will be set with default: + * + * <ul> + * <li>rotation: 0 + * </ul> + * + * @param bitmap image data object. + */ + public BitmapMlImageBuilder(Bitmap bitmap) { + this.bitmap = bitmap; + rotation = 0; + roi = new Rect(0, 0, bitmap.getWidth(), bitmap.getHeight()); + timestamp = 0; + } + + /** + * Creates the builder to build {@link MlImage} from a file. + * + * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the + * values will be set with default: + * + * <ul> + * <li>rotation: 0 + * </ul> + * + * @param context the application context. + * @param uri the path to the resource file. + */ + public BitmapMlImageBuilder(Context context, Uri uri) throws IOException { + this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri)); + } + + /** + * Sets value for {@link MlImage#getRotation()}. + * + * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270. + */ + public BitmapMlImageBuilder setRotation(int rotation) { + MlImage.validateRotation(rotation); + this.rotation = rotation; + return this; + } + + /** Sets value for {@link MlImage#getRoi()}. */ + BitmapMlImageBuilder setRoi(Rect roi) { + this.roi = roi; + return this; + } + + /** Sets value for {@link MlImage#getTimestamp()}. */ + BitmapMlImageBuilder setTimestamp(long timestamp) { + this.timestamp = timestamp; + return this; + } + + /** Builds an {@link MlImage} instance. */ + public MlImage build() { + return new MlImage(new BitmapImageContainer(bitmap), rotation, roi, timestamp, + bitmap.getWidth(), bitmap.getHeight()); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java new file mode 100644 index 0000000..d5861c8c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java
@@ -0,0 +1,267 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.os.Build.VERSION; +import android.os.Build.VERSION_CODES; + +import com.google.android.odml.image.MlImage.ImageFormat; +import com.google.auto.value.AutoValue; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Locale; + +/** + * Utility for extracting {@link ByteBuffer} from {@link MlImage}. + * + * <p>Currently it only supports {@link MlImage} with {@link MlImage#STORAGE_TYPE_BYTEBUFFER}, + * otherwise {@link IllegalArgumentException} will be thrown. + */ +public class ByteBufferExtractor { + /** + * Extracts a {@link ByteBuffer} from an {@link MlImage}. + * + * <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link + * ImageProperties} whose storage type is {@code MlImage.STORAGE_TYPE_BYTEBUFFER}. + * + * @see MlImage#getContainedImageProperties() + * @return A read-only {@link ByteBuffer}. + * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage. + */ + public static ByteBuffer extract(MlImage image) { + ImageContainer container = image.getContainer(); + switch (container.getImageProperties().getStorageType()) { + case MlImage.STORAGE_TYPE_BYTEBUFFER: + ByteBufferImageContainer byteBufferImageContainer = + (ByteBufferImageContainer) container; + return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); + default: + throw new IllegalArgumentException( + "Extract ByteBuffer from an MlImage created by objects other than Bytebuffer is not" + + " supported"); + } + } + + /** + * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link MlImage}. + * + * <p>Notice: Properties of the {@code image} like rotation will not take effects. + * + * <p>Format conversion spec: + * + * <ul> + * <li>When extracting RGB images to RGBA format, A channel will always set to 255. + * <li>When extracting RGBA images to RGB format, A channel will be dropped. + * </ul> + * + * @param image the image to extract buffer from. + * @param targetFormat the image format of the result bytebuffer. + * @return the readonly {@link ByteBuffer} stored in {@link MlImage} + * @throws IllegalArgumentException when the extraction requires unsupported format or data type + * conversions. + */ + static ByteBuffer extract(MlImage image, @ImageFormat int targetFormat) { + ImageContainer container; + ImageProperties byteBufferProperties = + ImageProperties.builder() + .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER) + .setImageFormat(targetFormat) + .build(); + if ((container = image.getContainer(byteBufferProperties)) != null) { + ByteBufferImageContainer byteBufferImageContainer = + (ByteBufferImageContainer) container; + return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); + } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) { + ByteBufferImageContainer byteBufferImageContainer = + (ByteBufferImageContainer) container; + @ImageFormat + int sourceFormat = byteBufferImageContainer.getImageFormat(); + return convertByteBuffer( + byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat) + .asReadOnlyBuffer(); + } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) { + BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container; + ByteBuffer byteBuffer = + extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat) + .asReadOnlyBuffer(); + image.addContainer(new ByteBufferImageContainer(byteBuffer, targetFormat)); + return byteBuffer; + } else { + throw new IllegalArgumentException( + "Extracting ByteBuffer from an MlImage created by objects other than Bitmap or" + + " Bytebuffer is not supported"); + } + } + + /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */ + @AutoValue + abstract static class Result { + /** + * Gets the {@link ByteBuffer} in the result of {@link + * ByteBufferExtractor#extract(MlImage)}. + */ + public abstract ByteBuffer buffer(); + + /** + * Gets the {@link ImageFormat} in the result of {@link + * ByteBufferExtractor#extract(MlImage)}. + */ + @ImageFormat + public abstract int format(); + + static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) { + return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat); + } + } + + /** + * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link MlImage}. + * + * <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid + * copy. + * + * <p>Notice: Properties of the {@code image} like rotation will not take effects. + * + * @return the readonly {@link ByteBuffer} stored in {@link MlImage} + * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with + * given {@code imageFormat} + */ + static Result extractInRecommendedFormat(MlImage image) { + ImageContainer container; + if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) { + Bitmap bitmap = ((BitmapImageContainer) container).getBitmap(); + @ImageFormat + int format = adviseImageFormat(bitmap); + Result result = Result.create( + extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format); + + image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format())); + return result; + } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) { + ByteBufferImageContainer byteBufferImageContainer = + (ByteBufferImageContainer) container; + return Result.create(byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(), + byteBufferImageContainer.getImageFormat()); + } else { + throw new IllegalArgumentException( + "Extract ByteBuffer from an MlImage created by objects other than Bitmap or Bytebuffer" + + " is not supported"); + } + } + + @ImageFormat + private static int adviseImageFormat(Bitmap bitmap) { + if (bitmap.getConfig() == Config.ARGB_8888) { + return MlImage.IMAGE_FORMAT_RGBA; + } else { + throw new IllegalArgumentException(String.format( + "Extracting ByteBuffer from an MlImage created by a Bitmap in config %s is not" + + " supported", + bitmap.getConfig())); + } + } + + private static ByteBuffer extractByteBufferFromBitmap( + Bitmap bitmap, @ImageFormat int imageFormat) { + if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) { + throw new IllegalArgumentException( + "Extracting ByteBuffer from an MlImage created by a premultiplied Bitmap is not" + + " supported"); + } + if (bitmap.getConfig() == Config.ARGB_8888) { + if (imageFormat == MlImage.IMAGE_FORMAT_RGBA) { + ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount()); + bitmap.copyPixelsToBuffer(buffer); + buffer.rewind(); + return buffer; + } else if (imageFormat == MlImage.IMAGE_FORMAT_RGB) { + // TODO(b/180504869): Try Use RGBA buffer to create RGB buffer which might be + // faster. + int w = bitmap.getWidth(); + int h = bitmap.getHeight(); + int[] pixels = new int[w * h]; + bitmap.getPixels(pixels, 0, w, 0, 0, w, h); + ByteBuffer buffer = ByteBuffer.allocateDirect(w * h * 3); + buffer.order(ByteOrder.nativeOrder()); + for (int pixel : pixels) { + // getPixels returns Color in ARGB rather than copyPixelsToBuffer which returns + // RGBA + buffer.put((byte) ((pixel >> 16) & 0xff)); + buffer.put((byte) ((pixel >> 8) & 0xff)); + buffer.put((byte) (pixel & 0xff)); + } + buffer.rewind(); + return buffer; + } + } + throw new IllegalArgumentException(String.format( + "Extracting ByteBuffer from an MlImage created by Bitmap and convert from %s to format" + + " %d is not supported", + bitmap.getConfig(), imageFormat)); + } + + private static ByteBuffer convertByteBuffer( + ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) { + if (sourceFormat == MlImage.IMAGE_FORMAT_RGB && targetFormat == MlImage.IMAGE_FORMAT_RGBA) { + ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4); + // Extend the buffer when the target is longer than the source. Use two cursors and + // sweep the array reversely to convert in-place. + byte[] array = new byte[target.capacity()]; + source.get(array, 0, source.capacity()); + source.rewind(); + int rgbCursor = source.capacity(); + int rgbaCursor = target.capacity(); + while (rgbCursor != rgbaCursor) { + array[--rgbaCursor] = (byte) 0xff; // A + array[--rgbaCursor] = array[--rgbCursor]; // B + array[--rgbaCursor] = array[--rgbCursor]; // G + array[--rgbaCursor] = array[--rgbCursor]; // R + } + target.put(array, 0, target.capacity()); + target.rewind(); + return target; + } else if (sourceFormat == MlImage.IMAGE_FORMAT_RGBA + && targetFormat == MlImage.IMAGE_FORMAT_RGB) { + ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3); + // Shrink the buffer when the target is shorter than the source. Use two cursors and + // sweep the array to convert in-place. + byte[] array = new byte[source.capacity()]; + source.get(array, 0, source.capacity()); + source.rewind(); + int rgbaCursor = 0; + int rgbCursor = 0; + while (rgbaCursor < array.length) { + array[rgbCursor++] = array[rgbaCursor++]; // R + array[rgbCursor++] = array[rgbaCursor++]; // G + array[rgbCursor++] = array[rgbaCursor++]; // B + rgbaCursor++; + } + target.put(array, 0, target.capacity()); + target.rewind(); + return target; + } else { + throw new IllegalArgumentException(String.format(Locale.ENGLISH, + "Convert bytebuffer image format from %d to %d is not supported", sourceFormat, + targetFormat)); + } + } + + // ByteBuffer is not able to be instantiated. + private ByteBufferExtractor() {} +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java new file mode 100644 index 0000000..f872db4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java
@@ -0,0 +1,55 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import com.google.android.odml.image.MlImage.ImageFormat; + +import java.nio.ByteBuffer; + +class ByteBufferImageContainer implements ImageContainer { + private final ByteBuffer buffer; + private final ImageProperties properties; + + public ByteBufferImageContainer(ByteBuffer buffer, @ImageFormat int imageFormat) { + this.buffer = buffer; + this.properties = ImageProperties.builder() + .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER) + .setImageFormat(imageFormat) + .build(); + } + + public ByteBuffer getByteBuffer() { + return buffer; + } + + @Override + public ImageProperties getImageProperties() { + return properties; + } + + /** + * Returns the image format. + */ + @ImageFormat + public int getImageFormat() { + return properties.getImageFormat(); + } + + @Override + public void close() { + // No op for ByteBuffer. + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java new file mode 100644 index 0000000..f4b0b31 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java
@@ -0,0 +1,103 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import android.graphics.Rect; + +import com.google.android.odml.image.MlImage.ImageFormat; + +import java.nio.ByteBuffer; + +/** + * Builds a {@link MlImage} from a {@link ByteBuffer}. + * + * <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link + * ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it. + * + * <p>Use {@link ByteBufferExtractor} to get {@link ByteBuffer} you passed in. + */ +public class ByteBufferMlImageBuilder { + // Mandatory fields. + private final ByteBuffer buffer; + private final int width; + private final int height; + @ImageFormat + private final int imageFormat; + + // Optional fields. + private int rotation; + private Rect roi; + private long timestamp; + + /** + * Creates the builder with mandatory {@link ByteBuffer} and the represented image. + * + * <p>We will validate the size of the {@code byteBuffer} with given {@code width}, {@code + * height} and {@code imageFormat}. + * + * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the + * values will be set with default: + * + * <ul> + * <li>rotation: 0 + * </ul> + * + * @param byteBuffer image data object. + * @param width the width of the represented image. + * @param height the height of the represented image. + * @param imageFormat how the data encode the image. + */ + public ByteBufferMlImageBuilder( + ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) { + this.buffer = byteBuffer; + this.width = width; + this.height = height; + this.imageFormat = imageFormat; + // TODO(b/180504869): Validate bytebuffer size with width, height and image format + this.rotation = 0; + this.roi = new Rect(0, 0, width, height); + this.timestamp = 0; + } + + /** + * Sets value for {@link MlImage#getRotation()}. + * + * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270. + */ + public ByteBufferMlImageBuilder setRotation(int rotation) { + MlImage.validateRotation(rotation); + this.rotation = rotation; + return this; + } + + /** Sets value for {@link MlImage#getRoi()}. */ + ByteBufferMlImageBuilder setRoi(Rect roi) { + this.roi = roi; + return this; + } + + /** Sets value for {@link MlImage#getTimestamp()}. */ + ByteBufferMlImageBuilder setTimestamp(long timestamp) { + this.timestamp = timestamp; + return this; + } + + /** Builds an {@link MlImage} instance. */ + public MlImage build() { + return new MlImage(new ByteBufferImageContainer(buffer, imageFormat), rotation, roi, + timestamp, width, height); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java new file mode 100644 index 0000000..bfa7c0a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java
@@ -0,0 +1,30 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import com.google.android.odml.image.annotation.KeepForSdk; + +/** Manages internal image data storage. The interface is package-private. */ +@KeepForSdk +interface ImageContainer { + /** Returns the properties of the contained image. */ + @KeepForSdk + ImageProperties getImageProperties(); + + /** Close the image container and releases the image resource inside. */ + @KeepForSdk + void close(); +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java new file mode 100644 index 0000000..a61e97b8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java
@@ -0,0 +1,84 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import com.google.android.odml.image.MlImage.ImageFormat; +import com.google.android.odml.image.MlImage.StorageType; +import com.google.android.odml.image.annotation.KeepForSdk; +import com.google.auto.value.AutoValue; +import com.google.auto.value.extension.memoized.Memoized; + +/** Groups a set of properties to describe how an image is stored. */ +@AutoValue +public abstract class ImageProperties { + /** + * Gets the pixel format of the image. + * + * @see MlImage.ImageFormat + */ + @ImageFormat + public abstract int getImageFormat(); + + /** + * Gets the storage type of the image. + * + * @see MlImage.StorageType + */ + @StorageType + public abstract int getStorageType(); + + @Memoized + @Override + public abstract int hashCode(); + + /** + * Creates a builder of {@link ImageProperties}. + * + * @see ImageProperties.Builder + */ + @KeepForSdk + static Builder builder() { + return new AutoValue_ImageProperties.Builder(); + } + + /** Builds a {@link ImageProperties}. */ + @AutoValue.Builder + @KeepForSdk + abstract static class Builder { + /** + * Sets the {@link MlImage.ImageFormat}. + * + * @see ImageProperties#getImageFormat + */ + @KeepForSdk + abstract Builder setImageFormat(@ImageFormat int value); + + /** + * Sets the {@link MlImage.StorageType}. + * + * @see ImageProperties#getStorageType + */ + @KeepForSdk + abstract Builder setStorageType(@StorageType int value); + + /** Builds the {@link ImageProperties}. */ + @KeepForSdk + abstract ImageProperties build(); + } + + // Hide the constructor. + ImageProperties() {} +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java new file mode 100644 index 0000000..9ed88ee3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java
@@ -0,0 +1,74 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import android.media.Image; +import android.os.Build; +import android.os.Build.VERSION; +import android.os.Build.VERSION_CODES; + +import androidx.annotation.RequiresApi; + +import com.google.android.odml.image.MlImage.ImageFormat; + +@RequiresApi(VERSION_CODES.KITKAT) +class MediaImageContainer implements ImageContainer { + private final Image mediaImage; + private final ImageProperties properties; + + public MediaImageContainer(Image mediaImage) { + this.mediaImage = mediaImage; + this.properties = ImageProperties.builder() + .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE) + .setImageFormat(convertFormatCode(mediaImage.getFormat())) + .build(); + } + + public Image getImage() { + return mediaImage; + } + + @Override + public ImageProperties getImageProperties() { + return properties; + } + + @Override + public void close() { + mediaImage.close(); + } + + @ImageFormat + static int convertFormatCode(int graphicsFormat) { + // We only cover the format mentioned in + // https://developer.android.com/reference/android/media/Image#getFormat() + if (VERSION.SDK_INT >= Build.VERSION_CODES.M) { + if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) { + return MlImage.IMAGE_FORMAT_RGBA; + } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) { + return MlImage.IMAGE_FORMAT_RGB; + } + } + switch (graphicsFormat) { + case android.graphics.ImageFormat.JPEG: + return MlImage.IMAGE_FORMAT_JPEG; + case android.graphics.ImageFormat.YUV_420_888: + return MlImage.IMAGE_FORMAT_YUV_420_888; + default: + return MlImage.IMAGE_FORMAT_UNKNOWN; + } + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java new file mode 100644 index 0000000..59ed98b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java
@@ -0,0 +1,52 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import android.media.Image; +import android.os.Build.VERSION_CODES; + +import androidx.annotation.RequiresApi; + +/** + * Utility for extracting {@link android.media.Image} from {@link MlImage}. + * + * <p>Currently it only supports {@link MlImage} with {@link MlImage#STORAGE_TYPE_MEDIA_IMAGE}, + * otherwise {@link IllegalArgumentException} will be thrown. + */ +@RequiresApi(VERSION_CODES.KITKAT) +public class MediaImageExtractor { + private MediaImageExtractor() {} + + /** + * Extracts a {@link android.media.Image} from an {@link MlImage}. Currently it only works for + * {@link MlImage} that built from {@link MediaMlImageBuilder}. + * + * <p>Notice: Properties of the {@code image} like rotation will not take effects. + * + * @param image the image to extract {@link android.media.Image} from. + * @return {@link android.media.Image} that stored in {@link MlImage}. + * @throws IllegalArgumentException if the extraction failed. + */ + public static Image extract(MlImage image) { + ImageContainer container; + if ((container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) { + return ((MediaImageContainer) container).getImage(); + } + throw new IllegalArgumentException( + "Extract Media Image from an MlImage created by objects other than Media Image" + + " is not supported"); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java new file mode 100644 index 0000000..80771bd --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java
@@ -0,0 +1,89 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import android.graphics.Rect; +import android.media.Image; +import android.os.Build.VERSION_CODES; + +import androidx.annotation.RequiresApi; + +/** + * Builds {@link MlImage} from {@link android.media.Image}. + * + * <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify + * content in it. + * + * <p>Use {@link MediaImageExtractor} to get {@link android.media.Image} you passed in. + */ +@RequiresApi(VERSION_CODES.KITKAT) +public class MediaMlImageBuilder { + // Mandatory fields. + private final Image mediaImage; + + // Optional fields. + private int rotation; + private Rect roi; + private long timestamp; + + /** + * Creates the builder with a mandatory {@link android.media.Image}. + * + * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the + * values will be set with default: + * + * <ul> + * <li>rotation: 0 + * </ul> + * + * @param mediaImage image data object. + */ + public MediaMlImageBuilder(Image mediaImage) { + this.mediaImage = mediaImage; + this.rotation = 0; + this.roi = new Rect(0, 0, mediaImage.getWidth(), mediaImage.getHeight()); + this.timestamp = 0; + } + + /** + * Sets value for {@link MlImage#getRotation()}. + * + * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270. + */ + public MediaMlImageBuilder setRotation(int rotation) { + MlImage.validateRotation(rotation); + this.rotation = rotation; + return this; + } + + /** Sets value for {@link MlImage#getRoi()}. */ + MediaMlImageBuilder setRoi(Rect roi) { + this.roi = roi; + return this; + } + + /** Sets value for {@link MlImage#getTimestamp()}. */ + MediaMlImageBuilder setTimestamp(long timestamp) { + this.timestamp = timestamp; + return this; + } + + /** Builds an {@link MlImage} instance. */ + public MlImage build() { + return new MlImage(new MediaImageContainer(mediaImage), rotation, roi, timestamp, + mediaImage.getWidth(), mediaImage.getHeight()); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java new file mode 100644 index 0000000..7e21e6ad --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java
@@ -0,0 +1,296 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import android.graphics.Rect; + +import androidx.annotation.IntDef; +import androidx.annotation.Nullable; + +import com.google.android.odml.image.annotation.KeepForSdk; + +import java.io.Closeable; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** + * Wraps image data for on-device machine learning (ODML) usages. + * + * <p>{@link MlImage} is designed to be an immutable image container, which could be shared + * cross-platforms, among different Google ODML frameworks(TFLite Support, MLKit). + * + * <p>It is a common abstraction image that could help to chain different frameworks that adapts + * {@link MlImage} together. + * + * <p>To construct an {@link MlImage}, use the provided builders: + * + * <ul> + * <li>{@link ByteBufferMlImageBuilder} + * <li>{@link BitmapMlImageBuilder} + * <li>{@link MediaMlImageBuilder} + * </ul> + * + * <p>{@link MlImage} uses reference counting to maintain internal storage. When it is created the + * reference count is 1. Developer can call {@link #close()} to reduce reference count to release + * internal storage earlier, otherwise Java garbage collection will release the storage eventually. + * + * <p>To extract concrete image, first check {@link StorageType} and then use the provided + * extractors: + * + * <ul> + * <li>{@link ByteBufferExtractor} + * <li>{@link BitmapExtractor} + * <li>{@link MediaImageExtractor} + * </ul> + * + * In future release, {@link MlImage} will support internal conversion(e.g. Bitmap -> ByteBuffer) + * and multiple storages. + */ +public class MlImage implements Closeable { + /** Specifies the image format of an image. */ + @IntDef({ + IMAGE_FORMAT_UNKNOWN, + IMAGE_FORMAT_RGBA, + IMAGE_FORMAT_RGB, + IMAGE_FORMAT_NV12, + IMAGE_FORMAT_NV21, + IMAGE_FORMAT_YV12, + IMAGE_FORMAT_YV21, + IMAGE_FORMAT_YUV_420_888, + IMAGE_FORMAT_ALPHA, + IMAGE_FORMAT_JPEG, + }) + @Retention(RetentionPolicy.SOURCE) + public @interface ImageFormat {} + + public static final int IMAGE_FORMAT_UNKNOWN = 0; + public static final int IMAGE_FORMAT_RGBA = 1; + public static final int IMAGE_FORMAT_RGB = 2; + public static final int IMAGE_FORMAT_NV12 = 3; + public static final int IMAGE_FORMAT_NV21 = 4; + public static final int IMAGE_FORMAT_YV12 = 5; + public static final int IMAGE_FORMAT_YV21 = 6; + public static final int IMAGE_FORMAT_YUV_420_888 = 7; + public static final int IMAGE_FORMAT_ALPHA = 8; + public static final int IMAGE_FORMAT_JPEG = 9; + + /** Specifies the image container type. Would be useful for choosing extractors. */ + @IntDef({ + STORAGE_TYPE_BITMAP, + STORAGE_TYPE_BYTEBUFFER, + STORAGE_TYPE_MEDIA_IMAGE, + STORAGE_TYPE_IMAGE_PROXY, + }) + @Retention(RetentionPolicy.SOURCE) + public @interface StorageType {} + + public static final int STORAGE_TYPE_BITMAP = 1; + public static final int STORAGE_TYPE_BYTEBUFFER = 2; + public static final int STORAGE_TYPE_MEDIA_IMAGE = 3; + public static final int STORAGE_TYPE_IMAGE_PROXY = 4; + + /** + * Returns a list of supported image properties for this {@link MlImage}. + * + * <p>Currently {@link MlImage} only support single storage type so the size of return list will + * always be 1. + * + * @see ImageProperties + */ + public List<ImageProperties> getContainedImageProperties() { + return Collections.singletonList(getContainer().getImageProperties()); + } + + /** Returns the rotation value attached to the image. Rotation value will be 0, 90, 180, 270. */ + public int getRotation() { + return rotation; + } + + /** Returns the timestamp attached to the image. */ + long getTimestamp() { + return timestamp; + } + + /** Returns the width of the image. */ + public int getWidth() { + return width; + } + + /** Returns the height of the image. */ + public int getHeight() { + return height; + } + + /** Returns the region-of-interest rectangle attached to the image. */ + Rect getRoi() { + Rect result = new Rect(); + result.set(roi); + return result; + } + + /** + * Acquires a reference on this {@link MlImage}. This will increase the reference count by 1. + */ + private synchronized void acquire() { + referenceCount += 1; + } + + /** + * Removes a reference that was previously acquired or init. + * + * <p>When {@link MlImage} is created, it has 1 reference count. + * + * <p>When the reference count becomes 0, it will release the resource under the hood. + */ + @Override + // TODO(b/189767728): Create an internal flag to indicate image is closed, or use referenceCount + public synchronized void close() { + referenceCount -= 1; + if (referenceCount == 0) { + for (ImageContainer imageContainer : containerMap.values()) { + imageContainer.close(); + } + } + } + + /** + * Advanced API access for {@link MlImage}. + * + * <p>These APIs are useful for other infrastructures, for example, acquiring extra reference + * count for {@link MlImage}. However, an App developer should avoid using the following APIs. + * + * <p>APIs inside are treated as internal APIs which are subject to change. + */ + public static final class Internal { + /** + * Acquires a reference on this {@link MlImage}. This will increase the reference count + * by 1. + * + * <p>This method is more useful for image consumer to acquire a reference so image resource + * will not be closed accidentally. As image creator, normal developer doesn't need to call + * this method. + * + * <p>The reference count is 1 when {@link MlImage} is created. Developer can call {@link + * #close()} to indicate it doesn't need this {@link MlImage} anymore. + * + * @see #close() + */ + public void acquire() { + image.acquire(); + } + + private final MlImage image; + + // Only MlImage creates the internal helper. + private Internal(MlImage image) { + this.image = image; + } + } + + /** Gets {@link Internal} object which contains internal APIs. */ + public Internal getInternal() { + return new Internal(this); + } + + private final Map<ImageProperties, ImageContainer> containerMap; + private final int rotation; + private final Rect roi; + private final long timestamp; + private final int width; + private final int height; + + private int referenceCount; + + /** Constructs an {@link MlImage} with a built container. */ + @KeepForSdk + MlImage(ImageContainer container, int rotation, Rect roi, long timestamp, int width, + int height) { + this.containerMap = new HashMap<>(); + containerMap.put(container.getImageProperties(), container); + this.rotation = rotation; + this.roi = new Rect(); + this.roi.set(roi); + this.timestamp = timestamp; + this.width = width; + this.height = height; + this.referenceCount = 1; + } + + /** + * Gets one available container. + * + * @return the current container. + */ + @KeepForSdk + ImageContainer getContainer() { + // According to the design, in the future we will support multiple containers in one image. + // Currently just return the original container. + // TODO(b/182443927): Cache multiple containers in MlImage. + return containerMap.values().iterator().next(); + } + + /** + * Gets container from required {@code storageType}. Returns {@code null} if not existed. + * + * <p>If there are multiple containers with required {@code storageType}, returns the first one. + */ + @Nullable + @KeepForSdk + ImageContainer getContainer(@StorageType int storageType) { + for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) { + if (entry.getKey().getStorageType() == storageType) { + return entry.getValue(); + } + } + return null; + } + + /** + * Gets container from required {@code imageProperties}. Returns {@code null} if non existed. + */ + @Nullable + @KeepForSdk + ImageContainer getContainer(ImageProperties imageProperties) { + return containerMap.get(imageProperties); + } + + /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */ + boolean addContainer(ImageContainer container) { + ImageProperties imageProperties = container.getImageProperties(); + if (containerMap.containsKey(imageProperties)) { + return false; + } + containerMap.put(imageProperties, container); + return true; + } + + /** + * Validates rotation values for builders. Only supports 0, 90, 180, 270. + * + * @throws IllegalArgumentException if the rotation value is invalid. + */ + static void validateRotation(int rotation) { + if (rotation != 0 && rotation != 90 && rotation != 180 && rotation != 270) { + throw new IllegalArgumentException( + "Rotation value " + rotation + " is not valid. Use only 0, 90, 180 or 270."); + } + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/annotation/KeepForSdk.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/annotation/KeepForSdk.java new file mode 100644 index 0000000..8442e1d6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/annotation/KeepForSdk.java
@@ -0,0 +1,27 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package com.google.android.odml.image.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Target; + +/** + * Indicates that this object (class, method, etc) should be retained and not renamed when + * generating the SDK, but should be allowed to be stripped or renamed in end developer apps. + */ +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.CONSTRUCTOR}) +@Documented // Technically not needed as this doesn't impact end apps, but kept for consistency. +public @interface KeepForSdk {}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/AndroidManifest.xml b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/AndroidManifest.xml new file mode 100644 index 0000000..d04979a7 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/AndroidManifest.xml
@@ -0,0 +1,8 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="com.google.android.odml.image"> + + <uses-sdk + android:minSdkVersion="16" + android:targetSdkVersion="29" /> +</manifest>
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BUILD new file mode 100644 index 0000000..87642d10 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BUILD
@@ -0,0 +1,111 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_library", "android_local_test") + +licenses(["notice"]) + +INSTRUMENTED_TESTS = glob(["*InstrumentedTest.java"]) + +android_library( + name = "odml-image-test", + srcs = [ + "TestImageCreator.java", + ], + custom_package = "com.google.android.odml.image", + # TODO(b/163039980): Use JAVACOPTS in TF. "-Xep:RemoveUnusedImports:ERROR" wierdly break the build. + javacopts = ["-source 7 -target 7"], + manifest = "AndroidManifest.xml", + deps = [ + "@maven//:androidx_annotation_annotation", + ], +) + +android_local_test( + name = "BitmapMlImageBuilderTest", + srcs = ["BitmapMlImageBuilderTest.java"], + manifest = "AndroidManifest.xml", + test_class = "com.google.android.odml.image.BitmapMlImageBuilderTest", + deps = [ + "//tensorflow_lite_support/odml/java/image", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:org_robolectric_robolectric", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "BitmapExtractorTest", + srcs = ["BitmapExtractorTest.java"], + manifest = "AndroidManifest.xml", + test_class = "com.google.android.odml.image.BitmapExtractorTest", + deps = [ + ":odml-image-test", + "//tensorflow_lite_support/odml/java/image", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:org_robolectric_robolectric", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "ByteBufferMlImageBuilderTest", + srcs = ["ByteBufferMlImageBuilderTest.java"], + manifest = "AndroidManifest.xml", + test_class = "com.google.android.odml.image.ByteBufferMlImageBuilderTest", + deps = [ + "//tensorflow_lite_support/odml/java/image", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:org_robolectric_robolectric", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "ByteBufferExtractorTest", + srcs = ["ByteBufferExtractorTest.java"], + manifest = "AndroidManifest.xml", + test_class = "com.google.android.odml.image.ByteBufferExtractorTest", + deps = [ + ":odml-image-test", + "//tensorflow_lite_support/odml/java/image", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:org_robolectric_robolectric", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "MediaMlImageTest", + srcs = [ + "MediaImageExtractorTest.java", + "MediaMlImageBuilderTest.java", + ], + manifest = "AndroidManifest.xml", + test_class = "com.google.android.odml.image.MediaMlImageBuilderTest", + deps = [ + "//tensorflow_lite_support/odml/java/image", + "//third_party/java/mockito:mockito-android", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:org_robolectric_robolectric", + "@robolectric//bazel:android-all", + ], +) + +android_local_test( + name = "TextureImageExtractorTest", + srcs = ["TextureImageExtractorTest.java"], + manifest = "AndroidManifest.xml", + test_class = "com.google.android.odml.image.TextureImageExtractorTest", + deps = [ + ":odml-image-test", + "//tensorflow_lite_support/odml/java/image", + "//third_party/java/mockito:mockito-android", + "@maven//:androidx_test_core", + "@maven//:com_google_truth_truth", # android + "@maven//:org_robolectric_robolectric", + "@robolectric//bazel:android-all", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java new file mode 100644 index 0000000..8408a0e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java
@@ -0,0 +1,52 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; + +import android.graphics.Bitmap; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; + +import java.nio.ByteBuffer; + +/** Unit test for {@link BitmapExtractor}. */ +@RunWith(RobolectricTestRunner.class) +public class BitmapExtractorTest { + @Test + public void extract_fromBitmap_succeeds() { + Bitmap bitmap = TestImageCreator.createRgbaBitmap(); + MlImage image = new BitmapMlImageBuilder(bitmap).build(); + + Bitmap result = BitmapExtractor.extract(image); + + assertThat(result).isSameInstanceAs(bitmap); + } + + @Test + public void extract_fromByteBuffer_throwsException() { + ByteBuffer buffer = TestImageCreator.createRgbBuffer(); + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(), + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB) + .build(); + + assertThrows(IllegalArgumentException.class, () -> BitmapExtractor.extract(image)); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java new file mode 100644 index 0000000..9a4051c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java
@@ -0,0 +1,88 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; + +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.graphics.Rect; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; + +/** Tests for {@link BitmapMlImageBuilder} */ +@RunWith(RobolectricTestRunner.class) +public final class BitmapMlImageBuilderTest { + @Test + public void build_fromBitmap_succeeds() { + Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888); + + MlImage image = new BitmapMlImageBuilder(bitmap).build(); + ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP); + + assertThat(image.getWidth()).isEqualTo(20); + assertThat(image.getHeight()).isEqualTo(25); + assertThat(image.getContainedImageProperties()) + .containsExactly(ImageProperties.builder() + .setImageFormat(MlImage.IMAGE_FORMAT_RGBA) + .setStorageType(MlImage.STORAGE_TYPE_BITMAP) + .build()); + assertThat(((BitmapImageContainer) container).getBitmap().getConfig()) + .isEqualTo(Config.ARGB_8888); + } + + @Test + public void build_withOptionalProperties_succeeds() { + Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888); + + MlImage image = new BitmapMlImageBuilder(bitmap) + .setRoi(new Rect(0, 5, 10, 15)) + .setRotation(90) + .setTimestamp(12345) + .build(); + + assertThat(image.getTimestamp()).isEqualTo(12345); + assertThat(image.getRotation()).isEqualTo(90); + assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15)); + } + + @Test + public void build_withInvalidRotation_throwsException() { + Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888); + BitmapMlImageBuilder builder = new BitmapMlImageBuilder(bitmap); + + assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360)); + } + + @Test + public void release_recyclesBitmap() { + Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888); + + MlImage image = new BitmapMlImageBuilder(bitmap) + .setRoi(new Rect(0, 5, 10, 15)) + .setRotation(90) + .setTimestamp(12345) + .build(); + assertThat(bitmap.isRecycled()).isFalse(); + image.close(); + + assertThat(bitmap.isRecycled()).isTrue(); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java new file mode 100644 index 0000000..e675ba9 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java
@@ -0,0 +1,157 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; + +import android.graphics.Bitmap; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; + +import java.nio.Buffer; +import java.nio.ByteBuffer; + +/** + * Tests for {@link ByteBufferExtractor}. + * + * <p>{@code RGBA}{@link Bitmap} to {@code RGBA}{@link ByteBuffer} convesion is not tested here, + * like {@link ByteBufferExtractorInstrumentedTest#extract_rgbaFromRgbaBitmap_succeeds()}, because + * Robolectric seems not handling {@link Bitmap#copyPixelsToBuffer(Buffer)} correctly. So that we + * only test that path in the instrumented unit test. + */ +@RunWith(RobolectricTestRunner.class) +public final class ByteBufferExtractorTest { + @Test + public void extract_fromByteBuffer_succeeds() { + ByteBuffer byteBuffer = TestImageCreator.createRgbBuffer(); + MlImage image = new ByteBufferMlImageBuilder(byteBuffer, TestImageCreator.getWidth(), + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB) + .build(); + + ByteBuffer result = ByteBufferExtractor.extract(image); + + assertThat(result).isEquivalentAccordingToCompareTo(byteBuffer); + assertThat(result.isReadOnly()).isTrue(); + } + + @Test + public void extract_fromBitmap_throws() { + Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap(); + MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build(); + + assertThrows(IllegalArgumentException.class, () -> ByteBufferExtractor.extract(image)); + } + + @Test + public void extract_rgbFromRgbByteBuffer_succeeds() { + ByteBuffer buffer = TestImageCreator.createRgbBuffer(); + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(), + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB) + .build(); + + ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB); + + assertThat(result.isReadOnly()).isTrue(); + assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer()); + } + + @Test + public void extract_rgbFromRgbaByteBuffer_succeeds() { + ByteBuffer buffer = TestImageCreator.createRgbaBuffer(); + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(), + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGBA) + .build(); + + ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB); + + assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer()); + assertThat(buffer.position()).isEqualTo(0); + } + + @Test + public void extract_rgbaFromRgbByteBuffer_succeeds() { + ByteBuffer buffer = TestImageCreator.createRgbBuffer(); + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(), + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB) + .build(); + + ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGBA); + + assertThat(result).isEquivalentAccordingToCompareTo( + TestImageCreator.createOpaqueRgbaBuffer()); + assertThat(buffer.position()).isEqualTo(0); + } + + @Test + public void extract_rgbFromRgbaBitmap_succeeds() { + Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap(); + MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build(); + + ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB); + + assertThat(result.isReadOnly()).isTrue(); + assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer()); + + // Verifies ByteBuffer is cached inside MlImage. + ByteBufferImageContainer byteBufferImageContainer = + (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER); + assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result); + assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB); + + // Verifies that extracted ByteBuffer is the cached one. + ByteBuffer result2 = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB); + assertThat(result2).isEqualTo(result); + } + + @Test + public void extract_unsupportedFormatFromByteBuffer_throws() { + ByteBuffer buffer = TestImageCreator.createRgbaBuffer(); + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(), + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGBA) + .build(); + + assertThrows(IllegalArgumentException.class, + () -> ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_YUV_420_888)); + } + + @Test + public void extractInRecommendedFormat_anyFormatFromRgbByteBuffer_succeeds() { + ByteBuffer buffer = TestImageCreator.createRgbBuffer(); + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(), + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB) + .build(); + + ByteBufferExtractor.Result result = ByteBufferExtractor.extractInRecommendedFormat(image); + + assertThat(result.buffer().isReadOnly()).isTrue(); + assertThat(result.format()).isEqualTo(MlImage.IMAGE_FORMAT_RGB); + + // Verifies ByteBuffer is cached inside MlImage. + ByteBufferImageContainer byteBufferImageContainer = + (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER); + assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result.buffer()); + assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB); + + // Verifies that extracted ByteBuffer is the cached one. + ByteBufferExtractor.Result result2 = ByteBufferExtractor.extractInRecommendedFormat(image); + assertThat(result2.buffer()).isEqualTo(result.buffer()); + assertThat(result2.format()).isEqualTo(result.format()); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java new file mode 100644 index 0000000..374c82b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java
@@ -0,0 +1,77 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; + +import android.graphics.Rect; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; + +import java.nio.ByteBuffer; + +/** Tests for {@link ByteBufferMlImageBuilder} */ +@RunWith(RobolectricTestRunner.class) +public final class ByteBufferMlImageBuilderTest { + @Test + public void build_fromByteBuffer_succeeds() { + ByteBuffer buffer = ByteBuffer.allocate(500); + + MlImage image = + new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB).build(); + ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER); + + assertThat(image.getWidth()).isEqualTo(20); + assertThat(image.getHeight()).isEqualTo(25); + assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, 20, 25)); + assertThat(image.getRotation()).isEqualTo(0); + assertThat(image.getContainedImageProperties()) + .containsExactly(ImageProperties.builder() + .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER) + .setImageFormat(MlImage.IMAGE_FORMAT_RGB) + .build()); + assertThat(((ByteBufferImageContainer) container).getImageFormat()) + .isEqualTo(MlImage.IMAGE_FORMAT_RGB); + } + + @Test + public void build_withOptionalProperties_succeeds() { + ByteBuffer buffer = ByteBuffer.allocate(500); + + MlImage image = new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB) + .setRoi(new Rect(0, 5, 10, 15)) + .setRotation(90) + .setTimestamp(12345) + .build(); + + assertThat(image.getTimestamp()).isEqualTo(12345); + assertThat(image.getRotation()).isEqualTo(90); + assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15)); + } + + @Test + public void build_withInvalidRotation_throwsException() { + ByteBuffer buffer = ByteBuffer.allocate(500); + ByteBufferMlImageBuilder builder = + new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB); + + assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360)); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java new file mode 100644 index 0000000..fa832671 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java
@@ -0,0 +1,68 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.when; + +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.graphics.ImageFormat; +import android.media.Image; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.robolectric.RobolectricTestRunner; + +/** Tests for {@link MediaImageExtractor} */ +@RunWith(RobolectricTestRunner.class) +public final class MediaImageExtractorTest { + private static final int HEIGHT = 100; + private static final int WIDTH = 50; + + @Mock + private Image mediaImage; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + + when(mediaImage.getHeight()).thenReturn(HEIGHT); + when(mediaImage.getWidth()).thenReturn(WIDTH); + when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888); + } + + @Test + public void extract_fromMediaMlImage_succeeds() { + MlImage image = new MediaMlImageBuilder(mediaImage).build(); + Image extractedMediaImage = MediaImageExtractor.extract(image); + + assertThat(extractedMediaImage).isSameInstanceAs(image); + } + + @Test + public void extract_fromBitmapMlImage_throwsException() { + MlImage image = new BitmapMlImageBuilder( + Bitmap.createBitmap(/* width= */ 20, /* height= */ 25, Config.ARGB_8888)) + .build(); + assertThrows(IllegalArgumentException.class, () -> MediaImageExtractor.extract(image)); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java new file mode 100644 index 0000000..60397fec --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java
@@ -0,0 +1,90 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import static com.google.common.truth.Truth.assertThat; + +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.when; + +import android.graphics.ImageFormat; +import android.graphics.Rect; +import android.media.Image; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.robolectric.RobolectricTestRunner; + +/** Tests for {@link MediaMlImageBuilder} */ +@RunWith(RobolectricTestRunner.class) +public final class MediaMlImageBuilderTest { + private static final int HEIGHT = 100; + private static final int WIDTH = 50; + + @Mock + private Image mediaImage; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + + when(mediaImage.getHeight()).thenReturn(HEIGHT); + when(mediaImage.getWidth()).thenReturn(WIDTH); + when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888); + } + + @Test + public void build_fromMediaImage_succeeds() { + MlImage image = new MediaMlImageBuilder(mediaImage).build(); + ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE); + + assertThat(image.getWidth()).isEqualTo(WIDTH); + assertThat(image.getHeight()).isEqualTo(HEIGHT); + assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, WIDTH, HEIGHT)); + assertThat(image.getRotation()).isEqualTo(0); + assertThat(image.getTimestamp()).isAtLeast(0); + assertThat(image.getContainedImageProperties()) + .containsExactly(ImageProperties.builder() + .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE) + .setImageFormat(MlImage.IMAGE_FORMAT_YUV_420_888) + .build()); + assertThat(((MediaImageContainer) container).getImage().getFormat()) + .isEqualTo(ImageFormat.YUV_420_888); + } + + @Test + public void build_withOptionalProperties_succeeds() { + MlImage image = new MediaMlImageBuilder(mediaImage) + .setTimestamp(12345) + .setRoi(new Rect(0, 5, 10, 15)) + .setRotation(90) + .build(); + + assertThat(image.getTimestamp()).isEqualTo(12345); + assertThat(image.getRotation()).isEqualTo(90); + assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15)); + } + + @Test + public void build_withInvalidRotation_throwsException() { + MediaMlImageBuilder builder = new MediaMlImageBuilder(mediaImage); + + assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360)); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java new file mode 100644 index 0000000..28f54be --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java
@@ -0,0 +1,148 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.android.odml.image; + +import android.graphics.Bitmap; +import android.graphics.Color; + +import java.nio.ByteBuffer; + +/** + * Creates test images. + * + * <p>Typically, {@link TestImageCreator} creates a 10x2 image, which looks like: + * + * <p>{@code BBBBBWWWWW} + * + * <p>{@code GGGGGRRRRR} + * + * <p>where B=0x000093, W=0xffffff, G=0x009300, R=0x930000. + * + * <p>The value of ALPHA channel is 0x70 or 0xff, depending on the settings. + * + * <p>The created {@link Bitmap} is not pre-multiplied. + */ +final class TestImageCreator { + private static final int RED = 0x73; + private static final int GREEN = 0x85; + private static final int BLUE = 0x96; + private static final int ALPHA = 0x70; + + static int getWidth() { + return 10; + } + + static int getHeight() { + return 2; + } + + /** + * Creates an example non-pre-multiplied bitmap which is 100% opaque. + * + * @see TestImageCreator for details. + */ + static Bitmap createOpaqueRgbaBitmap() { + return createRgbaBitmap(0xff); + } + + /** + * Creates an example non-pre-multiplied bitmap which has non-trivial alpha channel. + * + * @see TestImageCreator for details. + */ + static Bitmap createRgbaBitmap() { + return createRgbaBitmap(ALPHA); + } + + /** + * Creates an example 10x2 bitmap demonstrated in the class doc. A channel sets to {@code + * alpha}. + */ + static Bitmap createRgbaBitmap(int alpha) { + int[] colors = new int[20]; + for (int i = 0; i < 5; i++) { + colors[i] = Color.argb(alpha, 0, 0, BLUE); + colors[i + 5] = Color.argb(alpha, 0xff, 0xff, 0xff); + colors[i + 10] = Color.argb(alpha, 0, GREEN, 0); + colors[i + 15] = Color.argb(alpha, RED, 0, 0); + } + // We don't use Bitmap#createBitmap(int[] ...) here, because that method creates + // pre-multiplied bitmaps. + Bitmap bitmap = Bitmap.createBitmap(10, 2, Bitmap.Config.ARGB_8888); + bitmap.setPremultiplied(false); + bitmap.setPixels(colors, 0, 10, 0, 0, 10, 2); + return bitmap; + } + + /** + * Creates an example 10*10*3 bytebuffer in R-G-B format. + * + * @see TestImageCreator for details. + */ + static ByteBuffer createRgbBuffer() { + return createRgbOrRgbaBuffer(false, 0xff); + } + + /** + * Creates an example 10*10*4 bytebuffer in R-G-B-A format. + * + * @see TestImageCreator for details. + */ + static ByteBuffer createRgbaBuffer() { + return createRgbOrRgbaBuffer(true, ALPHA); + } + + /** + * Creates an example 10*10*4 bytebuffer in R-G-B-A format, but the A channel is 0xFF. + * + * @see TestImageCreator for details. + */ + static ByteBuffer createOpaqueRgbaBuffer() { + return createRgbOrRgbaBuffer(true, 0xff); + } + + /** + * Creates an example 10x2x4 (or 10x2x3 if no alpha) bytebuffer demonstrated in the class doc. + * + * @param withAlpha if true, set A to {@code alpha}, otherwise A channel is ignored. + * @param alpha alpha channel value. Only effective when {@code withAlpha} is {@code true}. + */ + static ByteBuffer createRgbOrRgbaBuffer(boolean withAlpha, int alpha) { + int capacity = withAlpha ? 80 : 60; + ByteBuffer buffer = ByteBuffer.allocateDirect(capacity); + putColorInByteBuffer(buffer, 0, 0, BLUE, withAlpha, alpha, 5); + putColorInByteBuffer(buffer, 0xff, 0xff, 0xff, withAlpha, alpha, 5); + putColorInByteBuffer(buffer, 0, GREEN, 0, withAlpha, alpha, 5); + putColorInByteBuffer(buffer, RED, 0, 0, withAlpha, alpha, 5); + buffer.rewind(); + return buffer; + } + + private static void putColorInByteBuffer( + ByteBuffer buffer, int r, int g, int b, boolean withAlpha, int alpha, int num) { + for (int i = 0; i < num; i++) { + buffer.put((byte) r); + buffer.put((byte) g); + buffer.put((byte) b); + if (withAlpha) { + buffer.put((byte) alpha); + } + } + } + + // Should not be instantiated. + private TestImageCreator() {} +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/third_party_licenses/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/third_party_licenses/BUILD new file mode 100644 index 0000000..b3c4abb --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/third_party_licenses/BUILD
@@ -0,0 +1,68 @@ +# This file is auto generated by addlicensedeps.py. Please do not edit any part +# of this file except to add missing licenses to the extra_licenses parameter. +# +# This file should not need to be updated unless the license_presubmit_check is +# failing. To update it, run the following command from a CitC client on Linux: +# +# /google/bin/releases/opensource/tools/addlicensedeps_py/addlicensedeps --target=//third_party/tensorflow_lite_support/odml/java/image:image --create_licenses_and_metadata_filegroups + +load("//devtools/compliance/addlicensedeps/py:third_party_licenses.bzl", "third_party_licenses") + +licenses(["notice"]) + +third_party_licenses( + create_licenses_and_metadata_filegroups = True, + custom_output_directory = None, + custom_package = None, + extra_licenses = {}, + generated_licenses = { + "AndroidX activity library": "third_party/java/androidx/activity/hack_for_licenses/LICENSE", + "AndroidX annotation experimental library": "third_party/java/androidx/annotation/experimental/LICENSE", + "AndroidX annotation library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/annotations/LICENSE", + "AndroidX architecture core library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/arch_components/core/common/LICENSE", + "AndroidX architecture library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/arch_components/core/runtime/LICENSE", + "AndroidX asynclayoutinflater library": "third_party/java/androidx/asynclayoutinflater/LICENSE", + "AndroidX collection library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/collections/LICENSE", + "AndroidX concurrent futures library": "third_party/java/androidx/concurrent/futures/LICENSE", + "AndroidX coordinatorlayout library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/coordinatorlayout/LICENSE", + "AndroidX core library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/compat/LICENSE", + "AndroidX cursoradapter library": "third_party/java/androidx/cursoradapter/LICENSE", + "AndroidX customview library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/customview/LICENSE", + "AndroidX documentfile library": "third_party/java/androidx/documentfile/LICENSE", + "AndroidX drawerlayout library": "third_party/java/androidx/drawerlayout/LICENSE", + "AndroidX fragment library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/fragment/LICENSE", + "AndroidX interpolator library": "third_party/java/androidx/interpolator/LICENSE", + "AndroidX legacy coreui library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/core_ui/LICENSE", + "AndroidX legacy coreutils library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/core_utils/LICENSE", + "AndroidX legacy v4 library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/v4/LICENSE", + "AndroidX lifecycle common library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/arch_components/lifecycle/common/LICENSE", + "AndroidX lifecycle livedatacore library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/arch_components/lifecycle/livedata_core/LICENSE", + "AndroidX lifecycle runtime library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/arch_components/lifecycle/runtime/LICENSE", + "AndroidX lifecycle viewmodel library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/arch_components/lifecycle/viewmodel/LICENSE", + "AndroidX lifecycle viewmodel savedstate library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/arch_components/lifecycle/lifecycle_viewmodel_savedstate/LICENSE", + "AndroidX loader library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/loader/LICENSE", + "AndroidX localbroadcastmanager library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/localbroadcastmanager/LICENSE", + "AndroidX media base library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/media_compat/LICENSE", + "AndroidX print library": "third_party/java/androidx/print/LICENSE", + "AndroidX savedstate library": "third_party/java/androidx/savedstate/hack_for_licenses/LICENSE", + "AndroidX swiperefreshlayout library": "third_party/java/androidx/swiperefreshlayout/LICENSE", + "AndroidX tracing library": "third_party/java/androidx/tracing/hack_for_licenses/LICENSE", + "AndroidX versionedparcelable library": "third_party/java/android/android_sdk_linux/extras/android/compatibility/versionedparcelable/LICENSE", + "AndroidX viewpager library": "third_party/java/androidx/viewpager/LICENSE", + "Animal Sniffer": "third_party/java/animal_sniffer/LICENSE", + "Checker Framework Annotations": "third_party/java/checker_framework_annotations/LICENSE", + "Error Prone": "third_party/java/error_prone/LICENSE", + "Google Auto": "third_party/java/auto/LICENSE", + "Guava JDK5": "third_party/java/android_libs/guava_jdk5/LICENSE", + "Guava JDK7": "third_party/java_src/google_common/LICENSE", + "J2ObjC": "third_party/java/j2objc/LICENSE", + "JSR 250": "third_party/java/jsr250_annotations/LICENSE", + "JSR 305": "third_party/java/jsr305_annotations/LICENSE", + "JSpecify": "third_party/java/jspecify_annotations/LICENSE", + "JsInterop Annotations": "third_party/java_src/jsinterop_annotations/java/jsinterop/annotations/LICENSE", + "TensorFlow Lite Support": "third_party/tensorflow_lite_support/LICENSE", + "Kotlin": "third_party/kotlin/kotlin/LICENSE", + }, + platform = "android", + target = "//tensorflow_lite_support/odml/java/image", +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/opensource/opensource_only.files b/third_party/tflite_support/src/tensorflow_lite_support/opensource/opensource_only.files index be426420..71f2c89 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/opensource/opensource_only.files +++ b/third_party/tflite_support/src/tensorflow_lite_support/opensource/opensource_only.files
@@ -1,7 +1,8 @@ +tensorflow_lite_support/cc/port/build_defs.bzl tensorflow_lite_support/custom_ops/kernel/sentencepiece/native.bzl +tensorflow_lite_support/opensource/.bazelversion tensorflow_lite_support/opensource/BUILD tensorflow_lite_support/opensource/WORKSPACE -tensorflow_lite_support/opensource/cc_build_defs.bzl tensorflow_lite_support/third_party/android/BUILD tensorflow_lite_support/third_party/android/android.bzl.tpl tensorflow_lite_support/third_party/android/android_configure.BUILD.tpl @@ -24,6 +25,13 @@ tensorflow_lite_support/third_party/toolchains/java/BUILD tensorflow_lite_support/third_party/utf.BUILD tensorflow_lite_support/third_party/zlib.BUILD +tensorflow_lite_support/tools/build_rules/android_test/AndroidManifest_instrumentation_test_template.xml +tensorflow_lite_support/tools/build_rules/android_test/AndroidManifest_target_stub.xml +tensorflow_lite_support/tools/build_rules/android_test/android_library_instrumentation_tests.bzl +tensorflow_lite_support/tools/build_rules/android_test/android_multidevice_instrumentation_test.bzl +tensorflow_lite_support/tools/build_rules/android_test/generate_instrumentation_tests.bzl +tensorflow_lite_support/tools/build_rules/android_test/infer_java_package_name.bzl +tensorflow_lite_support/tools/build_rules/http_files.bzl tensorflow_lite_support/tools/ci_build/build_all.sh tensorflow_lite_support/tools/ci_build/common.sh tensorflow_lite_support/tools/ci_build/common_win.bat @@ -31,6 +39,7 @@ tensorflow_lite_support/tools/pip_package/MANIFEST.in tensorflow_lite_support/tools/pip_package/README tensorflow_lite_support/tools/pip_package/build_pip_package.sh +tensorflow_lite_support/tools/pip_package/metadata_writers.__init__.py tensorflow_lite_support/tools/pip_package/setup.py tensorflow_lite_support/tools/pip_package/simple_console_for_windows.py tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py \ No newline at end of file
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/tools/BUILD index c3525ca..402b274 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/BUILD
@@ -13,8 +13,3 @@ "@absl_py//absl/flags", ], ) - -py_library( - name = "expect_flatbuffers_installed", - srcs = [], -)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/Build_TFLite_Support_Targets.ipynb b/third_party/tflite_support/src/tensorflow_lite_support/tools/Build_TFLite_Support_Targets.ipynb new file mode 100644 index 0000000..3e7e64e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/Build_TFLite_Support_Targets.ipynb
@@ -0,0 +1,278 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Build_TensorFlow_Lite_Support_Libraries_with_Bazel_and_Colab.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "bFF0wwGKvyxw" + }, + "source": [ + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + " http://www.apache.org/licenses/LICENSE-2.0\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_u2bOkHzv4F4" + }, + "source": [ + "# Build TensorFlow Lite Support libraries with Bazel " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sh4xCk3Ygq6p" + }, + "source": [ + "## Set up Android environment" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "S-NOtaam7Ebn" + }, + "source": [ + "# Create folders\n", + "!mkdir -p '/android/sdk'\n", + "\n", + "# Download and move android SDK tools to specific folders\n", + "!wget -q 'https://dl.google.com/android/repository/tools_r25.2.5-linux.zip'\n", + "\n", + "!unzip 'tools_r25.2.5-linux.zip'\n", + "!mv '/content/tools' '/android/sdk'\n", + "# Copy paste the folder\n", + "!cp -r /android/sdk/tools /android/android-sdk-linux\n", + "\n", + "# Download NDK, unzip and move contents\n", + "!wget 'https://dl.google.com/android/repository/android-ndk-r19c-linux-x86_64.zip'\n", + "\n", + "!unzip 'android-ndk-r19c-linux-x86_64.zip'\n", + "!mv /content/android-ndk-r19c /content/ndk\n", + "!mv '/content/ndk' '/android'\n", + "# Copy paste the folder\n", + "!cp -r /android/ndk /android/android-ndk-r19c\n", + "\n", + "# Remove .zip files\n", + "!rm 'tools_r25.2.5-linux.zip'\n", + "!rm 'android-ndk-r19c-linux-x86_64.zip'\n", + "\n", + "# Make android ndk executable to all users\n", + "!chmod -R go=u '/android'" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "dTZrWYNwhd7G" + }, + "source": [ + "# Set and view environment variables\n", + "%env PATH = /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/tools/node/bin:/tools/google-cloud-sdk/bin:/opt/bin:/android/sdk/tools:/android/sdk/platform-tools:/android/ndk\n", + "%env ANDROID_SDK_API_LEVEL=29\n", + "%env ANDROID_API_LEVEL=29\n", + "%env ANDROID_BUILD_TOOLS_VERSION=29.0.2\n", + "%env ANDROID_DEV_HOME=/android\n", + "%env ANDROID_NDK_API_LEVEL=21\n", + "%env ANDROID_NDK_FILENAME=android-ndk-r19c-linux-x86_64.zip\n", + "%env ANDROID_NDK_HOME=/android/ndk\n", + "%env ANDROID_NDK_URL=https://dl.google.com/android/repository/android-ndk-r19c-linux-x86_64.zip\n", + "%env ANDROID_SDK_FILENAME=tools_r25.2.5-linux.zip\n", + "%env ANDROID_SDK_HOME=/android/sdk\n", + "#%env ANDROID_HOME=/android/sdk\n", + "%env ANDROID_SDK_URL=https://dl.google.com/android/repository/tools_r25.2.5-linux.zip\n", + "\n", + "#!echo $PATH\n", + "!export -p" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "e26VrdXdL2U6" + }, + "source": [ + "# Install specific versions of sdk, tools etc.\n", + "!android update sdk --no-ui -a \\\n", + " --filter tools,platform-tools,android-29,build-tools-29.0.2" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3eAfl3MmySzO" + }, + "source": [ + "## Install BAZEL with Baselisk" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Dj_P2Fb-yTqq" + }, + "source": [ + "# Download Latest version of Bazelisk\n", + "!wget https://github.com/bazelbuild/bazelisk/releases/latest/download/bazelisk-linux-amd64\n", + "\n", + "# Make script executable\n", + "!chmod +x bazelisk-linux-amd64\n", + "\n", + "# Adding to the path\n", + "!sudo mv bazelisk-linux-amd64 /usr/local/bin/bazel\n", + "\n", + "# Extract bazel info\n", + "!bazel" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "AMPL9u6tzOMJ" + }, + "source": [ + "# Clone TensorFlow Lite Support repository OR upload your custom folder to build\n", + "!git clone https://github.com/tensorflow/tflite-support.git" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "9KXH-1So0VOd" + }, + "source": [ + "# Move into tflite-support folder\n", + "%cd /content/tflite-support/\n", + "\n", + "!ls" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2D2tsSFs7Hrq" + }, + "source": [ + "## Build .aar files" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "8XM37u827jD_" + }, + "source": [ + "#@title Select library. { display-mode: \"form\" }\n", + "\n", + "library = 'Support library' #@param [\"Support library\", \"Task Vision library\", \"Task Text library\", \"Task Audio library\",\"Metadata library\",\"C++ image_classifier\",\"C++ image_objector\",\"C++ image_segmenter\",\"C++ image_embedder\",\"C++ nl_classifier\",\"C++ bert_nl_classifier\", \"C++ bert_question_answerer\", \"C++ metadata_extractor\"]\n", + "\n", + "print('You selected:', library)\n", + "\n", + "if library == 'Support library':\n", + " library = '//tensorflow_lite_support/java:tensorflowlite_support.aar'\n", + "elif library == 'Task Vision library':\n", + " library = '//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision:task-library-vision'\n", + "elif library == 'Task Text library':\n", + " library = '//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text:task-library-text'\n", + "elif library == 'Task Audio library':\n", + " library = '//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio:task-library-audio'\n", + "elif library == 'Metadata library':\n", + " library = '//tensorflow_lite_support/metadata/java:tensorflow-lite-support-metadata-lib'\n", + "elif library == 'C++ image_classifier':\n", + " library = '//tensorflow_lite_support/cc/task/vision:image_classifier'\n", + "elif library == 'C++ image_objector':\n", + " library = '//tensorflow_lite_support/cc/task/vision:image_objector'\n", + "elif library == 'C++ image_segmenter':\n", + " library = '//tensorflow_lite_support/cc/task/vision:image_segmenter'\n", + "elif library == 'C++ image_embedder':\n", + " library = '//tensorflow_lite_support/cc/task/vision:image_embedder'\n", + "elif library == 'C++ nl_classifier':\n", + " library = '//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier'\n", + "elif library == 'C++ bert_nl_classifier':\n", + " library = '//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier'\n", + "elif library == 'C++ bert_question_answerer':\n", + " library = '//tensorflow_lite_support/cc/task/text/qa:bert_question_answerer'\n", + "elif library == 'C++ metadata_extractor':\n", + " library = '//tensorflow_lite_support/metadata/cc:metadata_extractor'\n", + "\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "WwanGSg-FE-0" + }, + "source": [ + "#@title Select platform(s). { display-mode: \"form\" }\n", + "\n", + "platforms = 'arm64-v8a,armeabi-v7a' #@param [\"arm64-v8a,armeabi-v7a\",\"x86\", \"x86_64\", \"arm64-v8a\", \"armeabi-v7a\",\"x86,x86_64,arm64-v8a,armeabi-v7a\"]\n", + "print('You selected:', platforms)\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "FBw-vFrw0hdc" + }, + "source": [ + "# Build library\n", + "!bazel build \\\n", + " --fat_apk_cpu='{platforms}' \\\n", + " '{library}'" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "vRCanN-W5Av7" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/BUILD new file mode 100644 index 0000000..34d04920 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/BUILD
@@ -0,0 +1,4 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/AndroidManifest_instrumentation_test_template.xml b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/AndroidManifest_instrumentation_test_template.xml new file mode 100644 index 0000000..7c43c16 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/AndroidManifest_instrumentation_test_template.xml
@@ -0,0 +1,31 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- + ~ Copyright (C) 2018 The Android Open Source Project + ~ + ~ Licensed under the Apache License, Version 2.0 (the "License"); + ~ you may not use this file except in compliance with the License. + ~ You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="${applicationId}" > + + <uses-sdk + android:minSdkVersion="19" + android:targetSdkVersion="28" /> + + <application/> + + <instrumentation + android:name="androidx.test.runner.AndroidJUnitRunner" + android:targetPackage="${instrumentationTargetPackage}"> + </instrumentation> + +</manifest>
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/AndroidManifest_target_stub.xml b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/AndroidManifest_target_stub.xml new file mode 100644 index 0000000..e1f4f30 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/AndroidManifest_target_stub.xml
@@ -0,0 +1,24 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- + ~ Copyright (C) 2018 The Android Open Source Project + ~ + ~ Licensed under the Apache License, Version 2.0 (the "License"); + ~ you may not use this file except in compliance with the License. + ~ You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="${applicationId}" > + + <uses-sdk + android:minSdkVersion="19" + android:targetSdkVersion="28" /> + +</manifest>
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/BUILD new file mode 100644 index 0000000..a515bdda --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/BUILD
@@ -0,0 +1,11 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files( + [ + "AndroidManifest_instrumentation_test_template.xml", + "AndroidManifest_target_stub.xml", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/android_library_instrumentation_tests.bzl b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/android_library_instrumentation_tests.bzl new file mode 100644 index 0000000..b0028e8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/android_library_instrumentation_tests.bzl
@@ -0,0 +1,78 @@ +"""A rule wrapper for an instrumentation test for an android library.""" + +load( + "//tensorflow_lite_support/tools/build_rules/android_test:generate_instrumentation_tests.bzl", + "generate_instrumentation_tests", +) +load( + "//tensorflow_lite_support/tools/build_rules/android_test:infer_java_package_name.bzl", + "infer_java_package_name", +) + +def android_library_instrumentation_tests( + name, + srcs, + deps, + target_devices, + test_java_package = None, + library_args = {}, + binary_args = {}, + **kwargs): + """A macro for an instrumentation test whose target under test is an android_library. + + Will generate a 'self-instrumentating' test binary and other associated rules + + The intent of this wrapper is to simplify the build API for creating instrumentation test rules + for simple cases, while still supporting build_cleaner for automatic dependency management. + + This will generate: + - an unused stub android_binary under test, to placate bazel + - a test_lib android_library, containing all sources and dependencies + - a test_binary android_binary (soon to be android_application) + - the manifest to use for the test library. + - for each device combination: + - an android_instrumentation_test rule) + + Args: + name: the name to use for the generated android_library rule. This is needed for build_cleaner to + manage dependencies + srcs: the test sources to generate rules for + deps: the build dependencies to use for the generated test binary + target_devices: array of device targets to execute on + test_java_package: Optional. A custom root package name to use for the tests. If unset + will be derived based on current path to a java source root + library_args: additional arguments to pass to generated android_library + binary_args: additional arguments to pass to generated android_binary + **kwargs: arguments to pass to generated android_instrumentation_test rules + """ + library_name = "%s_library" % name + test_java_package_name = test_java_package if test_java_package else infer_java_package_name() + + native.android_binary( + name = "target_stub_binary", + manifest = "//tensorflow_lite_support/tools/build_rules/android_test:AndroidManifest_target_stub.xml", + # use the same package name as the test package, so it gets overridden + manifest_values = {"applicationId": test_java_package_name}, + testonly = 1, + ) + + native.android_library( + name = library_name, + srcs = srcs, + testonly = 1, + deps = deps, + **library_args + ) + + generate_instrumentation_tests( + name = name, + srcs = srcs, + deps = [library_name], + target_devices = target_devices, + test_java_package_name = test_java_package_name, + test_android_package_name = test_java_package_name, + instrumentation_target_package = test_java_package_name, + instruments = ":target_stub_binary", + binary_args = binary_args, + **kwargs + )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/android_multidevice_instrumentation_test.bzl b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/android_multidevice_instrumentation_test.bzl new file mode 100644 index 0000000..c252b03b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/android_multidevice_instrumentation_test.bzl
@@ -0,0 +1,17 @@ +"""Utility for running single test on multiple emulator targets.""" + +def android_multidevice_instrumentation_test(name, target_devices, **kwargs): + """Generates a android_instrumentation_test rule for each given device. + + Args: + name: Name prefix to use for the rules. The name of the generated rules will follow: + name + target_device[-6:] eg name-15_x86 + target_devices: array of device targets + **kwargs: arguments to pass to generated android_test rules + """ + for device in target_devices: + native.android_instrumentation_test( + name = name + "-" + device[-6:], + target_device = device, + **kwargs + )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/generate_instrumentation_tests.bzl b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/generate_instrumentation_tests.bzl new file mode 100644 index 0000000..fb586d1 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/generate_instrumentation_tests.bzl
@@ -0,0 +1,62 @@ +"""Internal helper function for generating instrumentation tests .""" + +load( + "//tensorflow_lite_support/tools/build_rules/android_test:android_multidevice_instrumentation_test.bzl", + "android_multidevice_instrumentation_test", +) + +def generate_instrumentation_tests( + name, + srcs, + deps, + target_devices, + test_java_package_name, + test_android_package_name, + instrumentation_target_package, + instruments, + binary_args = {}, + **kwargs): + """A helper rule to generate instrumentation tests. + + + This will generate: + - a test_binary android_binary (soon to be android_application) + - the manifest to use for the test library. + - for each device combination: + - an android_instrumentation_test rule) + + Args: + name: unique prefix to use for generated rules + srcs: the test sources to generate rules for + deps: the build dependencies to use for the generated test binary + target_devices: array of device targets to execute on + test_java_package_name: the root java package name for the tests. + test_android_package_name: the android package name to use for the android_binary test app. Typically this is the same as test_java_package_name + instrumentation_target_package: the android package name to specify as instrumentationTargetPackage in the test_app manifest + instruments: The android binary the tests instrument. + binary_args: Optional additional arguments to pass to generated android_binary + **kwargs: arguments to pass to generated android_instrumentation_test rules + """ + + _manifest_values = { + "applicationId": test_android_package_name, + "instrumentationTargetPackage": instrumentation_target_package, + } + _manifest_values.update(binary_args.pop("manifest_values", {})) + native.android_binary( + name = "%s_binary" % name, + instruments = instruments, + manifest = "//tensorflow_lite_support/tools/build_rules/android_test:AndroidManifest_instrumentation_test_template.xml", + manifest_values = _manifest_values, + testonly = 1, + deps = deps + [ + "@maven//:androidx_test_runner", + ], + **binary_args + ) + android_multidevice_instrumentation_test( + name = "%s_tests" % name, + target_devices = target_devices, + test_app = "%s_binary" % name, + **kwargs + )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/infer_java_package_name.bzl b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/infer_java_package_name.bzl new file mode 100644 index 0000000..929bb7e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/android_test/infer_java_package_name.bzl
@@ -0,0 +1,28 @@ +"""A rule for inferring a java package name.""" + +_JAVA_ROOTS = [ + "javatests/", + "javatest/", + "java/", +] + +def infer_java_package_name(): + """Infer a java package name based on current path below 'javatests' or 'java'""" + return _infer_java_package_name_from_path(native.package_name()) + +def infer_java_package_name_from_label(label): + package_path = _get_path_from_label(label) + return _infer_java_package_name_from_path(package_path) + +def _infer_java_package_name_from_path(package_path): + for root in _JAVA_ROOTS: + if root in package_path: + root_index = package_path.rindex(root) + len(root) + return package_path[root_index:].replace("/", ".") + fail("Could not find one of java roots %s in %s" % (_JAVA_ROOTS, package_path)) + +def _get_path_from_label(label_string): + label_string = label_string.split(":")[0] + if not label_string.startswith("//"): + label_string = "//%s%s" % (native.package_name(), label_string) + return label_string
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/http_files.bzl b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/http_files.bzl new file mode 100644 index 0000000..58be497 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/http_files.bzl
@@ -0,0 +1,28 @@ +"""Build rule to depend on files downloaded from http_file.""" + +def tflite_file(name, extension): + """Links the tflite file from http_file with the current directory. + + Args: + name: the name of the tflite_file target, which is also the name of the + tflite file specified through http_file in WORKSPACE. For example, if + `name` is Foo, `tflite_file` will create a link to the downloaded file + file "@Foo//file" to the current directory as "Foo.tflite". + extension: the extension of the file. + """ + native.genrule( + name = "%s_ln" % (name), + srcs = ["@%s//file" % (name)], + outs = ["%s.%s" % (name, extension)], + output_to_bindir = 1, + cmd = "ln $< $@", + ) + + native.filegroup( + name = name, + srcs = ["%s.%s" % (name, extension)], + ) + +def tflite_model(name): + """Links the tflite model from http_file with the current directory.""" + tflite_file(name, "tflite")
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/build_all.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/build_all.sh index 5bbfb60..4219a8e 100755 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/build_all.sh +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/build_all.sh
@@ -17,17 +17,56 @@ set -ex -bazel build -c opt --config=monolithic \ - //tensorflow_lite_support/java:tensorflowlite_support \ - //tensorflow_lite_support/codegen/python:codegen \ - //tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata_lib \ - //tensorflow_lite_support/metadata/cc:metadata_extractor \ - //tensorflow_lite_support/custom_ops/kernel:all \ - //tensorflow_lite_support/custom_ops/python:tflite_text_api +bash tensorflow_lite_support/custom_ops/tf_configure.sh -# Build Task libraries. +# TODO(b/200756963): Make it possible to build flatbuffer schema libraries with +# more jobs. bazel build -c opt --config=monolithic \ + //tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata_lib \ + //tensorflow_lite_support/metadata/cc:metadata_extractor + +export BAZEL_PARALLEL="-j 32" + +# General targets. +bazel build -c opt ${BAZEL_PARALLEL} --config=monolithic \ + //tensorflow_lite_support/codegen/python:codegen \ + //tensorflow_lite_support/custom_ops/kernel:all \ + //tensorflow_lite_support/custom_ops/python:tflite_text_api \ + //tensorflow_lite_support/examples/task/audio/desktop:audio_classifier_demo + +# Android targets. +bazel build -c opt ${BAZEL_PARALLEL} --config=monolithic \ --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ + //tensorflow_lite_support/java:tensorflowlite_support \ + //tensorflow_lite_support/cc/task/vision:image_embedder \ + //tensorflow_lite_support/cc/task/audio:audio_embedder \ + //tensorflow_lite_support/cc/task/processor:all \ + //tensorflow_lite_support/odml/java/image \ //tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base-task-api.aar \ //tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text:task-library-text \ - //tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision:task-library-vision + //tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision:task-library-vision \ + //tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio:task-library-audio \ + //tensorflow_lite_support/acceleration/configuration:gpu-delegate-plugin + +bazel clean +# Coral plugin. +bazel build -c opt ${BAZEL_PARALLEL} --define=darwinn_portable=1 \ + //tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin + +# Tests. + +bazel clean + +bazel test -c opt $BAZEL_PARALLEL --test_output=all \ + //tensorflow_lite_support/c/test/... \ + //tensorflow_lite_support/cc/test/task/vision:all \ + //tensorflow_lite_support/cc/test/task/text/... \ + //tensorflow_lite_support/custom_ops/kernel/sentencepiece:all \ + //tensorflow_lite_support/metadata/python/tests:metadata_test \ + //tensorflow_lite_support/metadata/python/tests/metadata_writers:all \ + +bazel test -c opt $BAZEL_PARALLEL --test_output=all --build_tests_only \ + --build_tag_filters=-tflite_emulator_test_android \ + --test_tag_filters=-tflite_emulator_test_android \ + //tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/... +
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/build_ios_framework.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/build_ios_framework.sh new file mode 100755 index 0000000..8f8ae294 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/build_ios_framework.sh
@@ -0,0 +1,137 @@ +#!/usr/bin/env bash +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Set the following variables as appropriate. +# * BAZEL: path to bazel. defaults to the first one available in PATH +# * TFLS_BUILD_VERSION: to specify the release version. defaults to 0.0.1-dev +# * IS_RELEASE_BUILD: set as true if this build should be a release build +# * ARCHIVE_FRAMEWORK: set as true if the framework should be archived +# * DEST_DIR: destination directory to which the framework will be copied + +set -ex + +if [[ "$(uname)" != "Darwin" ]]; then + echo "This build script only works on macOS." + exit 1 +fi + +BAZEL="${BAZEL:-$(which bazel)}" +TFLS_BUILD_VERSION=${TFLS_BUILD_VERSION:-0.0.1-dev} +TFLS_ROOT_DIR=$(git rev-parse --show-toplevel) + +if [[ ! -x "${BAZEL}" ]]; then + echo "bazel executable is not found." + exit 1 +fi + +if [[ -z "${DEST_DIR}" || "${DEST_DIR}" == ${TFLS_ROOT_DIR}* ]]; then + echo "DEST_DIR variable must be set and not be under the repository root." + exit 1 +fi + +# Builds the C API framework for the specified framework for iOS. +function build_c_api_framework { + "${BAZEL}" build -c opt --config=ios_fat \ + //tensorflow_lite_support/ios:$1C_framework +} + +function create_framework_archive { + TARGET_FRAMEWORK_NAME="$1" + C_API_FRAMEWORK_NAME="$1C" + + TFLS_IOS_DIR=tensorflow_lite_support/ios + BAZEL_IOS_OUTDIR="bazel-bin/${TFLS_IOS_DIR}" + + # Change to the Bazel iOS output directory. + pushd "${BAZEL_IOS_OUTDIR}" + + # Create the temporary directory for the given framework. + ARCHIVE_NAME="${TARGET_FRAMEWORK_NAME}-${TFLS_BUILD_VERSION}" + TFLS_TMPDIR="$(mktemp -d)" + + # Unzip the framework into the appropriate directory structure for CocoaPods. + # The final archive should contain the C API framework as a vendored framework + # as well as all the Obj-C source code and the header files. + # + # The directory structure will be like: + # + # ${TFLS_TMPDIR}/ + # |-- tensorflow_lite_support/ + # | |-- cc/ + # | | +-- <C-API header files...> + # | +-- ios/ + # | +-- <Obj-C header/source files...> + # +-- Frameworks/ + # +-- TensorFlowLiteTaskTextC.framework + # + + # ----- (1) Copy source files ----- + pushd "${TFLS_ROOT_DIR}" + + # Source files to be archived along with the static framework. + SRC_PATTERNS=" + */cc/*c_api*.h + */ios/task/text/*/Sources/* + */ios/task/text/apis/* + " + + # List of individual files obtained from the patterns above. + SRC_FILES=$(xargs -n1 find * -wholename <<< "${SRC_PATTERNS}" | sort | uniq) + + # Copy source files with the intermediate directories preserved. + xargs -n1 -I{} rsync -R {} "${TFLS_TMPDIR}" <<< "${SRC_FILES}" + popd + + # ----- (2) Unzip the prebuilt C API framework ----- + unzip "${C_API_FRAMEWORK_NAME}_framework.zip" -d "${TFLS_TMPDIR}"/Frameworks + + # ----- (3) Move the framework to the destination ----- + if [[ "${ARCHIVE_FRAMEWORK}" == true ]]; then + TARGET_DIR="$(realpath "${TARGET_FRAMEWORK_NAME}")" + + # Create the framework archive directory. + if [[ "${IS_RELEASE_BUILD}" == true ]]; then + # Get the first 16 bytes of the sha256 checksum of the root directory. + SHA256_CHECKSUM=$(find "${TFLS_TMPDIR}" -type f -print0 | xargs -0 shasum -a 256 | sort | shasum -a 256 | cut -c1-16) + FRAMEWORK_ARCHIVE_DIR="${TARGET_DIR}/${TFLS_BUILD_VERSION}/${SHA256_CHECKSUM}" + else + FRAMEWORK_ARCHIVE_DIR="${TARGET_DIR}/${TFLS_BUILD_VERSION}" + fi + mkdir -p "${FRAMEWORK_ARCHIVE_DIR}" + + # Zip up the framework and move to the archive directory. + pushd "${TFLS_TMPDIR}" + TFLS_ARCHIVE_FILE="${ARCHIVE_NAME}.tar.gz" + tar -cvzf "${TFLS_ARCHIVE_FILE}" . + mv "${TFLS_ARCHIVE_FILE}" "${FRAMEWORK_ARCHIVE_DIR}" + popd + + # Move the target directory to the Kokoro artifacts directory. + mv "${TARGET_DIR}" "$(realpath "${DEST_DIR}")"/ + else + rsync -r "${TFLS_TMPDIR}/" "$(realpath "${DEST_DIR}")/" + fi + + # Clean up the temporary directory for the framework. + rm -rf "${TFLS_TMPDIR}" + + # Pop back to the TFLS root directory. + popd +} + +cd "${TFLS_ROOT_DIR}" +build_c_api_framework TensorFlowLiteTaskText +create_framework_archive TensorFlowLiteTaskText
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh new file mode 100755 index 0000000..49f8c3d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh
@@ -0,0 +1,117 @@ +#!/bin/bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Pip install TensorFlow Lite Support and run basic test on the pip package. + +# Important: Use msys shell to run this script on Windows. + +set -e +set -x + +function run_smoke_test() { + VENV_TMP_DIR="$(mktemp -d)" + + if [[ "$OSTYPE" == "msys" ]]; then + VENV_TMP_DIR="$(cygpath -m $VENV_TMP_DIR)" + fi + + python -m virtualenv "${VENV_TMP_DIR}" || \ + die "FAILED: Unable to create virtualenv" + + if [[ "$OSTYPE" == "msys" ]]; then + source "${VENV_TMP_DIR}/Scripts/activate" || \ + die "FAILED: Unable to activate virtualenv " + else + source "${VENV_TMP_DIR}/bin/activate" || \ + die "FAILED: Unable to activate virtualenv " + fi + + # install tflite-support + python -m pip install ${WHL_NAME} || \ + die "pip install (forcing to reinstall tflite-support) FAILED" + echo "Successfully installed pip package ${WHL_NAME}" + + # Download a test model + export TEST_MODEL="$(pwd)/test.tflite" + wget https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_0.75_192_quantized/1/metadata/1\?lite-format\=tflite -O "$TEST_MODEL" + if [[ "$OSTYPE" == "msys" ]]; then + TEST_MODEL=$(cygpath -m $TEST_MODEL) + fi + + test_tfls_imports + + test_codegen + + # Deactivate from virtualenv. + deactivate || source deactivate || \ + die "FAILED: Unable to deactivate from existing virtualenv." + + echo "All smoke test passes!" +} + +function test_tfls_imports() { + TMP_DIR=$(mktemp -d) + pushd "${TMP_DIR}" + + # test for basic import and metadata display. + RET_VAL=$(python -c "from tflite_support import metadata; \ +md = metadata.MetadataDisplayer.with_model_file(\"$TEST_MODEL\"); \ +print(md.get_metadata_json())") + + # just check if the model name is there. + if ! [[ ${RET_VAL} == *"MobileNetV1 image classifier (quantized)"* ]]; then + echo "Unexpected return value: ${RET_VAL}" + echo "PIP smoke test on virtualenv FAILED, do not upload ${WHL_NAME}." + return 1 + fi + + RESULT=$? + + popd + return $RESULT +} + +function test_codegen() { + TMP_DIR=$(mktemp -d) + pushd "${TMP_DIR}" + + # test for basic import and metadata display. + tflite_codegen --model ${TEST_MODEL} --destination tmp + RESULT=$? + + # just check if the model name is there. + if [[ ${RESULT} -ne 0 ]]; then + echo "Unexpected return value: ${RESULT}" + echo "PIP smoke test on virtualenv FAILED, do not upload ${WHL_NAME}." + return 1 + fi + + popd + return $RESULT +} + +########################################################################### +# Main +########################################################################### +if [[ -z "${1}" ]]; then + echo "TFLite Support WHL path not given, unable to install and test." + return 1 +fi + +which python +python --version + +WHL_NAME=${1} +run_smoke_test
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat index 35f39a7..702fb89f 100755 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat
@@ -22,7 +22,7 @@ @REM Setup Bazel @REM :: Download Bazel from github and make sure its found in PATH. -SET BAZEL_VERSION=3.1.0 +SET BAZEL_VERSION=3.7.2 md C:\tools\bazel\ wget -q https://github.com/bazelbuild/bazel/releases/download/%BAZEL_VERSION%/bazel-%BAZEL_VERSION%-windows-x86_64.exe -O C:/tools/bazel/bazel.exe SET PATH=C:\tools\bazel;%PATH%
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/update_version.py b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/update_version.py new file mode 100644 index 0000000..86fa588 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/update_version.py
@@ -0,0 +1,120 @@ +# lint as: python3 +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.. +# ============================================================================== +"""Update version code in the repo. + +We use a python script rather than GNU tools to avoid cross-platform +difficulties. + +The script takes 3 argument: + --src <path> a path pointing to the code repo. + --version <version> the new version code. + --nightly [default: false] when true, the version code will append a build + suffix (e.g. dev20201103) + +It should not run by bazel. Use it as a simple python script. +""" + +import argparse +import datetime +import os +import re + +SETUP_PY_PATH = "tensorflow_lite_support/tools/pip_package/setup.py" + + +def replace_string_in_line(search, replace, filename): + """Replace the string in every line of the file in-place.""" + with open(filename, "r") as f: + content = f.read() + with open(filename, "w") as f: + f.write(re.sub(search, replace, content)) + + +def get_current_version(path): + """Get the current version code from setup.py.""" + for line in open(os.path.join(path, SETUP_PY_PATH)): + match = re.search("^_VERSION = '([a-z0-9\\.\\-]+)'", line) + if match: + return match.group(1) + print("Cannot find current version!") + return None + + +def update_version(path, current_version, new_version): + """Update the version code in the codebase.""" + # Update setup.py + replace_string_in_line( + "_VERSION = '%s'" % current_version, + # pep440 requires such a replacement + "_VERSION = '%s'" % new_version.replace("-", "."), + os.path.join(path, SETUP_PY_PATH)) + + +class CustomTimeZone(datetime.tzinfo): + + def utcoffset(self, dt): + return -datetime.timedelta(hours=8) + + def tzname(self, dt): + return "UTC-8" + + def dst(self, dt): + return datetime.timedelta(0) + + +def remove_build_suffix(version): + """Remove build suffix (if exists) from a version.""" + if version.find("-dev") >= 0: + return version[:version.find("-dev")] + if version.find(".dev") >= 0: + return version[:version.find(".dev")] + if version.find("dev") >= 0: + return version[:version.find("dev")] + return version + + +def main(): + parser = argparse.ArgumentParser(description="Update TFLS version in repo") + parser.add_argument( + "--src", + help="a path pointing to the code repo", + required=True, + default="") + parser.add_argument("--version", help="the new SemVer code", default="") + parser.add_argument( + "--nightly", + help="if true, a build suffix will append to the version code. If " + "current version code or the <version> argument provided contains a " + "build suffix, the suffix will be replaced with the timestamp", + action="store_true") + args = parser.parse_args() + + path = args.src + current_version = get_current_version(path) + if not current_version: + return + new_version = args.version if args.version else current_version + if args.nightly: + new_version = remove_build_suffix(new_version) + # Use UTC-8 rather than uncertain local time. + d = datetime.datetime.now(tz=CustomTimeZone()) + new_version += "-dev" + d.strftime("%Y%m%d") + print("Updating version from %s to %s" % (current_version, new_version)) + update_version(path, current_version, new_version) + + +if __name__ == "__main__": + main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/BUILD index 61df24a..4341e34 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/BUILD
@@ -12,6 +12,13 @@ "setup.py", "//tensorflow_lite_support/codegen/python:codegen", "//tensorflow_lite_support/metadata/python:metadata", + "//tensorflow_lite_support/metadata/python/metadata_writers:writer_utils", + "//tensorflow_lite_support/metadata/python/metadata_writers:metadata_info", + "//tensorflow_lite_support/metadata/python/metadata_writers:image_classifier", + "//tensorflow_lite_support/metadata/python/metadata_writers:object_detector", + "//tensorflow_lite_support/metadata/python/metadata_writers:image_segmenter", + "//tensorflow_lite_support/metadata/python/metadata_writers:audio_classifier", + "//tensorflow_lite_support/metadata/python/metadata_writers:nl_classifier", ] filegroup(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/build_pip_package.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/build_pip_package.sh index 66c05b4..d27c392 100755 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/build_pip_package.sh +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/build_pip_package.sh
@@ -92,6 +92,8 @@ # A helper entry. mkdir ${TMPDIR}/tflite_support cp tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py ${TMPDIR}/tflite_support/__init__.py + mkdir ${TMPDIR}/tflite_support/metadata_writers + cp tensorflow_lite_support/tools/pip_package/metadata_writers.__init__.py ${TMPDIR}/tflite_support/metadata_writers/__init__.py } function build_wheel() { @@ -133,8 +135,9 @@ echo " if dstdir is not set do not build, only prepare sources" echo "" echo " Options:" - echo " --project_name <name> set project name to name" - echo " --nightly_flag build TFLite Support nightly" + echo " --project_name <name> set project name to <name>" + echo " --version <version> reset the pip package version to <version>" + echo " --nightly_flag build TFLite Support nightly" echo "" echo "When using bazel, add the following flag: --run_under=\"cd \$PWD && \"" echo "" @@ -148,6 +151,7 @@ SRCDIR="" DSTDIR="" CLEANSRC=1 + VERSION="" while true; do if [[ "$1" == "--help" ]]; then usage @@ -160,6 +164,12 @@ break fi PROJECT_NAME="$1" + elif [[ "$1" == "--version" ]]; then + shift + if [[ -z "$1" ]]; then + break + fi + VERSION="$1" elif [[ "$1" == "--src" ]]; then shift SRCDIR="$(real_path $1)" @@ -168,7 +178,9 @@ shift DSTDIR="$(real_path $1)" else - DSTDIR="$(real_path $1)" + echo "Unrecognized flag: $1" + usage + exit 1 fi shift @@ -188,8 +200,6 @@ SRCDIR="$(mktemp -d -t tmp.XXXXXXXXXX)" fi - prepare_src "$SRCDIR" - if [[ -z "$DSTDIR" ]]; then # only want to prepare sources exit @@ -201,6 +211,24 @@ PKG_NAME_FLAG="--project_name tflite_support_nightly" fi + # Set additional package name flags (for ARM builds). + if [[ -n ${EXTRA_PKG_NAME_FLAG} ]]; then + PKG_NAME_FLAG="${PKG_NAME_FLAG} ${EXTRA_PKG_NAME_FLAG}" + fi + + if [[ ${NIGHTLY_BUILD} == "1" ]]; then + # we use a script to update versions to avoid any tool differences on different platforms. + if [[ ! -z ${VERSION} ]]; then + python tensorflow_lite_support/tools/ci_build/update_version.py --src "." --version ${VERSION} --nightly + else + python tensorflow_lite_support/tools/ci_build/update_version.py --src "." --nightly + fi + elif [[ ! -z ${VERSION} ]]; then + python tensorflow_lite_support/tools/ci_build/update_version.py --src "." --version ${VERSION} + fi + + prepare_src "$SRCDIR" + build_wheel "$SRCDIR" "$DSTDIR" "$PKG_NAME_FLAG" if [[ $CLEANSRC -ne 0 ]]; then
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/metadata_writers.__init__.py b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/metadata_writers.__init__.py new file mode 100644 index 0000000..83336a7 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/metadata_writers.__init__.py
@@ -0,0 +1,23 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An import entry for the metadata writers library.""" + +from tensorflow_lite_support.metadata.python.metadata_writers import audio_classifier +from tensorflow_lite_support.metadata.python.metadata_writers import image_classifier +from tensorflow_lite_support.metadata.python.metadata_writers import image_segmenter +from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info +from tensorflow_lite_support.metadata.python.metadata_writers import nl_classifier +from tensorflow_lite_support.metadata.python.metadata_writers import object_detector +from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/Dockerfile.py3 b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/Dockerfile.py3 new file mode 100644 index 0000000..74ed5d9 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/Dockerfile.py3
@@ -0,0 +1,60 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG IMAGE +FROM ${IMAGE} +ARG PYTHON_VERSION + +COPY update_sources.sh / +RUN /update_sources.sh + +RUN apt-get update && \ + apt-get install -y \ + build-essential \ + software-properties-common \ + zlib1g-dev \ + curl \ + unzip \ + git && \ + apt-get clean + +# Install Python packages. +RUN dpkg --add-architecture armhf +RUN dpkg --add-architecture arm64 +RUN yes | add-apt-repository ppa:deadsnakes/ppa +RUN apt-get update && \ + apt-get install -y \ + python$PYTHON_VERSION \ + python$PYTHON_VERSION-dev \ + python$PYTHON_VERSION-distutils \ + libpython$PYTHON_VERSION-dev \ + libpython$PYTHON_VERSION-dev:armhf \ + libpython$PYTHON_VERSION-dev:arm64 +RUN ln -sf /usr/bin/python$PYTHON_VERSION /usr/bin/python3 +RUN curl -OL https://bootstrap.pypa.io/get-pip.py +RUN python3 get-pip.py +RUN rm get-pip.py +RUN pip3 install --upgrade pip +RUN pip3 install numpy~=1.19.2 setuptools pybind11 +RUN ln -sf /usr/include/python$PYTHON_VERSION /usr/include/python3 +RUN ln -sf /usr/local/lib/python$PYTHON_VERSION/dist-packages/numpy/core/include/numpy /usr/include/python3/numpy +RUN ln -sf /usr/bin/python3 /usr/bin/python + +ENV CI_BUILD_PYTHON=python$PYTHON_VERSION +ENV CROSSTOOL_PYTHON_INCLUDE_PATH=/usr/include/python$PYTHON_VERSION + +COPY install_bazel.sh / +RUN /install_bazel.sh + +COPY with_the_same_user /
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/Makefile b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/Makefile new file mode 100644 index 0000000..31d600b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/Makefile
@@ -0,0 +1,63 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Values: debian:<version>, ubuntu:<version> +BASE_IMAGE ?= ubuntu:16.04 +PYTHON_VERSION ?= 3.9 + +MAKEFILE_DIR := $(realpath $(dir $(lastword $(MAKEFILE_LIST)))) +WORKSPACE_DIR := $(MAKEFILE_DIR)/../../../.. +TAG_IMAGE := "tflite-runtime-builder-$(subst :,-,$(BASE_IMAGE))" + +DOCKER_PARAMS := --pid=host \ + --env "CI_BUILD_USER=$(shell id -u -n)" \ + --env "CI_BUILD_UID=$(shell id -u)" \ + --env "CI_BUILD_GROUP=$(shell id -g -n)" \ + --env "CI_BUILD_GID=$(shell id -g)" \ + --env "CI_BUILD_HOME=$(WORKSPACE_DIR)/bazel-ci_build-cache" \ + --volume $(WORKSPACE_DIR)/bazel-ci_build-cache:$(WORKSPACE_DIR)/bazel-ci-build-cache \ + --volume $(WORKSPACE_DIR):/workspace \ + --workdir /workspace + +.PHONY: help \ + docker-image \ + docker-shell \ + docker-build \ + clean + +help: + @echo "make docker-image -- build docker image" + @echo "make docker-shell -- run shell inside the docker image" + @echo "make docker-build -- build wheel and deb inside the docker image" + @echo "make clean -- remove wheel and deb files" + +docker-image: + docker build -t $(TAG_IMAGE) --build-arg IMAGE=$(BASE_IMAGE) --build-arg PYTHON_VERSION=$(PYTHON_VERSION) -f $(MAKEFILE_DIR)/Dockerfile.py3 $(MAKEFILE_DIR)/. + +docker-shell: docker-image + mkdir -p $(WORKSPACE_DIR)/bazel-ci_build-cache + docker run --rm --interactive --tty \ + $(DOCKER_PARAMS) \ + $(TAG_IMAGE) /with_the_same_user /bin/bash + +docker-build: docker-image + mkdir -p $(WORKSPACE_DIR)/bazel-ci_build-cache + docker run \ + --rm --interactive $(shell tty -s && echo --tty) \ + $(DOCKER_PARAMS) \ + $(TAG_IMAGE) \ + /with_the_same_user /bin/bash -C tensorflow_lite_support/tools/pip_package/rpi/build_arm_pip_package.sh + +clean: + rm -rf $(CURDIR)/wheels
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/build_arm_pip_package.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/build_arm_pip_package.sh new file mode 100755 index 0000000..9676de4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/build_arm_pip_package.sh
@@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -ex + +bazel build -c opt --config=elinux_armhf tensorflow_lite_support/tools/pip_package:build_pip_package +EXTRA_PKG_NAME_FLAG="--plat-name=manylinux2014-armv7l" ./bazel-bin/tensorflow_lite_support/tools/pip_package/build_pip_package --dst wheels --nightly_flag +bazel build -c opt --config=elinux_aarch64 tensorflow_lite_support/tools/pip_package:build_pip_package +EXTRA_PKG_NAME_FLAG="--plat-name=manylinux2014-aarch64" ./bazel-bin/tensorflow_lite_support/tools/pip_package/build_pip_package --dst wheels --nightly_flag
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/install_bazel.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/install_bazel.sh new file mode 100755 index 0000000..063e1f9 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/install_bazel.sh
@@ -0,0 +1,40 @@ +#!/usr/bin/env bash +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Select bazel version. +BAZEL_VERSION="3.7.2" + +set +e +local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}') + +if [[ "$local_bazel_ver" == "$BAZEL_VERSION" ]]; then + exit 0 +fi + +set -e + +# Install bazel. +mkdir -p /bazel +cd /bazel +if [[ ! -f "bazel-$BAZEL_VERSION-installer-linux-x86_64.sh" ]]; then + curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh +fi +chmod +x /bazel/bazel-*.sh +/bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh +rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh + +# Enable bazel auto completion. +echo "source /usr/local/lib/bazel/bin/bazel-complete.bash" >> ~/.bashrc
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/update_sources.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/update_sources.sh new file mode 100755 index 0000000..40e3213 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/update_sources.sh
@@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +#!/bin/bash +. /etc/os-release + +[[ "${NAME}" == "Ubuntu" ]] || exit 0 + +sed -i "s/deb\ /deb \[arch=amd64\]\ /g" /etc/apt/sources.list + +cat <<EOT >> /etc/apt/sources.list +deb [arch=arm64,armhf] http://ports.ubuntu.com/ubuntu-ports ${UBUNTU_CODENAME} main universe +deb [arch=arm64,armhf] http://ports.ubuntu.com/ubuntu-ports ${UBUNTU_CODENAME}-updates main universe +deb [arch=arm64,armhf] http://ports.ubuntu.com/ubuntu-ports ${UBUNTU_CODENAME}-security main universe +EOT
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/with_the_same_user b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/with_the_same_user new file mode 100755 index 0000000..0c8c5069 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/with_the_same_user
@@ -0,0 +1,65 @@ +#!/usr/bin/env bash +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# This script is a wrapper creating the same user inside container as the one +# running the ci_build.sh outside the container. It also set the home directory +# for the user inside container to match the same absolute path as the workspace +# outside of container. +# We do this so that the bazel running inside container generate symbolic links +# and user permissions which makes sense outside of container. +# Do not run this manually. It does not make sense. It is intended to be called +# by ci_build.sh only. + +set -e + +COMMAND=("$@") + +if ! touch /this_is_writable_file_system; then + echo "You can't write to your filesystem!" + echo "If you are in Docker you should check you do not have too many images" \ + "with too many files in them. Docker has some issue with it." + exit 1 +else + rm /this_is_writable_file_system +fi + +if [ -n "${CI_BUILD_USER_FORCE_BADNAME}" ]; then + ADDUSER_OPTS="--force-badname" +fi + +apt-get install sudo + +getent group "${CI_BUILD_GID}" || addgroup ${ADDUSER_OPTS} --gid "${CI_BUILD_GID}" "${CI_BUILD_GROUP}" +getent passwd "${CI_BUILD_UID}" || adduser ${ADDUSER_OPTS} \ + --gid "${CI_BUILD_GID}" --uid "${CI_BUILD_UID}" \ + --gecos "${CI_BUILD_USER} (generated by with_the_same_user script)" \ + --disabled-password --home "${CI_BUILD_HOME}" --quiet "${CI_BUILD_USER}" +usermod -a -G sudo "${CI_BUILD_USER}" +echo "${CI_BUILD_USER} ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-nopasswd-sudo + +if [[ "${TF_NEED_ROCM}" -eq 1 ]]; then + # ROCm requires the video group in order to use the GPU for compute. If it + # exists on the host, add it to the container. + getent group video || addgroup video && adduser "${CI_BUILD_USER}" video +fi + +if [ -e /root/.bazelrc ]; then + cp /root/.bazelrc "${CI_BUILD_HOME}/.bazelrc" + chown "${CI_BUILD_UID}:${CI_BUILD_GID}" "${CI_BUILD_HOME}/.bazelrc" +fi + +sudo -u "#${CI_BUILD_UID}" --preserve-env "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" \ +"HOME=${CI_BUILD_HOME}" ${COMMAND[@]}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/setup.py b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/setup.py index 9b8d0ad..df97ebe7 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/setup.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/setup.py
@@ -43,16 +43,18 @@ # This version string is semver compatible, but incompatible with pip. # For pip, we will remove all '-' characters from this string, and use the # result for pip. -_VERSION = '0.1.0' +_VERSION = '0.3.0' SETUP_PACKAGES = [ - 'pybind11 >= 2.4', + 'pybind11 >= 2.6.0', ] REQUIRED_PACKAGES = [ 'absl-py >= 0.7.0', - 'numpy >= 1.16.0', - 'flatbuffers >= 1.12', + 'numpy >= 1.19.2', + # TODO(b/187981032): remove the constraint for 2.0 once the incompatibile + # issue is resolved. + 'flatbuffers >= 1.12, <2', ] + SETUP_PACKAGES project_name = 'tflite-support'
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py index c5bf3de..da5c2d8 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py
@@ -26,3 +26,4 @@ from tensorflow_lite_support.metadata import metadata_schema_py_generated from tensorflow_lite_support.metadata import schema_py_generated from tensorflow_lite_support.metadata.python import metadata +from tflite_support import metadata_writers
diff --git a/third_party/tflite_support/src/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff b/third_party/tflite_support/src/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff index 0cd2dff..5725c34 100644 --- a/third_party/tflite_support/src/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff +++ b/third_party/tflite_support/src/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff
@@ -1,8 +1,8 @@ diff --git a/absl/time/internal/cctz/BUILD.bazel b/absl/time/internal/cctz/BUILD.bazel -index 9fceffe..e7f9d01 100644 +index 45a9529..461164d 100644 --- a/absl/time/internal/cctz/BUILD.bazel +++ b/absl/time/internal/cctz/BUILD.bazel -@@ -69,8 +69,5 @@ cc_library( +@@ -75,9 +75,6 @@ cc_library( "include/cctz/zone_info_source.h", ], linkopts = select({ @@ -11,4 +11,4 @@ - ], ":ios": [ "-framework Foundation", - ], \ No newline at end of file + ],
diff --git a/third_party/tflite_support/src/third_party/com_google_protobuf_fixes.diff b/third_party/tflite_support/src/third_party/com_google_protobuf_fixes.diff deleted file mode 100644 index b9bc17e..0000000 --- a/third_party/tflite_support/src/third_party/com_google_protobuf_fixes.diff +++ /dev/null
@@ -1,140 +0,0 @@ -diff --git a/BUILD b/BUILD -index 79871d621..51b3a063f 100644 ---- a/BUILD -+++ b/BUILD -@@ -26,7 +26,7 @@ config_setting( - # ZLIB configuration - ################################################################################ - --ZLIB_DEPS = ["@zlib//:zlib"] -+ZLIB_DEPS = ["@zlib"] - - ################################################################################ - # Protobuf Runtime Library -@@ -157,6 +157,7 @@ cc_library( - includes = ["src/"], - linkopts = LINK_OPTS, - visibility = ["//visibility:public"], -+ alwayslink = 1, - ) - - PROTOBUF_DEPS = select({ -@@ -230,6 +231,7 @@ cc_library( - linkopts = LINK_OPTS, - visibility = ["//visibility:public"], - deps = [":protobuf_lite"] + PROTOBUF_DEPS, -+ alwayslink = 1, - ) - - # This provides just the header files for use in projects that need to build -@@ -318,13 +320,13 @@ cc_proto_library( - - [native_cc_proto_library( - name = proto + "_cc_proto", -- deps = [proto + "_proto"], - visibility = ["//visibility:private"], -+ deps = [proto + "_proto"], - ) for proto in WELL_KNOWN_PROTO_MAP.keys()] - - cc_proto_blacklist_test( - name = "cc_proto_blacklist_test", -- deps = [proto + "_cc_proto" for proto in WELL_KNOWN_PROTO_MAP.keys()] -+ deps = [proto + "_cc_proto" for proto in WELL_KNOWN_PROTO_MAP.keys()], - ) - - ################################################################################ -@@ -900,7 +902,6 @@ py_proto_library( - py_extra_srcs = glob(["python/**/__init__.py"]), - py_libs = [ - ":python_srcs", -- "@six//:six", - ], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], -@@ -1002,7 +1003,9 @@ cc_library( - # Note: We use `native_proto_common` here because we depend on an implementation-detail of - # `proto_lang_toolchain`, which may not be available on `proto_common`. - reject_blacklisted_files = hasattr(native_proto_common, "proto_lang_toolchain_rejects_files_do_not_use_or_we_will_break_you_without_mercy") -+ - cc_toolchain_blacklisted_protos = [proto + "_proto" for proto in WELL_KNOWN_PROTO_MAP.keys()] if reject_blacklisted_files else [":well_known_protos"] -+ - proto_lang_toolchain( - name = "cc_toolchain", - blacklisted_protos = cc_toolchain_blacklisted_protos, -diff --git a/protobuf.bzl b/protobuf.bzl -index 829464d44..4ac23594b 100644 ---- a/protobuf.bzl -+++ b/protobuf.bzl -@@ -87,6 +87,8 @@ def _proto_gen_impl(ctx): - for dep in ctx.attr.deps: - import_flags += dep.proto.import_flags - deps += dep.proto.deps -+ import_flags = depset(import_flags).to_list() -+ deps = depset(deps).to_list() - - if not ctx.attr.gen_cc and not ctx.attr.gen_py and not ctx.executable.plugin: - return struct( -diff --git a/src/google/protobuf/io/gzip_stream.h b/src/google/protobuf/io/gzip_stream.h -index b1ce1d36c..d5d560ea7 100644 ---- a/src/google/protobuf/io/gzip_stream.h -+++ b/src/google/protobuf/io/gzip_stream.h -@@ -47,10 +47,12 @@ - #include <google/protobuf/stubs/common.h> - #include <google/protobuf/io/zero_copy_stream.h> - #include <google/protobuf/port.h> --#include <zlib.h> -- - #include <google/protobuf/port_def.inc> - -+#if HAVE_ZLIB -+#include <zlib.h> -+#endif // HAVE_ZLIB -+ - namespace google { - namespace protobuf { - namespace io { -@@ -76,8 +78,10 @@ class PROTOBUF_EXPORT GzipInputStream : public ZeroCopyInputStream { - virtual ~GzipInputStream(); - - // Return last error message or NULL if no error. -+#if HAVE_ZLIB - inline const char* ZlibErrorMessage() const { return zcontext_.msg; } - inline int ZlibErrorCode() const { return zerror_; } -+#endif // HAVE_ZLIB - - // implements ZeroCopyInputStream ---------------------------------- - bool Next(const void** data, int* size); -@@ -90,8 +94,10 @@ class PROTOBUF_EXPORT GzipInputStream : public ZeroCopyInputStream { - - ZeroCopyInputStream* sub_stream_; - -+ #if HAVE_ZLIB - z_stream zcontext_; - int zerror_; -+ #endif // HAVE_ZLIB - - void* output_buffer_; - void* output_position_; -@@ -142,9 +148,11 @@ class PROTOBUF_EXPORT GzipOutputStream : public ZeroCopyOutputStream { - - virtual ~GzipOutputStream(); - -+#if HAVE_ZLIB - // Return last error message or NULL if no error. - inline const char* ZlibErrorMessage() const { return zcontext_.msg; } - inline int ZlibErrorCode() const { return zerror_; } -+#endif // HAVE_ZLIB - - // Flushes data written so far to zipped data in the underlying stream. - // It is the caller's responsibility to flush the underlying stream if -@@ -177,8 +185,10 @@ class PROTOBUF_EXPORT GzipOutputStream : public ZeroCopyOutputStream { - void* sub_data_; - int sub_data_size_; - -+#if HAVE_ZLIB - z_stream zcontext_; - int zerror_; -+#endif //HAVE_ZLIB - void* input_buffer_; - size_t input_buffer_length_; -
diff --git a/third_party/tflite_support/src/third_party/flatbuffers/BUILD.bazel b/third_party/tflite_support/src/third_party/flatbuffers/BUILD.bazel index 1ee46f05..9bd6965 100644 --- a/third_party/tflite_support/src/third_party/flatbuffers/BUILD.bazel +++ b/third_party/tflite_support/src/third_party/flatbuffers/BUILD.bazel
@@ -1,4 +1,5 @@ load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load(":build_defs.bzl", "flatbuffer_py_strip_prefix_srcs") package(default_visibility = ["//visibility:public"]) @@ -102,8 +103,10 @@ visibility = ["//visibility:public"], ) -filegroup( - name = "runtime_py_srcs", +# Note: Don't include "flexbuffers.py" as it's not available in the latest +# pip package (version 1.12.0) and causes failures in internal pip tests. +flatbuffer_py_strip_prefix_srcs( + name = "flatbuffer_py_strip_prefix", srcs = [ "python/flatbuffers/__init__.py", "python/flatbuffers/builder.py", @@ -114,6 +117,21 @@ "python/flatbuffers/table.py", "python/flatbuffers/util.py", ], + strip_prefix = "python/flatbuffers/", +) + +filegroup( + name = "runtime_py_srcs", + srcs = [ + "__init__.py", + "builder.py", + "compat.py", + "encode.py", + "number_types.py", + "packer.py", + "table.py", + "util.py", + ], ) py_library(
diff --git a/third_party/tflite_support/src/third_party/flatbuffers/build_defs.bzl b/third_party/tflite_support/src/third_party/flatbuffers/build_defs.bzl index 45ff816..2f1e8ce4 100644 --- a/third_party/tflite_support/src/third_party/flatbuffers/build_defs.bzl +++ b/third_party/tflite_support/src/third_party/flatbuffers/build_defs.bzl
@@ -24,6 +24,7 @@ out_prefix = "", includes = [], include_paths = [], + compatible_with = [], flatc_args = DEFAULT_FLATC_ARGS, reflection_name = "", reflection_visibility = None, @@ -43,6 +44,8 @@ single source targets. Usually is a directory name. includes: Optional, list of filegroups of schemas that the srcs depend on. include_paths: Optional, list of paths the includes files can be found in. + compatible_with: Optional, passed to genrule for environments this rule + can be built for. flatc_args: Optional, list of additional arguments to pass to flatc. reflection_name: Optional, if set this will generate the flatbuffer reflection binaries for the schemas. @@ -72,6 +75,7 @@ srcs = srcs, outs = outs, output_to_bindir = output_to_bindir, + compatible_with = compatible_with, tools = includes + [flatc_path], cmd = genrule_cmd, message = "Generating flatbuffer files for %s:" % (name), @@ -97,6 +101,7 @@ srcs = srcs, outs = reflection_outs, output_to_bindir = output_to_bindir, + compatible_with = compatible_with, tools = includes + [flatc_path], cmd = reflection_genrule_cmd, message = "Generating flatbuffer reflection binary for %s:" % (name), @@ -111,6 +116,7 @@ # native.FilesetEntry(files = reflection_outs), # ], # visibility = reflection_visibility, + # compatible_with = compatible_with, # ) def flatbuffer_cc_library( @@ -120,6 +126,7 @@ out_prefix = "", includes = [], include_paths = [], + compatible_with = [], flatc_args = DEFAULT_FLATC_ARGS, visibility = None, srcs_filegroup_visibility = None, @@ -175,6 +182,8 @@ includes: Optional, list of filegroups of schemas that the srcs depend on. ** SEE REMARKS BELOW ** include_paths: Optional, list of paths the includes files can be found in. + compatible_with: Optional, passed to genrule for environments this rule + can be built for flatc_args: Optional list of additional arguments to pass to flatc (e.g. --gen-mutable). visibility: The visibility of the generated cc_library. By default, use the @@ -198,6 +207,7 @@ out_prefix = out_prefix, includes = includes, include_paths = include_paths, + compatible_with = compatible_with, flatc_args = flatc_args, reflection_name = reflection_name, reflection_visibility = visibility, @@ -215,6 +225,7 @@ includes = ["."], linkstatic = 1, visibility = visibility, + compatible_with = compatible_with, ) # A filegroup for the `srcs`. That is, all the schema files for this @@ -223,6 +234,7 @@ name = srcs_filegroup_name if srcs_filegroup_name else "%s_includes" % (name), srcs = srcs, visibility = srcs_filegroup_visibility if srcs_filegroup_visibility != None else visibility, + compatible_with = compatible_with, ) # Custom provider to track dependencies transitively. @@ -349,18 +361,37 @@ output_to_genfiles = True, ) +def flatbuffer_py_strip_prefix_srcs(name, srcs = [], strip_prefix = ""): + """Strips path prefix. + + Args: + name: Rule name. (required) + srcs: Source .py files. (required) + strip_prefix: Path that needs to be stripped from the srcs filepaths. (required) + """ + for src in srcs: + native.genrule( + name = name + "_" + src.replace(".", "_").replace("/", "_"), + srcs = [src], + outs = [src.replace(strip_prefix, "")], + cmd = "cp $< $@", + ) + def _concat_flatbuffer_py_srcs_impl(ctx): # Merge all generated python files. The files are concatenated and the # import statements are removed. Finally we import the flatbuffer runtime # library. - command = "for f in $(find %s -name '*.py'); do cat $f | sed '/import flatbuffers/d' >> %s; done; " - command += "sed -i '1s/^/import flatbuffers\\'$'\\n/' %s" + # IMPORTANT: Our Windows shell does not support "find ... -exec" properly. + # If changing the commandline here, please build wheels and run smoke tests + # on all the three operation systems. + command = "echo 'import flatbuffers\n' > %s; " + command += "for f in $(find %s -name '*.py'); do cat $f | sed '/import flatbuffers/d' >> %s; done " ctx.actions.run_shell( inputs = ctx.attr.deps[0].files, outputs = [ctx.outputs.out], command = command % ( - ctx.attr.deps[0].files.to_list()[0].path, ctx.outputs.out.path, + ctx.attr.deps[0].files.to_list()[0].path, ctx.outputs.out.path, ), )
diff --git a/third_party/tflite_support/src/third_party/libzip.BUILD b/third_party/tflite_support/src/third_party/libzip.BUILD index b69ccf4..2f75f40 100644 --- a/third_party/tflite_support/src/third_party/libzip.BUILD +++ b/third_party/tflite_support/src/third_party/libzip.BUILD
@@ -2,7 +2,7 @@ default_visibility = ["//visibility:public"], ) -load("@org_tensorflow_lite_support//tensorflow_lite_support/tools:build_rules/expand_template.bzl", "cmake_substitutions", "expand_template") +load("@org_tensorflow_lite_support//tensorflow_lite_support/tools/build_rules:expand_template.bzl", "cmake_substitutions", "expand_template") _CMAKE_VARIABLES = { "INT16_T_LIBZIP": 2, @@ -34,12 +34,12 @@ _CMAKE_VARIABLES.update(dict([ ( "ZIP_{sign}INT{size}_T".format( - sign = sign.upper(), size = size, + sign = sign.upper(), ), "{sign}int{size}_t".format( - sign = sign.lower(), size = size, + sign = sign.lower(), ), ) for sign in ("U", "") @@ -130,12 +130,12 @@ _VARS.update(dict([ ( "ZIP_{sign}INT{size}_T".format( - sign = sign.upper(), size = size, + sign = sign.upper(), ), "{sign}int{size}_t".format( - sign = sign.lower(), size = size, + sign = sign.lower(), ), ) for sign in ("U", "")
diff --git a/third_party/tflite_support/src/third_party/tensorflow_lite_ios_build.patch b/third_party/tflite_support/src/third_party/tensorflow_lite_ios_build.patch index e5d3128fc..2edf08b 100644 --- a/third_party/tflite_support/src/third_party/tensorflow_lite_ios_build.patch +++ b/third_party/tflite_support/src/third_party/tensorflow_lite_ios_build.patch
@@ -1,37 +1,4 @@ -diff --git a/tensorflow/lite/experimental/ios/BUILD.apple b/tensorflow/lite/experimental/ios/BUILD -similarity index 99% -rename from tensorflow/lite/experimental/ios/BUILD.apple -rename to tensorflow/lite/experimental/ios/BUILD -index e1e3be2bcd..912d27702b 100644 ---- a/tensorflow/lite/experimental/ios/BUILD.apple -+++ b/tensorflow/lite/experimental/ios/BUILD -@@ -21,6 +21,7 @@ sh_binary( - srcs = [ - "hide_symbols_with_allowlist.sh", - ], -+ visibility = ["//visibility:public"], - ) - - strip_common_include_path_prefix( -diff --git a/tensorflow/lite/experimental/ios/ios.bzl b/tensorflow/lite/experimental/ios/ios.bzl -index 63747eb8d1..07bcb49de0 100644 ---- a/tensorflow/lite/experimental/ios/ios.bzl -+++ b/tensorflow/lite/experimental/ios/ios.bzl -@@ -60,7 +60,7 @@ def tflite_ios_static_framework( - "BUNDLE_NAME=\"" + bundle_name + "\" " + - "ALLOWLIST_FILE_PATH=\"$(location " + allowlist_symbols_file + ")\" " + - "OUTPUT=\"$(OUTS)\" " + -- "\"$(location //tensorflow/lite/experimental/ios:hide_symbols_with_allowlist)\"") -+ "\"$(location @org_tensorflow//tensorflow/lite/experimental/ios:hide_symbols_with_allowlist)\"") - - native.genrule( - name = name, -@@ -68,7 +68,7 @@ def tflite_ios_static_framework( - outs = [name + ".zip"], - cmd = cmd, - tools = [ -- "//tensorflow/lite/experimental/ios:hide_symbols_with_allowlist", -+ "@org_tensorflow//tensorflow/lite/experimental/ios:hide_symbols_with_allowlist", - ], - ) - +diff --git a/tensorflow/lite/ios/BUILD.apple b/tensorflow/lite/ios/BUILD +similarity index 100% +rename from tensorflow/lite/ios/BUILD.apple +rename to tensorflow/lite/ios/BUILD
diff --git a/third_party/tflite_support/src/third_party/zlib.BUILD b/third_party/tflite_support/src/third_party/zlib.BUILD index 275782e..3a93488 100644 --- a/third_party/tflite_support/src/third_party/zlib.BUILD +++ b/third_party/tflite_support/src/third_party/zlib.BUILD
@@ -37,3 +37,30 @@ copts = ["-Wno-implicit-function-declaration"], includes = ["."], ) + +cc_library( + name = "zlib_minizip", + srcs = [ + "contrib/minizip/ioapi.c", + "contrib/minizip/miniunz.c", + "contrib/minizip/minizip.c", + "contrib/minizip/unzip.c", + "contrib/minizip/zip.c", + ], + hdrs = [ + "contrib/minizip/crypt.h", + "contrib/minizip/ioapi.h", + "contrib/minizip/mztools.h", + "contrib/minizip/unzip.h", + "contrib/minizip/zip.h", + ], + copts = [ + "-Wno-dangling-else", + "-Wno-format", + "-Wno-incompatible-pointer-types", + "-Wno-incompatible-pointer-types-discards-qualifiers", + "-Wno-parentheses", + "-DIOAPI_NO_64", + ], + deps = [":zlib"], +)
diff --git a/tools/android/checkxmlstyle/checkxmlstyle.py b/tools/android/checkxmlstyle/checkxmlstyle.py index 44c46e5..3d3fb35 100644 --- a/tools/android/checkxmlstyle/checkxmlstyle.py +++ b/tools/android/checkxmlstyle/checkxmlstyle.py
@@ -310,7 +310,13 @@ namespace = {'android': 'http://schemas.android.com/apk/res/android'} errors = [] for f in IncludedFiles(input_api): - root = ET.fromstring(input_api.ReadFile(f)) + try: + root = ET.fromstring(input_api.ReadFile(f)) + except ET.ParseError: + print('*' * 80) + print('Parse error processing file:', f) + print('*' * 80) + raise # Check if there are text attributes defined outside text appearances. for attribute in text_attributes: # Get style name that contains text attributes but is not text appearance.
diff --git a/tools/binary_size/libsupersize/archive.py b/tools/binary_size/libsupersize/archive.py index 417bb39e..1ed215c 100644 --- a/tools/binary_size/libsupersize/archive.py +++ b/tools/binary_size/libsupersize/archive.py
@@ -47,14 +47,6 @@ import zip_util -# Holds computation state that is live only when an output directory exists. -_OutputDirectoryContext = collections.namedtuple('_OutputDirectoryContext', [ - 'elf_object_paths', # Only when elf_path is also provided. - 'known_inputs', # Only when elf_path is also provided. - 'output_directory', - 'thin_archives', -]) - # When ensuring matching section sizes between .elf and .map files, these # sections should be ignored. When lld creates a combined library with # partitions, some sections (like .text) exist in each partition, but the ones @@ -79,16 +71,34 @@ _MAX_SAME_NAME_ALIAS_COUNT = 40 # 50kb is basically negligible. +# Holds computation state that is live only when an output directory exists. +@dataclasses.dataclass +class _OutputDirectoryContext: + elf_object_paths: list # Non-None only when elf_path is. + known_inputs: list # Non-None only when elf_path is. + output_directory: str + thin_archives: list + + @dataclasses.dataclass class NativeSpec: # One (or more) of apk_so_path, map_path, elf_path must be non-None. tool_prefix: str # Never None. + # Path within the .apk of the .so file. Non-None only when apk_spec is. apk_so_path: str = None map_path: str = None elf_path: str = None # Unstripped .so path. - linker_name: str = None + linker_name: str = None # Requires map_path. Either 'gold' or 'lld'. track_string_literals: bool = True + @property + def algorithm(self): + if self.map_path: + return 'linker_map' + if self.elf_path: + return 'dwarf' + return 'sections' + @dataclasses.dataclass class PakSpec: @@ -700,13 +710,7 @@ native_spec.tool_prefix) update_build_config(models.BUILD_CONFIG_TOOL_PREFIX, relative_tool_prefix) - if native_spec.map_path: - metadata[models.METADATA_ELF_ALGORITHM] = 'linker_map' - elif native_spec.elf_path: - metadata[models.METADATA_ELF_ALGORITHM] = 'dwarf' - else: - metadata[models.METADATA_ELF_ALGORITHM] = 'sections' - + metadata[models.METADATA_ELF_ALGORITHM] = native_spec.algorithm if native_spec.linker_name: update_build_config(models.BUILD_CONFIG_LINKER_NAME, native_spec.linker_name) @@ -1029,69 +1033,6 @@ return path -def _ParseApkOtherSymbols(*, apk_spec, native_spec, section_ranges, - resources_pathmap_path, metadata): - apk_so_path = native_spec and native_spec.apk_so_path - res_source_mapper = _ResourceSourceMapper(apk_spec.size_info_prefix, - apk_spec.path_defaults) - resource_deobfuscator = _ResourcePathDeobfuscator(resources_pathmap_path) - apk_symbols = [] - dex_size = 0 - zip_info_total = 0 - zipalign_total = 0 - with zipfile.ZipFile(apk_spec.apk_path) as z: - signing_block_size = zip_util.MeasureApkSignatureBlock(z) - for zip_info in z.infolist(): - zip_info_total += zip_info.compress_size - # Account for zipalign overhead that exists in local file header. - zipalign_total += zip_util.ReadZipInfoExtraFieldLength(z, zip_info) - # Account for zipalign overhead that exists in central directory header. - # Happens when python aligns entries in apkbuilder.py, but does not - # exist when using Android's zipalign. E.g. for bundle .apks files. - zipalign_total += len(zip_info.extra) - # Skip files that we explicitly analyze: .so, .dex, and .pak. - if zip_info.filename == apk_so_path: - continue - if apk_spec.analyze_dex and zip_info.filename.endswith('.dex'): - dex_size += zip_info.file_size - continue - if zip_info.filename.endswith('.pak'): - continue - - resource_filename = resource_deobfuscator.MaybeRemapPath( - zip_info.filename) - source_path = res_source_mapper.FindSourceForPath(resource_filename) - if source_path is None: - source_path = os.path.join(models.APK_PREFIX_PATH, resource_filename) - apk_symbols.append( - models.Symbol( - models.SECTION_OTHER, - zip_info.compress_size, - source_path=source_path, - full_name=resource_filename)) # Full name must disambiguate - - # Store zipalign overhead and signing block size as metadata rather than an - # "Overhead:" symbol because they fluctuate in size, and would be a source of - # noise in symbol diffs if included as symbols (http://crbug.com/1130754). - # Might be even better if we had an option in Tiger Viewer to ignore certain - # symbols, but taking this as a short-cut for now. - metadata[models.METADATA_ZIPALIGN_OVERHEAD] = zipalign_total - metadata[models.METADATA_SIGNING_BLOCK_SIZE] = signing_block_size - - # Overhead includes: - # * Size of all local zip headers (minus zipalign padding). - # * Size of central directory & end of central directory. - overhead_size = (os.path.getsize(apk_spec.apk_path) - zip_info_total - - zipalign_total - signing_block_size) - assert overhead_size >= 0, 'Apk overhead must be non-negative' - zip_overhead_symbol = models.Symbol( - models.SECTION_OTHER, overhead_size, full_name='Overhead: APK file') - apk_symbols.append(zip_overhead_symbol) - _ExtendSectionRange(section_ranges, models.SECTION_OTHER, - sum(s.size for s in apk_symbols)) - return dex_size, apk_symbols - - def _CalculateElfOverhead(section_ranges, elf_path): if elf_path: section_sizes_total_without_bss = sum( @@ -1192,98 +1133,85 @@ return source_mapper, ninja_elf_object_paths -def CreateContainerSymbols(*, - container_name, - metadata, - apk_spec, - pak_spec, - native_spec, - source_directory, - output_directory=None, - resources_pathmap_path=None, - pak_id_map=None): - """Creates a Container (with sections sizes) and symbols for a SizeInfo. +def _CreateNativeSymbols(*, + metadata, + apk_spec, + native_spec, + output_directory=None, + pak_id_map=None): + """Creates native symbols for the given native_spec. Args: - container_name: Name for the created Container. May be '' if only one - Container exists. metadata: Metadata dict from CreateMetadata(). apk_spec: Instance of ApkSpec, or None. - pak_spec: Instance of PakSpec, or None. - native_spec: Instance of NativeSpec, or None. + native_spec: Instance of NativeSpec. output_directory: Build output directory. If None, source_paths and symbol alias information will not be recorded. - source_directory: Path to source root. - resources_pathmap_path: Path to the pathmap file that maps original - resource paths to shortened resource paths. - pak_id_map: Instance of PakIdMap, or None. + pak_id_map: Instance of PakIdMap. Returns: - List of symbols. + A tuple of (section_ranges, raw_symbols). """ apk_elf_result = None - if apk_spec and native_spec and native_spec.apk_so_path: + if apk_spec and native_spec.apk_so_path: # Extraction takes around 1 second, so do it in parallel. apk_elf_result = parallel.ForkAndCall( _ElfInfoFromApk, (apk_spec.apk_path, native_spec.apk_so_path, native_spec.tool_prefix)) + raw_symbols = [] ninja_source_mapper = None dwarf_source_mapper = None section_ranges = {} - raw_symbols = [] - object_paths_by_name = None - if native_spec: - ninja_elf_object_paths = None - if output_directory and native_spec.map_path: - # Finds all objects passed to the linker and creates a map of .o -> .cc. - ninja_source_mapper, ninja_elf_object_paths = _ParseNinjaFiles( - output_directory, native_spec.elf_path) - elif native_spec.elf_path: - logging.info('Parsing source path info via dwarfdump') - dwarf_source_mapper = dwarfdump.CreateAddressSourceMapper( - native_spec.elf_path, native_spec.tool_prefix) - logging.info('Found %d source paths across %s ranges', - dwarf_source_mapper.NumberOfPaths(), - dwarf_source_mapper.num_ranges) + ninja_elf_object_paths = None + if output_directory and native_spec.map_path: + # Finds all objects passed to the linker and creates a map of .o -> .cc. + ninja_source_mapper, ninja_elf_object_paths = _ParseNinjaFiles( + output_directory, native_spec.elf_path) + elif native_spec.elf_path: + logging.info('Parsing source path info via dwarfdump') + dwarf_source_mapper = dwarfdump.CreateAddressSourceMapper( + native_spec.elf_path, native_spec.tool_prefix) + logging.info('Found %d source paths across %s ranges', + dwarf_source_mapper.NumberOfPaths(), + dwarf_source_mapper.num_ranges) - # Start by finding elf_object_paths so that nm can run on them while the - # linker .map is being parsed. - if ninja_elf_object_paths: - elf_object_paths, thin_archives = ar.ExpandThinArchives( - ninja_elf_object_paths, output_directory) - known_inputs = set(elf_object_paths) - known_inputs.update(ninja_elf_object_paths) + # Start by finding elf_object_paths so that nm can run on them while the + # linker .map is being parsed. + if ninja_elf_object_paths: + elf_object_paths, thin_archives = ar.ExpandThinArchives( + ninja_elf_object_paths, output_directory) + known_inputs = set(elf_object_paths) + known_inputs.update(ninja_elf_object_paths) + else: + elf_object_paths = [] + known_inputs = None + # When we don't know which elf file is used, just search all paths. + # TODO(agrieve): Seems to be used only for tests. Remove? + if ninja_source_mapper: + thin_archives = set( + p for p in ninja_source_mapper.IterAllPaths() if p.endswith('.a') + and ar.IsThinArchive(os.path.join(output_directory, p))) else: - elf_object_paths = [] - known_inputs = None - # When we don't know which elf file is used, just search all paths. - # TODO(agrieve): Seems to be used only for tests. Remove? - if ninja_source_mapper: - thin_archives = set( - p for p in ninja_source_mapper.IterAllPaths() if p.endswith('.a') - and ar.IsThinArchive(os.path.join(output_directory, p))) - else: - thin_archives = None + thin_archives = None - outdir_context = None - if output_directory: - outdir_context = _OutputDirectoryContext( - elf_object_paths=elf_object_paths, - known_inputs=known_inputs, - output_directory=output_directory, - thin_archives=thin_archives) + outdir_context = None + if output_directory: + outdir_context = _OutputDirectoryContext(elf_object_paths=elf_object_paths, + known_inputs=known_inputs, + output_directory=output_directory, + thin_archives=thin_archives) - if native_spec.elf_path or native_spec.map_path: - section_ranges, raw_symbols, object_paths_by_name = _ParseElfInfo( - native_spec, outdir_context=outdir_context) - - if pak_id_map and native_spec.map_path: - # For trichrome, pak files are in different apks than native library, - # so need to pass along pak_id_map separately and ensure - # TrichromeLibrary appears first in .ssargs file. - logging.debug('Extracting pak IDs from symbol names') - pak_id_map.Update(object_paths_by_name, ninja_source_mapper) + object_paths_by_name = None + if native_spec.elf_path or native_spec.map_path: + section_ranges, raw_symbols, object_paths_by_name = _ParseElfInfo( + native_spec, outdir_context=outdir_context) + if pak_id_map and native_spec.map_path: + # For trichrome, pak files are in different apks than native library, + # so need to pass along pak_id_map separately and ensure + # TrichromeLibrary appears first in .ssargs file. + logging.debug('Extracting pak IDs from symbol names') + pak_id_map.Update(object_paths_by_name, ninja_source_mapper) if apk_elf_result: logging.debug('Extracting section sizes from .so within .apk') @@ -1291,7 +1219,7 @@ if metadata and models.METADATA_ELF_BUILD_ID in metadata: assert apk_build_id == metadata[models.METADATA_ELF_BUILD_ID], ( 'BuildID from apk_elf_result did not match') - elif native_spec and native_spec.elf_path: + elif native_spec.elf_path: # Strip ELF before capturing section information to avoid recording # debug sections. with tempfile.NamedTemporaryFile( @@ -1303,79 +1231,16 @@ native_spec.tool_prefix) elf_overhead_size = _CalculateElfOverhead(section_ranges, f.name) - if native_spec: - raw_symbols, other_elf_symbols = _AddUnattributedSectionSymbols( - raw_symbols, section_ranges) + raw_symbols, other_elf_symbols = _AddUnattributedSectionSymbols( + raw_symbols, section_ranges) - other_symbols = [] - if apk_spec and apk_spec.size_info_prefix: - # Can modify |section_ranges|. - dex_size, other_symbols = _ParseApkOtherSymbols( - apk_spec=apk_spec, - native_spec=native_spec, - section_ranges=section_ranges, - resources_pathmap_path=resources_pathmap_path, - metadata=metadata) - - if apk_spec.analyze_dex: - logging.info('Analyzing Dex') - dex_symbols = apkanalyzer.CreateDexSymbols(apk_spec.apk_path, - apk_spec.mapping_path, - apk_spec.size_info_prefix) - - # We can't meaningfully track section size of dex methods vs other, so - # just fake the size of dex methods as the sum of symbols, and make - # "dex other" responsible for any unattributed bytes. - dex_method_size = int( - round( - sum(s.pss for s in dex_symbols - if s.section_name == models.SECTION_DEX_METHOD))) - section_ranges[models.SECTION_DEX_METHOD] = (0, dex_method_size) - section_ranges[models.SECTION_DEX] = (0, dex_size - dex_method_size) - - dex_other_size = int( - round( - sum(s.pss for s in dex_symbols - if s.section_name == models.SECTION_DEX))) - unattributed_dex = section_ranges[models.SECTION_DEX][1] - dex_other_size - # Compare against -5 instead of 0 to guard against round-off errors. - assert unattributed_dex >= -5, ('Dex symbols take up more space than ' - 'the dex sections have available') - if unattributed_dex > 0: - dex_symbols.append( - models.Symbol( - models.SECTION_DEX, - unattributed_dex, - full_name='** .dex (unattributed - includes string literals)')) - raw_symbols.extend(dex_symbols) - - if pak_spec: - logging.debug('Creating Pak symbols') - if pak_spec.apk_pak_paths: - assert apk_spec.size_info_prefix - # Can modify |section_ranges|. - raw_symbols += pakfile.CreatePakSymbolsFromApk(section_ranges, - apk_spec.apk_path, - pak_spec.apk_pak_paths, - apk_spec.size_info_prefix, - pak_id_map) - else: - # Can modify |section_ranges|. - raw_symbols += pakfile.CreatePakSymbolsFromFiles(section_ranges, - pak_spec.pak_paths, - pak_spec.pak_info_path, - output_directory, - pak_id_map) - - if native_spec: - other_symbols.extend(other_elf_symbols) - if native_spec.elf_path: - elf_overhead_symbol = models.Symbol(models.SECTION_OTHER, - elf_overhead_size, - full_name='Overhead: ELF file') - _ExtendSectionRange(section_ranges, models.SECTION_OTHER, - elf_overhead_size) - other_symbols.append(elf_overhead_symbol) + other_symbols = other_elf_symbols + if native_spec.elf_path: + elf_overhead_symbol = models.Symbol(models.SECTION_OTHER, + elf_overhead_size, + full_name='Overhead: ELF file') + _ExtendSectionRange(section_ranges, models.SECTION_OTHER, elf_overhead_size) + other_symbols.append(elf_overhead_symbol) # Always have .other come last. other_symbols.sort(key=lambda s: (s.IsOverhead(), s.full_name.startswith( @@ -1386,20 +1251,226 @@ _AddSourcePathsUsingObjectPaths(ninja_source_mapper, raw_symbols) elif dwarf_source_mapper: _AddSourcePathsUsingAddress(dwarf_source_mapper, raw_symbols) + + # Path normalization must come before compacting aliases so that + # ancestor paths do not mix generated and non-generated paths. _NormalizePaths(raw_symbols) + logging.info('Converting excessive aliases into shared-path symbols') + _CompactLargeAliasesIntoSharedSymbols(raw_symbols) + + if native_spec.elf_path or native_spec.map_path: + logging.debug('Connecting nm aliases') + _ConnectNmAliases(raw_symbols) + + return section_ranges, raw_symbols + + +def _CreatePakSymbols(*, pak_spec, pak_id_map, apk_spec, output_directory): + logging.debug('Creating Pak symbols') + section_ranges = {} + if apk_spec: + assert apk_spec.size_info_prefix + # Can modify |section_ranges|. + raw_symbols = pakfile.CreatePakSymbolsFromApk(section_ranges, + apk_spec.apk_path, + pak_spec.apk_pak_paths, + apk_spec.size_info_prefix, + pak_id_map) + else: + # Can modify |section_ranges|. + raw_symbols = pakfile.CreatePakSymbolsFromFiles(section_ranges, + pak_spec.pak_paths, + pak_spec.pak_info_path, + output_directory, + pak_id_map) + return section_ranges, raw_symbols + + +def _CreateDexSymbols(*, apk_spec): + """Create dex symbols for the given apk_spec. + + Args: + apk_spec: Instance of ApkSpec or None. + + Returns: + A tuple of (section_ranges, raw_symbols). + """ + logging.info('Analyzing classes.dex for %s', apk_spec.split_name + or apk_spec.apk_path) + + def compute_dex_size(): + with zipfile.ZipFile(apk_spec.apk_path) as z: + return sum(i.file_size for i in z.infolist() + if i.filename.endswith('.dex')) + + dex_size_result = parallel.CallOnThread(compute_dex_size) + raw_symbols = apkanalyzer.CreateDexSymbols(apk_spec.apk_path, + apk_spec.mapping_path, + apk_spec.size_info_prefix) + dex_size = dex_size_result.get() + + sizes = collections.Counter() + for s in raw_symbols: + sizes[s.section_name] += s.pss + assert len(sizes) <= 2, 'Unexpected: ' + str(sizes) + dex_method_size = round(sizes[models.SECTION_DEX_METHOD]) + dex_other_size = round(sizes[models.SECTION_DEX]) + + unattributed_dex = dex_size - dex_method_size - dex_other_size + # Compare against -5 instead of 0 to guard against round-off errors. + assert unattributed_dex >= -5, ( + 'sum(dex_symbols.size) > filesize(classes.dex). {} vs {}'.format( + dex_method_size + dex_other_size, dex_size)) + + if unattributed_dex > 0: + raw_symbols.append( + models.Symbol( + models.SECTION_DEX, + unattributed_dex, + full_name='** .dex (unattributed - includes string literals)')) + + # We can't meaningfully track section size of dex methods vs other, so + # just fake the size of dex methods as the sum of symbols, and make + # "dex other" responsible for any unattributed bytes. + section_ranges = { + models.SECTION_DEX_METHOD: (0, dex_method_size), + models.SECTION_DEX: (0, dex_size - dex_method_size), + } + + return section_ranges, raw_symbols + + +def _CreateApkOtherSymbols(*, + metadata, + apk_spec, + native_spec, + resources_pathmap_path=None): + """Creates a Container (with sections sizes) and symbols for a SizeInfo. + + Args: + metadata: Metadata dict from CreateMetadata(). + apk_spec: Instance of ApkSpec or None. + native_spec: Instance of NativeSpec or None. + resources_pathmap_path: Path to the pathmap file that maps original + resource paths to shortened resource paths. + + Returns: + A tuple of (section_ranges, raw_symbols). + """ + logging.info('Creating symbols for other APK entries') + apk_so_path = native_spec and native_spec.apk_so_path + res_source_mapper = _ResourceSourceMapper(apk_spec.size_info_prefix, + apk_spec.path_defaults) + resource_deobfuscator = _ResourcePathDeobfuscator(resources_pathmap_path) + raw_symbols = [] + zip_info_total = 0 + zipalign_total = 0 + with zipfile.ZipFile(apk_spec.apk_path) as z: + signing_block_size = zip_util.MeasureApkSignatureBlock(z) + for zip_info in z.infolist(): + zip_info_total += zip_info.compress_size + # Account for zipalign overhead that exists in local file header. + zipalign_total += zip_util.ReadZipInfoExtraFieldLength(z, zip_info) + # Account for zipalign overhead that exists in central directory header. + # Happens when python aligns entries in apkbuilder.py, but does not + # exist when using Android's zipalign. E.g. for bundle .apks files. + zipalign_total += len(zip_info.extra) + # Skip files that we explicitly analyze: .so, .dex, and .pak. + if zip_info.filename == apk_so_path: + continue + if apk_spec.analyze_dex and zip_info.filename.endswith('.dex'): + continue + if zip_info.filename.endswith('.pak'): + continue + + resource_filename = resource_deobfuscator.MaybeRemapPath( + zip_info.filename) + source_path = res_source_mapper.FindSourceForPath(resource_filename) + if source_path is None: + source_path = os.path.join(models.APK_PREFIX_PATH, resource_filename) + raw_symbols.append( + models.Symbol( + models.SECTION_OTHER, + zip_info.compress_size, + source_path=source_path, + full_name=resource_filename)) # Full name must disambiguate + + # Store zipalign overhead and signing block size as metadata rather than an + # "Overhead:" symbol because they fluctuate in size, and would be a source of + # noise in symbol diffs if included as symbols (http://crbug.com/1130754). + # Might be even better if we had an option in Tiger Viewer to ignore certain + # symbols, but taking this as a short-cut for now. + metadata[models.METADATA_ZIPALIGN_OVERHEAD] = zipalign_total + metadata[models.METADATA_SIGNING_BLOCK_SIZE] = signing_block_size + + # Overhead includes: + # * Size of all local zip headers (minus zipalign padding). + # * Size of central directory & end of central directory. + overhead_size = (os.path.getsize(apk_spec.apk_path) - zip_info_total - + zipalign_total - signing_block_size) + assert overhead_size >= 0, 'Apk overhead must be non-negative' + zip_overhead_symbol = models.Symbol(models.SECTION_OTHER, + overhead_size, + full_name='Overhead: APK file') + raw_symbols.append(zip_overhead_symbol) + + section_ranges = {} + _ExtendSectionRange(section_ranges, models.SECTION_OTHER, + sum(s.size for s in raw_symbols)) + return section_ranges, raw_symbols + + +def CreateContainerSymbols(*, container_name, metadata, apk_spec, pak_spec, + native_spec, source_directory, output_directory, + resources_pathmap_path, pak_id_map): + raw_symbols = [] + section_sizes = {} + + def add_syms(section_ranges, new_raw_symbols): + new_section_sizes = { + k: size + for k, (address, size) in section_ranges.items() + } + if models.SECTION_OTHER in new_section_sizes: + section_sizes[models.SECTION_OTHER] = section_sizes.get( + models.SECTION_OTHER, 0) + new_section_sizes[models.SECTION_OTHER] + del new_section_sizes[models.SECTION_OTHER] + + assert not (set(section_sizes) & set(new_section_sizes)), ( + 'Section collision: {}\n\n {}'.format(section_sizes, new_section_sizes)) + section_sizes.update(new_section_sizes) + + # _CreateNativeSymbols() already calls _NormalizePaths(). + if new_raw_symbols and not new_raw_symbols[0].IsNative(): + _NormalizePaths(new_raw_symbols) + raw_symbols.extend(new_raw_symbols) + + if native_spec: + add_syms(*_CreateNativeSymbols(metadata=metadata, + apk_spec=apk_spec, + native_spec=native_spec, + output_directory=output_directory, + pak_id_map=pak_id_map)) + + if pak_spec: + add_syms(*_CreatePakSymbols(pak_spec=pak_spec, + pak_id_map=pak_id_map, + apk_spec=apk_spec, + output_directory=output_directory)) + if apk_spec: + if apk_spec.analyze_dex: + add_syms(*_CreateDexSymbols(apk_spec=apk_spec)) + add_syms( + *_CreateApkOtherSymbols(metadata=metadata, + apk_spec=apk_spec, + native_spec=native_spec, + resources_pathmap_path=resources_pathmap_path)) + default_component = apk_spec.default_component if apk_spec else '' dir_metadata.PopulateComponents(raw_symbols, source_directory, default_component=default_component) - logging.info('Converting excessive aliases into shared-path symbols') - _CompactLargeAliasesIntoSharedSymbols(raw_symbols) - - if native_spec: - logging.debug('Connecting nm aliases') - _ConnectNmAliases(raw_symbols) - - section_sizes = {k: size for k, (address, size) in section_ranges.items()} container = models.Container(name=container_name, metadata=metadata, section_sizes=section_sizes) @@ -1828,6 +1899,8 @@ or top_args.output_directory) analyze_native = not (sub_args.java_only or sub_args.no_native or top_args.java_only or top_args.no_native) + analyze_dex = not (sub_args.native_only or sub_args.no_java + or top_args.native_only or top_args.no_java) apk_path = apk_path or sub_args.apk_file if split_name: @@ -1856,8 +1929,7 @@ apk_spec.size_info_prefix = os.path.join(top_args.output_directory, 'size-info', os.path.basename(apk_prefix)) - apk_spec.analyze_dex = not (sub_args.native_only or sub_args.no_java - or top_args.native_only or top_args.no_java) + apk_spec.analyze_dex = bool(analyze_dex and apk_spec.size_info_prefix) apk_spec.path_defaults = { k: v['source_path'] for k, v in json_config.get('apk_files', {}).items()
diff --git a/tools/binary_size/libsupersize/integration_test.py b/tools/binary_size/libsupersize/integration_test.py index c90890f..50ec9b8 100755 --- a/tools/binary_size/libsupersize/integration_test.py +++ b/tools/binary_size/libsupersize/integration_test.py
@@ -296,6 +296,7 @@ native_spec=native_spec, source_directory=_TEST_SOURCE_DIR, output_directory=output_directory, + resources_pathmap_path=None, pak_id_map=pak_id_map) raw_symbols_list.append(raw_symbols)
diff --git a/tools/clang/plugins/CheckLayoutObjectMethodsVisitor.cpp b/tools/clang/plugins/CheckLayoutObjectMethodsVisitor.cpp index caa5482..b360366 100644 --- a/tools/clang/plugins/CheckLayoutObjectMethodsVisitor.cpp +++ b/tools/clang/plugins/CheckLayoutObjectMethodsVisitor.cpp
@@ -75,8 +75,9 @@ cxxRecordDecl(isSameOrDerivedFrom("::blink::LayoutObject"))), has(compoundStmt()), // Avoid matching the following cases - unless(anyOf(isConstexpr(), isDefaulted(), cxxConstructorDecl(), - cxxDestructorDecl(), isStaticStorageClass(), + unless(anyOf(isConstexpr(), isDefaulted(), isPure(), + cxxConstructorDecl(), cxxDestructorDecl(), + isStaticStorageClass(), // Do not trace lambdas (no name, possibly tracking // more parameters than intended because of [&]). hasParent(cxxRecordDecl(isLambda())), @@ -95,12 +96,14 @@ const auto* stmt = method->getBody(); assert(stmt); - auto* stmts = llvm::dyn_cast<clang::CompoundStmt>(stmt)->body_front(); - if (clang::CXXMemberCallExpr::classof(stmts)) { - auto* call = llvm::dyn_cast<clang::CXXMemberCallExpr>(stmts); - const std::string& name = call->getMethodDecl()->getNameAsString(); - if (name == "CheckIsNotDestroyed") - return; + if (!llvm::dyn_cast<clang::CompoundStmt>(stmt)->body_empty()) { + auto* stmts = llvm::dyn_cast<clang::CompoundStmt>(stmt)->body_front(); + if (clang::CXXMemberCallExpr::classof(stmts)) { + auto* call = llvm::dyn_cast<clang::CXXMemberCallExpr>(stmts); + const std::string& name = call->getMethodDecl()->getNameAsString(); + if (name == "CheckIsNotDestroyed") + return; + } } auto* type = method->getParent();
diff --git a/tools/clang/plugins/tests/layout_object_methods.h b/tools/clang/plugins/tests/layout_object_methods.h index cdaae5b..08239b5 100644 --- a/tools/clang/plugins/tests/layout_object_methods.h +++ b/tools/clang/plugins/tests/layout_object_methods.h
@@ -15,7 +15,7 @@ public: // These methods should be ignored. static void StaticMethod() {} - void CheckIsNotDestroyed() {} + void CheckIsNotDestroyed() const {} void Trace(Visitor*) const {} int ShouldPass1() { @@ -28,19 +28,24 @@ foo(); return 0; } + + virtual void VirtualEmptyMethod() = 0; + void EmptyMethod() {} }; class LayoutBoxModelObject : public LayoutObject { public: - int ShouldPass2() { + int ShouldPass2() const { CheckIsNotDestroyed(); return 0; } - int ShouldFail2() { + int ShouldFail2() const { ShouldPass2(); CheckIsNotDestroyed(); // This should be the first statement. return 0; } + + void VirtualEmptyMethod() override {} }; class LayoutBox : public LayoutBoxModelObject { @@ -54,6 +59,8 @@ CheckIsNotDestroyed(); // This should be the first statement. return 0; } + + void VirtualEmptyMethod() override {} }; } // namespace blink
diff --git a/tools/clang/plugins/tests/layout_object_methods.txt b/tools/clang/plugins/tests/layout_object_methods.txt index 71e025a..a57e685 100644 --- a/tools/clang/plugins/tests/layout_object_methods.txt +++ b/tools/clang/plugins/tests/layout_object_methods.txt
@@ -2,10 +2,19 @@ ./layout_object_methods.h:26:3: warning: [layout] LayoutObject's method 'ShouldFail1' in 'LayoutObject' must call CheckIsNotDestroyed() at the beginning. int ShouldFail1() { ^~~~~~~~~~~~~~~~~~~ -./layout_object_methods.h:39:3: warning: [layout] LayoutObject's method 'ShouldFail2' in 'LayoutBoxModelObject' must call CheckIsNotDestroyed() at the beginning. - int ShouldFail2() { - ^~~~~~~~~~~~~~~~~~~ -./layout_object_methods.h:52:3: warning: [layout] LayoutObject's method 'ShouldFail3' in 'LayoutBox' must call CheckIsNotDestroyed() at the beginning. +./layout_object_methods.h:33:3: warning: [layout] LayoutObject's method 'EmptyMethod' in 'LayoutObject' must call CheckIsNotDestroyed() at the beginning. + void EmptyMethod() {} + ^~~~~~~~~~~~~~~~~~~~~ +./layout_object_methods.h:42:3: warning: [layout] LayoutObject's method 'ShouldFail2' in 'LayoutBoxModelObject' must call CheckIsNotDestroyed() at the beginning. + int ShouldFail2() const { + ^~~~~~~~~~~~~~~~~~~~~~~~~ +./layout_object_methods.h:48:3: warning: [layout] LayoutObject's method 'VirtualEmptyMethod' in 'LayoutBoxModelObject' must call CheckIsNotDestroyed() at the beginning. + void VirtualEmptyMethod() override {} + ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +./layout_object_methods.h:57:3: warning: [layout] LayoutObject's method 'ShouldFail3' in 'LayoutBox' must call CheckIsNotDestroyed() at the beginning. int ShouldFail3() { ^~~~~~~~~~~~~~~~~~~ -3 warnings generated. +./layout_object_methods.h:63:3: warning: [layout] LayoutObject's method 'VirtualEmptyMethod' in 'LayoutBox' must call CheckIsNotDestroyed() at the beginning. + void VirtualEmptyMethod() override {} + ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +6 warnings generated.
diff --git a/tools/clang/scripts/update.py b/tools/clang/scripts/update.py index 4b55f6f..23d0f00 100755 --- a/tools/clang/scripts/update.py +++ b/tools/clang/scripts/update.py
@@ -35,7 +35,7 @@ # Reverting problematic clang rolls is safe, though. # This is the output of `git describe` and is usable as a commit-ish. CLANG_REVISION = 'llvmorg-14-init-12719-gc4b45eeb' -CLANG_SUB_REVISION = 2 +CLANG_SUB_REVISION = 3 PACKAGE_VERSION = '%s-%s' % (CLANG_REVISION, CLANG_SUB_REVISION) RELEASE_VERSION = '14.0.0'
diff --git a/tools/gritsettings/resource_ids.spec b/tools/gritsettings/resource_ids.spec index a552aea9..90c9af0 100644 --- a/tools/gritsettings/resource_ids.spec +++ b/tools/gritsettings/resource_ids.spec
@@ -529,7 +529,7 @@ "includes": [3320], }, "<(SHARED_INTERMEDIATE_DIR)/ash/webui/personalization_app/resources/ash_personalization_app_resources.grd": { - "META": {"sizes": {"includes": [50],}}, + "META": {"sizes": {"includes": [120],}}, "includes": [3340], }, "<(SHARED_INTERMEDIATE_DIR)/ash/webui/demo_mode_app_ui/ash_demo_mode_app_resources.grd": {
diff --git a/tools/mac/power/BUILD.gn b/tools/mac/power/BUILD.gn index cb6091ead..b64ca06 100644 --- a/tools/mac/power/BUILD.gn +++ b/tools/mac/power/BUILD.gn
@@ -31,6 +31,8 @@ "power_sampler/csv_exporter.h", "power_sampler/json_exporter.cc", "power_sampler/json_exporter.h", + "power_sampler/m1_sampler.h", + "power_sampler/m1_sampler.mm", "power_sampler/main_display_sampler.cc", "power_sampler/main_display_sampler.h", "power_sampler/monitor.cc", @@ -78,6 +80,7 @@ "power_sampler/battery_sampler_unittest.cc", "power_sampler/csv_exporter_unittest.cc", "power_sampler/json_exporter_unittest.cc", + "power_sampler/m1_sampler_unittest.mm", "power_sampler/main_display_sampler_unittest.cc", "power_sampler/resource_coalition_sampler_unittest.cc", "power_sampler/sampling_controller_unittest.cc",
diff --git a/tools/mac/power/power_sampler/m1_sampler.h b/tools/mac/power/power_sampler/m1_sampler.h new file mode 100644 index 0000000..5542be6 --- /dev/null +++ b/tools/mac/power/power_sampler/m1_sampler.h
@@ -0,0 +1,44 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef TOOLS_MAC_POWER_POWER_SAMPLER_M1_SAMPLER_H_ +#define TOOLS_MAC_POWER_POWER_SAMPLER_M1_SAMPLER_H_ + +#include <memory> + +#include "tools/mac/power/power_sampler/sampler.h" + +namespace power_metrics { +class M1SensorsReader; +} + +namespace power_sampler { + +// The M1 sensors sampler samples the temperature of M1 P-Cores and E-Cores. +class M1Sampler : public Sampler { + public: + static constexpr char kSamplerName[] = "m1"; + + ~M1Sampler() override; + + // Creates and initializes a new sampler, if possible. + // Returns nullptr on failure. + static std::unique_ptr<M1Sampler> Create(); + + // Sampler implementation. + std::string GetName() override; + DatumNameUnits GetDatumNameUnits() override; + Sample GetSample(base::TimeTicks sample_time) override; + + private: + friend class M1SamplerTest; + + M1Sampler(std::unique_ptr<power_metrics::M1SensorsReader> reader); + + std::unique_ptr<power_metrics::M1SensorsReader> reader_; +}; + +} // namespace power_sampler + +#endif // TOOLS_MAC_POWER_POWER_SAMPLER_M1_SAMPLER_H_
diff --git a/tools/mac/power/power_sampler/m1_sampler.mm b/tools/mac/power/power_sampler/m1_sampler.mm new file mode 100644 index 0000000..4f6049f --- /dev/null +++ b/tools/mac/power/power_sampler/m1_sampler.mm
@@ -0,0 +1,61 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "tools/mac/power/power_sampler/m1_sampler.h" + +#include "base/memory/ptr_util.h" +#include "base/strings/string_piece.h" +#include "components/power_metrics/m1_sensors_mac.h" + +namespace power_sampler { + +namespace { + +void MaybeAddToSample(Sampler::Sample* sample, + base::StringPiece name, + absl::optional<double> val) { + if (val.has_value()) + sample->emplace(name, val.value()); +} + +} // namespace + +M1Sampler::~M1Sampler() = default; + +// static +std::unique_ptr<M1Sampler> M1Sampler::Create() { + std::unique_ptr<power_metrics::M1SensorsReader> reader = + power_metrics::M1SensorsReader::Create(); + if (!reader) + return nullptr; + return base::WrapUnique(new M1Sampler(std::move(reader))); +} + +std::string M1Sampler::GetName() { + return kSamplerName; +} + +Sampler::DatumNameUnits M1Sampler::GetDatumNameUnits() { + DatumNameUnits ret{{"p_cores_temperature", "C"}, + {"e_cores_temperature", "C"}}; + return ret; +} + +Sampler::Sample M1Sampler::GetSample(base::TimeTicks sample_time) { + Sample sample; + power_metrics::M1SensorsReader::TemperaturesCelsius temperatures = + reader_->ReadTemperatures(); + + MaybeAddToSample(&sample, "p_cores_temperature", temperatures.p_cores); + MaybeAddToSample(&sample, "e_cores_temperature", temperatures.e_cores); + + return sample; +} + +M1Sampler::M1Sampler(std::unique_ptr<power_metrics::M1SensorsReader> reader) + : reader_(std::move(reader)) { + DCHECK(reader_); +} + +} // namespace power_sampler
diff --git a/tools/mac/power/power_sampler/m1_sampler_unittest.mm b/tools/mac/power/power_sampler/m1_sampler_unittest.mm new file mode 100644 index 0000000..681a2f1 --- /dev/null +++ b/tools/mac/power/power_sampler/m1_sampler_unittest.mm
@@ -0,0 +1,100 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "tools/mac/power/power_sampler/m1_sampler.h" + +#include <memory> + +#include "base/containers/flat_map.h" +#include "base/memory/ptr_util.h" +#include "components/power_metrics/m1_sensors_mac.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "third_party/abseil-cpp/absl/types/optional.h" +#include "tools/mac/power/power_sampler/battery_sampler.h" + +namespace power_sampler { + +namespace { + +using testing::UnorderedElementsAre; + +class TestM1SensorsReader : public power_metrics::M1SensorsReader { + public: + TestM1SensorsReader() + : power_metrics::M1SensorsReader( + base::ScopedCFTypeRef<IOHIDEventSystemClientRef>()) {} + + void set_temperatures(TemperaturesCelsius temperatures) { + temperatures_ = temperatures; + } + + // power_metrics::M1SensorsReader: + TemperaturesCelsius ReadTemperatures() override { return temperatures_; } + + private: + TemperaturesCelsius temperatures_; +}; + +} // namespace + +class M1SamplerTest : public testing::Test { + public: + M1SamplerTest() { + std::unique_ptr<TestM1SensorsReader> reader = + std::make_unique<TestM1SensorsReader>(); + reader_ = reader.get(); + sampler_ = base::WrapUnique(new M1Sampler(std::move(reader))); + } + + TestM1SensorsReader* reader_ = nullptr; + std::unique_ptr<M1Sampler> sampler_; +}; + +TEST_F(M1SamplerTest, NameAndGetDatumNameUnits) { + EXPECT_EQ("m1", sampler_->GetName()); + + auto datum_name_units = sampler_->GetDatumNameUnits(); + EXPECT_THAT(datum_name_units, + UnorderedElementsAre(std::make_pair("p_cores_temperature", "C"), + std::make_pair("e_cores_temperature", "C"))); +} + +TEST_F(M1SamplerTest, GetSample_AllFieldsAvailable) { + power_metrics::M1SensorsReader::TemperaturesCelsius temperatures; + temperatures.p_cores = 1; + temperatures.e_cores = 2; + reader_->set_temperatures(temperatures); + + Sampler::Sample sample = sampler_->GetSample(base::TimeTicks()); + EXPECT_THAT(sample, + UnorderedElementsAre(std::make_pair("p_cores_temperature", 1), + std::make_pair("e_cores_temperature", 2))); +} + +TEST_F(M1SamplerTest, GetSample_IndividualFieldNotAvailable) { + { + power_metrics::M1SensorsReader::TemperaturesCelsius temperatures; + temperatures.p_cores = absl::nullopt; + temperatures.e_cores = 2; + reader_->set_temperatures(temperatures); + + Sampler::Sample sample = sampler_->GetSample(base::TimeTicks()); + EXPECT_THAT(sample, + UnorderedElementsAre(std::make_pair("e_cores_temperature", 2))); + } + + { + power_metrics::M1SensorsReader::TemperaturesCelsius temperatures; + temperatures.p_cores = 1; + temperatures.e_cores = absl::nullopt; + reader_->set_temperatures(temperatures); + + Sampler::Sample sample = sampler_->GetSample(base::TimeTicks()); + EXPECT_THAT(sample, + UnorderedElementsAre(std::make_pair("p_cores_temperature", 1))); + } +} + +} // namespace power_sampler
diff --git a/tools/mac/power/power_sampler/power_sampler_main.cc b/tools/mac/power/power_sampler/power_sampler_main.cc index 9250d2b..da9a9f1 100644 --- a/tools/mac/power/power_sampler/power_sampler_main.cc +++ b/tools/mac/power/power_sampler/power_sampler_main.cc
@@ -20,6 +20,7 @@ #include "tools/mac/power/power_sampler/battery_sampler.h" #include "tools/mac/power/power_sampler/csv_exporter.h" #include "tools/mac/power/power_sampler/json_exporter.h" +#include "tools/mac/power/power_sampler/m1_sampler.h" #include "tools/mac/power/power_sampler/main_display_sampler.h" #include "tools/mac/power/power_sampler/resource_coalition_sampler.h" #include "tools/mac/power/power_sampler/sample_counter.h" @@ -234,6 +235,13 @@ return kStatusRuntimeError; } } + if (ConsumeSamplerName(power_sampler::M1Sampler::kSamplerName, + sampler_names) || + all_samplers) { + if (!MaybeAddSamplerToController<power_sampler::M1Sampler>(controller)) { + return kStatusRuntimeError; + } + } if (ConsumeSamplerName(power_sampler::UserIdleLevelSampler::kSamplerName, sampler_names) || all_samplers) {
diff --git a/tools/mac/power/power_sampler/smc_sampler_unittest.mm b/tools/mac/power/power_sampler/smc_sampler_unittest.mm index b696da3..4e7b850 100644 --- a/tools/mac/power/power_sampler/smc_sampler_unittest.mm +++ b/tools/mac/power/power_sampler/smc_sampler_unittest.mm
@@ -3,6 +3,7 @@ // found in the LICENSE file. #include "tools/mac/power/power_sampler/smc_sampler.h" + #include <memory> #include "base/containers/flat_map.h"
diff --git a/tools/mb/mb_config.pyl b/tools/mb/mb_config.pyl index b9c4463..3d2a243 100644 --- a/tools/mb/mb_config.pyl +++ b/tools/mb/mb_config.pyl
@@ -277,8 +277,8 @@ 'reclient': 'gpu_tests_release_bot_reclient', }, 'Comparison Windows': { - 'goma': 'gpu_tests_release_bot', - 'reclient': 'gpu_tests_release_bot_reclient', + 'goma': 'gpu_tests_release_bot_minimal_symbols', + 'reclient': 'gpu_tests_release_bot_minimal_symbols_reclient', }, 'Libfuzzer Upload Chrome OS ASan': 'libfuzzer_chromeos_asan_release_bot', 'Libfuzzer Upload Linux ASan': 'libfuzzer_asan_release_bot', @@ -291,7 +291,7 @@ 'Linux Builder (j-500) (reclient)': 'gpu_tests_release_bot_reclient', 'Linux CFI (reclient shadow)': 'cfi_full_cfi_icall_cfi_diag_thin_lto_release_static_dcheck_always_on_reclient', 'Linux Builder (j-500) (reclient)': 'gpu_tests_release_bot_reclient', - 'Linux Viz': 'release_trybot_minimal_symbols', + 'Linux Viz': 'release_trybot_minimal_symbols_reclient', 'lacros-amd64-generic-rel-fyi': 'chromeos_amd64-generic_lacros_rel', 'linux-ash-chromium-builder-fyi-rel': 'chromeos_with_codecs_release_bot', 'linux-lacros-builder-fyi-rel': 'lacros_on_linux_release_bot', @@ -453,7 +453,7 @@ 'GPU Mac Builder': 'gpu_tests_release_trybot_minimal_symbols', 'GPU Mac Builder (dbg)': 'gpu_tests_debug_bot', 'GPU Linux Builder': 'gpu_tests_release_trybot_minimal_symbols', - 'GPU Linux Builder (dbg)': 'gpu_tests_debug_bot', + 'GPU Linux Builder (dbg)': 'gpu_tests_debug_bot_reclient', 'GPU Win x64 Builder': 'gpu_tests_release_bot_dcheck_always_on_resource_allowlisting', 'GPU Win x64 Builder Code Coverage': 'gpu_tests_release_trybot_resource_allowlisting_code_coverage', 'GPU Win x64 Builder (dbg)': 'gpu_tests_debug_bot', @@ -480,7 +480,7 @@ 'GPU FYI Win x64 DX12 Vulkan Builder': 'gpu_fyi_tests_dx12vk_release_trybot', 'GPU FYI Win x64 DX12 Vulkan Builder (dbg)': 'gpu_fyi_tests_dx12vk_debug_trybot', 'GPU FYI XR Win x64 Builder': 'gpu_fyi_tests_release_trybot', - 'Linux FYI GPU TSAN Release': 'gpu_fyi_tests_release_trybot_tsan', + 'Linux FYI GPU TSAN Release': 'gpu_fyi_tests_release_trybot_tsan_reclient', 'Linux FYI SkiaRenderer Dawn Release (Intel HD 630)': 'gpu_tests_sk_dawn_release_trybot', 'Optional Android Release (Nexus 5X)': 'gpu_tests_android_release_trybot_arm64', 'Optional Android Release (Pixel 4)': 'gpu_tests_android_release_trybot', @@ -1194,6 +1194,7 @@ 'win10_chromium_x64_rel_ng_exp': 'release_trybot', 'win10_chromium_x64_rel_ng_rts': 'gpu_tests_release_trybot_resource_allowlisting_code_coverage', 'win10-rel-orchestrator': 'gpu_tests_release_trybot_resource_allowlisting_code_coverage', + 'win11-x64-fyi-rel': 'gpu_tests_release_trybot_resource_allowlisting_code_coverage', 'win-annotator-rel': 'release_trybot', 'win-asan': 'asan_clang_fuzzer_static_v8_heap_minimal_symbols_release', 'win-celab-try-rel': 'release_bot_minimal_symbols', @@ -2236,6 +2237,10 @@ 'gpu_fyi_tests', 'release_trybot_minimal_symbols', 'tsan', 'disable_nacl', ], + 'gpu_fyi_tests_release_trybot_tsan_reclient': [ + 'gpu_fyi_tests', 'release_trybot_minimal_symbols_reclient', 'tsan', 'disable_nacl', + ], + 'gpu_fyi_tests_release_trybot_x86': [ 'gpu_fyi_tests', 'release_trybot_minimal_symbols', 'x86', 'disable_nacl', ],
diff --git a/tools/mb/mb_config_expectations/chromium.fyi.json b/tools/mb/mb_config_expectations/chromium.fyi.json index 2320959..d458782 100644 --- a/tools/mb/mb_config_expectations/chromium.fyi.json +++ b/tools/mb/mb_config_expectations/chromium.fyi.json
@@ -56,6 +56,7 @@ "is_component_build": false, "is_debug": false, "proprietary_codecs": true, + "symbol_level": 1, "use_goma": true }, "reclient": { @@ -64,6 +65,7 @@ "is_component_build": false, "is_debug": false, "proprietary_codecs": true, + "symbol_level": 1, "use_rbe": true, "use_remoteexec": true } @@ -230,7 +232,8 @@ "is_component_build": false, "is_debug": false, "symbol_level": 1, - "use_goma": true + "use_rbe": true, + "use_remoteexec": true } }, "Mac Builder (reclient compare)": {
diff --git a/tools/mb/mb_config_expectations/chromium.gpu.fyi.json b/tools/mb/mb_config_expectations/chromium.gpu.fyi.json index c88bf19..725ec0a 100644 --- a/tools/mb/mb_config_expectations/chromium.gpu.fyi.json +++ b/tools/mb/mb_config_expectations/chromium.gpu.fyi.json
@@ -296,7 +296,8 @@ "is_tsan": true, "proprietary_codecs": true, "symbol_level": 1, - "use_goma": true + "use_rbe": true, + "use_remoteexec": true } }, "Linux FYI SkiaRenderer Dawn Release (Intel HD 630)": {
diff --git a/tools/mb/mb_config_expectations/chromium.gpu.json b/tools/mb/mb_config_expectations/chromium.gpu.json index 287c485..360fc0c 100644 --- a/tools/mb/mb_config_expectations/chromium.gpu.json +++ b/tools/mb/mb_config_expectations/chromium.gpu.json
@@ -36,7 +36,8 @@ "is_debug": true, "proprietary_codecs": true, "symbol_level": 1, - "use_goma": true + "use_rbe": true, + "use_remoteexec": true } }, "GPU Mac Builder": {
diff --git a/tools/mb/mb_config_expectations/tryserver.chromium.win.json b/tools/mb/mb_config_expectations/tryserver.chromium.win.json index f5a90f8..91960ea 100644 --- a/tools/mb/mb_config_expectations/tryserver.chromium.win.json +++ b/tools/mb/mb_config_expectations/tryserver.chromium.win.json
@@ -334,6 +334,21 @@ "use_goma": true } }, + "win11-x64-fyi-rel": { + "gn_args": { + "blink_enable_generated_code_formatting": false, + "coverage_instrumentation_input_file": "//.code-coverage/files_to_instrument.txt", + "dcheck_always_on": true, + "enable_resource_allowlist_generation": false, + "ffmpeg_branding": "Chrome", + "is_component_build": false, + "is_debug": false, + "proprietary_codecs": true, + "symbol_level": 0, + "use_clang_coverage": true, + "use_goma": true + } + }, "win7-rel": { "gn_args": { "blink_enable_generated_code_formatting": false,
diff --git a/tools/metrics/histograms/enums.xml b/tools/metrics/histograms/enums.xml index d0c9ef6..a13f62a 100644 --- a/tools/metrics/histograms/enums.xml +++ b/tools/metrics/histograms/enums.xml
@@ -32124,6 +32124,18 @@ <int value="1213" label="Pairing Failed: Error UI Dismissed By User"/> </enum> +<enum name="FastPairHandshakeFailureReason"> + <int value="0" label="Failed to initialize GATT connection"/> + <int value="1" label="Failed to create FastPairDataEncryptor"/> + <int value="2" + label="Failed to write to the Key-based Pairing Characteristic"/> + <int value="3" + label="Failed to decrypt the the Key-based Pairing Characteristic + response"/> + <int value="4" + label="Incorrect Key-based Pairing Characteristic response type"/> +</enum> + <enum name="FastPairPairFailure"> <int value="0" label="Failed to create a GATT connection to the device"/> <int value="1" label="Failed to find the expected GATT service"/> @@ -36417,7 +36429,8 @@ <int value="3884" label="WebAppManifestProtocolHandlers"/> <int value="3885" label="RTCPeerConnectionOfferAllowExtmapMixedFalse"/> <int value="3886" label="NewCanvas2DAPI"/> - <int value="3887" label="ServiceWorkerSubresourceFilterBypassedRequest"/> + <int value="3887" + label="OBSOLETE_ServiceWorkerSubresourceFilterBypassedRequest"/> <int value="3888" label="WebGPU"/> <int value="3889" label="CSSFilterColorMatrix"/> <int value="3890" label="HTMLFencedFrameElement"/> @@ -72989,6 +73002,8 @@ <int value="20" label="kPendingNotificationCloseEvent"/> <int value="21" label="kFeedbackDialog"/> <int value="22" label="kWebAppUpdate"/> + <int value="23" label="kGettingWebAppInfo"/> + <int value="24" label="kCrxInstaller"/> </enum> <enum name="ProfileMenuActionableItem"> @@ -83995,6 +84010,12 @@ <int value="9" label="Failed"/> </enum> +<enum name="StatefulFormat"> + <summary>ChromeOS Stateful Partition Format</summary> + <int value="0" label="Raw partition"/> + <int value="1" label="Logical volume"/> +</enum> + <enum name="StateStoreInitResult"> <int value="0" label="Kept pref: platform-specific store non-existent or not
diff --git a/tools/metrics/histograms/metadata/arc/histograms.xml b/tools/metrics/histograms/metadata/arc/histograms.xml index c8d57a5a..74b1434 100644 --- a/tools/metrics/histograms/metadata/arc/histograms.xml +++ b/tools/metrics/histograms/metadata/arc/histograms.xml
@@ -617,8 +617,8 @@ <histogram name="Arc.CompanionLibraryApisCounter" enum="CompanionLibraryApisList" expires_after="2022-04-10"> <owner>sstan@google.com</owner> - <owner>bartfab@chromium.org</owner> - <owner>giovax@google.com</owner> + <owner>mhasank@google.com</owner> + <owner>arc-commercial@google.com</owner> <summary> Records the number of times ChromeOS Companion Library API called. Counter adding when its Stub library receive the call from applications. @@ -1857,7 +1857,7 @@ <histogram name="Arc.Supervision.Transition.Result" enum="ArcSupervisionTransitionResult" expires_after="2022-08-03"> - <owner>giovax@chromium.org</owner> + <owner>mhasank@chromium.org</owner> <owner>arc-commercial@google.com</owner> <summary> The result (success or the type of failure) of ARC supervision transition
diff --git a/tools/metrics/histograms/metadata/bluetooth/histograms.xml b/tools/metrics/histograms/metadata/bluetooth/histograms.xml index 78ebb77..ae5e567 100644 --- a/tools/metrics/histograms/metadata/bluetooth/histograms.xml +++ b/tools/metrics/histograms/metadata/bluetooth/histograms.xml
@@ -383,6 +383,58 @@ </summary> </histogram> +<histogram name="Bluetooth.ChromeOS.FastPair.FastPairRepository.Cache.Result" + enum="Boolean" expires_after="2022-09-20"> + <owner>shanefitz@google.com</owner> + <owner>julietlevesque@google.com</owner> + <owner>chromeos-cross-device-eng@google.com</owner> + <summary> + Records whether or not device metadata we retrieve is in the cache in the + repository. Records true when the metadata is in the cache, and false when + the metadata is not in the cache. Emitted in the FastPairRepository on a + GetDeviceMetadata request. + </summary> +</histogram> + +<histogram name="Bluetooth.ChromeOS.FastPair.FootprintsFetcher.Delete.Result" + enum="BooleanSuccess" expires_after="2022-09-20"> + <owner>shanefitz@google.com</owner> + <owner>julietlevesque@google.com</owner> + <owner>chromeos-cross-device-eng@google.com</owner> + <summary> + Records the success or failure of a Delete request in the FootprintsFetcher. + A failure is considered no response from the server. A success is considered + a response that is able to be parsed. Emitted when the HTTP response is + received from the Footprints server. + </summary> +</histogram> + +<histogram name="Bluetooth.ChromeOS.FastPair.FootprintsFetcher.Get.Result" + enum="BooleanSuccess" expires_after="2022-09-20"> + <owner>shanefitz@google.com</owner> + <owner>julietlevesque@google.com</owner> + <owner>chromeos-cross-device-eng@google.com</owner> + <summary> + Records the success or failure of a Get request in the FootprintsFetcher. A + failure is considered either no response, or a response that is unable to be + parsed. A success is considered a response that is able to be parsed. + Emitted when the HTTP response is received from the Footprints server. + </summary> +</histogram> + +<histogram name="Bluetooth.ChromeOS.FastPair.FootprintsFetcher.Post.Result" + enum="BooleanSuccess" expires_after="2022-09-20"> + <owner>shanefitz@google.com</owner> + <owner>julietlevesque@google.com</owner> + <owner>chromeos-cross-device-eng@google.com</owner> + <summary> + Records the success or failure of a Post request in the FootprintsFetcher. A + failure is considered no response from the server. A success is considered a + response that is able to be parsed. Emitted when the HTTP response is + received from the Footprints server. + </summary> +</histogram> + <histogram name="Bluetooth.ChromeOS.FastPair.GattConnection.ErrorReason" enum="BluetoothDeviceConnectErrorCode" expires_after="2022-09-20"> <owner>shanefitz@google.com</owner> @@ -407,6 +459,32 @@ </summary> </histogram> +<histogram name="Bluetooth.ChromeOS.FastPair.Handshake.FailureReason" + enum="FastPairHandshakeFailureReason" expires_after="2022-09-20"> + <owner>shanefitz@google.com</owner> + <owner>julietlevesque@google.com</owner> + <owner>chromeos-cross-device-eng@google.com</owner> + <summary> + Records reason behind the failure of the big picture Fast Pair Handshake + with a device (see 'Bluetooth.ChromeOS.FastPair.Handshake.Result'). Emitted + on Handshake failure. + </summary> +</histogram> + +<histogram name="Bluetooth.ChromeOS.FastPair.Handshake.Result" + enum="BooleanSuccess" expires_after="2022-09-20"> + <owner>shanefitz@google.com</owner> + <owner>julietlevesque@google.com</owner> + <owner>chromeos-cross-device-eng@google.com</owner> + <summary> + Records the success or failure of the big picture Fast Pair Handshake with a + device. Possible failures include : creating a Gatt connection, creating the + data encryptor, writing to the device, parsing the response bytes from the + device. Success is considered when we receive a valid response from the + device. Emitted when the Handshake reaches a terminal point. + </summary> +</histogram> + <histogram name="Bluetooth.ChromeOS.FastPair.KeyBasedPairing.DecryptResult" enum="BooleanSuccess" expires_after="2022-09-20"> <owner>shanefitz@google.com</owner>
diff --git a/tools/metrics/histograms/metadata/platform/histograms.xml b/tools/metrics/histograms/metadata/platform/histograms.xml index 99ff3dbc..d262006 100644 --- a/tools/metrics/histograms/metadata/platform/histograms.xml +++ b/tools/metrics/histograms/metadata/platform/histograms.xml
@@ -884,6 +884,13 @@ </summary> </histogram> +<histogram name="Platform.StatefulFormat" enum="StatefulFormat" + expires_after="2022-10-21"> + <owner>sarthakkukreti@chromium.org</owner> + <owner>gwendal@chromium.org</owner> + <summary>Chrome OS stateful partition format. Sampled once per boot.</summary> +</histogram> + <histogram name="Platform.StatefulFreeSpace" units="MB" expires_after="2022-10-21"> <owner>asavery@chromium.org</owner>
diff --git a/tools/metrics/histograms/metadata/startup/histograms.xml b/tools/metrics/histograms/metadata/startup/histograms.xml index 40fc591..078e604 100644 --- a/tools/metrics/histograms/metadata/startup/histograms.xml +++ b/tools/metrics/histograms/metadata/startup/histograms.xml
@@ -259,6 +259,19 @@ </summary> </histogram> +<histogram name="Startup.Android.LastVisitedTabIsSRPWhenOverviewShownAtLaunch" + enum="Boolean" expires_after="2022-08-09"> + <owner>hanxi@chromium.org</owner> + <owner>spdonghao@chromium.org</owner> + <owner>fredmello@chromium.org</owner> + <summary> + Records whether or not the last visited tab is a search result page when + StartSurface is shown at launch. This histogram is only recorded when + StartSurface is shown at launch due to "return to tab switcher" + feature. + </summary> +</histogram> + <histogram name="Startup.Android.ShowChromeStartSegmentationResult" enum="ShowChromeStartSegmentationResult" expires_after="2022-08-09"> <owner>hanxi@chromium.org</owner>
diff --git a/tools/traffic_annotation/summary/annotations.xml b/tools/traffic_annotation/summary/annotations.xml index 8b127ee3..d44bde8 100644 --- a/tools/traffic_annotation/summary/annotations.xml +++ b/tools/traffic_annotation/summary/annotations.xml
@@ -352,4 +352,5 @@ <item id="wallpaper_backdrop_surprise_me_image" added_in_milestone="99" content_hash_code="0383b9db" os_list="chromeos" file_path="chrome/browser/ash/wallpaper_handlers/wallpaper_handlers.cc" /> <item id="nearby_connections_wifi_lan" added_in_milestone="99" content_hash_code="06b30010" os_list="chromeos" file_path="chrome/services/sharing/nearby/platform/wifi_lan_medium.cc" /> <item id="timezone_lookup" added_in_milestone="98" content_hash_code="01e64e71" os_list="chromeos" file_path="ash/components/timezone/timezone_request.cc" /> + <item id="oma_download_handler_android" added_in_milestone="99" content_hash_code="03226fe2" os_list="android" file_path="chrome/android/java/src/org/chromium/chrome/browser/download/OMADownloadHandler.java" /> </annotations>
diff --git a/tools/traffic_annotation/summary/grouping.xml b/tools/traffic_annotation/summary/grouping.xml index 22e449c..2b3f9adc0 100644 --- a/tools/traffic_annotation/summary/grouping.xml +++ b/tools/traffic_annotation/summary/grouping.xml
@@ -36,6 +36,7 @@ <traffic_annotation unique_id="gstatic_onboarding_definition"/> <traffic_annotation unique_id="kids_chrome_management_client_classify_url"/> <traffic_annotation unique_id="minidump_uploader_android"/> + <traffic_annotation unique_id="oma_download_handler_android"/> <traffic_annotation unique_id="partner_bookmarks_reader_get_favicon"/> <traffic_annotation unique_id="permission_request_creator"/> <traffic_annotation unique_id="publish_note_request"/>
diff --git a/ui/accessibility/accessibility_features.cc b/ui/accessibility/accessibility_features.cc index 2297086..bbe6a2867 100644 --- a/ui/accessibility/accessibility_features.cc +++ b/ui/accessibility/accessibility_features.cc
@@ -150,7 +150,7 @@ const base::Feature kExperimentalAccessibilityDictationExtension{ "ExperimentalAccessibilityDictationExtension", - base::FEATURE_DISABLED_BY_DEFAULT}; + base::FEATURE_ENABLED_BY_DEFAULT}; bool IsExperimentalAccessibilityDictationExtensionEnabled() { return base::FeatureList::IsEnabled(
diff --git a/ui/accessibility/platform/inspect/ax_api_type.cc b/ui/accessibility/platform/inspect/ax_api_type.cc index a2b511b..4d6ebf32 100644 --- a/ui/accessibility/platform/inspect/ax_api_type.cc +++ b/ui/accessibility/platform/inspect/ax_api_type.cc
@@ -4,6 +4,8 @@ #include "ui/accessibility/platform/inspect/ax_api_type.h" +#include <cstring> + namespace ui { namespace {
diff --git a/ui/base/models/table_model.cc b/ui/base/models/table_model.cc index 228008a..42a19ca 100644 --- a/ui/base/models/table_model.cc +++ b/ui/base/models/table_model.cc
@@ -7,6 +7,7 @@ #include "base/check.h" #include "base/i18n/string_compare.h" #include "base/notreached.h" +#include "third_party/icu/source/i18n/unicode/coll.h" #include "ui/base/l10n/l10n_util.h" #include "ui/base/models/image_model.h" @@ -43,7 +44,7 @@ // TableModel ----------------------------------------------------------------- // Used for sorting. -static icu::Collator* g_collator = NULL; +static icu::Collator* g_collator = nullptr; ui::ImageModel TableModel::GetIcon(int row) { return ui::ImageModel(); @@ -54,8 +55,11 @@ } int TableModel::CompareValues(int row1, int row2, int column_id) { - DCHECK(row1 >= 0 && row1 < RowCount() && - row2 >= 0 && row2 < RowCount()); + DCHECK_GE(row1, 0); + DCHECK_LT(row1, RowCount()); + DCHECK_GE(row2, 0); + DCHECK_LT(row2, RowCount()); + std::u16string value1 = GetText(row1, column_id); std::u16string value2 = GetText(row2, column_id); icu::Collator* collator = GetCollator(); @@ -69,15 +73,17 @@ void TableModel::ClearCollator() { delete g_collator; - g_collator = NULL; + g_collator = nullptr; } +TableModel::~TableModel() = default; + icu::Collator* TableModel::GetCollator() { if (!g_collator) { UErrorCode create_status = U_ZERO_ERROR; g_collator = icu::Collator::createInstance(create_status); if (!U_SUCCESS(create_status)) { - g_collator = NULL; + g_collator = nullptr; NOTREACHED(); } }
diff --git a/ui/base/models/table_model.h b/ui/base/models/table_model.h index 7cd5fca..f32af40 100644 --- a/ui/base/models/table_model.h +++ b/ui/base/models/table_model.h
@@ -6,10 +6,14 @@ #define UI_BASE_MODELS_TABLE_MODEL_H_ #include <string> -#include <vector> #include "base/component_export.h" -#include "third_party/icu/source/i18n/unicode/coll.h" +#include "third_party/icu/source/common/unicode/uversion.h" + +// third_party/icu/source/common/unicode/uversion.h will set namespace icu. +namespace U_ICU_NAMESPACE { +class Collator; +} namespace ui { @@ -55,7 +59,7 @@ void ClearCollator(); protected: - virtual ~TableModel() {} + virtual ~TableModel(); // Returns the collator used by CompareValues. icu::Collator* GetCollator();
diff --git a/ui/compositor/compositor_animation_observer.cc b/ui/compositor/compositor_animation_observer.cc index 6c5d76d8..b3a9724 100644 --- a/ui/compositor/compositor_animation_observer.cc +++ b/ui/compositor/compositor_animation_observer.cc
@@ -13,10 +13,11 @@ namespace ui { -// Do not fail on SANITIZER builds as they run slow. -#if !DCHECK_IS_ON() || defined(ADDRESS_SANITIZER) || \ - defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) || \ - defined(LEAK_SANITIZER) || defined(UNDEFINED_SANITIZER) +// Do not fail on builds that run slow, such as SANITIZER, debug. +#if !DCHECK_IS_ON() || defined(ADDRESS_SANITIZER) || \ + defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) || \ + defined(LEAK_SANITIZER) || defined(UNDEFINED_SANITIZER) || \ + !defined(NDEBUG) #define NOTREACHED_OR_WARN() LOG(WARNING) #else #define NOTREACHED_OR_WARN() NOTREACHED()
diff --git a/ui/ozone/platform/drm/gpu/drm_overlay_candidates.cc b/ui/ozone/platform/drm/gpu/drm_overlay_candidates.cc index 24f2610..1ca2737 100644 --- a/ui/ozone/platform/drm/gpu/drm_overlay_candidates.cc +++ b/ui/ozone/platform/drm/gpu/drm_overlay_candidates.cc
@@ -4,6 +4,7 @@ #include "ui/ozone/platform/drm/gpu/drm_overlay_candidates.h" +#include "media/media_buildflags.h" #include "ui/ozone/platform/drm/gpu/drm_overlay_manager.h" #include "ui/ozone/public/overlay_surface_candidate.h" @@ -13,11 +14,20 @@ gfx::AcceleratedWidget widget) : overlay_manager_(manager), widget_(widget) {} -DrmOverlayCandidates::~DrmOverlayCandidates() = default; +DrmOverlayCandidates::~DrmOverlayCandidates() { + overlay_manager_->RegisterOverlayRequirement(widget_, false); +} void DrmOverlayCandidates::CheckOverlaySupport( std::vector<OverlaySurfaceCandidate>* candidates) { overlay_manager_->CheckOverlaySupport(candidates, widget_); } +void DrmOverlayCandidates::RegisterOverlayRequirement(bool requires_overlay) { +#if !BUILDFLAG(USE_CHROMEOS_PROTECTED_MEDIA) + DCHECK(!requires_overlay); +#endif + overlay_manager_->RegisterOverlayRequirement(widget_, requires_overlay); +} + } // namespace ui
diff --git a/ui/ozone/platform/drm/gpu/drm_overlay_candidates.h b/ui/ozone/platform/drm/gpu/drm_overlay_candidates.h index a680368..cbbd6d7 100644 --- a/ui/ozone/platform/drm/gpu/drm_overlay_candidates.h +++ b/ui/ozone/platform/drm/gpu/drm_overlay_candidates.h
@@ -29,6 +29,7 @@ // OverlayCandidatesOzone: void CheckOverlaySupport( std::vector<OverlaySurfaceCandidate>* candidates) override; + void RegisterOverlayRequirement(bool requires_overlay) override; private: DrmOverlayManager* const overlay_manager_; // Not owned.
diff --git a/ui/ozone/platform/drm/gpu/drm_overlay_manager.cc b/ui/ozone/platform/drm/gpu/drm_overlay_manager.cc index 29030d9c..79227e5 100644 --- a/ui/ozone/platform/drm/gpu/drm_overlay_manager.cc +++ b/ui/ozone/platform/drm/gpu/drm_overlay_manager.cc
@@ -66,6 +66,15 @@ TRACE_EVENT0("hwoverlays", "DrmOverlayManager::CheckOverlaySupport"); DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); + // Check if another display has an overlay requirement, and if so do not + // allow overlays. Some ChromeOS boards only support one overlay across all + // displays so we want the overlay to go somewhere that requires it first vs. + // a display that will just be using it as an optimization. + if (!widgets_with_required_overlays_.empty() && + !widgets_with_required_overlays_.contains(widget)) { + return; + } + std::vector<OverlaySurfaceCandidate> result_candidates; for (auto& candidate : *candidates) { bool can_handle = CanHandleCandidate(candidate, widget); @@ -140,6 +149,15 @@ cache_hit); } +void DrmOverlayManager::RegisterOverlayRequirement( + gfx::AcceleratedWidget widget, + bool requires_overlay) { + if (requires_overlay) + widgets_with_required_overlays_.insert(widget); + else + widgets_with_required_overlays_.erase(widget); +} + bool DrmOverlayManager::CanHandleCandidate( const OverlaySurfaceCandidate& candidate, gfx::AcceleratedWidget widget) const {
diff --git a/ui/ozone/platform/drm/gpu/drm_overlay_manager.h b/ui/ozone/platform/drm/gpu/drm_overlay_manager.h index 43b9fac..c9927dd4 100644 --- a/ui/ozone/platform/drm/gpu/drm_overlay_manager.h +++ b/ui/ozone/platform/drm/gpu/drm_overlay_manager.h
@@ -8,6 +8,7 @@ #include <memory> #include <vector> +#include "base/containers/flat_set.h" #include "base/containers/lru_cache.h" #include "base/threading/thread_checker.h" #include "ui/gfx/native_widget_types.h" @@ -43,6 +44,12 @@ void CheckOverlaySupport(std::vector<OverlaySurfaceCandidate>* candidates, gfx::AcceleratedWidget widget); + // Should be called by the overlay processor to indicate if a widget has a + // candidate that requires an overlay. This is to prioritize which display + // gets the overlay in a multiple display environment. + void RegisterOverlayRequirement(gfx::AcceleratedWidget widget, + bool requires_overlay); + protected: // Sends a request to see if overlay configuration will work. Implementations // should call UpdateCacheForOverlayCandidates() with the response. @@ -89,6 +96,8 @@ std::map<gfx::AcceleratedWidget, OverlayCandidatesListCache> widget_cache_map_; + base::flat_set<gfx::AcceleratedWidget> widgets_with_required_overlays_; + THREAD_CHECKER(thread_checker_); };
diff --git a/ui/ozone/platform/drm/gpu/drm_overlay_manager_unittest.cc b/ui/ozone/platform/drm/gpu/drm_overlay_manager_unittest.cc index d49fed0..24e93a0 100644 --- a/ui/ozone/platform/drm/gpu/drm_overlay_manager_unittest.cc +++ b/ui/ozone/platform/drm/gpu/drm_overlay_manager_unittest.cc
@@ -271,4 +271,41 @@ manager.CheckOverlaySupport(&candidates, kPrimaryWidget); } +TEST(DrmOverlayManagerTest, RequiredOverlayMultiDisplay) { + TestDrmOverlayManager manager; + + // Primary has a requirement, secondary does not, should only make a request + // on the primary. + std::vector<OverlaySurfaceCandidate> candidates1 = { + CreateCandidate(gfx::Rect(0, 0, 100, 100), 0)}; + + manager.RegisterOverlayRequirement(kPrimaryWidget, true); + manager.RegisterOverlayRequirement(kSecondaryWidget, false); + // Call 4 Times to go beyond the Throttle Request Size + for (int i = 0; i < 4; ++i) + manager.CheckOverlaySupport(&candidates1, kPrimaryWidget); + EXPECT_EQ(manager.requests().size(), 1u); + // Call 4 Times to go beyond the Throttle Request Size + for (int i = 0; i < 4; ++i) + manager.CheckOverlaySupport(&candidates1, kSecondaryWidget); + EXPECT_EQ(manager.requests().size(), 1u); + manager.requests().clear(); + + // Secondary has a requirement, primary does not, should only make a request + // on the secondary. + std::vector<OverlaySurfaceCandidate> candidates2 = { + CreateCandidate(gfx::Rect(0, 0, 200, 200), 0)}; + + manager.RegisterOverlayRequirement(kPrimaryWidget, false); + manager.RegisterOverlayRequirement(kSecondaryWidget, true); + // Call 4 Times to go beyond the Throttle Request Size + for (int i = 0; i < 4; ++i) + manager.CheckOverlaySupport(&candidates2, kPrimaryWidget); + EXPECT_TRUE(manager.requests().empty()); + // Call 4 Times to go beyond the Throttle Request Size + for (int i = 0; i < 4; ++i) + manager.CheckOverlaySupport(&candidates2, kSecondaryWidget); + EXPECT_EQ(manager.requests().size(), 1u); +} + } // namespace ui
diff --git a/ui/ozone/platform/scenic/scenic_surface.cc b/ui/ozone/platform/scenic/scenic_surface.cc index 45b1e8f..9d61fa3 100644 --- a/ui/ozone/platform/scenic/scenic_surface.cc +++ b/ui/ozone/platform/scenic/scenic_surface.cc
@@ -136,13 +136,6 @@ UpdateViewHolderScene(); break; } - case fuchsia::ui::gfx::Event::kViewDetachedFromScene: { - DCHECK(event.gfx().view_detached_from_scene().view_id == parent_->id()); - // Present an empty frame to ensure that the outdated content doesn't - // become visible if the view is attached again. - PresentEmptyImage(); - break; - } default: break; } @@ -526,32 +519,6 @@ safe_presenter_.QueuePresent(); } -void ScenicSurface::PresentEmptyImage() { - if (last_frame_present_time_ == base::TimeTicks()) - return; - - fuchsia::sysmem::BufferCollectionTokenSyncPtr dummy_collection_token; - zx_status_t status = - sysmem_buffer_manager_->GetAllocator()->AllocateSharedCollection( - dummy_collection_token.NewRequest()); - if (status != ZX_OK) { - ZX_DLOG(ERROR, status) - << "fuchsia.sysmem.Allocator.AllocateSharedCollection()"; - return; - } - - const uint32_t image_id = ++next_unique_id_; - image_pipe_->AddBufferCollection(image_id, std::move(dummy_collection_token)); - fuchsia::sysmem::ImageFormat_2 image_format; - image_format.coded_width = 1; - image_format.coded_height = 1; - image_pipe_->AddImage(image_id, image_id, 0, image_format); - - image_pipe_->PresentImage(image_id, last_frame_present_time_.ToZxTime(), {}, - {}, [](fuchsia::images::PresentationInfo) {}); - image_pipe_->RemoveBufferCollection(image_id); -} - ScenicSurface::PresentedFrame::PresentedFrame( uint32_t ordinal, uint32_t image_id,
diff --git a/ui/ozone/platform/scenic/scenic_surface.h b/ui/ozone/platform/scenic/scenic_surface.h index 5587aec6..b242c9a 100644 --- a/ui/ozone/platform/scenic/scenic_surface.h +++ b/ui/ozone/platform/scenic/scenic_surface.h
@@ -141,8 +141,6 @@ void OnPresentComplete(fuchsia::images::PresentationInfo presentation_info); void UpdateViewHolderScene(); - void PresentEmptyImage(); - scenic::Session scenic_session_; std::unique_ptr<scenic::View> parent_;
diff --git a/ui/ozone/platform/wayland/host/wayland_window_drag_controller_unittest.cc b/ui/ozone/platform/wayland/host/wayland_window_drag_controller_unittest.cc index ecd049c..b1efbab 100644 --- a/ui/ozone/platform/wayland/host/wayland_window_drag_controller_unittest.cc +++ b/ui/ozone/platform/wayland/host/wayland_window_drag_controller_unittest.cc
@@ -572,7 +572,9 @@ // 5. Drag it a bit more (within window 2) and then calls EndMoveLoop(), // emulating a window snap), and then // 6. With the window in "snapped" state, drag it further and then drop. -TEST_P(WaylandWindowDragControllerTest, DragToOtherWindowSnapDragDrop_TOUCH) { +// TODO(crbug.com/1285380): Reenable test when flakiness is fixed. +TEST_P(WaylandWindowDragControllerTest, + DISABLED_DragToOtherWindowSnapDragDrop_TOUCH) { // Init and open |target_window|. PlatformWindowInitProperties properties{gfx::Rect{80, 80}}; properties.type = PlatformWindowType::kWindow;
diff --git a/ui/ozone/public/overlay_candidates_ozone.h b/ui/ozone/public/overlay_candidates_ozone.h index 2f8df59e..27e32a9a 100644 --- a/ui/ozone/public/overlay_candidates_ozone.h +++ b/ui/ozone/public/overlay_candidates_ozone.h
@@ -27,6 +27,10 @@ // if necessary. virtual void CheckOverlaySupport(OverlaySurfaceCandidateList* surfaces); + // This should be invoked during overlay processing to indicate if there are + // any candidates for this processor that have an overlay requirement. + virtual void RegisterOverlayRequirement(bool requires_overlay) {} + virtual ~OverlayCandidatesOzone(); };
diff --git a/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_item.html b/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_item.html index eebb09c..b35427c 100644 --- a/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_item.html +++ b/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_item.html
@@ -7,6 +7,13 @@ max-height: 32px; } + #deviceName { + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + width: 404px; + } + #container:hover { background-color: var(--cr-hover-background-color); cursor: pointer;
diff --git a/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_item.js b/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_item.js index b5e6130a0..fc3cc99f 100644 --- a/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_item.js +++ b/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_item.js
@@ -73,6 +73,16 @@ }; } + /** @override */ + focus() { + // Prevent scroll stops iron list from trying to bring this element to view, + // if it is the |lastFocused| element and scrolled out of view. This can + // happen if this element is tabbed to or selected and then scrolled out of + // view. + // TODO(b/210743107) Add a test for this. + this.$.container.focus({preventScroll: true}); + } + /** * @return {boolean} * @private
diff --git a/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_selection_page.html b/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_selection_page.html index ec4bfab..395c47cb 100644 --- a/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_selection_page.html +++ b/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_selection_page.html
@@ -21,9 +21,9 @@ <localized-link localized-string="[[i18nAdvanced('bluetoothPairingLearnMoreLabel')]]"> </localized-link> - <div id="deviceListTitle" tabindex="0"> + <h2 id="deviceListTitle" aria-live="polite"> [[getDeviceListTitle_(devices.*, isBluetoothEnabled)]] - </div> + </h2> <template is="dom-if" if="[[shouldShowDeviceList_(devices.*, isBluetoothEnabled)]]" restamp> <div id="container" class="layout vertical flex" scrollable
diff --git a/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_selection_page.js b/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_selection_page.js index 4ab6acc..448e27d 100644 --- a/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_selection_page.js +++ b/ui/webui/resources/cr_components/chromeos/bluetooth/bluetooth_pairing_device_selection_page.js
@@ -47,6 +47,7 @@ devices: { type: Array, value: [], + observer: 'onDevicesChanged_', }, /** @@ -132,6 +133,15 @@ } /** @private */ + onDevicesChanged_() { + // CrScrollableBehaviorInterface method required for list items to be + // properly rendered when devices updates. This is because iron-list size + // is not fixed, if this is not called iron-list container would not be + // properly sized. + this.updateScrollableContents(); + } + + /** @private */ onDevicePendingPairingChanged_() { // If |devicePendingPairing_| has changed to a defined value, it was the // last selected device. |devicePendingPairing_| gets reset to null whenever
diff --git a/ui/webui/resources/cr_elements/cr_button/cr_button.html b/ui/webui/resources/cr_elements/cr_button/cr_button.html index 481c386..214f571e 100644 --- a/ui/webui/resources/cr_elements/cr_button/cr_button.html +++ b/ui/webui/resources/cr_elements/cr_button/cr_button.html
@@ -11,18 +11,18 @@ <style include="cr-hidden-style"> :host { --active-shadow-rgb: var(--google-grey-800-rgb); - --active-shadow-action-rgb: var(--google-blue-refresh-500-rgb); + --active-shadow-action-rgb: var(--google-blue-500-rgb); --bg-action: var(--google-blue-600); - --border-color: var(--google-grey-refresh-300); - --disabled-bg-action: var(--google-grey-refresh-100); + --border-color: var(--google-grey-300); + --disabled-bg-action: var(--google-grey-100); --disabled-bg: white; - --disabled-border-color: var(--google-grey-refresh-100); + --disabled-border-color: var(--google-grey-100); --disabled-text-color: var(--google-grey-600); --focus-shadow-color: rgba(var(--google-blue-600-rgb), .4); --hover-bg-action: rgba(var(--google-blue-600-rgb), .9); - --hover-bg-color: rgba(var(--google-blue-refresh-500-rgb), .04); - --hover-border-color: var(--google-blue-refresh-100); - --hover-shadow-action-rgb: var(--google-blue-refresh-500-rgb); + --hover-bg-color: rgba(var(--google-blue-500-rgb), .04); + --hover-border-color: var(--google-blue-100); + --hover-shadow-action-rgb: var(--google-blue-500-rgb); --ink-color-action: white; /* Blue-ish color used either as a background or as a text color, * depending on the type of button. */ @@ -39,24 +39,24 @@ --active-bg: black linear-gradient(rgba(255, 255, 255, .06), rgba(255, 255, 255, .06)); --active-shadow-rgb: 0, 0, 0; - --active-shadow-action-rgb: var(--google-blue-refresh-500-rgb); - --bg-action: var(--google-blue-refresh-300); - --border-color: var(--google-grey-refresh-700); + --active-shadow-action-rgb: var(--google-blue-500-rgb); + --bg-action: var(--google-blue-300); + --border-color: var(--google-grey-700); --disabled-bg-action: var(--google-grey-800); /* TODO(dbeam): get --disabled-bg from Namrata. */ --disabled-bg: transparent; --disabled-border-color: var(--google-grey-800); - --disabled-text-color: var(--google-grey-refresh-500); - --focus-shadow-color: rgba(var(--google-blue-refresh-300-rgb), .5); + --disabled-text-color: var(--google-grey-500); + --focus-shadow-color: rgba(var(--google-blue-300-rgb), .5); --hover-bg-action: var(--bg-action) linear-gradient(rgba(0, 0, 0, .08), rgba(0, 0, 0, .08)); - --hover-bg-color: rgba(var(--google-blue-refresh-300-rgb), .08); + --hover-bg-color: rgba(var(--google-blue-300-rgb), .08); --ink-color-action: black; - --ink-color: var(--google-blue-refresh-300); + --ink-color: var(--google-blue-300); --ripple-opacity-action: .16; --ripple-opacity: .16; --text-color-action: var(--google-grey-900); - --text-color: var(--google-blue-refresh-300); + --text-color: var(--google-blue-300); } }
diff --git a/ui/webui/resources/cr_elements/cr_checkbox/cr_checkbox.html b/ui/webui/resources/cr_elements/cr_checkbox/cr_checkbox.html index 0014143..05ff980 100644 --- a/ui/webui/resources/cr_elements/cr_checkbox/cr_checkbox.html +++ b/ui/webui/resources/cr_elements/cr_checkbox/cr_checkbox.html
@@ -46,7 +46,7 @@ --cr-checkbox-checked-ripple-opacity: .2; --cr-checkbox-mark-color: white; --cr-checkbox-ripple-unchecked-color: var(--google-grey-900); - --cr-checkbox-unchecked-box-color: var(--google-grey-refresh-700); + --cr-checkbox-unchecked-box-color: var(--google-grey-700); --cr-checkbox-unchecked-ripple-opacity: .15; } @@ -55,8 +55,8 @@ /* Dark mode colors. */ --cr-checkbox-checked-ripple-opacity: .4; --cr-checkbox-mark-color: var(--google-grey-900); - --cr-checkbox-ripple-unchecked-color: var(--google-grey-refresh-500); - --cr-checkbox-unchecked-box-color: var(--google-grey-refresh-500); + --cr-checkbox-ripple-unchecked-color: var(--google-grey-500); + --cr-checkbox-unchecked-box-color: var(--google-grey-500); --cr-checkbox-unchecked-ripple-opacity: .4; } }
diff --git a/ui/webui/resources/cr_elements/cr_dialog/cr_dialog.html b/ui/webui/resources/cr_elements/cr_dialog/cr_dialog.html index 42af22f..087a2ad 100644 --- a/ui/webui/resources/cr_elements/cr_dialog/cr_dialog.html +++ b/ui/webui/resources/cr_elements/cr_dialog/cr_dialog.html
@@ -30,7 +30,7 @@ @media (prefers-color-scheme: dark) { dialog { - --scroll-border-color: var(--google-grey-refresh-700); + --scroll-border-color: var(--google-grey-700); background-color: var(--google-grey-900); /* Note: the colors in linear-gradient() are intentionally the same to * add a 4% white layer on top of the fully opaque background-color. */
diff --git a/ui/webui/resources/cr_elements/cr_icon_button/cr_icon_button.html b/ui/webui/resources/cr_elements/cr_icon_button/cr_icon_button.html index efb1d0d..e7ec685 100644 --- a/ui/webui/resources/cr_elements/cr_icon_button/cr_icon_button.html +++ b/ui/webui/resources/cr_elements/cr_icon_button/cr_icon_button.html
@@ -7,7 +7,7 @@ <template> <style> :host { - --cr-icon-button-fill-color: var(--google-grey-refresh-700); + --cr-icon-button-fill-color: var(--google-grey-700); --cr-icon-button-icon-start-offset: 0; --cr-icon-button-icon-size: 20px; --cr-icon-button-size: 36px; @@ -103,7 +103,7 @@ @media (prefers-color-scheme: dark) { :host { - --cr-icon-button-fill-color: var(--google-grey-refresh-500); + --cr-icon-button-fill-color: var(--google-grey-500); } } </style>
diff --git a/ui/webui/resources/cr_elements/cr_icons_css.html b/ui/webui/resources/cr_elements/cr_icons_css.html index 24df8913..ab07360 100644 --- a/ui/webui/resources/cr_elements/cr_icons_css.html +++ b/ui/webui/resources/cr_elements/cr_icons_css.html
@@ -81,7 +81,7 @@ -webkit-mask-position: center; -webkit-mask-repeat: no-repeat; -webkit-mask-size: var(--cr-icon-size); - background-color: var(--cr-icon-color, var(--google-grey-refresh-700)); + background-color: var(--cr-icon-color, var(--google-grey-700)); flex-shrink: 0; height: var(--cr-icon-ripple-size); margin-inline-end: var(--cr-icon-ripple-margin); @@ -101,7 +101,7 @@ @media (prefers-color-scheme: dark) { .cr-icon { - background-color: var(--cr-icon-color, var(--google-grey-refresh-500)); + background-color: var(--cr-icon-color, var(--google-grey-500)); } } </style>
diff --git a/ui/webui/resources/cr_elements/cr_input/cr_input_style_css.html b/ui/webui/resources/cr_elements/cr_input/cr_input_style_css.html index 3ce3bbb..8422436 100644 --- a/ui/webui/resources/cr_elements/cr_input/cr_input_style_css.html +++ b/ui/webui/resources/cr_elements/cr_input/cr_input_style_css.html
@@ -8,7 +8,7 @@ <template> <style> :host { - --cr-input-background-color: var(--google-grey-refresh-100); + --cr-input-background-color: var(--google-grey-100); --cr-input-color: var(--cr-primary-text-color); --cr-input-error-color: var(--google-red-600); --cr-input-focus-color: var(--google-blue-600); @@ -21,8 +21,8 @@ @media (prefers-color-scheme: dark) { :host { --cr-input-background-color: rgba(0, 0, 0, .3); - --cr-input-error-color: var(--google-red-refresh-300); - --cr-input-focus-color: var(--google-blue-refresh-300); + --cr-input-error-color: var(--google-red-300); + --cr-input-focus-color: var(--google-blue-300); } }
diff --git a/ui/webui/resources/cr_elements/cr_link_row/cr_link_row.html b/ui/webui/resources/cr_elements/cr_link_row/cr_link_row.html index aa663ee..d3eb55a5 100644 --- a/ui/webui/resources/cr_elements/cr_link_row/cr_link_row.html +++ b/ui/webui/resources/cr_elements/cr_link_row/cr_link_row.html
@@ -15,7 +15,7 @@ #startIcon { --iron-icon-fill-color: var(--cr-link-row-start-icon-color, - var(--google-grey-refresh-700)); + var(--google-grey-700)); display: flex; flex-shrink: 0; padding-inline-end: var(--cr-icon-button-margin-start); @@ -25,7 +25,7 @@ @media (prefers-color-scheme: dark) { #startIcon { --iron-icon-fill-color: var(--cr-link-row-start-icon-color, - var(--google-grey-refresh-500)); + var(--google-grey-500)); } }
diff --git a/ui/webui/resources/cr_elements/cr_nav_menu_item_style.html b/ui/webui/resources/cr_elements/cr_nav_menu_item_style.html index 5cbb718..02d2cb0 100644 --- a/ui/webui/resources/cr_elements/cr_nav_menu_item_style.html +++ b/ui/webui/resources/cr_elements/cr_nav_menu_item_style.html
@@ -13,7 +13,7 @@ } :host-context([enable-branding-update]) .cr-nav-menu-item { - --iron-icon-fill-color: var(--google-grey-refresh-700); + --iron-icon-fill-color: var(--google-grey-700); border-end-end-radius: 100px; border-start-end-radius: 100px; box-sizing: border-box; @@ -49,12 +49,12 @@ :host-context([enable-branding-update]) .cr-nav-menu-item[selected] { --iron-icon-fill-color: var(--google-blue-600); background: var(--google-blue-50); - color: var(--google-blue-refresh-700); + color: var(--google-blue-700); } @media (prefers-color-scheme: dark) { :host-context([enable-branding-update]) .cr-nav-menu-item { - --iron-icon-fill-color: var(--google-grey-refresh-500); + --iron-icon-fill-color: var(--google-grey-500); color: white; } @@ -65,7 +65,7 @@ :host-context([enable-branding-update]) .cr-nav-menu-item[selected] { --iron-icon-fill-color: black; - background: var(--google-blue-refresh-300); + background: var(--google-blue-300); color: var(--google-grey-900); }
diff --git a/ui/webui/resources/cr_elements/cr_profile_avatar_selector/cr_profile_avatar_selector.html b/ui/webui/resources/cr_elements/cr_profile_avatar_selector/cr_profile_avatar_selector.html index fe45be1..4d4059f 100644 --- a/ui/webui/resources/cr_elements/cr_profile_avatar_selector/cr_profile_avatar_selector.html +++ b/ui/webui/resources/cr_elements/cr_profile_avatar_selector/cr_profile_avatar_selector.html
@@ -6,7 +6,7 @@ } #avatar-grid .avatar { - --avatar-focus-color: var(--google-grey-refresh-700); + --avatar-focus-color: var(--google-grey-700); --avatar-gap-color: white; --avatar-gap-width: 2px; --avatar-selected-color: var(--google-blue-500); @@ -29,9 +29,9 @@ @media (prefers-color-scheme: dark) { #avatar-grid .avatar { - --avatar-focus-color: var(--google-grey-refresh-500); + --avatar-focus-color: var(--google-grey-500); --avatar-gap-color: var(--google-grey-800); - --avatar-selected-color: var(--google-blue-refresh-300); + --avatar-selected-color: var(--google-blue-300); } }
diff --git a/ui/webui/resources/cr_elements/cr_radio_button/cr_radio_button_style_css.html b/ui/webui/resources/cr_elements/cr_radio_button/cr_radio_button_style_css.html index 38704200..a3fdf64c 100644 --- a/ui/webui/resources/cr_elements/cr_radio_button/cr_radio_button_style_css.html +++ b/ui/webui/resources/cr_elements/cr_radio_button/cr_radio_button_style_css.html
@@ -13,7 +13,7 @@ rgba(var(--google-blue-600-rgb), .2); --cr-radio-button-ink-size: 40px; --cr-radio-button-size: 16px; - --cr-radio-button-unchecked-color: var(--google-grey-refresh-700); + --cr-radio-button-unchecked-color: var(--google-grey-700); --cr-radio-button-unchecked-ripple-color: rgba(var(--google-grey-600-rgb), .15); @@ -27,12 +27,12 @@ @media (prefers-color-scheme: dark) { :host { - --cr-radio-button-checked-color: var(--google-blue-refresh-300); + --cr-radio-button-checked-color: var(--google-blue-300); --cr-radio-button-checked-ripple-color: - rgba(var(--google-blue-refresh-300-rgb), .4); - --cr-radio-button-unchecked-color: var(--google-grey-refresh-500); + rgba(var(--google-blue-300-rgb), .4); + --cr-radio-button-unchecked-color: var(--google-grey-500); --cr-radio-button-unchecked-ripple-color: - rgba(var(--google-grey-refresh-300-rgb), .4); + rgba(var(--google-grey-300-rgb), .4); } }
diff --git a/ui/webui/resources/cr_elements/cr_search_field/cr_search_field.html b/ui/webui/resources/cr_elements/cr_search_field/cr_search_field.html index 9fea65eb5..a05bc091 100644 --- a/ui/webui/resources/cr_elements/cr_search_field/cr_search_field.html +++ b/ui/webui/resources/cr_elements/cr_search_field/cr_search_field.html
@@ -2,7 +2,7 @@ :host { display: flex; user-select: none; - --cr-search-field-clear-icon-fill: var(--google-grey-refresh-700); + --cr-search-field-clear-icon-fill: var(--google-grey-700); --cr-search-field-clear-icon-margin-end : -4px; --cr-search-field-input-border-bottom: 1px solid var(--cr-secondary-text-color); }
diff --git a/ui/webui/resources/cr_elements/cr_slider/cr_slider.html b/ui/webui/resources/cr_elements/cr_slider/cr_slider.html index a0e8f6e9..98732f99 100644 --- a/ui/webui/resources/cr_elements/cr_slider/cr_slider.html +++ b/ui/webui/resources/cr_elements/cr_slider/cr_slider.html
@@ -23,17 +23,17 @@ @media (prefers-color-scheme: dark) { :host { - --cr-slider-active-color: var(--google-blue-refresh-300); + --cr-slider-active-color: var(--google-blue-300); --cr-slider-container-color: - rgba(var(--google-blue-refresh-500-rgb), .48); + rgba(var(--google-blue-500-rgb), .48); --cr-slider-container-disabled-color: rgba(var(--google-grey-600-rgb), .48); /* --cr-slider-disabled-color is the same in dark mode (GG600). */ - --cr-slider-knob-color-rgb: var(--google-blue-refresh-300-rgb); + --cr-slider-knob-color-rgb: var(--google-blue-300-rgb); --cr-slider-knob-disabled-color: var(--google-grey-900-white-4-percent); - --cr-slider-marker-active-color: var(--google-blue-refresh-300); - --cr-slider-marker-color: var(--google-blue-refresh-300); + --cr-slider-marker-active-color: var(--google-blue-300); + --cr-slider-marker-color: var(--google-blue-300); --cr-slider-marker-disabled-color: rgba(255, 255, 255, .54); --cr-slider-ripple-color: rgba(var(--cr-slider-knob-color-rgb), .4); }
diff --git a/ui/webui/resources/cr_elements/cr_tabs/cr_tabs.html b/ui/webui/resources/cr_elements/cr_tabs/cr_tabs.html index 1005206..c8158da 100644 --- a/ui/webui/resources/cr_elements/cr_tabs/cr_tabs.html +++ b/ui/webui/resources/cr_elements/cr_tabs/cr_tabs.html
@@ -37,7 +37,7 @@ @media (prefers-color-scheme: dark) { .selected { - color: var(--cr-tabs-selected-color, var(--google-blue-refresh-300)); + color: var(--cr-tabs-selected-color, var(--google-blue-300)); } } @@ -62,7 +62,7 @@ @media (prefers-color-scheme: dark) { .selected .tab-icon { - background-color: var(--cr-tabs-selected-color, var(--google-blue-refresh-300)); + background-color: var(--cr-tabs-selected-color, var(--google-blue-300)); } } @@ -96,7 +96,7 @@ @media (prefers-color-scheme: dark) { .tab-indicator { - background: var(--cr-tabs-selected-color, var(--google-blue-refresh-300)); + background: var(--cr-tabs-selected-color, var(--google-blue-300)); } } </style>
diff --git a/ui/webui/resources/cr_elements/cr_toast/cr_toast.html b/ui/webui/resources/cr_elements/cr_toast/cr_toast.html index 7547715..ba028b86 100644 --- a/ui/webui/resources/cr_elements/cr_toast/cr_toast.html +++ b/ui/webui/resources/cr_elements/cr_toast/cr_toast.html
@@ -9,7 +9,7 @@ :host { --cr-toast-background: var(--google-grey-900) linear-gradient(rgba(255, 255, 255, .06), rgba(255, 255, 255, .06)); - --cr-toast-button-color: var(--google-blue-refresh-300); + --cr-toast-button-color: var(--google-blue-300); --cr-toast-text-color: var(--google-grey-200); } }
diff --git a/ui/webui/resources/cr_elements/cr_toggle/cr_toggle.html b/ui/webui/resources/cr_elements/cr_toggle/cr_toggle.html index c7e12ea..41e0f42 100644 --- a/ui/webui/resources/cr_elements/cr_toggle/cr_toggle.html +++ b/ui/webui/resources/cr_elements/cr_toggle/cr_toggle.html
@@ -26,26 +26,26 @@ @media (prefers-color-scheme: dark) { :host { - --cr-toggle-checked-bar-color: var(--google-blue-refresh-300); - --cr-toggle-checked-button-color: var(--google-blue-refresh-300); + --cr-toggle-checked-bar-color: var(--google-blue-300); + --cr-toggle-checked-button-color: var(--google-blue-300); --cr-toggle-checked-ripple-color: - rgba(var(--google-blue-refresh-300-rgb), .4); - --cr-toggle-unchecked-bar-color: var(--google-grey-refresh-500); - --cr-toggle-unchecked-button-color: var(--google-grey-refresh-300); + rgba(var(--google-blue-300-rgb), .4); + --cr-toggle-unchecked-bar-color: var(--google-grey-500); + --cr-toggle-unchecked-button-color: var(--google-grey-300); --cr-toggle-unchecked-ripple-color: - rgba(var(--google-grey-refresh-300-rgb), .4); + rgba(var(--google-grey-300-rgb), .4); } } /* Keep the prefers-color-scheme and [dark] rules the same. */ :host([dark]) { - --cr-toggle-checked-bar-color: var(--google-blue-refresh-300); - --cr-toggle-checked-button-color: var(--google-blue-refresh-300); + --cr-toggle-checked-bar-color: var(--google-blue-300); + --cr-toggle-checked-button-color: var(--google-blue-300); --cr-toggle-checked-ripple-color: - rgba(var(--google-blue-refresh-300-rgb), .4); - --cr-toggle-unchecked-bar-color: var(--google-grey-refresh-500); - --cr-toggle-unchecked-button-color: var(--google-grey-refresh-300); + rgba(var(--google-blue-300-rgb), .4); + --cr-toggle-unchecked-bar-color: var(--google-grey-500); + --cr-toggle-unchecked-button-color: var(--google-grey-300); --cr-toggle-unchecked-ripple-color: - rgba(var(--google-grey-refresh-300-rgb), .4); + rgba(var(--google-grey-300-rgb), .4); } :host([disabled]) {
diff --git a/ui/webui/resources/cr_elements/cr_toolbar/cr_toolbar_search_field.html b/ui/webui/resources/cr_elements/cr_toolbar/cr_toolbar_search_field.html index 499e941..c854656 100644 --- a/ui/webui/resources/cr_elements/cr_toolbar/cr_toolbar_search_field.html +++ b/ui/webui/resources/cr_elements/cr_toolbar/cr_toolbar_search_field.html
@@ -30,7 +30,7 @@ :host-context([enable-branding-update]) cr-icon-button { --cr-icon-button-fill-color: var( --cr-toolbar-search-field-input-icon-color, - var(--google-grey-refresh-700)); + var(--google-grey-700)); } } @@ -38,7 +38,7 @@ cr-icon-button { --cr-icon-button-fill-color: var( --cr-toolbar-search-field-input-icon-color, - var(--google-grey-refresh-500)); + var(--google-grey-500)); } } @@ -54,7 +54,7 @@ @media (prefers-color-scheme: light) { :host-context([enable-branding-update]) #prompt { color: var(--cr-toolbar-search-field-prompt-color, - var(--google-grey-refresh-700)); + var(--google-grey-700)); } } @@ -80,7 +80,7 @@ :host-context([enable-branding-update]) paper-spinner-lite { --paper-spinner-color: var( --cr-toolbar-search-field-input-icon-color, - var(--google-grey-refresh-700)); + var(--google-grey-700)); } } @@ -132,7 +132,7 @@ @media (prefers-color-scheme: light) { :host-context([enable-branding-update]) input { - caret-color: var(--google-blue-refresh-700); + caret-color: var(--google-blue-700); color: var(--cr-toolbar-search-field-input-text-color, var(--google-grey-900)); } @@ -168,7 +168,7 @@ :host-context([enable-branding-update]):host(:not([narrow])) { background: var(--cr-toolbar-search-field-background, - var(--google-grey-refresh-100)); + var(--google-grey-100)); } }
diff --git a/ui/webui/resources/cr_elements/md_select_css.html b/ui/webui/resources/cr_elements/md_select_css.html index 704f1f1..6ce0769 100644 --- a/ui/webui/resources/cr_elements/md_select_css.html +++ b/ui/webui/resources/cr_elements/md_select_css.html
@@ -8,7 +8,7 @@ <style> .md-select { --md-arrow-width: 10px; - --md-select-bg-color: var(--google-grey-refresh-100); + --md-select-bg-color: var(--google-grey-100); --md-select-focus-shadow-color: rgba(var(--google-blue-600-rgb), .4); --md-select-option-bg-color: white; --md-select-side-padding: 8px; @@ -42,7 +42,7 @@ .md-select { --md-select-bg-color: rgba(0, 0, 0, .3); --md-select-focus-shadow-color: - rgba(var(--google-blue-refresh-300-rgb), .5); + rgba(var(--google-blue-300-rgb), .5); --md-select-option-bg-color: var(--google-grey-900-white-4-percent); --md-select-text-color: var(--cr-primary-text-color); background-image: url(chrome://resources/images/dark/arrow_down.svg);
diff --git a/ui/webui/resources/cr_elements/mwb_shared_vars.html b/ui/webui/resources/cr_elements/mwb_shared_vars.html index b5640ee..27f26e1 100644 --- a/ui/webui/resources/cr_elements/mwb_shared_vars.html +++ b/ui/webui/resources/cr_elements/mwb_shared_vars.html
@@ -2,7 +2,7 @@ <style> html { --mwb-background-color: white; - --mwb-icon-button-fill-color: var(--google-grey-refresh-700); + --mwb-icon-button-fill-color: var(--google-grey-700); --mwb-icon-button-focus-ring-color: var(--google-blue-600); --mwb-icon-button-hover-background-color: rgba(var(--google-grey-900-rgb), 0.1); --mwb-icon-size: 16px; @@ -13,8 +13,8 @@ --mwb-list-section-title-font-size: 11px; --mwb-list-section-title-height: 40px; --mwb-primary-text-font-size: 13px; - --mwb-scrollbar-thumb-color: var(--google-grey-refresh-300); - --mwb-scrollbar-thumb-hover-color: var(--google-grey-refresh-500); + --mwb-scrollbar-thumb-color: var(--google-grey-300); + --mwb-scrollbar-thumb-hover-color: var(--google-grey-500); --mwb-scrollbar-track-color: var(--mwb-background-color); --mwb-scrollbar-width: 4px; --mwb-secondary-text-font-size: 12px; @@ -23,13 +23,13 @@ @media (prefers-color-scheme: dark) { html { --mwb-background-color: var(--google-grey-900); - --mwb-icon-button-fill-color: var(--google-grey-refresh-300); + --mwb-icon-button-fill-color: var(--google-grey-300); --mwb-icon-button-focus-ring-color: var(--google-blue-300); --mwb-icon-button-hover-background-color: rgba(255, 255, 255, 0.1); --mwb-list-item-hover-background-color: rgb(55, 56, 58); /* #37383a */ --mwb-list-item-selected-background-color: rgb(68, 69, 71); /* #444547 */ - --mwb-scrollbar-thumb-color: var(--google-grey-refresh-500); - --mwb-scrollbar-thumb-hover-color: var(--google-grey-refresh-300); + --mwb-scrollbar-thumb-color: var(--google-grey-500); + --mwb-scrollbar-thumb-hover-color: var(--google-grey-300); } } </style>
diff --git a/ui/webui/resources/cr_elements/policy/cr_tooltip_icon.html b/ui/webui/resources/cr_elements/policy/cr_tooltip_icon.html index d1ddc2d..4a07dec 100644 --- a/ui/webui/resources/cr_elements/policy/cr_tooltip_icon.html +++ b/ui/webui/resources/cr_elements/policy/cr_tooltip_icon.html
@@ -17,7 +17,7 @@ --iron-icon-width: var(--cr-icon-size); --iron-icon-height: var(--cr-icon-size); --iron-icon-fill-color: - var(--cr-tooltip-icon-fill-color, var(--google-grey-refresh-700)); + var(--cr-tooltip-icon-fill-color, var(--google-grey-700)); } </style> <iron-icon id="indicator" tabindex="0" aria-label$="[[iconAriaLabel]]"
diff --git a/ui/webui/resources/cr_elements/shared_style_css.html b/ui/webui/resources/cr_elements/shared_style_css.html index d862a9a..c0976a8 100644 --- a/ui/webui/resources/cr_elements/shared_style_css.html +++ b/ui/webui/resources/cr_elements/shared_style_css.html
@@ -11,13 +11,13 @@ <style include="cr-hidden-style cr-icons"> html, :host { - --scrollable-border-color: var(--google-grey-refresh-300); + --scrollable-border-color: var(--google-grey-300); } @media (prefers-color-scheme: dark) { html, :host { - --scrollable-border-color: var(--google-grey-refresh-700); + --scrollable-border-color: var(--google-grey-700); } }
diff --git a/weblayer/browser/android/javatests/src/org/chromium/weblayer/test/WebMessageTest.java b/weblayer/browser/android/javatests/src/org/chromium/weblayer/test/WebMessageTest.java index 6067332..ef61144 100644 --- a/weblayer/browser/android/javatests/src/org/chromium/weblayer/test/WebMessageTest.java +++ b/weblayer/browser/android/javatests/src/org/chromium/weblayer/test/WebMessageTest.java
@@ -99,7 +99,12 @@ assertNotNull(webMessageCallback.mLastMessage); assertNotNull(webMessageCallback.mLastProxy); assertEquals("from page", webMessageCallback.mLastMessage.getContents()); - WebMessageReplyProxy lastProxy = webMessageCallback.mLastProxy; + final WebMessageReplyProxy lastProxy = webMessageCallback.mLastProxy; + int majorVersion = TestThreadUtils.runOnUiThreadBlockingNoException( + () -> WebLayer.getSupportedMajorVersion(mActivityTestRule.getActivity())); + if (majorVersion >= 99) { + assertNotNull(runOnUiThreadBlocking(() -> { return lastProxy.getPage(); })); + } webMessageCallback.reset(); int currentCallCount = callbackHelper.getCallCount();
diff --git a/weblayer/browser/java/org/chromium/weblayer_private/WebMessageReplyProxyImpl.java b/weblayer/browser/java/org/chromium/weblayer_private/WebMessageReplyProxyImpl.java index 9eeef1e..2b213f47 100644 --- a/weblayer/browser/java/org/chromium/weblayer_private/WebMessageReplyProxyImpl.java +++ b/weblayer/browser/java/org/chromium/weblayer_private/WebMessageReplyProxyImpl.java
@@ -25,12 +25,16 @@ private final int mId; private WebMessageReplyProxyImpl(long nativeWebMessageReplyProxyImpl, int id, - IWebMessageCallbackClient client, boolean isMainFrame, String sourceOrigin) { + IWebMessageCallbackClient client, boolean isMainFrame, String sourceOrigin, + PageImpl page) { mNativeWebMessageReplyProxyImpl = nativeWebMessageReplyProxyImpl; mClient = client; mId = id; try { client.onNewReplyProxy(this, mId, isMainFrame, sourceOrigin); + if (WebLayerFactoryImpl.getClientMajorVersion() >= 99) { + client.onSetPage(mId, page.getClientPage()); + } } catch (RemoteException e) { throw new APICallException(e); } @@ -38,9 +42,10 @@ @CalledByNative private static WebMessageReplyProxyImpl create(long nativeWebMessageReplyProxyImpl, int id, - IWebMessageCallbackClient client, boolean isMainFrame, String sourceOrigin) { + IWebMessageCallbackClient client, boolean isMainFrame, String sourceOrigin, + PageImpl page) { return new WebMessageReplyProxyImpl( - nativeWebMessageReplyProxyImpl, id, client, isMainFrame, sourceOrigin); + nativeWebMessageReplyProxyImpl, id, client, isMainFrame, sourceOrigin, page); } @CalledByNative
diff --git a/weblayer/browser/java/org/chromium/weblayer_private/interfaces/IWebMessageCallbackClient.aidl b/weblayer/browser/java/org/chromium/weblayer_private/interfaces/IWebMessageCallbackClient.aidl index d42f3a0..641b855 100644 --- a/weblayer/browser/java/org/chromium/weblayer_private/interfaces/IWebMessageCallbackClient.aidl +++ b/weblayer/browser/java/org/chromium/weblayer_private/interfaces/IWebMessageCallbackClient.aidl
@@ -4,6 +4,7 @@ package org.chromium.weblayer_private.interfaces; +import org.chromium.weblayer_private.interfaces.IClientPage; import org.chromium.weblayer_private.interfaces.IWebMessageReplyProxy; interface IWebMessageCallbackClient { @@ -16,4 +17,7 @@ // @since 90 void onReplyProxyActiveStateChanged(in int proxyId) = 3; + + // @since 99 + void onSetPage(in int proxyId, IClientPage page) = 4; }
diff --git a/weblayer/browser/js_communication/web_message_browsertest.cc b/weblayer/browser/js_communication/web_message_browsertest.cc index 064273a..a0b2e9d 100644 --- a/weblayer/browser/js_communication/web_message_browsertest.cc +++ b/weblayer/browser/js_communication/web_message_browsertest.cc
@@ -134,6 +134,8 @@ ASSERT_EQ(2u, current_connection->messages().size()); EXPECT_EQ(u"from page", current_connection->messages()[0]); EXPECT_EQ(u"bouncing from c++", current_connection->messages()[1]); + // WebLayer's Page has no functions, verify it can be requested. + current_connection->proxy()->GetPage(); } class WebMessageTestWithBfCache : public WebLayerBrowserTest { @@ -179,6 +181,7 @@ ASSERT_TRUE(web_message_host); EXPECT_FALSE(web_message_host->proxy()->IsInBackForwardCache()); EXPECT_EQ(0, web_message_host->back_forward_cache_state_changed_call_count()); + Page* original_page = &(web_message_host->proxy()->GetPage()); // Navigate to a new host. The old page should go into the cache. OneShotNavigationObserver observer1(shell()); @@ -196,5 +199,7 @@ observer2.WaitForNavigation(); web_message_host->WaitForBackForwardStateToBe(false); EXPECT_EQ(2, web_message_host->back_forward_cache_state_changed_call_count()); + EXPECT_EQ(original_page, &(web_message_host->proxy()->GetPage())); } + } // namespace weblayer
diff --git a/weblayer/browser/js_communication/web_message_host_factory_wrapper.cc b/weblayer/browser/js_communication/web_message_host_factory_wrapper.cc index e5914a6..c3ef27a 100644 --- a/weblayer/browser/js_communication/web_message_host_factory_wrapper.cc +++ b/weblayer/browser/js_communication/web_message_host_factory_wrapper.cc
@@ -8,6 +8,9 @@ #include "components/js_injection/browser/web_message.h" #include "components/js_injection/browser/web_message_host.h" #include "components/js_injection/browser/web_message_reply_proxy.h" +#include "content/public/browser/page.h" +#include "content/public/browser/render_frame_host.h" +#include "weblayer/browser/page_impl.h" #include "weblayer/public/js_communication/web_message.h" #include "weblayer/public/js_communication/web_message_host.h" #include "weblayer/public/js_communication/web_message_host_factory.h" @@ -50,6 +53,16 @@ bool IsInBackForwardCache() override { return proxy_->IsInBackForwardCache(); } + Page& GetPage() override { + // In general WebLayer avoids exposing child frames. As such, GetPage() + // returns the Page of the main frame. + PageImpl* page = + PageImpl::GetForPage(proxy_->GetPage().GetMainDocument().GetPage()); + // NavigationControllerImpl creates the PageImpl when navigation finishes so + // that by the time this is called the Page should have been created. + DCHECK(page); + return *page; + } private: raw_ptr<js_injection::WebMessageReplyProxy> proxy_;
diff --git a/weblayer/browser/js_communication/web_message_reply_proxy_impl.cc b/weblayer/browser/js_communication/web_message_reply_proxy_impl.cc index d01203d..d484d9d 100644 --- a/weblayer/browser/js_communication/web_message_reply_proxy_impl.cc +++ b/weblayer/browser/js_communication/web_message_reply_proxy_impl.cc
@@ -9,6 +9,7 @@ #include "base/android/jni_android.h" #include "base/android/jni_string.h" #include "weblayer/browser/java/jni/WebMessageReplyProxyImpl_jni.h" +#include "weblayer/browser/page_impl.h" #include "weblayer/public/js_communication/web_message.h" #include "weblayer/public/js_communication/web_message_reply_proxy.h" @@ -24,7 +25,8 @@ auto* env = base::android::AttachCurrentThread(); java_object_ = Java_WebMessageReplyProxyImpl_create( env, reinterpret_cast<intptr_t>(this), id, client, is_main_frame, - base::android::ConvertUTF8ToJavaString(env, origin_string)); + base::android::ConvertUTF8ToJavaString(env, origin_string), + static_cast<PageImpl&>(reply_proxy->GetPage()).java_page()); } WebMessageReplyProxyImpl::~WebMessageReplyProxyImpl() {
diff --git a/weblayer/public/java/org/chromium/weblayer/Tab.java b/weblayer/public/java/org/chromium/weblayer/Tab.java index 166b022..c503d97 100644 --- a/weblayer/public/java/org/chromium/weblayer/Tab.java +++ b/weblayer/public/java/org/chromium/weblayer/Tab.java
@@ -15,6 +15,7 @@ import org.chromium.weblayer_private.interfaces.APICallException; import org.chromium.weblayer_private.interfaces.IClientNavigation; +import org.chromium.weblayer_private.interfaces.IClientPage; import org.chromium.weblayer_private.interfaces.IContextMenuParams; import org.chromium.weblayer_private.interfaces.IErrorPageCallbackClient; import org.chromium.weblayer_private.interfaces.IExternalIntentInIncognitoCallbackClient; @@ -836,6 +837,13 @@ assert proxy != null; mCallback.onWebMessageReplyProxyActiveStateChanged(proxy); } + + @Override + public void onSetPage(int proxyId, IClientPage clientPage) { + StrictModeWorkaround.apply(); + assert mProxyIdToProxy.get(proxyId) != null; + mProxyIdToProxy.get(proxyId).setPage((Page) clientPage); + } } private final class TabClientImpl extends ITabClient.Stub {
diff --git a/weblayer/public/java/org/chromium/weblayer/WebMessageReplyProxy.java b/weblayer/public/java/org/chromium/weblayer/WebMessageReplyProxy.java index 5075249..098c4b2 100644 --- a/weblayer/public/java/org/chromium/weblayer/WebMessageReplyProxy.java +++ b/weblayer/public/java/org/chromium/weblayer/WebMessageReplyProxy.java
@@ -23,6 +23,8 @@ private final boolean mIsMainFrame; private final String mSourceOrigin; private boolean mIsClosed; + // Added in 99. + private Page mPage; // Constructor for test mocking. protected WebMessageReplyProxy() { @@ -101,4 +103,25 @@ throw new APICallException(e); } } + + /** + * Returns the Page associated with this proxy. For child frame, the Page of the main frame is + * returned. + * + * @return The Page. + * + * @since 99 + */ + public Page getPage() { + ThreadCheck.ensureOnUiThread(); + if (WebLayer.getSupportedMajorVersionInternal() < 99) { + throw new UnsupportedOperationException(); + } + return mPage; + } + + // Only called in >= 99. + void setPage(Page page) { + mPage = page; + } }
diff --git a/weblayer/public/js_communication/web_message_reply_proxy.h b/weblayer/public/js_communication/web_message_reply_proxy.h index e464b80..6406364 100644 --- a/weblayer/public/js_communication/web_message_reply_proxy.h +++ b/weblayer/public/js_communication/web_message_reply_proxy.h
@@ -9,6 +9,7 @@ namespace weblayer { +class Page; struct WebMessage; // Used to send messages to the page. @@ -19,6 +20,10 @@ // Returns true if the page is in the back/forward cache. virtual bool IsInBackForwardCache() = 0; + // Returns the Page this proxy was created for. This always returns the Page + // of the main frame. + virtual Page& GetPage() = 0; + protected: virtual ~WebMessageReplyProxy() = default; };
diff --git a/weblayer/public/page.h b/weblayer/public/page.h index 5a65af2..d50a476b 100644 --- a/weblayer/public/page.h +++ b/weblayer/public/page.h
@@ -9,9 +9,9 @@ // This objects tracks the lifetime of a loaded web page. Most of the time there // is only one Page object per tab. However features like back-forward cache, -// prerendering etc... sometime involve the creation of additional Page object. +// prerendering etc... sometime involve the creation of additional Page objects. // Navigation::getPage() will return the Page for a given navigation. Similarly -// it'll the same Page object that's passed in +// it'll be the same Page object that's passed in // NavigationObserver::OnPageDestroyed(). class Page { protected: