ml: add performance metrics for HWR recognization.

Add performance metrics in HandwritingRecognizerImpl::Recognize.

Also, simplify RequestMetrics class while making specifying an Event enum type optional.

BUG=chromium:1099555
TEST='All unit tests passed.'

Change-Id: Ice8f677e8d9836e25e38fdf9e572d905f23e3235
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform2/+/2269917
Reviewed-by: Charles . <charleszhao@chromium.org>
Reviewed-by: Andrew Moylan <amoylan@chromium.org>
Tested-by: Charles . <charleszhao@chromium.org>
Commit-Queue: Charles . <charleszhao@chromium.org>
diff --git a/ml/graph_executor_impl.cc b/ml/graph_executor_impl.cc
index 3c902b1..698626f 100644
--- a/ml/graph_executor_impl.cc
+++ b/ml/graph_executor_impl.cc
@@ -175,8 +175,7 @@
                                 ExecuteCallback callback) {
   DCHECK(!metrics_model_name_.empty());
 
-  RequestMetrics<ExecuteResult> request_metrics(metrics_model_name_,
-                                                kMetricsRequestName);
+  RequestMetrics request_metrics(metrics_model_name_, kMetricsRequestName);
   request_metrics.StartRecordingPerformanceMetrics();
 
   // Validate input and output names (before executing graph, for efficiency).
diff --git a/ml/handwriting_recognizer_impl.cc b/ml/handwriting_recognizer_impl.cc
index fb7ca6d..1467739 100644
--- a/ml/handwriting_recognizer_impl.cc
+++ b/ml/handwriting_recognizer_impl.cc
@@ -9,6 +9,7 @@
 
 #include "ml/handwriting_path.h"
 #include "ml/handwriting_proto_mojom_conversion.h"
+#include "ml/request_metrics.h"
 
 namespace ml {
 namespace {
@@ -67,6 +68,9 @@
 
 void HandwritingRecognizerImpl::Recognize(HandwritingRecognitionQueryPtr query,
                                           RecognizeCallback callback) {
+  RequestMetrics request_metrics("HandwritingModel", "Recognize");
+  request_metrics.StartRecordingPerformanceMetrics();
+
   chrome_knowledge::HandwritingRecognizerResult result_proto;
 
   if (ml::HandwritingLibrary::GetInstance()->RecognizeHandwriting(
@@ -74,11 +78,15 @@
           &result_proto)) {
     // Recognition succeeded, run callback on the result.
     std::move(callback).Run(HandwritingRecognizerResultFromProto(result_proto));
+    request_metrics.FinishRecordingPerformanceMetrics();
+    request_metrics.RecordRequestEvent(HandwritingRecognizerResult::Status::OK);
   } else {
     // Recognition failed, run callback on empty result and status = ERROR.
     std::move(callback).Run(HandwritingRecognizerResult::New(
         HandwritingRecognizerResult::Status::ERROR,
         std::vector<HandwritingRecognizerCandidatePtr>()));
+    request_metrics.RecordRequestEvent(
+        HandwritingRecognizerResult::Status::ERROR);
   }
 }
 
diff --git a/ml/machine_learning_service_impl.cc b/ml/machine_learning_service_impl.cc
index df2fd9f..44e4b4d 100644
--- a/ml/machine_learning_service_impl.cc
+++ b/ml/machine_learning_service_impl.cc
@@ -103,8 +103,8 @@
 
   DCHECK(!metadata.metrics_model_name.empty());
 
-  RequestMetrics<LoadModelResult> request_metrics(metadata.metrics_model_name,
-                                                  kMetricsRequestName);
+  RequestMetrics request_metrics(metadata.metrics_model_name,
+                                 kMetricsRequestName);
   request_metrics.StartRecordingPerformanceMetrics();
 
   // Attempt to load model.
@@ -134,8 +134,7 @@
     LoadFlatBufferModelCallback callback) {
   DCHECK(!spec->metrics_model_name.empty());
 
-  RequestMetrics<LoadModelResult> request_metrics(spec->metrics_model_name,
-                                                  kMetricsRequestName);
+  RequestMetrics request_metrics(spec->metrics_model_name, kMetricsRequestName);
   request_metrics.StartRecordingPerformanceMetrics();
 
   // Take the ownership of the content of `model_string` because `ModelImpl` has
@@ -169,8 +168,7 @@
 void MachineLearningServiceImpl::LoadTextClassifier(
     mojo::PendingReceiver<TextClassifier> receiver,
     LoadTextClassifierCallback callback) {
-  RequestMetrics<LoadModelResult> request_metrics("TextClassifier",
-                                                  kMetricsRequestName);
+  RequestMetrics request_metrics("TextClassifier", kMetricsRequestName);
   request_metrics.StartRecordingPerformanceMetrics();
 
   // Attempt to load model.
@@ -216,8 +214,7 @@
     HandwritingRecognizerSpecPtr spec,
     mojo::PendingReceiver<HandwritingRecognizer> receiver,
     LoadHandwritingModelCallback callback) {
-  RequestMetrics<LoadModelResult> request_metrics("HandwritingModel",
-                                                  kMetricsRequestName);
+  RequestMetrics request_metrics("HandwritingModel", kMetricsRequestName);
   request_metrics.StartRecordingPerformanceMetrics();
 
   // Load HandwritingLibrary.
diff --git a/ml/model_impl.cc b/ml/model_impl.cc
index fbf1d78..0c70783 100644
--- a/ml/model_impl.cc
+++ b/ml/model_impl.cc
@@ -101,8 +101,7 @@
     CreateGraphExecutorCallback callback) {
   DCHECK(!metrics_model_name_.empty());
 
-  RequestMetrics<CreateGraphExecutorResult> request_metrics(
-      metrics_model_name_, kMetricsRequestName);
+  RequestMetrics request_metrics(metrics_model_name_, kMetricsRequestName);
   request_metrics.StartRecordingPerformanceMetrics();
 
   if (model_ == nullptr) {
diff --git a/ml/request_metrics.cc b/ml/request_metrics.cc
index 933a5f1..8678142 100644
--- a/ml/request_metrics.cc
+++ b/ml/request_metrics.cc
@@ -4,7 +4,7 @@
 
 #include "ml/request_metrics.h"
 
-#include <metrics/metrics_library.h>
+#include <base/logging.h>
 
 #include "ml/mojom/machine_learning_service.mojom.h"
 
@@ -12,6 +12,69 @@
 
 using chromeos::machine_learning::mojom::LoadModelResult;
 
+RequestMetrics::RequestMetrics(
+    const std::string& model_name, const std::string& request_name)
+    : name_base_(std::string(kGlobalMetricsPrefix) + model_name + "." +
+                 request_name),
+      process_metrics_(nullptr) {}
+
+void RequestMetrics::StartRecordingPerformanceMetrics() {
+  DCHECK(process_metrics_ == nullptr);
+  process_metrics_ = base::ProcessMetrics::CreateCurrentProcessMetrics();
+  // Call GetPlatformIndependentCPUUsage in order to set the "zero" point of the
+  // CPU usage counter of process_metrics_.
+  process_metrics_->GetPlatformIndependentCPUUsage();
+  timer_.Start();
+  // Query memory usage.
+  size_t usage = 0;
+  if (!GetTotalProcessMemoryUsage(&usage)) {
+    LOG(DFATAL) << "Getting process memory usage failed.";
+    return;
+  }
+  initial_memory_ = static_cast<int64_t>(usage);
+}
+
+void RequestMetrics::FinishRecordingPerformanceMetrics() {
+  DCHECK(process_metrics_ != nullptr);
+  // To get CPU time, we multiply elapsed (wall) time by CPU usage percentage.
+  timer_.Stop();
+  base::TimeDelta elapsed_time;
+  DCHECK(timer_.GetElapsedTime(&elapsed_time));
+  const int64_t elapsed_time_microsec = elapsed_time.InMicroseconds();
+
+  // CPU usage, 12.34 means 12.34%, and the range is 0 to 100 * numCPUCores.
+  // That's to say it can exceed 100 when there're multi CPUs.
+  // For example, if the device has 4 CPUs and the process fully uses 2 of
+  // them, the percent will be 200%.
+  const double cpu_usage_percent =
+      process_metrics_->GetPlatformIndependentCPUUsage();
+
+  // CPU time, as mentioned above, "100 microseconds" means "1 CPU core fully
+  // utilized for 100 microseconds".
+  const int64_t cpu_time_microsec =
+      static_cast<int64_t>(cpu_usage_percent * elapsed_time_microsec / 100.);
+
+  // Memory usage
+  size_t usage = 0;
+  if (!GetTotalProcessMemoryUsage(&usage)) {
+    LOG(DFATAL) << "Getting process memory usage failed.";
+    return;
+  }
+  const int64_t memory_usage_kb =
+      static_cast<int64_t>(usage) - initial_memory_;
+
+  metrics_library_.SendToUMA(name_base_ + kTotalMemoryDeltaSuffix,
+                             memory_usage_kb,
+                             kMemoryDeltaMinKb,
+                             kMemoryDeltaMaxKb,
+                             kMemoryDeltaBuckets);
+  metrics_library_.SendToUMA(name_base_ + kCpuTimeSuffix,
+                             cpu_time_microsec,
+                             kCpuTimeMinMicrosec,
+                             kCpuTimeMaxMicrosec,
+                             kCpuTimeBuckets);
+}
+
 // Records in MachineLearningService.LoadModelResult rather than a
 // model-specific enum histogram because the model name is unknown.
 void RecordModelSpecificationErrorEvent() {
diff --git a/ml/request_metrics.h b/ml/request_metrics.h
index 91d25f5..622602e 100644
--- a/ml/request_metrics.h
+++ b/ml/request_metrics.h
@@ -5,17 +5,11 @@
 #ifndef ML_REQUEST_METRICS_H_
 #define ML_REQUEST_METRICS_H_
 
-#include <algorithm>
 #include <memory>
 #include <string>
-#include <vector>
 
-#include <base/bind.h>
-#include <base/files/file_path.h>
-#include <base/logging.h>
 #include <base/macros.h>
 #include <base/process/process_metrics.h>
-#include <base/system/sys_info.h>
 #include <base/time/time.h>
 #include <metrics/metrics_library.h>
 
@@ -31,7 +25,6 @@
 // specific actions, currently we reuse the enum classes defined in mojoms. The
 // enum class generally contains an OK and several different Errors, besides,
 // there should be a kMax which shares the value of the highest enumerator.
-template <class RequestEventEnum>
 class RequestMetrics {
  public:
   // Creates a RequestMetrics with the specified model and request names.
@@ -41,6 +34,7 @@
                  const std::string& request_name);
 
   // Logs (to UMA) the specified `event` associated with this request.
+  template <class RequestEventEnum>
   void RecordRequestEvent(RequestEventEnum event);
 
   // When you want to record metrics of some action, call Start func at the
@@ -77,80 +71,13 @@
 constexpr int kCpuTimeBuckets = 100;
 
 template <class RequestEventEnum>
-RequestMetrics<RequestEventEnum>::RequestMetrics(
-    const std::string& model_name, const std::string& request_name)
-    : name_base_(std::string(kGlobalMetricsPrefix) + model_name + "." +
-                 request_name),
-      process_metrics_(nullptr) {}
-
-template <class RequestEventEnum>
-void RequestMetrics<RequestEventEnum>::RecordRequestEvent(
-    RequestEventEnum event) {
+void RequestMetrics::RecordRequestEvent(RequestEventEnum event) {
   metrics_library_.SendEnumToUMA(
       name_base_ + kEventSuffix, static_cast<int>(event),
       static_cast<int>(RequestEventEnum::kMaxValue) + 1);
   process_metrics_.reset(nullptr);
 }
 
-template <class RequestEventEnum>
-void RequestMetrics<RequestEventEnum>::StartRecordingPerformanceMetrics() {
-  DCHECK(process_metrics_ == nullptr);
-  process_metrics_ = base::ProcessMetrics::CreateCurrentProcessMetrics();
-  // Call GetPlatformIndependentCPUUsage in order to set the "zero" point of the
-  // CPU usage counter of process_metrics_.
-  process_metrics_->GetPlatformIndependentCPUUsage();
-  timer_.Start();
-  // Query memory usage.
-  size_t usage = 0;
-  if (!GetTotalProcessMemoryUsage(&usage)) {
-    LOG(DFATAL) << "Getting process memory usage failed.";
-    return;
-  }
-  initial_memory_ = static_cast<int64_t>(usage);
-}
-
-template <class RequestEventEnum>
-void RequestMetrics<RequestEventEnum>::FinishRecordingPerformanceMetrics() {
-  DCHECK(process_metrics_ != nullptr);
-  // To get CPU time, we multiply elapsed (wall) time by CPU usage percentage.
-  timer_.Stop();
-  base::TimeDelta elapsed_time;
-  DCHECK(timer_.GetElapsedTime(&elapsed_time));
-  const int64_t elapsed_time_microsec = elapsed_time.InMicroseconds();
-
-  // CPU usage, 12.34 means 12.34%, and the range is 0 to 100 * numCPUCores.
-  // That's to say it can exceed 100 when there're multi CPUs.
-  // For example, if the device has 4 CPUs and the process fully uses 2 of
-  // them, the percent will be 200%.
-  const double cpu_usage_percent =
-      process_metrics_->GetPlatformIndependentCPUUsage();
-
-  // CPU time, as mentioned above, "100 microseconds" means "1 CPU core fully
-  // utilized for 100 microseconds".
-  const int64_t cpu_time_microsec =
-      static_cast<int64_t>(cpu_usage_percent * elapsed_time_microsec / 100.);
-
-  // Memory usage
-  size_t usage = 0;
-  if (!GetTotalProcessMemoryUsage(&usage)) {
-    LOG(DFATAL) << "Getting process memory usage failed.";
-    return;
-  }
-  const int64_t memory_usage_kb =
-      static_cast<int64_t>(usage) - initial_memory_;
-
-  metrics_library_.SendToUMA(name_base_ + kTotalMemoryDeltaSuffix,
-                             memory_usage_kb,
-                             kMemoryDeltaMinKb,
-                             kMemoryDeltaMaxKb,
-                             kMemoryDeltaBuckets);
-  metrics_library_.SendToUMA(name_base_ + kCpuTimeSuffix,
-                             cpu_time_microsec,
-                             kCpuTimeMinMicrosec,
-                             kCpuTimeMaxMicrosec,
-                             kCpuTimeBuckets);
-}
-
 // Records a generic model specification error event during a model loading
 // (LoadBuiltinModel or LoadFlatBufferModel) request.
 void RecordModelSpecificationErrorEvent();
diff --git a/ml/text_classifier_impl.cc b/ml/text_classifier_impl.cc
index 3f7de36..ec073af 100644
--- a/ml/text_classifier_impl.cc
+++ b/ml/text_classifier_impl.cc
@@ -78,8 +78,7 @@
 
 void TextClassifierImpl::Annotate(TextAnnotationRequestPtr request,
                                   AnnotateCallback callback) {
-  RequestMetrics<TextAnnotationResult> request_metrics("TextClassifier",
-                                                       "Annotate");
+  RequestMetrics request_metrics("TextClassifier", "Annotate");
   request_metrics.StartRecordingPerformanceMetrics();
 
   // Parse and set up the options.
@@ -142,13 +141,11 @@
   std::move(callback).Run(std::move(annotations));
 
   request_metrics.FinishRecordingPerformanceMetrics();
-  request_metrics.RecordRequestEvent(TextAnnotationResult::OK);
 }
 
 void TextClassifierImpl::SuggestSelection(
     TextSuggestSelectionRequestPtr request, SuggestSelectionCallback callback) {
-  RequestMetrics<SuggestSelectionResult> request_metrics("TextClassifier",
-                                                         "SuggestSelection");
+  RequestMetrics request_metrics("TextClassifier", "SuggestSelection");
   request_metrics.StartRecordingPerformanceMetrics();
 
   libtextclassifier3::SelectionOptions option;
@@ -174,13 +171,11 @@
   std::move(callback).Run(std::move(result_span));
 
   request_metrics.FinishRecordingPerformanceMetrics();
-  request_metrics.RecordRequestEvent(SuggestSelectionResult::OK);
 }
 
 void TextClassifierImpl::FindLanguages(const std::string& text,
                                        FindLanguagesCallback callback) {
-  RequestMetrics<FindLanguagesResult> request_metrics("TextClassifier",
-                                                      "FindLanguages");
+  RequestMetrics request_metrics("TextClassifier", "FindLanguages");
   request_metrics.StartRecordingPerformanceMetrics();
 
   const std::vector<std::pair<std::string, float>> languages =