From 4b28c98df176339686f6c22053e546da7a5c052e Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Mon, 18 May 2026 21:27:21 +0000 Subject: [PATCH 1/4] Add pipeline checkpoint retention strategies --- src/art/pipeline_trainer/__init__.py | 10 ++ .../pipeline_trainer/checkpoint_retention.py | 66 ++++++++ src/art/pipeline_trainer/state.py | 1 + src/art/pipeline_trainer/trainer.py | 146 +++++++++++++++++- tests/unit/test_checkpoint_retention.py | 62 ++++++++ .../test_pipeline_trainer_local_backend.py | 87 +++++++++++ 6 files changed, 365 insertions(+), 7 deletions(-) create mode 100644 src/art/pipeline_trainer/checkpoint_retention.py create mode 100644 tests/unit/test_checkpoint_retention.py diff --git a/src/art/pipeline_trainer/__init__.py b/src/art/pipeline_trainer/__init__.py index 3e0d0f9ce..f6fd5227c 100644 --- a/src/art/pipeline_trainer/__init__.py +++ b/src/art/pipeline_trainer/__init__.py @@ -1,10 +1,20 @@ +from .checkpoint_retention import ( + CheckpointInfo, + CheckpointRetentionContext, + CheckpointRetentionStrategy, + keep_recent_and_top, +) from .status import StatusReporter from .trainer import PipelineTrainer, make_group_rollout_fn from .types import EvalFn, RolloutFn, ScenarioT, SingleRolloutFn __all__ = [ + "CheckpointInfo", + "CheckpointRetentionContext", + "CheckpointRetentionStrategy", "PipelineTrainer", "make_group_rollout_fn", + "keep_recent_and_top", "StatusReporter", "RolloutFn", "SingleRolloutFn", diff --git a/src/art/pipeline_trainer/checkpoint_retention.py b/src/art/pipeline_trainer/checkpoint_retention.py new file mode 100644 index 000000000..2bc90c365 --- /dev/null +++ b/src/art/pipeline_trainer/checkpoint_retention.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from collections.abc import Callable, Iterable +from datetime import datetime + +from pydantic import BaseModel, Field + + +class CheckpointInfo(BaseModel): + step: int + path: str + created_at: datetime + is_eval_step: bool = False + metrics: dict[str, float] = Field(default_factory=dict) + + +class CheckpointRetentionContext(BaseModel): + current_step: int + checkpoints: list[CheckpointInfo] = Field(default_factory=list) + + +# Strategies receive only checkpoints that ART has determined are safe to delete +# and return the subset of those checkpoint steps to remove. +CheckpointRetentionStrategy = Callable[[CheckpointRetentionContext], Iterable[int]] + + +def keep_recent_and_top( + *, + recent: int = 5, + top: int = 2, + metric: str = "val/reward", +) -> CheckpointRetentionStrategy: + """Delete eligible checkpoints except the most recent and top metric steps.""" + if recent < 0: + raise ValueError("recent must be >= 0") + if top < 0: + raise ValueError("top must be >= 0") + + def strategy(context: CheckpointRetentionContext) -> set[int]: + eligible_steps = {checkpoint.step for checkpoint in context.checkpoints} + keep_steps: set[int] = set() + if recent > 0: + keep_steps.update( + checkpoint.step + for checkpoint in sorted( + context.checkpoints, key=lambda item: item.step + )[-recent:] + ) + ranked = [ + checkpoint + for checkpoint in context.checkpoints + if checkpoint.is_eval_step and metric in checkpoint.metrics + ] + ranked.sort(key=lambda item: (item.metrics[metric], item.step), reverse=True) + keep_steps.update(checkpoint.step for checkpoint in ranked[:top]) + return eligible_steps - keep_steps + + return strategy + + +__all__ = [ + "CheckpointInfo", + "CheckpointRetentionContext", + "CheckpointRetentionStrategy", + "keep_recent_and_top", +] diff --git a/src/art/pipeline_trainer/state.py b/src/art/pipeline_trainer/state.py index 95569d1c1..dc5653501 100644 --- a/src/art/pipeline_trainer/state.py +++ b/src/art/pipeline_trainer/state.py @@ -16,6 +16,7 @@ class PipelineState: scenario_offset: int = 0 total_scenarios_consumed: int = 0 last_eval_step: int = 0 + completed_eval_steps: set[int] = field(default_factory=set) # Metrics discarded_stale_groups: int = 0 diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index 6b44628af..66856af60 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -1,8 +1,12 @@ from __future__ import annotations import asyncio +from collections import Counter from contextlib import asynccontextmanager +from datetime import datetime, timezone +import json import os +from pathlib import Path import signal import time from typing import Any, AsyncIterator, Generic, Iterable, TypeVar, cast @@ -12,6 +16,11 @@ import art from art import TrajectoryGroup +from .checkpoint_retention import ( + CheckpointInfo, + CheckpointRetentionContext, + CheckpointRetentionStrategy, +) from .state import PipelineState from .status import StatusReporter from .types import ConfigT, EvalFn, RolloutFn, ScenarioT, SingleRolloutFn # noqa: F401 @@ -91,6 +100,8 @@ def __init__( eval_every_n_steps: int = 20, eval_at_start: bool = True, save_checkpoint: bool = True, + checkpoint_retention_strategy: CheckpointRetentionStrategy | None = None, + checkpoint_retention_interval: int = 1, # Resumption resume: bool = True, ) -> None: @@ -114,6 +125,8 @@ def __init__( raise ValueError("log_interval_seconds must be > 0") if discard_queue_multiplier <= 0: raise ValueError("discard_queue_multiplier must be > 0") + if checkpoint_retention_interval <= 0: + raise ValueError("checkpoint_retention_interval must be > 0") self.model = model self.backend = backend self.rollout_fn = rollout_fn @@ -137,11 +150,15 @@ def __init__( self.eval_every_n_steps = eval_every_n_steps self.eval_at_start = eval_at_start self.save_checkpoint = save_checkpoint + self.checkpoint_retention_strategy = checkpoint_retention_strategy + self.checkpoint_retention_interval = checkpoint_retention_interval self.resume = resume self.discard_queue_multiplier = discard_queue_multiplier self._discard_queue: list[TrajectoryGroup] = [] self._discard_queue_limit = discard_queue_multiplier * min_batch_size self._collapse_triggered = False + self._checkpoint_lease_counts: Counter[int] = Counter() + self._scheduled_eval_steps: set[int] = set() self.state = PipelineState() self._scenario_lock = asyncio.Lock() @@ -180,6 +197,9 @@ async def train(self, *, handle_signals: bool = True) -> None: pipeline_state.get("total_scenarios_consumed", scenario_offset) or 0 ) self.state.last_eval_step = last_eval_step + self.state.completed_eval_steps = { + int(step) for step in pipeline_state.get("completed_eval_steps", []) or [] + } if scenario_offset > 0 and self._scenario_iter is not None: skipped = await self._skip_scenarios(self._scenario_iter, scenario_offset) @@ -195,6 +215,7 @@ async def train(self, *, handle_signals: bool = True) -> None: self._eval_queue = asyncio.Queue() if self.eval_fn is not None and self.eval_at_start: + self._scheduled_eval_steps.add(start_step) await self._eval_queue.put(start_step) self.state.last_eval_step = start_step self._persist_state(start_step) @@ -352,15 +373,27 @@ async def _wait_for_policy(self) -> None: @asynccontextmanager async def _adapter_lease(self, step: int) -> AsyncIterator[None]: + self._checkpoint_lease_counts[step] += 1 if not hasattr(type(self.backend), "adapter_lease"): - yield - return - lease = getattr(self.backend, "adapter_lease", None) - if lease is None: - yield + try: + yield + finally: + self._release_checkpoint_lease(step) return - async with lease(self.model, step): - yield + try: + lease = getattr(self.backend, "adapter_lease", None) + if lease is None: + yield + return + async with lease(self.model, step): + yield + finally: + self._release_checkpoint_lease(step) + + def _release_checkpoint_lease(self, step: int) -> None: + self._checkpoint_lease_counts[step] -= 1 + if self._checkpoint_lease_counts[step] <= 0: + del self._checkpoint_lease_counts[step] def _retained_adapter_steps(self, current_step: int) -> set[int]: min_step = max(0, current_step - self.max_steps_off_policy) @@ -513,6 +546,7 @@ async def _training_stage(self) -> None: self.state.policy_version = current_step self.state.next_training_step = current_step await self._prune_model_adapters(current_step) + await self._run_checkpoint_retention(current_step) step_seconds = time.monotonic() - step_start self._status.note_training_batch( @@ -542,6 +576,7 @@ async def _training_stage(self) -> None: await self._log_zero_variance_groups(current_step) if self.eval_fn is not None and should_eval_step: + self._scheduled_eval_steps.add(current_step) if self._eval_queue is not None: await self._eval_queue.put(current_step) self.state.last_eval_step = current_step @@ -654,6 +689,7 @@ async def _run_eval(self, step: int) -> None: self._status.note_val_started(step) reward: float | None = None eval_elapsed = 0.0 + eval_completed = False try: token = self.model.activate_metrics_context("eval") eval_started = time.monotonic() @@ -697,11 +733,16 @@ async def _run_eval(self, step: int) -> None: step=step, metrics={"time/step_eval_s": eval_elapsed}, ) + eval_completed = True except asyncio.CancelledError: raise except Exception as exc: print(f"Eval failed at step {step}: {exc}") finally: + self._scheduled_eval_steps.discard(step) + if eval_completed: + self.state.completed_eval_steps.add(step) + self._persist_state(self.state.next_training_step) self._status.note_val_finished(step, reward) @staticmethod @@ -862,9 +903,100 @@ def _persist_state(self, training_step: int) -> None: "total_scenarios_consumed": self.state.total_scenarios_consumed, "training_step": training_step, "last_eval_step": self.state.last_eval_step, + "completed_eval_steps": sorted(self.state.completed_eval_steps), } self.model.merge_state({PIPELINE_STATE_KEY: payload}) + def _checkpoint_metrics_by_step(self) -> dict[int, dict[str, float]]: + history_path = Path(self.model._get_output_dir()) / "history.jsonl" + if not history_path.exists(): + return {} + sums: dict[int, dict[str, float]] = {} + counts: dict[int, dict[str, int]] = {} + with history_path.open("r", encoding="utf-8") as history_file: + for line in history_file: + try: + row = json.loads(line) + except json.JSONDecodeError: + continue + step = row.get("step") + if not isinstance(step, int): + continue + for key, value in row.items(): + if key in {"step", "recorded_at"}: + continue + if isinstance(value, bool) or not isinstance(value, (int, float)): + continue + step_sums = sums.setdefault(step, {}) + step_counts = counts.setdefault(step, {}) + step_sums[key] = step_sums.get(key, 0.0) + float(value) + step_counts[key] = step_counts.get(key, 0) + 1 + return { + step: { + key: value / counts[step][key] + for key, value in step_sums.items() + if counts[step][key] > 0 + } + for step, step_sums in sums.items() + } + + def _checkpoint_infos(self) -> list[CheckpointInfo]: + checkpoint_dir = Path(self.model._get_output_dir()) / "checkpoints" + if not checkpoint_dir.exists(): + return [] + metrics_by_step = self._checkpoint_metrics_by_step() + checkpoints: list[CheckpointInfo] = [] + for path in checkpoint_dir.iterdir(): + if not path.is_dir() or not path.name.isdigit(): + continue + step = int(path.name) + stat = path.stat() + checkpoints.append( + CheckpointInfo( + step=step, + path=str(path), + created_at=datetime.fromtimestamp(stat.st_ctime, timezone.utc), + is_eval_step=step in self.state.completed_eval_steps, + metrics=metrics_by_step.get(step, {}), + ) + ) + return sorted(checkpoints, key=lambda checkpoint: checkpoint.step) + + def _protected_checkpoint_steps(self, current_step: int) -> set[int]: + return ( + {current_step} + | set(self._checkpoint_lease_counts) + | set(self._scheduled_eval_steps) + ) + + async def _run_checkpoint_retention(self, current_step: int) -> None: + strategy = self.checkpoint_retention_strategy + if strategy is None: + return + if current_step % self.checkpoint_retention_interval != 0: + return + all_checkpoints = self._checkpoint_infos() + if not all_checkpoints: + return + protected_steps = self._protected_checkpoint_steps(current_step) + eligible = [ + checkpoint + for checkpoint in all_checkpoints + if checkpoint.step not in protected_steps + ] + if not eligible: + return + context = CheckpointRetentionContext( + current_step=current_step, + checkpoints=eligible, + ) + eligible_steps = {checkpoint.step for checkpoint in eligible} + delete_steps = set(strategy(context)) & eligible_steps + if not delete_steps: + return + keep_steps = {checkpoint.step for checkpoint in all_checkpoints} - delete_steps + await self.backend._delete_checkpoint_files(self.model, sorted(keep_steps)) + @staticmethod def _is_scalar_metadata(value: object) -> bool: return value is None or isinstance(value, (str, int, float, bool)) diff --git a/tests/unit/test_checkpoint_retention.py b/tests/unit/test_checkpoint_retention.py new file mode 100644 index 000000000..75dad5176 --- /dev/null +++ b/tests/unit/test_checkpoint_retention.py @@ -0,0 +1,62 @@ +from datetime import datetime, timezone + +import pytest + +from art.pipeline_trainer import CheckpointInfo, CheckpointRetentionContext +from art.pipeline_trainer.checkpoint_retention import keep_recent_and_top + + +def _checkpoint( + step: int, + *, + is_eval_step: bool = False, + reward: float | None = None, +) -> CheckpointInfo: + metrics = {"val/reward": reward} if reward is not None else {} + return CheckpointInfo( + step=step, + path=f"/tmp/checkpoints/{step:04d}", + created_at=datetime.fromtimestamp(step, timezone.utc), + is_eval_step=is_eval_step, + metrics=metrics, + ) + + +def test_keep_recent_and_top_deletes_everything_else() -> None: + strategy = keep_recent_and_top(recent=2, top=1, metric="val/reward") + context = CheckpointRetentionContext( + current_step=6, + checkpoints=[ + _checkpoint(0), + _checkpoint(1, is_eval_step=True, reward=0.2), + _checkpoint(2), + _checkpoint(3, is_eval_step=True, reward=0.8), + _checkpoint(4), + _checkpoint(5), + ], + ) + + assert set(strategy(context)) == {0, 1, 2} + + +def test_keep_recent_and_top_handles_zero_limits() -> None: + strategy = keep_recent_and_top(recent=0, top=0) + context = CheckpointRetentionContext( + current_step=3, + checkpoints=[_checkpoint(0), _checkpoint(1), _checkpoint(2)], + ) + + assert set(strategy(context)) == {0, 1, 2} + + +@pytest.mark.parametrize( + ("recent", "top", "match"), + [(-1, 0, "recent must be >= 0"), (0, -1, "top must be >= 0")], +) +def test_keep_recent_and_top_rejects_negative_limits( + recent: int, + top: int, + match: str, +) -> None: + with pytest.raises(ValueError, match=match): + keep_recent_and_top(recent=recent, top=top) diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py index c02b22ea1..939b0a124 100644 --- a/tests/unit/test_pipeline_trainer_local_backend.py +++ b/tests/unit/test_pipeline_trainer_local_backend.py @@ -1,4 +1,5 @@ import asyncio +import json from pathlib import Path from types import SimpleNamespace from typing import Any, cast @@ -13,6 +14,7 @@ from art.local import LocalBackend from art.megatron import MegatronBackend from art.megatron.train import load_adapter_into_model +from art.pipeline_trainer import CheckpointRetentionContext from art.pipeline_trainer.trainer import PipelineTrainer from art.preprocessing.tokenize import TokenizedResult from art.utils.output_dirs import get_model_dir @@ -235,6 +237,91 @@ async def fake_train_model( assert seen["dev_config"]["scale_rewards"] is False +@pytest.mark.asyncio +async def test_pipeline_trainer_checkpoint_retention_only_passes_unprotected_steps( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-checkpoint-retention", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + checkpoint_dir = Path(model._get_output_dir()) / "checkpoints" + for step in range(6): + (checkpoint_dir / f"{step:04d}").mkdir(parents=True) + history_path = Path(model._get_output_dir()) / "history.jsonl" + history_path.write_text( + "\n".join( + json.dumps(row) + for row in [ + {"step": 2, "val/reward": 1.0}, + {"step": 2, "val/reward": 3.0}, + {"step": 3, "val/reward": 10.0}, + ] + ) + + "\n", + encoding="utf-8", + ) + + backend = MagicMock() + backend._delete_checkpoint_files = AsyncMock() + contexts: list[CheckpointRetentionContext] = [] + + def strategy(context: CheckpointRetentionContext) -> set[int]: + contexts.append(context) + return {0, 2, 4, 99} + + trainer = _make_trainer( + model=model, + backend=backend, + checkpoint_retention_strategy=strategy, + ) + trainer.state.completed_eval_steps = {2, 3} + trainer._checkpoint_lease_counts[3] = 1 + trainer._scheduled_eval_steps.add(4) + + await trainer._run_checkpoint_retention(5) + + assert [checkpoint.step for checkpoint in contexts[0].checkpoints] == [0, 1, 2] + step_two = contexts[0].checkpoints[2] + assert step_two.is_eval_step is True + assert step_two.metrics["val/reward"] == 2.0 + backend._delete_checkpoint_files.assert_awaited_once_with( # type: ignore[attr-defined] + model, + [1, 3, 4, 5], + ) + + +@pytest.mark.asyncio +async def test_pipeline_trainer_checkpoint_retention_honors_interval( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-checkpoint-retention-interval", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + checkpoint_dir = Path(model._get_output_dir()) / "checkpoints" + for step in range(3): + (checkpoint_dir / f"{step:04d}").mkdir(parents=True) + + backend = MagicMock() + backend._delete_checkpoint_files = AsyncMock() + + trainer = _make_trainer( + model=model, + backend=backend, + checkpoint_retention_strategy=lambda _context: {0, 1, 2}, + checkpoint_retention_interval=5, + ) + + await trainer._run_checkpoint_retention(4) + + backend._delete_checkpoint_files.assert_not_awaited() # type: ignore[attr-defined] + + def _make_tokenized_result( trajectory: Trajectory, token_ids: list[int], From 21ce7a5a3dcb6309435a7bc8ee8ca79f2c9e417c Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Mon, 18 May 2026 22:18:25 +0000 Subject: [PATCH 2/4] Record checkpoint retention metadata --- src/art/model.py | 2 + src/art/pipeline_trainer/__init__.py | 6 +++ .../pipeline_trainer/checkpoint_retention.py | 9 +++- src/art/pipeline_trainer/trainer.py | 47 +++++++++++++++++-- src/art/unsloth/service.py | 37 ++++++++++++++- tests/unit/test_checkpoint_retention.py | 14 ++++++ tests/unit/test_multi_checkpoint_inference.py | 47 +++++++++++++++++++ .../test_pipeline_trainer_local_backend.py | 44 ++++++++++++++++- 8 files changed, 200 insertions(+), 6 deletions(-) diff --git a/src/art/model.py b/src/art/model.py index e412002b0..00cd76810 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -132,6 +132,7 @@ def __getattr__(self, name: str) -> Any: "costs", "time", "data", + "checkpoint", } ) METRIC_SPLITS = frozenset({"train", "val", "test"}) @@ -617,6 +618,7 @@ def _get_wandb_run(self) -> Optional["Run"]: run.define_metric("costs/*", step_metric="training_step") run.define_metric("time/*", step_metric="training_step") run.define_metric("data/*", step_metric="training_step") + run.define_metric("checkpoint/*", step_metric="training_step") run.define_metric("train/*", step_metric="training_step") run.define_metric("val/*", step_metric="training_step") run.define_metric("test/*", step_metric="training_step") diff --git a/src/art/pipeline_trainer/__init__.py b/src/art/pipeline_trainer/__init__.py index f6fd5227c..ff10857ef 100644 --- a/src/art/pipeline_trainer/__init__.py +++ b/src/art/pipeline_trainer/__init__.py @@ -1,4 +1,7 @@ from .checkpoint_retention import ( + CHECKPOINT_CREATED_AT_METRIC, + CHECKPOINT_EVAL_COMPLETED_METRIC, + CHECKPOINT_SAVED_METRIC, CheckpointInfo, CheckpointRetentionContext, CheckpointRetentionStrategy, @@ -9,6 +12,9 @@ from .types import EvalFn, RolloutFn, ScenarioT, SingleRolloutFn __all__ = [ + "CHECKPOINT_CREATED_AT_METRIC", + "CHECKPOINT_EVAL_COMPLETED_METRIC", + "CHECKPOINT_SAVED_METRIC", "CheckpointInfo", "CheckpointRetentionContext", "CheckpointRetentionStrategy", diff --git a/src/art/pipeline_trainer/checkpoint_retention.py b/src/art/pipeline_trainer/checkpoint_retention.py index 2bc90c365..54c2232e1 100644 --- a/src/art/pipeline_trainer/checkpoint_retention.py +++ b/src/art/pipeline_trainer/checkpoint_retention.py @@ -5,6 +5,10 @@ from pydantic import BaseModel, Field +CHECKPOINT_CREATED_AT_METRIC = "checkpoint/created_at_unix" +CHECKPOINT_EVAL_COMPLETED_METRIC = "checkpoint/eval_completed" +CHECKPOINT_SAVED_METRIC = "checkpoint/saved" + class CheckpointInfo(BaseModel): step: int @@ -49,7 +53,7 @@ def strategy(context: CheckpointRetentionContext) -> set[int]: ranked = [ checkpoint for checkpoint in context.checkpoints - if checkpoint.is_eval_step and metric in checkpoint.metrics + if metric in checkpoint.metrics ] ranked.sort(key=lambda item: (item.metrics[metric], item.step), reverse=True) keep_steps.update(checkpoint.step for checkpoint in ranked[:top]) @@ -59,6 +63,9 @@ def strategy(context: CheckpointRetentionContext) -> set[int]: __all__ = [ + "CHECKPOINT_CREATED_AT_METRIC", + "CHECKPOINT_EVAL_COMPLETED_METRIC", + "CHECKPOINT_SAVED_METRIC", "CheckpointInfo", "CheckpointRetentionContext", "CheckpointRetentionStrategy", diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index 66856af60..1ab41e4af 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -17,6 +17,8 @@ from art import TrajectoryGroup from .checkpoint_retention import ( + CHECKPOINT_CREATED_AT_METRIC, + CHECKPOINT_EVAL_COMPLETED_METRIC, CheckpointInfo, CheckpointRetentionContext, CheckpointRetentionStrategy, @@ -545,6 +547,7 @@ async def _training_stage(self) -> None: current_step = result.step self.state.policy_version = current_step self.state.next_training_step = current_step + await self._log_checkpoint_saved(result) await self._prune_model_adapters(current_step) await self._run_checkpoint_retention(current_step) @@ -733,6 +736,7 @@ async def _run_eval(self, step: int) -> None: step=step, metrics={"time/step_eval_s": eval_elapsed}, ) + await self._log_checkpoint_eval_completed(step) eval_completed = True except asyncio.CancelledError: raise @@ -907,6 +911,32 @@ def _persist_state(self, training_step: int) -> None: } self.model.merge_state({PIPELINE_STATE_KEY: payload}) + async def _log_checkpoint_saved(self, result: Any) -> None: + step = int(result.step) + checkpoint_path = getattr(result, "checkpoint_path", None) + path = ( + Path(checkpoint_path) + if isinstance(checkpoint_path, str) and checkpoint_path + else Path(self.model._get_output_dir()) / "checkpoints" / f"{step:04d}" + ) + if not path.exists(): + return + await self.model.log( + metrics={ + "saved": 1.0, + "created_at_unix": path.stat().st_ctime, + }, + split="checkpoint", + step=step, + ) + + async def _log_checkpoint_eval_completed(self, step: int) -> None: + await self.model.log( + metrics={"eval_completed": 1.0}, + split="checkpoint", + step=step, + ) + def _checkpoint_metrics_by_step(self) -> dict[int, dict[str, float]]: history_path = Path(self.model._get_output_dir()) / "history.jsonl" if not history_path.exists(): @@ -951,13 +981,24 @@ def _checkpoint_infos(self) -> list[CheckpointInfo]: continue step = int(path.name) stat = path.stat() + metrics = metrics_by_step.get(step, {}) + created_at_unix = metrics.get(CHECKPOINT_CREATED_AT_METRIC) + created_at = ( + datetime.fromtimestamp(created_at_unix, timezone.utc) + if created_at_unix is not None + else datetime.fromtimestamp(stat.st_ctime, timezone.utc) + ) checkpoints.append( CheckpointInfo( step=step, path=str(path), - created_at=datetime.fromtimestamp(stat.st_ctime, timezone.utc), - is_eval_step=step in self.state.completed_eval_steps, - metrics=metrics_by_step.get(step, {}), + created_at=created_at, + is_eval_step=( + step in self.state.completed_eval_steps + or metrics.get(CHECKPOINT_EVAL_COMPLETED_METRIC, 0.0) > 0.0 + or any(key.startswith(("val/", "test/")) for key in metrics) + ), + metrics=metrics, ) ) return sorted(checkpoints, key=lambda checkpoint: checkpoint.step) diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index 6d2c02b2e..a90590c04 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -143,6 +143,11 @@ class UnslothService: repr=False, ) _child_processes: ChildProcessSupervisor = field(init=False, repr=False) + _loaded_adapter_steps: set[int] = field( + default_factory=set, + init=False, + repr=False, + ) def __post_init__(self) -> None: self._child_processes = ChildProcessSupervisor(self._on_child_process_exit) @@ -552,7 +557,7 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: ) async with httpx.AsyncClient() as client: response = await client.post( - f"http://{self._vllm_host}:{self._vllm_port}/v1/load_lora_adapter", + f"{self._vllm_base_url}/v1/load_lora_adapter", json={ "lora_name": lora_name, "lora_path": checkpoint_path, @@ -566,6 +571,33 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: f"[DEDICATED] _reload_adapter DONE: lora_name={lora_name} " f"status={response.status_code}" ) + self._latest_step = step + self._loaded_adapter_steps.add(step) + + async def _unload_adapter(self, step: int) -> None: + import httpx + + self._raise_if_child_failed() + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self._vllm_base_url}/v1/unload_lora_adapter", + json={"lora_name": f"{self.model_name}@{step}"}, + **self._runtime_request_kwargs(), + timeout=30.0, + ) + if response.status_code == 404: + self._loaded_adapter_steps.discard(step) + return + response.raise_for_status() + self._loaded_adapter_steps.discard(step) + + async def prune_loaded_adapters(self, *, retain_steps: set[int]) -> None: + if self.rollout_weights_mode != "lora" or self._vllm_port == 0: + return + for step in sorted(self._loaded_adapter_steps - retain_steps): + if step == self._latest_step: + continue + await self._unload_adapter(step) def close(self) -> None: """Terminate vLLM subprocess if running.""" @@ -582,6 +614,7 @@ def close(self) -> None: self._vllm_log_file = None self._vllm_log_path = None self._vllm_nccl_so_path = None + self._loaded_adapter_steps.clear() finally: self._lifecycle.restore_parent_cleanup() @@ -617,6 +650,8 @@ async def start_openai_server( port, config=config, ) + if self.rollout_weights_mode == "lora": + self._loaded_adapter_steps.add(self._latest_step) try: if self.rollout_weights_mode == "merged": _ = self._state diff --git a/tests/unit/test_checkpoint_retention.py b/tests/unit/test_checkpoint_retention.py index 75dad5176..d0d4e8d14 100644 --- a/tests/unit/test_checkpoint_retention.py +++ b/tests/unit/test_checkpoint_retention.py @@ -39,6 +39,20 @@ def test_keep_recent_and_top_deletes_everything_else() -> None: assert set(strategy(context)) == {0, 1, 2} +def test_keep_recent_and_top_uses_metric_presence_for_legacy_history() -> None: + strategy = keep_recent_and_top(recent=0, top=1, metric="val/reward") + context = CheckpointRetentionContext( + current_step=3, + checkpoints=[ + _checkpoint(0), + _checkpoint(1, reward=0.9), + _checkpoint(2, reward=0.2), + ], + ) + + assert set(strategy(context)) == {0, 2} + + def test_keep_recent_and_top_handles_zero_limits() -> None: strategy = keep_recent_and_top(recent=0, top=0) context = CheckpointRetentionContext( diff --git a/tests/unit/test_multi_checkpoint_inference.py b/tests/unit/test_multi_checkpoint_inference.py index cd91e2f31..7faf921e9 100644 --- a/tests/unit/test_multi_checkpoint_inference.py +++ b/tests/unit/test_multi_checkpoint_inference.py @@ -347,6 +347,53 @@ def test_max_loras_can_be_overridden(self, unsloth_service_class): assert engine_args["max_loras"] == 8 + @pytest.mark.asyncio + async def test_prune_loaded_adapters_unloads_non_retained_steps( + self, unsloth_service_class, monkeypatch + ): + """UnslothService should unload old vLLM LoRA adapters like MegatronService.""" + httpx = pytest.importorskip("httpx") + UnslothService = unsloth_service_class + calls = [] + + class FakeResponse: + status_code = 200 + + def raise_for_status(self): + return None + + class FakeAsyncClient: + async def __aenter__(self): + return self + + async def __aexit__(self, *_args): + return None + + async def post(self, url, *, json, **_kwargs): + calls.append((url, json)) + return FakeResponse() + + monkeypatch.setattr(httpx, "AsyncClient", FakeAsyncClient) + service = UnslothService( + model_name="test-model", + base_model="meta-llama/Llama-3.1-8B", + config={"rollout_weights_mode": "lora"}, + output_dir="/tmp/test", + ) + service._vllm_port = 8000 + service._latest_step = 3 + service._loaded_adapter_steps.update({1, 2, 3}) + + await service.prune_loaded_adapters(retain_steps={2}) + + assert calls == [ + ( + "http://127.0.0.1:8000/v1/unload_lora_adapter", + {"lora_name": "test-model@1"}, + ) + ] + assert service._loaded_adapter_steps == {2, 3} + # ============================================================================= # Pipelined Training Usage Example diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py index 939b0a124..a845b44fd 100644 --- a/tests/unit/test_pipeline_trainer_local_backend.py +++ b/tests/unit/test_pipeline_trainer_local_backend.py @@ -1,4 +1,5 @@ import asyncio +from datetime import datetime, timezone import json from pathlib import Path from types import SimpleNamespace @@ -14,7 +15,11 @@ from art.local import LocalBackend from art.megatron import MegatronBackend from art.megatron.train import load_adapter_into_model -from art.pipeline_trainer import CheckpointRetentionContext +from art.pipeline_trainer import ( + CHECKPOINT_CREATED_AT_METRIC, + CHECKPOINT_EVAL_COMPLETED_METRIC, + CheckpointRetentionContext, +) from art.pipeline_trainer.trainer import PipelineTrainer from art.preprocessing.tokenize import TokenizedResult from art.utils.output_dirs import get_model_dir @@ -257,6 +262,11 @@ async def test_pipeline_trainer_checkpoint_retention_only_passes_unprotected_ste for row in [ {"step": 2, "val/reward": 1.0}, {"step": 2, "val/reward": 3.0}, + { + "step": 2, + CHECKPOINT_CREATED_AT_METRIC: 123.0, + CHECKPOINT_EVAL_COMPLETED_METRIC: 1.0, + }, {"step": 3, "val/reward": 10.0}, ] ) @@ -286,6 +296,7 @@ def strategy(context: CheckpointRetentionContext) -> set[int]: assert [checkpoint.step for checkpoint in contexts[0].checkpoints] == [0, 1, 2] step_two = contexts[0].checkpoints[2] assert step_two.is_eval_step is True + assert step_two.created_at == datetime.fromtimestamp(123.0, timezone.utc) assert step_two.metrics["val/reward"] == 2.0 backend._delete_checkpoint_files.assert_awaited_once_with( # type: ignore[attr-defined] model, @@ -322,6 +333,37 @@ async def test_pipeline_trainer_checkpoint_retention_honors_interval( backend._delete_checkpoint_files.assert_not_awaited() # type: ignore[attr-defined] +@pytest.mark.asyncio +async def test_pipeline_trainer_logs_checkpoint_retention_metadata( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-checkpoint-retention-metadata", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + report_metrics=[], + ) + checkpoint_path = Path(model._get_output_dir()) / "checkpoints" / "0001" + checkpoint_path.mkdir(parents=True) + trainer = _make_trainer(model=model, backend=MagicMock()) + + await trainer._log_checkpoint_saved( + SimpleNamespace(step=1, checkpoint_path=str(checkpoint_path)) + ) + await trainer._log_checkpoint_eval_completed(1) + + rows = [ + json.loads(line) + for line in (Path(model._get_output_dir()) / "history.jsonl") + .read_text() + .splitlines() + ] + assert rows[0]["checkpoint/saved"] == 1.0 + assert rows[0][CHECKPOINT_CREATED_AT_METRIC] > 0.0 + assert rows[1][CHECKPOINT_EVAL_COMPLETED_METRIC] == 1.0 + + def _make_tokenized_result( trajectory: Trajectory, token_ids: list[int], From 1966c3427d39de8baefb5a16d0be932dff20a8c7 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Mon, 18 May 2026 23:23:00 +0000 Subject: [PATCH 3/4] Keep checkpoint metadata out of metric routing --- src/art/model.py | 2 -- src/art/pipeline_trainer/trainer.py | 37 +++++++++++++++++++++-------- tests/unit/test_metric_routing.py | 2 ++ 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/art/model.py b/src/art/model.py index 00cd76810..e412002b0 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -132,7 +132,6 @@ def __getattr__(self, name: str) -> Any: "costs", "time", "data", - "checkpoint", } ) METRIC_SPLITS = frozenset({"train", "val", "test"}) @@ -618,7 +617,6 @@ def _get_wandb_run(self) -> Optional["Run"]: run.define_metric("costs/*", step_metric="training_step") run.define_metric("time/*", step_metric="training_step") run.define_metric("data/*", step_metric="training_step") - run.define_metric("checkpoint/*", step_metric="training_step") run.define_metric("train/*", step_metric="training_step") run.define_metric("val/*", step_metric="training_step") run.define_metric("test/*", step_metric="training_step") diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index 1ab41e4af..bfe0db692 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -19,6 +19,7 @@ from .checkpoint_retention import ( CHECKPOINT_CREATED_AT_METRIC, CHECKPOINT_EVAL_COMPLETED_METRIC, + CHECKPOINT_SAVED_METRIC, CheckpointInfo, CheckpointRetentionContext, CheckpointRetentionStrategy, @@ -911,6 +912,24 @@ def _persist_state(self, training_step: int) -> None: } self.model.merge_state({PIPELINE_STATE_KEY: payload}) + def _log_checkpoint_history(self, step: int, metrics: dict[str, float]) -> None: + row = { + (key if key.startswith("checkpoint/") else f"checkpoint/{key}"): value + for key, value in metrics.items() + if value == value + } + if not row: + return + row["training_step"] = step + row["time/wall_clock_sec"] = time.time() - self.model._run_start_time + row["step"] = step + row["recorded_at"] = datetime.now().isoformat() + + output_dir = self.model._get_output_dir() + os.makedirs(output_dir, exist_ok=True) + with open(Path(output_dir) / "history.jsonl", "a", encoding="utf-8") as f: + f.write(json.dumps(row) + "\n") + async def _log_checkpoint_saved(self, result: Any) -> None: step = int(result.step) checkpoint_path = getattr(result, "checkpoint_path", None) @@ -921,20 +940,18 @@ async def _log_checkpoint_saved(self, result: Any) -> None: ) if not path.exists(): return - await self.model.log( - metrics={ - "saved": 1.0, - "created_at_unix": path.stat().st_ctime, + self._log_checkpoint_history( + step, + { + CHECKPOINT_SAVED_METRIC: 1.0, + CHECKPOINT_CREATED_AT_METRIC: path.stat().st_ctime, }, - split="checkpoint", - step=step, ) async def _log_checkpoint_eval_completed(self, step: int) -> None: - await self.model.log( - metrics={"eval_completed": 1.0}, - split="checkpoint", - step=step, + self._log_checkpoint_history( + step, + {CHECKPOINT_EVAL_COMPLETED_METRIC: 1.0}, ) def _checkpoint_metrics_by_step(self) -> dict[int, dict[str, float]]: diff --git a/tests/unit/test_metric_routing.py b/tests/unit/test_metric_routing.py index fb5eb65ae..5a290ebfb 100644 --- a/tests/unit/test_metric_routing.py +++ b/tests/unit/test_metric_routing.py @@ -24,6 +24,7 @@ def test_log_metrics_routes_known_sections_without_split_prefix( { "reward/mean": 0.9, "custom": 1.0, + "checkpoint/foo": 1.5, "rewardish/value": 2.0, }, split="train", @@ -36,6 +37,7 @@ def test_log_metrics_routes_known_sections_without_split_prefix( assert entry["reward/mean"] == 0.9 assert entry["train/custom"] == 1.0 + assert entry["train/checkpoint/foo"] == 1.5 assert entry["train/rewardish/value"] == 2.0 assert entry["training_step"] == 7 assert entry["time/wall_clock_sec"] >= 0 From 762fc26f6cd3db397424569f218209c40dd1c5f3 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Tue, 19 May 2026 07:22:50 +0000 Subject: [PATCH 4/4] Make retention strategies return kept steps --- src/art/pipeline_trainer/checkpoint_retention.py | 8 ++++---- src/art/pipeline_trainer/trainer.py | 3 ++- tests/unit/test_checkpoint_retention.py | 8 ++++---- tests/unit/test_pipeline_trainer_local_backend.py | 2 +- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/art/pipeline_trainer/checkpoint_retention.py b/src/art/pipeline_trainer/checkpoint_retention.py index 54c2232e1..776045f7b 100644 --- a/src/art/pipeline_trainer/checkpoint_retention.py +++ b/src/art/pipeline_trainer/checkpoint_retention.py @@ -23,8 +23,8 @@ class CheckpointRetentionContext(BaseModel): checkpoints: list[CheckpointInfo] = Field(default_factory=list) -# Strategies receive only checkpoints that ART has determined are safe to delete -# and return the subset of those checkpoint steps to remove. +# Strategies receive only checkpoints that ART has determined are eligible for +# removal and return the subset of those checkpoint steps to keep. CheckpointRetentionStrategy = Callable[[CheckpointRetentionContext], Iterable[int]] @@ -34,7 +34,7 @@ def keep_recent_and_top( top: int = 2, metric: str = "val/reward", ) -> CheckpointRetentionStrategy: - """Delete eligible checkpoints except the most recent and top metric steps.""" + """Keep the most recent eligible checkpoints and top metric checkpoints.""" if recent < 0: raise ValueError("recent must be >= 0") if top < 0: @@ -57,7 +57,7 @@ def strategy(context: CheckpointRetentionContext) -> set[int]: ] ranked.sort(key=lambda item: (item.metrics[metric], item.step), reverse=True) keep_steps.update(checkpoint.step for checkpoint in ranked[:top]) - return eligible_steps - keep_steps + return keep_steps & eligible_steps return strategy diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index bfe0db692..721ddd791 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -1049,7 +1049,8 @@ async def _run_checkpoint_retention(self, current_step: int) -> None: checkpoints=eligible, ) eligible_steps = {checkpoint.step for checkpoint in eligible} - delete_steps = set(strategy(context)) & eligible_steps + keep_eligible_steps = set(strategy(context)) & eligible_steps + delete_steps = eligible_steps - keep_eligible_steps if not delete_steps: return keep_steps = {checkpoint.step for checkpoint in all_checkpoints} - delete_steps diff --git a/tests/unit/test_checkpoint_retention.py b/tests/unit/test_checkpoint_retention.py index d0d4e8d14..ee71dd2d2 100644 --- a/tests/unit/test_checkpoint_retention.py +++ b/tests/unit/test_checkpoint_retention.py @@ -22,7 +22,7 @@ def _checkpoint( ) -def test_keep_recent_and_top_deletes_everything_else() -> None: +def test_keep_recent_and_top_returns_kept_steps() -> None: strategy = keep_recent_and_top(recent=2, top=1, metric="val/reward") context = CheckpointRetentionContext( current_step=6, @@ -36,7 +36,7 @@ def test_keep_recent_and_top_deletes_everything_else() -> None: ], ) - assert set(strategy(context)) == {0, 1, 2} + assert set(strategy(context)) == {3, 4, 5} def test_keep_recent_and_top_uses_metric_presence_for_legacy_history() -> None: @@ -50,7 +50,7 @@ def test_keep_recent_and_top_uses_metric_presence_for_legacy_history() -> None: ], ) - assert set(strategy(context)) == {0, 2} + assert set(strategy(context)) == {1} def test_keep_recent_and_top_handles_zero_limits() -> None: @@ -60,7 +60,7 @@ def test_keep_recent_and_top_handles_zero_limits() -> None: checkpoints=[_checkpoint(0), _checkpoint(1), _checkpoint(2)], ) - assert set(strategy(context)) == {0, 1, 2} + assert set(strategy(context)) == set() @pytest.mark.parametrize( diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py index a845b44fd..65a9db9d6 100644 --- a/tests/unit/test_pipeline_trainer_local_backend.py +++ b/tests/unit/test_pipeline_trainer_local_backend.py @@ -280,7 +280,7 @@ async def test_pipeline_trainer_checkpoint_retention_only_passes_unprotected_ste def strategy(context: CheckpointRetentionContext) -> set[int]: contexts.append(context) - return {0, 2, 4, 99} + return {1, 4, 99} trainer = _make_trainer( model=model,