Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion docs/features/tracking-metrics.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ icon: "chart-line"
---

ART writes a metrics row every time you call `model.log(...)`. Those rows go to
`history.jsonl` in the run directory and, if W&B logging is enabled, to W&B.
`history.jsonl` in the run directory and, if W&B or Trackio logging is enabled,
to the configured tracker.

Use this page for three things:

Expand Down Expand Up @@ -105,6 +106,26 @@ A few useful patterns:

ART flushes builder-managed metrics on the next `model.log(...)` call.

## Log trajectories as Trackio traces

Set `report_metrics=["trackio"]` on your model to log each trajectory as a
`trackio.Trace` when `model.log(...)` receives trajectory groups. ART also sends
the metrics row for the same step to Trackio.

```python
model = art.Model(
name="my-model",
project="my-project",
report_metrics=["trackio"],
)

await model.log(train_groups, split="train", step=step)
```

Each Trackio trace includes the trajectory's OpenAI-style messages plus metadata
for the split, step, group index, trajectory index, reward, trajectory metrics,
trajectory metadata, and logs.

## Track judge and API costs

Use `@track_api_cost` when a function returns a provider response object with
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ allowed-unresolved-imports = [
"torch.**",
"torchao.**",
"transformers.**",
"trackio",
"trackio.**",
"trl.**",
"unsloth.**",
"unsloth_zoo.**",
Expand Down
70 changes: 70 additions & 0 deletions src/art/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class Model(
_s3_prefix: str | None = None
_openai_client: AsyncOpenAI | None = None
_wandb_run: Optional["Run"] = None # Private, for lazy wandb initialization
_trackio_initialized: bool
_wandb_defined_metrics: set[str]
_wandb_config: dict[str, Any]
_run_start_time: float
Expand Down Expand Up @@ -241,6 +242,7 @@ def __init__(

def _init_runtime_state(self) -> None:
object.__setattr__(self, "_wandb_defined_metrics", set())
object.__setattr__(self, "_trackio_initialized", False)
object.__setattr__(self, "_wandb_config", {})
object.__setattr__(self, "_run_start_time", time.time())
object.__setattr__(self, "_run_start_monotonic", time.monotonic())
Expand Down Expand Up @@ -628,6 +630,26 @@ def _get_wandb_run(self) -> Optional["Run"]:
self._sync_wandb_config(run)
return self._wandb_run

def _get_trackio(self) -> Any:
"""Get or initialize the Trackio logger for this model."""
try:
import trackio
except ImportError as exc:
raise ImportError(
"Trackio logging requires the `trackio` package. "
"Install it with `pip install trackio` or remove 'trackio' "
"from report_metrics."
) from exc

if not self._trackio_initialized:
trackio.init(
project=self.project,
name=self.name,
config=self._wandb_config or None,
)
object.__setattr__(self, "_trackio_initialized", True)
return trackio

def _log_metrics(
self,
metrics: dict[str, float],
Expand Down Expand Up @@ -682,6 +704,53 @@ def _log_metrics(
# which preserves out-of-order eval logging.
run.log(prefixed)

should_log_trackio = (
self.report_metrics is not None and "trackio" in self.report_metrics
)
if should_log_trackio:
self._get_trackio().log(prefixed, step=step)

def _trackio_trace_messages(self, trajectory: Trajectory) -> list[dict[str, Any]]:
messages = []
for message in trajectory.for_logging()["messages"]:
messages.append(
{key: value for key, value in message.items() if key != "trainable"}
)
return messages

def _log_trackio_traces(
self,
trajectory_groups: list[TrajectoryGroup],
*,
split: str,
step: int,
) -> None:
if self.report_metrics is None or "trackio" not in self.report_metrics:
return

trackio = self._get_trackio()
traces = []
for group_index, group in enumerate(trajectory_groups):
for trajectory_index, trajectory in enumerate(group.trajectories):
traces.append(
trackio.Trace(
messages=self._trackio_trace_messages(trajectory),
metadata={
"split": split,
"step": step,
"group_index": group_index,
"trajectory_index": trajectory_index,
"reward": trajectory.reward,
"metrics": trajectory.metrics,
"metadata": trajectory.metadata,
"logs": trajectory.logs,
},
)
)

if traces:
trackio.log({f"{split}/trajectories": traces}, step=step)

def _define_wandb_step_metrics(self, keys: Iterable[str]) -> None:
run = self._wandb_run
if run is None or run._is_finished:
Expand Down Expand Up @@ -913,6 +982,7 @@ async def log(
write_trajectory_groups_parquet(
trajectory_groups, f"{trajectories_dir}/{file_name}"
)
self._log_trackio_traces(trajectory_groups, split=split, step=step)

# 2. Calculate aggregate metrics (excluding additive costs)
reward_key = "reward"
Expand Down
97 changes: 97 additions & 0 deletions tests/unit/test_frontend_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,103 @@ async def test_step_numbering_format(self, tmp_path: Path):
).exists()


class TestTrackioLogging:
"""Test Trackio trace logging integration."""

@pytest.mark.asyncio
async def test_model_log_writes_trackio_traces_when_enabled(self, tmp_path: Path):
fake_trackio = MagicMock()
fake_trackio.Trace.side_effect = lambda messages, metadata=None: {
"_type": "trackio.trace",
"messages": messages,
"metadata": metadata or {},
}

model = Model(
name="test-model",
project="test-project",
base_path=str(tmp_path),
report_metrics=["trackio"],
)
trajectory_groups = [
TrajectoryGroup(
trajectories=[
Trajectory(
reward=0.8,
metrics={"correct": 1.0},
metadata={"scenario_id": "scenario-1"},
messages_and_choices=[
{"role": "user", "content": "What is 2 + 2?"},
{"role": "assistant", "content": "4"},
],
logs=["finished"],
)
],
)
]

with patch.dict("sys.modules", {"trackio": fake_trackio}):
await model.log(trajectory_groups, split="val", step=3)

fake_trackio.init.assert_called_once_with(
project="test-project",
name="test-model",
config=None,
)
fake_trackio.Trace.assert_called_once()
trace_kwargs = fake_trackio.Trace.call_args.kwargs
assert trace_kwargs["messages"] == [
{"role": "user", "content": "What is 2 + 2?"},
{"role": "assistant", "content": "4"},
]
assert trace_kwargs["metadata"]["split"] == "val"
assert trace_kwargs["metadata"]["step"] == 3
assert trace_kwargs["metadata"]["reward"] == 0.8
assert trace_kwargs["metadata"]["metrics"] == {"correct": 1.0}
assert trace_kwargs["metadata"]["metadata"] == {"scenario_id": "scenario-1"}
assert trace_kwargs["metadata"]["logs"] == ["finished"]

trace_log_call = next(
call
for call in fake_trackio.log.call_args_list
if "val/trajectories" in call.args[0]
)
assert trace_log_call.kwargs == {"step": 3}

@pytest.mark.asyncio
async def test_model_log_does_not_import_trackio_unless_enabled(
self, tmp_path: Path
):
fake_trackio = MagicMock()
model = Model(
name="test-model",
project="test-project",
base_path=str(tmp_path),
report_metrics=[],
)

with patch.dict("sys.modules", {"trackio": fake_trackio}):
await model.log(
[
TrajectoryGroup(
trajectories=[
Trajectory(
reward=0.5,
messages_and_choices=[
{"role": "user", "content": "hello"}
],
)
]
)
],
split="val",
step=1,
)

fake_trackio.init.assert_not_called()
fake_trackio.log.assert_not_called()


class TestMetricCalculation:
"""Test metric calculation and formatting."""

Expand Down