arc: networkd: Refactor epoll loop.

BUG=None
TEST=manual autotests and tasts

Change-Id: I245548a3ce44b762f6c572b8e3a8b9bbe78e98fb
Reviewed-on: https://chromium-review.googlesource.com/1617169
Commit-Ready: Garrick Evans <garrick@chromium.org>
Tested-by: Garrick Evans <garrick@chromium.org>
Legacy-Commit-Queue: Commit Bot <commit-bot@chromium.org>
Reviewed-by: Hugo Benichi <hugobenichi@google.com>
diff --git a/arc/network/socket_forwarder.cc b/arc/network/socket_forwarder.cc
index 4fbc641..f911602 100644
--- a/arc/network/socket_forwarder.cc
+++ b/arc/network/socket_forwarder.cc
@@ -20,7 +20,7 @@
 
 namespace arc_networkd {
 namespace {
-constexpr int kBufSize = 4096;
+constexpr int kWaitTimeoutMs = 1000;
 // Maximum number of epoll events to process per wait.
 constexpr int kMaxEvents = 4;
 }  // namespace
@@ -30,14 +30,24 @@
                                  std::unique_ptr<Socket> sock1)
     : base::SimpleThread(name),
       sock0_(std::move(sock0)),
-      sock1_(std::move(sock1)) {}
+      sock1_(std::move(sock1)),
+      len0_(0),
+      len1_(0),
+      poll_(false),
+      done_(false) {
+  DCHECK(sock0_);
+  DCHECK(sock1_);
+}
 
 SocketForwarder::~SocketForwarder() {
-  sock1_.reset();
-  sock0_.reset();
+  Stop();
   Join();
 }
 
+bool SocketForwarder::IsValid() const {
+  return !done_;
+}
+
 void SocketForwarder::Run() {
   LOG(INFO) << "Starting forwarder: " << *sock0_ << " <-> " << *sock1_;
 
@@ -50,142 +60,155 @@
     return;
   }
 
-  base::ScopedFD efd(epoll_create1(0));
-  if (!efd.is_valid()) {
+  Poll();
+  Stop();
+
+  sock1_.reset();
+  sock0_.reset();
+}
+
+void SocketForwarder::Stop() {
+  if (done_)
+    return;
+
+  LOG(INFO) << "Stopping forwarder: " << *sock0_ << " <-> " << *sock1_;
+  poll_ = false;
+  done_ = true;
+}
+
+void SocketForwarder::Poll() {
+  base::ScopedFD cfd(epoll_create1(0));
+  if (!cfd.is_valid()) {
     PLOG(ERROR) << "epoll_create1 failed";
     return;
   }
   struct epoll_event ev;
   ev.events = EPOLLIN | EPOLLRDHUP;
   ev.data.fd = sock0_->fd();
-  if (epoll_ctl(efd.get(), EPOLL_CTL_ADD, sock0_->fd(), &ev) == -1) {
+  if (epoll_ctl(cfd.get(), EPOLL_CTL_ADD, sock0_->fd(), &ev) == -1) {
     PLOG(ERROR) << "epoll_ctl failed";
     return;
   }
   ev.data.fd = sock1_->fd();
-  if (epoll_ctl(efd.get(), EPOLL_CTL_ADD, sock1_->fd(), &ev) == -1) {
+  if (epoll_ctl(cfd.get(), EPOLL_CTL_ADD, sock1_->fd(), &ev) == -1) {
     PLOG(ERROR) << "epoll_ctl failed";
     return;
   }
 
+  poll_ = true;
   struct epoll_event events[kMaxEvents];
-  char buf0[kBufSize] = {0}, buf1[kBufSize] = {0};
-  ssize_t len0 = 0, len1 = 0;
-  bool run = true;
-  while (run) {
-    int n = epoll_wait(efd.get(), events, kMaxEvents, -1);
+  while (poll_) {
+    int n = epoll_wait(cfd.get(), events, kMaxEvents, kWaitTimeoutMs);
     if (n == -1) {
       PLOG(ERROR) << "epoll_wait failed";
-      break;
+      return;
     }
     for (int i = 0; i < n; ++i) {
-      if (events[i].events & EPOLLERR) {
-        PLOG(WARNING) << "Socket error: " << *sock0_ << " <-> " << *sock1_;
-        run = false;
-        break;
-      }
-      if (events[i].events & (EPOLLHUP | EPOLLRDHUP)) {
-        LOG(INFO) << "Peer closed connection: " << *sock0_ << " <-> "
-                  << *sock1_;
-        run = false;
-        break;
-      }
+      if (!ProcessEvents(events[i].events, events[i].data.fd, cfd.get()))
+        return;
+    }
+  }
+}
 
-      if (events[i].events & EPOLLOUT) {
-        Socket* dst;
-        char* buf;
-        ssize_t* len;
-        if (sock0_->fd() == events[i].data.fd) {
-          dst = sock0_.get();
-          buf = buf1;
-          len = &len1;
-        } else {
-          dst = sock1_.get();
-          buf = buf0;
-          len = &len0;
-        }
+bool SocketForwarder::ProcessEvents(uint32_t events, int efd, int cfd) {
+  if (events & EPOLLERR) {
+    PLOG(WARNING) << "Socket error: " << *sock0_ << " <-> " << *sock1_;
+    return false;
+  }
+  if (events & (EPOLLHUP | EPOLLRDHUP)) {
+    LOG(INFO) << "Peer closed connection: " << *sock0_ << " <-> " << *sock1_;
+    return false;
+  }
 
-        ssize_t bytes = dst->SendTo(buf, *len);
-        if (bytes < 0) {
-          run = false;
-          break;
-        }
-        // Still unavailable.
-        if (bytes == 0)
-          continue;
-        // Partial write.
-        if (bytes < *len)
-          memmove(&buf[0], &buf[bytes], *len - bytes);
-        *len -= bytes;
+  if (events & EPOLLOUT) {
+    Socket* dst;
+    char* buf;
+    ssize_t* len;
+    if (sock0_->fd() == efd) {
+      dst = sock0_.get();
+      buf = buf1_;
+      len = &len1_;
+    } else {
+      dst = sock1_.get();
+      buf = buf0_;
+      len = &len0_;
+    }
 
-        if (*len == 0) {
-          ev.events = EPOLLIN | EPOLLRDHUP;
-          ev.data.fd = dst->fd();
-          if (epoll_ctl(efd.get(), EPOLL_CTL_MOD, dst->fd(), &ev) == -1) {
-            PLOG(ERROR) << "epoll_ctl failed";
-            run = false;
-            break;
-          }
-        }
-      }
+    ssize_t bytes = dst->SendTo(buf, *len);
+    if (bytes < 0)
+      return false;
 
-      if (events[i].events & EPOLLIN) {
-        Socket *src, *dst;
-        char* buf;
-        ssize_t* len;
-        if (sock0_->fd() == events[i].data.fd) {
-          src = sock0_.get();
-          dst = sock1_.get();
-          buf = buf0;
-          len = &len0;
-        } else {
-          src = sock1_.get();
-          dst = sock0_.get();
-          buf = buf1;
-          len = &len1;
-        }
-        // Skip the read if this buffer is still pending write: requires that
-        // epoll_wait is in level-triggered mode.
-        if (*len > 0) {
-          continue;
-        }
-        *len = src->RecvFrom(buf, kBufSize);
-        if (*len < 0) {
-          run = false;
-          break;
-        }
-        if (*len == 0) {
-          continue;
-        }
+    // Still unavailable.
+    if (bytes == 0)
+      return true;
 
-        ssize_t bytes = dst->SendTo(buf, *len);
-        if (bytes < 0) {
-          run = false;
-          break;
-        }
-        if (bytes > 0) {
-          // Partial write.
-          if (bytes < *len)
-            memmove(&buf[0], &buf[bytes], *len - bytes);
-          *len -= bytes;
-        }
+    // Partial write.
+    if (bytes < *len)
+      memmove(&buf[0], &buf[bytes], *len - bytes);
+    *len -= bytes;
 
-        if (*len > 0) {
-          ev.events = EPOLLOUT | EPOLLRDHUP;
-          ev.data.fd = dst->fd();
-          if (epoll_ctl(efd.get(), EPOLL_CTL_MOD, dst->fd(), &ev) == -1) {
-            PLOG(ERROR) << "epoll_ctl failed";
-            run = false;
-            break;
-          }
-        }
+    if (*len == 0) {
+      struct epoll_event ev;
+      ev.events = EPOLLIN | EPOLLRDHUP;
+      ev.data.fd = dst->fd();
+      if (epoll_ctl(cfd, EPOLL_CTL_MOD, dst->fd(), &ev) == -1) {
+        PLOG(ERROR) << "epoll_ctl failed";
+        return false;
       }
     }
   }
 
-  LOG(INFO) << "Stopping forwarder: " << *sock0_ << " <-> " << *sock1_;
-  sock1_.reset();
-  sock0_.reset();
+  if (events & EPOLLIN) {
+    Socket *src, *dst;
+    char* buf;
+    ssize_t* len;
+    if (sock0_->fd() == efd) {
+      src = sock0_.get();
+      dst = sock1_.get();
+      buf = buf0_;
+      len = &len0_;
+    } else {
+      src = sock1_.get();
+      dst = sock0_.get();
+      buf = buf1_;
+      len = &len1_;
+    }
+
+    // Skip the read if this buffer is still pending write: requires that
+    // epoll_wait is in level-triggered mode.
+    if (*len > 0)
+      return true;
+
+    *len = src->RecvFrom(buf, kBufSize);
+    if (*len < 0)
+      return false;
+
+    if (*len == 0)
+      return true;
+
+    ssize_t bytes = dst->SendTo(buf, *len);
+    if (bytes < 0)
+      return false;
+
+    if (bytes > 0) {
+      // Partial write.
+      if (bytes < *len)
+        memmove(&buf[0], &buf[bytes], *len - bytes);
+      *len -= bytes;
+    }
+
+    if (*len > 0) {
+      struct epoll_event ev;
+      ev.events = EPOLLOUT | EPOLLRDHUP;
+      ev.data.fd = dst->fd();
+      if (epoll_ctl(cfd, EPOLL_CTL_MOD, dst->fd(), &ev) == -1) {
+        PLOG(ERROR) << "epoll_ctl failed";
+        return false;
+      }
+    }
+  }
+
+  return true;
 }
 
 }  // namespace arc_networkd
diff --git a/arc/network/socket_forwarder.h b/arc/network/socket_forwarder.h
index 19288cd..2d98ff6 100644
--- a/arc/network/socket_forwarder.h
+++ b/arc/network/socket_forwarder.h
@@ -8,6 +8,7 @@
 #include <netinet/ip.h>
 #include <sys/socket.h>
 
+#include <atomic>
 #include <memory>
 #include <string>
 
@@ -18,7 +19,6 @@
 #include "arc/network/socket.h"
 
 namespace arc_networkd {
-
 // Forwards data between a pair of sockets.
 // This is a simple implementation as a thread main function.
 class SocketForwarder : public base::SimpleThread {
@@ -31,11 +31,24 @@
   // Runs the forwarder. The sockets are closed and released on exit,
   // so this can only be run once.
   void Run() override;
-  bool IsValid() const { return sock0_ && sock1_; }
+  bool IsValid() const;
 
  private:
+  static constexpr int kBufSize = 4096;
+
+  void Poll();
+  void Stop();
+  bool ProcessEvents(uint32_t events, int efd, int cfd);
+
   std::unique_ptr<Socket> sock0_;
   std::unique_ptr<Socket> sock1_;
+  char buf0_[kBufSize] = {0};
+  char buf1_[kBufSize] = {0};
+  ssize_t len0_;
+  ssize_t len1_;
+
+  std::atomic<bool> poll_;
+  std::atomic<bool> done_;
 
   DISALLOW_COPY_AND_ASSIGN(SocketForwarder);
 };