diff --git a/data/.lfs/groot.tar.gz b/data/.lfs/groot.tar.gz new file mode 100644 index 0000000000..16602bca84 --- /dev/null +++ b/data/.lfs/groot.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:305ed4b8538cb8083d91518e687b184b374c977afd185adecabc1152cef8d701 +size 3517892 diff --git a/data/.lfs/mujoco_sim.tar.gz b/data/.lfs/mujoco_sim.tar.gz index 57833fbbc6..47d2df201d 100644 --- a/data/.lfs/mujoco_sim.tar.gz +++ b/data/.lfs/mujoco_sim.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d178439569ed81dfad05455419dc51da2c52021313b6d7b9259d9e30946db7c6 -size 60186340 +oid sha256:4232d61e4af19dee0c0e3a8f55f2c1b48b28a70f810ff17039876a48b999d6d2 +size 60251722 diff --git a/dimos/control/tasks/g1_groot_wbc_task/__registry__.py b/dimos/control/tasks/g1_groot_wbc_task/__registry__.py new file mode 100644 index 0000000000..3be9ad92e2 --- /dev/null +++ b/dimos/control/tasks/g1_groot_wbc_task/__registry__.py @@ -0,0 +1,17 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +TASK_FACTORIES = { + "g1_groot_wbc": "dimos.control.tasks.g1_groot_wbc_task.g1_groot_wbc_task:create_task", +} diff --git a/dimos/control/tasks/g1_groot_wbc_task/g1_groot_wbc_task.py b/dimos/control/tasks/g1_groot_wbc_task/g1_groot_wbc_task.py new file mode 100644 index 0000000000..d361029c8c --- /dev/null +++ b/dimos/control/tasks/g1_groot_wbc_task/g1_groot_wbc_task.py @@ -0,0 +1,788 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GR00T whole-body-control task for the Unitree G1 humanoid. + +Runs the two-model GR00T WBC locomotion policy (balance + walk) inside +the coordinator tick loop. Claims the 15 legs+waist joints at high +priority; arm joints are left to lower-priority tasks in the blueprint. + +Reference implementation: g1_control/backends/groot_wbc_backend.py. +Observation, action, and model-selection semantics are preserved +verbatim — changing them drifts us away from the ONNX policies trained +by GR00T-WholeBodyControl. + +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +import threading +from typing import TYPE_CHECKING, Any + +import numpy as np +import onnxruntime as ort # type: ignore[import-untyped] + +from dimos.control.components import make_humanoid_joints +from dimos.control.task import ( + BaseControlTask, + ControlMode, + CoordinatorState, + JointCommandOutput, + ResourceClaim, +) +from dimos.protocol.service.spec import BaseConfig +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.hardware.whole_body.spec import WholeBodyAdapter + from dimos.msgs.geometry_msgs.Twist import Twist + +logger = setup_logger() + + +# The 29 DDS motor names + their kp/kd for the GR00T-trained policies. +# Lifted verbatim from g1-control-api/configs/g1_groot_wbc.yaml, which +# itself copies GR00T-WBC's g1_29dof_gear_wbc.yaml. Diverging from these +# on real hardware risks instability — the ONNX models were trained +# against them. +g1_joints = make_humanoid_joints("g1") +g1_legs_waist = g1_joints[:15] # indices 0..14 — legs (12) + waist (3) +g1_arms = g1_joints[15:] # indices 15..28 — left arm (7) + right arm (7) + +G1_GROOT_KP: list[float] = [ + 150.0, + 150.0, + 150.0, + 200.0, + 40.0, + 40.0, # left leg + 150.0, + 150.0, + 150.0, + 200.0, + 40.0, + 40.0, # right leg + 250.0, + 250.0, + 250.0, # waist + 100.0, + 100.0, + 40.0, + 40.0, + 20.0, + 20.0, + 20.0, # left arm + 100.0, + 100.0, + 40.0, + 40.0, + 20.0, + 20.0, + 20.0, # right arm +] +G1_GROOT_KD: list[float] = [ + 2.0, + 2.0, + 2.0, + 4.0, + 2.0, + 2.0, # left leg + 2.0, + 2.0, + 2.0, + 4.0, + 2.0, + 2.0, # right leg + 5.0, + 5.0, + 5.0, # waist + 5.0, + 5.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, # left arm + 5.0, + 5.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, # right arm +] + +# Relaxed arms-down pose. From g1_control/backends/groot_wbc_backend.py +# DEFAULT_29[15:] (all zeros) — the zero-offset pose the policy was +# trained against. Operators can override at runtime by publishing +# joint targets on the arms via the coordinator's joint_command transport. +ARM_DEFAULT_POSE: list[float] = [0.0] * 14 + + +# Default joint angles copied verbatim from +# g1_control/backends/groot_wbc_backend.py DEFAULT_29. Policy was trained +# against these as the zero-offset pose. +_DEFAULT_POSITIONS_29 = [ + -0.1, + 0.0, + 0.0, + 0.3, + -0.2, + 0.0, # left leg + -0.1, + 0.0, + 0.0, + 0.3, + -0.2, + 0.0, # right leg + 0.0, + 0.0, + 0.0, # waist + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, # left arm (not driven by policy) + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, # right arm (not driven by policy) +] + +_SINGLE_OBS_DIM = 86 +_OBS_HISTORY_LEN = 6 +_NUM_ACTIONS = 15 +_NUM_MOTORS = 29 + + +@dataclass +class G1GrootWBCTaskConfig: + """Configuration for the GR00T WBC task. + + Attributes: + balance_onnx: Path to the balance ONNX model. Used when + ``||cmd|| <= cmd_norm_threshold``. + walk_onnx: Path to the walk ONNX model. Used otherwise. + joint_names: The 15 coordinator joint names this task claims + (legs 0-11 + waist 12-14, in DDS order). + all_joint_names: All 29 coordinator joint names in DDS order + (legs 0-11 + waist 12-14 + arms 15-28). Required to build + the observation, which feeds all 29 joint states. + default_positions_29: Default joint angles for all 29 joints + (DDS order). First 15 are the policy's zero-offset pose. + priority: Arbitration priority (higher wins). 50 is the + recommended WBC priority per the task.py conventions. + decimation: Run inference every N ticks. At 500 Hz tick / + 50 Hz policy → decimation=10. + action_scale: Multiplier on raw policy output before adding + defaults. + obs_ang_vel_scale: Scale for base angular velocity in obs. + obs_dof_pos_scale: Scale for joint position offset in obs. + obs_dof_vel_scale: Scale for joint velocity in obs. + cmd_scale: Per-axis scale applied to (vx, vy, wz) in obs. + cmd_norm_threshold: ||cmd|| below this selects the balance + model, otherwise walk. + height_cmd: Fixed height command slot in obs. + timeout: Seconds without a velocity command before zeroing it. + auto_arm: Arm the policy automatically on ``start()``. Default + False — safe for real hardware; the blueprint sets True for + simulation. + auto_dry_run: Enter dry-run mode on ``start()``. Policy still + runs but outputs are not emitted to the adapter — useful for + verifying on real hardware without commanding motors. + default_ramp_seconds: Duration of the arming ramp (current pose + → ``default_15``) when ``arm()`` is called without an + explicit duration. Set to 0 in simulation (no ramp needed); + 10 s on real hardware mirrors the g1-control-api default. + """ + + balance_onnx: str | Path + walk_onnx: str | Path + joint_names: list[str] + all_joint_names: list[str] + default_positions_29: list[float] = field(default_factory=lambda: list(_DEFAULT_POSITIONS_29)) + priority: int = 50 + decimation: int = 10 + action_scale: float = 0.25 + obs_ang_vel_scale: float = 0.5 + obs_dof_pos_scale: float = 1.0 + obs_dof_vel_scale: float = 0.05 + cmd_scale: tuple[float, float, float] = (2.0, 2.0, 0.5) + cmd_norm_threshold: float = 0.05 + height_cmd: float = 0.74 + timeout: float = 1.0 + auto_arm: bool = False + auto_dry_run: bool = False + default_ramp_seconds: float = 10.0 + + +class G1GrootWBCTask(BaseControlTask): + """Runs the GR00T balance / walk ONNX policies inside the coordinator tick loop. + + Observation vector (86 dims, built each inference tick, replicates + ``groot_wbc_backend.GrootWBCBackend._compute_obs`` verbatim): + + [0:3] cmd_vel * cmd_scale # scaled velocity command + [3] height_cmd # fixed slot (0.74) + [4:7] (0, 0, 0) # rpy_cmd, zeros + [7:10] gyro * obs_ang_vel_scale # body-frame ang vel + [10:13] projected_gravity(quat) # gravity in body frame + [13:42] (q_29 - default_29) * dof_pos_scale + [42:71] dq_29 * dof_vel_scale + [71:86] last_action (15 dims) + + The observation is stacked into a 6-frame history buffer (516 dims) + before being fed to ONNX. + + Action (15 dims, legs + waist only): + + target_q_15 = action * action_scale + default_15 + + Arms are NOT driven by this task — the blueprint pairs this task + with a lower-priority servo task scoped to the 14 arm joints. + """ + + def __init__( + self, + name: str, + config: G1GrootWBCTaskConfig, + adapter: WholeBodyAdapter, + ) -> None: + if len(config.joint_names) != _NUM_ACTIONS: + raise ValueError( + f"G1GrootWBCTask '{name}' requires exactly {_NUM_ACTIONS} joint names " + f"(legs + waist), got {len(config.joint_names)}" + ) + if len(config.all_joint_names) != _NUM_MOTORS: + raise ValueError( + f"G1GrootWBCTask '{name}' requires exactly {_NUM_MOTORS} all_joint_names " + f"(full 29-DOF G1), got {len(config.all_joint_names)}" + ) + if len(config.default_positions_29) != _NUM_MOTORS: + raise ValueError( + f"G1GrootWBCTask '{name}' requires exactly {_NUM_MOTORS} " + f"default_positions_29, got {len(config.default_positions_29)}" + ) + if config.decimation < 1: + raise ValueError(f"G1GrootWBCTask '{name}' requires decimation >= 1") + + self._name = name + self._config = config + self._adapter = adapter + self._joint_names_list = list(config.joint_names) + self._joint_names_set = frozenset(config.joint_names) + self._all_joint_names = list(config.all_joint_names) + + providers = ort.get_available_providers() + self._balance_session = ort.InferenceSession(str(config.balance_onnx), providers=providers) + self._walk_session = ort.InferenceSession(str(config.walk_onnx), providers=providers) + self._balance_input = self._balance_session.get_inputs()[0].name + self._walk_input = self._walk_session.get_inputs()[0].name + logger.info( + f"G1GrootWBCTask '{name}' loaded balance={config.balance_onnx}, " + f"walk={config.walk_onnx} (providers: {providers})" + ) + + self._default_29 = np.asarray(config.default_positions_29, dtype=np.float32) + self._default_15 = self._default_29[:_NUM_ACTIONS] + self._cmd_scale = np.asarray(config.cmd_scale, dtype=np.float32) + + # Inference state + self._last_action = np.zeros(_NUM_ACTIONS, dtype=np.float32) + self._obs_buf = np.zeros((1, _SINGLE_OBS_DIM * _OBS_HISTORY_LEN), dtype=np.float32) + self._first_inference = True + self._tick_count = 0 + self._last_targets: list[float] | None = None + + # Last-known-good state caches. compute() falls back to these + # whenever a joint is missing from CoordinatorState (transient + # packet drop, late publisher, etc) instead of substituting 0.0 + # — feeding a zero pose to the policy makes it think the robot + # is at the URDF zero (legs straight) and command a snap-back, + # which on real hardware tips the robot over. ``_state_seen`` + # tracks whether we've ever observed a fully-populated state; + # until then compute() returns None rather than running on + # half-cached defaults. + self._cached_q_29 = self._default_29.copy() + self._cached_dq_29 = np.zeros(_NUM_MOTORS, dtype=np.float32) + self._cached_q_15 = self._default_15.copy() + self._state_seen = False + + self._active = False + self._armed = False + self._arming = False + self._arm_pending = False + self._dry_run = bool(config.auto_dry_run) + self._arming_duration = 0.0 + self._arming_start_t = 0.0 + self._ramp_start: np.ndarray | None = None + self._last_dry_run_log_t: float = 0.0 + + self._cmd_lock = threading.Lock() + self._cmd = np.zeros(3, dtype=np.float32) + self._last_cmd_time: float = 0.0 + + @property + def name(self) -> str: + return self._name + + def claim(self) -> ResourceClaim: + return ResourceClaim( + joints=self._joint_names_set, + priority=self._config.priority, + mode=ControlMode.SERVO_POSITION, + ) + + def is_active(self) -> bool: + return self._active + + def _refresh_state_caches(self, state: CoordinatorState) -> bool: + """Pull current q/dq for the 15 claimed joints and the full 29 + from ``CoordinatorState``, updating last-known-good caches and + returning True iff the full 29 came back populated this tick. + + On a missing joint we keep the cached value rather than dropping + in 0.0 — the policy interprets 0.0 as "at URDF zero / legs + straight" and commands a recovery, which tips the robot. + """ + all_present = True + for i, jname in enumerate(self._joint_names_list): + pos = state.joints.get_position(jname) + if pos is None: + all_present = False + else: + self._cached_q_15[i] = pos + for i, jname in enumerate(self._all_joint_names): + pos = state.joints.get_position(jname) + vel = state.joints.get_velocity(jname) + if pos is None or vel is None: + all_present = False + else: + self._cached_q_29[i] = pos + self._cached_dq_29[i] = vel + if all_present: + self._state_seen = True + return all_present + + def compute(self, state: CoordinatorState) -> JointCommandOutput | None: + if not self._active: + return None + + # Refresh the last-known-good state caches. If we've never seen + # a fully-populated state and this tick is also incomplete, hold + # off — emitting a command from defaults would snap the robot. + fresh = self._refresh_state_caches(state) + if not self._state_seen and not fresh: + return None + + current_15 = self._cached_q_15.copy() + + # arm() was called — snapshot the ramp start and enter arming / + # armed state (ramp=0 arms immediately). + if self._arm_pending: + self._ramp_start = current_15.copy() + self._arming_start_t = state.t_now + if self._arming_duration > 0.0: + self._arming = True + self._armed = False + logger.info( + f"G1GrootWBCTask '{self._name}' arming: " + f"ramp → default_15 over {self._arming_duration:.1f}s" + ) + else: + self._arming = False + self._armed = True + self._reset_policy_state() + logger.info(f"G1GrootWBCTask '{self._name}' armed (no ramp)") + self._arm_pending = False + + # Unarmed & not arming: echo current joint positions. With the + # component's kp/kd applied downstream, q_tgt == q_actual yields + # pure damping (tau = -kd * dq), which mirrors the reference + # backend's inactive "hold current pose" behaviour. + if not self._armed and not self._arming: + self._last_targets = current_15.tolist() + return JointCommandOutput( + joint_names=self._joint_names_list, + positions=self._last_targets, + mode=ControlMode.SERVO_POSITION, + ) + + # Arming: lerp ramp_start → default_15 over arming_duration. + if self._arming: + assert self._ramp_start is not None + elapsed = state.t_now - self._arming_start_t + alpha = ( + 1.0 if self._arming_duration <= 0.0 else min(1.0, elapsed / self._arming_duration) + ) + target = self._ramp_start + alpha * (self._default_15 - self._ramp_start) + self._last_targets = target.tolist() + if alpha >= 1.0: + self._arming = False + self._armed = True + self._reset_policy_state() + logger.info( + f"G1GrootWBCTask '{self._name}' ramp complete — policy armed " + f"({'dry-run' if self._dry_run else 'live'})" + ) + return JointCommandOutput( + joint_names=self._joint_names_list, + positions=self._last_targets, + mode=ControlMode.SERVO_POSITION, + ) + + # Armed: run the policy. In dry-run mode we still compute (so + # the obs buffer stays hot), but return None so no command goes + # downstream. A throttled log line shows what WOULD have been + # sent, which is how g1-control-api lets operators verify pre-go. + self._tick_count += 1 + + # Decimation: only run inference every N ticks. Between inference + # ticks, re-emit the last target so the coordinator keeps driving + # the joints (or nothing, in dry-run). + if self._tick_count % self._config.decimation != 0: + if self._dry_run or self._last_targets is None: + return None + return JointCommandOutput( + joint_names=self._joint_names_list, + positions=self._last_targets, + mode=ControlMode.SERVO_POSITION, + ) + + # State was refreshed up top (with fall-back-to-last-good on + # missing joints). Snapshot the caches now so concurrent state + # updates don't tear the obs vector. + q_29 = self._cached_q_29.copy() + dq_29 = self._cached_dq_29.copy() + + # Prefer IMU from CoordinatorState (populated by the coordinator + # each tick from every whole-body adapter); fall back to the + # adapter-direct read if state.imu is empty (e.g. unit tests + # that build a bare CoordinatorState). The state path is what + # decouples this task from the WholeBodyAdapter Protocol. + if state.imu: + # Single whole-body adapter is the common case — take any. + imu = next(iter(state.imu.values())) + else: + imu = self._adapter.read_imu() + gyro = np.asarray(imu.gyroscope, dtype=np.float32) + gravity = self._projected_gravity(imu.quaternion) + + # Velocity command (with timeout → zero). + with self._cmd_lock: + if ( + self._config.timeout > 0.0 + and self._last_cmd_time > 0.0 + and (state.t_now - self._last_cmd_time) > self._config.timeout + ): + cmd = np.zeros(3, dtype=np.float32) + else: + cmd = self._cmd.copy() + + obs = self._build_obs(cmd=cmd, gyro=gyro, gravity=gravity, q=q_29, dq=dq_29) + + # History buffer: first inference fills all slots with the current + # obs (warm-start); subsequent ticks roll the window. + if self._first_inference: + tiled = np.tile(obs, _OBS_HISTORY_LEN) + self._obs_buf[0, :] = tiled + self._first_inference = False + else: + self._obs_buf[0, : _SINGLE_OBS_DIM * (_OBS_HISTORY_LEN - 1)] = self._obs_buf[ + 0, _SINGLE_OBS_DIM: + ] + self._obs_buf[0, _SINGLE_OBS_DIM * (_OBS_HISTORY_LEN - 1) :] = obs + + # Model selection: balance when near-stationary, walk otherwise. + cmd_norm = float(np.linalg.norm(cmd)) + if cmd_norm <= self._config.cmd_norm_threshold: + raw = self._balance_session.run(None, {self._balance_input: self._obs_buf})[0] + else: + raw = self._walk_session.run(None, {self._walk_input: self._obs_buf})[0] + + action = raw[0, :_NUM_ACTIONS].astype(np.float32) + self._last_action[:] = action + + target_q_15 = action * self._config.action_scale + self._default_15 + self._last_targets = target_q_15.tolist() + + if self._dry_run: + # Throttled peek at the commanded pose so the operator can + # decide whether it looks sane before flipping dry-run off. + if (state.t_now - self._last_dry_run_log_t) >= 1.0: + max_delta = float(np.max(np.abs(target_q_15 - current_15))) + logger.info( + f"G1GrootWBCTask '{self._name}' DRY-RUN (|Δq|_max={max_delta:.3f} rad, " + f"model={'walk' if cmd_norm > self._config.cmd_norm_threshold else 'balance'})" + ) + self._last_dry_run_log_t = state.t_now + return None + + return JointCommandOutput( + joint_names=self._joint_names_list, + positions=self._last_targets, + mode=ControlMode.SERVO_POSITION, + ) + + def on_preempted(self, by_task: str, joints: frozenset[str]) -> None: + if joints & self._joint_names_set: + logger.warning(f"G1GrootWBCTask '{self._name}' preempted by {by_task} on {joints}") + + # Velocity command input + + def set_velocity_command(self, vx: float, vy: float, yaw_rate: float, t_now: float) -> None: + """Set the (vx, vy, yaw_rate) commanded to the policy. + + Called by the coordinator's twist_command dispatcher and by + external Python callers. Thread-safe. + """ + with self._cmd_lock: + self._cmd[:] = [vx, vy, yaw_rate] + self._last_cmd_time = t_now + + def on_twist(self, msg: Twist, t_now: float) -> bool: + """Accept a Twist message, e.g. from an LCM cmd_vel transport.""" + self.set_velocity_command( + float(msg.linear.x), + float(msg.linear.y), + float(msg.angular.z), + t_now, + ) + return True + + # Lifecycle + + def start(self) -> None: + """Enter the coordinator tick loop. + + Starts in "active but unarmed" — compute() echoes current joint + positions every tick, which (combined with the component's + kp/kd) produces damping-only behaviour on real hardware (the + robot sits quietly in dev mode). + + If ``config.auto_arm`` is set, schedules an immediate + ``arm()`` using ``config.default_ramp_seconds`` — this is how + the simulation blueprint bypasses the activation ritual. + If ``config.auto_dry_run`` is set, starts in dry-run mode. + """ + self._active = True + self._armed = False + self._arming = False + self._arm_pending = False + self._dry_run = bool(self._config.auto_dry_run) + self._last_targets = None + self._reset_policy_state() + with self._cmd_lock: + self._cmd[:] = 0.0 + self._last_cmd_time = 0.0 + logger.info( + f"G1GrootWBCTask '{self._name}' started (unarmed" + + (", dry-run" if self._dry_run else "") + + ")" + ) + if self._config.auto_arm: + self.arm(self._config.default_ramp_seconds) + + def stop(self) -> None: + """Leave the tick loop. Re-activation resets policy state.""" + self._active = False + self._armed = False + self._arming = False + self._arm_pending = False + self._last_targets = None + logger.info(f"G1GrootWBCTask '{self._name}' stopped") + + # Arming / dry-run (RPC-callable via coordinator.task_invoke) + + def arm(self, ramp_seconds: float | None = None) -> bool: + """Begin the arming sequence. + + ``compute()`` will snapshot the current joint positions on the + next tick, lerp toward ``default_15`` over ``ramp_seconds``, + then flip ``_armed`` true and hand control to the ONNX policy. + A ramp of 0 arms immediately with no interpolation, which is what + sim uses when the MJCF already starts near the policy default pose. + + Safe to call redundantly; calls while already armed or arming + are ignored. No-op if the task is not ``_active``. + """ + if not self._active: + logger.warning(f"G1GrootWBCTask '{self._name}' arm() called before start() — ignoring") + return False + if self._armed: + logger.info(f"G1GrootWBCTask '{self._name}' already armed — arm() ignored") + return False + if self._arming or self._arm_pending: + logger.info(f"G1GrootWBCTask '{self._name}' arm in progress -- arm() ignored") + return False + ramp = ramp_seconds if ramp_seconds is not None else self._config.default_ramp_seconds + self._arming_duration = max(0.0, float(ramp)) + self._arm_pending = True + logger.info( + f"G1GrootWBCTask '{self._name}' arm requested (ramp={self._arming_duration:.1f}s)" + ) + return True + + def disarm(self) -> bool: + """Stop emitting policy outputs; fall back to hold-current-pose. + + Called either from an operator ``Disarm`` button or from + safety watchdogs. Resets obs history so the next ``arm()`` + starts with a clean buffer. + """ + if not self._armed and not self._arming and not self._arm_pending: + return False + self._armed = False + self._arming = False + self._arm_pending = False + self._ramp_start = None + self._reset_policy_state() + logger.info(f"G1GrootWBCTask '{self._name}' disarmed (holding current pose)") + return True + + def set_dry_run(self, enabled: bool) -> None: + """Enable/disable dry-run. + + In dry-run the policy still runs (obs history stays hot) but + ``compute()`` returns ``None``, so the coordinator forwards no + command to the adapter. Use to verify policy sanity on real + hardware before committing motor torques. + """ + new_val = bool(enabled) + if new_val == self._dry_run: + return + self._dry_run = new_val + self._last_dry_run_log_t = 0.0 + logger.info(f"G1GrootWBCTask '{self._name}' dry_run = {new_val}") + + def state_snapshot(self) -> dict[str, object]: + """Return the current state-machine flags for UI / telemetry.""" + return { + "active": self._active, + "armed": self._armed, + "arming": self._arming, + "arm_pending": self._arm_pending, + "dry_run": self._dry_run, + "arming_duration": self._arming_duration, + } + + # Internal helpers + + def _reset_policy_state(self) -> None: + """Clear inference state — obs history, last action, tick count.""" + self._last_action[:] = 0.0 + self._obs_buf[:] = 0.0 + self._first_inference = True + self._tick_count = 0 + + def _build_obs( + self, + cmd: np.ndarray, + gyro: np.ndarray, + gravity: np.ndarray, + q: np.ndarray, + dq: np.ndarray, + ) -> np.ndarray: + """Build the 86-dim GR00T observation. Layout matches + ``groot_wbc_backend.py`` exactly.""" + obs = np.zeros(_SINGLE_OBS_DIM, dtype=np.float32) + obs[0:3] = cmd * self._cmd_scale + obs[3] = self._config.height_cmd + obs[4:7] = 0.0 + obs[7:10] = gyro * self._config.obs_ang_vel_scale + obs[10:13] = gravity + obs[13:42] = (q - self._default_29) * self._config.obs_dof_pos_scale + obs[42:71] = dq * self._config.obs_dof_vel_scale + obs[71:86] = self._last_action + return obs + + @staticmethod + def _projected_gravity(quaternion: tuple[float, ...]) -> np.ndarray: + """Project world gravity into body frame. + + Uses Unitree DDS quaternion order (w, x, y, z). Formula matches + ``groot_wbc_backend._get_gravity_orientation`` and is + algebraically equivalent to the Go2 RLPolicyTask helper. + """ + w, x, y, z = quaternion + gx = 2.0 * (-x * z + w * y) + gy = 2.0 * (-y * z - w * x) + gz = -(w * w - x * x - y * y + z * z) + return np.array([gx, gy, gz], dtype=np.float32) + + +__all__ = [ + "ARM_DEFAULT_POSE", + "G1_GROOT_KD", + "G1_GROOT_KP", + "G1GrootWBCTask", + "G1GrootWBCTaskConfig", + "g1_arms", + "g1_joints", + "g1_legs_waist", +] + + +class G1GrootWBCTaskParams(BaseConfig): + model_path: str | Path + hardware_id: str + auto_arm: bool = False + auto_dry_run: bool = False + default_ramp_seconds: float = 10.0 + decimation: int | None = None + + +def create_task(cfg: Any, hardware: Any) -> G1GrootWBCTask: + from dimos.control.hardware_interface import ConnectedWholeBody + + params = G1GrootWBCTaskParams.model_validate(cfg.params) + hw = hardware.get(params.hardware_id) if hardware else None + if hw is None: + raise ValueError( + f"G1GrootWBCTask {cfg.name!r} references unknown hardware " + f"{params.hardware_id!r}. Declare the hardware before the task " + f"in the blueprint config." + ) + if not isinstance(hw, ConnectedWholeBody): + raise TypeError( + f"G1GrootWBCTask {cfg.name!r} requires a WHOLE_BODY hardware " + f"component for {params.hardware_id!r}, got {type(hw).__name__}. " + f"Set hardware_type=HardwareType.WHOLE_BODY." + ) + + model_dir = Path(params.model_path) + kwargs: dict[str, Any] = dict( + balance_onnx=model_dir / "balance.onnx", + walk_onnx=model_dir / "walk.onnx", + joint_names=cfg.joint_names, + all_joint_names=hw.joint_names, + priority=cfg.priority, + auto_arm=params.auto_arm, + auto_dry_run=params.auto_dry_run, + default_ramp_seconds=params.default_ramp_seconds, + ) + if params.decimation is not None: + kwargs["decimation"] = params.decimation + return G1GrootWBCTask( + cfg.name, + G1GrootWBCTaskConfig(**kwargs), + adapter=hw.adapter, + ) diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index c3e0f9fff0..477378bd08 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -88,6 +88,7 @@ "unitree-g1-coordinator": "dimos.robot.unitree.g1.blueprints.basic.unitree_g1_coordinator:unitree_g1_coordinator", "unitree-g1-detection": "dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1_detection:unitree_g1_detection", "unitree-g1-full": "dimos.robot.unitree.g1.blueprints.agentic.unitree_g1_full:unitree_g1_full", + "unitree-g1-groot-wbc": "dimos.robot.unitree.g1.blueprints.basic.unitree_g1_groot_wbc:unitree_g1_groot_wbc", "unitree-g1-joystick": "dimos.robot.unitree.g1.blueprints.basic.unitree_g1_joystick:unitree_g1_joystick", "unitree-g1-nav-onboard": "dimos.robot.unitree.g1.blueprints.navigation.unitree_g1_nav_onboard:unitree_g1_nav_onboard", "unitree-g1-nav-sim": "dimos.robot.unitree.g1.blueprints.navigation.unitree_g1_nav_sim:unitree_g1_nav_sim", diff --git a/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_groot_wbc.py b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_groot_wbc.py new file mode 100644 index 0000000000..71895e959c --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_groot_wbc.py @@ -0,0 +1,238 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unitree G1 GR00T whole-body-control blueprint. + +One blueprint, ``--simulation`` flag picks the backend: + +Real hardware (default): + G1WholeBodyConnection (DDS rt/lowstate <-> rt/lowcmd) + transport_lcm + whole-body adapter. 500 Hz tick. Safety profile: unarmed + dry-run on + start; activate explicitly through ControlCoordinator RPC after + verifying commands. The policy ramps from the current pose to its + bent-knee default over 10 s before taking torque control. The 14 arm + joints are held at the relaxed GR00T-trained default via a lower-priority + servo task. + +Sim (``--simulation``): + MujocoSimModule (in-process MuJoCo + SHM) + sim_mujoco_g1 adapter. + 50 Hz tick (matches the rate the policy was trained at). No arming + ramp, no dry-run, no servo_arms -- sim physics doesn't gravity-collapse + the arms between trajectories. + +Usage: + dimos run unitree-g1-groot-wbc # real hardware + dimos --simulation mujoco run unitree-g1-groot-wbc # sim + +Overrides (replace the old env-var dance): + dimos run unitree-g1-groot-wbc \\ + -o g1wholebodyconnection.network_interface=enp2s0 +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from dimos.control.components import HardwareComponent, HardwareType +from dimos.control.coordinator import ControlCoordinator, TaskConfig +from dimos.control.tasks.g1_groot_wbc_task.g1_groot_wbc_task import ( + ARM_DEFAULT_POSE, + G1_GROOT_KD, + G1_GROOT_KP, + g1_arms, + g1_joints, + g1_legs_waist, +) +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.core.transport import LCMTransport +from dimos.hardware.whole_body.spec import WholeBodyConfig +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Imu import Imu +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.sensor_msgs.MotorCommandArray import MotorCommandArray +from dimos.utils.data import LfsPath + +# Lazy data handles. LfsPath only triggers the LFS pull on first +# str()/open(); using ``get_data(...)`` at import time would block the +# whole CLI on a multi-GB download every time the module is imported. +_GROOT_MODEL_DIR = LfsPath("groot") +_MJCF_PATH = LfsPath("mujoco_sim/g1_gear_wbc.xml") + +_adapter_address: str | Path + +if global_config.simulation: + from dimos.simulation.engines.mujoco_sim_module import MujocoSimModule + + # Sim backend: MuJoCo engine via SHM. + _backend = MujocoSimModule.blueprint( + address=_MJCF_PATH, + headless=True, + dof=29, + enable_color=False, + enable_depth=False, + enable_pointcloud=False, + inject_legacy_assets=True, + ) + # MujocoSimModule's ``odom`` Out is the sole producer of ``/odom`` + # now — the coordinator no longer polls the whole-body adapter for + # base pose (read_odom was dropped from the Protocol). autoconnect + # maps ``(odom, PoseStamped)`` to ``/odom`` by default; no override. + _adapter_type = "sim_mujoco_g1" + _adapter_address = _MJCF_PATH + _tick_rate = 50.0 + _auto_arm = True + _auto_dry_run = False + _default_ramp_seconds = 0.0 + _decimation: int | None = 1 + # Sim physics holds the arms between trajectories on its own -- no + # servo task needed. + _arm_holder: TaskConfig | None = None +else: + from dimos.robot.unitree.g1.wholebody_connection import G1WholeBodyConnection + + # Real-hw backend: DDS connection module + transport_lcm adapter. + _backend = G1WholeBodyConnection.blueprint(release_sport_mode=True) + _adapter_type = "transport_lcm" + _adapter_address = "" + _tick_rate = 500.0 + # Real hardware: come up unarmed + dry-run; operator must click + # Activate (10 s ramp) after verifying commands. + _auto_arm = False + _auto_dry_run = True + _default_ramp_seconds = 10.0 + _decimation = None # task default (10) pairs with 500 Hz tick. + # Real hardware needs the arms held -- kd damping alone would let + # them sag toward singular configurations between trajectories. + _arm_holder = TaskConfig( + name="servo_arms", + type="servo", + joint_names=g1_arms, + priority=10, + auto_start=True, + params={"default_positions": ARM_DEFAULT_POSE}, + ) + + +def _g1_groot_rerun_blueprint() -> Any: + import rerun as rr + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Spatial3DView( + origin="world", + name="G1 GR00T WBC", + background=rrb.Background(kind="SolidColor", color=[0, 0, 0]), + line_grid=rrb.LineGrid3D( + plane=rr.components.Plane3D.XY.with_distance(0.0), + ), + ), + rrb.TimePanel(state="collapsed"), + ) + + +def _static_g1_body(rr: Any) -> Any: + return rr.Boxes3D( + half_sizes=[0.25, 0.20, 0.6], + centers=[[0.0, 0.0, 0.6]], + colors=[(0, 255, 127)], + fill_mode="MajorWireframe", + ) + + +_rerun_config = { + "blueprint": _g1_groot_rerun_blueprint, + "static": { + # MujocoSimModule logs odom as a Transform3D at world/odom. + # This body marker inherits that transform, giving dimos-viewer + # a visible robot anchor until a richer joint/URDF view exists. + "world/odom/g1": _static_g1_body, + }, +} + + +def _viewer() -> Any: + if global_config.viewer == "none": + return autoconnect() + if global_config.viewer != "rerun": + raise ValueError(f"Unsupported viewer backend for G1 GR00T WBC: {global_config.viewer}") + + from dimos.visualization.rerun.bridge import RerunBridgeModule + from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + + return autoconnect( + RerunBridgeModule.blueprint( + **_rerun_config, + rerun_open=global_config.rerun_open, + rerun_web=global_config.rerun_web, + ), + RerunWebSocketServer.blueprint(), + ) + + +_coordinator = ControlCoordinator.blueprint( + tick_rate=_tick_rate, + publish_joint_state=True, + joint_state_frame_id="coordinator", + hardware=[ + HardwareComponent( + hardware_id="g1", + hardware_type=HardwareType.WHOLE_BODY, + joints=g1_joints, + adapter_type=_adapter_type, + address=_adapter_address, + auto_enable=True, + wb_config=WholeBodyConfig(kp=tuple(G1_GROOT_KP), kd=tuple(G1_GROOT_KD)), + ), + ], + tasks=[ + TaskConfig( + name="groot_wbc", + type="g1_groot_wbc", + joint_names=g1_legs_waist, + priority=50, + auto_start=True, + params={ + "model_path": _GROOT_MODEL_DIR, + "hardware_id": "g1", + "auto_arm": _auto_arm, + "auto_dry_run": _auto_dry_run, + "default_ramp_seconds": _default_ramp_seconds, + "decimation": _decimation, + }, + ), + *([_arm_holder] if _arm_holder is not None else []), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("joint_command", JointState): LCMTransport("/g1/joint_command", JointState), + ("twist_command", Twist): LCMTransport("/g1/cmd_vel", Twist), + ("tele_cmd_vel", Twist): LCMTransport("/g1/cmd_vel", Twist), + # Real-hw only: the transport_lcm adapter speaks to + # G1WholeBodyConnection over these topics. autoconnect already + # matches by (name, type) so sim doesn't need them -- they're + # harmless when the sim engine doesn't expose those ports. + ("motor_states", JointState): LCMTransport("/g1/motor_states", JointState), + ("imu", Imu): LCMTransport("/g1/imu", Imu), + ("motor_command", MotorCommandArray): LCMTransport("/g1/motor_command", MotorCommandArray), + } +) + +unitree_g1_groot_wbc = autoconnect(_backend, _coordinator, _viewer()).global_config( + robot_model="unitree_g1" +) + +__all__ = ["unitree_g1_groot_wbc"] diff --git a/dimos/robot/unitree/g1/wholebody_connection.py b/dimos/robot/unitree/g1/wholebody_connection.py index f4fb762bae..f062f41157 100644 --- a/dimos/robot/unitree/g1/wholebody_connection.py +++ b/dimos/robot/unitree/g1/wholebody_connection.py @@ -51,7 +51,12 @@ _NUM_MOTORS = 29 _NUM_MOTOR_SLOTS = 35 # G1 hg LowCmd has 35 slots; only 29 are used -_MODE_MACHINE_WAIT_S = 10.0 +# mode_machine is a static value identifying the G1 firmware/hardware +# variant. Older code read it back from the first LowState frame and +# echoed it into LowCmd; that callback path is unreliable on macOS +# cyclonedds, and the value never changes for a given robot, so we +# hardcode it. 29-DOF G1 (gear) reports 5. +_MODE_MACHINE_G1: int = 5 # Joint names sourced from the canonical helper. Order matches the motor index # convention above. Single-source-of-truth so any coordinator-side adapter built @@ -84,8 +89,12 @@ def __init__(self, **kwargs: Any) -> None: self._low_cmd: LowCmd_ | None = None self._low_state: LowState_ | None = None self._crc: CRC | None = None - # mode_machine: read from first LowState, echoed back in every LowCmd. + # mode_machine: hardcoded at start() to the static value for the + # 29-DOF G1. We log a one-shot warning if the first LowState we + # read disagrees — that's the early signal of firmware drift on a + # variant that needs a different value. self._mode_machine: int | None = None + self._mode_machine_verified: bool = False # Guards _low_cmd / _low_state / _mode_machine across DDS, publish, and LCM threads. self._lock = threading.Lock() self._stop_event = threading.Event() @@ -119,12 +128,19 @@ def start(self) -> None: self._publisher = ChannelPublisher("rt/lowcmd", LowCmd_) self._publisher.Init() + # Passive subscriber — Read() per tick from the publish loop. The + # callback variant (Init(self._on_low_state, 10)) doesn't fire + # reliably under cyclonedds on macOS, which used to leave us + # blocked here forever waiting for a first LowState. self._subscriber = ChannelSubscriber("rt/lowstate", LowState_) - self._subscriber.Init(self._on_low_state, 10) + self._subscriber.Init(None, 0) # POS_STOP/VEL_STOP + zero gains so the robot can't twitch pre-command. self._low_cmd = unitree_hg_msg_dds__LowCmd_() self._low_cmd.mode_pr = 0 # PR (pitch/roll) mode + # mode_machine is a static value (see comment above the constant). + self._mode_machine = _MODE_MACHINE_G1 + self._low_cmd.mode_machine = self._mode_machine for i in range(_NUM_MOTOR_SLOTS): self._low_cmd.motor_cmd[i].mode = 0x01 # enable self._low_cmd.motor_cmd[i].q = POS_STOP @@ -141,16 +157,6 @@ def start(self) -> None: else: logger.info("Skipping sport mode release (release_sport_mode=False)") - logger.info("Waiting for first LowState to capture mode_machine...") - deadline = time.time() + _MODE_MACHINE_WAIT_S - while self._mode_machine is None and time.time() < deadline: - time.sleep(0.1) - if self._mode_machine is None: - raise RuntimeError( - f"Timed out after {_MODE_MACHINE_WAIT_S:.1f}s waiting for " - f"first LowState — mode_machine never captured" - ) - logger.info(f"G1WholeBodyConnection connected (mode_machine={self._mode_machine})") self.register_disposable(Disposable(self.motor_command.subscribe(self._on_motor_command))) @@ -167,6 +173,29 @@ def stop(self) -> None: self._publish_thread.join(timeout=DEFAULT_THREAD_JOIN_TIMEOUT) self._publish_thread = None + # Final safe-stop lowcmd: disable every motor (mode=0x00, kp=kd=0, + # tau=0). Without this, the motors freeze stiffly at whatever + # the last commanded pose was and the next ``dimos run`` opens + # against a robot that's actively fighting its own controllers + # — observed as horrible mechanical noise during sport-mode + # release. Best-effort: any failure is logged, not raised, so + # cleanup still drains the DDS endpoints. + if self._publisher is not None and self._low_cmd is not None and self._crc is not None: + try: + with self._lock: + for i in range(_NUM_MOTOR_SLOTS): + self._low_cmd.motor_cmd[i].mode = 0x00 # disable + self._low_cmd.motor_cmd[i].q = POS_STOP + self._low_cmd.motor_cmd[i].dq = VEL_STOP + self._low_cmd.motor_cmd[i].kp = 0 + self._low_cmd.motor_cmd[i].kd = 0 + self._low_cmd.motor_cmd[i].tau = 0 + self._low_cmd.crc = self._crc.Crc(self._low_cmd) + self._publisher.Write(self._low_cmd) + logger.info("Sent safe-stop lowcmd (motors disabled)") + except (OSError, RuntimeError, AttributeError) as e: + logger.warning(f"Safe-stop lowcmd failed: {e}") + # Close DDS endpoints explicitly — GC-based cleanup races with in-flight # callbacks and segfaults on process exit (mirrors the Go2 adapter). if self._subscriber is not None: @@ -190,56 +219,114 @@ def stop(self) -> None: logger.info("G1WholeBodyConnection disconnected") super().stop() + def _drain_low_state(self) -> None: + """Pull the freshest LowState frame off the subscriber and stash it.""" + sub = self._subscriber + if sub is None: + return + fresh = sub.Read() + if fresh is None: + return + with self._lock: + self._low_state = fresh + self._verify_mode_machine_once(fresh) + + def _verify_mode_machine_once(self, sample: object) -> None: + """One-shot sanity check: log if the hardcoded mode_machine + doesn't match what the firmware reports. Commands with a + wrong mode_machine are silently rejected, so this prevents + a confusing "everything looks fine but the robot doesn't + move" failure mode on G1 variants we haven't tested.""" + if self._mode_machine_verified: + return + self._mode_machine_verified = True + actual = int(getattr(sample, "mode_machine", -1)) + if actual != self._mode_machine: + logger.warning( + f"mode_machine mismatch: hardcoded {self._mode_machine}, " + f"robot reports {actual}. Commands may be silently rejected " + f"by firmware — set _MODE_MACHINE_G1 to {actual} for this variant." + ) + + def _snapshot_motor_imu( + self, + ) -> ( + tuple[ + list[float], + list[float], + list[float], + tuple[float, float, float, float], + tuple[float, float, float], + tuple[float, float, float], + ] + | None + ): + """Return the latest real motor/IMU sample, or None before first LowState.""" + with self._lock: + ls = self._low_state + if ls is None: + return None + return ( + [ls.motor_state[i].q for i in range(_NUM_MOTORS)], + [ls.motor_state[i].dq for i in range(_NUM_MOTORS)], + [ls.motor_state[i].tau_est for i in range(_NUM_MOTORS)], + tuple(ls.imu_state.quaternion), + tuple(ls.imu_state.gyroscope), + tuple(ls.imu_state.accelerometer), + ) + + def _publish_motor_state_and_imu( + self, + now: float, + frame_id: str, + positions: list[float], + velocities: list[float], + efforts: list[float], + quat: tuple[float, float, float, float], + gyro: tuple[float, float, float], + accel: tuple[float, float, float], + ) -> None: + self.motor_states.publish( + JointState( + ts=now, + frame_id=frame_id, + name=G1_JOINT_NAMES, + position=positions, + velocity=velocities, + effort=efforts, + ) + ) + # Unitree quat is (w,x,y,z); dimos Quaternion is (x,y,z,w). + self.imu.publish( + Imu( + ts=now, + frame_id=frame_id, + orientation=Quaternion(quat[1], quat[2], quat[3], quat[0]), + angular_velocity=Vector3(gyro[0], gyro[1], gyro[2]), + linear_acceleration=Vector3(accel[0], accel[1], accel[2]), + ) + ) + def _publish_loop(self) -> None: period = 1.0 / float(self.config.publish_rate_hz) next_tick = time.perf_counter() frame_id = self.config.frame_id - # Identity quaternion + zeros while LowState hasn't arrived (start() blocks - # for it, but the publish loop may also see _low_state cleared during stop()). - zero_quat = (1.0, 0.0, 0.0, 0.0) - zero_vec3 = (0.0, 0.0, 0.0) - while not self._stop_event.is_set(): - with self._lock: - ls = self._low_state - if ls is None: - positions: list[float] = [0.0] * _NUM_MOTORS - velocities: list[float] = [0.0] * _NUM_MOTORS - efforts: list[float] = [0.0] * _NUM_MOTORS - quat = zero_quat - gyro = zero_vec3 - accel = zero_vec3 - else: - positions = [ls.motor_state[i].q for i in range(_NUM_MOTORS)] - velocities = [ls.motor_state[i].dq for i in range(_NUM_MOTORS)] - efforts = [ls.motor_state[i].tau_est for i in range(_NUM_MOTORS)] - quat = tuple(ls.imu_state.quaternion) - gyro = tuple(ls.imu_state.gyroscope) - accel = tuple(ls.imu_state.accelerometer) - - now = time.time() - self.motor_states.publish( - JointState( - ts=now, - frame_id=frame_id, - name=G1_JOINT_NAMES, - position=positions, - velocity=velocities, - effort=efforts, - ) - ) - - # Unitree quat is (w,x,y,z); dimos Quaternion is (x,y,z,w). - self.imu.publish( - Imu( - ts=now, + self._drain_low_state() + sample = self._snapshot_motor_imu() + if sample is not None: + positions, velocities, efforts, quat, gyro, accel = sample + self._publish_motor_state_and_imu( + now=time.time(), frame_id=frame_id, - orientation=Quaternion(quat[1], quat[2], quat[3], quat[0]), - angular_velocity=Vector3(gyro[0], gyro[1], gyro[2]), - linear_acceleration=Vector3(accel[0], accel[1], accel[2]), + positions=positions, + velocities=velocities, + efforts=efforts, + quat=quat, + gyro=gyro, + accel=accel, ) - ) next_tick += period sleep_for = next_tick - time.perf_counter() @@ -276,15 +363,17 @@ def _on_motor_command(self, msg: MotorCommandArray) -> None: self._low_cmd.crc = self._crc.Crc(self._low_cmd) self._publisher.Write(self._low_cmd) - def _on_low_state(self, msg: Any) -> None: - """rt/lowstate callback — captures mode_machine and the latest snapshot.""" - with self._lock: - self._low_state = msg - if self._mode_machine is None: - self._mode_machine = msg.mode_machine - def _release_sport_mode(self) -> None: - """Loop ReleaseMode until MotionSwitcher reports no active controller.""" + """Loop ReleaseMode until MotionSwitcher reports no active controller. + + Bails early if the first CheckMode reports nothing active. That + matters for back-to-back ``dimos run`` invocations: the first run + already released sport mode, so on a clean second start there's + nothing to release. Calling ReleaseMode anyway opens a window + where motor controllers are mid-handoff while we're already + publishing rt/lowcmd, which has been observed to cause horrible + mechanical noise from the gearboxes. + """ from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import ( MotionSwitcherClient, ) @@ -293,8 +382,14 @@ def _release_sport_mode(self) -> None: msc.SetTimeout(5.0) msc.Init() - # CheckMode returns (status, None) once nothing is active — null-tolerant. + # CheckMode returns (status, None) — or (status, {"name": ""}) on + # some firmwares — once nothing is active. Treat both as "already + # released" and return without poking ReleaseMode. _status, result = msc.CheckMode() + if not result or not result.get("name"): + logger.info("Sport mode already released — skipping ReleaseMode") + return + while result and result.get("name"): msc.ReleaseMode() _status, result = msc.CheckMode() diff --git a/dimos/simulation/adapters/whole_body/g1.py b/dimos/simulation/adapters/whole_body/g1.py new file mode 100644 index 0000000000..b0b8693642 --- /dev/null +++ b/dimos/simulation/adapters/whole_body/g1.py @@ -0,0 +1,205 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MuJoCo simulation ``WholeBodyAdapter`` for the Unitree G1. + +Pairs with ``MujocoSimModule`` (in-process MuJoCo engine + SHM publisher). +The blueprint composes both modules; they share the same ``MujocoEngine`` +indirectly via SHM keyed on the MJCF path. + +The adapter conforms to the same ``WholeBodyAdapter`` Protocol the real-hw +DDS adapter implements, so ControlCoordinator (and the GR00T WBC task on +top of it) can't tell sim from real. +""" + +from __future__ import annotations + +import math +from pathlib import Path +import time +from typing import TYPE_CHECKING, Any + +from dimos.hardware.whole_body.spec import ( + POS_STOP, + IMUState, + MotorCommand, + MotorState, +) +from dimos.simulation.engines.mujoco_shm import ( + ManipShmReader, + shm_key_from_path, +) +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.hardware.whole_body.registry import WholeBodyAdapterRegistry + +logger = setup_logger() + +_NUM_MOTORS = 29 + +_READY_WAIT_TIMEOUT_S = 60.0 +_READY_WAIT_POLL_S = 0.1 +_ATTACH_RETRY_TIMEOUT_S = 30.0 +_ATTACH_RETRY_POLL_S = 0.2 + + +class SimMujocoG1WholeBodyAdapter: + """G1 ``WholeBodyAdapter`` that proxies to a ``MujocoSimModule`` via SHM. + + The sim module owns the engine and publishes joint state + IMU into + SHM each step; this adapter reads them and forwards per-joint + (q, kp, kd, tau) commands back into SHM for the engine's pre-step + PD-with-feedforward hook to apply. + + ``address`` (the MJCF XML path) is the discovery key — both sides + derive the same SHM names from it via ``shm_key_from_path``. + """ + + def __init__( + self, + address: str | Path | None = None, + domain_id: int = 0, + **_: Any, + ) -> None: + if address is None: + raise ValueError( + "SimMujocoG1WholeBodyAdapter: address (MJCF XML path) is required — " + "set HardwareComponent.address to the same MJCF path the " + "MujocoSimModule loads." + ) + self._address = address + self._shm_key = shm_key_from_path(address) + self._shm: ManipShmReader | None = None + self._connected = False + + # Lifecycle + + def connect(self) -> bool: + # Attach with retry — MujocoSimModule may still be starting up. + deadline = time.monotonic() + _ATTACH_RETRY_TIMEOUT_S + while True: + try: + self._shm = ManipShmReader(self._shm_key) + break + except FileNotFoundError: + if time.monotonic() > deadline: + logger.error( + "SimMujocoG1WholeBodyAdapter: SHM buffers not found", + address=self._address, + shm_key=self._shm_key, + timeout_s=_ATTACH_RETRY_TIMEOUT_S, + ) + return False + time.sleep(_ATTACH_RETRY_POLL_S) + + # Wait for the sim to signal ready (engine connected, first + # joint-state packet written). Without this the first + # read_motor_states() returns zeros and the WBC obs is junk. + deadline = time.monotonic() + _READY_WAIT_TIMEOUT_S + while not self._shm.is_ready(): + if time.monotonic() > deadline: + logger.error( + "SimMujocoG1WholeBodyAdapter: sim module not ready", + timeout_s=_READY_WAIT_TIMEOUT_S, + ) + self._shm.cleanup() + self._shm = None + return False + time.sleep(_READY_WAIT_POLL_S) + + self._connected = True + logger.info( + "SimMujocoG1WholeBodyAdapter connected", + num_motors=_NUM_MOTORS, + shm_key=self._shm_key, + ) + return True + + def disconnect(self) -> None: + if self._shm is not None: + try: + self._shm.cleanup() + except Exception as e: # best-effort cleanup + logger.warning(f"SHM cleanup raised: {e}") + self._shm = None + self._connected = False + + def is_connected(self) -> bool: + return self._connected and self._shm is not None + + # IO (WholeBodyAdapter protocol) + + def read_motor_states(self) -> list[MotorState]: + if not self._connected or self._shm is None: + return [MotorState()] * _NUM_MOTORS + positions = self._shm.read_positions(_NUM_MOTORS) + velocities = self._shm.read_velocities(_NUM_MOTORS) + efforts = self._shm.read_efforts(_NUM_MOTORS) + return [ + MotorState(q=positions[i], dq=velocities[i], tau=efforts[i]) for i in range(_NUM_MOTORS) + ] + + def has_motor_states(self) -> bool: + # Sim ground truth is available the moment SHM attaches. + # No ramp-up window like real DDS adapters need before the + # first state msg arrives. + return self._connected and self._shm is not None + + def read_imu(self) -> IMUState: + if not self._connected or self._shm is None: + return IMUState() + quat, gyro, accel = self._shm.read_imu() + # Derive ZYX Euler from the quaternion — matches the real G1 adapter. + w, x, y, z = quat + sinr = 2.0 * (w * x + y * z) + cosr = 1.0 - 2.0 * (x * x + y * y) + roll = math.atan2(sinr, cosr) + sinp = 2.0 * (w * y - z * x) + pitch = math.copysign(math.pi / 2.0, sinp) if abs(sinp) >= 1.0 else math.asin(sinp) + siny = 2.0 * (w * z + x * y) + cosy = 1.0 - 2.0 * (y * y + z * z) + yaw = math.atan2(siny, cosy) + return IMUState( + quaternion=quat, + gyroscope=gyro, + accelerometer=accel, + rpy=(roll, pitch, yaw), + ) + + def write_motor_commands(self, commands: list[MotorCommand]) -> bool: + if not self._connected or self._shm is None: + return False + if len(commands) != _NUM_MOTORS: + logger.error( + f"SimMujocoG1WholeBodyAdapter: expected {_NUM_MOTORS} commands, got {len(commands)}" + ) + return False + # Flatten the per-motor command into per-joint arrays. POS_STOP + # ("no command") is replaced with 0.0 — the engine's PD only + # acts when kp > 0 anyway, so a zeroed q is harmless. + q = [cmd.q if cmd.q != POS_STOP else 0.0 for cmd in commands] + kp = [cmd.kp for cmd in commands] + kd = [cmd.kd for cmd in commands] + tau = [cmd.tau for cmd in commands] + self._shm.write_pd_tau_command(q, kp, kd, tau) + return True + + +def register(registry: WholeBodyAdapterRegistry) -> None: + """Register with the whole-body adapter registry.""" + registry.register("sim_mujoco_g1", SimMujocoG1WholeBodyAdapter) + + +__all__ = ["SimMujocoG1WholeBodyAdapter"] diff --git a/dimos/simulation/engines/mujoco_engine.py b/dimos/simulation/engines/mujoco_engine.py index ada85dd477..9de32d1378 100644 --- a/dimos/simulation/engines/mujoco_engine.py +++ b/dimos/simulation/engines/mujoco_engine.py @@ -86,13 +86,22 @@ def __init__( cameras: list[CameraConfig] | None = None, on_before_step: StepHook | None = None, on_after_step: StepHook | None = None, + assets: dict[str, bytes] | None = None, ) -> None: super().__init__(config_path=config_path, headless=headless) self._on_before_step: StepHook | None = on_before_step self._on_after_step: StepHook | None = on_after_step xml_path = self._resolve_xml_path(config_path) - self._model = mujoco.MjModel.from_xml_path(str(xml_path)) + if assets is not None: + # MJCFs that reference meshes by bare filename (e.g. menagerie + # G1) need the mesh bytes injected by name; from_xml_path can't + # find them on disk. + with open(xml_path) as f: + xml_str = f.read() + self._model = mujoco.MjModel.from_xml_string(xml_str, assets=assets) + else: + self._model = mujoco.MjModel.from_xml_path(str(xml_path)) self._xml_path = xml_path self._data = mujoco.MjData(self._model) @@ -125,6 +134,19 @@ def __init__( self._camera_frames: dict[str, CameraFrame] = {} self._camera_lock = threading.Lock() + def set_step_hooks( + self, + before: StepHook | None = None, + after: StepHook | None = None, + ) -> None: + """Install pre/post step hooks after construction. + + Use when the hooks depend on engine state (joint count, gripper + index) that isn't known until the model is loaded. + """ + self._on_before_step = before + self._on_after_step = after + def _resolve_xml_path(self, config_path: Path) -> Path: if config_path is None: raise ValueError("config_path is required for MuJoCo simulation loading") @@ -334,6 +356,14 @@ def joint_names(self) -> list[str]: def model(self) -> mujoco.MjModel: return self._model + @property + def data(self) -> mujoco.MjData: + """Live MjData. In-process consumers (sensors, PD hooks) read it + directly; physics integration in the sim thread mutates it under + ``self._lock`` so reads inside the same MujocoEngine instance are + coherent without extra locking.""" + return self._data + @property def joint_positions(self) -> list[float]: with self._lock: diff --git a/dimos/simulation/engines/mujoco_shm.py b/dimos/simulation/engines/mujoco_shm.py index c0623c7915..c8d71100cd 100644 --- a/dimos/simulation/engines/mujoco_shm.py +++ b/dimos/simulation/engines/mujoco_shm.py @@ -40,26 +40,38 @@ logger = setup_logger() -# Upper bound on joint count per sim. Arms + gripper are typically <= 10. -MAX_JOINTS = 16 +# Upper bound on joint count per sim. Manipulators use ≤10; humanoids +# (Unitree G1: 29) push higher. 32 leaves headroom while keeping all +# per-joint buffers tiny (32 floats = 256 B). +MAX_JOINTS = 32 _FLOAT_BYTES = 8 # float64 _INT32_BYTES = 4 +# IMU layout: quat (4) + gyro (3) + accel (3) = 10 floats. +_IMU_FLOATS = 10 + _joint_array_size = MAX_JOINTS * _FLOAT_BYTES # float64 array # Element counts for control and sequence arrays. _NUM_CTRL_FIELDS = 4 # [ready, stop, command_mode, num_joints] -_NUM_SEQ_COUNTERS = 8 # one per buffer type +_NUM_SEQ_COUNTERS = 12 # one per buffer type (manipulator + WB additions) # Buffer sizes (in bytes). # Keys are short to stay under macOS PSHMNAMLEN (31 bytes). _shm_sizes = { + # Manipulator-shared layout "pos": _joint_array_size, "vel": _joint_array_size, "eff": _joint_array_size, "pos_t": _joint_array_size, "vel_t": _joint_array_size, "grp": 2 * _FLOAT_BYTES, # [gripper_position, gripper_target] + # Whole-body additions (unused by manipulator path). + "imu": _IMU_FLOATS * _FLOAT_BYTES, # [w,x,y,z, gx,gy,gz, ax,ay,az] + "kp_t": _joint_array_size, # per-joint position-gain target + "kd_t": _joint_array_size, # per-joint velocity-gain target + "tau_t": _joint_array_size, # per-joint feedforward torque + # Bookkeeping "seq": _NUM_SEQ_COUNTERS * _FLOAT_BYTES, # int64 counters "ctl": _NUM_CTRL_FIELDS * _INT32_BYTES, # [ready, stop, command_mode, num_joints] } @@ -72,6 +84,11 @@ SEQ_VELOCITY_CMD = 4 SEQ_GRIPPER_STATE = 5 SEQ_GRIPPER_CMD = 6 +# Whole-body additions +SEQ_IMU = 7 +SEQ_KP_CMD = 8 +SEQ_KD_CMD = 9 +SEQ_TAU_CMD = 10 # Control indices. CTRL_READY = 0 @@ -82,6 +99,9 @@ # Command modes. CMD_MODE_POSITION = 0 CMD_MODE_VELOCITY = 1 +# Whole-body PD-with-feedforward: ctrl = kp*(q_t - q) + kd*(0 - dq) + tau_t. +# Per-step kp/kd lets a policy retune gains online if it wants to. +CMD_MODE_PD_TAU = 2 _NAME_PREFIX = "dmjm" @@ -114,7 +134,13 @@ def _unregister(shm: SharedMemory) -> SharedMemory: @dataclass(frozen=True) class ManipShmSet: - """Frozen set of named SharedMemory buffers for manipulator IPC.""" + """Frozen set of named SharedMemory buffers for sim ↔ adapter IPC. + + Despite the name (kept for backward compat with existing manipulator + consumers), the layout now also covers whole-body needs: IMU, per-joint + PD gain commands, and per-joint feedforward torque commands. The + extra buffers are unused by the manipulator path. + """ pos: SharedMemory vel: SharedMemory @@ -122,6 +148,12 @@ class ManipShmSet: pos_t: SharedMemory vel_t: SharedMemory grp: SharedMemory + # Whole-body additions + imu: SharedMemory + kp_t: SharedMemory + kd_t: SharedMemory + tau_t: SharedMemory + # Bookkeeping seq: SharedMemory ctl: SharedMemory @@ -170,6 +202,9 @@ def __init__(self, key: str) -> None: self._last_pos_cmd_seq = 0 self._last_vel_cmd_seq = 0 self._last_gripper_cmd_seq = 0 + self._last_kp_cmd_seq = 0 + self._last_kd_cmd_seq = 0 + self._last_tau_cmd_seq = 0 # Zero everything. for buf in self.shm.as_list(): np.ndarray((buf.size,), dtype=np.uint8, buffer=buf.buf)[:] = 0 @@ -226,6 +261,47 @@ def read_gripper_command(self) -> float | None: def read_command_mode(self) -> int: return int(self._control()[CTRL_COMMAND_MODE]) + # Whole-body additions + + def write_imu( + self, + quaternion: tuple[float, float, float, float], + gyroscope: tuple[float, float, float], + accelerometer: tuple[float, float, float], + ) -> None: + """Write IMU sample. Quaternion is (w, x, y, z).""" + arr = self._array(self.shm.imu, _IMU_FLOATS, np.float64) + arr[0:4] = quaternion + arr[4:7] = gyroscope + arr[7:10] = accelerometer + self._increment_seq(SEQ_IMU) + + def read_kp_command(self, num_joints: int) -> NDArray[np.float64] | None: + """Per-joint position-gain target if a new command landed since last call.""" + seq = self._get_seq(SEQ_KP_CMD) + if seq <= self._last_kp_cmd_seq: + return None + self._last_kp_cmd_seq = seq + arr = self._array(self.shm.kp_t, MAX_JOINTS, np.float64) + return arr[:num_joints].copy() + + def read_kd_command(self, num_joints: int) -> NDArray[np.float64] | None: + seq = self._get_seq(SEQ_KD_CMD) + if seq <= self._last_kd_cmd_seq: + return None + self._last_kd_cmd_seq = seq + arr = self._array(self.shm.kd_t, MAX_JOINTS, np.float64) + return arr[:num_joints].copy() + + def read_tau_command(self, num_joints: int) -> NDArray[np.float64] | None: + """Per-joint feedforward torque if a new command landed since last call.""" + seq = self._get_seq(SEQ_TAU_CMD) + if seq <= self._last_tau_cmd_seq: + return None + self._last_tau_cmd_seq = seq + arr = self._array(self.shm.tau_t, MAX_JOINTS, np.float64) + return arr[:num_joints].copy() + def signal_ready(self, num_joints: int) -> None: ctrl = self._control() ctrl[CTRL_NUM_JOINTS] = num_joints @@ -314,6 +390,77 @@ def write_gripper_command(self, position: float) -> None: arr[1] = position self._increment_seq(SEQ_GRIPPER_CMD) + # Whole-body additions + + def read_imu( + self, + ) -> tuple[ + tuple[float, float, float, float], + tuple[float, float, float], + tuple[float, float, float], + ]: + """Read IMU sample: ((qw, qx, qy, qz), (gx, gy, gz), (ax, ay, az)).""" + arr = np.ndarray((_IMU_FLOATS,), dtype=np.float64, buffer=self.shm.imu.buf) + return ( + (float(arr[0]), float(arr[1]), float(arr[2]), float(arr[3])), + (float(arr[4]), float(arr[5]), float(arr[6])), + (float(arr[7]), float(arr[8]), float(arr[9])), + ) + + def write_kp_command(self, kp: list[float]) -> None: + """Per-joint position-gain target. Switches command mode to PD+τ.""" + n = min(len(kp), MAX_JOINTS) + arr = np.ndarray((MAX_JOINTS,), dtype=np.float64, buffer=self.shm.kp_t.buf) + arr[:n] = kp[:n] + self._set_command_mode(CMD_MODE_PD_TAU) + self._increment_seq(SEQ_KP_CMD) + + def write_kd_command(self, kd: list[float]) -> None: + n = min(len(kd), MAX_JOINTS) + arr = np.ndarray((MAX_JOINTS,), dtype=np.float64, buffer=self.shm.kd_t.buf) + arr[:n] = kd[:n] + self._set_command_mode(CMD_MODE_PD_TAU) + self._increment_seq(SEQ_KD_CMD) + + def write_tau_command(self, tau: list[float]) -> None: + """Per-joint feedforward torque, applied on top of PD.""" + n = min(len(tau), MAX_JOINTS) + arr = np.ndarray((MAX_JOINTS,), dtype=np.float64, buffer=self.shm.tau_t.buf) + arr[:n] = tau[:n] + self._set_command_mode(CMD_MODE_PD_TAU) + self._increment_seq(SEQ_TAU_CMD) + + def write_pd_tau_command( + self, + positions: list[float], + kp: list[float], + kd: list[float], + tau: list[float], + ) -> None: + """Write a whole-body PD+tau command without transient mode flips. + + The sim engine runs in a different process, so setting position mode + first and PD mode later creates a small but real race. Write all arrays, + publish PD mode once, then bump the sequence counters. + """ + n_pos = min(len(positions), MAX_JOINTS) + n_kp = min(len(kp), MAX_JOINTS) + n_kd = min(len(kd), MAX_JOINTS) + n_tau = min(len(tau), MAX_JOINTS) + np.ndarray((MAX_JOINTS,), dtype=np.float64, buffer=self.shm.pos_t.buf)[:n_pos] = positions[ + :n_pos + ] + np.ndarray((MAX_JOINTS,), dtype=np.float64, buffer=self.shm.kp_t.buf)[:n_kp] = kp[:n_kp] + np.ndarray((MAX_JOINTS,), dtype=np.float64, buffer=self.shm.kd_t.buf)[:n_kd] = kd[:n_kd] + np.ndarray((MAX_JOINTS,), dtype=np.float64, buffer=self.shm.tau_t.buf)[:n_tau] = tau[:n_tau] + self._set_command_mode(CMD_MODE_PD_TAU) + self._increment_seq(SEQ_KP_CMD) + self._increment_seq(SEQ_KD_CMD) + self._increment_seq(SEQ_TAU_CMD) + # Position is the engine-side trigger for latching a new PD target, + # so publish it last after gains/torque are visible. + self._increment_seq(SEQ_POSITION_CMD) + def is_ready(self) -> bool: return bool(self._control()[CTRL_READY] == 1) diff --git a/dimos/simulation/engines/mujoco_sim_module.py b/dimos/simulation/engines/mujoco_sim_module.py index 3d2ff927fe..8bf502b9bd 100644 --- a/dimos/simulation/engines/mujoco_sim_module.py +++ b/dimos/simulation/engines/mujoco_sim_module.py @@ -32,6 +32,8 @@ import time from typing import Any +import mujoco +import numpy as np from pydantic import Field import reactivex as rx from scipy.spatial.transform import Rotation as R @@ -40,11 +42,13 @@ from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.hardware.sensors.camera.spec import DepthCameraConfig, DepthCameraHardware +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.sensor_msgs.Imu import Imu from dimos.msgs.sensor_msgs.JointState import JointState from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.simulation.engines.mujoco_engine import ( @@ -53,6 +57,7 @@ MujocoEngine, ) from dimos.simulation.engines.mujoco_shm import ( + CMD_MODE_PD_TAU, ManipShmWriter, shm_key_from_path, ) @@ -61,6 +66,17 @@ logger = setup_logger() + +def _find_sensor_slice(model: mujoco.MjModel, *names: str, dim: int = 3) -> slice | None: + """Return the first matching MJCF sensor's slice into sensordata, or None.""" + for n in names: + sid = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_SENSOR, n) # type: ignore[attr-defined] + if sid >= 0: + adr = int(model.sensor_adr[sid]) + return slice(adr, adr + dim) + return None + + _RX180 = R.from_euler("x", 180, degrees=True) @@ -71,10 +87,101 @@ def _default_identity_transform() -> Transform: ) +class _WholeBodySimHooks: + """Per-step bridge between MuJoCo actuators and whole-body SHM.""" + + def __init__( + self, + shm: ManipShmWriter, + dof: int, + *, + gripper_idx: int | None = None, + gripper_ctrl_range: tuple[float, float] = (0.0, 1.0), + gripper_joint_range: tuple[float, float] = (0.0, 1.0), + ) -> None: + self._shm = shm + self._dof = dof + self._gripper_idx = gripper_idx + self._gripper_ctrl_range = gripper_ctrl_range + self._gripper_joint_range = gripper_joint_range + self._latest_pd_pos_target: np.ndarray | None = None + self._latest_pd_kp: np.ndarray | None = None + self._latest_pd_kd: np.ndarray | None = None + self._latest_pd_tau: np.ndarray | None = None + + def pre_step(self, engine: MujocoEngine) -> None: + shm = self._shm + dof = self._dof + + pos_cmd = shm.read_position_command(dof) + if pos_cmd is not None: + if shm.read_command_mode() == CMD_MODE_PD_TAU: + self._latest_pd_pos_target = pos_cmd + else: + engine.write_joint_command(JointState(position=pos_cmd.tolist())) + + vel_cmd = shm.read_velocity_command(dof) + if vel_cmd is not None: + engine.write_joint_command(JointState(velocity=vel_cmd.tolist())) + + kp_cmd = shm.read_kp_command(dof) + if kp_cmd is not None: + self._latest_pd_kp = kp_cmd + kd_cmd = shm.read_kd_command(dof) + if kd_cmd is not None: + self._latest_pd_kd = kd_cmd + tau_cmd = shm.read_tau_command(dof) + if tau_cmd is not None: + self._latest_pd_tau = tau_cmd + + if ( + self._latest_pd_pos_target is not None + and self._latest_pd_kp is not None + and self._latest_pd_kd is not None + ): + q = np.asarray(engine.joint_positions[:dof], dtype=np.float64) + dq = np.asarray(engine.joint_velocities[:dof], dtype=np.float64) + tau_ff = self._latest_pd_tau if self._latest_pd_tau is not None else np.zeros(dof) + tau = ( + self._latest_pd_kp * (self._latest_pd_pos_target - q) + + self._latest_pd_kd * (-dq) + + tau_ff + ) + engine.write_joint_command(JointState(effort=tau.tolist())) + + if self._gripper_idx is not None: + gripper_cmd = shm.read_gripper_command() + if gripper_cmd is not None: + engine.set_position_target( + self._gripper_idx, self._gripper_joint_to_ctrl(gripper_cmd) + ) + + def post_step(self, engine: MujocoEngine) -> None: + shm = self._shm + shm.write_joint_state( + positions=engine.joint_positions, + velocities=engine.joint_velocities, + efforts=engine.joint_efforts, + ) + if self._gripper_idx is not None: + positions = engine.joint_positions + if self._gripper_idx < len(positions): + shm.write_gripper_state(positions[self._gripper_idx]) + + def _gripper_joint_to_ctrl(self, joint_position: float) -> float: + jlo, jhi = self._gripper_joint_range + clo, chi = self._gripper_ctrl_range + clamped = max(jlo, min(jhi, joint_position)) + if jhi == jlo: + return clo + t = (clamped - jlo) / (jhi - jlo) + return chi - t * (chi - clo) + + class MujocoSimModuleConfig(ModuleConfig, DepthCameraConfig): """Configuration for the unified MuJoCo simulation module.""" - address: str = "" + address: str | Path = "" headless: bool = False dof: int = 7 @@ -86,10 +193,37 @@ class MujocoSimModuleConfig(ModuleConfig, DepthCameraConfig): base_frame_id: str = "link7" base_transform: Transform | None = Field(default_factory=_default_identity_transform) align_depth_to_color: bool = True + enable_color: bool = True enable_depth: bool = True enable_pointcloud: bool = False pointcloud_fps: float = 5.0 camera_info_fps: float = 1.0 + # Inject menagerie/dimos-bundled mesh bytes (via + # dimos.simulation.mujoco.model.get_assets) into MjModel.from_xml_string. + # MJCFs that reference meshes by bare filename (G1 GR00T, Go2) need this; + # self-contained MJCFs with on-disk meshes (xarm scene.xml) don't. + inject_legacy_assets: bool = False + # MJCF sensor names used to publish IMU. The module probes these in + # order and uses the first that exists in the model; if none match + # IMU publishing stays silent. Default list covers the common + # humanoid pelvis-mounted naming conventions (menagerie + dimos + # bundled MJCFs); pass robot-specific names for other platforms. + imu_gyro_sensor_names: list[str] = Field( + default_factory=lambda: [ + "imu-pelvis-angular-velocity", + "imu-torso-angular-velocity", + "gyro_pelvis", + "imu_gyro", + ] + ) + imu_accel_sensor_names: list[str] = Field( + default_factory=lambda: [ + "imu-pelvis-linear-acceleration", + "imu-torso-linear-acceleration", + "accelerometer_pelvis", + "imu_accel", + ] + ) class MujocoSimModule( @@ -111,17 +245,32 @@ class MujocoSimModule( pointcloud: Out[PointCloud2] camera_info: Out[CameraInfo] depth_camera_info: Out[CameraInfo] + imu: Out[Imu] + # Floating-base pose (qpos[0:7]) for robots whose MJCF has a free + # joint at the root. Published every step; consumers like the viser + # viewer use this to translate the robot in world space. + odom: Out[PoseStamped] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self._engine: MujocoEngine | None = None self._shm: ManipShmWriter | None = None + self._sim_hooks: _WholeBodySimHooks | None = None self._gripper_idx: int | None = None self._gripper_ctrl_range: tuple[float, float] = (0.0, 1.0) self._gripper_joint_range: tuple[float, float] = (0.0, 1.0) self._stop_event = threading.Event() self._publish_thread: threading.Thread | None = None self._camera_info_base: CameraInfo | None = None + self._shm_ready_signaled = False + + # IMU sensor slices into MjData.sensordata, resolved once at start. + # None if the MJCF has no recognized IMU sensors (e.g. arm-only sims). + self._imu_gyro_slice: slice | None = None + self._imu_accel_slice: slice | None = None + # Quaternion is read from the floating-base qpos[3:7] when the model + # has a free joint at the root; None otherwise. + self._imu_base_qpos_slice: slice | None = None @property def _camera_link(self) -> str: @@ -167,21 +316,42 @@ def start(self) -> None: # SHM key — adapter derives the same key from the same MJCF path. shm_key = shm_key_from_path(self.config.address) self._shm = ManipShmWriter(shm_key) + self._shm_ready_signaled = False # Build engine with SHM hooks installed. - self._engine = MujocoEngine( - config_path=Path(self.config.address), - headless=self.config.headless, - cameras=[ + engine_assets: dict[str, bytes] | None = None + if self.config.inject_legacy_assets: + from dimos.simulation.mujoco.model import get_assets + + engine_assets = get_assets() + # Compose the camera list. Each registered camera blocks the + # sim thread inside _step_once (mujoco_engine._render_cameras + # does update_scene + GPU render synchronously between physics + # steps — typically 5-30 ms per camera), so registering a camera + # nobody consumes burns the 500 Hz tick deadline for nothing. + # Skip the primary camera entirely when none of color / depth / + # pointcloud is enabled. + cameras: list[CameraConfig] = [] + primary_needed = ( + self.config.enable_color or self.config.enable_depth or self.config.enable_pointcloud + ) + if primary_needed: + cameras.append( CameraConfig( name=self.config.camera_name, width=self.config.width, height=self.config.height, fps=float(self.config.fps), ) - ], - on_before_step=self._apply_shm_commands, - on_after_step=self._publish_shm_state, + ) + + # Hooks are installed via set_step_hooks() after gripper detection + # below, since they depend on the resolved gripper index. + self._engine = MujocoEngine( + config_path=Path(self.config.address), + headless=self.config.headless, + cameras=cameras, + assets=engine_assets, ) # Detect gripper (extra joint beyond dof). @@ -202,12 +372,42 @@ def start(self) -> None: joint_range=joint_range, ) + # Resolve IMU sensors once. Names come from config so robot- + # specific blueprints (G1, H1, Optimus, …) can override; manipulator + # MJCFs typically have neither — we leave the slices as None and + # skip IMU publishing for those. + self._imu_gyro_slice = _find_sensor_slice( + self._engine.model, *self.config.imu_gyro_sensor_names, dim=3 + ) + self._imu_accel_slice = _find_sensor_slice( + self._engine.model, *self.config.imu_accel_sensor_names, dim=3 + ) + # Floating-base orientation is qpos[3:7] (w,x,y,z) when the root + # joint is a free joint. Detect by checking jnt_type[0]. + if self._engine.model.njnt > 0 and int(self._engine.model.jnt_type[0]) == int( + mujoco.mjtJoint.mjJNT_FREE # type: ignore[attr-defined] + ): + self._imu_base_qpos_slice = slice(3, 7) + else: + self._imu_base_qpos_slice = None + + # Wire SHM bridge hooks. + self._sim_hooks = _WholeBodySimHooks( + self._shm, + dof=dof, + gripper_idx=self._gripper_idx, + gripper_ctrl_range=self._gripper_ctrl_range, + gripper_joint_range=self._gripper_joint_range, + ) + self._engine.set_step_hooks( + before=self._sim_hooks.pre_step, + after=self._publish_shm_and_lcm, + ) + # Start physics (sim thread spawned inside engine.connect()). if not self._engine.connect(): raise RuntimeError("MujocoSimModule: engine.connect() failed") - self._shm.signal_ready(num_joints=len(joint_names)) - # Camera intrinsics. self._build_camera_info() @@ -226,7 +426,7 @@ def start(self) -> None: ) ) - # Optional pointcloud generation. + # Optional pointcloud generation: back-projects primary camera depth. if self.config.enable_pointcloud and self.config.enable_depth: pc_interval = 1.0 / self.config.pointcloud_fps self.register_disposable( @@ -268,6 +468,7 @@ def stop(self) -> None: logger.error("SHM cleanup failed", error=str(exc)) errors.append(("shm.cleanup", exc)) + self._sim_hooks = None self._camera_info_base = None super().stop() @@ -275,51 +476,76 @@ def stop(self) -> None: op, err = errors[0] raise RuntimeError(f"MujocoSimModule.stop() failed during {op}: {err}") from err - def _apply_shm_commands(self, engine: MujocoEngine) -> None: - """Pre-step hook: pull command targets from SHM into the engine.""" + def _publish_shm_and_lcm(self, engine: MujocoEngine) -> None: + """Post-step hook: SHM writes + LCM publishes. + + This stays in the module so odom/IMU continue to flow through normal + typed ports while the whole-body adapter consumes joint state via SHM. + """ + if self._sim_hooks is not None: + self._sim_hooks.post_step(engine) shm = self._shm if shm is None: return - dof = self.config.dof - - pos_cmd = shm.read_position_command(dof) - if pos_cmd is not None: - engine.write_joint_command(JointState(position=pos_cmd.tolist())) - - vel_cmd = shm.read_velocity_command(dof) - if vel_cmd is not None: - engine.write_joint_command(JointState(velocity=vel_cmd.tolist())) - - if self._gripper_idx is not None: - gripper_cmd = shm.read_gripper_command() - if gripper_cmd is not None: - ctrl_value = self._gripper_joint_to_ctrl(gripper_cmd) - engine.set_position_target(self._gripper_idx, ctrl_value) + if not self._shm_ready_signaled: + shm.signal_ready(num_joints=len(engine.joint_names)) + self._shm_ready_signaled = True + + # Odom — when the MJCF has a free-joint root, publish base pose + # from qpos[0:7] every step. Without this, downstream consumers + # (viser viewer, nav stack) only see joint articulation, not + # base translation through the world. + data = engine.data # in-process: same MjData the sim thread mutates + if self._imu_base_qpos_slice is not None: + base_pos = data.qpos[0:3] + base_quat = data.qpos[3:7] # (w, x, y, z) per MuJoCo convention + self.odom.publish( + PoseStamped( + ts=time.time(), + frame_id="world", + position=Vector3(float(base_pos[0]), float(base_pos[1]), float(base_pos[2])), + orientation=Quaternion( + float(base_quat[1]), + float(base_quat[2]), + float(base_quat[3]), + float(base_quat[0]), + ), # PoseStamped uses x,y,z,w + ) + ) - def _publish_shm_state(self, engine: MujocoEngine) -> None: - """Post-step hook: publish joint state to SHM.""" - shm = self._shm - if shm is None: + # IMU — only if MJCF declared the sensors. + if ( + self._imu_gyro_slice is None + and self._imu_accel_slice is None + and self._imu_base_qpos_slice is None + ): return - shm.write_joint_state( - positions=engine.joint_positions, - velocities=engine.joint_velocities, - efforts=engine.joint_efforts, + if self._imu_base_qpos_slice is not None: + q = data.qpos[self._imu_base_qpos_slice] + quat = (float(q[0]), float(q[1]), float(q[2]), float(q[3])) + else: + quat = (1.0, 0.0, 0.0, 0.0) + if self._imu_gyro_slice is not None: + g = data.sensordata[self._imu_gyro_slice] + gyro = (float(g[0]), float(g[1]), float(g[2])) + else: + gyro = (0.0, 0.0, 0.0) + if self._imu_accel_slice is not None: + a = data.sensordata[self._imu_accel_slice] + accel = (float(a[0]), float(a[1]), float(a[2])) + else: + accel = (0.0, 0.0, 0.0) + shm.write_imu(quaternion=quat, gyroscope=gyro, accelerometer=accel) + # Also publish on the stream port for downstream consumers. + self.imu.publish( + Imu( + ts=time.time(), + frame_id="pelvis", + orientation=Quaternion(quat[1], quat[2], quat[3], quat[0]), + angular_velocity=Vector3(gyro[0], gyro[1], gyro[2]), + linear_acceleration=Vector3(accel[0], accel[1], accel[2]), + ) ) - if self._gripper_idx is not None: - positions = engine.joint_positions - if self._gripper_idx < len(positions): - shm.write_gripper_state(positions[self._gripper_idx]) - - def _gripper_joint_to_ctrl(self, joint_position: float) -> float: - """Map joint-space gripper position to actuator control value.""" - jlo, jhi = self._gripper_joint_range - clo, chi = self._gripper_ctrl_range - clamped = max(jlo, min(jhi, joint_position)) - if jhi == jlo: - return clo - t = (clamped - jlo) / (jhi - jlo) - return chi - t * (chi - clo) def _build_camera_info(self) -> None: if self._engine is None: @@ -382,13 +608,14 @@ def _publish_loop(self) -> None: last_timestamp = frame.timestamp ts = time.time() - color_img = Image( - data=frame.rgb, - format=ImageFormat.RGB, - frame_id=self._color_optical_frame, - ts=ts, - ) - self.color_image.publish(color_img) + if self.config.enable_color: + color_img = Image( + data=frame.rgb, + format=ImageFormat.RGB, + frame_id=self._color_optical_frame, + ts=ts, + ) + self.color_image.publish(color_img) if self.config.enable_depth: depth_img = Image( @@ -469,7 +696,10 @@ def _publish_tf(self, ts: float, frame: CameraFrame | None) -> None: ) def _generate_pointcloud(self) -> None: - if self._engine is None or self._camera_info_base is None: + if self._engine is None: + return + # Back-project the primary camera's depth image. + if self._camera_info_base is None: return frame = self._engine.read_camera(self.config.camera_name) if frame is None: