blob: fd3b912e683c125519fc9b0c5106412d5d6fa9dd [file] [log] [blame]
//
// Copyright (C) 2012 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "shill/net/netlink_socket.h"
#include <linux/if_packet.h>
#include <linux/netlink.h>
#include <sys/socket.h>
#include <string>
#include <base/logging.h>
#include "shill/logging.h"
#include "shill/net/netlink_fd.h"
#include "shill/net/netlink_message.h"
#include "shill/net/sockets.h"
// This is from a version of linux/socket.h that we don't have.
#define SOL_NETLINK 270
namespace shill {
namespace Logging {
static auto kModuleLogScope = ScopeLogger::kRTNL;
static std::string ObjectID(const NetlinkSocket* obj) {
return "(netlink_socket)";
}
} // namespace Logging
NetlinkSocket::NetlinkSocket()
: sequence_number_(0), file_descriptor_(Sockets::kInvalidFileDescriptor) {}
NetlinkSocket::~NetlinkSocket() {
if (sockets_ && (file_descriptor_ >= 0)) {
sockets_->Close(file_descriptor_);
}
}
bool NetlinkSocket::Init() {
// Allows for a test to set |sockets_| before calling |Init|.
if (sockets_) {
LOG(INFO) << "|sockets_| already has a value -- this must be a test.";
} else {
sockets_.reset(new Sockets);
}
file_descriptor_ = OpenNetlinkSocketFD(sockets_.get(), NETLINK_GENERIC, 0);
if (file_descriptor_ == Sockets::kInvalidFileDescriptor)
return false;
SLOG(this, 2) << "Netlink socket started";
return true;
}
bool NetlinkSocket::RecvMessage(ByteString* message) {
if (!message) {
LOG(ERROR) << "Null |message|";
return false;
}
// Determine the amount of data currently waiting.
const size_t kDummyReadByteCount = 1;
ByteString dummy_read(kDummyReadByteCount);
ssize_t result;
result = sockets_->RecvFrom(
file_descriptor_,
dummy_read.GetData(),
dummy_read.GetLength(),
MSG_TRUNC | MSG_PEEK,
nullptr,
nullptr);
if (result < 0) {
PLOG(ERROR) << "Socket recvfrom failed.";
return false;
}
// Read the data that was waiting when we did our previous read.
message->Resize(result);
result = sockets_->RecvFrom(
file_descriptor_,
message->GetData(),
message->GetLength(),
0,
nullptr,
nullptr);
if (result < 0) {
PLOG(ERROR) << "Second socket recvfrom failed.";
return false;
}
return true;
}
bool NetlinkSocket::SendMessage(const ByteString& out_msg) {
ssize_t result = sockets_->Send(file_descriptor(), out_msg.GetConstData(),
out_msg.GetLength(), 0);
if (!result) {
PLOG(ERROR) << "Send failed.";
return false;
}
if (result != static_cast<ssize_t>(out_msg.GetLength())) {
LOG(ERROR) << "Only sent " << result << " bytes out of "
<< out_msg.GetLength() << ".";
return false;
}
return true;
}
bool NetlinkSocket::SubscribeToEvents(uint32_t group_id) {
int err = setsockopt(file_descriptor_, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP,
&group_id, sizeof(group_id));
if (err < 0) {
PLOG(ERROR) << "setsockopt didn't work.";
return false;
}
return true;
}
uint32_t NetlinkSocket::GetSequenceNumber() {
if (++sequence_number_ == NetlinkMessage::kBroadcastSequenceNumber)
++sequence_number_;
return sequence_number_;
}
} // namespace shill.