blob: 3c03e5377dfbd46a478330455ca286b55395bd11 [file] [log] [blame] [edit]
# Copyright 2022 The ChromiumOS Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Utilities needed for Fingerprint Study Analysis."""
from __future__ import annotations
import collections
import enum
import timeit
from typing import Any, Iterable, Literal
from IPython.display import display
from IPython.display import Markdown
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd
from scipy.stats import norm
class DataFrameSetAccess:
"""Provides a quick method of checking if a given row exists in the table.
This look method takes hundreds of nanoseconds vs other methods that take
hundreds of micro seconds. Given the amount of times we must query certain
tables, this order of magnitude difference is unacceptable.
The constructor is very slow, as it builds the cache, but all accessor
methods are very fast.
"""
def __init__(self, tbl: pd.DataFrame, cols: list[str] | None = None):
"""This performs the expensive caching operation, that must occur once.
Args:
tbl: The DataFrame to cache.
cols: The specific columns of the DataFrame to index into the Trie.
If None, all columns are used.
"""
if not cols:
cols = list(tbl.columns)
self.cols = cols
# There is nothing incompatible with having duplicate rows for this
# data structure, but being providing duplicates might indicate an
# attempt to analyze cross sections of data using the wrong tool.
# Just take for example that we are caching only the first three columns
# of a DataFrame. If this DataFrame contains match decisions, then there
# will be many rows that have the same first three columns. This data
# structure will collapse all of these identical results to one entry.
assert not tbl.duplicated(subset=cols).any()
# This is an expensive operation.
self.set = {tuple(row) for row in np.array(tbl[cols])}
def isin(self, values: tuple[Any, ...]) -> bool:
"""Check if the values appeared as a cached row."""
# This must remain very fast, so do not add additional asserts/check.
return values in self.set
class DataFrameCountTrieAccess:
"""Provides a quick method of checking the number of matching rows.
This implementation builds a trie with all partial row columns,
from the empty tuple (all rows) to the full row in tuple form.
The count of all downstream nodes is saved at each trie node.
The constructor is very slow, as it builds the cache, but all accessor
methods are very fast.
This is on par with the performance of `DataFrameSetAccess`, but still
tens of nanoseconds slower.
"""
def __init__(self, tbl: pd.DataFrame, cols: list[str] | None = None):
"""This performs the expensive caching operation, that must occur once.
Args:
tbl: The DataFrame to index into the Trie.
cols: The specific columns of the DataFrame to index into the Trie.
If None, all columns are used.
"""
if not cols:
cols = list(tbl.columns)
self.cols = cols
self.counts_dict: collections.Counter[
tuple[Any, ...]
] = collections.Counter()
for row in np.array(tbl[cols]):
# Update all partial trie nodes. For example, take row (val1, val2):
# We would increment all tuples ()++, (val1)++, (val1, val2)++, ...
for i in range(len(cols) + 1):
# We include the empty tuple (row[0:0]) count also.
t = tuple(row)[0:i]
self.counts_dict[t] += 1
def isin(self, values: tuple[Any, ...]) -> bool:
"""A tuple will only be in the cache if the count is at least 1."""
# This must remain very fast, so do not add additional asserts/check.
return values in self.counts_dict
def counts(self, values: tuple[Any, ...]) -> int:
"""Get the number of rows that start with `values` tuple."""
# This must remain very fast, so do not add additional asserts/check.
return self.counts_dict[values]
# pylint: disable=unused-argument
def fake_use(v: Any):
"""Make a variable seem to be used to avoid linter warnings.
This is useful when benchmarking, since the expression being testing is
passed in as a statement string for evaluation.
"""
return
def boot_sample(
# This is the fastest input to rng.choice, other than a scalar.
a: npt.NDArray[Any],
*,
n: int | None = None,
rng: np.random.Generator = np.random.default_rng(),
) -> npt.NDArray:
"""Sample with replacement the same number of elements given.
If `n` is given, do `n` number of repeat bootstrap samples
and return an `n x a.size` ndarray.
Equivalent to `rng.choice(a, size=(a.size, n), replace=True)`.
NOTE: It is slightly faster when invoking rng.choice with a scalar as the
first argument, instead of giving it an np.array.
See `boot_sample_range`.
"""
if n:
return rng.choice(a, size=(n, a.size), replace=True)
else:
return rng.choice(a, size=a.size, replace=True)
def boot_sample_range(
# Scalar input is the fastest invocation to rng.choice.
range_max: int,
n: int | None = None,
rng: np.random.Generator = np.random.default_rng(),
) -> npt.NDArray[np.int64]:
"""Sample with replacement `range_max` elements from `0` to `range_max`.
This is slightly faster than `fpsutils.boot_sample`.
Equivalent to `rng.choice(range_max, size=range, replace=True)`.
"""
return rng.choice(range_max, size=range_max, replace=True)
def plot_pd_column_hist_discrete(
tbl: pd.DataFrame, column: str, title_prefix: str | None = None
):
"""Plot the histogram of a single column of a DataFrame"""
vals = np.unique(tbl[column], return_counts=True)
plt.bar(*vals)
plt.xticks(vals[0], rotation="vertical", fontsize=5)
plt.xlabel(column)
plt.ylabel("Count")
if title_prefix:
plt.title(f"{title_prefix} by {column}")
def plot_pd_hist_discrete(
tbl: pd.DataFrame,
title_prefix: str | None = None,
figsize: tuple | None = None,
):
"""Plot the histograms of each column in a DataFrame.
This is different than `pd.DataFrame.hist`, because it ensures that all
unique elements of the column are represented in a bar plot. Other
implementations will try to bin multiple values and doesn't center the
graphical bars on the values.
"""
num_plots = len(tbl.columns)
if not figsize:
figsize = (10, 6 * num_plots)
plt.figure(figsize=figsize)
for index, col in enumerate(tbl.columns):
plt.subplot(num_plots, 1, index + 1)
plot_pd_column_hist_discrete(tbl, col, title_prefix)
plt.show()
def discrete_hist(data) -> tuple[npt.NDArray, npt.NDArray]:
"""Return a tuple of unique items and their counts.
Returns:
([items], [counts])
"""
return np.unique(data, return_counts=True)
def has_columns(df: pd.DataFrame, cols: Iterable[enum.Enum | str]) -> bool:
"""Check if the DataFrame `df` contains all `cols`.
This allows for specifying a list of Enums, whose `value` is the column
name that is expected.
"""
col_strings = {isinstance(c, enum.Enum) and c.value or str(c) for c in cols}
return col_strings <= set(df.columns)
def plt_discrete_hist(data):
counts = np.bincount(data)
# We need to zoom in, since there would be thousands of thousands of bars
# that are zero near the tail end.
nonzero_indicies = np.nonzero(counts)
first_index = np.min(nonzero_indicies)
# first_index = 0
last_index = np.max(nonzero_indicies)
if (last_index - first_index) < 2000:
# plt.title(f'Samples {np.size(r)} | p = {p:.3e} | Groups {groups}')
x = np.arange(start=first_index, stop=last_index + 1)
h = counts[first_index : last_index + 1]
plt.bar(x, h)
# plt.xticks(x)
# plt.xlabel('Group Sums')
plt.ylabel("Frequency")
# plt.vlines(bin_edges[:np.size(bin_edges)-1], 0, hist)
# Overlay Norm Curve
mu, std = norm.fit(data)
mean = np.mean(data)
print(f"first={first_index} last={last_index}")
print(
f"mu={mu} , std={std} 3*std={3*std}, "
f"np.mean(data) = {np.mean(data)}"
)
# x_curve = x
x_curve = np.linspace(mean - 3 * std, mean + 3 * std, 50)
p = norm.pdf(x_curve, mu, std)
p_scaled = p * np.sum(h)
plt.plot(x_curve, p_scaled, "k", linewidth=2)
plt.xticks([mean - 3 * std, mean, mean + 3 * std])
else:
display(
Markdown(
f"Plot is too large (first={first_index} last={last_index}),"
" not diplaying."
)
)
def plt_discrete_hist2(data):
"""Plot a histogram of `data`, where all values are represented on x-axis.
It also fit a normal curve and places vertical lines for the mean
and 2x standard deviation limits.
"""
vals, counts = np.unique(data, return_counts=True)
plt.bar(vals, counts)
plt.xticks(vals, rotation="vertical", fontsize=5)
# plt.xlabel(col)
plt.ylabel("Frequency")
# plt.show()
# Overlay normal curve that spans 3x standard deviations.
mean, std = norm.fit(data)
x_curve = np.linspace(mean - 3 * std, mean + 3 * std, 50)
p = norm.pdf(x_curve, mean, std)
p_scaled = p * np.sum(counts)
plt.plot(x_curve, p_scaled, "k", linewidth=2)
# Place 2x standard deviation confidence lines and mean.
mid_y = np.max(counts) / 2
for ci_x, ci_label in [
(mean - 2 * std, "lower 2*std"),
(mean, "mean"),
(mean + 2 * std, "upper 2*std"),
]:
plt.axvline(ci_x, color="red")
plt.text(
ci_x + 1,
mid_y,
f"{ci_x:.3f} is {ci_label}",
rotation=90,
color="red",
)
def elapsed_time_str(seconds: float) -> str:
"""Convert a seconds value into a more easily interpretable units str.
Example: elapsed_time_str(0.003) -> "3ms"
"""
# TODO: See if numpy.timedelta64 can be used.
hrs = int(seconds / 60.0**2)
seconds -= hrs * 60**2
mins = int(seconds / 60.0)
seconds -= mins * 60.0
s = int(seconds)
ms = int(seconds * 1e3) % 1000
us = int(seconds * 1e6) % 1000
ns = (seconds * 1e9) % 1000
string = ""
string += hrs and f"{hrs}hr " or ""
string += mins and f"{mins}min " or ""
string += s and f"{s}s " or ""
string += ms and f"{ms}ms " or ""
string += us and f"{us}us " or ""
string += ns and f"{ns:3.3f}ns " or ""
return string.rstrip()
def benchmark(
stmt: str,
setup: str = "pass",
global_vars: dict[str, Any] = {**locals(), **globals()},
) -> tuple[int, float, float]:
"""Measure the runtime of `stmt`.
This method invokes timeit.Timer.autorange and print results.
Returns:
(num_loops, sec_total, sec_per_loop)
"""
loops, sec = timeit.Timer(
stmt, setup=setup, globals=global_vars
).autorange()
print(
f'Ran "{stmt}" {loops} times over {sec}s.'
" "
f"It took {elapsed_time_str(sec/loops)} per loop."
)
return loops, sec, np.divide(sec, loops)
def fmt_far(
far_value: float, fmt: Literal["k", "s"] = "k", decimal_places: int = 3
) -> str:
"""Pretty print an FAR value.
Args:
`far_value` is the FAR value to show.
`fmt` indicates the output format.
Either `'k'` for `1/<FAR>k` or `'s'` for scientific notation.
`decimal_places` indicates the number of values to show past
the decimal point.
"""
if fmt not in ["k", "s"]:
raise TypeError("type must be 'k' or 's'.")
if fmt == "k":
if far_value == 0.0:
return "0"
return f"1/{{:.{decimal_places}f}}k".format(1 / (far_value * 1000))
else:
return f"{{:.{decimal_places}e}}".format(far_value)
def fmt_frr(frr_value: float, decimal_places: int = 3) -> str:
"""Pretty print an FRR value as a percentage.
Args:
`frr_value` is the FRR value to show.
`decimal_places` indicates the number of values to show past
the decimal point.
"""
return f"{{:.{decimal_places}%}}".format(frr_value)