Skip to content

Add explicit batched inference for PyTorch pi0 policies#941

Open
taivu1998 wants to merge 1 commit into
Physical-Intelligence:mainfrom
taivu1998:tdv/issue-765-vectorized-inference
Open

Add explicit batched inference for PyTorch pi0 policies#941
taivu1998 wants to merge 1 commit into
Physical-Intelligence:mainfrom
taivu1998:tdv/issue-765-vectorized-inference

Conversation

@taivu1998

Copy link
Copy Markdown

Summary

Adds an explicit batched inference path for policy users and websocket clients. The new Policy.infer_batch([...]) path applies existing input/output transforms per observation, stacks the transformed batch, and calls the model once with a leading batch dimension. Existing infer(obs) behavior and shapes are preserved.

This also adds WebsocketClientPolicy.infer_batch([...]) using an explicit {"batch": [...]} request payload, plus server-side routing for that payload.

Root Cause

Issue #765 asks whether the PyTorch pi0 / pi0.5 models support vectorized inference. The model core is batch-shaped, but the public policy/server path always treated incoming data as one observation: it transformed one dict, added a singleton batch dimension, sampled, and stripped index 0 from outputs. Passing already-batched observations therefore produced shapes like [1, B, ...] instead of true [B, ...] inference.

Changes

  • Add optional BasePolicy.infer_batch that raises by default instead of unsafe sequential fallback.
  • Implement optimized Policy.infer_batch for true batched sampling while keeping Policy.infer singleton-compatible.
  • Preserve per-sample transform semantics by applying input and output transforms around the model batch boundary.
  • Add explicit websocket batch routing and client support.
  • Forward batched calls through PolicyRecorder.
  • Document local and remote vectorized inference usage.
  • Add focused tests for batch shape preservation, one model call, transform behavior, noise validation, wrapper behavior, websocket request routing, and client payload handling.
  • Add typing-extensions to openpi-client metadata because the client package already imports it.

Validation

  • uv run --project packages/openpi-client pytest packages/openpi-client/src/openpi_client -> 24 passed
  • uvx ruff check src/openpi/policies/policy.py src/openpi/policies/policy_test.py src/openpi/serving/websocket_policy_server.py src/openpi/serving/websocket_policy_server_test.py packages/openpi-client/src/openpi_client/base_policy.py packages/openpi-client/src/openpi_client/websocket_client_policy.py packages/openpi-client/src/openpi_client/websocket_client_policy_test.py
  • uvx ruff format --check src/openpi/policies/policy.py src/openpi/policies/policy_test.py src/openpi/serving/websocket_policy_server.py src/openpi/serving/websocket_policy_server_test.py packages/openpi-client/src/openpi_client/base_policy.py packages/openpi-client/src/openpi_client/websocket_client_policy.py packages/openpi-client/src/openpi_client/websocket_client_policy_test.py
  • python -m py_compile src/openpi/policies/policy.py src/openpi/policies/policy_test.py src/openpi/serving/websocket_policy_server.py src/openpi/serving/websocket_policy_server_test.py packages/openpi-client/src/openpi_client/base_policy.py packages/openpi-client/src/openpi_client/websocket_client_policy.py packages/openpi-client/src/openpi_client/websocket_client_policy_test.py
  • uv lock --check
  • git diff --check

Note: root uv run pytest src/openpi/policies/policy_test.py src/openpi/serving/websocket_policy_server_test.py cannot run on macOS arm64 because jax-cuda12-plugin==0.5.3 has no compatible wheel for this platform; it fails before test collection.

Addresses #765.

@taivu1998 taivu1998 marked this pull request as ready for review May 11, 2026 03:37
@jimmyt857 jimmyt857 removed their request for review May 11, 2026 04:08

@wadeKeith wadeKeith left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid addition - the batched inference path addresses a real need. Policy.infer_batch() API is clean with good test coverage for sync and websocket paths. README and docs updated. LGTM! Reviewed by Hermes Agent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants