Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/art/pipeline_trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
from .checkpoint_retention import (
CHECKPOINT_CREATED_AT_METRIC,
CHECKPOINT_EVAL_COMPLETED_METRIC,
CHECKPOINT_SAVED_METRIC,
CheckpointInfo,
CheckpointRetentionContext,
CheckpointRetentionStrategy,
keep_recent_and_top,
)
from .status import StatusReporter
from .trainer import PipelineTrainer, make_group_rollout_fn
from .types import EvalFn, RolloutFn, ScenarioT, SingleRolloutFn

__all__ = [
"CHECKPOINT_CREATED_AT_METRIC",
"CHECKPOINT_EVAL_COMPLETED_METRIC",
"CHECKPOINT_SAVED_METRIC",
"CheckpointInfo",
"CheckpointRetentionContext",
"CheckpointRetentionStrategy",
"PipelineTrainer",
"make_group_rollout_fn",
"keep_recent_and_top",
"StatusReporter",
"RolloutFn",
"SingleRolloutFn",
Expand Down
73 changes: 73 additions & 0 deletions src/art/pipeline_trainer/checkpoint_retention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

from collections.abc import Callable, Iterable
from datetime import datetime

from pydantic import BaseModel, Field

CHECKPOINT_CREATED_AT_METRIC = "checkpoint/created_at_unix"
CHECKPOINT_EVAL_COMPLETED_METRIC = "checkpoint/eval_completed"
CHECKPOINT_SAVED_METRIC = "checkpoint/saved"


class CheckpointInfo(BaseModel):
step: int
path: str
created_at: datetime
is_eval_step: bool = False
metrics: dict[str, float] = Field(default_factory=dict)


class CheckpointRetentionContext(BaseModel):
current_step: int
checkpoints: list[CheckpointInfo] = Field(default_factory=list)


# Strategies receive only checkpoints that ART has determined are eligible for
# removal and return the subset of those checkpoint steps to keep.
CheckpointRetentionStrategy = Callable[[CheckpointRetentionContext], Iterable[int]]


def keep_recent_and_top(
*,
recent: int = 5,
top: int = 2,
metric: str = "val/reward",
) -> CheckpointRetentionStrategy:
"""Keep the most recent eligible checkpoints and top metric checkpoints."""
if recent < 0:
raise ValueError("recent must be >= 0")
if top < 0:
raise ValueError("top must be >= 0")

def strategy(context: CheckpointRetentionContext) -> set[int]:
eligible_steps = {checkpoint.step for checkpoint in context.checkpoints}
keep_steps: set[int] = set()
if recent > 0:
keep_steps.update(
checkpoint.step
for checkpoint in sorted(
context.checkpoints, key=lambda item: item.step
)[-recent:]
)
ranked = [
checkpoint
for checkpoint in context.checkpoints
if metric in checkpoint.metrics
]
ranked.sort(key=lambda item: (item.metrics[metric], item.step), reverse=True)
keep_steps.update(checkpoint.step for checkpoint in ranked[:top])
return keep_steps & eligible_steps

return strategy


__all__ = [
"CHECKPOINT_CREATED_AT_METRIC",
"CHECKPOINT_EVAL_COMPLETED_METRIC",
"CHECKPOINT_SAVED_METRIC",
"CheckpointInfo",
"CheckpointRetentionContext",
"CheckpointRetentionStrategy",
"keep_recent_and_top",
]
1 change: 1 addition & 0 deletions src/art/pipeline_trainer/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class PipelineState:
scenario_offset: int = 0
total_scenarios_consumed: int = 0
last_eval_step: int = 0
completed_eval_steps: set[int] = field(default_factory=set)

# Metrics
discarded_stale_groups: int = 0
Expand Down
Loading
Loading