blob: e688afe5a5ecfc0bbfbaa1fbb4be4256a47fa2ba [file] [log] [blame] [edit]
// Copyright 2017 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "vm_tools/vsh/utils.h"
#include <errno.h>
#include <stddef.h>
#include <stdint.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/types.h>
#include <base/files/file_util.h>
#include <base/functional/bind.h>
#include <base/logging.h>
#include <base/posix/eintr_wrapper.h>
#include <brillo/message_loops/message_loop.h>
using google::protobuf::MessageLite;
namespace vm_tools {
namespace vsh {
namespace {
bool SendAllBytes(int sockfd, const uint8_t* buf, uint32_t buf_size) {
if (!base::WriteFileDescriptor(
sockfd, base::as_bytes(base::make_span(buf, buf_size)))) {
PLOG(ERROR) << "Failed to write message to socket";
return false;
}
return true;
}
void ShutdownTask() {
brillo::MessageLoop::current()->BreakLoop();
}
} // namespace
struct MessagePacket {
uint32_t size;
uint8_t payload[kMaxMessageSize];
};
static_assert(sizeof(MessagePacket) == sizeof(uint32_t) + kMaxMessageSize,
"MessagePacket must not have implicit paddings");
bool SendMessage(int sockfd, const MessageLite& message) {
size_t msg_size = message.ByteSizeLong();
if (msg_size > kMaxMessageSize) {
LOG(ERROR) << "Serialized message too large: " << msg_size;
return false;
}
// Pack size and data into a buffer to send them through 1 vsock packet
MessagePacket msg;
msg.size = htole32(msg_size);
if (!message.SerializeToArray(msg.payload, kMaxMessageSize)) {
LOG(ERROR) << "Failed to serialize message";
return false;
}
if (!SendAllBytes(sockfd, reinterpret_cast<uint8_t*>(&msg),
sizeof(uint32_t) + msg_size)) {
return false;
}
return true;
}
bool RecvMessage(int sockfd, MessageLite* message) {
MessagePacket msg;
if (!base::ReadFromFD(sockfd, reinterpret_cast<char*>(&msg.size),
sizeof(msg.size))) {
LOG(ERROR) << "Failed to read message from socket";
return false;
}
msg.size = le32toh(msg.size);
if (msg.size > kMaxMessageSize) {
LOG(ERROR) << "Message size of " << msg.size << " exceeds max message size "
<< kMaxMessageSize;
return false;
}
if (!base::ReadFromFD(sockfd, reinterpret_cast<char*>(msg.payload),
msg.size)) {
LOG(ERROR) << "Failed to read message from socket";
return false;
}
if (!message->ParseFromArray(msg.payload, msg.size)) {
LOG(ERROR) << "Failed to parse message";
return false;
}
return true;
}
// Posts a shutdown task to the main message loop.
void Shutdown() {
brillo::MessageLoop::current()->PostTask(FROM_HERE,
base::BindOnce(&ShutdownTask));
}
} // namespace vsh
} // namespace vm_tools