blob: 98238ee1a1ba502c01caac11bd414d0a38a65677 [file] [log] [blame] [edit]
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// This file is mostly copied from chromium repo:
// //components/assist_ranker/ranker_example_util_unittest.cc
#include "ml/example_preprocessor/ranker_example_util.h"
#include <limits>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
namespace assist_ranker {
using ::testing::ElementsAreArray;
class RankerExampleUtilTest : public ::testing::Test {
protected:
void SetUp() override {
auto& features = *example_.mutable_features();
features[bool_name_].set_bool_value(bool_value_);
features[int32_name_].set_int32_value(int32_value_);
features[float_name_].set_float_value(float_value_);
features[one_hot_name_].set_string_value(one_hot_value_);
}
RankerExample example_;
const std::string bool_name_ = "bool_feature";
const bool bool_value_ = true;
const std::string int32_name_ = "int32_feature";
const int int32_value_ = 2;
const std::string float_name_ = "float_feature";
const float float_value_ = 3.0f;
const std::string one_hot_name_ = "one_hot_feature";
const std::string elem1_ = "elem1";
const std::string elem2_ = "elem2";
const std::string one_hot_value_ = elem1_;
const float epsilon_ = 0.00000001f;
};
TEST_F(RankerExampleUtilTest, CheckFeature) {
EXPECT_TRUE(SafeGetFeature(bool_name_, example_, nullptr));
EXPECT_TRUE(SafeGetFeature(int32_name_, example_, nullptr));
EXPECT_TRUE(SafeGetFeature(float_name_, example_, nullptr));
EXPECT_TRUE(SafeGetFeature(one_hot_name_, example_, nullptr));
EXPECT_FALSE(SafeGetFeature("", example_, nullptr));
EXPECT_FALSE(SafeGetFeature("foo", example_, nullptr));
}
TEST_F(RankerExampleUtilTest, SafeGetFeature) {
Feature feature;
EXPECT_TRUE(SafeGetFeature(bool_name_, example_, &feature));
EXPECT_TRUE(feature.bool_value());
feature.Clear();
EXPECT_TRUE(SafeGetFeature(int32_name_, example_, &feature));
EXPECT_EQ(int32_value_, feature.int32_value());
feature.Clear();
EXPECT_TRUE(SafeGetFeature(float_name_, example_, &feature));
EXPECT_NEAR(float_value_, feature.float_value(), epsilon_);
feature.Clear();
EXPECT_TRUE(SafeGetFeature(one_hot_name_, example_, &feature));
EXPECT_EQ(one_hot_value_, feature.string_value());
feature.Clear();
EXPECT_FALSE(SafeGetFeature("", example_, &feature));
EXPECT_FALSE(SafeGetFeature("foo", example_, &feature));
}
TEST_F(RankerExampleUtilTest, GetFeatureValueAsFloat) {
float value;
EXPECT_TRUE(GetFeatureValueAsFloat(bool_name_, example_, &value));
EXPECT_NEAR(1.0f, value, epsilon_);
EXPECT_TRUE(GetFeatureValueAsFloat(int32_name_, example_, &value));
EXPECT_NEAR(2.0f, value, epsilon_);
EXPECT_TRUE(GetFeatureValueAsFloat(float_name_, example_, &value));
EXPECT_NEAR(3.0f, value, epsilon_);
EXPECT_FALSE(GetFeatureValueAsFloat(one_hot_name_, example_, &value));
// Value remains unchanged if GetFeatureValueAsFloat returns false.
EXPECT_NEAR(3.0f, value, epsilon_);
EXPECT_FALSE(GetFeatureValueAsFloat("", example_, &value));
EXPECT_FALSE(GetFeatureValueAsFloat("foo", example_, &value));
}
TEST_F(RankerExampleUtilTest, GetOneHotValue) {
std::string value;
EXPECT_FALSE(GetOneHotValue(bool_name_, example_, &value));
EXPECT_FALSE(GetOneHotValue(int32_name_, example_, &value));
EXPECT_FALSE(GetOneHotValue(float_name_, example_, &value));
EXPECT_TRUE(GetOneHotValue(one_hot_name_, example_, &value));
EXPECT_EQ(one_hot_value_, value);
EXPECT_FALSE(GetOneHotValue("", example_, &value));
EXPECT_FALSE(GetOneHotValue("foo", example_, &value));
}
TEST_F(RankerExampleUtilTest, ScalarFeatureInt64Conversion) {
Feature feature;
int64_t int64_value;
feature.set_bool_value(true);
EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
EXPECT_EQ(int64_value, 72057594037927937LL);
feature.set_int32_value(std::numeric_limits<int32_t>::max());
EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
EXPECT_EQ(int64_value, 216172784261267455LL);
feature.set_int32_value(std::numeric_limits<int32_t>::lowest());
EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
EXPECT_EQ(int64_value, 216172784261267456LL);
feature.set_string_value("foo");
EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
EXPECT_EQ(int64_value, 288230377439557724LL);
}
TEST_F(RankerExampleUtilTest, FloatFeatureInt64Conversion) {
Feature feature;
int64_t int64_value;
feature.set_float_value(std::numeric_limits<float>::epsilon());
EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
EXPECT_EQ(int64_value, 144115188948271104LL);
feature.set_float_value(-std::numeric_limits<float>::epsilon());
EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
EXPECT_EQ(int64_value, 144115191095754752LL);
feature.set_float_value(std::numeric_limits<float>::max());
EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
EXPECT_EQ(int64_value, 144115190214950911LL);
feature.set_float_value(std::numeric_limits<float>::lowest());
EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
EXPECT_EQ(int64_value, 144115192362434559LL);
}
TEST_F(RankerExampleUtilTest, StringListInt64Conversion) {
Feature feature;
int64_t int64_value;
feature.mutable_string_list()->add_string_value("");
feature.mutable_string_list()->add_string_value("TEST");
EXPECT_TRUE(FeatureToInt64(feature, &int64_value, 1));
EXPECT_EQ(int64_value, 360287974776690660LL);
}
TEST_F(RankerExampleUtilTest, HashExampleFeatureNames) {
auto hashed_example = HashExampleFeatureNames(example_);
// Hashed example has the same number of features.
EXPECT_EQ(example_.features().size(), hashed_example.features().size());
// But the feature names have changed.
EXPECT_FALSE(SafeGetFeature(bool_name_, hashed_example, nullptr));
EXPECT_FALSE(SafeGetFeature(int32_name_, hashed_example, nullptr));
EXPECT_FALSE(SafeGetFeature(float_name_, hashed_example, nullptr));
EXPECT_FALSE(SafeGetFeature(one_hot_name_, hashed_example, nullptr));
EXPECT_TRUE(
SafeGetFeature(HashFeatureName(bool_name_), hashed_example, nullptr));
// Values have not changed.
float float_value;
EXPECT_TRUE(GetFeatureValueAsFloat(HashFeatureName(float_name_),
hashed_example, &float_value));
EXPECT_EQ(float_value_, float_value);
std::string string_value;
EXPECT_TRUE(GetOneHotValue(HashFeatureName(one_hot_name_), hashed_example,
&string_value));
EXPECT_EQ(one_hot_value_, string_value);
}
} // namespace assist_ranker