diff --git a/README.md b/README.md index 40864b9..c3659ca 100644 --- a/README.md +++ b/README.md @@ -331,7 +331,6 @@ invoked over a protocol boundary. - [Remote Agent](examples/remote_agent) implements AX's native `AgentService` directly. - [ADK Agent (Python)](examples/adk_agent) runs a Google ADK agent as a remote agent. - [A2A Agent](examples/a2a_agent) connects agents that speak the [A2A protocol](https://github.com/a2aproject/A2A) through AX's A2A bridge. -- [Colab Agents (Experimental)](examples/colab_agent) runs Python scripts or notebooks in a remote Google Colab session. Please note that AX is actively developing its resumable streaming and agent communication protocols; these interfaces will change before a stable release. diff --git a/ax.yaml b/ax.yaml index c9c4cb4..8a34b93 100644 --- a/ax.yaml +++ b/ax.yaml @@ -56,20 +56,3 @@ registry: # credential_env: "CODING_AGENT_AUTH_TOKEN" # metadata: # version: "1.0" - - # colab_agents: - # - id: "plotter" - # name: "Function Plotter" - # description: "Plots mathematical functions on a Colab session." - # local_file: "./examples/colab_agent/plot.py" - # requirements: "./examples/colab_agent/requirements.txt" - # input_flag: "input" - # output_image: "./examples/colab_agent/plot.png" - # output_drive_path: "MyDrive/notebooks/plot.ipynb" - # - # - id: "data-analysis" - # name: "Data Analysis" - # description: "Analyzes data using a Colab notebook on Google Drive." - # drive_file: "MyDrive/notebooks/data_analysis.ipynb" - # input_flag: "query" - # output_image: "./examples/colab_agent/chart.png" diff --git a/cmd/ax/internal/cliutil/cliutil.go b/cmd/ax/internal/cliutil/cliutil.go index f2dc728..a755803 100644 --- a/cmd/ax/internal/cliutil/cliutil.go +++ b/cmd/ax/internal/cliutil/cliutil.go @@ -103,12 +103,6 @@ func NewControllerFromConfig(ctx context.Context, cfg *Config) (*controller.Cont } } - for _, agentCfg := range cfg.Registry.ColabAgents { - if err := c.Registry().RegisterColab(agentCfg); err != nil { - return nil, fmt.Errorf("failed to register colab agent %s: %w", agentCfg.ID, err) - } - } - return c, nil } diff --git a/examples/colab_agent/README.md b/examples/colab_agent/README.md deleted file mode 100644 index 52813bb..0000000 --- a/examples/colab_agent/README.md +++ /dev/null @@ -1,162 +0,0 @@ -# Colab Agents - -AX supports executing Python scripts and Jupyter notebooks on Google Colab sessions via the [colab CLI](https://github.com/googlecolab/google-colab-cli). Colab agents provision ephemeral sessions with optional GPU/TPU accelerators, run agent code remotely, stream output back in real time, and tear down the session on completion. - -## Prerequisites - -- The `colab` CLI installed and available in your `PATH`. -- Application Default Credentials (ADC) authenticated. - - ```sh - gcloud auth application-default login \ - --scopes=openid,https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/userinfo.email,https://www.googleapis.com/auth/colaboratory - ``` - -## Configuration - -Colab agents are configured in `ax.yaml` under `registry.colab_agents`. Two modes are supported: - -### Python script (local file) - -A local `.py` file is uploaded to the Colab VM and executed via `!python`. - -```yaml -registry: - colab_agents: - - id: "plotter" - name: "Function Plotter" - description: "Plots mathematical functions on a Colab session." - local_file: "./examples/colab_agent/plot.py" - accelerator: "tpu-v5e1" - requirements: "./examples/colab_agent/requirements.txt" - input_flag: "input" # passed as --input to the script - output_image: "./plot.png" # downloaded from /content/plot.png on the VM - output_drive_path: "MyDrive/notebooks/plot.ipynb" # .py converted to .ipynb, saved to Drive -``` - -### Jupyter notebook (file on Google Drive) - -A notebook on Google Drive is executed via `%run` after mounting Drive. -The input is set as a Python variable in the kernel before the notebook runs. - -```yaml -registry: - colab_agents: - - id: "data-analysis" - name: "Data Analysis" - description: "Analyzes data using a Colab notebook on Google Drive." - drive_file: "MyDrive/notebooks/data_analysis.ipynb" # Drive-relative path - input_flag: "query" # set as: query = '' - output_image: "./chart.png" -``` - -### Configuration reference - -| Field | Required | Description | -|-------|----------|-------------| -| `id` | Yes | Unique agent identifier. | -| `name` | Yes | Human-readable name shown in the planner. | -| `description` | Yes | Description of the agent's capabilities (used by the planner to select agents). | -| `local_file` | One of `local_file` or `drive_file` | Path to a local `.py` or `.ipynb` file on your machine. Uploaded to `/content/` on the VM. | -| `drive_file` | One of `local_file` or `drive_file` | Drive-relative path to a file on Google Drive (e.g. `MyDrive/notebooks/nb.ipynb`). Requires `drive_mount_path`. | -| `accelerator` | No | Hardware accelerator, e.g. `tpu-v5e1` or `gpu-A100`. | -| `drive_mount_path` | No | Path to mount Google Drive on the VM. Defaults to the Colab CLI's standard path (`/content/drive`). Only needed if you want a non-standard mount point. Drive is mounted automatically when `drive_file` or `output_drive_path` is used. Prompts for OAuth authorization on first use. | -| `requirements` | No | Path to a local `requirements.txt`. Packages are installed on the VM before execution. | -| `input_flag` | No | Name of the input parameter. For `.py` files, passed as `--`. For `.ipynb` files, set as a Python variable before `%run`. | -| `output_image` | No | Local path to download an output image to. The remote path is `/content/` + basename (e.g. `./plot.png` downloads from `/content/plot.png`). | -| `output_drive_path` | No | Drive-relative path to save the script converted to a `.ipynb` notebook (e.g. `MyDrive/notebooks/out.ipynb`). The `.py` source is placed in a single code cell. Only supported with `local_file`. | -| `metadata` | No | Optional key-value metadata. | - -## Execution flow - -### Python scripts (.py) - -``` -1. colab new -s [-tpu|-gpu ] -2. colab drivemount -s (if drive_mount_path set) -3. colab install -s -r (if requirements set) -4. colab upload /content/ -5. echo "!python -u /content/ -- ''" | colab exec -s -6. colab download /content/ (if output_image set) -7. colab exec: convert .py to .ipynb, save to Drive (if output_drive_path set) -8. colab stop -s -``` - -### Jupyter notebooks (.ipynb) - -``` -1. colab new -s [-tpu|-gpu ] -2. colab drivemount -s (if drive_mount_path set) -3. colab install -s -r (if requirements set) -4. colab upload /content/ (skipped for drive_file) -5. echo " = ''" | colab exec -s -6. echo "%run " | colab exec -s -7. colab download /content/ (if output_image set) -8. colab stop -s -``` - -### Session timeout retry - -If a Colab session is terminated due to idle timeout (e.g. while waiting for Drive authorization), AX automatically recreates the session and retries once. - -## Examples - -AX includes two examples in `examples/colab_agent/`: - -### Function plotter (Python script) - -`examples/colab_agent/plot.py` plots mathematical expressions using numpy and matplotlib. - -```bash -ax exec --agent plotter --input "sin(x) * exp(-x/10)" -``` - -### Data analysis (Jupyter notebook) - -`examples/colab_agent/data_analysis.ipynb` generates synthetic revenue data and produces a chart. - -```bash -ax exec --agent data-analysis --input "Show monthly revenue trend for 2024" -``` - -## Writing a Colab agent - -### Python script - -Create a `.py` file that accepts input via `argparse`: - -```python -import argparse - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--input", required=True) - parser.add_argument("--output", default="./plot.png") - args = parser.parse_args() - - # Your agent logic here. - print(f"Processing: {args.input}") - -if __name__ == "__main__": - main() -``` - -AX runs scripts with `python -u` (unbuffered stdout). - -### Jupyter notebook - -Create an `.ipynb` notebook that reads the input variable: - -```python -# The AX colab agent sets this variable before %run. -# Fall back to a default for standalone use. -try: - input -except NameError: - input = "default query" - -# Your notebook logic here. -print(f"Processing: {input}") -``` - -Note: Notebooks run via `%run` in the IPython kernel, not as a subprocess with `-u`. If you need real-time streaming from a notebook, use `flush=True` on `print()` calls. diff --git a/examples/colab_agent/data_analysis.ipynb b/examples/colab_agent/data_analysis.ipynb deleted file mode 100644 index 556c07a..0000000 --- a/examples/colab_agent/data_analysis.ipynb +++ /dev/null @@ -1,164 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Data Analysis Agent\n", - "\n", - "This notebook is an example Colab agent for AX. It reads a query from\n", - "the `input` variable (set by the AX colab agent before `%run`), generates\n", - "synthetic data, performs analysis, and produces a chart.\n", - "\n", - "## Usage via AX\n", - "\n", - "```yaml\n", - "registry:\n", - " colab_agents:\n", - " - id: \"data-analysis\"\n", - " name: \"Data Analysis\"\n", - " description: \"Analyzes data using a Colab notebook on Google Drive.\"\n", - " remote_file: \"/content/drive/MyDrive/notebooks/data_analysis.ipynb\"\n", - " drive_mount_path: \"/content/drive\"\n", - " input_flag: \"query\"\n", - " output_image: \"./examples/colab_agent/chart.png\"\n", - "```\n", - "\n", - "Then run:\n", - "```\n", - "ax exec --agent data-analysis --input \"Show monthly revenue trend for 2024\"\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# The `query` variable is set by the AX colab agent before %run.\n", - "# When running standalone, set it manually.\n", - "query = \"Show monthly revenue trend for 2024\"\n", - "try:\n", - " query\n", - "except NameError:\n", - " query = \"Show monthly revenue trend for 2024\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(f\"Query: {input}\", flush=True)\n", - "print(\"\", flush=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import time\n", - "import numpy as np\n", - "import matplotlib\n", - "matplotlib.use('Agg')\n", - "import matplotlib.pyplot as plt" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Step 1: Generate synthetic monthly data.\n", - "print(\"[1/3] Generating data...\", flush=True)\n", - "time.sleep(1)\n", - "\n", - "months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',\n", - " 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']\n", - "\n", - "np.random.seed(42)\n", - "base = np.linspace(100, 180, 12)\n", - "noise = np.random.normal(0, 10, 12)\n", - "revenue = base + noise\n", - "revenue = np.maximum(revenue, 0)\n", - "\n", - "print(f\" Generated 12 months of revenue data\", flush=True)\n", - "print(f\" Range: ${revenue.min():.0f}K - ${revenue.max():.0f}K\", flush=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Step 2: Compute statistics.\n", - "print(\"[2/3] Analyzing...\", flush=True)\n", - "time.sleep(2)\n", - "\n", - "avg = np.mean(revenue)\n", - "growth = ((revenue[-1] - revenue[0]) / revenue[0]) * 100\n", - "best_month = months[np.argmax(revenue)]\n", - "worst_month = months[np.argmin(revenue)]\n", - "\n", - "print(f\" Average monthly revenue: ${avg:.0f}K\", flush=True)\n", - "print(f\" Year-over-year growth: {growth:.1f}%\", flush=True)\n", - "print(f\" Best month: {best_month} (${revenue.max():.0f}K)\", flush=True)\n", - "print(f\" Worst month: {worst_month} (${revenue.min():.0f}K)\", flush=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Step 3: Generate chart.\n", - "print(\"[3/3] Generating chart...\", flush=True)\n", - "time.sleep(2)\n", - "\n", - "fig, ax = plt.subplots(figsize=(10, 6))\n", - "bars = ax.bar(months, revenue, color='#2563eb', alpha=0.8)\n", - "ax.plot(months, revenue, color='#1d4ed8', linewidth=2, marker='o', markersize=6)\n", - "ax.axhline(y=avg, color='#dc2626', linestyle='--', linewidth=1, label=f'Average: ${avg:.0f}K')\n", - "ax.set_xlabel('Month', fontsize=12)\n", - "ax.set_ylabel('Revenue ($K)', fontsize=12)\n", - "ax.set_title('Monthly Revenue Trend - 2024', fontsize=14)\n", - "ax.legend(fontsize=11)\n", - "ax.grid(axis='y', alpha=0.3)\n", - "\n", - "output_path = './chart.png'\n", - "fig.savefig(output_path, dpi=150, bbox_inches='tight')\n", - "plt.close(fig)\n", - "\n", - "print(f\" Chart saved to {output_path}\", flush=True)\n", - "print(\"\", flush=True)\n", - "print(\"Done.\", flush=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# from IPython.display import Image, display\n", - "\n", - "# # Display the chart saved at ./chart.png\n", - "# display(Image(filename='./chart.png'))" - ] - } - ], - "metadata": { - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/colab_agent/plot.py b/examples/colab_agent/plot.py deleted file mode 100644 index e1e8d2e..0000000 --- a/examples/colab_agent/plot.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Example Colab agent for AX: Function Plotter. - -Given a mathematical expression, this agent evaluates it over a range -and generates a plot using matplotlib. It demonstrates a Colab agent -that requires external packages (numpy, matplotlib) and produces a -graph saved to the Colab VM. - -Supported functions: all numpy functions (sin, cos, exp, log, sqrt, -abs, pi, e, etc.) and standard operators (+, -, *, /, **, etc.). -The variable is `x`. - -Usage (standalone): - pip install numpy matplotlib - python plot.py --input "sin(x) * exp(-x/10)" - -Usage (via AX): - Configure in ax.yaml: - - registry: - colab_agents: - - id: "plotter" - name: "Function Plotter" - description: "Plots mathematical functions." - local_file: "./examples/colab_agent/plot.py" - requirements: "./examples/colab_agent/requirements.txt" - input_flag: "input" - output_image: "./examples/colab_agent/plot.png" - - Then run: - ax exec --agent plotter --input "sin(x) * exp(-x/10)" -""" - -import argparse - -import time - -import matplotlib -import numpy as np - -matplotlib.use("Agg") -import matplotlib.pyplot as plt # noqa: E402 - - -def plot(expression: str, output_path: str = "./plot.png") -> None: - """Evaluate a math expression over x and save a plot.""" - print(f"Expression: y = {expression}") - print() - - # Step 1: Generate x values. - print("[1/4] Generating sample points...") - time.sleep(1) - x = np.linspace(-10, 10, 2000) - print(f" x range: [-10, 10], 2000 points") - - # Step 2: Evaluate the expression. - print("[2/4] Evaluating expression...") - time.sleep(2) - # Expose all numpy functions (sin, cos, exp, log, sqrt, pi, e, etc.) - safe_ns = {name: getattr(np, name) for name in dir(np) if not name.startswith("_")} - safe_ns["x"] = x - try: - y = eval(expression, {"__builtins__": {}}, safe_ns) # noqa: S307 - except Exception as e: - print(f" Error: {e}") - return - - y = np.asarray(y, dtype=float) - finite = y[np.isfinite(y)] - if len(finite) == 0: - print(" Error: expression produced no finite values") - return - print(f" y range: [{finite.min():.4f}, {finite.max():.4f}]") - - # Step 3: Create the plot. - print("[3/4] Rendering plot...") - time.sleep(2) - fig, ax = plt.subplots(figsize=(10, 6)) - ax.plot(x, y, color="#2563eb", linewidth=2) - ax.set_xlabel("x", fontsize=12) - ax.set_ylabel("y", fontsize=12) - ax.set_title(f"y = {expression}", fontsize=14) - ax.grid(True, alpha=0.3) - ax.axhline(y=0, color="black", linewidth=0.5) - ax.axvline(x=0, color="black", linewidth=0.5) - - # Clamp y-axis to avoid extreme values from singularities. - margin = (finite.max() - finite.min()) * 0.1 or 1.0 - ax.set_ylim(finite.min() - margin, finite.max() + margin) - print(f" Plot size: 10x6 inches, 150 dpi") - - # Step 4: Save. - print("[4/4] Saving...") - time.sleep(1) - fig.savefig(output_path, dpi=150, bbox_inches="tight") - plt.close(fig) - print(f" Saved to {output_path}") - print() - print("Done.") - - -def main(): - parser = argparse.ArgumentParser(description="AX Colab function plotter") - parser.add_argument( - "--input", - required=True, - help="Mathematical expression to plot (variable: x)", - ) - parser.add_argument( - "--output", - default="./plot.png", - help="Path to save the output image (default: ./plot.png)", - ) - args = parser.parse_args() - plot(args.input, args.output) - - -if __name__ == "__main__": - main() diff --git a/examples/colab_agent/requirements.txt b/examples/colab_agent/requirements.txt deleted file mode 100644 index aa094d9..0000000 --- a/examples/colab_agent/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -numpy -matplotlib diff --git a/internal/config/config.go b/internal/config/config.go index 50479a4..7491cff 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -37,7 +37,6 @@ type Config struct { // RegistryConfig allows registring agents. type RegistryConfig struct { RemoteAgents []RemoteAgentConfig `yaml:"remote_agents,omitempty"` - ColabAgents []ColabAgentConfig `yaml:"colab_agents,omitempty"` } // SubstrateConfig configures the Substrate integration. @@ -115,21 +114,7 @@ type A2AConfig struct { Stateless bool `yaml:"stateless,omitempty"` // Send full history each turn (default: stateful) } -// ColabAgentConfig configures a Colab agent to register on startup. -type ColabAgentConfig struct { - ID string `yaml:"id"` // Unique agent identifier - Name string `yaml:"name"` // Human-readable name - Description string `yaml:"description"` // Description of agent capabilities - LocalFile string `yaml:"local_file,omitempty"` // Path to local .py or .ipynb file (uploaded to VM) - DriveFile string `yaml:"drive_file,omitempty"` // Path to .ipynb file in Google Drive (e.g. MyDrive/notebooks/nb.ipynb) - Accelerator string `yaml:"accelerator,omitempty"` // Accelerator type (optional), e.g. "tpu-v5e1", "gpu-A100" - DriveMountPath string `yaml:"drive_mount_path,omitempty"` // Path to mount Google Drive (optional), default: "/content/drive" - Requirements string `yaml:"requirements,omitempty"` // Path to requirements.txt (optional) - InputFlag string `yaml:"input_flag,omitempty"` // Input parameter name (optional). For .py, passed as --. For .ipynb, set as a variable before %run - OutputImage string `yaml:"output_image,omitempty"` // Local path to download the output image to - OutputDrivePath string `yaml:"output_drive_path,omitempty"` // Google Drive path to save converted .ipynb (e.g. MyDrive/notebooks/out.ipynb) - Metadata map[string]string `yaml:"metadata,omitempty"` // Optional metadata -} + type LocalAgentConfig struct { ID string `yaml:"id"` // Unique agent identifier diff --git a/internal/controller/registry.go b/internal/controller/registry.go index c96a1f1..e1fae30 100644 --- a/internal/controller/registry.go +++ b/internal/controller/registry.go @@ -142,45 +142,6 @@ func (r *Registry) registerA2A(ctx context.Context, cfg config.RemoteAgentConfig return nil } -// RegisterColab registers a Colab agent that executes a local Python file -// on a remote Colab session via the colab CLI. -func (r *Registry) RegisterColab(cfg config.ColabAgentConfig) error { - r.mu.Lock() - defer r.mu.Unlock() - - if err := validateID(cfg.ID); err != nil { - return err - } - - if _, ok := r.agents[cfg.ID]; ok { - return fmt.Errorf("agent %s already registered", cfg.ID) - } - - colabAgent, err := expagent.NewColabAgent(expagent.ColabAgentConfig{ - ID: cfg.ID, - LocalFile: cfg.LocalFile, - DriveFile: cfg.DriveFile, - Accelerator: cfg.Accelerator, - DriveMountPath: cfg.DriveMountPath, - Requirements: cfg.Requirements, - InputFlag: cfg.InputFlag, - OutputImage: cfg.OutputImage, - OutputDrivePath: cfg.OutputDrivePath, - }) - if err != nil { - return fmt.Errorf("failed to create colab agent: %w", err) - } - - r.agents[cfg.ID] = colabAgent - r.agentInfo[cfg.ID] = &agent.AgentInfo{ - ID: cfg.ID, - Name: cfg.Name, - Description: cfg.Description, - Metadata: cfg.Metadata, - } - - return nil -} // Get retrieves an agent by ID. func (r *Registry) Get(id string) (agent.Agent, error) { diff --git a/internal/experimental/agent/colab.go b/internal/experimental/agent/colab.go deleted file mode 100644 index a6e86f5..0000000 --- a/internal/experimental/agent/colab.go +++ /dev/null @@ -1,511 +0,0 @@ -// Copyright 2026 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package agent - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "fmt" - "log" - "os" - "os/exec" - "path/filepath" - "regexp" - "strings" - "sync" - "time" - - "github.com/google/ax/internal/agent" - "github.com/google/ax/proto" -) - -// ColabAgent implements the Agent interface by executing a Python file or -// Jupyter notebook on a remote Google Colab session via the colab CLI. -// -// Each Connect() call provisions a new ephemeral Colab session, sets up the -// environment (drive mount, package installation), uploads the file (if local), -// executes it with user input, and tears the session down on completion. -// -// Two execution modes are supported: -// - Python scripts (.py): executed via !python with input passed as a CLI flag. -// - Notebooks (.ipynb): executed via %run with input set as a Python variable. -type ColabAgent struct { - config ColabAgentConfig - acceleratorArgs []string // parsed from config.Accelerator, e.g. ["-tpu", "v5e1"] - notebook bool // true if the file is a .ipynb notebook - mu sync.Mutex - activeSessions map[string]struct{} // all currently running sessions, for Close() cleanup -} - -// ColabAgentConfig configures a Colab agent. -type ColabAgentConfig struct { - ID string - LocalFile string // Path to a local .py or .ipynb file (uploaded to VM) - DriveFile string // Path to .ipynb file in Google Drive (e.g. MyDrive/notebooks/nb.ipynb) - Accelerator string // Accelerator type (optional), e.g. "tpu-v5e1", "gpu-A100" - DriveMountPath string // Path to mount Google Drive (optional), default: "/content/drive" - Requirements string // Path to requirements.txt (optional) - InputFlag string // Name of the input parameter (optional). For .py files, passed as --. For .ipynb, set as a variable before %run - OutputImage string // Local path to download the output image to - OutputDrivePath string // Google Drive path to save converted .ipynb (e.g. MyDrive/notebooks/out.ipynb) -} - -// validIdentifier matches a valid Python identifier (also valid as a CLI flag name). -// Used to validate InputFlag to prevent code injection via malformed ax.yaml. -var validIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) - -const ( - // maxRetries is the number of times Connect will retry if the Colab session - // is terminated due to idle timeout. - maxRetries = 1 - - // defaultDriveMountPath is the standard mount path used by the Colab CLI - // when no path is specified. Used as a fallback for filepath.Join when - // drive_mount_path is omitted in the config. - defaultDriveMountPath = "/content/drive" -) - -// NewColabAgent creates a new ColabAgent. It validates that the colab CLI -// binary is available in PATH and that exactly one of LocalFile or RemoteFile -// is set. -func NewColabAgent(config ColabAgentConfig) (*ColabAgent, error) { - // Validate that the colab CLI is installed. - if _, err := exec.LookPath("colab"); err != nil { - return nil, fmt.Errorf("colab CLI not found in PATH: %w", err) - } - - // Exactly one of LocalFile or DriveFile must be set. - if config.LocalFile == "" && config.DriveFile == "" { - return nil, fmt.Errorf("one of local_file or drive_file must be set") - } - if config.LocalFile != "" && config.DriveFile != "" { - return nil, fmt.Errorf("only one of local_file or drive_file can be set, not both") - } - - // OutputDrivePath is only supported for local files. - if config.OutputDrivePath != "" && config.LocalFile == "" { - return nil, fmt.Errorf("output_drive_path is only supported with local_file") - } - - // Validate the local file exists (Drive files are on the VM, can't check locally). - if config.LocalFile != "" { - if _, err := os.Stat(config.LocalFile); err != nil { - return nil, fmt.Errorf("local file %q not found: %w", config.LocalFile, err) - } - } - - // Parse the accelerator string into CLI flags. - accelArgs, err := parseAccelerator(config.Accelerator) - if err != nil { - return nil, err - } - - // Validate InputFlag is a safe identifier to prevent code injection. - // It's used directly in Python variable assignments and shell command flags. - if config.InputFlag != "" && !validIdentifier.MatchString(config.InputFlag) { - return nil, fmt.Errorf("invalid input_flag %q: must be a valid identifier (letters, digits, underscores)", config.InputFlag) - } - - // Detect notebook mode from file extension. - notebook := isNotebook(config.LocalFile) || isNotebook(config.DriveFile) - - return &ColabAgent{ - config: config, - acceleratorArgs: accelArgs, - notebook: notebook, - activeSessions: make(map[string]struct{}), - }, nil -} - -// isNotebook returns true if the file path has a .ipynb extension. -func isNotebook(path string) bool { - return strings.HasSuffix(strings.ToLower(path), ".ipynb") -} - -// parseAccelerator converts an accelerator string like "tpu-v5e1" or "gpu-A100" -// into colab CLI flags like ["-tpu", "v5e1"] or ["-gpu", "A100"]. -// An empty string returns nil (no accelerator flags). -func parseAccelerator(accel string) ([]string, error) { - if accel == "" { - return nil, nil - } - parts := strings.SplitN(accel, "-", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("invalid accelerator format %q: expected \"tpu-\" or \"gpu-\"", accel) - } - kind := strings.ToLower(parts[0]) - if kind != "tpu" && kind != "gpu" { - return nil, fmt.Errorf("invalid accelerator kind %q: must be \"tpu\" or \"gpu\"", kind) - } - return []string{"-" + kind, parts[1]}, nil -} - -// Connect provisions a Colab session, runs user code, and tears it down, -// retrying once if the session is lost mid-run (e.g. idle timeout). -func (a *ColabAgent) Connect(ctx context.Context, conversationID string, execID string, start *proto.AgentStart, e agent.Executor, o agent.OutputHandler) error { - sessionName := colabSessionName(a.config.ID, execID) - for attempt := 0; attempt <= maxRetries; attempt++ { - runErr, sessionDied := a.runWithSession(ctx, sessionName, start, o) - if runErr == nil { - break - } - if !sessionDied || attempt == maxRetries { - return runErr - } - log.Printf("Colab session %s timed out, retrying...", sessionName) - } - return nil -} - -// runWithSession creates a Colab session, runs the user code, and stops the -// session on the way out. -func (a *ColabAgent) runWithSession(ctx context.Context, sessionName string, start *proto.AgentStart, o agent.OutputHandler) (runErr error, sessionDied bool) { - if err := a.createSession(ctx, sessionName); err != nil { - return err, false - } - defer func() { - if stopErr := a.stopSession(sessionName); stopErr != nil { - log.Printf("Warning: %v", stopErr) - } - }() - - runErr = a.run(ctx, sessionName, start, o) - if runErr != nil { - // Check whether the session died on its own (idle timeout). - sessionDied = !a.isSessionAlive(ctx, sessionName) - } - return runErr, sessionDied -} - -// createSession provisions a fresh Colab session and records it in the -// active sessions tracker. -func (a *ColabAgent) createSession(ctx context.Context, sessionName string) error { - args := append([]string{"new", "-s", sessionName}, a.acceleratorArgs...) - if _, err := a.runColab(ctx, args...); err != nil { - return fmt.Errorf("failed to create colab session: %w", err) - } - a.mu.Lock() - a.activeSessions[sessionName] = struct{}{} - a.mu.Unlock() - return nil -} - -// run executes the setup and agent code on an existing Colab session. -// This is called by Connect and may be retried if the session times out. -func (a *ColabAgent) run(ctx context.Context, sessionName string, start *proto.AgentStart, o agent.OutputHandler) error { - // Mount Google Drive if any Drive feature is used (drive_file, - // output_drive_path, or explicit drive_mount_path). This step is - // interactive because drivemount may prompt for OAuth authorization - // (the user must open a URL in their browser, authorize, then press Enter). - needsDrive := a.config.DriveMountPath != "" || a.config.DriveFile != "" || a.config.OutputDrivePath != "" - if needsDrive { - mountCtx, mountCancel := context.WithTimeout(context.Background(), 10*time.Minute) - defer mountCancel() - log.Println("Mounting Google Drive (follow the authorization prompt, then press Enter)...") - - // If drive_mount_path is set, pass it to the CLI. Otherwise let - // the Colab CLI use its default mount path. - args := []string{"drivemount", "-s", sessionName} - if a.config.DriveMountPath != "" { - args = append(args, a.config.DriveMountPath) - } - if err := a.runColabInteractive(mountCtx, args...); err != nil { - return fmt.Errorf("failed to mount drive: %w", err) - } - - // Use the Colab CLI's default mount path for filepath.Join - // when constructing VM paths for DriveFile and OutputDrivePath. - if a.config.DriveMountPath == "" { - a.config.DriveMountPath = defaultDriveMountPath - } - } - - // Install requirements if configured. - if a.config.Requirements != "" { - if _, err := a.runColab(ctx, "install", "-s", sessionName, "-r", a.config.Requirements); err != nil { - return fmt.Errorf("failed to install requirements: %w", err) - } - } - - // Determine the remote path for execution. If LocalFile is set, upload - // it to the VM first. If RemoteFile is set, use it directly (e.g. a - // notebook on Google Drive that is accessible after drivemount). - var remotePath string - if a.config.LocalFile != "" { - remotePath = "/content/" + filepath.Base(a.config.LocalFile) - if _, err := a.runColab(ctx, "upload", "-s", sessionName, a.config.LocalFile, remotePath); err != nil { - return fmt.Errorf("failed to upload %s: %w", a.config.LocalFile, err) - } - } else { - remotePath = filepath.Join(a.config.DriveMountPath, a.config.DriveFile) - } - - // Extract the latest user text from the message history. - userText := lastUserText(start.Messages) - pyInput, _ := json.Marshal(userText) - shellEscaped := strings.ReplaceAll(userText, "'", "'\\''") - - // The VM path for the output image. - var remoteImagePath string - if a.config.OutputImage != "" { - remoteImagePath = "/content/" + filepath.Base(a.config.OutputImage) - } - - if a.notebook { - // Notebook execution. - // If input_flag is set, set the input variable in the kernel first - // (output suppressed), then run the notebook. - if a.config.InputFlag != "" { - setVarCmd := fmt.Sprintf("%s = %s", a.config.InputFlag, pyInput) - if _, err := a.runColabExecBatch(ctx, sessionName, setVarCmd); err != nil { - return fmt.Errorf("failed to set input variable: %w", err) - } - } - - // Run the notebook (output streamed). - runCmd := fmt.Sprintf("%%run %s", remotePath) - if err := a.runColabExec(ctx, sessionName, runCmd, o); err != nil { - return err - } - } else { - // Python script execution (output streamed). - // -u disables stdout buffering so output streams line-by-line - // (Python block-buffers when stdout is not a TTY). - command := fmt.Sprintf("!python -u %s", remotePath) - - // If input_flag is set, pass user input as a CLI flag. - if a.config.InputFlag != "" { - command += fmt.Sprintf(" --%s '%s'", a.config.InputFlag, shellEscaped) - } - - // Pass the output image path to the script if configured. - if remoteImagePath != "" { - command += fmt.Sprintf(" --output '%s'", remoteImagePath) - } - - // Execute and stream output line-by-line. - if err := a.runColabExec(ctx, sessionName, command, o); err != nil { - return err - } - } - - // Download the output image from the Colab VM if configured. - if remoteImagePath != "" { - if _, err := a.runColab(ctx, "download", "-s", sessionName, remoteImagePath, a.config.OutputImage); err != nil { - return fmt.Errorf("failed to download output image: %w", err) - } - } - - // Convert the .py file to a .ipynb notebook and save it to Google Drive - // (local_file only). The output path is constructed from DriveMountPath + - // OutputDrivePath (e.g. /content/drive + MyDrive/nb.ipynb). Uses nbformat - // (pre-installed in Colab) to create a notebook with the script's source - // code as a single code cell. - if a.config.OutputDrivePath != "" && a.config.LocalFile != "" { - outputNotebookPath := filepath.Join(a.config.DriveMountPath, a.config.OutputDrivePath) - remotePathLit, _ := json.Marshal(remotePath) - outputNbLit, _ := json.Marshal(outputNotebookPath) - convertCmd := fmt.Sprintf( - "import nbformat; nb = nbformat.v4.new_notebook(); nb.cells.append(nbformat.v4.new_code_cell(open(%s).read())); nbformat.write(nb, %s)", - remotePathLit, outputNbLit, - ) - if _, err := a.runColabExecBatch(ctx, sessionName, convertCmd); err != nil { - return fmt.Errorf("failed to convert script to notebook: %w", err) - } - } - - return nil -} - -// stopSession stops a Colab session and removes it from the active sessions tracker. -func (a *ColabAgent) stopSession(sessionName string) error { - stopCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - if _, err := a.runColab(stopCtx, "stop", "-s", sessionName); err != nil { - return fmt.Errorf("failed to stop colab session %s: %w", sessionName, err) - } - a.mu.Lock() - delete(a.activeSessions, sessionName) - a.mu.Unlock() - return nil -} - -// isSessionAlive checks whether a Colab session still exists by running -// colab status. Returns false if the session was terminated (e.g. due to -// idle timeout). -func (a *ColabAgent) isSessionAlive(ctx context.Context, sessionName string) bool { - _, err := a.runColab(ctx, "status", "-s", sessionName) - return err == nil -} - -// newColabCmd builds an exec.Cmd for the colab CLI with the standard -// environment plus any extras. -func newColabCmd(ctx context.Context, extraEnv []string, args ...string) *exec.Cmd { - cmd := exec.CommandContext(ctx, "colab", args...) - cmd.Env = append(os.Environ(), "PYTHONWARNINGS=ignore") - cmd.Env = append(cmd.Env, extraEnv...) - return cmd -} - -// runColab executes a colab CLI command and returns its stdout. -// Only the exit code determines success or failure; stderr is included -// in the error message on failure but is not treated as an error by itself -// (setup commands like install write progress to stderr). -func (a *ColabAgent) runColab(ctx context.Context, args ...string) (string, error) { - if len(args) == 0 { - return "", fmt.Errorf("runColab requires at least one argument (subcommand)") - } - cmd := newColabCmd(ctx, nil, args...) - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - if err := cmd.Run(); err != nil { - return "", fmt.Errorf("colab %s: %w\nstderr: %s", args[0], err, stderr.String()) - } - return stdout.String(), nil -} - -// runColabInteractive executes a colab CLI command with the user's terminal -// connected for interactive I/O. This is required for commands like drivemount -// that may prompt for OAuth authorization (the user must open a URL, authorize, -// then press Enter). -// -// Both stdout and stderr from the colab process are routed to os.Stderr rather -// than os.Stdout. This is necessary because the ax CLI's display system -// (Display.DisplayOutput in cmd/ax/internal/display.go) writes to stdout and -// would interfere with the colab process's multi-line auth message -- only a -// small portion would be visible. Routing to stderr bypasses the display layer -// and ensures the full auth prompt is shown to the user. -func (a *ColabAgent) runColabInteractive(ctx context.Context, args ...string) error { - // Set a request timeout for the colab CLI. The default is 60s, which is too - // short for interactive commands like drivemount where the user needs time - // to authorize in the browser. - cmd := newColabCmd(ctx, []string{"REQUEST_TIMEOUT=600"}, args...) - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stderr // Use stderr to bypass ax CLI display layer (see comment above). - cmd.Stderr = os.Stderr - return cmd.Run() -} - -// runColabExecBatch pipes Python code to colab exec via stdin and waits for -// completion. Output is discarded. Used for setup commands like setting -// variables in the kernel before running a notebook. -func (a *ColabAgent) runColabExecBatch(ctx context.Context, sessionName, command string) (string, error) { - cmd := newColabCmd(ctx, nil, "exec", "-s", sessionName) - cmd.Stdin = strings.NewReader(command) - var stdoutBuf, stderrBuf bytes.Buffer - cmd.Stdout = &stdoutBuf - cmd.Stderr = &stderrBuf - if err := cmd.Run(); err != nil { - return "", fmt.Errorf("colab exec: %w\nstderr: %s", err, stderrBuf.String()) - } - return stdoutBuf.String(), nil -} - -// runColabExec pipes Python code to colab exec via stdin and streams stdout -// line-by-line to the OutputHandler as the script runs. Stderr is buffered -// and treated as an error after the process exits. -func (a *ColabAgent) runColabExec(ctx context.Context, sessionName, command string, o agent.OutputHandler) error { - cmd := newColabCmd(ctx, nil, "exec", "-s", sessionName) - cmd.Stdin = strings.NewReader(command) - - stdout, err := cmd.StdoutPipe() - if err != nil { - return fmt.Errorf("stdout pipe: %w", err) - } - var stderrBuf bytes.Buffer - cmd.Stderr = &stderrBuf - - if err := cmd.Start(); err != nil { - return fmt.Errorf("start colab exec: %w", err) - } - - // Stream stdout line-by-line to the output handler. - // Skip empty lines to avoid extra blank lines from colab exec - // (IPython's ! command often emits a trailing empty line). - scanner := bufio.NewScanner(stdout) - for scanner.Scan() { - line := scanner.Text() - if line == "" { - continue - } - if err := o(&proto.AgentOutputs{ - Messages: []*proto.Message{{ - Role: "assistant", - Content: &proto.Content{ - Type: &proto.Content_Text{ - Text: &proto.TextContent{Text: line}, - }, - }, - }}, - }); err != nil { - return fmt.Errorf("output handler: %w", err) - } - } - if err := scanner.Err(); err != nil { - return fmt.Errorf("reading stdout: %w", err) - } - - if err := cmd.Wait(); err != nil { - return fmt.Errorf("%w\nstderr: %s", err, stderrBuf.String()) - } - return nil -} - -// colabSessionName builds a deterministic session name from the agent ID -// and execution ID. Colab session names should be short and safe. -func colabSessionName(agentID, execID string) string { - safeID := strings.ReplaceAll(execID, "-", "") - if len(safeID) > 20 { - safeID = safeID[:20] - } - return fmt.Sprintf("ax-%s-%s", agentID, safeID) -} - -// lastUserText extracts the text from the most recent user message -// in the message history. -func lastUserText(messages []*proto.Message) string { - for i := len(messages) - 1; i >= 0; i-- { - msg := messages[i] - if msg.Role == "user" { - if t := msg.GetContent().GetText(); t != nil { - return t.Text - } - } - } - return "" -} - -// Close stops all currently active Colab sessions. This handles the case where -// Close() is called (e.g. via SIGINT) while one or more Connect() calls are -// mid-execution. In ax serve mode, multiple concurrent Connect() calls may -// each have their own session. -func (a *ColabAgent) Close() error { - a.mu.Lock() - sessions := make([]string, 0, len(a.activeSessions)) - for s := range a.activeSessions { - sessions = append(sessions, s) - } - a.mu.Unlock() - var firstErr error - for _, s := range sessions { - if err := a.stopSession(s); err != nil && firstErr == nil { - firstErr = err - } - } - return firstErr -} diff --git a/internal/experimental/agent/colab_test.go b/internal/experimental/agent/colab_test.go deleted file mode 100644 index c2ca610..0000000 --- a/internal/experimental/agent/colab_test.go +++ /dev/null @@ -1,562 +0,0 @@ -// Copyright 2026 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package agent - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" - - baseagent "github.com/google/ax/internal/agent" - "github.com/google/ax/proto" -) - -// --------------------------------------------------------------------------- -// Pure function tests -// --------------------------------------------------------------------------- - -func TestParseAccelerator(t *testing.T) { - tests := []struct { - input string - want []string - wantErr bool - }{ - {input: "tpu-v5e1", want: []string{"-tpu", "v5e1"}}, - {input: "gpu-A100", want: []string{"-gpu", "A100"}}, - {input: "", want: nil}, - {input: "cpu-x86", wantErr: true}, - {input: "tpu", wantErr: true}, - } - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - got, err := parseAccelerator(tt.input) - if tt.wantErr { - if err == nil { - t.Fatalf("expected error for %q", tt.input) - } - return - } - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(got) != len(tt.want) { - t.Fatalf("got %v, want %v", got, tt.want) - } - for i := range got { - if got[i] != tt.want[i] { - t.Errorf("got[%d] = %q, want %q", i, got[i], tt.want[i]) - } - } - }) - } -} - -func TestLastUserText(t *testing.T) { - msg := func(role, s string) *proto.Message { - return &proto.Message{ - Role: role, - Content: &proto.Content{Type: &proto.Content_Text{Text: &proto.TextContent{Text: s}}}, - } - } - - tests := []struct { - name string - messages []*proto.Message - want string - }{ - {name: "empty", messages: nil, want: ""}, - {name: "single user", messages: []*proto.Message{msg("user", "hello")}, want: "hello"}, - { - name: "returns last user message", - messages: []*proto.Message{msg("user", "first"), msg("assistant", "reply"), msg("user", "second")}, - want: "second", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := lastUserText(tt.messages); got != tt.want { - t.Errorf("lastUserText() = %q, want %q", got, tt.want) - } - }) - } -} - -func TestRunColab_NoArgs(t *testing.T) { - // runColab indexes args[0] when formatting the error message on - // failure, so calling it with no arguments would panic. Guard - // against that by returning an error up front. - a := &ColabAgent{} - _, err := a.runColab(context.Background()) - if err == nil { - t.Fatal("expected error when called with no args") - } - if !strings.Contains(err.Error(), "at least one argument") { - t.Errorf("error message should mention required argument, got: %v", err) - } -} - -// --------------------------------------------------------------------------- -// NewColabAgent validation tests -// --------------------------------------------------------------------------- - -func TestNewColabAgent_Validation(t *testing.T) { - binDir := t.TempDir() - writeFakeColab(t, binDir) - t.Setenv("PATH", binDir) - - pyFile := writeTempFile(t, "agent.py", "# test") - nbFile := writeTempFile(t, "notebook.ipynb", `{"cells":[]}`) - - tests := []struct { - name string - cfg ColabAgentConfig - wantErr string // empty = no error expected - notebook bool - }{ - { - name: "py file defaults", - cfg: ColabAgentConfig{ID: "t", LocalFile: pyFile}, - }, - { - name: "ipynb detected as notebook", - cfg: ColabAgentConfig{ID: "t", LocalFile: nbFile}, - notebook: true, - }, - { - name: "drive ipynb detected as notebook", - cfg: ColabAgentConfig{ID: "t", DriveFile: "MyDrive/nb.ipynb"}, - notebook: true, - }, - { - name: "both local and drive", - cfg: ColabAgentConfig{ID: "t", LocalFile: pyFile, DriveFile: "MyDrive/nb.ipynb"}, - wantErr: "only one of", - }, - { - name: "neither local nor drive", - cfg: ColabAgentConfig{ID: "t"}, - wantErr: "must be set", - }, - { - name: "missing local file", - cfg: ColabAgentConfig{ID: "t", LocalFile: "/nonexistent.py"}, - wantErr: "not found", - }, - { - name: "output_drive_path with drive_file", - cfg: ColabAgentConfig{ID: "t", DriveFile: "MyDrive/nb.ipynb", OutputDrivePath: "MyDrive/out.ipynb"}, - wantErr: "output_drive_path is only supported with local_file", - }, - { - name: "malicious input_flag rejected", - cfg: ColabAgentConfig{ID: "t", LocalFile: pyFile, InputFlag: "x'; rm -rf /; echo '"}, - wantErr: "invalid input_flag", - }, - { - name: "input_flag with dashes rejected", - cfg: ColabAgentConfig{ID: "t", LocalFile: pyFile, InputFlag: "--input"}, - wantErr: "invalid input_flag", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - a, err := NewColabAgent(tt.cfg) - if tt.wantErr != "" { - if err == nil || !strings.Contains(err.Error(), tt.wantErr) { - t.Fatalf("want error containing %q, got: %v", tt.wantErr, err) - } - return - } - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if a.notebook != tt.notebook { - t.Errorf("notebook = %v, want %v", a.notebook, tt.notebook) - } - }) - } -} - -func TestNewColabAgent_MissingBinary(t *testing.T) { - t.Setenv("PATH", t.TempDir()) - _, err := NewColabAgent(ColabAgentConfig{LocalFile: "anything.py"}) - if err == nil || !strings.Contains(err.Error(), "colab CLI not found") { - t.Fatalf("want colab CLI error, got: %v", err) - } -} - -// --------------------------------------------------------------------------- -// Connect() tests with fake colab script -// --------------------------------------------------------------------------- - -func TestConnect_FullSequence(t *testing.T) { - setup := newFakeColabEnv(t, ColabAgentConfig{ - ID: "myagent", - Accelerator: "tpu-v5e1", - DriveMountPath: "/content/drive", - Requirements: "requirements.txt", - InputFlag: "query", - OutputImage: filepath.Join(t.TempDir(), "plot.png"), - OutputDrivePath: "MyDrive/session.ipynb", - }) - t.Setenv("COLAB_EXEC_STDOUT", "line 1\nline 2\nline 3") - - var outputs []string - handler := baseagent.OutputHandler(func(resp *proto.AgentOutputs) error { - for _, m := range resp.Messages { - if t := m.GetContent().GetText(); t != nil { - outputs = append(outputs, t.Text) - } - } - return nil - }) - - start := &proto.AgentStart{ - Messages: []*proto.Message{userText("hello world")}, - } - - if err := setup.agent.Connect(context.Background(), "test-conv", "exec-123", start, nil, handler); err != nil { - t.Fatalf("Connect: %v", err) - } - - // Verify streaming output. - wantOutputs := []string{"line 1", "line 2", "line 3"} - if len(outputs) != len(wantOutputs) { - t.Fatalf("got %d outputs %v, want %d", len(outputs), outputs, len(wantOutputs)) - } - for i, want := range wantOutputs { - if outputs[i] != want { - t.Errorf("outputs[%d] = %q, want %q", i, outputs[i], want) - } - } - - // Verify CLI call sequence. - cmds := setup.loggedCommands(t) - wantCmds := []string{"new", "drivemount", "install", "upload", "exec", "download", "exec", "stop"} - if len(cmds) != len(wantCmds) { - t.Fatalf("got commands %v, want %v", cmds, wantCmds) - } - for i, want := range wantCmds { - if cmds[i] != want { - t.Errorf("cmd[%d] = %q, want %q", i, cmds[i], want) - } - } - - // Verify exec stdin contains the command with flag and user text. - stdin := setup.readStdinLog(t) - for _, want := range []string{"!python", "--query", "hello world"} { - if !strings.Contains(stdin, want) { - t.Errorf("exec stdin missing %q: %q", want, stdin) - } - } -} - -func TestConnect_SessionCreationFails(t *testing.T) { - setup := newFakeColabEnv(t, ColabAgentConfig{ID: "fail-new"}) - t.Setenv("COLAB_FAIL_CMD", "new") - - start := &proto.AgentStart{Messages: []*proto.Message{userText("test")}} - err := setup.agent.Connect(context.Background(), "test-conv", "e1", start, nil, noopHandler) - if err == nil || !strings.Contains(err.Error(), "failed to create colab session") { - t.Fatalf("want session creation error, got: %v", err) - } - - // Stop should not be called since no session was created. - for _, cmd := range setup.loggedCommands(t) { - if cmd == "stop" { - t.Error("stop should not be called when session creation fails") - } - } -} - -func TestConnect_ExecFailure(t *testing.T) { - setup := newFakeColabEnv(t, ColabAgentConfig{ID: "fail-exec"}) - t.Setenv("COLAB_FAIL_CMD", "exec") - - start := &proto.AgentStart{Messages: []*proto.Message{userText("test")}} - err := setup.agent.Connect(context.Background(), "test-conv", "e1", start, nil, noopHandler) - if err == nil { - t.Fatal("expected error when exec fails") - } - - // Stop should still run despite exec failure. - cmds := setup.loggedCommands(t) - if cmds[len(cmds)-1] != "stop" { - t.Errorf("last command should be stop, got %v", cmds) - } -} - -func TestConnect_NotebookLocalFile(t *testing.T) { - setup := newFakeColabEnv(t, ColabAgentConfig{ - ID: "nb-local", - InputFlag: "query", - }, withNotebook()) - t.Setenv("COLAB_EXEC_STDOUT", "analysis complete") - - start := &proto.AgentStart{Messages: []*proto.Message{userText("analyze data")}} - if err := setup.agent.Connect(context.Background(), "test-conv", "e1", start, nil, noopHandler); err != nil { - t.Fatalf("Connect: %v", err) - } - - // Notebook: new, upload, exec (set var), exec (%run), stop. - cmds := setup.loggedCommands(t) - wantCmds := []string{"new", "upload", "exec", "exec", "stop"} - if len(cmds) != len(wantCmds) { - t.Fatalf("got commands %v, want %v", cmds, wantCmds) - } - - // Verify stdin contains variable assignment and %run. - stdin := setup.readStdinLog(t) - for _, want := range []string{"query = ", "analyze data", "%run"} { - if !strings.Contains(stdin, want) { - t.Errorf("exec stdin missing %q: %q", want, stdin) - } - } -} - -func TestConnect_NotebookDriveFile(t *testing.T) { - // drive_mount_path is omitted -- the colab CLI uses its default, - // and the code falls back to defaultDriveMountPath for filepath.Join. - setup := newFakeColabEnv(t, ColabAgentConfig{ - ID: "nb-drive", - DriveFile: "MyDrive/notebooks/analysis.ipynb", - InputFlag: "query", - }) - t.Setenv("COLAB_EXEC_STDOUT", "done") - - start := &proto.AgentStart{Messages: []*proto.Message{userText("analyze trends")}} - if err := setup.agent.Connect(context.Background(), "test-conv", "e1", start, nil, noopHandler); err != nil { - t.Fatalf("Connect: %v", err) - } - - // Drive notebook: new, drivemount, exec (set var), exec (%run), stop. No upload. - cmds := setup.loggedCommands(t) - wantCmds := []string{"new", "drivemount", "exec", "exec", "stop"} - if len(cmds) != len(wantCmds) { - t.Fatalf("got commands %v, want %v", cmds, wantCmds) - } - - // Verify %run uses the default mount path + drive_file. - stdin := setup.readStdinLog(t) - if !strings.Contains(stdin, "%run /content/drive/MyDrive/notebooks/analysis.ipynb") { - t.Errorf("exec stdin missing %%run with drive path: %q", stdin) - } -} - -func TestConnect_RetryOnSessionTimeout(t *testing.T) { - binDir := t.TempDir() - writeScript(t, binDir, "colab", fakeColabTimeoutScript) - - pyFile := writeTempFile(t, "agent.py", "# test") - logFile := filepath.Join(t.TempDir(), "colab.log") - - t.Setenv("PATH", binDir+string(os.PathListSeparator)+os.Getenv("PATH")) - t.Setenv("COLAB_TEST_LOG", logFile) - t.Setenv("COLAB_EXEC_COUNTER", filepath.Join(t.TempDir(), "count")) - t.Setenv("COLAB_EXEC_STDOUT", "success on retry") - - agent, err := NewColabAgent(ColabAgentConfig{ - ID: "timeout-test", LocalFile: pyFile, InputFlag: "input", - }) - if err != nil { - t.Fatalf("NewColabAgent: %v", err) - } - - var output string - handler := baseagent.OutputHandler(func(resp *proto.AgentOutputs) error { - for _, m := range resp.Messages { - if txt := m.GetContent().GetText(); txt != nil { - output = txt.Text - } - } - return nil - }) - - start := &proto.AgentStart{Messages: []*proto.Message{userText("test")}} - if err := agent.Connect(context.Background(), "test-conv", "e1", start, nil, handler); err != nil { - t.Fatalf("Connect should succeed after retry: %v", err) - } - if output != "success on retry" { - t.Errorf("output = %q, want %q", output, "success on retry") - } - - // Attempt 1: new, upload, exec (fail), status (dead), stop - // Attempt 2: new, upload, exec (ok), stop - env := &fakeColabEnv{logFile: logFile} - cmds := env.loggedCommands(t) - wantCmds := []string{"new", "upload", "exec", "status", "stop", "new", "upload", "exec", "stop"} - if len(cmds) != len(wantCmds) { - t.Fatalf("got commands %v, want %v", cmds, wantCmds) - } - for i, want := range wantCmds { - if cmds[i] != want { - t.Errorf("cmd[%d] = %q, want %q", i, cmds[i], want) - } - } -} - -// --------------------------------------------------------------------------- -// Test helpers -// --------------------------------------------------------------------------- - -var noopHandler = baseagent.OutputHandler(func(*proto.AgentOutputs) error { return nil }) - -const fakeColabScript = `#!/bin/sh -echo "$*" >> "$COLAB_TEST_LOG" -if [ "$1" = "$COLAB_FAIL_CMD" ]; then - echo "simulated failure for $1" >&2 - exit 1 -fi -if [ "$1" = "exec" ]; then - cat >> "${COLAB_TEST_LOG}.stdin" - if [ -n "$COLAB_EXEC_STDOUT" ]; then printf '%s\n' "$COLAB_EXEC_STDOUT"; fi - if [ -n "$COLAB_EXEC_STDERR" ]; then printf '%s' "$COLAB_EXEC_STDERR" >&2; fi - exit 0 -fi -exit 0 -` - -const fakeColabTimeoutScript = `#!/bin/sh -echo "$*" >> "$COLAB_TEST_LOG" -if [ "$1" = "status" ]; then exit 1; fi -if [ "$1" = "exec" ]; then - cat >> "${COLAB_TEST_LOG}.stdin" - count=0 - if [ -f "$COLAB_EXEC_COUNTER" ]; then count=$(cat "$COLAB_EXEC_COUNTER"); fi - count=$((count + 1)) - echo "$count" > "$COLAB_EXEC_COUNTER" - if [ "$count" -eq 1 ]; then echo "session not found" >&2; exit 1; fi - if [ -n "$COLAB_EXEC_STDOUT" ]; then printf '%s\n' "$COLAB_EXEC_STDOUT"; fi - exit 0 -fi -exit 0 -` - -type envOpt func(cfg *ColabAgentConfig) - -func withNotebook() envOpt { - return func(cfg *ColabAgentConfig) { - cfg.LocalFile = "" // will be set to notebook below - } -} - -type fakeColabEnv struct { - agent *ColabAgent - logFile string -} - -func newFakeColabEnv(t *testing.T, cfg ColabAgentConfig, opts ...envOpt) *fakeColabEnv { - t.Helper() - - binDir := t.TempDir() - writeFakeColab(t, binDir) - logFile := filepath.Join(t.TempDir(), "colab.log") - - t.Setenv("PATH", binDir+string(os.PathListSeparator)+os.Getenv("PATH")) - t.Setenv("COLAB_TEST_LOG", logFile) - - isNb := false - for _, opt := range opts { - opt(&cfg) - isNb = true - } - - // Create a temp file if LocalFile is needed and not already set. - if cfg.LocalFile == "" && cfg.DriveFile == "" { - if isNb { - cfg.LocalFile = writeTempFile(t, "test.ipynb", `{"cells":[]}`) - } else { - cfg.LocalFile = writeTempFile(t, "test.py", "# test") - } - } - - if cfg.ID == "" { - cfg.ID = "test-agent" - } - - agent, err := NewColabAgent(cfg) - if err != nil { - t.Fatalf("NewColabAgent: %v", err) - } - return &fakeColabEnv{agent: agent, logFile: logFile} -} - -func (e *fakeColabEnv) readLog(t *testing.T) []string { - t.Helper() - data, err := os.ReadFile(e.logFile) - if err != nil { - if os.IsNotExist(err) { - return nil - } - t.Fatalf("readLog: %v", err) - } - trimmed := strings.TrimSpace(string(data)) - if trimmed == "" { - return nil - } - return strings.Split(trimmed, "\n") -} - -func (e *fakeColabEnv) loggedCommands(t *testing.T) []string { - t.Helper() - lines := e.readLog(t) - cmds := make([]string, len(lines)) - for i, line := range lines { - cmds[i] = strings.Fields(line)[0] - } - return cmds -} - -func (e *fakeColabEnv) readStdinLog(t *testing.T) string { - t.Helper() - data, err := os.ReadFile(e.logFile + ".stdin") - if err != nil { - if os.IsNotExist(err) { - return "" - } - t.Fatalf("readStdinLog: %v", err) - } - return strings.TrimSpace(string(data)) -} - -func writeFakeColab(t *testing.T, dir string) { - t.Helper() - writeScript(t, dir, "colab", fakeColabScript) -} - -func writeScript(t *testing.T, dir, name, content string) { - t.Helper() - if err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0755); err != nil { - t.Fatalf("write script %s: %v", name, err) - } -} - -func writeTempFile(t *testing.T, name, content string) string { - t.Helper() - path := filepath.Join(t.TempDir(), name) - if err := os.WriteFile(path, []byte(content), 0644); err != nil { - t.Fatalf("write %s: %v", name, err) - } - return path -} - -func userText(s string) *proto.Message { - return &proto.Message{ - Role: "user", - Content: &proto.Content{Type: &proto.Content_Text{Text: &proto.TextContent{Text: s}}}, - } -}