import os
from functools import partial

import numpy
import pytest
from numpy.testing import assert_allclose

from sklearn._config import config_context
from sklearn.base import BaseEstimator
from sklearn.utils._array_api import (
    _asarray_with_order,
    _atol_for_type,
    _average,
    _convert_to_numpy,
    _count_nonzero,
    _estimator_with_converted_arrays,
    _fill_or_add_to_diagonal,
    _get_namespace_device_dtype_ids,
    _is_numpy_namespace,
    _isin,
    _max_precision_float_dtype,
    _nanmax,
    _nanmean,
    _nanmin,
    _ravel,
    device,
    get_namespace,
    get_namespace_and_device,
    indexing_dtype,
    np_compat,
    yield_namespace_device_dtype_combinations,
)
from sklearn.utils._testing import (
    SkipTest,
    _array_api_for_tests,
    assert_array_equal,
    skip_if_array_api_compat_not_configured,
)
from sklearn.utils.fixes import _IS_32BIT, CSR_CONTAINERS, np_version, parse_version


@pytest.mark.parametrize("X", [numpy.asarray([1, 2, 3]), [1, 2, 3]])
def test_get_namespace_ndarray_default(X):
    """Check that get_namespace returns NumPy wrapper"""
    xp_out, is_array_api_compliant = get_namespace(X)
    assert xp_out is np_compat
    assert not is_array_api_compliant


def test_get_namespace_ndarray_creation_device():
    """Check expected behavior with device and creation functions."""
    X = numpy.asarray([1, 2, 3])
    xp_out, _ = get_namespace(X)

    full_array = xp_out.full(10, fill_value=2.0, device="cpu")
    assert_allclose(full_array, [2.0] * 10)

    with pytest.raises(ValueError, match="Unsupported device"):
        xp_out.zeros(10, device="cuda")


@skip_if_array_api_compat_not_configured
def test_get_namespace_ndarray_with_dispatch():
    """Test get_namespace on NumPy ndarrays."""

    X_np = numpy.asarray([[1, 2, 3]])

    with config_context(array_api_dispatch=True):
        xp_out, is_array_api_compliant = get_namespace(X_np)
        assert is_array_api_compliant

        # In the future, NumPy should become API compliant library and we should have
        # assert xp_out is numpy
        assert xp_out is np_compat


@skip_if_array_api_compat_not_configured
def test_get_namespace_array_api(monkeypatch):
    """Test get_namespace for ArrayAPI arrays."""
    xp = pytest.importorskip("array_api_strict")

    X_np = numpy.asarray([[1, 2, 3]])
    X_xp = xp.asarray(X_np)
    with config_context(array_api_dispatch=True):
        xp_out, is_array_api_compliant = get_namespace(X_xp)
        assert is_array_api_compliant

        with pytest.raises(TypeError):
            xp_out, is_array_api_compliant = get_namespace(X_xp, X_np)

        def mock_getenv(key):
            if key == "SCIPY_ARRAY_API":
                return "0"

        monkeypatch.setattr("os.environ.get", mock_getenv)
        assert os.environ.get("SCIPY_ARRAY_API") != "1"
        with pytest.raises(
            RuntimeError,
            match="scipy's own support is not enabled.",
        ):
            get_namespace(X_xp)


@pytest.mark.parametrize("array_api", ["numpy", "array_api_strict"])
def test_asarray_with_order(array_api):
    """Test _asarray_with_order passes along order for NumPy arrays."""
    xp = pytest.importorskip(array_api)

    X = xp.asarray([1.2, 3.4, 5.1])
    X_new = _asarray_with_order(X, order="F", xp=xp)

    X_new_np = numpy.asarray(X_new)
    assert X_new_np.flags["F_CONTIGUOUS"]


@pytest.mark.parametrize(
    "array_namespace, device_, dtype_name",
    yield_namespace_device_dtype_combinations(),
    ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize(
    "weights, axis, normalize, expected",
    [
        # normalize = True
        (None, None, True, 3.5),
        (None, 0, True, [2.5, 3.5, 4.5]),
        (None, 1, True, [2, 5]),
        ([True, False], 0, True, [1, 2, 3]),  # boolean weights
        ([True, True, False], 1, True, [1.5, 4.5]),  # boolean weights
        ([0.4, 0.1], 0, True, [1.6, 2.6, 3.6]),
        ([0.4, 0.2, 0.2], 1, True, [1.75, 4.75]),
        ([1, 2], 0, True, [3, 4, 5]),
        ([1, 1, 2], 1, True, [2.25, 5.25]),
        ([[1, 2, 3], [1, 2, 3]], 0, True, [2.5, 3.5, 4.5]),
        ([[1, 2, 1], [2, 2, 2]], 1, True, [2, 5]),
        # normalize = False
        (None, None, False, 21),
        (None, 0, False, [5, 7, 9]),
        (None, 1, False, [6, 15]),
        ([True, False], 0, False, [1, 2, 3]),  # boolean weights
        ([True, True, False], 1, False, [3, 9]),  # boolean weights
        ([0.4, 0.1], 0, False, [0.8, 1.3, 1.8]),
        ([0.4, 0.2, 0.2], 1, False, [1.4, 3.8]),
        ([1, 2], 0, False, [9, 12, 15]),
        ([1, 1, 2], 1, False, [9, 21]),
        ([[1, 2, 3], [1, 2, 3]], 0, False, [5, 14, 27]),
        ([[1, 2, 1], [2, 2, 2]], 1, False, [8, 30]),
    ],
)
def test_average(
    array_namespace, device_, dtype_name, weights, axis, normalize, expected
):
    xp = _array_api_for_tests(array_namespace, device_)
    array_in = numpy.asarray([[1, 2, 3], [4, 5, 6]], dtype=dtype_name)
    array_in = xp.asarray(array_in, device=device_)
    if weights is not None:
        weights = numpy.asarray(weights, dtype=dtype_name)
        weights = xp.asarray(weights, device=device_)

    with config_context(array_api_dispatch=True):
        result = _average(array_in, axis=axis, weights=weights, normalize=normalize)

    if np_version < parse_version("2.0.0") or np_version >= parse_version("2.1.0"):
        # NumPy 2.0 has a problem with the device attribute of scalar arrays:
        # https://github.com/numpy/numpy/issues/26850
        assert device(array_in) == device(result)

    result = _convert_to_numpy(result, xp)
    assert_allclose(result, expected, atol=_atol_for_type(dtype_name))


@pytest.mark.parametrize(
    "array_namespace, device, dtype_name",
    yield_namespace_device_dtype_combinations(include_numpy_namespaces=False),
    ids=_get_namespace_device_dtype_ids,
)
def test_average_raises_with_wrong_dtype(array_namespace, device, dtype_name):
    xp = _array_api_for_tests(array_namespace, device)

    array_in = numpy.asarray([2, 0], dtype=dtype_name) + 1j * numpy.asarray(
        [4, 3], dtype=dtype_name
    )
    complex_type_name = array_in.dtype.name
    if not hasattr(xp, complex_type_name):
        # This is the case for cupy as of March 2024 for instance.
        pytest.skip(f"{array_namespace} does not support {complex_type_name}")

    array_in = xp.asarray(array_in, device=device)

    err_msg = "Complex floating point values are not supported by average."
    with (
        config_context(array_api_dispatch=True),
        pytest.raises(NotImplementedError, match=err_msg),
    ):
        _average(array_in)


@pytest.mark.parametrize(
    "array_namespace, device, dtype_name",
    yield_namespace_device_dtype_combinations(include_numpy_namespaces=True),
    ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize(
    "axis, weights, error, error_msg",
    (
        (
            None,
            [1, 2],
            TypeError,
            "Axis must be specified",
        ),
        (
            0,
            [[1, 2]],
            # NumPy 2 raises ValueError, NumPy 1 raises TypeError
            (ValueError, TypeError),
            "weights",  # the message is different for NumPy 1 and 2...
        ),
        (
            0,
            [1, 2, 3, 4],
            ValueError,
            "weights",
        ),
        (0, [-1, 1], ZeroDivisionError, "Weights sum to zero, can't be normalized"),
    ),
)
def test_average_raises_with_invalid_parameters(
    array_namespace, device, dtype_name, axis, weights, error, error_msg
):
    xp = _array_api_for_tests(array_namespace, device)

    array_in = numpy.asarray([[1, 2, 3], [4, 5, 6]], dtype=dtype_name)
    array_in = xp.asarray(array_in, device=device)

    weights = numpy.asarray(weights, dtype=dtype_name)
    weights = xp.asarray(weights, device=device)

    with config_context(array_api_dispatch=True), pytest.raises(error, match=error_msg):
        _average(array_in, axis=axis, weights=weights)


def test_device_none_if_no_input():
    assert device() is None

    assert device(None, "name") is None


@skip_if_array_api_compat_not_configured
def test_device_inspection():
    class Device:
        def __init__(self, name):
            self.name = name

        def __eq__(self, device):
            return self.name == device.name

        def __hash__(self):
            raise TypeError("Device object is not hashable")

        def __str__(self):
            return self.name

    class Array:
        def __init__(self, device_name):
            self.device = Device(device_name)

    # Sanity check: ensure our Device mock class is non hashable, to
    # accurately account for non-hashable device objects in some array
    # libraries, because of which the `device` inspection function shouldn't
    # make use of hash lookup tables (in particular, not use `set`)
    with pytest.raises(TypeError):
        hash(Array("device").device)

    # If array API dispatch is disabled the device should be ignored. Erroring
    # early for different devices would prevent the np.asarray conversion to
    # happen. For example, `r2_score(np.ones(5), torch.ones(5))` should work
    # fine with array API disabled.
    assert device(Array("cpu"), Array("mygpu")) is None

    # Test that ValueError is raised if on different devices and array API dispatch is
    # enabled.
    err_msg = "Input arrays use different devices: cpu, mygpu"
    with config_context(array_api_dispatch=True):
        with pytest.raises(ValueError, match=err_msg):
            device(Array("cpu"), Array("mygpu"))

        # Test expected value is returned otherwise
        array1 = Array("device")
        array2 = Array("device")

        assert array1.device == device(array1)
        assert array1.device == device(array1, array2)
        assert array1.device == device(array1, array1, array2)


# TODO: add cupy to the list of libraries once the following upstream issue
# has been fixed:
# https://github.com/cupy/cupy/issues/8180
@skip_if_array_api_compat_not_configured
@pytest.mark.parametrize("library", ["numpy", "array_api_strict", "torch"])
@pytest.mark.parametrize(
    "X,reduction,expected",
    [
        ([1, 2, numpy.nan], _nanmin, 1),
        ([1, -2, -numpy.nan], _nanmin, -2),
        ([numpy.inf, numpy.inf], _nanmin, numpy.inf),
        (
            [[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
            partial(_nanmin, axis=0),
            [1.0, 2.0, 3.0],
        ),
        (
            [[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
            partial(_nanmin, axis=1),
            [1.0, numpy.nan, 4.0],
        ),
        ([1, 2, numpy.nan], _nanmax, 2),
        ([1, 2, numpy.nan], _nanmax, 2),
        ([-numpy.inf, -numpy.inf], _nanmax, -numpy.inf),
        (
            [[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
            partial(_nanmax, axis=0),
            [4.0, 5.0, 6.0],
        ),
        (
            [[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
            partial(_nanmax, axis=1),
            [3.0, numpy.nan, 6.0],
        ),
        ([1, 2, numpy.nan], _nanmean, 1.5),
        ([1, -2, -numpy.nan], _nanmean, -0.5),
        ([-numpy.inf, -numpy.inf], _nanmean, -numpy.inf),
        (
            [[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
            partial(_nanmean, axis=0),
            [2.5, 3.5, 4.5],
        ),
        (
            [[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
            partial(_nanmean, axis=1),
            [2.0, numpy.nan, 5.0],
        ),
    ],
)
def test_nan_reductions(library, X, reduction, expected):
    """Check NaN reductions like _nanmin and _nanmax"""
    xp = pytest.importorskip(library)

    with config_context(array_api_dispatch=True):
        result = reduction(xp.asarray(X))

    result = _convert_to_numpy(result, xp)
    assert_allclose(result, expected)


@pytest.mark.parametrize(
    "namespace, _device, _dtype",
    yield_namespace_device_dtype_combinations(),
    ids=_get_namespace_device_dtype_ids,
)
def test_ravel(namespace, _device, _dtype):
    xp = _array_api_for_tests(namespace, _device)

    array = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
    array_xp = xp.asarray(array, device=_device)
    with config_context(array_api_dispatch=True):
        result = _ravel(array_xp)

    result = _convert_to_numpy(result, xp)
    expected = numpy.ravel(array, order="C")

    assert_allclose(expected, result)

    if _is_numpy_namespace(xp):
        assert numpy.asarray(result).flags["C_CONTIGUOUS"]


@skip_if_array_api_compat_not_configured
@pytest.mark.parametrize("library", ["cupy", "torch"])
def test_convert_to_numpy_gpu(library):  # pragma: nocover
    """Check convert_to_numpy for GPU backed libraries."""
    xp = pytest.importorskip(library)

    if library == "torch":
        if not xp.backends.cuda.is_built():
            pytest.skip("test requires cuda")
        X_gpu = xp.asarray([1.0, 2.0, 3.0], device="cuda")
    else:
        X_gpu = xp.asarray([1.0, 2.0, 3.0])

    X_cpu = _convert_to_numpy(X_gpu, xp=xp)
    expected_output = numpy.asarray([1.0, 2.0, 3.0])
    assert_allclose(X_cpu, expected_output)


def test_convert_to_numpy_cpu():
    """Check convert_to_numpy for PyTorch CPU arrays."""
    torch = pytest.importorskip("torch")
    X_torch = torch.asarray([1.0, 2.0, 3.0], device="cpu")

    X_cpu = _convert_to_numpy(X_torch, xp=torch)
    expected_output = numpy.asarray([1.0, 2.0, 3.0])
    assert_allclose(X_cpu, expected_output)


class SimpleEstimator(BaseEstimator):
    def fit(self, X, y=None):
        self.X_ = X
        self.n_features_ = X.shape[0]
        return self


@skip_if_array_api_compat_not_configured
@pytest.mark.parametrize(
    "array_namespace, converter",
    [
        ("torch", lambda array: array.cpu().numpy()),
        ("array_api_strict", lambda array: numpy.asarray(array)),
        ("cupy", lambda array: array.get()),
    ],
)
def test_convert_estimator_to_ndarray(array_namespace, converter):
    """Convert estimator attributes to ndarray."""
    xp = pytest.importorskip(array_namespace)

    X = xp.asarray([[1.3, 4.5]])
    est = SimpleEstimator().fit(X)

    new_est = _estimator_with_converted_arrays(est, converter)
    assert isinstance(new_est.X_, numpy.ndarray)


@skip_if_array_api_compat_not_configured
def test_convert_estimator_to_array_api():
    """Convert estimator attributes to ArrayAPI arrays."""
    xp = pytest.importorskip("array_api_strict")

    X_np = numpy.asarray([[1.3, 4.5]])
    est = SimpleEstimator().fit(X_np)

    new_est = _estimator_with_converted_arrays(est, lambda array: xp.asarray(array))
    assert hasattr(new_est.X_, "__array_namespace__")


@pytest.mark.parametrize(
    "namespace, _device, _dtype",
    yield_namespace_device_dtype_combinations(),
    ids=_get_namespace_device_dtype_ids,
)
def test_indexing_dtype(namespace, _device, _dtype):
    xp = _array_api_for_tests(namespace, _device)

    if _IS_32BIT:
        assert indexing_dtype(xp) == xp.int32
    else:
        assert indexing_dtype(xp) == xp.int64


@pytest.mark.parametrize(
    "namespace, _device, _dtype",
    yield_namespace_device_dtype_combinations(),
    ids=_get_namespace_device_dtype_ids,
)
def test_max_precision_float_dtype(namespace, _device, _dtype):
    xp = _array_api_for_tests(namespace, _device)
    expected_dtype = xp.float32 if _device == "mps" else xp.float64
    assert _max_precision_float_dtype(xp, _device) == expected_dtype


@pytest.mark.parametrize(
    "array_namespace, device, _",
    yield_namespace_device_dtype_combinations(),
    ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize("invert", [True, False])
@pytest.mark.parametrize("assume_unique", [True, False])
@pytest.mark.parametrize("element_size", [6, 10, 14])
@pytest.mark.parametrize("int_dtype", ["int16", "int32", "int64", "uint8"])
def test_isin(
    array_namespace, device, _, invert, assume_unique, element_size, int_dtype
):
    xp = _array_api_for_tests(array_namespace, device)
    r = element_size // 2
    element = 2 * numpy.arange(element_size).reshape((r, 2)).astype(int_dtype)
    test_elements = numpy.array(numpy.arange(14), dtype=int_dtype)
    element_xp = xp.asarray(element, device=device)
    test_elements_xp = xp.asarray(test_elements, device=device)
    expected = numpy.isin(
        element=element,
        test_elements=test_elements,
        assume_unique=assume_unique,
        invert=invert,
    )
    with config_context(array_api_dispatch=True):
        result = _isin(
            element=element_xp,
            test_elements=test_elements_xp,
            xp=xp,
            assume_unique=assume_unique,
            invert=invert,
        )

    assert_array_equal(_convert_to_numpy(result, xp=xp), expected)


@pytest.mark.skipif(
    os.environ.get("SCIPY_ARRAY_API") != "1", reason="SCIPY_ARRAY_API not set to 1."
)
def test_get_namespace_and_device():
    # Use torch as a library with custom Device objects:
    torch = pytest.importorskip("torch")

    from sklearn.externals.array_api_compat import torch as torch_compat

    some_torch_tensor = torch.arange(3, device="cpu")
    some_numpy_array = numpy.arange(3)

    # When dispatch is disabled, get_namespace_and_device should return the
    # default NumPy wrapper namespace and "cpu" device. Our code will handle such
    # inputs via the usual __array__ interface without attempting to dispatch
    # via the array API.
    namespace, is_array_api, device = get_namespace_and_device(some_torch_tensor)
    assert namespace is get_namespace(some_numpy_array)[0]
    assert not is_array_api
    assert device is None

    # Otherwise, expose the torch namespace and device via array API compat
    # wrapper.
    with config_context(array_api_dispatch=True):
        namespace, is_array_api, device = get_namespace_and_device(some_torch_tensor)
        assert namespace is torch_compat
        assert is_array_api
        assert device == some_torch_tensor.device


@pytest.mark.parametrize(
    "array_namespace, device_, dtype_name",
    yield_namespace_device_dtype_combinations(),
    ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
@pytest.mark.parametrize("axis", [0, 1, None, -1, -2])
@pytest.mark.parametrize("sample_weight_type", [None, "int", "float"])
def test_count_nonzero(
    array_namespace, device_, dtype_name, csr_container, axis, sample_weight_type
):
    from sklearn.utils.sparsefuncs import count_nonzero as sparse_count_nonzero

    xp = _array_api_for_tests(array_namespace, device_)
    array = numpy.array([[0, 3, 0], [2, -1, 0], [0, 0, 0], [9, 8, 7], [4, 0, 5]])
    if sample_weight_type == "int":
        sample_weight = numpy.asarray([1, 2, 2, 3, 1])
    elif sample_weight_type == "float":
        sample_weight = numpy.asarray([0.5, 1.5, 0.8, 3.2, 2.4], dtype=dtype_name)
    else:
        sample_weight = None
    expected = sparse_count_nonzero(
        csr_container(array), axis=axis, sample_weight=sample_weight
    )
    array_xp = xp.asarray(array, device=device_)

    with config_context(array_api_dispatch=True):
        result = _count_nonzero(
            array_xp, axis=axis, sample_weight=sample_weight, xp=xp, device=device_
        )

    assert_allclose(_convert_to_numpy(result, xp=xp), expected)

    if np_version < parse_version("2.0.0") or np_version >= parse_version("2.1.0"):
        # NumPy 2.0 has a problem with the device attribute of scalar arrays:
        # https://github.com/numpy/numpy/issues/26850
        assert device(array_xp) == device(result)


@pytest.mark.parametrize(
    "array_namespace, device_, dtype_name",
    yield_namespace_device_dtype_combinations(),
    ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize("wrap", [True, False])
def test_fill_or_add_to_diagonal(array_namespace, device_, dtype_name, wrap):
    xp = _array_api_for_tests(array_namespace, device_)

    array_np = numpy.zeros((5, 4), dtype=dtype_name)
    array_xp = xp.asarray(array_np.copy(), device=device_)

    numpy.fill_diagonal(array_np, val=1, wrap=wrap)
    with config_context(array_api_dispatch=True):
        _fill_or_add_to_diagonal(array_xp, value=1, xp=xp, add_value=False, wrap=wrap)

    assert_array_equal(_convert_to_numpy(array_xp, xp=xp), array_np)


@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
@pytest.mark.parametrize("dispatch", [True, False])
def test_sparse_device(csr_container, dispatch):
    a, b = csr_container(numpy.array([[1]])), csr_container(numpy.array([[2]]))
    if dispatch and os.environ.get("SCIPY_ARRAY_API") is None:
        raise SkipTest("SCIPY_ARRAY_API is not set: not checking array_api input")
    with config_context(array_api_dispatch=dispatch):
        assert device(a, b) is None
        assert device(a, numpy.array([1])) is None
        assert get_namespace_and_device(a, b)[2] is None
        assert get_namespace_and_device(a, numpy.array([1]))[2] is None
