From df743959827c9c53c3a3be47029797f280b1b0ec Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Wed, 20 May 2026 18:06:47 -0700 Subject: [PATCH] Add Trackio trace logging --- docs/features/tracking-metrics.mdx | 23 ++++++- pyproject.toml | 2 + src/art/model.py | 70 +++++++++++++++++++++ tests/unit/test_frontend_logging.py | 97 +++++++++++++++++++++++++++++ 4 files changed, 191 insertions(+), 1 deletion(-) diff --git a/docs/features/tracking-metrics.mdx b/docs/features/tracking-metrics.mdx index a800569f3..d5dd40d27 100644 --- a/docs/features/tracking-metrics.mdx +++ b/docs/features/tracking-metrics.mdx @@ -6,7 +6,8 @@ icon: "chart-line" --- ART writes a metrics row every time you call `model.log(...)`. Those rows go to -`history.jsonl` in the run directory and, if W&B logging is enabled, to W&B. +`history.jsonl` in the run directory and, if W&B or Trackio logging is enabled, +to the configured tracker. Use this page for three things: @@ -105,6 +106,26 @@ A few useful patterns: ART flushes builder-managed metrics on the next `model.log(...)` call. +## Log trajectories as Trackio traces + +Set `report_metrics=["trackio"]` on your model to log each trajectory as a +`trackio.Trace` when `model.log(...)` receives trajectory groups. ART also sends +the metrics row for the same step to Trackio. + +```python +model = art.Model( + name="my-model", + project="my-project", + report_metrics=["trackio"], +) + +await model.log(train_groups, split="train", step=step) +``` + +Each Trackio trace includes the trajectory's OpenAI-style messages plus metadata +for the split, step, group index, trajectory index, reward, trajectory metrics, +trajectory metadata, and logs. + ## Track judge and API costs Use `@track_api_cost` when a function returns a provider response object with diff --git a/pyproject.toml b/pyproject.toml index 999b25d20..0109c5b59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -223,6 +223,8 @@ allowed-unresolved-imports = [ "torch.**", "torchao.**", "transformers.**", + "trackio", + "trackio.**", "trl.**", "unsloth.**", "unsloth_zoo.**", diff --git a/src/art/model.py b/src/art/model.py index 902f337e0..6e2e44e94 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -199,6 +199,7 @@ class Model( _s3_prefix: str | None = None _openai_client: AsyncOpenAI | None = None _wandb_run: Optional["Run"] = None # Private, for lazy wandb initialization + _trackio_initialized: bool _wandb_defined_metrics: set[str] _wandb_config: dict[str, Any] _run_start_time: float @@ -241,6 +242,7 @@ def __init__( def _init_runtime_state(self) -> None: object.__setattr__(self, "_wandb_defined_metrics", set()) + object.__setattr__(self, "_trackio_initialized", False) object.__setattr__(self, "_wandb_config", {}) object.__setattr__(self, "_run_start_time", time.time()) object.__setattr__(self, "_run_start_monotonic", time.monotonic()) @@ -628,6 +630,26 @@ def _get_wandb_run(self) -> Optional["Run"]: self._sync_wandb_config(run) return self._wandb_run + def _get_trackio(self) -> Any: + """Get or initialize the Trackio logger for this model.""" + try: + import trackio + except ImportError as exc: + raise ImportError( + "Trackio logging requires the `trackio` package. " + "Install it with `pip install trackio` or remove 'trackio' " + "from report_metrics." + ) from exc + + if not self._trackio_initialized: + trackio.init( + project=self.project, + name=self.name, + config=self._wandb_config or None, + ) + object.__setattr__(self, "_trackio_initialized", True) + return trackio + def _log_metrics( self, metrics: dict[str, float], @@ -682,6 +704,53 @@ def _log_metrics( # which preserves out-of-order eval logging. run.log(prefixed) + should_log_trackio = ( + self.report_metrics is not None and "trackio" in self.report_metrics + ) + if should_log_trackio: + self._get_trackio().log(prefixed, step=step) + + def _trackio_trace_messages(self, trajectory: Trajectory) -> list[dict[str, Any]]: + messages = [] + for message in trajectory.for_logging()["messages"]: + messages.append( + {key: value for key, value in message.items() if key != "trainable"} + ) + return messages + + def _log_trackio_traces( + self, + trajectory_groups: list[TrajectoryGroup], + *, + split: str, + step: int, + ) -> None: + if self.report_metrics is None or "trackio" not in self.report_metrics: + return + + trackio = self._get_trackio() + traces = [] + for group_index, group in enumerate(trajectory_groups): + for trajectory_index, trajectory in enumerate(group.trajectories): + traces.append( + trackio.Trace( + messages=self._trackio_trace_messages(trajectory), + metadata={ + "split": split, + "step": step, + "group_index": group_index, + "trajectory_index": trajectory_index, + "reward": trajectory.reward, + "metrics": trajectory.metrics, + "metadata": trajectory.metadata, + "logs": trajectory.logs, + }, + ) + ) + + if traces: + trackio.log({f"{split}/trajectories": traces}, step=step) + def _define_wandb_step_metrics(self, keys: Iterable[str]) -> None: run = self._wandb_run if run is None or run._is_finished: @@ -913,6 +982,7 @@ async def log( write_trajectory_groups_parquet( trajectory_groups, f"{trajectories_dir}/{file_name}" ) + self._log_trackio_traces(trajectory_groups, split=split, step=step) # 2. Calculate aggregate metrics (excluding additive costs) reward_key = "reward" diff --git a/tests/unit/test_frontend_logging.py b/tests/unit/test_frontend_logging.py index 9e72f82d3..1517c447f 100644 --- a/tests/unit/test_frontend_logging.py +++ b/tests/unit/test_frontend_logging.py @@ -305,6 +305,103 @@ async def test_step_numbering_format(self, tmp_path: Path): ).exists() +class TestTrackioLogging: + """Test Trackio trace logging integration.""" + + @pytest.mark.asyncio + async def test_model_log_writes_trackio_traces_when_enabled(self, tmp_path: Path): + fake_trackio = MagicMock() + fake_trackio.Trace.side_effect = lambda messages, metadata=None: { + "_type": "trackio.trace", + "messages": messages, + "metadata": metadata or {}, + } + + model = Model( + name="test-model", + project="test-project", + base_path=str(tmp_path), + report_metrics=["trackio"], + ) + trajectory_groups = [ + TrajectoryGroup( + trajectories=[ + Trajectory( + reward=0.8, + metrics={"correct": 1.0}, + metadata={"scenario_id": "scenario-1"}, + messages_and_choices=[ + {"role": "user", "content": "What is 2 + 2?"}, + {"role": "assistant", "content": "4"}, + ], + logs=["finished"], + ) + ], + ) + ] + + with patch.dict("sys.modules", {"trackio": fake_trackio}): + await model.log(trajectory_groups, split="val", step=3) + + fake_trackio.init.assert_called_once_with( + project="test-project", + name="test-model", + config=None, + ) + fake_trackio.Trace.assert_called_once() + trace_kwargs = fake_trackio.Trace.call_args.kwargs + assert trace_kwargs["messages"] == [ + {"role": "user", "content": "What is 2 + 2?"}, + {"role": "assistant", "content": "4"}, + ] + assert trace_kwargs["metadata"]["split"] == "val" + assert trace_kwargs["metadata"]["step"] == 3 + assert trace_kwargs["metadata"]["reward"] == 0.8 + assert trace_kwargs["metadata"]["metrics"] == {"correct": 1.0} + assert trace_kwargs["metadata"]["metadata"] == {"scenario_id": "scenario-1"} + assert trace_kwargs["metadata"]["logs"] == ["finished"] + + trace_log_call = next( + call + for call in fake_trackio.log.call_args_list + if "val/trajectories" in call.args[0] + ) + assert trace_log_call.kwargs == {"step": 3} + + @pytest.mark.asyncio + async def test_model_log_does_not_import_trackio_unless_enabled( + self, tmp_path: Path + ): + fake_trackio = MagicMock() + model = Model( + name="test-model", + project="test-project", + base_path=str(tmp_path), + report_metrics=[], + ) + + with patch.dict("sys.modules", {"trackio": fake_trackio}): + await model.log( + [ + TrajectoryGroup( + trajectories=[ + Trajectory( + reward=0.5, + messages_and_choices=[ + {"role": "user", "content": "hello"} + ], + ) + ] + ) + ], + split="val", + step=1, + ) + + fake_trackio.init.assert_not_called() + fake_trackio.log.assert_not_called() + + class TestMetricCalculation: """Test metric calculation and formatting."""