// Copyright 2019 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "patchpanel/socket_forwarder.h"

#include <arpa/inet.h>
#include <fcntl.h>
#include <netinet/ip.h>
#include <string.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>

#include <utility>

#include <base/bind.h>
#include <base/logging.h>
#include <base/time/time.h>

namespace patchpanel {
namespace {
constexpr int kWaitTimeoutMs = 1000;
// Maximum number of epoll events to process per wait.
constexpr int kMaxEvents = 4;

std::ostream& operator<<(std::ostream& stream,
                         const struct epoll_event& event) {
  stream << "{ fd: " << event.data.fd << ", events: 0x" << std::hex
         << event.events << "}";
  return stream;
}

bool SetPollEvents(Socket* socket, int cfd, uint32_t events) {
  struct epoll_event ev;
  ev.events = events;
  ev.data.fd = socket->fd();
  if (epoll_ctl(cfd, EPOLL_CTL_MOD, socket->fd(), &ev) == -1) {
    PLOG(ERROR) << "epoll_ctl(" << ev << ") failed";
    return false;
  }
  return true;
}

}  // namespace

SocketForwarder::SocketForwarder(const std::string& name,
                                 std::unique_ptr<Socket> sock0,
                                 std::unique_ptr<Socket> sock1)
    : base::SimpleThread(name),
      sock0_(std::move(sock0)),
      sock1_(std::move(sock1)),
      len0_(0),
      len1_(0),
      eof_(-1),
      poll_(false),
      done_(false) {
  DCHECK(sock0_);
  DCHECK(sock1_);
}

SocketForwarder::~SocketForwarder() {
  // Ensure the polling loop exits.
  poll_ = false;
  Join();
}

bool SocketForwarder::IsRunning() const {
  return !done_;
}

void SocketForwarder::SetStopQuitClosureForTesting(base::OnceClosure closure) {
  stop_quit_closure_for_testing_ = std::move(closure);
}

void SocketForwarder::Run() {
  LOG(INFO) << "Starting forwarder: " << *sock0_ << " <-> " << *sock1_;

  // We need these sockets to be non-blocking.
  if (fcntl(sock0_->fd(), F_SETFL,
            fcntl(sock0_->fd(), F_GETFL, 0) | O_NONBLOCK) < 0 ||
      fcntl(sock1_->fd(), F_SETFL,
            fcntl(sock1_->fd(), F_GETFL, 0) | O_NONBLOCK) < 0) {
    PLOG(ERROR) << "fcntl failed";
    if (stop_quit_closure_for_testing_)
      std::move(stop_quit_closure_for_testing_).Run();
    return;
  }

  Poll();

  LOG(INFO) << "Forwarder stopped: " << *sock0_ << " <-> " << *sock1_;
  done_ = true;
  sock1_.reset();
  sock0_.reset();
  if (stop_quit_closure_for_testing_)
    std::move(stop_quit_closure_for_testing_).Run();
}

void SocketForwarder::Poll() {
  base::ScopedFD cfd(epoll_create1(0));
  if (!cfd.is_valid()) {
    PLOG(ERROR) << "epoll_create1 failed";
    return;
  }
  struct epoll_event ev;
  ev.events = EPOLLIN;
  ev.data.fd = sock0_->fd();
  if (epoll_ctl(cfd.get(), EPOLL_CTL_ADD, sock0_->fd(), &ev) == -1) {
    PLOG(ERROR) << "epoll_ctl failed";
    return;
  }
  ev.data.fd = sock1_->fd();
  if (epoll_ctl(cfd.get(), EPOLL_CTL_ADD, sock1_->fd(), &ev) == -1) {
    PLOG(ERROR) << "epoll_ctl failed";
    return;
  }

  poll_ = true;
  struct epoll_event events[kMaxEvents];
  while (poll_) {
    int n = epoll_wait(cfd.get(), events, kMaxEvents, kWaitTimeoutMs);
    if (n == -1) {
      if (errno == EINTR) {
        LOG(INFO) << "Resume epoll_wait from interruption.";
        continue;
      }
      PLOG(ERROR) << "epoll_wait failed";
      return;
    }
    for (int i = 0; i < n; ++i) {
      if (!poll_ ||
          !ProcessEvents(events[i].events, events[i].data.fd, cfd.get()))
        return;
    }
  }
}

bool SocketForwarder::ProcessEvents(uint32_t events, int efd, int cfd) {
  if (events & EPOLLERR) {
    int so_error;
    socklen_t optlen = sizeof(so_error);
    getsockopt(efd, SOL_SOCKET, SO_ERROR, &so_error, &optlen);
    PLOG(WARNING) << "Socket error: (" << so_error << ") " << *sock0_ << " <-> "
                  << *sock1_;
    return false;
  }

  if (events & EPOLLOUT) {
    Socket* dst;
    char* buf;
    ssize_t* len;
    if (sock0_->fd() == efd) {
      dst = sock0_.get();
      buf = buf1_;
      len = &len1_;
    } else {
      dst = sock1_.get();
      buf = buf0_;
      len = &len0_;
    }

    ssize_t bytes = dst->SendTo(buf, *len);
    if (bytes < 0)
      return false;

    // Still unavailable.
    if (bytes == 0)
      return true;

    // Partial write.
    if (bytes < *len)
      memmove(&buf[0], &buf[bytes], *len - bytes);
    *len -= bytes;

    // If all the buffered data was written to the socket and the peer socket is
    // still open for writing, listen for read events on the socket.
    if (*len == 0 && eof_ != dst->fd() && !SetPollEvents(dst, cfd, EPOLLIN))
      return false;
  }

  Socket *src, *dst;
  char* buf;
  ssize_t* len;
  if (sock0_->fd() == efd) {
    src = sock0_.get();
    dst = sock1_.get();
    buf = buf0_;
    len = &len0_;
  } else {
    src = sock1_.get();
    dst = sock0_.get();
    buf = buf1_;
    len = &len1_;
  }

  // Skip the read if this buffer is still pending write: requires that
  // epoll_wait is in level-triggered mode.
  if (*len > 0)
    return true;

  if (events & EPOLLIN) {
    *len = src->RecvFrom(buf, kBufSize);
    if (*len < 0)
      return false;

    if (*len == 0)
      return HandleConnectionClosed(src, dst, cfd);

    ssize_t bytes = dst->SendTo(buf, *len);
    if (bytes < 0)
      return false;

    if (bytes > 0) {
      // Partial write.
      if (bytes < *len)
        memmove(&buf[0], &buf[bytes], *len - bytes);
      *len -= bytes;
    }

    if (*len > 0 && !SetPollEvents(dst, cfd, EPOLLOUT))
      return false;
  }

  if (events & EPOLLHUP) {
    LOG(INFO) << "Peer closed connection: " << *sock0_ << " <-> " << *sock1_;
    return false;
  }
  return true;
}

bool SocketForwarder::HandleConnectionClosed(Socket* src,
                                             Socket* dst,
                                             int cfd) {
  LOG(INFO) << "Peer closed connection: " << *src;
  if (eof_ == dst->fd()) {
    // Stop the forwarder since the other peer has already closed the
    // connection.
    LOG(INFO) << "Closed connection: " << *sock0_ << " <-> " << *sock1_;
    return false;
  }
  // Stop listening for read ready events from |src|.
  if (!SetPollEvents(src, cfd, 0))
    return false;

  // Propagate the shut down for writing to the other peer. This is safe
  // to do since reading the EOF on |src| only happens if the buffer
  // associated with the |src| socket if empty, so there's no outstanding
  // data to be written to |dst|.
  if (shutdown(dst->fd(), SHUT_WR) == -1) {
    PLOG(ERROR) << "Shutting down " << *socket << " for writing failed";
    return false;
  }

  eof_ = src->fd();
  return true;
}
}  // namespace patchpanel
