| // Copyright 2021 The Chromium OS Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| // This file is mostly copied from chromium repo: |
| // //components/assist_ranker/ranker_example_util.cc |
| |
| #include "ml/example_preprocessor/ranker_example_util.h" |
| |
| #include <limits> |
| #include <math.h> |
| |
| #include <base/bit_cast.h> |
| #include <base/format_macros.h> |
| #include <base/logging.h> |
| #include <base/metrics/metrics_hashes.h> |
| #include <base/notreached.h> |
| #include <base/strings/stringprintf.h> |
| |
| namespace assist_ranker { |
| namespace { |
| const uint64_t MASK32Bits = (1LL << 32) - 1; |
| constexpr int kFloatMainDigits = 23; |
| // Returns lower 32 bits of the hash of the input. |
| int32_t StringToIntBits(const std::string& str) { |
| return base::HashMetricName(str) & MASK32Bits; |
| } |
| |
| // Converts float to int32 |
| int32_t FloatToIntBits(float f) { |
| if (std::numeric_limits<float>::is_iec559) { |
| // Directly bit_cast if float follows ieee754 standard. |
| return bit_cast<int32_t>(f); |
| } else { |
| // Otherwise, manually calculate sign, exp and mantissa. |
| // For sign. |
| const uint32_t sign = f < 0; |
| |
| // For exponent. |
| int exp; |
| f = std::abs(std::frexp(f, &exp)); |
| // Add 126 to get non-negative format of exp. |
| // This should not be 127 because the return of frexp is different from |
| // ieee754 with a multiple of 2. |
| const uint32_t exp_u = exp + 126; |
| |
| // Get mantissa. |
| const uint32_t mantissa = std::ldexp(f * 2.0f - 1.0f, kFloatMainDigits); |
| // Set each bits and return. |
| return (sign << 31) | (exp_u << kFloatMainDigits) | mantissa; |
| } |
| } |
| |
| // Pair type, value and index into one int64. |
| int64_t PairInt(const uint64_t type, |
| const uint32_t value, |
| const uint64_t index) { |
| return (type << 56) | (index << 32) | static_cast<uint64_t>(value); |
| } |
| |
| } // namespace |
| |
| bool SafeGetFeature(const std::string& key, |
| const RankerExample& example, |
| Feature* feature) { |
| auto p_feature = example.features().find(key); |
| if (p_feature != example.features().end()) { |
| if (feature) |
| *feature = p_feature->second; |
| return true; |
| } |
| return false; |
| } |
| |
| bool GetFeatureValueAsFloat(const std::string& key, |
| const RankerExample& example, |
| float* value) { |
| Feature feature; |
| if (!SafeGetFeature(key, example, &feature)) { |
| return false; |
| } |
| switch (feature.feature_type_case()) { |
| case Feature::kBoolValue: |
| *value = static_cast<float>(feature.bool_value()); |
| break; |
| case Feature::kInt32Value: |
| *value = static_cast<float>(feature.int32_value()); |
| break; |
| case Feature::kFloatValue: |
| *value = feature.float_value(); |
| break; |
| default: |
| return false; |
| } |
| return true; |
| } |
| |
| bool FeatureToInt64(const Feature& feature, |
| int64_t* const res, |
| const int index) { |
| int32_t value = -1; |
| int32_t type = feature.feature_type_case(); |
| switch (type) { |
| case Feature::kBoolValue: |
| value = static_cast<int32_t>(feature.bool_value()); |
| break; |
| case Feature::kFloatValue: |
| value = FloatToIntBits(feature.float_value()); |
| break; |
| case Feature::kInt32Value: |
| value = feature.int32_value(); |
| break; |
| case Feature::kStringValue: |
| value = StringToIntBits(feature.string_value()); |
| break; |
| case Feature::kStringList: |
| if (index >= 0 && index < feature.string_list().string_value_size()) { |
| value = StringToIntBits(feature.string_list().string_value(index)); |
| } else { |
| DVLOG(3) << "Invalid index for string list: " << index; |
| NOTREACHED(); |
| return false; |
| } |
| break; |
| default: |
| DVLOG(3) << "Feature type is supported for logging: " << type; |
| NOTREACHED(); |
| return false; |
| } |
| *res = PairInt(type, value, index); |
| return true; |
| } |
| |
| bool GetOneHotValue(const std::string& key, |
| const RankerExample& example, |
| std::string* value) { |
| Feature feature; |
| if (!SafeGetFeature(key, example, &feature)) { |
| return false; |
| } |
| if (feature.feature_type_case() != Feature::kStringValue) { |
| DVLOG(1) << "Feature " << key |
| << " exists, but is not the right type (Expected: " |
| << Feature::kStringValue |
| << " vs. Actual: " << feature.feature_type_case() << ")"; |
| return false; |
| } |
| *value = feature.string_value(); |
| return true; |
| } |
| |
| // Converts string to a hex hash string. |
| std::string HashFeatureName(const std::string& feature_name) { |
| uint64_t feature_key = base::HashMetricName(feature_name); |
| return base::StringPrintf("%016" PRIx64, feature_key); |
| } |
| |
| RankerExample HashExampleFeatureNames(const RankerExample& example) { |
| RankerExample hashed_example; |
| auto& output_features = *hashed_example.mutable_features(); |
| for (const auto& feature : example.features()) { |
| output_features[HashFeatureName(feature.first)] = feature.second; |
| } |
| *hashed_example.mutable_target() = example.target(); |
| return hashed_example; |
| } |
| |
| } // namespace assist_ranker |