// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "shill/net/netlink_socket.h"

#include <linux/if_packet.h>
#include <linux/netlink.h>
#include <sys/socket.h>

#include <string>

#include <base/logging.h>

#include "shill/logging.h"
#include "shill/net/netlink_fd.h"
#include "shill/net/netlink_message.h"
#include "shill/net/sockets.h"

// This is from a version of linux/socket.h that we don't have.
#define SOL_NETLINK 270

namespace shill {

namespace Logging {
static auto kModuleLogScope = ScopeLogger::kRTNL;
static std::string ObjectID(const NetlinkSocket* obj) {
  return "(netlink_socket)";
}
}  // namespace Logging

NetlinkSocket::NetlinkSocket()
    : sequence_number_(0), file_descriptor_(Sockets::kInvalidFileDescriptor) {}

NetlinkSocket::~NetlinkSocket() {
  if (sockets_ && (file_descriptor_ >= 0)) {
    sockets_->Close(file_descriptor_);
  }
}

bool NetlinkSocket::Init() {
  // Allows for a test to set |sockets_| before calling |Init|.
  if (sockets_) {
    LOG(INFO) << "|sockets_| already has a value -- this must be a test.";
  } else {
    sockets_.reset(new Sockets);
  }

  file_descriptor_ = OpenNetlinkSocketFD(sockets_.get(), NETLINK_GENERIC, 0);
  if (file_descriptor_ == Sockets::kInvalidFileDescriptor)
    return false;

  SLOG(this, 2) << "Netlink socket started";
  return true;
}

bool NetlinkSocket::RecvMessage(ByteString* message) {
  if (!message) {
    LOG(ERROR) << "Null |message|";
    return false;
  }

  // Determine the amount of data currently waiting.
  const size_t kDummyReadByteCount = 1;
  ByteString dummy_read(kDummyReadByteCount);
  ssize_t result;
  result = sockets_->RecvFrom(file_descriptor_, dummy_read.GetData(),
                              dummy_read.GetLength(), MSG_TRUNC | MSG_PEEK,
                              nullptr, nullptr);
  if (result < 0) {
    PLOG(ERROR) << "Socket recvfrom failed.";
    return false;
  }

  // Read the data that was waiting when we did our previous read.
  message->Resize(result);
  result = sockets_->RecvFrom(file_descriptor_, message->GetData(),
                              message->GetLength(), 0, nullptr, nullptr);
  if (result < 0) {
    PLOG(ERROR) << "Second socket recvfrom failed.";
    return false;
  }
  return true;
}

bool NetlinkSocket::SendMessage(const ByteString& out_msg) {
  ssize_t result = sockets_->Send(file_descriptor(), out_msg.GetConstData(),
                                  out_msg.GetLength(), 0);
  if (!result) {
    PLOG(ERROR) << "Send failed.";
    return false;
  }
  if (result != static_cast<ssize_t>(out_msg.GetLength())) {
    LOG(ERROR) << "Only sent " << result << " bytes out of "
               << out_msg.GetLength() << ".";
    return false;
  }

  return true;
}

bool NetlinkSocket::SubscribeToEvents(uint32_t group_id) {
  int err = setsockopt(file_descriptor_, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP,
                       &group_id, sizeof(group_id));
  if (err < 0) {
    PLOG(ERROR) << "setsockopt didn't work.";
    return false;
  }
  return true;
}

uint32_t NetlinkSocket::GetSequenceNumber() {
  if (++sequence_number_ == NetlinkMessage::kBroadcastSequenceNumber)
    ++sequence_number_;
  return sequence_number_;
}

}  // namespace shill.
