| // Copyright 2020 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "connections/implementation/base_endpoint_channel.h" |
| |
| #include <cassert> |
| |
| #include "absl/strings/escaping.h" |
| #include "absl/strings/str_cat.h" |
| #include "connections/implementation/offline_frames.h" |
| #include "internal/platform/byte_array.h" |
| #include "internal/platform/exception.h" |
| #include "internal/platform/logging.h" |
| #include "internal/platform/mutex.h" |
| #include "internal/platform/mutex_lock.h" |
| |
| namespace location { |
| namespace nearby { |
| namespace connections { |
| |
| namespace { |
| |
| std::int32_t BytesToInt(const ByteArray& bytes) { |
| const char* int_bytes = bytes.data(); |
| |
| std::int32_t result = 0; |
| result |= (static_cast<std::int32_t>(int_bytes[0]) & 0x0FF) << 24; |
| result |= (static_cast<std::int32_t>(int_bytes[1]) & 0x0FF) << 16; |
| result |= (static_cast<std::int32_t>(int_bytes[2]) & 0x0FF) << 8; |
| result |= (static_cast<std::int32_t>(int_bytes[3]) & 0x0FF); |
| |
| return result; |
| } |
| |
| ByteArray IntToBytes(std::int32_t value) { |
| char int_bytes[sizeof(std::int32_t)]; |
| int_bytes[0] = static_cast<char>((value >> 24) & 0x0FF); |
| int_bytes[1] = static_cast<char>((value >> 16) & 0x0FF); |
| int_bytes[2] = static_cast<char>((value >> 8) & 0x0FF); |
| int_bytes[3] = static_cast<char>((value)&0x0FF); |
| |
| return ByteArray(int_bytes, sizeof(int_bytes)); |
| } |
| |
| ExceptionOr<ByteArray> ReadExactly(InputStream* reader, std::int64_t size) { |
| ByteArray buffer(size); |
| std::int64_t current_pos = 0; |
| |
| while (current_pos < size) { |
| ExceptionOr<ByteArray> read_bytes = reader->Read(size - current_pos); |
| if (!read_bytes.ok()) { |
| return read_bytes; |
| } |
| ByteArray result = read_bytes.result(); |
| |
| if (result.Empty()) { |
| NEARBY_LOGS(WARNING) << __func__ << ": Empty result when reading bytes."; |
| return ExceptionOr<ByteArray>(Exception::kIo); |
| } |
| |
| buffer.CopyAt(current_pos, result); |
| current_pos += result.size(); |
| } |
| |
| return ExceptionOr<ByteArray>(std::move(buffer)); |
| } |
| |
| ExceptionOr<std::int32_t> ReadInt(InputStream* reader) { |
| ExceptionOr<ByteArray> read_bytes = ReadExactly(reader, sizeof(std::int32_t)); |
| if (!read_bytes.ok()) { |
| return ExceptionOr<std::int32_t>(read_bytes.exception()); |
| } |
| return ExceptionOr<std::int32_t>(BytesToInt(std::move(read_bytes.result()))); |
| } |
| |
| Exception WriteInt(OutputStream* writer, std::int32_t value) { |
| return writer->Write(IntToBytes(value)); |
| } |
| |
| } // namespace |
| |
| BaseEndpointChannel::BaseEndpointChannel(const std::string& channel_name, |
| InputStream* reader, |
| OutputStream* writer) |
| : BaseEndpointChannel( |
| channel_name, reader, writer, |
| // TODO(edwinwu): Below values should be retrieved from a base socket, |
| // the #MediumSocket in Android counterpart, from which all the |
| // derived medium sockets should dervied, and implement the supported |
| // values and leave the default values in base #MediumSocket. |
| /*ConnectionTechnology*/ |
| proto::connections::CONNECTION_TECHNOLOGY_UNKNOWN_TECHNOLOGY, |
| /*ConnectionBand*/ proto::connections::CONNECTION_BAND_UNKNOWN_BAND, |
| /*frequency*/ -1, |
| /*try_count*/ 0) {} |
| |
| BaseEndpointChannel::BaseEndpointChannel( |
| const std::string& channel_name, InputStream* reader, OutputStream* writer, |
| proto::connections::ConnectionTechnology technology, |
| proto::connections::ConnectionBand band, int frequency, int try_count) |
| : channel_name_(channel_name), |
| reader_(reader), |
| writer_(writer), |
| technology_(technology), |
| band_(band), |
| frequency_(frequency), |
| try_count_(try_count) {} |
| |
| ExceptionOr<ByteArray> BaseEndpointChannel::Read() { |
| ByteArray result; |
| { |
| MutexLock lock(&reader_mutex_); |
| |
| ExceptionOr<std::int32_t> read_int = ReadInt(reader_); |
| if (!read_int.ok()) { |
| return ExceptionOr<ByteArray>(read_int.exception()); |
| } |
| |
| if (read_int.result() < 0 || read_int.result() > kMaxAllowedReadBytes) { |
| NEARBY_LOGS(WARNING) << __func__ << ": Read an invalid number of bytes: " |
| << read_int.result(); |
| return ExceptionOr<ByteArray>(Exception::kIo); |
| } |
| |
| ExceptionOr<ByteArray> read_bytes = ReadExactly(reader_, read_int.result()); |
| if (!read_bytes.ok()) { |
| return read_bytes; |
| } |
| result = std::move(read_bytes.result()); |
| } |
| |
| { |
| MutexLock crypto_lock(&crypto_mutex_); |
| if (IsEncryptionEnabledLocked()) { |
| // If encryption is enabled, decode the message. |
| std::string input(std::move(result)); |
| std::unique_ptr<std::string> decrypted_data = |
| crypto_context_->DecodeMessageFromPeer(input); |
| if (decrypted_data) { |
| result = ByteArray(std::move(*decrypted_data)); |
| } else { |
| // It could be a protocol race, where remote party sends a KEEP_ALIVE |
| // before encryption is setup on their side, and we receive it after |
| // we switched to encryption mode. |
| // In this case, we verify that message is indeed a valid KEEP_ALIVE, |
| // and let it through if it is, otherwise message is erased. |
| // TODO(apolyudov): verify this happens at most once per session. |
| result = {}; |
| auto parsed = parser::FromBytes(ByteArray(input)); |
| if (parsed.ok()) { |
| if (parser::GetFrameType(parsed.result()) == V1Frame::KEEP_ALIVE) { |
| NEARBY_LOGS(INFO) |
| << __func__ |
| << ": Read unencrypted KEEP_ALIVE on encrypted channel."; |
| result = ByteArray(input); |
| } else { |
| NEARBY_LOGS(WARNING) |
| << __func__ << ": Read unexpected unencrypted frame of type " |
| << parser::GetFrameType(parsed.result()); |
| } |
| } else { |
| NEARBY_LOGS(WARNING) |
| << __func__ << ": Unable to parse data as unencrypted message."; |
| } |
| } |
| if (result.Empty()) { |
| NEARBY_LOGS(WARNING) << __func__ << ": Unable to parse read result."; |
| return ExceptionOr<ByteArray>(Exception::kInvalidProtocolBuffer); |
| } |
| } |
| } |
| |
| { |
| MutexLock lock(&last_read_mutex_); |
| last_read_timestamp_ = SystemClock::ElapsedRealtime(); |
| } |
| return ExceptionOr<ByteArray>(result); |
| } |
| |
| Exception BaseEndpointChannel::Write(const ByteArray& data) { |
| { |
| MutexLock pause_lock(&is_paused_mutex_); |
| if (is_paused_) { |
| BlockUntilUnpaused(); |
| } |
| } |
| |
| ByteArray encrypted_data; |
| const ByteArray* data_to_write = &data; |
| { |
| // Holding both mutexes is necessary to prevent the keep alive and payload |
| // threads from writing encrypted messages out of order which causes a |
| // failure to decrypt on the reader side. However we need to release the |
| // crypto lock after encrypting to ensure read decryption is not blocked. |
| MutexLock lock(&writer_mutex_); |
| { |
| MutexLock crypto_lock(&crypto_mutex_); |
| if (IsEncryptionEnabledLocked()) { |
| // If encryption is enabled, encode the message. |
| std::unique_ptr<std::string> encrypted = |
| crypto_context_->EncodeMessageToPeer(std::string(data)); |
| if (!encrypted) { |
| NEARBY_LOGS(WARNING) << __func__ << ": Failed to encrypt data."; |
| return {Exception::kIo}; |
| } |
| encrypted_data = ByteArray(std::move(*encrypted)); |
| data_to_write = &encrypted_data; |
| } |
| } |
| |
| Exception write_exception = |
| WriteInt(writer_, static_cast<std::int32_t>(data_to_write->size())); |
| if (write_exception.Raised()) { |
| NEARBY_LOGS(WARNING) << __func__ << ": Failed to write header: " |
| << write_exception.value; |
| return write_exception; |
| } |
| write_exception = writer_->Write(*data_to_write); |
| if (write_exception.Raised()) { |
| NEARBY_LOGS(WARNING) << __func__ << ": Failed to write data: " |
| << write_exception.value; |
| return write_exception; |
| } |
| Exception flush_exception = writer_->Flush(); |
| if (flush_exception.Raised()) { |
| NEARBY_LOGS(WARNING) << __func__ << ": Failed to flush writer: " |
| << flush_exception.value; |
| return flush_exception; |
| } |
| } |
| |
| { |
| MutexLock lock(&last_write_mutex_); |
| last_write_timestamp_ = SystemClock::ElapsedRealtime(); |
| } |
| return {Exception::kSuccess}; |
| } |
| |
| void BaseEndpointChannel::Close() { |
| { |
| // In case channel is paused, resume it first thing. |
| MutexLock lock(&is_paused_mutex_); |
| UnblockPausedWriter(); |
| } |
| CloseIo(); |
| CloseImpl(); |
| } |
| |
| void BaseEndpointChannel::CloseIo() { |
| // Keep this method dedicated to reader and writer handling an nothing else. |
| { |
| // Do not take reader_mutex_ here: read may be in progress, and it will |
| // deadlock. Calling Close() with Read() in progress will terminate the |
| // IO and Read() will proceed normally (with Exception::kIo). |
| Exception exception = reader_->Close(); |
| if (!exception.Ok()) { |
| NEARBY_LOGS(WARNING) << __func__ |
| << ": Exception closing reader: " << exception.value; |
| } |
| } |
| { |
| // Do not take writer_mutex_ here: write may be in progress, and it will |
| // deadlock. Calling Close() with Write() in progress will terminate the |
| // IO and Write() will proceed normally (with Exception::kIo). |
| Exception exception = writer_->Close(); |
| if (!exception.Ok()) { |
| NEARBY_LOGS(WARNING) << __func__ |
| << ": Exception closing writer: " << exception.value; |
| } |
| } |
| } |
| |
| void BaseEndpointChannel::SetAnalyticsRecorder( |
| analytics::AnalyticsRecorder* analytics_recorder, |
| const std::string& endpoint_id) { |
| analytics_recorder_ = analytics_recorder; |
| endpoint_id_ = endpoint_id; |
| } |
| |
| void BaseEndpointChannel::Close( |
| proto::connections::DisconnectionReason reason) { |
| NEARBY_LOGS(INFO) << __func__ |
| << ": Closing endpoint channel, reason: " << reason; |
| Close(); |
| |
| if (analytics_recorder_ != nullptr && !endpoint_id_.empty()) { |
| analytics_recorder_->OnConnectionClosed(endpoint_id_, GetMedium(), reason); |
| } |
| } |
| |
| std::string BaseEndpointChannel::GetType() const { |
| MutexLock crypto_lock(&crypto_mutex_); |
| std::string subtype = IsEncryptionEnabledLocked() ? "ENCRYPTED_" : ""; |
| std::string medium = proto::connections::Medium_Name( |
| proto::connections::Medium::UNKNOWN_MEDIUM); |
| |
| if (GetMedium() != proto::connections::Medium::UNKNOWN_MEDIUM) { |
| medium = |
| absl::StrCat(subtype, proto::connections::Medium_Name(GetMedium())); |
| } |
| return medium; |
| } |
| |
| std::string BaseEndpointChannel::GetName() const { return channel_name_; } |
| |
| int BaseEndpointChannel::GetMaxTransmitPacketSize() const { |
| // Return default value if the medium never define it's chunk size. |
| return kDefaultMaxTransmitPacketSize; |
| } |
| |
| void BaseEndpointChannel::EnableEncryption( |
| std::shared_ptr<EncryptionContext> context) { |
| MutexLock crypto_lock(&crypto_mutex_); |
| crypto_context_ = context; |
| } |
| |
| void BaseEndpointChannel::DisableEncryption() { |
| MutexLock crypto_lock(&crypto_mutex_); |
| crypto_context_.reset(); |
| } |
| |
| bool BaseEndpointChannel::IsPaused() const { |
| MutexLock lock(&is_paused_mutex_); |
| return is_paused_; |
| } |
| |
| void BaseEndpointChannel::Pause() { |
| MutexLock lock(&is_paused_mutex_); |
| is_paused_ = true; |
| } |
| |
| void BaseEndpointChannel::Resume() { |
| MutexLock lock(&is_paused_mutex_); |
| is_paused_ = false; |
| is_paused_cond_.Notify(); |
| } |
| |
| absl::Time BaseEndpointChannel::GetLastReadTimestamp() const { |
| MutexLock lock(&last_read_mutex_); |
| return last_read_timestamp_; |
| } |
| |
| absl::Time BaseEndpointChannel::GetLastWriteTimestamp() const { |
| MutexLock lock(&last_write_mutex_); |
| return last_write_timestamp_; |
| } |
| |
| proto::connections::ConnectionTechnology BaseEndpointChannel::GetTechnology() |
| const { |
| return technology_; |
| } |
| |
| // Returns the used wifi band of this EndpointChannel. |
| proto::connections::ConnectionBand BaseEndpointChannel::GetBand() const { |
| return band_; |
| } |
| |
| // Returns the used wifi frequency of this EndpointChannel. |
| int BaseEndpointChannel::GetFrequency() const { return frequency_; } |
| |
| // Returns the try count of this EndpointChannel. |
| int BaseEndpointChannel::GetTryCount() const { return try_count_; } |
| |
| bool BaseEndpointChannel::IsEncryptionEnabledLocked() const { |
| return crypto_context_ != nullptr; |
| } |
| |
| void BaseEndpointChannel::BlockUntilUnpaused() { |
| // For more on how this works, see |
| // https://docs.oracle.com/javase/tutorial/essential/concurrency/guardmeth.html |
| while (is_paused_) { |
| Exception wait_succeeded = is_paused_cond_.Wait(); |
| if (!wait_succeeded.Ok()) { |
| NEARBY_LOGS(WARNING) << __func__ << ": Failure waiting to unpause: " |
| << wait_succeeded.value; |
| return; |
| } |
| } |
| } |
| |
| void BaseEndpointChannel::UnblockPausedWriter() { |
| // For more on how this works, see |
| // https://docs.oracle.com/javase/tutorial/essential/concurrency/guardmeth.html |
| is_paused_ = false; |
| is_paused_cond_.Notify(); |
| } |
| |
| } // namespace connections |
| } // namespace nearby |
| } // namespace location |