Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions source/autosim/autosim/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class SkillGoal:
"""The target pose of the skill."""
extra_target_poses: dict[str, torch.Tensor] | None = None
"""The target poses of the extra end-effectors. dict[link_name, target_pose]."""
info: dict[str, Any] = field(default_factory=dict)
"""The information of the skill goal."""


@dataclass
Expand Down Expand Up @@ -94,6 +96,9 @@ class EnvExtraInfo:
object_navigate_sample_range: dict[str, tuple[float, float]] = field(default_factory=dict)
"""The sample range for the navigate skill. each object can have a tuple of (min_angle, max_angle) in radians."""

object_relative_reach_axis: dict[str, dict[str, str]] = field(default_factory=dict)
"""The relative reach axis for the relative reach skill. dict[object_name, dict[relative_reach_skill_name, axis]]."""

cached_initial_extra_target_offsets: dict[str, tuple[torch.Tensor, torch.Tensor]] | None = None
"""Cached primary-frame offsets for extra target links, reused across multiple reach-like skills."""

Expand All @@ -111,6 +116,9 @@ def get_reach_target_poses(self, object_name: str) -> list[torch.Tensor]:
def get_navigate_sample_range(self, object_name: str) -> tuple[float, float]:
return self.object_navigate_sample_range.get(object_name, (0.0, 2 * np.pi))

def get_relative_reach_axis(self, object_name: str, relative_reach_skill_name: str) -> str:
return self.object_relative_reach_axis.get(object_name, {}).get(relative_reach_skill_name)


@dataclass
class WorldState:
Expand Down
20 changes: 13 additions & 7 deletions source/autosim/autosim/skills/relative_reach.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __post_init__(self):
"z": torch.tensor([0.0, 0.0, 1.0]),
}

def get_direction_vector(self) -> torch.Tensor:
def get_direction_vector(self, move_axis: str) -> torch.Tensor:
"""Parse move_axis and compute the normalized direction vector.

This is computed on-demand to support dynamic modification of move_axis.
Expand All @@ -49,24 +49,24 @@ def get_direction_vector(self) -> torch.Tensor:
import re

pattern = r"([+-][xyz])"
matches = re.findall(pattern, self.move_axis)
matches = re.findall(pattern, move_axis)

if not matches:
raise ValueError(
f"Invalid move_axis format: '{self.move_axis}'. Expected format: '+x', '-y', '+x+y', '+x-z', etc."
f"Invalid move_axis format: '{move_axis}'. Expected format: '+x', '-y', '+x+y', '+x-z', etc."
)

direction_vector = torch.zeros(3)
for match in matches:
sign = 1.0 if match[0] == "+" else -1.0
axis = match[1]
if axis not in self._axis_map:
raise ValueError(f"Invalid axis '{axis}' in move_axis: '{self.move_axis}'")
raise ValueError(f"Invalid axis '{axis}' in move_axis: '{move_axis}'")
direction_vector += sign * self._axis_map[axis]

norm = torch.linalg.norm(direction_vector)
if norm < 1e-6:
raise ValueError(f"move_axis '{self.move_axis}' results in zero direction vector")
raise ValueError(f"move_axis '{move_axis}' results in zero direction vector")

return direction_vector / norm

Expand All @@ -91,7 +91,11 @@ def extract_goal_from_info(
) -> SkillGoal:
"""Return the target object of the relative reach skill."""

return SkillGoal(target_object=skill_info.target_object)
relative_reach_axis = env_extra_info.get_relative_reach_axis(skill_info.target_object, self.cfg.name)
if relative_reach_axis is None:
relative_reach_axis = self.cfg.extra_cfg.move_axis

return SkillGoal(target_object=skill_info.target_object, info=dict(move_axis=relative_reach_axis))

def execute_plan(self, state: WorldState, goal: SkillGoal) -> bool:
"""Execute the plan of the relative reach skill."""
Expand All @@ -112,7 +116,9 @@ def execute_plan(self, state: WorldState, goal: SkillGoal) -> bool:

# move the eef along the move axis by the move offset based on eef frame, and convert to robot root frame to get target pose
isaaclab_device = state.device
move_offset_vector = self.cfg.extra_cfg.get_direction_vector() * self.cfg.extra_cfg.move_offset
move_offset_vector = (
self.cfg.extra_cfg.get_direction_vector(goal.info["move_axis"]) * self.cfg.extra_cfg.move_offset
)
offset_pos_in_ee = move_offset_vector.to(isaaclab_device).unsqueeze(0)
offset_quat_in_ee = torch.tensor([1.0, 0.0, 0.0, 0.0], device=isaaclab_device).unsqueeze(0)
ee_pos_in_robot_root, ee_quat_in_robot_root = target_pos.to(isaaclab_device), target_quat.to(isaaclab_device)
Expand Down
Loading