Source code for pytyche.viz._gif
"""Round-by-round experiment evolution GIF."""
from __future__ import annotations
from collections.abc import Sequence
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_agg import FigureCanvasAgg
from PIL import Image
from pytyche.viz._cells import plot_cells
from pytyche.viz._policy_tree import plot_policy_tree
[docs]
def experiment_evolution_gif(
history: Sequence[Any],
output_path: str | Path,
fps: float = 0.2,
) -> Path:
"""Render the round-by-round experiment evolution to an animated GIF.
Each frame shows the round's cell allocation, plus the round's policy
tree when one was fit.
Args:
history: Per-round experiment snapshots — the shipped L1
``Experiment`` shape, consumed duck-typed: ``cells_shipped``
(the cell list ``plot_cells`` accepts) and
``next_recommendation`` carrying the round's fitted ``tree``
(``None`` on a plan with no fit, e.g. real-data cold starts).
output_path: Destination ``.gif`` path.
fps: Frames (rounds) per second. Rounds are dense (a full
policy tree per frame); the default holds each for five
seconds.
Returns:
``output_path`` as a :class:`~pathlib.Path`.
Raises:
ValueError: When *history* is empty.
"""
if not history:
raise ValueError("history is empty — no rounds to render.")
path = Path(output_path)
# One fixed canvas for every frame (the GIF format requires it),
# sized for the widest round's tree: leaf labels are multi-line
# boxes and sklearn spaces leaves evenly across the panel.
max_leaves = max(
(
len(r.next_recommendation.tree.allocation_map)
for r in history
if r.next_recommendation is not None and r.next_recommendation.tree is not None
),
default=0,
)
width = 5.0 + max(6.0, 2.6 * max_leaves)
fig = plt.figure(figsize=(width, 5.5), dpi=100)
# Rasterize through an explicit Agg canvas: deterministic headless
# rendering regardless of the session's interactive backend, and the
# canvas type that actually exposes the RGBA buffer.
canvas = FigureCanvasAgg(fig)
frames: list[Image.Image] = []
try:
for i, round_snapshot in enumerate(history):
fig.clear()
plan = round_snapshot.next_recommendation
tree = plan.tree if plan is not None else None
if tree is not None:
grid = fig.add_gridspec(1, 2, width_ratios=[1, 3])
ax_cells = fig.add_subplot(grid[0, 0])
plot_policy_tree(tree, ax=fig.add_subplot(grid[0, 1]))
else:
ax_cells = fig.add_subplot(1, 1, 1)
plot_cells(round_snapshot.cells_shipped, ax=ax_cells)
fig.suptitle(f"round {i + 1}/{len(history)}")
canvas.draw()
rgba = np.asarray(canvas.buffer_rgba())
frames.append(Image.fromarray(rgba).convert("RGB"))
finally:
plt.close(fig)
# GIF timing is per-frame duration in milliseconds — the format's
# native model, and what sub-1 fps actually means (5 s per round at
# the default 0.2).
frames[0].save(
path,
save_all=True,
append_images=frames[1:],
duration=round(1000.0 / fps),
loop=0,
)
return path