blob: 8d90a611ff8573442ead6b7f346a922af0ed9a26 [file] [log] [blame] [edit]
// Copyright 2024 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <stdio.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <string>
#include <base/command_line.h>
#include <base/files/file_util.h>
#include <base/functional/bind.h>
#include <base/memory/raw_ref.h>
#include <base/run_loop.h>
#include <base/task/single_thread_task_executor.h>
#include <base/task/single_thread_task_runner.h>
#include <base/task/thread_pool/thread_pool_instance.h>
#include <base/time/time.h>
#include <base/uuid.h>
#include <brillo/flag_helper.h>
#include <brillo/syslog_logging.h>
#include <chromeos/mojo/service_constants.h>
#include <mojo/core/embedder/embedder.h>
#include <mojo/core/embedder/scoped_ipc_support.h>
#include <mojo/public/cpp/bindings/remote.h>
#include <mojo_service_manager/lib/connect.h>
#include "odml/mojom/embedding_model.mojom.h"
#include "odml/mojom/on_device_model.mojom.h"
#include "odml/mojom/on_device_model_service.mojom.h"
namespace {
using ::embedding_model::mojom::GenerateEmbeddingRequest;
using ::embedding_model::mojom::OnDeviceEmbeddingModel;
using ::embedding_model::mojom::OnDeviceEmbeddingModelInferenceError;
using ::embedding_model::mojom::OnDeviceEmbeddingModelService;
using ::embedding_model::mojom::TaskType;
using ::on_device_model::mojom::LoadModelResult;
// Switches for command line.
constexpr const char kUuid[] = "uuid";
constexpr const char kGenerateEmbedding[] = "generate_embedding";
constexpr const char kContent[] = "content";
constexpr const char kTaskType[] = "task_type";
constexpr const char kTruncateInput[] = "truncate_input";
constexpr const char kBinaryOutput[] = "binary_output";
// --benchmark specifies benchmark mode.
constexpr const char kBenchmark[] = "benchmark";
// --benchmark_run_count is the number of inference we'll make.
constexpr const char kBenchmarkRunCount[] = "benchmark_run_count";
// --benchmark_max_seconds is the approximate maximum time we'll run. If we
// cannot finish --benchmark_run_count within the specified time, we'll cut it
// short.
constexpr const char kBenchmarkMaxSeconds[] = "benchmark_max_seconds";
const constexpr char* kBenchmarkPrompts[] = {
"Tokyo City Guide - What to do in Tokyo",
"numpy.stack() in Python",
"Mountain View, CA - Google Map",
"Best Hiking Trails near San Francisco",
"Understanding Quantum Mechanics",
"Top 10 Restaurants in Paris",
"How to Bake a Perfect Chocolate Cake",
"Learning to Play the Guitar for Beginners",
"The History of the Roman Empire",
"The Benefits of Meditation",
"Effective Time Management Strategies",
"The Art of Public Speaking",
"Climate Change and its Impact",
"The Future of Artificial Intelligence",
"Space Exploration and the Search for Life",
"The Wonders of the Natural World",
"Investing in the Stock Market",
"The Psychology of Human Behavior",
"The Power of Positive Thinking",
"How to Write a Compelling Story"};
constexpr size_t kBenchmarkPromptsSize =
sizeof(kBenchmarkPrompts) / sizeof(kBenchmarkPrompts[0]);
TaskType GetTaskTypeFromString(std::string s) {
if (s == "clustering") {
return TaskType::kClustering;
}
LOG(FATAL) << "Unknown TaskType: " << s;
}
constexpr int kMaxPrintEmbeddingCount = 16;
void PrintEmbedding(const std::vector<float>& embedding) {
printf("Embedding: [\n");
for (int i = 0; i < std::min(static_cast<int>(embedding.size()),
kMaxPrintEmbeddingCount);
i++) {
printf(" %.2f,", embedding[i]);
if ((i + 1) % 4 == 0) {
printf("\n");
}
}
printf(" ]\n");
}
void WriteEmbedding(const std::vector<float>& embedding,
const std::string& path) {
std::vector<uint8_t> embedding_content(embedding.size() * sizeof(float), 0);
std::copy(
reinterpret_cast<const uint8_t*>(embedding.data()),
reinterpret_cast<const uint8_t*>(embedding.data() + embedding.size()),
embedding_content.begin());
bool ret = base::WriteFile(base::FilePath(path), embedding_content);
CHECK(ret) << "Failed to write file.";
}
class EmbeddingBenchmark {
public:
EmbeddingBenchmark(raw_ref<mojo::Remote<OnDeviceEmbeddingModel>> model,
TaskType task_type)
: model_(model), task_type_(task_type) {}
void Run(int run_count,
int max_seconds,
const std::optional<base::FilePath>& output_json_path);
private:
// How many times do we want to run during the benchmark?
int target_run_count_ = -1;
// If we hit this amount of seconds, we'll cut the benchmark short.
// This is needed because certain devices (such as CPU) is much slower and we
// want to benchmark with a maximum time instead of specified number of tries.
int max_seconds_;
// How many times has we called GenerateEmbedding()?
int launched_count_;
// How many times has GenerateEmbedding() returned a result?
int finish_count_;
// How many GenerateEmbedding() calls are invoked but not yet finished?
int in_flight_count_;
// Has max_seconds_ been exceeded?
bool has_been_cut_short_;
// The start of benchmarking.
base::TimeTicks start_;
// End of benchmarking.
base::TimeTicks end_;
raw_ref<mojo::Remote<OnDeviceEmbeddingModel>> model_;
TaskType task_type_;
// If not empty, will write the result to the specified path.
std::optional<base::FilePath> output_json_path_;
std::unique_ptr<base::RunLoop> run_loop_;
// This is the maximum number of concurrent request send to the embedding
// service backend at the same time. A value of 2 is picked at the moment
// because we want to ensure it's larger than 1 so that the roundtrip IPC
// latency doesn't skew the benchmark result and it's not too large so we can
// efficiently cut-off the benchmarking due to time limit. Also note that
// we've a maximum in-flight request count because we're measuring throughput
// and not latency. If we decide to measure latency in the future then we'll
// need to remove such cap so as to not distort the percentile latency figures
// contributed by queueing delay.
static constexpr int kMaxInflightCount = 2;
std::string GetContent(int idx);
void LaunchOne();
void OnFinish(OnDeviceEmbeddingModelInferenceError error,
const std::vector<float>& embeddings);
void PrintStats();
};
void EmbeddingBenchmark::Run(
int run_count,
int max_seconds,
const std::optional<base::FilePath>& output_json_path) {
CHECK_EQ(target_run_count_, -1); // run() should only be called once.
target_run_count_ = run_count;
max_seconds_ = max_seconds;
launched_count_ = 0;
finish_count_ = 0;
in_flight_count_ = 0;
has_been_cut_short_ = false;
output_json_path_ = output_json_path;
run_loop_ = std::make_unique<base::RunLoop>();
start_ = base::TimeTicks::Now();
LaunchOne();
run_loop_->Run();
PrintStats();
}
std::string EmbeddingBenchmark::GetContent(int idx) {
idx = idx % kBenchmarkPromptsSize;
return std::string(kBenchmarkPrompts[idx]);
}
void EmbeddingBenchmark::LaunchOne() {
GenerateEmbeddingRequest generate_embedding_request;
generate_embedding_request.content = GetContent(launched_count_);
generate_embedding_request.task_type = task_type_;
generate_embedding_request.truncate_input = true;
in_flight_count_++;
launched_count_++;
(*model_)->GenerateEmbedding(
generate_embedding_request.Clone(),
base::BindOnce(&EmbeddingBenchmark::OnFinish, base::Unretained(this)));
}
void EmbeddingBenchmark::OnFinish(OnDeviceEmbeddingModelInferenceError error,
const std::vector<float>& embeddings) {
base::TimeTicks current_time = base::TimeTicks::Now();
in_flight_count_--;
finish_count_++;
CHECK_EQ(error, OnDeviceEmbeddingModelInferenceError::kSuccess);
if (!has_been_cut_short_ &&
((current_time - start_) > base::Seconds(max_seconds_))) {
// We've not finished yet, we need to cut it short and let all in-flight
// requests finish.
target_run_count_ = launched_count_;
has_been_cut_short_ = true;
}
if (finish_count_ >= target_run_count_) {
CHECK_EQ(finish_count_, target_run_count_);
end_ = base::TimeTicks::Now();
run_loop_->Quit();
return;
}
if (launched_count_ < target_run_count_ &&
in_flight_count_ < kMaxInflightCount) {
LaunchOne();
}
}
void EmbeddingBenchmark::PrintStats() {
base::TimeDelta run_time = end_ - start_;
std::cout << "Embedding benchmark result: " << run_time.InMilliseconds()
<< " ms for " << target_run_count_ << " invocations" << std::endl;
if (output_json_path_.has_value()) {
std::string json_out =
"{\"runtime_ms\": " + std::to_string(run_time.InMilliseconds()) +
", \"count\": " + std::to_string(target_run_count_) + "}\n";
bool ret = base::WriteFile(output_json_path_.value(), json_out);
CHECK(ret) << "Failed to write JSON result file.";
}
}
void RunBenchmark(raw_ref<mojo::Remote<OnDeviceEmbeddingModel>> model,
base::CommandLine* cl) {
EmbeddingBenchmark bench(
model, GetTaskTypeFromString(cl->GetSwitchValueASCII(kTaskType)));
int run_count = 16;
int max_seconds = 20;
if (!cl->GetSwitchValueASCII(kBenchmarkRunCount).empty()) {
run_count = std::stoi(cl->GetSwitchValueASCII(kBenchmarkRunCount));
}
if (!cl->GetSwitchValueASCII(kBenchmarkMaxSeconds).empty()) {
max_seconds = std::stoi(cl->GetSwitchValueASCII(kBenchmarkMaxSeconds));
}
std::optional<base::FilePath> output_json_path;
std::string output_json_option = cl->GetSwitchValueASCII(kBenchmark);
if (!output_json_option.empty()) {
output_json_path = base::FilePath(output_json_option);
}
bench.Run(run_count, max_seconds, output_json_path);
}
} // namespace
int main(int argc, char** argv) {
// Setup command line and logging.
base::CommandLine::Init(argc, argv);
brillo::InitLog(brillo::kLogToSyslog | brillo::kLogToStderrIfTty);
base::CommandLine* cl = base::CommandLine::ForCurrentProcess();
const std::string uuid = cl->GetSwitchValueASCII(kUuid);
CHECK(!uuid.empty());
// Setup mojo
base::ThreadPoolInstance::CreateAndStartWithDefaultParams("thread_pool");
base::SingleThreadTaskExecutor io_task_executor(base::MessagePumpType::IO);
mojo::core::Init();
mojo::core::ScopedIPCSupport ipc_support(
base::SingleThreadTaskRunner::GetCurrentDefault(),
mojo::core::ScopedIPCSupport::ShutdownPolicy::CLEAN);
// Obtain a remote to the service.
mojo::Remote<OnDeviceEmbeddingModelService> service;
mojo::Remote<chromeos::mojo_service_manager::mojom::ServiceManager>
service_manager;
auto service_manager_remote =
chromeos::mojo_service_manager::ConnectToMojoServiceManager();
if (!service_manager_remote) {
LOG(ERROR) << "Failed to connect to Mojo Service Manager";
return -1;
}
service_manager.Bind(std::move(service_manager_remote));
service_manager.set_disconnect_with_reason_handler(
base::BindOnce([](uint32_t error, const std::string& message) {
LOG(INFO) << "Disconnected from mojo service manager (the mojo "
"broker process). Error: "
<< error << ", message: " << message
<< ". Shutdown and wait for respawn.";
}));
const base::TimeDelta kRemoteRequestTimeout = base::Milliseconds(10 * 1000);
service_manager->Request(
/*service_name=*/chromeos::mojo_services::kCrosEmbeddingModelService,
/*timeout=*/kRemoteRequestTimeout,
service.BindNewPipeAndPassReceiver().PassPipe());
mojo::Remote<OnDeviceEmbeddingModel> model;
{
base::RunLoop run_loop;
service->LoadEmbeddingModel(
base::Uuid::ParseLowercase(uuid), model.BindNewPipeAndPassReceiver(),
mojo::NullRemote(),
base::BindOnce(
[](base::RunLoop* run_loop, LoadModelResult result) {
if (result == LoadModelResult::kSuccess) {
LOG(INFO) << "LOADED";
} else {
LOG(ERROR) << "Fail";
exit(0);
}
run_loop->Quit();
},
&run_loop));
run_loop.Run();
}
if (cl->HasSwitch(kGenerateEmbedding) || cl->HasSwitch(kContent)) {
std::string content = cl->GetSwitchValueASCII(kContent);
GenerateEmbeddingRequest generate_embedding_request;
generate_embedding_request.content = content;
generate_embedding_request.task_type =
GetTaskTypeFromString(cl->GetSwitchValueASCII(kTaskType));
generate_embedding_request.truncate_input = false;
if (cl->HasSwitch(kTruncateInput)) {
generate_embedding_request.truncate_input = true;
}
{
base::RunLoop run_loop;
model->GenerateEmbedding(
generate_embedding_request.Clone(),
base::BindOnce(
[](base::RunLoop* run_loop, base::CommandLine* cl,
OnDeviceEmbeddingModelInferenceError error,
const std::vector<float>& embeddings) {
if (error == OnDeviceEmbeddingModelInferenceError::kSuccess) {
PrintEmbedding(embeddings);
if (cl->HasSwitch(kBinaryOutput)) {
WriteEmbedding(embeddings,
cl->GetSwitchValueASCII(kBinaryOutput));
}
} else {
LOG(ERROR)
<< "Failed to generate embedding, error: " << error;
}
run_loop->Quit();
},
&run_loop, base::Unretained(cl)));
run_loop.Run();
}
} else if (cl->HasSwitch(kBenchmark)) {
RunBenchmark(raw_ref(model), cl);
}
return 0;
}