| // Copyright 2018 The Chromium OS Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include <memory> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include <base/bind.h> |
| #include <base/containers/flat_map.h> |
| #include <base/files/file_util.h> |
| #include <base/macros.h> |
| #include <base/run_loop.h> |
| #include <base/stl_util.h> |
| #include <gmock/gmock.h> |
| #include <gtest/gtest.h> |
| #include <mojo/public/cpp/bindings/remote.h> |
| |
| #include "ml/grammar_library.h" |
| #include "ml/grammar_proto_mojom_conversion.h" |
| #include "ml/handwriting.h" |
| #include "ml/handwriting_proto_mojom_conversion.h" |
| #include "ml/machine_learning_service_impl.h" |
| #include "ml/mojom/grammar_checker.mojom.h" |
| #include "ml/mojom/graph_executor.mojom.h" |
| #include "ml/mojom/handwriting_recognizer.mojom.h" |
| #include "ml/mojom/machine_learning_service.mojom.h" |
| #include "ml/mojom/model.mojom.h" |
| #include "ml/mojom/soda.mojom.h" |
| #include "ml/mojom/text_classifier.mojom.h" |
| #include "ml/mojom/text_suggester.mojom.h" |
| #include "ml/tensor_view.h" |
| #include "ml/test_utils.h" |
| #include "ml/text_suggester_proto_mojom_conversion.h" |
| #include "ml/text_suggestions.h" |
| |
| namespace ml { |
| namespace { |
| |
| constexpr double kSearchRanker20190923TestInput[] = { |
| 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, |
| 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, |
| }; |
| |
| constexpr double kSmartDim20181115TestInput[] = { |
| 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, |
| 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| }; |
| |
| constexpr double kSmartDim20190221TestInput[] = { |
| 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| }; |
| |
| constexpr double kSmartDim20190521TestInput[] = { |
| 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, |
| 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, |
| }; |
| |
| constexpr double kSmartDim20200206TestInput[] = { |
| 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, |
| 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| }; |
| |
| constexpr double kSmartDim20210201TestInput[] = { |
| 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, |
| 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, |
| }; |
| |
| // Points that are used to generate a stroke for handwriting. |
| constexpr float kHandwritingTestPoints[23][2] = { |
| {1.928, 0.827}, {1.828, 0.826}, {1.73, 0.858}, {1.667, 0.901}, |
| {1.617, 0.955}, {1.567, 1.043}, {1.548, 1.148}, {1.569, 1.26}, |
| {1.597, 1.338}, {1.641, 1.408}, {1.688, 1.463}, {1.783, 1.473}, |
| {1.853, 1.418}, {1.897, 1.362}, {1.938, 1.278}, {1.968, 1.204}, |
| {1.999, 1.112}, {2.003, 1.004}, {1.984, 0.905}, {1.988, 1.043}, |
| {1.98, 1.178}, {1.976, 1.303}, {1.984, 1.415}, |
| }; |
| |
| // The words "unknownword" and "a.bcd" should not be detected by the new |
| // vocabulary based dictionary annotator. |
| constexpr char kTextClassifierTestInput[] = |
| "user.name@gmail.com. 123 George Street. unfathomable. 12pm. 350°F. " |
| "unknownword. a.bcd"; |
| |
| using ::chromeos::machine_learning::mojom::BuiltinModelId; |
| using ::chromeos::machine_learning::mojom::BuiltinModelSpec; |
| using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr; |
| using ::chromeos::machine_learning::mojom::CodepointSpan; |
| using ::chromeos::machine_learning::mojom::CodepointSpanPtr; |
| using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult; |
| using ::chromeos::machine_learning::mojom::ExecuteResult; |
| using ::chromeos::machine_learning::mojom::FlatBufferModelSpec; |
| using ::chromeos::machine_learning::mojom::FlatBufferModelSpecPtr; |
| using ::chromeos::machine_learning::mojom::GrammarChecker; |
| using ::chromeos::machine_learning::mojom::GrammarCheckerResult; |
| using ::chromeos::machine_learning::mojom::GrammarCheckerResultPtr; |
| using ::chromeos::machine_learning::mojom::GraphExecutor; |
| using ::chromeos::machine_learning::mojom::HandwritingRecognitionQuery; |
| using ::chromeos::machine_learning::mojom::HandwritingRecognitionQueryPtr; |
| using ::chromeos::machine_learning::mojom::HandwritingRecognizer; |
| using ::chromeos::machine_learning::mojom::HandwritingRecognizerResult; |
| using ::chromeos::machine_learning::mojom::HandwritingRecognizerResultPtr; |
| using ::chromeos::machine_learning::mojom::HandwritingRecognizerSpec; |
| using ::chromeos::machine_learning::mojom::LoadHandwritingModelResult; |
| using ::chromeos::machine_learning::mojom::LoadModelResult; |
| using ::chromeos::machine_learning::mojom::MachineLearningService; |
| using ::chromeos::machine_learning::mojom::Model; |
| using ::chromeos::machine_learning::mojom::SodaClient; |
| using ::chromeos::machine_learning::mojom::SodaConfig; |
| using ::chromeos::machine_learning::mojom::SodaRecognizer; |
| using ::chromeos::machine_learning::mojom::TensorPtr; |
| using ::chromeos::machine_learning::mojom::TextAnnotationPtr; |
| using ::chromeos::machine_learning::mojom::TextAnnotationRequest; |
| using ::chromeos::machine_learning::mojom::TextAnnotationRequestPtr; |
| using ::chromeos::machine_learning::mojom::TextClassifier; |
| using ::chromeos::machine_learning::mojom::TextLanguagePtr; |
| using ::chromeos::machine_learning::mojom::TextSuggester; |
| using ::chromeos::machine_learning::mojom::TextSuggesterQuery; |
| using ::chromeos::machine_learning::mojom::TextSuggesterQueryPtr; |
| using ::chromeos::machine_learning::mojom::TextSuggesterResult; |
| using ::chromeos::machine_learning::mojom::TextSuggesterResultPtr; |
| using ::chromeos::machine_learning::mojom::TextSuggestSelectionRequest; |
| using ::chromeos::machine_learning::mojom::TextSuggestSelectionRequestPtr; |
| |
| using ::chromeos::machine_learning::mojom::NextWordCompletionCandidate; |
| using ::chromeos::machine_learning::mojom::NextWordCompletionCandidatePtr; |
| |
| using ::testing::DoubleEq; |
| using ::testing::DoubleNear; |
| using ::testing::ElementsAre; |
| using ::testing::StrictMock; |
| |
| // A version of MachineLearningServiceImpl that loads from the testing model |
| // directory. |
| class MachineLearningServiceImplForTesting : public MachineLearningServiceImpl { |
| public: |
| // Pass an empty callback and use the testing model directory. |
| explicit MachineLearningServiceImplForTesting( |
| mojo::PendingReceiver< |
| chromeos::machine_learning::mojom::MachineLearningService> receiver) |
| : MachineLearningServiceImpl( |
| std::move(receiver), base::Closure(), GetTestModelDir()) {} |
| }; |
| |
| // A simple SODA client for testing. |
| class MockSodaClientImpl |
| : public chromeos::machine_learning::mojom::SodaClient { |
| public: |
| MOCK_METHOD(void, OnStop, (), (override)); |
| MOCK_METHOD(void, OnStart, (), (override)); |
| MOCK_METHOD( |
| void, |
| OnSpeechRecognizerEvent, |
| (chromeos::machine_learning::mojom::SpeechRecognizerEventPtr event), |
| (override)); |
| }; |
| |
| // Loads builtin model specified by `model_id`, binding the impl to `model`. |
| // Returns true on success. |
| bool LoadBuiltinModelForTesting( |
| const mojo::Remote<MachineLearningService>& ml_service, |
| BuiltinModelId model_id, |
| mojo::Remote<Model>* model) { |
| // Set up model spec. |
| BuiltinModelSpecPtr spec = BuiltinModelSpec::New(); |
| spec->id = model_id; |
| |
| bool model_callback_done = false; |
| ml_service->LoadBuiltinModel( |
| std::move(spec), model->BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| return model_callback_done; |
| } |
| |
| // Loads flatbuffer model specified by `spec`, binding the impl to `model`. |
| // Returns true on success. |
| bool LoadFlatBufferModelForTesting( |
| const mojo::Remote<MachineLearningService>& ml_service, |
| FlatBufferModelSpecPtr spec, |
| mojo::Remote<Model>* model) { |
| bool model_callback_done = false; |
| ml_service->LoadFlatBufferModel( |
| std::move(spec), model->BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| return model_callback_done; |
| } |
| |
| // Creates graph executor of `model`, binding the impl to `graph_executor`. |
| // Returns true on success. |
| bool CreateGraphExecutorForTesting( |
| const mojo::Remote<Model>& model, |
| mojo::Remote<GraphExecutor>* graph_executor) { |
| bool ge_callback_done = false; |
| model->CreateGraphExecutor( |
| graph_executor->BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* ge_callback_done, const CreateGraphExecutorResult result) { |
| EXPECT_EQ(result, CreateGraphExecutorResult::OK); |
| *ge_callback_done = true; |
| }, |
| &ge_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| return ge_callback_done; |
| } |
| |
| // Checks that `result` is OK and that `outputs` contains a tensor matching |
| // `expected_shape` and `expected_value`. Sets `infer_callback_done` to true so |
| // that this function can be used to verify that a Mojo callback has been run. |
| // TODO(alanlxl): currently the output size of all models are 1, and value types |
| // are all double. Parameterization may be necessary for future models. |
| void CheckOutputTensor(const std::vector<int64_t> expected_shape, |
| const double expected_value, |
| bool* infer_callback_done, |
| ExecuteResult result, |
| base::Optional<std::vector<TensorPtr>> outputs) { |
| // Check that the inference succeeded and gives the expected number |
| // of outputs. |
| EXPECT_EQ(result, ExecuteResult::OK); |
| ASSERT_TRUE(outputs.has_value()); |
| // currently all the models here has the same output size 1. |
| ASSERT_EQ(outputs->size(), 1); |
| |
| // Check that the output tensor has the right type and format. |
| const TensorView<double> out_tensor((*outputs)[0]); |
| EXPECT_TRUE(out_tensor.IsValidType()); |
| EXPECT_TRUE(out_tensor.IsValidFormat()); |
| |
| // Check the output tensor has the expected shape and values. |
| EXPECT_EQ(out_tensor.GetShape(), expected_shape); |
| EXPECT_THAT(out_tensor.GetValues(), |
| ElementsAre(DoubleNear(expected_value, 1e-5))); |
| *infer_callback_done = true; |
| } |
| |
| // Tests that Clone() connects to a working impl. |
| TEST(MachineLearningServiceImplTest, Clone) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| // Call Clone to bind another MachineLearningService. |
| mojo::Remote<MachineLearningService> ml_service_2; |
| ml_service->Clone(ml_service_2.BindNewPipeAndPassReceiver()); |
| |
| // Verify that the new MachineLearningService works with a simple call: |
| // Loading the TEST_MODEL. |
| BuiltinModelSpecPtr spec = BuiltinModelSpec::New(); |
| spec->id = BuiltinModelId::TEST_MODEL; |
| mojo::Remote<Model> model; |
| bool model_callback_done = false; |
| ml_service_2->LoadBuiltinModel( |
| std::move(spec), model.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| EXPECT_TRUE(model_callback_done); |
| EXPECT_TRUE(model.is_bound()); |
| } |
| |
| TEST(MachineLearningServiceImplTest, TestBadModel) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| // Set up model spec to specify an invalid model. |
| BuiltinModelSpecPtr spec = BuiltinModelSpec::New(); |
| spec->id = BuiltinModelId::UNSUPPORTED_UNKNOWN; |
| |
| // Load model. |
| mojo::Remote<Model> model; |
| bool model_callback_done = false; |
| ml_service->LoadBuiltinModel( |
| std::move(spec), model.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::MODEL_SPEC_ERROR); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| } |
| |
| // Tests loading an empty model through the downloaded model api. |
| TEST(MachineLearningServiceImplTest, EmptyModelString) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| FlatBufferModelSpecPtr spec = FlatBufferModelSpec::New(); |
| spec->model_string = ""; |
| spec->inputs["x"] = 1; |
| spec->inputs["y"] = 2; |
| spec->outputs["z"] = 0; |
| spec->metrics_model_name = "TestModel"; |
| |
| // Load model from an empty model string. |
| mojo::Remote<Model> model; |
| bool model_callback_done = false; |
| ml_service->LoadFlatBufferModel( |
| std::move(spec), model.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::LOAD_MODEL_ERROR); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| } |
| |
| // Tests loading a bad model string through the downloaded model api. |
| TEST(MachineLearningServiceImplTest, BadModelString) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| FlatBufferModelSpecPtr spec = FlatBufferModelSpec::New(); |
| spec->model_string = "bad model string"; |
| spec->inputs["x"] = 1; |
| spec->inputs["y"] = 2; |
| spec->outputs["z"] = 0; |
| spec->metrics_model_name = "TestModel"; |
| |
| // Load model from an empty model string. |
| mojo::Remote<Model> model; |
| bool model_callback_done = false; |
| ml_service->LoadFlatBufferModel( |
| std::move(spec), model.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::LOAD_MODEL_ERROR); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| } |
| |
| // Tests loading TEST_MODEL through the builtin model api. |
| TEST(MachineLearningServiceImplTest, TestModel) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| // Leave loading model and creating graph executor inline here to demonstrate |
| // the usage details. |
| // Set up model spec. |
| BuiltinModelSpecPtr spec = BuiltinModelSpec::New(); |
| spec->id = BuiltinModelId::TEST_MODEL; |
| |
| // Load model. |
| mojo::Remote<Model> model; |
| bool model_callback_done = false; |
| ml_service->LoadBuiltinModel( |
| std::move(spec), model.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| ASSERT_TRUE(model.is_bound()); |
| |
| // Get graph executor. |
| mojo::Remote<GraphExecutor> graph_executor; |
| bool ge_callback_done = false; |
| model->CreateGraphExecutor( |
| graph_executor.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* ge_callback_done, const CreateGraphExecutorResult result) { |
| EXPECT_EQ(result, CreateGraphExecutorResult::OK); |
| *ge_callback_done = true; |
| }, |
| &ge_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(ge_callback_done); |
| ASSERT_TRUE(graph_executor.is_bound()); |
| |
| // Construct input. |
| base::flat_map<std::string, TensorPtr> inputs; |
| inputs.emplace("x", NewTensor<double>({1}, {0.5})); |
| inputs.emplace("y", NewTensor<double>({1}, {0.25})); |
| std::vector<std::string> outputs({"z"}); |
| std::vector<int64_t> expected_shape{1L}; |
| |
| // Perform inference. |
| bool infer_callback_done = false; |
| graph_executor->Execute(std::move(inputs), std::move(outputs), |
| base::Bind(&CheckOutputTensor, expected_shape, 0.75, |
| &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests loading TEST_MODEL through the downloaded model api. |
| TEST(MachineLearningServiceImplTest, TestModelString) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| // Load the TEST_MODEL model file into string. |
| std::string model_string; |
| ASSERT_TRUE(base::ReadFileToString( |
| base::FilePath(GetTestModelDir() + |
| "mlservice-model-test_add-20180914.tflite"), |
| &model_string)); |
| |
| FlatBufferModelSpecPtr spec = FlatBufferModelSpec::New(); |
| spec->model_string = std::move(model_string); |
| spec->inputs["x"] = 1; |
| spec->inputs["y"] = 2; |
| spec->outputs["z"] = 0; |
| spec->metrics_model_name = "TestModel"; |
| |
| // Load model. |
| mojo::Remote<Model> model; |
| ASSERT_TRUE( |
| LoadFlatBufferModelForTesting(ml_service, std::move(spec), &model)); |
| ASSERT_NE(model.get(), nullptr); |
| ASSERT_TRUE(model.is_bound()); |
| |
| // Get graph executor. |
| mojo::Remote<GraphExecutor> graph_executor; |
| ASSERT_TRUE(CreateGraphExecutorForTesting(model, &graph_executor)); |
| ASSERT_TRUE(graph_executor.is_bound()); |
| |
| // Construct input. |
| base::flat_map<std::string, TensorPtr> inputs; |
| inputs.emplace("x", NewTensor<double>({1}, {0.5})); |
| inputs.emplace("y", NewTensor<double>({1}, {0.25})); |
| std::vector<std::string> outputs({"z"}); |
| std::vector<int64_t> expected_shape{1L}; |
| |
| // Perform inference. |
| bool infer_callback_done = false; |
| graph_executor->Execute(std::move(inputs), std::move(outputs), |
| base::Bind(&CheckOutputTensor, expected_shape, 0.75, |
| &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests that the Smart Dim (20181115) model file loads correctly and produces |
| // the expected inference result. |
| TEST(BuiltinModelInferenceTest, SmartDim20181115) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| // Load model. |
| mojo::Remote<Model> model; |
| ASSERT_TRUE(LoadBuiltinModelForTesting( |
| ml_service, BuiltinModelId::SMART_DIM_20181115, &model)); |
| ASSERT_TRUE(model.is_bound()); |
| |
| // Get graph executor. |
| mojo::Remote<GraphExecutor> graph_executor; |
| ASSERT_TRUE(CreateGraphExecutorForTesting(model, &graph_executor)); |
| ASSERT_TRUE(graph_executor.is_bound()); |
| |
| // Construct input. |
| base::flat_map<std::string, TensorPtr> inputs; |
| inputs.emplace( |
| "input", NewTensor<double>( |
| {1, base::size(kSmartDim20181115TestInput)}, |
| std::vector<double>(std::begin(kSmartDim20181115TestInput), |
| std::end(kSmartDim20181115TestInput)))); |
| std::vector<std::string> outputs({"output"}); |
| std::vector<int64_t> expected_shape{1L, 1L}; |
| |
| // Perform inference. |
| bool infer_callback_done = false; |
| graph_executor->Execute(std::move(inputs), std::move(outputs), |
| base::Bind(&CheckOutputTensor, expected_shape, |
| -3.36311, &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests that the Smart Dim (20190221) model file loads correctly and produces |
| // the expected inference result. |
| TEST(BuiltinModelInferenceTest, SmartDim20190221) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| // Load model and create graph executor. |
| mojo::Remote<Model> model; |
| ASSERT_TRUE(LoadBuiltinModelForTesting( |
| ml_service, BuiltinModelId::SMART_DIM_20190221, &model)); |
| ASSERT_TRUE(model.is_bound()); |
| |
| mojo::Remote<GraphExecutor> graph_executor; |
| ASSERT_TRUE(CreateGraphExecutorForTesting(model, &graph_executor)); |
| ASSERT_TRUE(graph_executor.is_bound()); |
| |
| // Construct input. |
| base::flat_map<std::string, TensorPtr> inputs; |
| inputs.emplace( |
| "input", NewTensor<double>( |
| {1, base::size(kSmartDim20190221TestInput)}, |
| std::vector<double>(std::begin(kSmartDim20190221TestInput), |
| std::end(kSmartDim20190221TestInput)))); |
| std::vector<std::string> outputs({"output"}); |
| std::vector<int64_t> expected_shape{1L, 1L}; |
| |
| // Perform inference. |
| bool infer_callback_done = false; |
| graph_executor->Execute(std::move(inputs), std::move(outputs), |
| base::Bind(&CheckOutputTensor, expected_shape, |
| -0.900591, &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests that the Smart Dim (20190521) model file loads correctly and produces |
| // the expected inference result. |
| TEST(BuiltinModelInferenceTest, SmartDim20190521) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| // Load model and create graph executor. |
| mojo::Remote<Model> model; |
| ASSERT_TRUE(LoadBuiltinModelForTesting( |
| ml_service, BuiltinModelId::SMART_DIM_20190521, &model)); |
| ASSERT_TRUE(model.is_bound()); |
| |
| mojo::Remote<GraphExecutor> graph_executor; |
| ASSERT_TRUE(CreateGraphExecutorForTesting(model, &graph_executor)); |
| ASSERT_TRUE(graph_executor.is_bound()); |
| |
| // Construct input. |
| base::flat_map<std::string, TensorPtr> inputs; |
| inputs.emplace( |
| "input", NewTensor<double>( |
| {1, base::size(kSmartDim20190521TestInput)}, |
| std::vector<double>(std::begin(kSmartDim20190521TestInput), |
| std::end(kSmartDim20190521TestInput)))); |
| std::vector<std::string> outputs({"output"}); |
| std::vector<int64_t> expected_shape{1L, 1L}; |
| |
| // Perform inference. |
| bool infer_callback_done = false; |
| graph_executor->Execute(std::move(inputs), std::move(outputs), |
| base::Bind(&CheckOutputTensor, expected_shape, |
| 0.66962254, &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests that the Search Ranker (20190923) model file loads correctly and |
| // produces the expected inference result. |
| TEST(BuiltinModelInferenceTest, SearchRanker20190923) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| // Load model and create graph executor. |
| mojo::Remote<Model> model; |
| ASSERT_TRUE(LoadBuiltinModelForTesting( |
| ml_service, BuiltinModelId::SEARCH_RANKER_20190923, &model)); |
| ASSERT_TRUE(model.is_bound()); |
| |
| mojo::Remote<GraphExecutor> graph_executor; |
| ASSERT_TRUE(CreateGraphExecutorForTesting(model, &graph_executor)); |
| ASSERT_TRUE(graph_executor.is_bound()); |
| |
| // Construct input. |
| base::flat_map<std::string, TensorPtr> inputs; |
| inputs.emplace("input", NewTensor<double>( |
| {1, base::size(kSearchRanker20190923TestInput)}, |
| std::vector<double>( |
| std::begin(kSearchRanker20190923TestInput), |
| std::end(kSearchRanker20190923TestInput)))); |
| std::vector<std::string> outputs({"output"}); |
| std::vector<int64_t> expected_shape{1L}; |
| |
| // Perform inference. |
| bool infer_callback_done = false; |
| graph_executor->Execute(std::move(inputs), std::move(outputs), |
| base::Bind(&CheckOutputTensor, expected_shape, |
| 0.658488, &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests that the Smart Dim (20200206) model file loads correctly and |
| // produces the expected inference result. |
| TEST(DownloadableModelInferenceTest, SmartDim20200206) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| // Load SmartDim model into string. |
| std::string model_string; |
| ASSERT_TRUE(base::ReadFileToString( |
| base::FilePath(GetTestModelDir() + |
| "mlservice-model-smart_dim-20200206-downloadable.tflite"), |
| &model_string)); |
| |
| FlatBufferModelSpecPtr spec = FlatBufferModelSpec::New(); |
| spec->model_string = std::move(model_string); |
| spec->inputs["input"] = 0; |
| spec->outputs["output"] = 6; |
| spec->metrics_model_name = "SmartDimModel_20200206"; |
| |
| // Load model. |
| mojo::Remote<Model> model; |
| ASSERT_TRUE( |
| LoadFlatBufferModelForTesting(ml_service, std::move(spec), &model)); |
| ASSERT_NE(model.get(), nullptr); |
| ASSERT_TRUE(model.is_bound()); |
| |
| // Get graph executor. |
| mojo::Remote<GraphExecutor> graph_executor; |
| ASSERT_TRUE(CreateGraphExecutorForTesting(model, &graph_executor)); |
| ASSERT_TRUE(graph_executor.is_bound()); |
| |
| // Construct input. |
| base::flat_map<std::string, TensorPtr> inputs; |
| inputs.emplace( |
| "input", NewTensor<double>( |
| {1, base::size(kSmartDim20200206TestInput)}, |
| std::vector<double>(std::begin(kSmartDim20200206TestInput), |
| std::end(kSmartDim20200206TestInput)))); |
| std::vector<std::string> outputs({"output"}); |
| std::vector<int64_t> expected_shape{1L, 1L}; |
| |
| // Perform inference. |
| bool infer_callback_done = false; |
| graph_executor->Execute(std::move(inputs), std::move(outputs), |
| base::Bind(&CheckOutputTensor, expected_shape, |
| -1.07195, &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests that the Smart Dim (20210201) model file loads correctly and |
| // produces the expected inference result. |
| TEST(DownloadableModelInferenceTest, SmartDim20210201) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| // Load SmartDim model into string. |
| std::string model_string; |
| ASSERT_TRUE(base::ReadFileToString( |
| base::FilePath(GetTestModelDir() + |
| "mlservice-model-smart_dim-20210201-downloadable.tflite"), |
| &model_string)); |
| |
| FlatBufferModelSpecPtr spec = FlatBufferModelSpec::New(); |
| spec->model_string = std::move(model_string); |
| spec->inputs["input"] = 0; |
| spec->outputs["output"] = 20; |
| spec->metrics_model_name = "SmartDimModel_20210201"; |
| |
| // Load model. |
| mojo::Remote<Model> model; |
| ASSERT_TRUE( |
| LoadFlatBufferModelForTesting(ml_service, std::move(spec), &model)); |
| ASSERT_NE(model.get(), nullptr); |
| ASSERT_TRUE(model.is_bound()); |
| |
| // Get graph executor. |
| mojo::Remote<GraphExecutor> graph_executor; |
| ASSERT_TRUE(CreateGraphExecutorForTesting(model, &graph_executor)); |
| ASSERT_TRUE(graph_executor.is_bound()); |
| |
| // Construct input. |
| base::flat_map<std::string, TensorPtr> inputs; |
| inputs.emplace( |
| "input", NewTensor<double>( |
| {1, base::size(kSmartDim20210201TestInput)}, |
| std::vector<double>(std::begin(kSmartDim20210201TestInput), |
| std::end(kSmartDim20210201TestInput)))); |
| std::vector<std::string> outputs({"output"}); |
| std::vector<int64_t> expected_shape{1L, 1L}; |
| |
| // Perform inference. |
| bool infer_callback_done = false; |
| graph_executor->Execute(std::move(inputs), std::move(outputs), |
| base::Bind(&CheckOutputTensor, expected_shape, |
| 0.76872265, &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests loading text classifier only. |
| TEST(LoadTextClassifierTest, NoInference) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| mojo::Remote<TextClassifier> text_classifier; |
| bool model_callback_done = false; |
| ml_service->LoadTextClassifier( |
| text_classifier.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| } |
| |
| // Tests text classifier annotator for empty string. |
| TEST(TextClassifierAnnotateTest, EmptyString) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| mojo::Remote<TextClassifier> text_classifier; |
| bool model_callback_done = false; |
| ml_service->LoadTextClassifier( |
| text_classifier.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| |
| TextAnnotationRequestPtr request = TextAnnotationRequest::New(); |
| request->text = ""; |
| bool infer_callback_done = false; |
| text_classifier->Annotate(std::move(request), |
| base::Bind( |
| [](bool* infer_callback_done, |
| std::vector<TextAnnotationPtr> annotations) { |
| *infer_callback_done = true; |
| EXPECT_EQ(annotations.size(), 0); |
| }, |
| &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests text classifier annotator for a complex string. |
| TEST(TextClassifierAnnotateTest, ComplexString) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| mojo::Remote<TextClassifier> text_classifier; |
| bool model_callback_done = false; |
| ml_service->LoadTextClassifier( |
| text_classifier.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| |
| TextAnnotationRequestPtr request = TextAnnotationRequest::New(); |
| request->text = kTextClassifierTestInput; |
| bool infer_callback_done = false; |
| text_classifier->Annotate( |
| std::move(request), |
| base::Bind( |
| [](bool* infer_callback_done, |
| std::vector<TextAnnotationPtr> annotations) { |
| *infer_callback_done = true; |
| EXPECT_EQ(annotations.size(), 5); |
| EXPECT_EQ(annotations[0]->start_offset, 0); |
| EXPECT_EQ(annotations[0]->end_offset, 19); |
| ASSERT_GE(annotations[0]->entities.size(), 1); |
| EXPECT_EQ(annotations[0]->entities[0]->name, "email"); |
| EXPECT_EQ(annotations[0]->entities[0]->data->get_string_value(), |
| "user.name@gmail.com"); |
| EXPECT_EQ(annotations[1]->start_offset, 21); |
| EXPECT_EQ(annotations[1]->end_offset, 38); |
| ASSERT_GE(annotations[1]->entities.size(), 1); |
| EXPECT_EQ(annotations[1]->entities[0]->name, "address"); |
| EXPECT_EQ(annotations[1]->entities[0]->data->get_string_value(), |
| "123 George Street"); |
| EXPECT_EQ(annotations[2]->start_offset, 40); |
| EXPECT_EQ(annotations[2]->end_offset, 52); |
| ASSERT_GE(annotations[2]->entities.size(), 1); |
| EXPECT_EQ(annotations[2]->entities[0]->name, "dictionary"); |
| EXPECT_EQ(annotations[2]->entities[0]->data->get_string_value(), |
| "unfathomable"); |
| EXPECT_EQ(annotations[3]->start_offset, 54); |
| EXPECT_EQ(annotations[3]->end_offset, 59); |
| ASSERT_GE(annotations[3]->entities.size(), 1); |
| EXPECT_EQ(annotations[3]->entities[0]->name, "datetime"); |
| EXPECT_EQ(annotations[3]->entities[0]->data->get_string_value(), |
| "12pm."); |
| EXPECT_EQ(annotations[4]->start_offset, 60); |
| EXPECT_EQ(annotations[4]->end_offset, 65); |
| ASSERT_GE(annotations[4]->entities.size(), 1); |
| EXPECT_EQ(annotations[4]->entities[0]->name, "unit"); |
| EXPECT_EQ(annotations[4]->entities[0]->data->get_string_value(), |
| "350°F"); |
| }, |
| &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests text classifier selection suggestion for an empty string. |
| // In this situation, text classifier will return the input span. |
| TEST(TextClassifierSelectionTest, EmptyString) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| mojo::Remote<TextClassifier> text_classifier; |
| bool model_callback_done = false; |
| ml_service->LoadTextClassifier( |
| text_classifier.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| |
| TextSuggestSelectionRequestPtr request = TextSuggestSelectionRequest::New(); |
| request->text = ""; |
| request->user_selection = CodepointSpan::New(); |
| request->user_selection->start_offset = 1; |
| request->user_selection->end_offset = 2; |
| bool infer_callback_done = false; |
| text_classifier->SuggestSelection( |
| std::move(request), |
| base::Bind( |
| [](bool* infer_callback_done, CodepointSpanPtr suggested_span) { |
| *infer_callback_done = true; |
| EXPECT_EQ(suggested_span->start_offset, 1); |
| EXPECT_EQ(suggested_span->end_offset, 2); |
| }, |
| &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests text classifier selection suggestion for a complex string. |
| TEST(TextClassifierSelectionTest, ComplexString) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| mojo::Remote<TextClassifier> text_classifier; |
| bool model_callback_done = false; |
| ml_service->LoadTextClassifier( |
| text_classifier.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| |
| TextSuggestSelectionRequestPtr request = TextSuggestSelectionRequest::New(); |
| request->text = kTextClassifierTestInput; |
| request->user_selection = CodepointSpan::New(); |
| request->user_selection->start_offset = 25; |
| request->user_selection->end_offset = 26; |
| bool infer_callback_done = false; |
| text_classifier->SuggestSelection( |
| std::move(request), |
| base::Bind( |
| [](bool* infer_callback_done, CodepointSpanPtr suggested_span) { |
| *infer_callback_done = true; |
| EXPECT_EQ(suggested_span->start_offset, 21); |
| EXPECT_EQ(suggested_span->end_offset, 38); |
| }, |
| &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests text classifier selection suggestion with wrong inputs. |
| // In this situation, text classifier will return the input span. |
| TEST(TextClassifierSelectionTest, WrongInput) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| mojo::Remote<TextClassifier> text_classifier; |
| bool model_callback_done = false; |
| ml_service->LoadTextClassifier( |
| text_classifier.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| |
| TextSuggestSelectionRequestPtr request = TextSuggestSelectionRequest::New(); |
| request->text = kTextClassifierTestInput; |
| request->user_selection = CodepointSpan::New(); |
| request->user_selection->start_offset = 30; |
| request->user_selection->end_offset = 26; |
| bool infer_callback_done = false; |
| text_classifier->SuggestSelection( |
| std::move(request), |
| base::Bind( |
| [](bool* infer_callback_done, CodepointSpanPtr suggested_span) { |
| *infer_callback_done = true; |
| EXPECT_EQ(suggested_span->start_offset, 30); |
| EXPECT_EQ(suggested_span->end_offset, 26); |
| }, |
| &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests text classifier language identification with some valid inputs. |
| TEST(TextClassifierLangIdTest, ValidInput) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| mojo::Remote<TextClassifier> text_classifier; |
| bool model_callback_done = false; |
| ml_service->LoadTextClassifier( |
| text_classifier.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| |
| bool infer_callback_done = false; |
| text_classifier->FindLanguages( |
| "Bonjour", |
| base::Bind( |
| [](bool* infer_callback_done, std::vector<TextLanguagePtr> result) { |
| *infer_callback_done = true; |
| ASSERT_GT(result.size(), 0); |
| EXPECT_EQ(result[0]->locale, "fr"); |
| }, |
| &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests text classifier language identification with empty input. |
| // Empty input should produce empty result. |
| TEST(TextClassifierLangIdTest, EmptyInput) { |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| mojo::Remote<TextClassifier> text_classifier; |
| bool model_callback_done = false; |
| ml_service->LoadTextClassifier( |
| text_classifier.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| EXPECT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| |
| bool infer_callback_done = false; |
| text_classifier->FindLanguages( |
| "", |
| base::Bind( |
| [](bool* infer_callback_done, std::vector<TextLanguagePtr> result) { |
| *infer_callback_done = true; |
| EXPECT_EQ(result.size(), 0); |
| }, |
| &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Test class for HandwritingRecognizerTest. |
| class HandwritingRecognizerTest : public testing::Test { |
| protected: |
| void SetUp() override { |
| // Nothing to test on an unsupported platform. |
| if (!ml::HandwritingLibrary::IsHandwritingLibraryUnitTestSupported()) { |
| return; |
| } |
| // Set ml_service. |
| ml_service_impl_ = std::make_unique<MachineLearningServiceImplForTesting>( |
| ml_service_.BindNewPipeAndPassReceiver()); |
| |
| // Set default request. |
| request_.set_max_num_results(1); |
| auto& stroke = *request_.mutable_ink()->add_strokes(); |
| for (int i = 0; i < 23; ++i) { |
| auto& point = *stroke.add_points(); |
| point.set_x(kHandwritingTestPoints[i][0]); |
| point.set_y(kHandwritingTestPoints[i][1]); |
| } |
| } |
| |
| // recognizer_ should be loaded successfully for this `language`. |
| // Using new API (LoadHandwritingModelWithSpec) if use_load_handwriting_model |
| // is true. |
| void LoadRecognizerWithLanguage( |
| const std::string& langauge, |
| const bool use_load_handwriting_model = false) { |
| bool model_callback_done = false; |
| if (use_load_handwriting_model) { |
| ml_service_->LoadHandwritingModel( |
| HandwritingRecognizerSpec::New(langauge), |
| recognizer_.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, |
| const LoadHandwritingModelResult result) { |
| ASSERT_EQ(result, LoadHandwritingModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| } else { |
| ml_service_->LoadHandwritingModelWithSpec( |
| HandwritingRecognizerSpec::New(langauge), |
| recognizer_.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| ASSERT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| } |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| ASSERT_TRUE(recognizer_.is_bound()); |
| } |
| |
| // Recognizing on the request_ should produce expected text and score. |
| void ExpectRecognizeResult(const std::string& text, const float score) { |
| // Perform inference. |
| bool infer_callback_done = false; |
| recognizer_->Recognize( |
| HandwritingRecognitionQueryFromProtoForTesting(request_), |
| base::Bind( |
| [](bool* infer_callback_done, const std::string& text, |
| const float score, const HandwritingRecognizerResultPtr result) { |
| // Check that the inference succeeded and gives |
| // the expected number of outputs. |
| EXPECT_EQ(result->status, |
| HandwritingRecognizerResult::Status::OK); |
| ASSERT_EQ(result->candidates.size(), 1); |
| EXPECT_EQ(result->candidates.at(0)->text, text); |
| EXPECT_NEAR(result->candidates.at(0)->score, score, 1e-4); |
| *infer_callback_done = true; |
| }, |
| &infer_callback_done, text, score)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| std::unique_ptr<MachineLearningServiceImplForTesting> ml_service_impl_; |
| mojo::Remote<MachineLearningService> ml_service_; |
| mojo::Remote<HandwritingRecognizer> recognizer_; |
| chrome_knowledge::HandwritingRecognizerRequest request_; |
| }; |
| |
| // Tests that the HandwritingRecognizer recognition returns expected scores. |
| TEST_F(HandwritingRecognizerTest, GetExpectedScores) { |
| // Nothing to test on an unsupported platform. |
| if (!ml::HandwritingLibrary::IsHandwritingLibraryUnitTestSupported()) { |
| return; |
| } |
| |
| // Load Recognizer successfully. |
| LoadRecognizerWithLanguage("en"); |
| |
| // Run Recognition on the default request_. |
| ExpectRecognizeResult("a", 0.50640869f); |
| |
| // Modify the request_ by setting fake time. |
| for (int i = 0; i < 23; ++i) { |
| request_.mutable_ink()->mutable_strokes(0)->mutable_points(i)->set_t(i * i * |
| 100); |
| } |
| ExpectRecognizeResult("a", 0.5121f); |
| } |
| |
| // Tests that the LoadHandwritingModel also perform as expected. |
| TEST_F(HandwritingRecognizerTest, LoadHandwritingModel) { |
| // Nothing to test on an unsupported platform. |
| if (!ml::HandwritingLibrary::IsHandwritingLibraryUnitTestSupported()) { |
| return; |
| } |
| |
| // Load Recognizer successfully. |
| LoadRecognizerWithLanguage("en", true); |
| |
| // Clear the ink inside request. |
| request_.clear_ink(); |
| |
| // Perform inference should return an error. |
| bool infer_callback_done = false; |
| recognizer_->Recognize( |
| HandwritingRecognitionQueryFromProtoForTesting(request_), |
| base::Bind( |
| [](bool* infer_callback_done, |
| const HandwritingRecognizerResultPtr result) { |
| // Check that the inference failed. |
| EXPECT_EQ(result->status, |
| HandwritingRecognizerResult::Status::ERROR); |
| EXPECT_EQ(result->candidates.size(), 0); |
| *infer_callback_done = true; |
| }, |
| &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests that the HandwritingRecognizer Recognition should fail on empty ink. |
| TEST_F(HandwritingRecognizerTest, FailOnEmptyInk) { |
| // Nothing to test on an unsupported platform. |
| if (!ml::HandwritingLibrary::IsHandwritingLibraryUnitTestSupported()) { |
| return; |
| } |
| |
| // Load Recognizer successfully. |
| LoadRecognizerWithLanguage("en"); |
| |
| // Clear the ink inside request. |
| request_.clear_ink(); |
| |
| // Perform inference should return an error. |
| bool infer_callback_done = false; |
| recognizer_->Recognize( |
| HandwritingRecognitionQueryFromProtoForTesting(request_), |
| base::Bind( |
| [](bool* infer_callback_done, |
| const HandwritingRecognizerResultPtr result) { |
| // Check that the inference failed. |
| EXPECT_EQ(result->status, |
| HandwritingRecognizerResult::Status::ERROR); |
| EXPECT_EQ(result->candidates.size(), 0); |
| *infer_callback_done = true; |
| }, |
| &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| MATCHER_P(StructPtrEq, n, "") { |
| return n.get().Equals(arg); |
| } |
| |
| // Test class for WebPlatformHandwritingRecognizerTest. |
| class WebPlatformHandwritingRecognizerTest : public testing::Test { |
| protected: |
| void SetUp() override { |
| // Nothing to test on an unsupported platform. |
| if (!ml::HandwritingLibrary::IsHandwritingLibraryUnitTestSupported()) { |
| return; |
| } |
| // Set ml_service. |
| ml_service_impl_ = std::make_unique<MachineLearningServiceImplForTesting>( |
| ml_service_.BindNewPipeAndPassReceiver()); |
| |
| // Set default inputs. |
| hints_ = chromeos::machine_learning::web_platform::mojom::HandwritingHints:: |
| New(); |
| hints_->alternatives = 1u; |
| auto stroke = chromeos::machine_learning::web_platform::mojom:: |
| HandwritingStroke::New(); |
| for (int i = 0; i < 23; ++i) { |
| auto point = chromeos::machine_learning::web_platform::mojom:: |
| HandwritingPoint::New(); |
| auto location = gfx::mojom::PointF::New(); |
| location->x = kHandwritingTestPoints[i][0]; |
| location->y = kHandwritingTestPoints[i][1]; |
| point->location = std::move(location); |
| stroke->points.push_back(std::move(point)); |
| } |
| strokes_.push_back(std::move(stroke)); |
| } |
| |
| // recognizer_ should be loaded successfully for this `language`. |
| void LoadRecognizerWithLanguage(const std::string& language) { |
| bool model_callback_done = false; |
| auto constraint = chromeos::machine_learning::web_platform::mojom:: |
| HandwritingModelConstraint::New(); |
| constraint->languages.push_back(language); |
| ml_service_->LoadWebPlatformHandwritingModel( |
| std::move(constraint), recognizer_.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, |
| const LoadHandwritingModelResult result) { |
| ASSERT_EQ(result, LoadHandwritingModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| ASSERT_TRUE(recognizer_.is_bound()); |
| } |
| |
| // Recognizing on the strokes_ and hints_, should produce expected text and |
| // score. |
| void ExpectRecognizeResult(const std::string& text) { |
| // Perform inference. |
| bool infer_callback_done = false; |
| // Make a copy of strokes and hints to avoid them being cleared after |
| recognizer_->GetPrediction( |
| GetDefaultStrokes(), hints_.Clone(), |
| base::Bind( |
| [](bool* infer_callback_done, const std::string& text, |
| base::Optional< |
| std::vector<chromeos::machine_learning::web_platform::mojom:: |
| HandwritingPredictionPtr>> predictions) { |
| // Check that the inference succeeded and gives |
| // the expected number of outputs. |
| ASSERT_TRUE(predictions.has_value()); |
| ASSERT_EQ(predictions->size(), 1u); |
| EXPECT_EQ(predictions->at(0)->text, text); |
| *infer_callback_done = true; |
| }, |
| &infer_callback_done, text)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Make a copy of strokes_ to avoid them being cleared after |
| // `GetPrediction()`. |
| std::vector< |
| chromeos::machine_learning::web_platform::mojom::HandwritingStrokePtr> |
| GetDefaultStrokes() { |
| std::vector< |
| chromeos::machine_learning::web_platform::mojom::HandwritingStrokePtr> |
| strokes_clone; |
| for (const auto& stroke : strokes_) { |
| strokes_clone.push_back(stroke.Clone()); |
| } |
| return strokes_clone; |
| } |
| |
| std::unique_ptr<MachineLearningServiceImplForTesting> ml_service_impl_; |
| mojo::Remote<MachineLearningService> ml_service_; |
| mojo::Remote< |
| chromeos::machine_learning::web_platform::mojom::HandwritingRecognizer> |
| recognizer_; |
| std::vector< |
| chromeos::machine_learning::web_platform::mojom::HandwritingStrokePtr> |
| strokes_; |
| chromeos::machine_learning::web_platform::mojom::HandwritingHintsPtr hints_; |
| }; |
| |
| // Tests that the web_platform::mojom::HandwritingRecognizer::GetPrediction |
| // returns expected scores. |
| TEST_F(WebPlatformHandwritingRecognizerTest, GetExpectedRecognizedText) { |
| // Nothing to test on an unsupported platform. |
| if (!ml::HandwritingLibrary::IsHandwritingLibraryUnitTestSupported()) { |
| return; |
| } |
| |
| // Load Recognizer successfully. |
| LoadRecognizerWithLanguage("en"); |
| |
| // Run Recognition on the default strokes_. |
| ExpectRecognizeResult("a"); |
| |
| // Modify the strokes_ by setting fake time. |
| ASSERT_EQ(strokes_.size(), 1u); |
| ASSERT_EQ(strokes_[0]->points.size(), 23u); |
| for (int i = 0; i < 23; ++i) { |
| strokes_[0]->points[i]->t = base::TimeDelta::FromMilliseconds(i * i * 100); |
| } |
| ExpectRecognizeResult("a"); |
| } |
| |
| TEST_F(WebPlatformHandwritingRecognizerTest, FailOnEmptyStrokes) { |
| // Nothing to test on an unsupported platform. |
| if (!ml::HandwritingLibrary::IsHandwritingLibraryUnitTestSupported()) { |
| return; |
| } |
| |
| // Load Recognizer successfully. |
| LoadRecognizerWithLanguage("en"); |
| |
| // Perform inference should return an error. |
| bool infer_callback_done = false; |
| recognizer_->GetPrediction( |
| {}, hints_.Clone(), |
| base::Bind( |
| [](bool* infer_callback_done, |
| base::Optional< |
| std::vector<chromeos::machine_learning::web_platform::mojom:: |
| HandwritingPredictionPtr>> predictions) { |
| // Check that the inference failed. |
| EXPECT_FALSE(predictions.has_value()); |
| *infer_callback_done = true; |
| }, |
| &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| // Tests the SODA CrOS mojo callback for the fake implementation can return |
| // expected error string. |
| TEST(SODARecognizerTest, FakeImplMojoCallback) { |
| #ifdef USE_ONDEVICE_SPEECH |
| return; |
| #else |
| StrictMock<MockSodaClientImpl> soda_client_impl; |
| mojo::Receiver<SodaClient> soda_client(&soda_client_impl); |
| auto soda_config = SodaConfig::New(); |
| mojo::Remote<SodaRecognizer> soda_recognizer; |
| |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| ml_service->LoadSpeechRecognizer(std::move(soda_config), |
| soda_client.BindNewPipeAndPassRemote(), |
| soda_recognizer.BindNewPipeAndPassReceiver(), |
| base::BindOnce([](LoadModelResult) {})); |
| chromeos::machine_learning::mojom::SpeechRecognizerEventPtr event = |
| chromeos::machine_learning::mojom::SpeechRecognizerEvent::New(); |
| chromeos::machine_learning::mojom::FinalResultPtr final_result = |
| chromeos::machine_learning::mojom::FinalResult::New(); |
| final_result->final_hypotheses.push_back( |
| "On-device speech is not supported."); |
| final_result->endpoint_reason = |
| chromeos::machine_learning::mojom::EndpointReason::ENDPOINT_UNKNOWN; |
| event->set_final_result(std::move(final_result)); |
| |
| // TODO(robsc): Update this unittest to use regular Eq() once |
| // https://chromium-review.googlesource.com/c/chromium/src/+/2456184 is |
| // submitted. |
| EXPECT_CALL(soda_client_impl, |
| OnSpeechRecognizerEvent(StructPtrEq(std::ref(event)))) |
| .Times(1); |
| soda_recognizer->Start(); |
| base::RunLoop().RunUntilIdle(); |
| |
| EXPECT_CALL(soda_client_impl, |
| OnSpeechRecognizerEvent(StructPtrEq(std::ref(event)))) |
| .Times(1); |
| soda_recognizer->AddAudio({}); |
| base::RunLoop().RunUntilIdle(); |
| |
| EXPECT_CALL(soda_client_impl, |
| OnSpeechRecognizerEvent(StructPtrEq(std::ref(event)))) |
| .Times(1); |
| soda_recognizer->MarkDone(); |
| base::RunLoop().RunUntilIdle(); |
| |
| EXPECT_CALL(soda_client_impl, |
| OnSpeechRecognizerEvent(StructPtrEq(std::ref(event)))) |
| .Times(1); |
| soda_recognizer->Stop(); |
| base::RunLoop().RunUntilIdle(); |
| #endif |
| } |
| |
| TEST(GrammarCheckerTest, LoadModelAndInference) { |
| if (ml::GrammarLibrary::GetInstance()->GetStatus() == |
| ml::GrammarLibrary::Status::kNotSupported) { |
| return; |
| } |
| |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| // Load GrammarChecker. |
| mojo::Remote<GrammarChecker> checker; |
| bool model_callback_done = false; |
| ml_service->LoadGrammarChecker( |
| checker.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| ASSERT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| ASSERT_TRUE(checker.is_bound()); |
| |
| chrome_knowledge::GrammarCheckerRequest request; |
| request.set_text("They is student."); |
| request.set_language("en-US"); |
| |
| bool infer_callback_done = false; |
| checker->Check( |
| GrammarCheckerQueryFromProtoForTesting(request), |
| base::Bind( |
| [](bool* infer_callback_done, const GrammarCheckerResultPtr result) { |
| EXPECT_EQ(result->status, GrammarCheckerResult::Status::OK); |
| ASSERT_GE(result->candidates.size(), 1); |
| EXPECT_EQ(result->candidates.at(0)->text, "They are students."); |
| |
| ASSERT_EQ(result->candidates.at(0)->fragments.size(), 1); |
| EXPECT_EQ(result->candidates.at(0)->fragments.at(0)->offset, 5); |
| EXPECT_EQ(result->candidates.at(0)->fragments.at(0)->length, 10); |
| EXPECT_EQ(result->candidates.at(0)->fragments.at(0)->replacement, |
| "are students"); |
| |
| *infer_callback_done = true; |
| }, |
| &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| TEST(TextSuggesterTest, LoadModelAndInference) { |
| if (ml::TextSuggestions::GetInstance()->GetStatus() == |
| ml::TextSuggestions::Status::kNotSupported) { |
| return; |
| } |
| |
| mojo::Remote<MachineLearningService> ml_service; |
| const MachineLearningServiceImplForTesting ml_service_impl( |
| ml_service.BindNewPipeAndPassReceiver()); |
| |
| // Load TextSuggester. |
| mojo::Remote<TextSuggester> suggester; |
| bool model_callback_done = false; |
| ml_service->LoadTextSuggester( |
| suggester.BindNewPipeAndPassReceiver(), |
| base::Bind( |
| [](bool* model_callback_done, const LoadModelResult result) { |
| ASSERT_EQ(result, LoadModelResult::OK); |
| *model_callback_done = true; |
| }, |
| &model_callback_done)); |
| |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(model_callback_done); |
| ASSERT_TRUE(suggester.is_bound()); |
| |
| TextSuggesterQueryPtr query = TextSuggesterQuery::New(); |
| query->text = "how are y"; |
| |
| NextWordCompletionCandidatePtr candidate_one = |
| NextWordCompletionCandidate::New(); |
| candidate_one->text = "you"; |
| candidate_one->normalized_score = -1.0f; |
| query->next_word_candidates.push_back(std::move(candidate_one)); |
| |
| bool infer_callback_done = false; |
| suggester->Suggest( |
| std::move(query), |
| base::Bind( |
| [](bool* infer_callback_done, const TextSuggesterResultPtr result) { |
| EXPECT_EQ(result->status, TextSuggesterResult::Status::OK); |
| ASSERT_EQ(result->candidates.size(), 1); |
| ASSERT_TRUE(result->candidates.at(0)->is_multi_word()); |
| EXPECT_EQ(result->candidates.at(0)->get_multi_word()->text, |
| "you doing"); |
| EXPECT_EQ( |
| result->candidates.at(0)->get_multi_word()->normalized_score, |
| -0.680989f); |
| *infer_callback_done = true; |
| }, |
| &infer_callback_done)); |
| base::RunLoop().RunUntilIdle(); |
| ASSERT_TRUE(infer_callback_done); |
| } |
| |
| } // namespace |
| } // namespace ml |