// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

// This file is used for testing multiprocess related interface of
// `MachineLearningService`. One should consider migrating the test from
// "machine_learning_service_impl_test.cc" here after the interface is made
// of multiprocess.

#include <string>
#include <utility>
#include <vector>

#include <base/bind.h>
#include <base/macros.h>
#include <base/run_loop.h>
#include <base/test/bind.h>
#include <base/time/time.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <mojo/public/cpp/bindings/remote.h>

#include "ml/handwriting.h"
#include "ml/machine_learning_service_impl.h"
#include "ml/mojom/machine_learning_service.mojom.h"
#include "ml/mojom/web_platform_handwriting.mojom.h"
#include "ml/process.h"
#include "ml/test_utils.h"
#include "ml/web_platform_handwriting_proto_mojom_conversion.h"

namespace ml {
namespace {

// We intend not to using `chromeos::machine_learning::web_platform::mojom::*`
// to avoid confusion.
using ::chromeos::machine_learning::mojom::LoadHandwritingModelResult;
using ::chromeos::machine_learning::mojom::MachineLearningService;

// Points that are used to generate a stroke for handwriting.
constexpr float kHandwritingTestPoints[23][2] = {
    {1.928, 0.827}, {1.828, 0.826}, {1.73, 0.858},  {1.667, 0.901},
    {1.617, 0.955}, {1.567, 1.043}, {1.548, 1.148}, {1.569, 1.26},
    {1.597, 1.338}, {1.641, 1.408}, {1.688, 1.463}, {1.783, 1.473},
    {1.853, 1.418}, {1.897, 1.362}, {1.938, 1.278}, {1.968, 1.204},
    {1.999, 1.112}, {2.003, 1.004}, {1.984, 0.905}, {1.988, 1.043},
    {1.98, 1.178},  {1.976, 1.303}, {1.984, 1.415},
};

// A version of MachineLearningServiceImpl that loads from the testing model
// directory.
class MachineLearningServiceImplForTesting : public MachineLearningServiceImpl {
 public:
  // Pass an empty callback and use the testing model directory.
  explicit MachineLearningServiceImplForTesting(
      mojo::PendingReceiver<
          chromeos::machine_learning::mojom::MachineLearningService> receiver)
      : MachineLearningServiceImpl(
            std::move(receiver), base::Closure(), GetTestModelDir()) {}
};

}  // namespace

TEST(WebPlatformHandwritingModel, LoadModelAndRecognize) {
  // Nothing to test on an unsupported platform.
  if (!ml::HandwritingLibrary::IsHandwritingLibraryUnitTestSupported()) {
    return;
  }

  // Loads a model.
  base::RunLoop runloop;

  // Sets the process to be control to test multiprocess code.
  Process::GetInstance()->SetTypeForTesting(Process::Type::kControlForTest);

  // Set the callback when the worker process has been reaped successfully. We
  // need to quit the runloop here.
  Process::GetInstance()->SetReapWorkerProcessSucceedCallbackForTesting(
      base::BindLambdaForTesting([&]() { runloop.Quit(); }));

  // Set the callback when the worker process fails to be reaped. We need to
  // quit the runloop here. Also we should set a flag and report the error.
  bool reap_worker_process_succeeded = true;
  std::string reap_worker_process_fail_reason;
  Process::GetInstance()->SetReapWorkerProcessFailCallbackForTesting(
      base::BindLambdaForTesting([&](std::string reason) {
        reap_worker_process_succeeded = false;
        reap_worker_process_fail_reason = reason;
        runloop.Quit();
      }));

  // Binds the disconnection handler. We need to quit the runloop here.
  Process::GetInstance()->SetReapWorkerProcessSucceedCallbackForTesting(
      base::BindLambdaForTesting([&]() { runloop.Quit(); }));

  // Sets the mlservice binary path which should be at the same dir of the test
  // binary.
  Process::GetInstance()->SetMlServicePathForTesting(GetMlServicePath());

  mojo::Remote<MachineLearningService> ml_service;
  auto ml_service_impl = std::make_unique<MachineLearningServiceImplForTesting>(
      ml_service.BindNewPipeAndPassReceiver());

  // Tries to load a model.
  mojo::Remote<
      chromeos::machine_learning::web_platform::mojom::HandwritingRecognizer>
      recognizer;

  bool model_callback_done = false;
  auto constraint = chromeos::machine_learning::web_platform::mojom::
      HandwritingModelConstraint::New();
  constraint->languages.push_back("en");
  ml_service->LoadWebPlatformHandwritingModel(
      std::move(constraint), recognizer.BindNewPipeAndPassReceiver(),
      base::BindLambdaForTesting([&](const LoadHandwritingModelResult result) {
        ASSERT_EQ(result, LoadHandwritingModelResult::OK);
        // Check the worker process is registered.
        EXPECT_EQ(Process::GetInstance()->GetWorkerPidInfoMap().size(), 1u);
        // Check the worker process is alive.
        pid_t worker_pid =
            Process::GetInstance()->GetWorkerPidInfoMap().begin()->first;
        ASSERT_GT(worker_pid, 0);
        EXPECT_EQ(kill(worker_pid, 0), 0);

        model_callback_done = true;
      }));

  // Tries to get the prediction result.
  // Set default inputs.
  auto hints =
      chromeos::machine_learning::web_platform::mojom::HandwritingHints::New();
  hints->alternatives = 1u;
  auto stroke =
      chromeos::machine_learning::web_platform::mojom::HandwritingStroke::New();
  for (int i = 0; i < 23; ++i) {
    auto point = chromeos::machine_learning::web_platform::mojom::
        HandwritingPoint::New();
    auto location = gfx::mojom::PointF::New();
    location->x = kHandwritingTestPoints[i][0];
    location->y = kHandwritingTestPoints[i][1];
    point->location = std::move(location);
    stroke->points.push_back(std::move(point));
  }
  std::vector<
      chromeos::machine_learning::web_platform::mojom::HandwritingStrokePtr>
      strokes;
  strokes.push_back(std::move(stroke));

  bool prediction_callback_done = false;
  pid_t worker_pid = -1;
  recognizer->GetPrediction(
      std::move(strokes), std::move(hints),
      base::BindLambdaForTesting(
          [&](base::Optional<
              std::vector<chromeos::machine_learning::web_platform::mojom::
                              HandwritingPredictionPtr>> predictions) {
            // Check that the inference succeeded and gives
            // the expected number of outputs.
            ASSERT_TRUE(predictions.has_value());
            ASSERT_EQ(predictions->size(), 1u);
            EXPECT_EQ(predictions->at(0)->text, "a");

            // Verify the worker process is registered.
            EXPECT_EQ(Process::GetInstance()->GetWorkerPidInfoMap().size(), 1u);
            worker_pid =
                Process::GetInstance()->GetWorkerPidInfoMap().begin()->first;
            // Verify the worker process is a different one.
            ASSERT_NE(worker_pid, getpid());
            // Check the worker process is alive.
            ASSERT_GT(worker_pid, 0);
            EXPECT_EQ(kill(worker_pid, 0), 0);

            // Post a task to disconnect the mojom connection to test whether
            // the worker process exits.
            base::ThreadTaskRunnerHandle::Get()->PostTask(
                FROM_HERE,
                base::BindLambdaForTesting([&]() { recognizer.reset(); }));

            prediction_callback_done = true;
          }));

  // For safety, sets a timeout of 5min. This is just to guarantee the test will
  // not hang.
  bool is_timeout = false;
  base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
      FROM_HERE, base::BindLambdaForTesting([&]() {
        is_timeout = true;
        runloop.Quit();
      }),
      base::TimeDelta::FromMilliseconds(1000 * 60 * 5));

  runloop.Run();

  // If timeout, the unit test failed.
  ASSERT_FALSE(is_timeout);

  // Fail the test if the worker process can not be reaped.
  ASSERT_TRUE(reap_worker_process_succeeded) << reap_worker_process_fail_reason;
  // Verify the worker process has exited.
  EXPECT_NE(kill(worker_pid, 0), 0);
  // Verify the worker process has been unregistered.
  EXPECT_EQ(Process::GetInstance()->GetWorkerPidInfoMap().size(), 0u);

  EXPECT_TRUE(model_callback_done);
  EXPECT_TRUE(prediction_callback_done);
}

// This tests, on non-supported boards, the `LoadWebPlatformHandwritingModel`
// API does not crash.
TEST(WebPlatformHandwritingModel, NoCrashOnNonsupportedBoards) {
  // Skip if ondevice HWR is supported. We do not need to worry about whether
  // asan is enabled because dlopen will not be called in the test.
  if (ml::HandwritingLibrary::IsHandwritingLibrarySupported()) {
    return;
  }

  // Loads a model.
  base::RunLoop runloop;

  // Sets the process to be control to test multiprocess code.
  // Note that we need to use `kSingleProcessForTest` because the worker
  // process' crash does not fail the unit test.
  Process::GetInstance()->SetTypeForTesting(
      Process::Type::kSingleProcessForTest);

  mojo::Remote<MachineLearningService> ml_service;
  auto ml_service_impl = std::make_unique<MachineLearningServiceImplForTesting>(
      ml_service.BindNewPipeAndPassReceiver());

  // Tries to load a model.
  mojo::Remote<
      chromeos::machine_learning::web_platform::mojom::HandwritingRecognizer>
      recognizer;

  bool model_callback_done = false;
  auto constraint = chromeos::machine_learning::web_platform::mojom::
      HandwritingModelConstraint::New();
  constraint->languages.push_back("en");
  ml_service->LoadWebPlatformHandwritingModel(
      std::move(constraint), recognizer.BindNewPipeAndPassReceiver(),
      base::BindLambdaForTesting([&](const LoadHandwritingModelResult result) {
        ASSERT_EQ(result, LoadHandwritingModelResult::LOAD_MODEL_ERROR);
        model_callback_done = true;
        runloop.Quit();
      }));

  runloop.Run();
  EXPECT_TRUE(model_callback_done);
}

}  // namespace ml
