arc: networkd: Refactor epoll loop.
BUG=None
TEST=manual autotests and tasts
Change-Id: I245548a3ce44b762f6c572b8e3a8b9bbe78e98fb
Reviewed-on: https://chromium-review.googlesource.com/1617169
Commit-Ready: Garrick Evans <garrick@chromium.org>
Tested-by: Garrick Evans <garrick@chromium.org>
Legacy-Commit-Queue: Commit Bot <commit-bot@chromium.org>
Reviewed-by: Hugo Benichi <hugobenichi@google.com>
diff --git a/arc/network/socket_forwarder.cc b/arc/network/socket_forwarder.cc
index 4fbc641..f911602 100644
--- a/arc/network/socket_forwarder.cc
+++ b/arc/network/socket_forwarder.cc
@@ -20,7 +20,7 @@
namespace arc_networkd {
namespace {
-constexpr int kBufSize = 4096;
+constexpr int kWaitTimeoutMs = 1000;
// Maximum number of epoll events to process per wait.
constexpr int kMaxEvents = 4;
} // namespace
@@ -30,14 +30,24 @@
std::unique_ptr<Socket> sock1)
: base::SimpleThread(name),
sock0_(std::move(sock0)),
- sock1_(std::move(sock1)) {}
+ sock1_(std::move(sock1)),
+ len0_(0),
+ len1_(0),
+ poll_(false),
+ done_(false) {
+ DCHECK(sock0_);
+ DCHECK(sock1_);
+}
SocketForwarder::~SocketForwarder() {
- sock1_.reset();
- sock0_.reset();
+ Stop();
Join();
}
+bool SocketForwarder::IsValid() const {
+ return !done_;
+}
+
void SocketForwarder::Run() {
LOG(INFO) << "Starting forwarder: " << *sock0_ << " <-> " << *sock1_;
@@ -50,142 +60,155 @@
return;
}
- base::ScopedFD efd(epoll_create1(0));
- if (!efd.is_valid()) {
+ Poll();
+ Stop();
+
+ sock1_.reset();
+ sock0_.reset();
+}
+
+void SocketForwarder::Stop() {
+ if (done_)
+ return;
+
+ LOG(INFO) << "Stopping forwarder: " << *sock0_ << " <-> " << *sock1_;
+ poll_ = false;
+ done_ = true;
+}
+
+void SocketForwarder::Poll() {
+ base::ScopedFD cfd(epoll_create1(0));
+ if (!cfd.is_valid()) {
PLOG(ERROR) << "epoll_create1 failed";
return;
}
struct epoll_event ev;
ev.events = EPOLLIN | EPOLLRDHUP;
ev.data.fd = sock0_->fd();
- if (epoll_ctl(efd.get(), EPOLL_CTL_ADD, sock0_->fd(), &ev) == -1) {
+ if (epoll_ctl(cfd.get(), EPOLL_CTL_ADD, sock0_->fd(), &ev) == -1) {
PLOG(ERROR) << "epoll_ctl failed";
return;
}
ev.data.fd = sock1_->fd();
- if (epoll_ctl(efd.get(), EPOLL_CTL_ADD, sock1_->fd(), &ev) == -1) {
+ if (epoll_ctl(cfd.get(), EPOLL_CTL_ADD, sock1_->fd(), &ev) == -1) {
PLOG(ERROR) << "epoll_ctl failed";
return;
}
+ poll_ = true;
struct epoll_event events[kMaxEvents];
- char buf0[kBufSize] = {0}, buf1[kBufSize] = {0};
- ssize_t len0 = 0, len1 = 0;
- bool run = true;
- while (run) {
- int n = epoll_wait(efd.get(), events, kMaxEvents, -1);
+ while (poll_) {
+ int n = epoll_wait(cfd.get(), events, kMaxEvents, kWaitTimeoutMs);
if (n == -1) {
PLOG(ERROR) << "epoll_wait failed";
- break;
+ return;
}
for (int i = 0; i < n; ++i) {
- if (events[i].events & EPOLLERR) {
- PLOG(WARNING) << "Socket error: " << *sock0_ << " <-> " << *sock1_;
- run = false;
- break;
- }
- if (events[i].events & (EPOLLHUP | EPOLLRDHUP)) {
- LOG(INFO) << "Peer closed connection: " << *sock0_ << " <-> "
- << *sock1_;
- run = false;
- break;
- }
+ if (!ProcessEvents(events[i].events, events[i].data.fd, cfd.get()))
+ return;
+ }
+ }
+}
- if (events[i].events & EPOLLOUT) {
- Socket* dst;
- char* buf;
- ssize_t* len;
- if (sock0_->fd() == events[i].data.fd) {
- dst = sock0_.get();
- buf = buf1;
- len = &len1;
- } else {
- dst = sock1_.get();
- buf = buf0;
- len = &len0;
- }
+bool SocketForwarder::ProcessEvents(uint32_t events, int efd, int cfd) {
+ if (events & EPOLLERR) {
+ PLOG(WARNING) << "Socket error: " << *sock0_ << " <-> " << *sock1_;
+ return false;
+ }
+ if (events & (EPOLLHUP | EPOLLRDHUP)) {
+ LOG(INFO) << "Peer closed connection: " << *sock0_ << " <-> " << *sock1_;
+ return false;
+ }
- ssize_t bytes = dst->SendTo(buf, *len);
- if (bytes < 0) {
- run = false;
- break;
- }
- // Still unavailable.
- if (bytes == 0)
- continue;
- // Partial write.
- if (bytes < *len)
- memmove(&buf[0], &buf[bytes], *len - bytes);
- *len -= bytes;
+ if (events & EPOLLOUT) {
+ Socket* dst;
+ char* buf;
+ ssize_t* len;
+ if (sock0_->fd() == efd) {
+ dst = sock0_.get();
+ buf = buf1_;
+ len = &len1_;
+ } else {
+ dst = sock1_.get();
+ buf = buf0_;
+ len = &len0_;
+ }
- if (*len == 0) {
- ev.events = EPOLLIN | EPOLLRDHUP;
- ev.data.fd = dst->fd();
- if (epoll_ctl(efd.get(), EPOLL_CTL_MOD, dst->fd(), &ev) == -1) {
- PLOG(ERROR) << "epoll_ctl failed";
- run = false;
- break;
- }
- }
- }
+ ssize_t bytes = dst->SendTo(buf, *len);
+ if (bytes < 0)
+ return false;
- if (events[i].events & EPOLLIN) {
- Socket *src, *dst;
- char* buf;
- ssize_t* len;
- if (sock0_->fd() == events[i].data.fd) {
- src = sock0_.get();
- dst = sock1_.get();
- buf = buf0;
- len = &len0;
- } else {
- src = sock1_.get();
- dst = sock0_.get();
- buf = buf1;
- len = &len1;
- }
- // Skip the read if this buffer is still pending write: requires that
- // epoll_wait is in level-triggered mode.
- if (*len > 0) {
- continue;
- }
- *len = src->RecvFrom(buf, kBufSize);
- if (*len < 0) {
- run = false;
- break;
- }
- if (*len == 0) {
- continue;
- }
+ // Still unavailable.
+ if (bytes == 0)
+ return true;
- ssize_t bytes = dst->SendTo(buf, *len);
- if (bytes < 0) {
- run = false;
- break;
- }
- if (bytes > 0) {
- // Partial write.
- if (bytes < *len)
- memmove(&buf[0], &buf[bytes], *len - bytes);
- *len -= bytes;
- }
+ // Partial write.
+ if (bytes < *len)
+ memmove(&buf[0], &buf[bytes], *len - bytes);
+ *len -= bytes;
- if (*len > 0) {
- ev.events = EPOLLOUT | EPOLLRDHUP;
- ev.data.fd = dst->fd();
- if (epoll_ctl(efd.get(), EPOLL_CTL_MOD, dst->fd(), &ev) == -1) {
- PLOG(ERROR) << "epoll_ctl failed";
- run = false;
- break;
- }
- }
+ if (*len == 0) {
+ struct epoll_event ev;
+ ev.events = EPOLLIN | EPOLLRDHUP;
+ ev.data.fd = dst->fd();
+ if (epoll_ctl(cfd, EPOLL_CTL_MOD, dst->fd(), &ev) == -1) {
+ PLOG(ERROR) << "epoll_ctl failed";
+ return false;
}
}
}
- LOG(INFO) << "Stopping forwarder: " << *sock0_ << " <-> " << *sock1_;
- sock1_.reset();
- sock0_.reset();
+ if (events & EPOLLIN) {
+ Socket *src, *dst;
+ char* buf;
+ ssize_t* len;
+ if (sock0_->fd() == efd) {
+ src = sock0_.get();
+ dst = sock1_.get();
+ buf = buf0_;
+ len = &len0_;
+ } else {
+ src = sock1_.get();
+ dst = sock0_.get();
+ buf = buf1_;
+ len = &len1_;
+ }
+
+ // Skip the read if this buffer is still pending write: requires that
+ // epoll_wait is in level-triggered mode.
+ if (*len > 0)
+ return true;
+
+ *len = src->RecvFrom(buf, kBufSize);
+ if (*len < 0)
+ return false;
+
+ if (*len == 0)
+ return true;
+
+ ssize_t bytes = dst->SendTo(buf, *len);
+ if (bytes < 0)
+ return false;
+
+ if (bytes > 0) {
+ // Partial write.
+ if (bytes < *len)
+ memmove(&buf[0], &buf[bytes], *len - bytes);
+ *len -= bytes;
+ }
+
+ if (*len > 0) {
+ struct epoll_event ev;
+ ev.events = EPOLLOUT | EPOLLRDHUP;
+ ev.data.fd = dst->fd();
+ if (epoll_ctl(cfd, EPOLL_CTL_MOD, dst->fd(), &ev) == -1) {
+ PLOG(ERROR) << "epoll_ctl failed";
+ return false;
+ }
+ }
+ }
+
+ return true;
}
} // namespace arc_networkd
diff --git a/arc/network/socket_forwarder.h b/arc/network/socket_forwarder.h
index 19288cd..2d98ff6 100644
--- a/arc/network/socket_forwarder.h
+++ b/arc/network/socket_forwarder.h
@@ -8,6 +8,7 @@
#include <netinet/ip.h>
#include <sys/socket.h>
+#include <atomic>
#include <memory>
#include <string>
@@ -18,7 +19,6 @@
#include "arc/network/socket.h"
namespace arc_networkd {
-
// Forwards data between a pair of sockets.
// This is a simple implementation as a thread main function.
class SocketForwarder : public base::SimpleThread {
@@ -31,11 +31,24 @@
// Runs the forwarder. The sockets are closed and released on exit,
// so this can only be run once.
void Run() override;
- bool IsValid() const { return sock0_ && sock1_; }
+ bool IsValid() const;
private:
+ static constexpr int kBufSize = 4096;
+
+ void Poll();
+ void Stop();
+ bool ProcessEvents(uint32_t events, int efd, int cfd);
+
std::unique_ptr<Socket> sock0_;
std::unique_ptr<Socket> sock1_;
+ char buf0_[kBufSize] = {0};
+ char buf1_[kBufSize] = {0};
+ ssize_t len0_;
+ ssize_t len1_;
+
+ std::atomic<bool> poll_;
+ std::atomic<bool> done_;
DISALLOW_COPY_AND_ASSIGN(SocketForwarder);
};