"""Resource monitoring primitives for GPU calibration sweeps.
Three components, each independent:
- RunLog — Append-only JSONL event log (phase timing, events).
- HWMonitor — Background nvidia-smi subprocess + /proc/stat sampler.
- ResourceSummary — Aggregation of HWMonitor output into summary statistics.
Zero overhead on the Python/JAX process: GPU sampling runs in a separate
OS process, CPU sampling reads /proc/stat in a daemon thread.
"""
from __future__ import annotations
import subprocess
import threading
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
# ---------------------------------------------------------------------------
# RunLog — JSONL event accumulator
# ---------------------------------------------------------------------------
[docs]
class RunLog:
"""Accumulates timestamped JSONL entries for a single fit run.
Each line is a JSON object with at least ``ts`` and ``type`` keys.
Designed for ``tail -f`` on remote pods and ``jq`` filtering.
"""
def __init__(self) -> None:
self._entries: list[dict[str, Any]] = []
[docs]
def log_phase(self, name: str, seconds: float, **extra: Any) -> None:
"""Record a completed phase with its wall-clock duration."""
entry: dict[str, Any] = {
"ts": _now_iso(),
"type": "phase",
"name": name,
"seconds": round(seconds, 3),
}
entry.update(extra)
self._entries.append(entry)
[docs]
def log_event(self, event_type: str, data: dict[str, Any] | None = None) -> None:
"""Record a point-in-time event (start, error, etc.)."""
entry: dict[str, Any] = {
"ts": _now_iso(),
"type": "event",
"event": event_type,
}
if data:
entry["data"] = data
self._entries.append(entry)
[docs]
def write(self, path: Path) -> None:
"""Write all entries as newline-delimited JSON."""
import json
with open(path, "w") as f:
for entry in self._entries:
f.write(json.dumps(entry, default=str) + "\n")
# ---------------------------------------------------------------------------
# /proc/stat CPU sampler
# ---------------------------------------------------------------------------
def _read_cpu_jiffies() -> tuple[int, int, int] | None:
"""Read aggregate CPU jiffies from /proc/stat.
Returns (user+nice, idle, total) or None if unavailable.
/proc/stat first line format (fixed across all Linux versions):
cpu user nice system idle iowait irq softirq steal [guest] [guest_nice]
"""
try:
with open("/proc/stat") as f:
line = f.readline()
except OSError:
return None
parts = line.split()
if not parts or parts[0] != "cpu":
return None
# All fields after "cpu" are jiffies
vals = [int(x) for x in parts[1:]]
user_nice = vals[0] + vals[1] # user + nice
idle = vals[3] # idle (index 3)
total = sum(vals)
return user_nice, idle, total
class _CpuSampler:
"""Daemon thread that samples /proc/stat at fixed intervals.
Stores (timestamp, usr%, idle%) tuples for later aggregation.
"""
def __init__(self, interval_sec: int = 2) -> None:
self._interval = interval_sec
self._samples: list[tuple[str, float, float]] = []
self._stop = threading.Event()
self._thread: threading.Thread | None = None
def start(self) -> None:
self._stop.clear()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
def stop(self) -> list[tuple[str, float, float]]:
"""Stop sampling, return collected (timestamp, usr%, idle%) samples."""
self._stop.set()
if self._thread is not None:
self._thread.join(timeout=5)
return list(self._samples)
def _run(self) -> None:
prev = _read_cpu_jiffies()
# Wait one interval before first delta
self._stop.wait(self._interval)
while not self._stop.is_set():
curr = _read_cpu_jiffies()
if prev is not None and curr is not None:
d_user = curr[0] - prev[0]
d_idle = curr[1] - prev[1]
d_total = curr[2] - prev[2]
if d_total > 0:
usr_pct = 100.0 * d_user / d_total
idle_pct = 100.0 * d_idle / d_total
self._samples.append((_now_iso(), usr_pct, idle_pct))
prev = curr
self._stop.wait(self._interval)
# ---------------------------------------------------------------------------
# HWMonitor — background hardware sampling
# ---------------------------------------------------------------------------
[docs]
class HWMonitor:
"""Background hardware sampler via nvidia-smi subprocess + /proc/stat.
GPU: nvidia-smi writes CSV directly to disk (separate process).
GPU dmon: nvidia-smi dmon captures SM/memory-controller utilization.
CPU: daemon thread reads /proc/stat deltas (no external tool needed).
"""
def __init__(self) -> None:
self._gpu_proc: subprocess.Popen | None = None
self._dmon_proc: subprocess.Popen | None = None
self._cpu_sampler: _CpuSampler | None = None
self._gpu_path: Path | None = None
self._dmon_path: Path | None = None
self._cpu_path: Path | None = None
[docs]
def start(self, output_dir: Path, interval_sec: int = 2) -> None:
"""Launch nvidia-smi, dmon, and CPU sampler."""
output_dir.mkdir(parents=True, exist_ok=True)
# GPU metrics via nvidia-smi --query-gpu
self._gpu_path = output_dir / "hw_metrics.csv"
try:
self._gpu_proc = subprocess.Popen(
[
"nvidia-smi",
"--query-gpu=timestamp,utilization.gpu,utilization.memory,"
"memory.used,memory.total,power.draw,temperature.gpu,"
"clocks.current.sm,clocks.current.memory",
"--format=csv,nounits,noheader",
"-l", str(interval_sec),
],
stdout=open(self._gpu_path, "w"),
stderr=subprocess.DEVNULL,
)
except FileNotFoundError:
# nvidia-smi not available (CPU-only environment)
self._gpu_proc = None
# GPU dmon — SM and memory-controller utilization
# -s mu: SM utilization (sm%) + memory utilization (mem%) + FB usage (fb)
# mem% = memory controller busy %, distinct from memory.used (VRAM capacity)
self._dmon_path = output_dir / "hw_dmon.csv"
try:
self._dmon_proc = subprocess.Popen(
[
"nvidia-smi", "dmon",
"-s", "mu",
"-d", str(interval_sec),
],
stdout=open(self._dmon_path, "w"),
stderr=subprocess.DEVNULL,
)
except FileNotFoundError:
self._dmon_proc = None
# CPU metrics via /proc/stat
self._cpu_path = output_dir / "cpu_metrics.csv"
self._cpu_sampler = _CpuSampler(interval_sec=interval_sec)
self._cpu_sampler.start()
[docs]
def stop(self) -> tuple[Path | None, Path | None, Path | None]:
"""Stop monitors, return (gpu_csv_path, cpu_csv_path, dmon_csv_path).
Returns None for paths where the monitor wasn't available.
"""
gpu_path = None
cpu_path = None
dmon_path = None
if self._gpu_proc is not None:
self._gpu_proc.terminate()
self._gpu_proc.wait(timeout=5)
gpu_path = self._gpu_path
if self._dmon_proc is not None:
self._dmon_proc.terminate()
self._dmon_proc.wait(timeout=5)
dmon_path = self._dmon_path
if self._cpu_sampler is not None:
samples = self._cpu_sampler.stop()
if samples and self._cpu_path is not None:
with open(self._cpu_path, "w") as f:
f.write("timestamp,usr_pct,idle_pct\n")
for ts, usr, idle in samples:
f.write(f"{ts},{usr:.1f},{idle:.1f}\n")
cpu_path = self._cpu_path
self._gpu_proc = None
self._dmon_proc = None
self._cpu_sampler = None
return gpu_path, cpu_path, dmon_path
# ---------------------------------------------------------------------------
# ResourceSummary — aggregate hardware metrics
# ---------------------------------------------------------------------------
[docs]
@dataclass
class ResourceSummary:
"""Aggregated resource utilization from HWMonitor output files."""
peak_vram_mb: float = 0.0
jax_peak_bytes_mb: float = 0.0
gpu_util_mean_pct: float = 0.0
gpu_util_max_pct: float = 0.0
mem_util_mean_pct: float = 0.0
mem_util_max_pct: float = 0.0
power_watts_mean: float = 0.0
power_watts_max: float = 0.0
cpu_usr_mean_pct: float = 0.0
cpu_idle_mean_pct: float = 0.0
def to_dict(self) -> dict[str, float]:
d = {
"peak_vram_mb": round(self.peak_vram_mb, 1),
"gpu_util_mean_pct": round(self.gpu_util_mean_pct, 1),
"gpu_util_max_pct": round(self.gpu_util_max_pct, 1),
"power_watts_mean": round(self.power_watts_mean, 1),
"power_watts_max": round(self.power_watts_max, 1),
"cpu_usr_mean_pct": round(self.cpu_usr_mean_pct, 1),
"cpu_idle_mean_pct": round(self.cpu_idle_mean_pct, 1),
}
if self.jax_peak_bytes_mb > 0:
d["jax_peak_bytes_mb"] = round(self.jax_peak_bytes_mb, 1)
if self.mem_util_mean_pct > 0:
d["mem_util_mean_pct"] = round(self.mem_util_mean_pct, 1)
d["mem_util_max_pct"] = round(self.mem_util_max_pct, 1)
return d
[docs]
def probe_jax_peak_memory_mb() -> float:
"""Query JAX for actual peak device memory usage.
Returns peak bytes in use (MB) from XLA's allocator, which reflects
real working set — not the pre-reserved pool from nvidia-smi.
Returns 0.0 if JAX is not available or has no GPU devices.
"""
try:
import jax
devices = jax.local_devices()
gpu_devices = [d for d in devices if d.platform == "gpu"]
if not gpu_devices:
return 0.0
stats = gpu_devices[0].memory_stats()
if stats is None:
return 0.0
peak_bytes = stats.get("peak_bytes_in_use", 0)
return peak_bytes / (1024 * 1024)
except Exception:
return 0.0
[docs]
def build_resource_summary(
gpu_csv_path: Path | None,
cpu_csv_path: Path | None,
dmon_csv_path: Path | None = None,
) -> ResourceSummary:
"""Parse HWMonitor output files into a ResourceSummary.
nvidia-smi CSV columns (no header, nounits):
timestamp, gpu_util%, mem_util%, mem_used_mb, mem_total_mb,
power_w, temp_c, sm_clock_mhz, mem_clock_mhz
nvidia-smi dmon output (``-s mu``):
Header lines start with ``#``. Data columns:
gpu sm mem fb (SM%, memory-controller%, framebuffer MB)
CPU CSV columns (written by _CpuSampler):
timestamp, usr_pct, idle_pct
"""
summary = ResourceSummary()
# --- GPU metrics (--query-gpu) ---
if gpu_csv_path is not None and gpu_csv_path.exists():
gpu_utils: list[float] = []
mem_useds: list[float] = []
powers: list[float] = []
with open(gpu_csv_path) as f:
for line in f:
parts = [p.strip() for p in line.split(",")]
if len(parts) < 9:
continue
try:
gpu_util = float(parts[1])
mem_used = float(parts[3])
power = float(parts[5])
except (ValueError, IndexError):
continue
gpu_utils.append(gpu_util)
mem_useds.append(mem_used)
powers.append(power)
if gpu_utils:
summary.gpu_util_mean_pct = sum(gpu_utils) / len(gpu_utils)
summary.gpu_util_max_pct = max(gpu_utils)
if mem_useds:
summary.peak_vram_mb = max(mem_useds)
if powers:
summary.power_watts_mean = sum(powers) / len(powers)
summary.power_watts_max = max(powers)
# --- GPU dmon (memory controller utilization) ---
if dmon_csv_path is not None and dmon_csv_path.exists():
mem_utils: list[float] = []
with open(dmon_csv_path) as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
# dmon -s mu columns: gpu sm mem fb
if len(parts) < 4:
continue
try:
mem_pct = float(parts[2])
except (ValueError, IndexError):
continue
# dmon reports "-" for unavailable values
mem_utils.append(mem_pct)
if mem_utils:
summary.mem_util_mean_pct = sum(mem_utils) / len(mem_utils)
summary.mem_util_max_pct = max(mem_utils)
# --- CPU metrics ---
if cpu_csv_path is not None and cpu_csv_path.exists():
usr_vals: list[float] = []
idle_vals: list[float] = []
with open(cpu_csv_path) as f:
f.readline() # skip header
for line in f:
parts = line.strip().split(",")
if len(parts) < 3:
continue
try:
usr_vals.append(float(parts[1]))
idle_vals.append(float(parts[2]))
except (ValueError, IndexError):
continue
if usr_vals:
summary.cpu_usr_mean_pct = sum(usr_vals) / len(usr_vals)
if idle_vals:
summary.cpu_idle_mean_pct = sum(idle_vals) / len(idle_vals)
return summary
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _now_iso() -> str:
"""Current UTC timestamp in ISO 8601 format."""
return datetime.now(tz=UTC).isoformat(timespec="milliseconds")