Source code for mizani.utils

from __future__ import annotations

import math
import sys
from datetime import datetime, timezone
from typing import TYPE_CHECKING, cast, overload
from warnings import warn

import numpy as np
import pandas as pd
import pandas.api.types as pdtypes

if TYPE_CHECKING:
    from datetime import tzinfo
    from typing import Any, Optional, Sequence, TypeGuard

    from mizani.typing import (
        AnyArrayLike,
        DurationUnit,
        FloatArrayLike,
        FloatSeries,
        NDArrayFloat,
        NullType,
        NumericUFunction,
        SeqDatetime,
        TupleFloat2,
    )


__all__ = [
    "round_any",
    "min_max",
    "match",
    "precision",
    "isclose_abs",
    "same_log10_order_of_magnitude",
    "identity",
    "get_categories",
    "get_timezone",
    "has_dtype",
]

# Use sqrt(epsilon) to correct for loss of precision due floating point
# rounding errors. This value is derived for forward difference numerical
# difference, but for our use cases this choice of value is arbitrary.
EPSILON = sys.float_info.epsilon
ROUNDING_ERROR = np.sqrt(EPSILON)
ABS_TOL = 1e-10  # Absolute Tolerance

DISCRETE_KINDS = "ObUS"
CONTINUOUS_KINDS = "ifuc"

SECONDS: dict[DurationUnit, float] = {
    "ns": 1e-9,  # nanosecond
    "us": 1e-6,  # microsecond
    "ms": 1e-3,  # millisecond
    "s": 1,  # second
    "min": 60,  # minutes
    "h": 3600,  # hour
    "day": 24 * 3600,  # day
    "week": 7 * 24 * 3600,  # week
    "month": 31 * 24 * 3600,  # month
    "year": 365 * 24 * 3600,  # year
}

NANOSECONDS: dict[DurationUnit, float] = {
    "ns": 1,  # nanosecond
    "us": 1e3,  # microsecond
    "ms": 1e6,  # millisecond
    "s": 1e9,  # second
    "min": 60e9,  # minutes
    "h": 3600e9,  # hour
    "day": 24 * 3600e9,  # day
    "week": 7 * 24 * 3600e9,  # week
    "month": 31 * 24 * 3600e9,  # month
    "year": 365 * 24 * 3600e9,  # year
}


@overload
def round_any(
    x: FloatArrayLike, accuracy: float, f: NumericUFunction = np.round
) -> NDArrayFloat: ...


@overload
def round_any(
    x: float, accuracy: float, f: NumericUFunction = np.round
) -> float: ...


def round_any(
    x: FloatArrayLike | float, accuracy: float, f: NumericUFunction = np.round
) -> NDArrayFloat | float:
    """
    Round to multiple of any number.
    """
    if not is_vector(x):
        x = np.asarray(x)
    return f(x / accuracy) * accuracy


def min_max(
    x: FloatArrayLike | float, na_rm: bool = False, finite: bool = True
) -> TupleFloat2:
    """
    Return the minimum and maximum of x

    Parameters
    ----------
    x : array_like
        Sequence
    na_rm : bool
        Whether to remove ``nan`` values.
    finite : bool
        Whether to consider only finite values.

    Returns
    -------
    out : tuple
        (minimum, maximum) of x
    """
    x = np.asarray(x)

    if na_rm and finite:
        x = x[np.isfinite(x)]
    elif not na_rm and np.any(np.isnan(x)):
        return np.nan, np.nan
    elif na_rm:
        x = x[~np.isnan(x)]
    elif finite:
        x = x[~np.isinf(x)]

    if len(x):
        return np.min(x), np.max(x)  # type: ignore
    else:
        return float("-inf"), float("inf")


def match(
    v1: AnyArrayLike,
    v2: AnyArrayLike,
    nomatch: int = -1,
    incomparables: Optional[Any] = None,
    start: int = 0,
) -> list[int]:
    """
    Return a vector of the positions of (first)
    matches of its first argument in its second.

    Parameters
    ----------
    v1: array-like
        The values to be matched

    v2: array-like
        The values to be matched against

    nomatch: int
        The value to be returned in the case when
        no match is found.

    incomparables: array-like
        A list of values that cannot be matched.
        Any value in v1 matching a value in this list
        is assigned the nomatch value.
    start: int
        Type of indexing to use. Most likely 0 or 1
    """
    # NOTE: This function gets called a lot. If it can
    # be optimised, it should.
    lookup: dict[Any, int] = {}
    for i, x in enumerate(v2, start=start):
        if x not in lookup:
            lookup[x] = i

    if incomparables:
        skip = set(incomparables)
        lst = [
            lookup.get(x, nomatch) if x not in skip else nomatch for x in v1
        ]
    else:
        lst = [lookup.get(x, nomatch) for x in v1]
    return lst


def precision(x: FloatArrayLike | float) -> float:
    """
    Return the precision of x

    Parameters
    ----------
    x : array_like | numeric
        Value(s) whose for which to compute the precision.

    Returns
    -------
    out : numeric
        The precision of ``x`` or that the values in ``x``.

    Notes
    -----
    The precision is computed in base 10.

    Examples
    --------
    >>> precision([0.08, 0.09])
    0.01

    Maximum precision is 1

    >>> precision([9, 8])
    1
    >>> precision([16, 78])
    1

    A single values have a precision of 1

    >>> precision(0.08)
    1
    >>> precision([325])
    1
    """

    x = np.asarray(x)
    x = x[~(np.isnan(x) | np.isinf(x))]

    if len(x) <= 1:
        return 1

    smallest_diff = np.diff(np.sort(x))[0]
    if smallest_diff < ROUNDING_ERROR:  # pragma: nocover
        return 1
    else:
        # For some intel processors (Skylake), numpy may be compiled
        # to do a fast but precision losing np.log10 calculation.
        # We add a rounding error incase due to lost precision, a
        # result that should be an integer (or slightly greater than
        # one) will get "floored" to the correct value
        res = 10 ** int(np.floor(np.log10(smallest_diff) + ROUNDING_ERROR) - 1)
        has_extra_zeros = (np.round(x / res) % 10 == 0).all()
        if has_extra_zeros:
            res *= 10

        # 1 comes first so that min(1, 1.0) returns an integer
        return min(1, res)


def same_log10_order_of_magnitude(x, delta=0.1):
    """
    Return true if range is approximately in same order of magnitude

    For example these sequences are in the same order of magnitude:

        - [1, 8, 5]     # [1, 10)
        - [35, 20, 80]  # [10 100)
        - [232, 730]    # [100, 1000)

    Parameters
    ----------
    x : array-like
         Values in base 10. Must be size 2 and
        ``rng[0] <= rng[1]``.
    delta : float
        Fuzz factor for approximation. It is multiplicative.
    """
    dmin = np.log10(np.min(x) * (1 - delta))
    dmax = np.log10(np.max(x) * (1 + delta))
    return np.floor(dmin) == np.floor(dmax)


def identity(*args):
    """
    Return whatever is passed in
    """
    return args[0] if len(args) == 1 else args


def get_categories(x):
    """
    Return the categories of x

    Parameters
    ----------
    x : category_like
        Input Values

    Returns
    -------
    out : Index
        Categories of x
    """
    try:
        return x.cat.categories  # series
    except AttributeError:
        try:
            return x.categories  # plain categorical
        except AttributeError as err:
            raise TypeError("x is not a categorical") from err


def log(x, base):
    """
    Calculate the log of x

    Parameters
    ----------
    x : category_like
        Input Values

    base : float
        Base of logarithm

    Returns
    -------
    out : float
        Log of x
    """
    if base == 10:
        res = np.log10(x)
    elif base == 2:
        res = np.log2(x)
    elif base == np.e:
        res = np.log(x)
    else:
        res = np.log(x) / np.log(base)
    return res


def get_timezone(x: SeqDatetime) -> tzinfo | None:
    """
    Return a single timezone for the sequence of datetimes

    Returns the timezone of first item and warns if any other items
    have a different timezone
    """

    # Ref: https://en.wikipedia.org/wiki/List_of_tz_database_time_zones
    x0 = next(iter(x))
    if not isinstance(x0, datetime):
        return None

    x = cast(list[datetime], x)
    info = x0.tzinfo
    if info is None:
        return timezone.utc

    # Consistency check
    tzname0 = info.tzname(x0)
    tznames = (dt.tzinfo.tzname(dt) if dt.tzinfo else None for dt in x)

    if any(tzname0 != name for name in tznames):
        msg = (
            "Dates in column have different time zones. "
            f"Choosen `{tzname0}` the time zone of the first date. "
            "To use a different time zone, create a "
            "labeller and pass the time zone."
        )
        warn(msg)
    return info


def get_null_value(x: Any) -> NullType:
    """
    Return a Null value for the type of values
    """
    from datetime import datetime, timedelta

    import pandas as pd

    x0 = next(iter(x))
    numeric_types: Sequence[type] = (np.int64, np.float64, int, float, bool)

    if pdtypes.is_object_dtype(x):
        return None
    elif isinstance(x0, numeric_types):  # type: ignore
        return float("nan")
    # pandas types subclass cypthon types, so check
    # for them first
    elif isinstance(x0, (pd.Timestamp, pd.Timedelta)):
        return type(x0)("NaT")
    elif isinstance(x0, (datetime, timedelta)):
        return None
    elif isinstance(x0, (np.datetime64, np.timedelta64)):
        return type(x0)("NaT")
    else:
        raise ValueError(
            "Cannot get a null value for type: {}".format(type(x[0]))
        )


def isclose_abs(a: float, b: float, tol: float = ABS_TOL) -> bool:
    """
    Return True if a and b are close given the absolute tolerance
    """
    return math.isclose(a, b, rel_tol=0, abs_tol=ABS_TOL)


def is_vector(x: Any) -> TypeGuard[NDArrayFloat | FloatSeries]:
    """
    Return True if x is a numpy array or a pandas series
    """
    return isinstance(x, (np.ndarray, pd.Series))


def has_dtype(
    x: Any,
) -> TypeGuard[np.ndarray | pd.Series | pd.Categorical | pd.Index]:
    """
    Return True if x has the dtype property
    """
    return hasattr(x, "dtype")