Source code for pytyche.setup

"""Runtime setup inspection (``pt.check_setup``) and the no-CUDA fit warning.

JAX is imported inside functions, never at module level — ``import pytyche``
must stay importable without pulling the heavy stack (see ``pytyche/__init__``).
"""

from __future__ import annotations

import dataclasses
import os
import warnings
from importlib import metadata, resources
from typing import Any

SUPPRESS_GPU_WARNING_VAR = "PYTYCHE_SUPPRESS_GPU_WARNING"

_warned = False

_NO_CUDA_MESSAGE = (
    "No CUDA device detected — this fit runs on CPU, which is orders of "
    "magnitude slower at production scale. Install the GPU extra: "
    "uv add 'pytyche[gpu]' (or pip install 'pytyche[gpu]'). Set "
    "PYTYCHE_SUPPRESS_GPU_WARNING=1 to silence this warning."
)


[docs] class NoCudaWarning(UserWarning): """Emitted once per process on the first fit without a CUDA device."""
[docs] @dataclasses.dataclass(frozen=True) class SetupReport: """Snapshot of the pytyche runtime, returned by :func:`check_setup`. Attributes: pytyche_version: Installed pytyche version. jax_devices: String form of every device JAX reports. cuda_available: Whether any reported device is a GPU. bartz_version: Installed bartz version. calibration_registry: Names of calibration artifacts bundled with the wheel (empty when none are bundled). recommended_install: Install command closing a detected setup gap (e.g. the CUDA upgrade), or ``None`` when the setup is complete. """ pytyche_version: str jax_devices: list[str] cuda_available: bool bartz_version: str calibration_registry: list[str] recommended_install: str | None def __repr__(self) -> str: return _format_summary(self) def _repr_html_(self) -> str: # Local import — `import pytyche` stays light (module docstring). import pandas as pd rows = [ ("pytyche", self.pytyche_version), ("bartz", self.bartz_version), ("JAX devices", ", ".join(self.jax_devices)), ("CUDA available", str(self.cuda_available)), ( "calibration artifacts", ", ".join(self.calibration_registry) or "(none bundled)", ), ] if self.recommended_install is not None: rows.append(("recommended", self.recommended_install)) frame = pd.DataFrame( [value for _, value in rows], index=[key for key, _ in rows], columns=["setup"], ) return frame.to_html(header=False, border=0)
def _devices() -> list[Any]: import jax return list(jax.devices()) def _calibration_data_root(): return resources.files("pytyche") / "_data" / "calibration" def _calibration_registry() -> list[str]: root = _calibration_data_root() if not root.is_dir(): return [] return sorted( entry.name.removesuffix(".json") for entry in root.iterdir() if entry.name.endswith(".json") ) def _format_summary(report: SetupReport) -> str: lines = [ f"pytyche {report.pytyche_version}", f"bartz {report.bartz_version}", f"JAX devices: {', '.join(report.jax_devices)}", f"CUDA available: {report.cuda_available}", "calibration artifacts: " + (", ".join(report.calibration_registry) or "(none bundled)"), ] if report.recommended_install is not None: lines.append(f"recommended: {report.recommended_install}") return "\n".join(lines)
[docs] def check_setup() -> SetupReport: """Inspect the pytyche runtime and report on it. Prints a human-readable summary to stdout (for notebook / interactive use) and returns the same content as a structured :class:`SetupReport` (for programmatic checks). Idempotent — no side effects beyond the print; in particular it never triggers the no-CUDA fit warning. Returns: The :class:`SetupReport` for this process. """ from pytyche import __version__ devices = _devices() cuda_available = any(d.platform == "gpu" for d in devices) report = SetupReport( pytyche_version=__version__, jax_devices=[str(d) for d in devices], cuda_available=cuda_available, bartz_version=metadata.version("bartz"), calibration_registry=_calibration_registry(), recommended_install=None if cuda_available else "uv add 'pytyche[gpu]'", ) print(_format_summary(report)) return report
def _warn_if_no_cuda() -> None: """Warn once per process when fitting without a CUDA device. Called at the top of every public fit entry. The check is consumed on first call regardless of outcome — one device probe per process. """ global _warned if _warned: return _warned = True if os.environ.get(SUPPRESS_GPU_WARNING_VAR) == "1": return if any(d.platform == "gpu" for d in _devices()): return # stacklevel=3 points past this helper and the fit entry, at the # user's fit call site. warnings.warn(_NO_CUDA_MESSAGE, NoCudaWarning, stacklevel=3)