Source code for pytyche.setup
"""Runtime setup inspection (``pt.check_setup``) and the no-CUDA fit warning.
JAX is imported inside functions, never at module level — ``import pytyche``
must stay importable without pulling the heavy stack (see ``pytyche/__init__``).
"""
from __future__ import annotations
import dataclasses
import os
import warnings
from importlib import metadata, resources
from typing import Any
SUPPRESS_GPU_WARNING_VAR = "PYTYCHE_SUPPRESS_GPU_WARNING"
_warned = False
_NO_CUDA_MESSAGE = (
"No CUDA device detected — this fit runs on CPU, which is orders of "
"magnitude slower at production scale. Install the GPU extra: "
"uv add 'pytyche[gpu]' (or pip install 'pytyche[gpu]'). Set "
"PYTYCHE_SUPPRESS_GPU_WARNING=1 to silence this warning."
)
[docs]
class NoCudaWarning(UserWarning):
"""Emitted once per process on the first fit without a CUDA device."""
[docs]
@dataclasses.dataclass(frozen=True)
class SetupReport:
"""Snapshot of the pytyche runtime, returned by :func:`check_setup`.
Attributes:
pytyche_version: Installed pytyche version.
jax_devices: String form of every device JAX reports.
cuda_available: Whether any reported device is a GPU.
bartz_version: Installed bartz version.
calibration_registry: Names of calibration artifacts bundled with
the wheel (empty when none are bundled).
recommended_install: Install command closing a detected setup gap
(e.g. the CUDA upgrade), or ``None`` when the setup is complete.
"""
pytyche_version: str
jax_devices: list[str]
cuda_available: bool
bartz_version: str
calibration_registry: list[str]
recommended_install: str | None
def __repr__(self) -> str:
return _format_summary(self)
def _repr_html_(self) -> str:
# Local import — `import pytyche` stays light (module docstring).
import pandas as pd
rows = [
("pytyche", self.pytyche_version),
("bartz", self.bartz_version),
("JAX devices", ", ".join(self.jax_devices)),
("CUDA available", str(self.cuda_available)),
(
"calibration artifacts",
", ".join(self.calibration_registry) or "(none bundled)",
),
]
if self.recommended_install is not None:
rows.append(("recommended", self.recommended_install))
frame = pd.DataFrame(
[value for _, value in rows],
index=[key for key, _ in rows],
columns=["setup"],
)
return frame.to_html(header=False, border=0)
def _devices() -> list[Any]:
import jax
return list(jax.devices())
def _calibration_data_root():
return resources.files("pytyche") / "_data" / "calibration"
def _calibration_registry() -> list[str]:
root = _calibration_data_root()
if not root.is_dir():
return []
return sorted(
entry.name.removesuffix(".json")
for entry in root.iterdir()
if entry.name.endswith(".json")
)
def _format_summary(report: SetupReport) -> str:
lines = [
f"pytyche {report.pytyche_version}",
f"bartz {report.bartz_version}",
f"JAX devices: {', '.join(report.jax_devices)}",
f"CUDA available: {report.cuda_available}",
"calibration artifacts: "
+ (", ".join(report.calibration_registry) or "(none bundled)"),
]
if report.recommended_install is not None:
lines.append(f"recommended: {report.recommended_install}")
return "\n".join(lines)
[docs]
def check_setup() -> SetupReport:
"""Inspect the pytyche runtime and report on it.
Prints a human-readable summary to stdout (for notebook / interactive
use) and returns the same content as a structured :class:`SetupReport`
(for programmatic checks). Idempotent — no side effects beyond the
print; in particular it never triggers the no-CUDA fit warning.
Returns:
The :class:`SetupReport` for this process.
"""
from pytyche import __version__
devices = _devices()
cuda_available = any(d.platform == "gpu" for d in devices)
report = SetupReport(
pytyche_version=__version__,
jax_devices=[str(d) for d in devices],
cuda_available=cuda_available,
bartz_version=metadata.version("bartz"),
calibration_registry=_calibration_registry(),
recommended_install=None if cuda_available else "uv add 'pytyche[gpu]'",
)
print(_format_summary(report))
return report
def _warn_if_no_cuda() -> None:
"""Warn once per process when fitting without a CUDA device.
Called at the top of every public fit entry. The check is consumed on
first call regardless of outcome — one device probe per process.
"""
global _warned
if _warned:
return
_warned = True
if os.environ.get(SUPPRESS_GPU_WARNING_VAR) == "1":
return
if any(d.platform == "gpu" for d in _devices()):
return
# stacklevel=3 points past this helper and the fit entry, at the
# user's fit call site.
warnings.warn(_NO_CUDA_MESSAGE, NoCudaWarning, stacklevel=3)