"""
Experimental support for external memory
========================================

This is similar to the one in `quantile_data_iterator.py`, but for external memory
instead of Quantile DMatrix.  The feature is not ready for production use yet.

    .. versionadded:: 1.5.0


See :doc:`the tutorial </tutorials/external_memory>` for more details.

    .. versionchanged:: 3.0.0

        Added :py:class:`~xgboost.ExtMemQuantileDMatrix`.

To run the example, following packages in addition to XGBoost native dependencies are
required:

- scikit-learn

If `device` is `cuda`, following are also needed:

- cupy
- rmm
- cuda-python

.. seealso::

  :ref:`sphx_glr_python_examples_distributed_extmem_basic.py`

Not shown in this example, but you should pay attention to NUMA configuration as
discussed in the tutorial.

"""

import argparse
import os
import tempfile
from typing import TYPE_CHECKING, Callable, List, Literal, Tuple

import numpy as np
import xgboost
from sklearn.datasets import make_regression

if TYPE_CHECKING:
    from cuda.bindings.runtime import cudaError_t


def _checkcu(status: "cudaError_t") -> None:
    import cuda.bindings.runtime as cudart

    if status != cudart.cudaError_t.cudaSuccess:
        raise RuntimeError(cudart.cudaGetErrorString(status))


def device_mem_total() -> int:
    """The total number of bytes of memory this GPU has."""
    import cuda.bindings.runtime as cudart

    status, _, total = cudart.cudaMemGetInfo()
    _checkcu(status)
    return total


def make_batches(
    n_samples_per_batch: int,
    n_features: int,
    n_batches: int,
    work_dir: str,
) -> List[Tuple[str, str]]:
    """Write ``n_batches`` synthetic regression batches as `.npy` pairs under ``work_dir``."""
    files: List[Tuple[str, str]] = []
    rng = np.random.RandomState(1994)
    for i in range(n_batches):
        X, y = make_regression(n_samples_per_batch, n_features, random_state=rng)
        X_path = os.path.join(work_dir, "X-" + str(i) + ".npy")
        y_path = os.path.join(work_dir, "y-" + str(i) + ".npy")
        np.save(X_path, X)
        np.save(y_path, y)
        files.append((X_path, y_path))
    return files


class Iterator(xgboost.DataIter):
    """A custom iterator for loading files in batches."""

    def __init__(
        self, device: Literal["cpu", "cuda"], file_paths: List[Tuple[str, str]]
    ) -> None:
        self.device = device

        self._file_paths = file_paths
        self._it = 0
        # XGBoost will generate some cache files under the current directory with the
        # prefix "cache"
        super().__init__(cache_prefix=os.path.join(".", "cache"))

    def load_file(self) -> Tuple[np.ndarray, np.ndarray]:
        """Load a single batch of data."""
        X_path, y_path = self._file_paths[self._it]
        # When the `ExtMemQuantileDMatrix` is used, the device must match. GPU cannot
        # consume CPU input data and vice-versa.
        if self.device == "cpu":
            X = np.load(X_path)
            y = np.load(y_path)
        else:
            import cupy as cp  # pylint: disable=import-outside-toplevel

            X = cp.load(X_path)
            y = cp.load(y_path)

        assert X.shape[0] == y.shape[0]
        return X, y

    def next(self, input_data: Callable) -> bool:
        """Advance the iterator by 1 step and pass the data to XGBoost.  This function
        is called by XGBoost during the construction of ``DMatrix``

        """
        if self._it == len(self._file_paths):
            # return False to let XGBoost know this is the end of iteration
            return False

        # input_data is a keyword-only function passed in by XGBoost and has the similar
        # signature to the ``DMatrix`` constructor.
        X, y = self.load_file()
        input_data(data=X, label=y)
        self._it += 1
        return True

    def reset(self) -> None:
        """Reset the iterator to its beginning"""
        self._it = 0


def hist_train(it: Iterator) -> None:
    """The hist tree method can use a special data structure `ExtMemQuantileDMatrix` for
    faster initialization and lower memory usage (recommended).

    .. versionadded:: 3.0.0

    """
    # For non-data arguments, specify it here once instead of passing them by the `next`
    # method.
    Xy = xgboost.ExtMemQuantileDMatrix(it, missing=np.nan, enable_categorical=False)
    booster = xgboost.train(
        {"tree_method": "hist", "max_depth": 4, "device": it.device},
        Xy,
        evals=[(Xy, "Train")],
        num_boost_round=10,
    )
    booster.predict(Xy)


def approx_train(it: Iterator) -> None:
    """The approx tree method uses the basic `DMatrix` (not recommended)."""

    # For non-data arguments, specify it here once instead of passing them by the `next`
    # method.
    Xy = xgboost.DMatrix(it, missing=np.nan, enable_categorical=False)
    # ``approx`` is also supported, but less efficient due to sketching. It's
    # recommended to use `hist` instead.
    booster = xgboost.train(
        {"tree_method": "approx", "max_depth": 4, "device": it.device},
        Xy,
        evals=[(Xy, "Train")],
        num_boost_round=10,
    )
    booster.predict(Xy)


def main(work_dir: str, cli_args: argparse.Namespace) -> None:
    """Entry point for training."""

    # generate some random data for demo
    files = make_batches(
        n_samples_per_batch=1024, n_features=17, n_batches=31, work_dir=work_dir
    )
    it = Iterator(cli_args.device, files)

    hist_train(it)
    approx_train(it)


def setup_async_pool() -> None:
    """Setup CUDA async pool. As an alternative, the RMM plugin can be used as well. See
    the `setup_rmm`. This is the same as using the `CudaAsyncMemoryResource` from RMM,
    but without the RMM dependency.

    .. versionadded:: 3.2.0

    """
    import cuda.bindings.runtime as cudart
    import cupy as cp  # pylint: disable=import-outside-toplevel
    from cuda.bindings import driver
    from cupy.cuda import MemoryAsyncPool

    status, dft_pool = cudart.cudaDeviceGetDefaultMemPool(0)
    _checkcu(status)

    total = device_mem_total()

    v = driver.cuuint64_t(int(total * 0.9))
    (status,) = cudart.cudaMemPoolSetAttribute(
        dft_pool,
        cudart.cudaMemPoolAttr.cudaMemPoolAttrReleaseThreshold,
        v,
    )
    _checkcu(status)
    # Set the allocator for cupy as well.
    cp.cuda.set_allocator(MemoryAsyncPool().malloc)


def setup_rmm() -> None:
    """Setup RMM for GPU-based external memory training.

    It's important to use RMM with `CudaAsyncMemoryResource` or `ArenaMemoryResource`
    for GPU-based external memory to improve performance. If XGBoost is not built with
    RMM support, a warning is raised when constructing the `DMatrix`.

    """

    import rmm
    from rmm.allocators.cupy import rmm_cupy_allocator
    from rmm.mr import ArenaMemoryResource

    if not xgboost.build_info()["USE_RMM"]:
        return

    import cupy as cp  # pylint: disable=import-outside-toplevel

    total = device_mem_total()

    mr: rmm.mr.DeviceMemoryResource = rmm.mr.CudaMemoryResource()
    mr = ArenaMemoryResource(mr, arena_size=int(total * 0.9))

    rmm.mr.set_current_device_resource(mr)
    # Set the allocator for cupy as well.
    cp.cuda.set_allocator(rmm_cupy_allocator)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
    parser.add_argument(
        "--memory_pool",
        choices=["rmm", "cuda"],
        default="rmm",
        help="Use a memory pool for asynchronous memory allocation in XGBoost.",
    )
    args = parser.parse_args()
    if args.device == "cuda":
        if args.memory_pool == "rmm":
            setup_rmm()
        elif args.memory_pool == "cuda":
            setup_async_pool()
        # Make sure XGBoost is using RMM for all allocations.
        with xgboost.config_context(
            use_rmm=args.memory_pool == "rmm",
            use_cuda_async_pool=args.memory_pool == "cuda",
        ):
            with tempfile.TemporaryDirectory() as tmpdir:
                main(tmpdir, args)
    else:
        with tempfile.TemporaryDirectory() as tmpdir:
            main(tmpdir, args)
