| // Copyright 2021 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "chrome/elevation_service/scoped_mock_context.h" |
| |
| #include <objbase.h> |
| #include <objidl.h> |
| #include <unknwn.h> |
| #include <wrl/implements.h> |
| |
| #include "base/win/com_init_util.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| namespace { |
| |
| // A mock implementation of IServerSecurity that allows for the production code |
| // that calls ::CoImpersonateClient() to work. |
| class MockServerSecurity |
| : public Microsoft::WRL::RuntimeClass< |
| Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, |
| IServerSecurity> { |
| public: |
| MockServerSecurity() = default; |
| MockServerSecurity(const MockServerSecurity&) = delete; |
| MockServerSecurity& operator=(const MockServerSecurity&) = delete; |
| |
| IFACEMETHODIMP QueryBlanket(DWORD* authentication_service, |
| DWORD* authorization_service, |
| OLECHAR** server_principal_name, |
| DWORD* authentication_level, |
| DWORD* impersonation_level, |
| void** privilege, |
| DWORD* capabilities) override { |
| return E_NOTIMPL; |
| } |
| IFACEMETHODIMP ImpersonateClient() override { |
| is_impersonating_ = true; |
| return S_OK; |
| } |
| IFACEMETHODIMP RevertToSelf() override { |
| is_impersonating_ = false; |
| return S_OK; |
| } |
| IFACEMETHODIMP_(BOOL) IsImpersonating() override { return is_impersonating_; } |
| |
| private: |
| ~MockServerSecurity() override { EXPECT_FALSE(is_impersonating_); } |
| |
| bool is_impersonating_ = false; |
| }; |
| |
| } // namespace |
| |
| namespace elevation_service { |
| |
| ScopedMockContext::ScopedMockContext() { |
| base::win::AssertComInitialized(); |
| auto mock_call_context = Microsoft::WRL::Make<MockServerSecurity>(); |
| |
| // We set the call context to a mock object that implements IServerSecurity. |
| // This allows for the production code that calls ::CoImpersonateClient() to |
| // succeed. |
| auto hresult = |
| ::CoSwitchCallContext(mock_call_context.Get(), &original_call_context_); |
| EXPECT_HRESULT_SUCCEEDED(hresult); |
| if (FAILED(hresult)) |
| return; |
| |
| mock_call_context_ = std::move(mock_call_context); |
| EXPECT_EQ(original_call_context_, nullptr); |
| } |
| |
| ScopedMockContext::~ScopedMockContext() { |
| base::win::AssertComInitialized(); |
| if (!Succeeded()) |
| return; |
| |
| IUnknown* this_call_context = nullptr; |
| EXPECT_HRESULT_SUCCEEDED( |
| ::CoSwitchCallContext(original_call_context_, &this_call_context)); |
| EXPECT_EQ(this_call_context, mock_call_context_.Get()) |
| << "CoSwitchCallContext switched out someone else's context."; |
| } |
| |
| } // namespace elevation_service |