Source code for pytyche.viz._gif

"""Round-by-round experiment evolution GIF."""

from __future__ import annotations

from collections.abc import Sequence
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_agg import FigureCanvasAgg
from PIL import Image

from pytyche.viz._cells import plot_cells
from pytyche.viz._policy_tree import plot_policy_tree


[docs] def experiment_evolution_gif( history: Sequence[Any], output_path: str | Path, fps: float = 0.2, ) -> Path: """Render the round-by-round experiment evolution to an animated GIF. Each frame shows the round's cell allocation, plus the round's policy tree when one was fit. Args: history: Per-round experiment snapshots — the shipped L1 ``Experiment`` shape, consumed duck-typed: ``cells_shipped`` (the cell list ``plot_cells`` accepts) and ``next_recommendation`` carrying the round's fitted ``tree`` (``None`` on a plan with no fit, e.g. real-data cold starts). output_path: Destination ``.gif`` path. fps: Frames (rounds) per second. Rounds are dense (a full policy tree per frame); the default holds each for five seconds. Returns: ``output_path`` as a :class:`~pathlib.Path`. Raises: ValueError: When *history* is empty. """ if not history: raise ValueError("history is empty — no rounds to render.") path = Path(output_path) # One fixed canvas for every frame (the GIF format requires it), # sized for the widest round's tree: leaf labels are multi-line # boxes and sklearn spaces leaves evenly across the panel. max_leaves = max( ( len(r.next_recommendation.tree.allocation_map) for r in history if r.next_recommendation is not None and r.next_recommendation.tree is not None ), default=0, ) width = 5.0 + max(6.0, 2.6 * max_leaves) fig = plt.figure(figsize=(width, 5.5), dpi=100) # Rasterize through an explicit Agg canvas: deterministic headless # rendering regardless of the session's interactive backend, and the # canvas type that actually exposes the RGBA buffer. canvas = FigureCanvasAgg(fig) frames: list[Image.Image] = [] try: for i, round_snapshot in enumerate(history): fig.clear() plan = round_snapshot.next_recommendation tree = plan.tree if plan is not None else None if tree is not None: grid = fig.add_gridspec(1, 2, width_ratios=[1, 3]) ax_cells = fig.add_subplot(grid[0, 0]) plot_policy_tree(tree, ax=fig.add_subplot(grid[0, 1])) else: ax_cells = fig.add_subplot(1, 1, 1) plot_cells(round_snapshot.cells_shipped, ax=ax_cells) fig.suptitle(f"round {i + 1}/{len(history)}") canvas.draw() rgba = np.asarray(canvas.buffer_rgba()) frames.append(Image.fromarray(rgba).convert("RGB")) finally: plt.close(fig) # GIF timing is per-frame duration in milliseconds — the format's # native model, and what sub-1 fps actually means (5 s per round at # the default 0.2). frames[0].save( path, save_all=True, append_images=frames[1:], duration=round(1000.0 / fps), loop=0, ) return path