diff --git a/source/autosim/autosim/core/types.py b/source/autosim/autosim/core/types.py index 2f3b3bf..bc0b7ec 100644 --- a/source/autosim/autosim/core/types.py +++ b/source/autosim/autosim/core/types.py @@ -46,6 +46,8 @@ class SkillGoal: """The target pose of the skill.""" extra_target_poses: dict[str, torch.Tensor] | None = None """The target poses of the extra end-effectors. dict[link_name, target_pose].""" + info: dict[str, Any] = field(default_factory=dict) + """The information of the skill goal.""" @dataclass @@ -94,6 +96,9 @@ class EnvExtraInfo: object_navigate_sample_range: dict[str, tuple[float, float]] = field(default_factory=dict) """The sample range for the navigate skill. each object can have a tuple of (min_angle, max_angle) in radians.""" + object_relative_reach_axis: dict[str, dict[str, str]] = field(default_factory=dict) + """The relative reach axis for the relative reach skill. dict[object_name, dict[relative_reach_skill_name, axis]].""" + cached_initial_extra_target_offsets: dict[str, tuple[torch.Tensor, torch.Tensor]] | None = None """Cached primary-frame offsets for extra target links, reused across multiple reach-like skills.""" @@ -111,6 +116,9 @@ def get_reach_target_poses(self, object_name: str) -> list[torch.Tensor]: def get_navigate_sample_range(self, object_name: str) -> tuple[float, float]: return self.object_navigate_sample_range.get(object_name, (0.0, 2 * np.pi)) + def get_relative_reach_axis(self, object_name: str, relative_reach_skill_name: str) -> str: + return self.object_relative_reach_axis.get(object_name, {}).get(relative_reach_skill_name) + @dataclass class WorldState: diff --git a/source/autosim/autosim/skills/relative_reach.py b/source/autosim/autosim/skills/relative_reach.py index d843884..b83a59f 100644 --- a/source/autosim/autosim/skills/relative_reach.py +++ b/source/autosim/autosim/skills/relative_reach.py @@ -38,7 +38,7 @@ def __post_init__(self): "z": torch.tensor([0.0, 0.0, 1.0]), } - def get_direction_vector(self) -> torch.Tensor: + def get_direction_vector(self, move_axis: str) -> torch.Tensor: """Parse move_axis and compute the normalized direction vector. This is computed on-demand to support dynamic modification of move_axis. @@ -49,11 +49,11 @@ def get_direction_vector(self) -> torch.Tensor: import re pattern = r"([+-][xyz])" - matches = re.findall(pattern, self.move_axis) + matches = re.findall(pattern, move_axis) if not matches: raise ValueError( - f"Invalid move_axis format: '{self.move_axis}'. Expected format: '+x', '-y', '+x+y', '+x-z', etc." + f"Invalid move_axis format: '{move_axis}'. Expected format: '+x', '-y', '+x+y', '+x-z', etc." ) direction_vector = torch.zeros(3) @@ -61,12 +61,12 @@ def get_direction_vector(self) -> torch.Tensor: sign = 1.0 if match[0] == "+" else -1.0 axis = match[1] if axis not in self._axis_map: - raise ValueError(f"Invalid axis '{axis}' in move_axis: '{self.move_axis}'") + raise ValueError(f"Invalid axis '{axis}' in move_axis: '{move_axis}'") direction_vector += sign * self._axis_map[axis] norm = torch.linalg.norm(direction_vector) if norm < 1e-6: - raise ValueError(f"move_axis '{self.move_axis}' results in zero direction vector") + raise ValueError(f"move_axis '{move_axis}' results in zero direction vector") return direction_vector / norm @@ -91,7 +91,11 @@ def extract_goal_from_info( ) -> SkillGoal: """Return the target object of the relative reach skill.""" - return SkillGoal(target_object=skill_info.target_object) + relative_reach_axis = env_extra_info.get_relative_reach_axis(skill_info.target_object, self.cfg.name) + if relative_reach_axis is None: + relative_reach_axis = self.cfg.extra_cfg.move_axis + + return SkillGoal(target_object=skill_info.target_object, info=dict(move_axis=relative_reach_axis)) def execute_plan(self, state: WorldState, goal: SkillGoal) -> bool: """Execute the plan of the relative reach skill.""" @@ -112,7 +116,9 @@ def execute_plan(self, state: WorldState, goal: SkillGoal) -> bool: # move the eef along the move axis by the move offset based on eef frame, and convert to robot root frame to get target pose isaaclab_device = state.device - move_offset_vector = self.cfg.extra_cfg.get_direction_vector() * self.cfg.extra_cfg.move_offset + move_offset_vector = ( + self.cfg.extra_cfg.get_direction_vector(goal.info["move_axis"]) * self.cfg.extra_cfg.move_offset + ) offset_pos_in_ee = move_offset_vector.to(isaaclab_device).unsqueeze(0) offset_quat_in_ee = torch.tensor([1.0, 0.0, 0.0, 0.0], device=isaaclab_device).unsqueeze(0) ee_pos_in_robot_root, ee_quat_in_robot_root = target_pos.to(isaaclab_device), target_quat.to(isaaclab_device)