| // Copyright 2020 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #ifdef UNSAFE_BUFFERS_BUILD |
| // TODO(crbug.com/40285824): Remove this and convert code to safer constructs. |
| #pragma allow_unsafe_buffers |
| #endif |
| |
| #include "chrome/browser/webshare/win/fake_random_access_stream.h" |
| |
| #include <robuffer.h> |
| #include <vector> |
| |
| #include "base/functional/bind.h" |
| #include "base/functional/callback_helpers.h" |
| #include "base/memory/weak_ptr.h" |
| #include "base/task/single_thread_task_runner.h" |
| #include "base/test/bind.h" |
| #include "base/test/fake_iasync_operation_win.h" |
| #include "chrome/browser/webshare/win/fake_iasync_operation_with_progress.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| using ABI::Windows::Foundation::IAsyncOperation; |
| using ABI::Windows::Foundation::IAsyncOperationWithProgress; |
| using ABI::Windows::Storage::Streams::IBuffer; |
| using ABI::Windows::Storage::Streams::IInputStream; |
| using ABI::Windows::Storage::Streams::InputStreamOptions; |
| using ABI::Windows::Storage::Streams::IOutputStream; |
| using ABI::Windows::Storage::Streams::IRandomAccessStream; |
| using Microsoft::WRL::ComPtr; |
| using Microsoft::WRL::Make; |
| using Windows::Storage::Streams::IBufferByteAccess; |
| |
| namespace ABI { |
| namespace Windows { |
| namespace Foundation { |
| |
| // Define template specializations for the types used. These uuids were randomly |
| // generated. |
| template <> |
| struct __declspec(uuid("99159E96-2AAD-4F4C-91AD-DBD5A92ACF12")) |
| IAsyncOperation<unsigned char> : IAsyncOperation_impl<unsigned char> {}; |
| |
| template <> |
| struct __declspec(uuid("9AF0D4FD-CD18-492E-A17E-27056D4F6481")) |
| IAsyncOperationCompletedHandler<unsigned char> |
| : IAsyncOperationCompletedHandler_impl<unsigned char> {}; |
| |
| } // namespace Foundation |
| } // namespace Windows |
| } // namespace ABI |
| |
| namespace webshare { |
| class StreamData final : public base::RefCountedThreadSafe<StreamData> { |
| public: |
| StreamData() = default; |
| StreamData(const StreamData& other) = delete; |
| StreamData& operator=(const StreamData&) = delete; |
| |
| public: |
| HRESULT get_Size(UINT64* value) { |
| *value = data_.size(); |
| return S_OK; |
| } |
| |
| HRESULT put_Size(UINT64 value) { |
| if (flush_async_in_progress_) { |
| ADD_FAILURE() |
| << "put_Size called while a flush operation is in progress."; |
| return E_ILLEGAL_METHOD_CALL; |
| } |
| if (read_async_in_progress_) { |
| ADD_FAILURE() << "put_Size called while a read operation is in progress."; |
| return E_ILLEGAL_METHOD_CALL; |
| } |
| if (write_async_in_progress_) { |
| ADD_FAILURE() |
| << "put_Size called while a write operation is in progress."; |
| return E_ILLEGAL_METHOD_CALL; |
| } |
| |
| data_.resize(value); |
| return S_OK; |
| } |
| |
| HRESULT ReadAsync(scoped_refptr<base::RefCountedData<UINT64>> position, |
| IBuffer* buffer, |
| UINT32 count, |
| InputStreamOptions options, |
| IAsyncOperationWithProgress<IBuffer*, UINT32>** operation) { |
| if (flush_async_in_progress_) { |
| ADD_FAILURE() |
| << "ReadAsync called while a flush operation is in progress."; |
| return E_ILLEGAL_METHOD_CALL; |
| } |
| if (read_async_in_progress_) { |
| ADD_FAILURE() |
| << "ReadAsync called while a read operation is in progress."; |
| return E_ILLEGAL_METHOD_CALL; |
| } |
| if (write_async_in_progress_) { |
| ADD_FAILURE() |
| << "ReadAsync called while a write operation is in progress."; |
| return E_ILLEGAL_METHOD_CALL; |
| } |
| |
| ComPtr<IBuffer> captured_buffer = buffer; |
| auto fake_iasync_operation = |
| Make<FakeIAsyncOperationWithProgress<IBuffer*, UINT32>>(); |
| |
| HRESULT hr = fake_iasync_operation->QueryInterface(IID_PPV_ARGS(operation)); |
| if (FAILED(hr)) { |
| EXPECT_HRESULT_SUCCEEDED(hr); |
| return hr; |
| } |
| |
| bool success = base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, |
| base::BindOnce(&StreamData::OnReadAsync, weak_factory_.GetWeakPtr(), |
| position, fake_iasync_operation, captured_buffer, |
| count)); |
| if (!success) { |
| EXPECT_TRUE(success); |
| return E_ASYNC_OPERATION_NOT_STARTED; |
| } |
| |
| read_async_in_progress_ = true; |
| return S_OK; |
| } |
| |
| HRESULT |
| WriteAsync(scoped_refptr<base::RefCountedData<UINT64>> position, |
| IBuffer* buffer, |
| IAsyncOperationWithProgress<UINT32, UINT32>** operation) { |
| if (flush_async_in_progress_) { |
| ADD_FAILURE() |
| << "WriteAsync called while a flush operation is in progress."; |
| return E_ILLEGAL_METHOD_CALL; |
| } |
| if (read_async_in_progress_) { |
| ADD_FAILURE() |
| << "WriteAsync called while a read operation is in progress."; |
| return E_ILLEGAL_METHOD_CALL; |
| } |
| if (write_async_in_progress_) { |
| ADD_FAILURE() |
| << "WriteAsync called while a write operation is in progress."; |
| return E_ILLEGAL_METHOD_CALL; |
| } |
| |
| ComPtr<IBuffer> captured_buffer = buffer; |
| auto fake_iasync_operation = |
| Make<FakeIAsyncOperationWithProgress<UINT32, UINT32>>(); |
| |
| HRESULT hr = fake_iasync_operation->QueryInterface(IID_PPV_ARGS(operation)); |
| if (FAILED(hr)) { |
| EXPECT_HRESULT_SUCCEEDED(hr); |
| return hr; |
| } |
| |
| bool success = base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, |
| base::BindOnce(&StreamData::OnWriteAsync, weak_factory_.GetWeakPtr(), |
| position, fake_iasync_operation, captured_buffer)); |
| if (!success) { |
| EXPECT_TRUE(success); |
| return E_ASYNC_OPERATION_NOT_STARTED; |
| } |
| |
| write_async_in_progress_ = true; |
| return S_OK; |
| } |
| |
| HRESULT |
| FlushAsync(IAsyncOperation<bool>** operation) { |
| if (flush_async_in_progress_) { |
| ADD_FAILURE() |
| << "FlushAsync called while a flush operation is in progress."; |
| return E_ILLEGAL_METHOD_CALL; |
| } |
| if (read_async_in_progress_) { |
| ADD_FAILURE() |
| << "FlushAsync called while a read operation is in progress."; |
| return E_ILLEGAL_METHOD_CALL; |
| } |
| if (write_async_in_progress_) { |
| ADD_FAILURE() |
| << "FlushAsync called while a write operation is in progress."; |
| return E_ILLEGAL_METHOD_CALL; |
| } |
| |
| auto fake_iasync_operation = Make<base::win::FakeIAsyncOperation<bool>>(); |
| |
| HRESULT hr = fake_iasync_operation->QueryInterface(IID_PPV_ARGS(operation)); |
| if (FAILED(hr)) { |
| EXPECT_HRESULT_SUCCEEDED(hr); |
| return hr; |
| } |
| |
| bool success = base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, |
| base::BindOnce(&StreamData::OnFlushAsync, weak_factory_.GetWeakPtr(), |
| fake_iasync_operation)); |
| if (!success) { |
| EXPECT_TRUE(success); |
| return E_ASYNC_OPERATION_NOT_STARTED; |
| } |
| |
| flush_async_in_progress_ = true; |
| return S_OK; |
| } |
| |
| private: |
| friend class base::RefCountedThreadSafe<StreamData>; |
| |
| ~StreamData() { |
| EXPECT_FALSE(flush_async_in_progress_) |
| << "StreamData destroyed while flush operation is in progress."; |
| EXPECT_FALSE(read_async_in_progress_) |
| << "StreamData destroyed while read operation is in progress."; |
| EXPECT_FALSE(write_async_in_progress_) |
| << "StreamData destroyed while write operation is in progress."; |
| } |
| |
| void OnFlushAsync( |
| ComPtr<base::win::FakeIAsyncOperation<bool>> fake_iasync_operation) { |
| ASSERT_TRUE(flush_async_in_progress_); |
| flush_async_in_progress_ = false; |
| |
| fake_iasync_operation->CompleteWithResults(true); |
| } |
| |
| void OnReadAsync(scoped_refptr<base::RefCountedData<UINT64>> position, |
| ComPtr<FakeIAsyncOperationWithProgress<IBuffer*, UINT32>> |
| fake_iasync_operation, |
| ComPtr<IBuffer> buffer, |
| UINT32 count) { |
| ASSERT_TRUE(read_async_in_progress_); |
| read_async_in_progress_ = false; |
| |
| // If reading |count| bytes would attempt to read past the end of our inner |
| // |data_|, reduce it to only read to the end of our |data_|. |
| if (position->data + count > data_.size()) |
| count = data_.size() - position->data; |
| |
| // Fetch the raw buffer to write to. |
| ComPtr<IBufferByteAccess> buffer_byte_access; |
| EXPECT_HRESULT_SUCCEEDED(buffer.As(&buffer_byte_access)); |
| byte* raw_buffer; |
| EXPECT_HRESULT_SUCCEEDED(buffer_byte_access->Buffer(&raw_buffer)); |
| |
| // Write the data to the buffer, updating the position and the buffer's |
| // length |
| EXPECT_HRESULT_SUCCEEDED(buffer->put_Length(count)); |
| for (UINT32 i = 0; i < count; i++) { |
| raw_buffer[i] = data_[position->data + i]; |
| } |
| position->data += count; |
| |
| fake_iasync_operation->CompleteWithResults(buffer.Get()); |
| } |
| |
| void OnWriteAsync(scoped_refptr<base::RefCountedData<UINT64>> position, |
| ComPtr<FakeIAsyncOperationWithProgress<UINT32, UINT32>> |
| fake_iasync_operation, |
| ComPtr<IBuffer> buffer) { |
| ASSERT_TRUE(write_async_in_progress_); |
| write_async_in_progress_ = false; |
| |
| UINT32 length; |
| ASSERT_HRESULT_SUCCEEDED(buffer->get_Length(&length)); |
| |
| // Fetch the raw buffer to read from. |
| ComPtr<IBufferByteAccess> buffer_byte_access; |
| EXPECT_HRESULT_SUCCEEDED(buffer.As(&buffer_byte_access)); |
| byte* raw_buffer; |
| EXPECT_HRESULT_SUCCEEDED(buffer_byte_access->Buffer(&raw_buffer)); |
| |
| // If reading the full buffer would take more room than is currently in our |
| // inner |data_|, resize it to fit. |
| if (position->data + length > data_.size()) |
| data_.resize(position->data + length); |
| |
| // Write the buffer to our inner |data_| and update the position. |
| for (UINT32 i = 0; i < length; i++) { |
| data_[position->data + i] = raw_buffer[i]; |
| } |
| position->data += length; |
| |
| fake_iasync_operation->CompleteWithResults(length); |
| } |
| |
| std::vector<unsigned char> data_; |
| bool flush_async_in_progress_ = false; |
| bool read_async_in_progress_ = false; |
| bool write_async_in_progress_ = false; |
| base::WeakPtrFactory<StreamData> weak_factory_{this}; |
| }; |
| |
| FakeRandomAccessStream::FakeRandomAccessStream() { |
| position_ = base::MakeRefCounted<base::RefCountedData<UINT64>>(); |
| shared_data_ = base::MakeRefCounted<StreamData>(); |
| } |
| FakeRandomAccessStream::~FakeRandomAccessStream() { |
| EXPECT_TRUE(is_closed_) |
| << "FakeRandomAccessStream destroyed without being closed."; |
| } |
| |
| IFACEMETHODIMP FakeRandomAccessStream::get_Size(UINT64* value) { |
| if (is_closed_) { |
| ADD_FAILURE() << "get_Size called on closed FakeRandomAccessStream."; |
| return RO_E_CLOSED; |
| } |
| return shared_data_->get_Size(value); |
| } |
| |
| IFACEMETHODIMP FakeRandomAccessStream::put_Size(UINT64 value) { |
| if (is_closed_) { |
| ADD_FAILURE() << "put_Size called on closed FakeRandomAccessStream."; |
| return RO_E_CLOSED; |
| } |
| return shared_data_->put_Size(value); |
| } |
| |
| IFACEMETHODIMP |
| FakeRandomAccessStream::GetInputStreamAt(UINT64 position, |
| IInputStream** stream) { |
| if (is_closed_) { |
| ADD_FAILURE() |
| << "GetInputStreamAt called on closed FakeRandomAccessStream."; |
| return RO_E_CLOSED; |
| } |
| auto copy = Make<FakeRandomAccessStream>(); |
| copy->position_->data = position; |
| copy->shared_data_ = shared_data_; |
| EXPECT_HRESULT_SUCCEEDED(copy->QueryInterface(IID_PPV_ARGS(stream))); |
| return S_OK; |
| } |
| |
| IFACEMETHODIMP |
| FakeRandomAccessStream::GetOutputStreamAt(UINT64 position, |
| IOutputStream** stream) { |
| if (is_closed_) { |
| ADD_FAILURE() |
| << "GetOutputStreamAt called on closed FakeRandomAccessStream."; |
| return RO_E_CLOSED; |
| } |
| auto copy = Make<FakeRandomAccessStream>(); |
| copy->position_->data = position; |
| copy->shared_data_ = shared_data_; |
| EXPECT_HRESULT_SUCCEEDED(copy->QueryInterface(IID_PPV_ARGS(stream))); |
| return S_OK; |
| } |
| |
| IFACEMETHODIMP FakeRandomAccessStream::get_Position(UINT64* value) { |
| if (is_closed_) { |
| ADD_FAILURE() << "get_Position called on closed FakeRandomAccessStream."; |
| return RO_E_CLOSED; |
| } |
| *value = position_->data; |
| return S_OK; |
| } |
| |
| IFACEMETHODIMP FakeRandomAccessStream::Seek(UINT64 position) { |
| if (is_closed_) { |
| ADD_FAILURE() << "Seek called on closed FakeRandomAccessStream."; |
| return RO_E_CLOSED; |
| } |
| UINT64 size; |
| HRESULT hr = shared_data_->get_Size(&size); |
| if (FAILED(hr)) |
| return hr; |
| |
| if (position > size) { |
| // Though it is technically legal to call Seek with an invalid |position| |
| // value, there is no good reason to do so, so presumably points to a coding |
| // error. |
| // https://docs.microsoft.com/en-us/uwp/api/windows.storage.streams.irandomaccessstream.seek#remarks |
| ADD_FAILURE() << "Seek called with position outside the known valid range."; |
| return E_BOUNDS; |
| } |
| |
| position_->data = position; |
| return S_OK; |
| } |
| |
| IFACEMETHODIMP |
| FakeRandomAccessStream::CloneStream(IRandomAccessStream** stream) { |
| NOTREACHED(); |
| } |
| |
| IFACEMETHODIMP FakeRandomAccessStream::get_CanRead(boolean* value) { |
| if (is_closed_) { |
| ADD_FAILURE() << "get_CanRead called on closed FakeRandomAccessStream."; |
| return RO_E_CLOSED; |
| } |
| *value = TRUE; |
| return S_OK; |
| } |
| |
| IFACEMETHODIMP FakeRandomAccessStream::get_CanWrite(boolean* value) { |
| if (is_closed_) { |
| ADD_FAILURE() << "get_CanWrite called on closed FakeRandomAccessStream."; |
| return RO_E_CLOSED; |
| } |
| *value = TRUE; |
| return S_OK; |
| } |
| |
| IFACEMETHODIMP FakeRandomAccessStream::Close() { |
| if (is_closed_) { |
| ADD_FAILURE() << "Close called on closed FakeRandomAccessStream."; |
| return RO_E_CLOSED; |
| } |
| is_closed_ = true; |
| if (on_close_) |
| std::move(on_close_).Run(); |
| return S_OK; |
| } |
| |
| IFACEMETHODIMP FakeRandomAccessStream::ReadAsync( |
| IBuffer* buffer, |
| UINT32 count, |
| InputStreamOptions options, |
| IAsyncOperationWithProgress<IBuffer*, UINT32>** operation) { |
| if (is_closed_) { |
| ADD_FAILURE() << "ReadAsync called on closed FakeRandomAccessStream."; |
| return RO_E_CLOSED; |
| } |
| return shared_data_->ReadAsync(position_, buffer, count, options, operation); |
| } |
| |
| IFACEMETHODIMP |
| FakeRandomAccessStream::WriteAsync( |
| IBuffer* buffer, |
| IAsyncOperationWithProgress<UINT32, UINT32>** operation) { |
| if (is_closed_) { |
| ADD_FAILURE() << "WriteAsync called on closed FakeRandomAccessStream."; |
| return RO_E_CLOSED; |
| } |
| return shared_data_->WriteAsync(position_, buffer, operation); |
| } |
| |
| IFACEMETHODIMP |
| FakeRandomAccessStream::FlushAsync(IAsyncOperation<bool>** operation) { |
| if (is_closed_) { |
| ADD_FAILURE() << "FlushAsync called on closed FakeRandomAccessStream."; |
| return RO_E_CLOSED; |
| } |
| return shared_data_->FlushAsync(operation); |
| } |
| |
| void FakeRandomAccessStream::OnClose(base::OnceClosure on_close) { |
| ASSERT_FALSE(is_closed_) |
| << "OnClose called on closed FakeRandomAccessStream."; |
| ASSERT_FALSE(on_close_) << "OnClose called on FakeRandomAccessStream that " |
| "already has an OnClose handler defined."; |
| on_close_ = std::move(on_close); |
| } |
| |
| } // namespace webshare |