blob: 9fa643e2c3a9eabd71765070da76b91cc1b6cf5b [file] [log] [blame]
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// This file is used for testing multiprocess related interface of
// `MachineLearningService`. One should consider migrating the test from
// "machine_learning_service_impl_test.cc" here after the interface is made
// of multiprocess.
#include <string>
#include <utility>
#include <vector>
#include <base/bind.h>
#include <base/macros.h>
#include <base/run_loop.h>
#include <base/test/bind.h>
#include <base/time/time.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <mojo/public/cpp/bindings/remote.h>
#include "ml/handwriting.h"
#include "ml/machine_learning_service_impl.h"
#include "ml/mojom/machine_learning_service.mojom.h"
#include "ml/mojom/web_platform_handwriting.mojom.h"
#include "ml/process.h"
#include "ml/test_utils.h"
#include "ml/web_platform_handwriting_proto_mojom_conversion.h"
namespace ml {
namespace {
// We intend not to using `chromeos::machine_learning::web_platform::mojom::*`
// to avoid confusion.
using ::chromeos::machine_learning::mojom::LoadHandwritingModelResult;
using ::chromeos::machine_learning::mojom::MachineLearningService;
// Points that are used to generate a stroke for handwriting.
constexpr float kHandwritingTestPoints[23][2] = {
{1.928, 0.827}, {1.828, 0.826}, {1.73, 0.858}, {1.667, 0.901},
{1.617, 0.955}, {1.567, 1.043}, {1.548, 1.148}, {1.569, 1.26},
{1.597, 1.338}, {1.641, 1.408}, {1.688, 1.463}, {1.783, 1.473},
{1.853, 1.418}, {1.897, 1.362}, {1.938, 1.278}, {1.968, 1.204},
{1.999, 1.112}, {2.003, 1.004}, {1.984, 0.905}, {1.988, 1.043},
{1.98, 1.178}, {1.976, 1.303}, {1.984, 1.415},
};
// A version of MachineLearningServiceImpl that loads from the testing model
// directory.
class MachineLearningServiceImplForTesting : public MachineLearningServiceImpl {
public:
// Pass an empty callback and use the testing model directory.
explicit MachineLearningServiceImplForTesting(
mojo::PendingReceiver<
chromeos::machine_learning::mojom::MachineLearningService> receiver)
: MachineLearningServiceImpl(
std::move(receiver), base::Closure(), GetTestModelDir()) {}
};
} // namespace
TEST(WebPlatformHandwritingModel, LoadModelAndRecognize) {
// Nothing to test on an unsupported platform.
if (!ml::HandwritingLibrary::IsHandwritingLibraryUnitTestSupported()) {
return;
}
// Loads a model.
base::RunLoop runloop;
// Sets the process to be control to test multiprocess code.
Process::GetInstance()->SetTypeForTesting(Process::Type::kControlForTest);
// Set the callback when the worker process has been reaped successfully. We
// need to quit the runloop here.
Process::GetInstance()->SetReapWorkerProcessSucceedCallbackForTesting(
base::BindLambdaForTesting([&]() { runloop.Quit(); }));
// Set the callback when the worker process fails to be reaped. We need to
// quit the runloop here. Also we should set a flag and report the error.
bool reap_worker_process_succeeded = true;
std::string reap_worker_process_fail_reason;
Process::GetInstance()->SetReapWorkerProcessFailCallbackForTesting(
base::BindLambdaForTesting([&](std::string reason) {
reap_worker_process_succeeded = false;
reap_worker_process_fail_reason = reason;
runloop.Quit();
}));
// Binds the disconnection handler. We need to quit the runloop here.
Process::GetInstance()->SetReapWorkerProcessSucceedCallbackForTesting(
base::BindLambdaForTesting([&]() { runloop.Quit(); }));
// Sets the mlservice binary path which should be at the same dir of the test
// binary.
Process::GetInstance()->SetMlServicePathForTesting(GetMlServicePath());
mojo::Remote<MachineLearningService> ml_service;
auto ml_service_impl = std::make_unique<MachineLearningServiceImplForTesting>(
ml_service.BindNewPipeAndPassReceiver());
// Tries to load a model.
mojo::Remote<
chromeos::machine_learning::web_platform::mojom::HandwritingRecognizer>
recognizer;
bool model_callback_done = false;
auto constraint = chromeos::machine_learning::web_platform::mojom::
HandwritingModelConstraint::New();
constraint->languages.push_back("en");
ml_service->LoadWebPlatformHandwritingModel(
std::move(constraint), recognizer.BindNewPipeAndPassReceiver(),
base::BindLambdaForTesting([&](const LoadHandwritingModelResult result) {
ASSERT_EQ(result, LoadHandwritingModelResult::OK);
// Check the worker process is registered.
EXPECT_EQ(Process::GetInstance()->GetWorkerPidInfoMap().size(), 1u);
// Check the worker process is alive.
pid_t worker_pid =
Process::GetInstance()->GetWorkerPidInfoMap().begin()->first;
ASSERT_GT(worker_pid, 0);
EXPECT_EQ(kill(worker_pid, 0), 0);
model_callback_done = true;
}));
// Tries to get the prediction result.
// Set default inputs.
auto hints =
chromeos::machine_learning::web_platform::mojom::HandwritingHints::New();
hints->alternatives = 1u;
auto stroke =
chromeos::machine_learning::web_platform::mojom::HandwritingStroke::New();
for (int i = 0; i < 23; ++i) {
auto point = chromeos::machine_learning::web_platform::mojom::
HandwritingPoint::New();
auto location = gfx::mojom::PointF::New();
location->x = kHandwritingTestPoints[i][0];
location->y = kHandwritingTestPoints[i][1];
point->location = std::move(location);
stroke->points.push_back(std::move(point));
}
std::vector<
chromeos::machine_learning::web_platform::mojom::HandwritingStrokePtr>
strokes;
strokes.push_back(std::move(stroke));
bool prediction_callback_done = false;
pid_t worker_pid = -1;
recognizer->GetPrediction(
std::move(strokes), std::move(hints),
base::BindLambdaForTesting(
[&](base::Optional<
std::vector<chromeos::machine_learning::web_platform::mojom::
HandwritingPredictionPtr>> predictions) {
// Check that the inference succeeded and gives
// the expected number of outputs.
ASSERT_TRUE(predictions.has_value());
ASSERT_EQ(predictions->size(), 1u);
EXPECT_EQ(predictions->at(0)->text, "a");
// Verify the worker process is registered.
EXPECT_EQ(Process::GetInstance()->GetWorkerPidInfoMap().size(), 1u);
worker_pid =
Process::GetInstance()->GetWorkerPidInfoMap().begin()->first;
// Verify the worker process is a different one.
ASSERT_NE(worker_pid, getpid());
// Check the worker process is alive.
ASSERT_GT(worker_pid, 0);
EXPECT_EQ(kill(worker_pid, 0), 0);
// Post a task to disconnect the mojom connection to test whether
// the worker process exits.
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE,
base::BindLambdaForTesting([&]() { recognizer.reset(); }));
prediction_callback_done = true;
}));
// For safety, sets a timeout of 5min. This is just to guarantee the test will
// not hang.
bool is_timeout = false;
base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
FROM_HERE, base::BindLambdaForTesting([&]() {
is_timeout = true;
runloop.Quit();
}),
base::TimeDelta::FromMilliseconds(1000 * 60 * 5));
runloop.Run();
// If timeout, the unit test failed.
ASSERT_FALSE(is_timeout);
// Fail the test if the worker process can not be reaped.
ASSERT_TRUE(reap_worker_process_succeeded) << reap_worker_process_fail_reason;
// Verify the worker process has exited.
EXPECT_NE(kill(worker_pid, 0), 0);
// Verify the worker process has been unregistered.
EXPECT_EQ(Process::GetInstance()->GetWorkerPidInfoMap().size(), 0u);
EXPECT_TRUE(model_callback_done);
EXPECT_TRUE(prediction_callback_done);
}
// This tests, on non-supported boards, the `LoadWebPlatformHandwritingModel`
// API does not crash.
TEST(WebPlatformHandwritingModel, NoCrashOnNonsupportedBoards) {
// Skip if ondevice HWR is supported. We do not need to worry about whether
// asan is enabled because dlopen will not be called in the test.
if (ml::HandwritingLibrary::IsHandwritingLibrarySupported()) {
return;
}
// Loads a model.
base::RunLoop runloop;
// Sets the process to be control to test multiprocess code.
// Note that we need to use `kSingleProcessForTest` because the worker
// process' crash does not fail the unit test.
Process::GetInstance()->SetTypeForTesting(
Process::Type::kSingleProcessForTest);
mojo::Remote<MachineLearningService> ml_service;
auto ml_service_impl = std::make_unique<MachineLearningServiceImplForTesting>(
ml_service.BindNewPipeAndPassReceiver());
// Tries to load a model.
mojo::Remote<
chromeos::machine_learning::web_platform::mojom::HandwritingRecognizer>
recognizer;
bool model_callback_done = false;
auto constraint = chromeos::machine_learning::web_platform::mojom::
HandwritingModelConstraint::New();
constraint->languages.push_back("en");
ml_service->LoadWebPlatformHandwritingModel(
std::move(constraint), recognizer.BindNewPipeAndPassReceiver(),
base::BindLambdaForTesting([&](const LoadHandwritingModelResult result) {
ASSERT_EQ(result, LoadHandwritingModelResult::LOAD_MODEL_ERROR);
model_callback_done = true;
runloop.Quit();
}));
runloop.Run();
EXPECT_TRUE(model_callback_done);
}
} // namespace ml