From e0cdb579696f94f5022c42c2ba2b0c17cf564fc2 Mon Sep 17 00:00:00 2001 From: catmub Date: Tue, 9 Jun 2026 16:31:20 +0000 Subject: [PATCH] VLA Server and TM_VLA integration --- .gitignore | 1 + _meta/docker/features/vla_server/Dockerfile | 15 + .../features/vla_server/docker-compose.yml | 21 + _meta/docker/features/vla_server/main | 55 ++ .../vla_server/model/model_omnivla_edge.py | 435 +++++++++++++++ .../vla_server/model/run_omnivla_edge.py | 495 ++++++++++++++++++ .../features/vla_server/model/utils_policy.py | 120 +++++ .../vla_server/server_omnivla_edge.py | 140 +++++ arena_bringup/arena_bringup/supervisor.py | 68 ++- .../task_generator/constants/__init__.py | 1 + .../task_generator/tasks/robots/__init__.py | 4 +- .../task_generator/tasks/robots/vla/README.md | 40 ++ .../tasks/robots/vla/__init__.py | 16 + .../task_generator/tasks/robots/vla/impl.py | 283 ++++++++++ 14 files changed, 1679 insertions(+), 15 deletions(-) create mode 100644 _meta/docker/features/vla_server/Dockerfile create mode 100644 _meta/docker/features/vla_server/docker-compose.yml create mode 100644 _meta/docker/features/vla_server/main create mode 100644 _meta/docker/features/vla_server/model/model_omnivla_edge.py create mode 100644 _meta/docker/features/vla_server/model/run_omnivla_edge.py create mode 100644 _meta/docker/features/vla_server/model/utils_policy.py create mode 100644 _meta/docker/features/vla_server/server_omnivla_edge.py create mode 100644 task_generator/task_generator/tasks/robots/vla/README.md create mode 100644 task_generator/task_generator/tasks/robots/vla/__init__.py create mode 100644 task_generator/task_generator/tasks/robots/vla/impl.py diff --git a/.gitignore b/.gitignore index 81dcebab..62395c73 100644 --- a/.gitignore +++ b/.gitignore @@ -71,3 +71,4 @@ __pycache__/ # agents /arena_training/agents/ +_meta/docker/features/vla_server/model/*.pth diff --git a/_meta/docker/features/vla_server/Dockerfile b/_meta/docker/features/vla_server/Dockerfile new file mode 100644 index 00000000..88bbbdd5 --- /dev/null +++ b/_meta/docker/features/vla_server/Dockerfile @@ -0,0 +1,15 @@ +ARG TORCH_TAG=2.5.1-cuda12.4-cudnn9-runtime +FROM pytorch/pytorch:${TORCH_TAG} AS base + +USER root + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git curl \ + && rm -rf /var/lib/apt/lists/* + +RUN pip install --no-cache-dir --upgrade pip + +RUN pip install fastapi "uvicorn[standard]" openai-clip Pillow numpy requests packaging matplotlib efficientnet_pytorch python-multipart + +RUN sed -i 's/from pkg_resources import packaging/import packaging, packaging.version/' \ + "$(python -c 'import clip, os; print(os.path.dirname(clip.__file__))')/clip.py" \ \ No newline at end of file diff --git a/_meta/docker/features/vla_server/docker-compose.yml b/_meta/docker/features/vla_server/docker-compose.yml new file mode 100644 index 00000000..2cbe0100 --- /dev/null +++ b/_meta/docker/features/vla_server/docker-compose.yml @@ -0,0 +1,21 @@ +services: + vla_server: + image: vla_server:omnivla-edge + pull_policy: "never" + build: + context: . + dockerfile: src/Arena/_meta/docker/features/vla_server/Dockerfile + # args: + volumes: + - "$HOST_ARENA_WS_DIR/src:/opt/arena_ws/src" + environment: + - HOST_ARENA_WS_DIR + - VLA_MODEL_DIR=/opt/arena_ws/src/Arena/_meta/docker/features/vla_server/model + command: ["python3", "/opt/arena_ws/src/Arena/_meta/docker/features/vla_server/server_omnivla_edge.py"] + network_mode: host + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: [gpu] diff --git a/_meta/docker/features/vla_server/main b/_meta/docker/features/vla_server/main new file mode 100644 index 00000000..fa06256a --- /dev/null +++ b/_meta/docker/features/vla_server/main @@ -0,0 +1,55 @@ +#!/usr/bin/env bash + +name="vla_server" + +update(){ + set -e + arena registry add "$name" + arena_docker_compose build vla_server +} + +uninstall(){ + arena_docker_compose rm -fs vla_server || true + arena registry remove "$name" +} + +launch(){ + # Source only the docker function library — env vars (arena_compose_sudo, docker_env, + # HOST_ARENA_WS_DIR) are already correctly set by the container environment. + # Do NOT source the full setup script; it resets arena_compose_sudo="" and + # overwrites HOST_ARENA_WS_DIR with the in-container path. + # shellcheck source=/dev/null + source "$ARENA_DIR/_meta/docker/lib" + # Register feature so arena_docker_compose includes its docker-compose.yml overlay + grep -qxF "$name" "$ARENA_DIR/.installed" 2>/dev/null || \ + echo "$name" >> "$ARENA_DIR/.installed" + # Stop any previous instance to avoid accumulated log-stream attachments + arena_docker_compose stop vla_server 2>/dev/null || true + # Run detached so only one log stream attaches; stop on process exit + arena_docker_compose up -d --remove-orphans vla_server + trap 'arena_docker_compose stop vla_server 2>/dev/null' EXIT INT TERM + arena_docker_compose logs -f vla_server +} + +help(){ + echo "Usage: $name " +} +if [ $# -lt 1 ]; then + help + exit 1 +fi +case "$1" in + update) + update + exit $? + ;; + uninstall) + uninstall + # shellcheck disable=SC2317 + return $? 2>/dev/null || exit $? + ;; + launch) + launch + exit $? + ;; +esac \ No newline at end of file diff --git a/_meta/docker/features/vla_server/model/model_omnivla_edge.py b/_meta/docker/features/vla_server/model/model_omnivla_edge.py new file mode 100644 index 00000000..ab551793 --- /dev/null +++ b/_meta/docker/features/vla_server/model/model_omnivla_edge.py @@ -0,0 +1,435 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Dict, Optional, Tuple +from efficientnet_pytorch import EfficientNet + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, max_seq_len=6): + super().__init__() + + # Compute the positional encoding once + pos_enc = torch.zeros(max_seq_len, d_model) + pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pos_enc[:, 0::2] = torch.sin(pos * div_term) + pos_enc[:, 1::2] = torch.cos(pos * div_term) + pos_enc = pos_enc.unsqueeze(0) + + # Register the positional encoding as a buffer to avoid it being + # considered a parameter when saving the model + self.register_buffer('pos_enc', pos_enc) + + def forward(self, x): + # Add the positional encoding to the input + x = x + self.pos_enc[:, :x.size(1), :] + return x + +class MultiLayerDecoder_mask3(nn.Module): + def __init__(self, embed_dim=512, seq_len=6, output_layers=[256, 128, 64], nhead=8, num_layers=8, ff_dim_factor=4): + super(MultiLayerDecoder_mask3, self).__init__() + self.positional_encoding = PositionalEncoding(embed_dim, max_seq_len=seq_len) + self.sa_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=ff_dim_factor*embed_dim, activation="gelu", batch_first=True, norm_first=True) + self.sa_decoder = nn.TransformerEncoder(self.sa_layer, num_layers=num_layers) + self.output_layers = nn.ModuleList([nn.Linear(embed_dim + 1, embed_dim)]) + self.output_layers.append(nn.Linear(embed_dim, output_layers[0])) + for i in range(len(output_layers)-1): + self.output_layers.append(nn.Linear(output_layers[i], output_layers[i+1])) + + def forward(self, x, src_key_padding_mask, avg_pool_mask, no_goal_mask): + if self.positional_encoding: x = self.positional_encoding(x) + x = self.sa_decoder(x, src_key_padding_mask=src_key_padding_mask) + if src_key_padding_mask is not None: + avg_mask = torch.index_select(avg_pool_mask, 0, no_goal_mask).unsqueeze(-1) + x = x * avg_mask + x = torch.mean(x, dim=1) + x = x.reshape(x.shape[0], -1) + if no_goal_mask.sum().item() == 9: + dev_gpu = no_goal_mask.get_device() + no_goal_mask = torch.tensor([9]).to(dev_gpu) + x = torch.cat((x, no_goal_mask.unsqueeze(1)), axis=1) + for i in range(len(self.output_layers)): + x = self.output_layers[i](x) + x = F.relu(x) + return x + +class BaseModel(nn.Module): + def __init__( + self, + context_size: int = 5, + len_traj_pred: Optional[int] = 5, + learn_angle: Optional[bool] = True, + ) -> None: + super(BaseModel, self).__init__() + self.context_size = context_size + self.learn_angle = learn_angle + self.len_trajectory_pred = len_traj_pred + if self.learn_angle: + self.num_action_params = 4 # last two dims are the cos and sin of the angle + else: + self.num_action_params = 2 + + def flatten(self, z: torch.Tensor) -> torch.Tensor: + z = nn.functional.adaptive_avg_pool2d(z, (1, 1)) + z = torch.flatten(z, 1) + return z + + def forward( + self, obs_img: torch.tensor, goal_img: torch.tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + +class OmniVLA_edge(BaseModel): + def __init__( + self, + context_size: int = 5, + len_traj_pred: Optional[int] = 5, + learn_angle: Optional[bool] = True, + obs_encoder: Optional[str] = "efficientnet-b0", + obs_encoding_size: Optional[int] = 512, + late_fusion: Optional[bool] = False, + mha_num_attention_heads: Optional[int] = 2, + mha_num_attention_layers: Optional[int] = 2, + mha_ff_dim_factor: Optional[int] = 4, + ) -> None: + + super(OmniVLA_edge, self).__init__(context_size, len_traj_pred, learn_angle) + self.obs_encoding_size = obs_encoding_size + self.goal_encoding_size = obs_encoding_size + + self.late_fusion = late_fusion + if obs_encoder.split("-")[0] == "efficientnet": + self.obs_encoder = EfficientNet.from_name(obs_encoder, in_channels=3) # context + self.num_obs_features = self.obs_encoder._fc.in_features + + self.goal_encoder = EfficientNet.from_name("efficientnet-b0", in_channels=9) # context + self.num_obs_features_map = self.goal_encoder._fc.in_features + + if self.late_fusion: + self.goal_encoder_img = EfficientNet.from_name("efficientnet-b0", in_channels=3) + else: + self.goal_encoder_img = EfficientNet.from_name("efficientnet-b0", in_channels=6) # obs+goal + self.num_goal_features_img = self.goal_encoder_img._fc.in_features + + else: + raise NotImplementedError + + if self.num_obs_features != self.obs_encoding_size: + self.compress_obs_enc = nn.Linear(self.num_obs_features, self.obs_encoding_size) + else: + self.compress_obs_enc = nn.Identity() + + if self.num_obs_features_map != self.obs_encoding_size: + self.compress_obs_enc_map = nn.Linear(self.num_obs_features_map, self.obs_encoding_size) + else: + self.compress_obs_enc_map = nn.Identity() + + if self.num_goal_features_img != self.goal_encoding_size: + self.compress_goal_enc_img = nn.Linear(self.num_goal_features_img, self.goal_encoding_size) + else: + self.compress_goal_enc_img = nn.Identity() + + self.num_goal_features_lan = 4096 + if self.num_goal_features_lan != self.goal_encoding_size: + self.compress_goal_enc_lan = nn.Linear(self.num_goal_features_lan, self.goal_encoding_size) #clip feature + else: + self.compress_goal_enc_lan = nn.Identity() + + self.decoder = MultiLayerDecoder_mask3( + embed_dim=self.obs_encoding_size, + seq_len=self.context_size+2+1+1+1, + output_layers=[256, 128, 64, 32], + nhead=mha_num_attention_heads, + num_layers=mha_num_attention_layers, + ff_dim_factor=mha_ff_dim_factor, + ) + + self.action_predictor = nn.Sequential( + nn.Linear(32, self.len_trajectory_pred * self.num_action_params), + ) + + self.film_model = build_film_model(8, 10, 128, 512) + + self.max_linvel = 0.5 + self.max_angvel = 1.0 + + self.dist_predictor = nn.Sequential( + nn.Linear(32, 1), + ) + self.local_goal = nn.Sequential( + nn.Linear(4, self.goal_encoding_size), + ) + + self.goal_mask_0 = torch.zeros((1, self.context_size + 5), dtype=torch.bool) + self.goal_mask_0[:, -4] = True + self.goal_mask_0[:, -2] = True + self.goal_mask_0[:, -1] = True + self.goal_mask_1 = torch.zeros((1, self.context_size + 5), dtype=torch.bool) + self.goal_mask_1[:, -3] = True + self.goal_mask_1[:, -2] = True + self.goal_mask_1[:, -1] = True + self.goal_mask_2 = torch.zeros((1, self.context_size + 5), dtype=torch.bool) + self.goal_mask_2[:, -2] = True + self.goal_mask_2[:, -1] = True + self.goal_mask_3 = torch.zeros((1, self.context_size + 5), dtype=torch.bool) + self.goal_mask_3[:, -4] = True + self.goal_mask_3[:, -1] = True + self.goal_mask_4 = torch.zeros((1, self.context_size + 5), dtype=torch.bool) + self.goal_mask_4[:, -3] = True + self.goal_mask_4[:, -1] = True + self.goal_mask_5 = torch.zeros((1, self.context_size + 5), dtype=torch.bool) + self.goal_mask_5[:, -1] = True + self.goal_mask_6 = torch.zeros((1, self.context_size + 5), dtype=torch.bool) + self.goal_mask_6[:, -4] = True + self.goal_mask_6[:, -3] = True + self.goal_mask_6[:, -1] = True + self.goal_mask_7 = torch.zeros((1, self.context_size + 5), dtype=torch.bool) + self.goal_mask_7[:, -4] = True + self.goal_mask_7[:, -3] = True + self.goal_mask_7[:, -2] = True + self.goal_mask_8 = torch.zeros((1, self.context_size + 5), dtype=torch.bool) + self.goal_mask_8[:, -3] = True + self.goal_mask_8[:, -2] = True + self.all_masks = torch.cat([self.goal_mask_0, self.goal_mask_2, self.goal_mask_3, self.goal_mask_5, self.goal_mask_1, self.goal_mask_4, self.goal_mask_6, self.goal_mask_7, self.goal_mask_8], dim=0) + self.no_mask = torch.zeros((1, self.context_size + 5), dtype=torch.bool) + + avep_mask_0 = (1.0 - self.goal_mask_0.float())*((self.context_size + 5)/(torch.sum(1.0 - self.goal_mask_0.float()))) + avep_mask_1 = (1.0 - self.goal_mask_1.float())*((self.context_size + 5)/(torch.sum(1.0 - self.goal_mask_1.float()))) + avep_mask_2 = (1.0 - self.goal_mask_2.float())*((self.context_size + 5)/(torch.sum(1.0 - self.goal_mask_2.float()))) + avep_mask_3 = (1.0 - self.goal_mask_3.float())*((self.context_size + 5)/(torch.sum(1.0 - self.goal_mask_3.float()))) + avep_mask_4 = (1.0 - self.goal_mask_4.float())*((self.context_size + 5)/(torch.sum(1.0 - self.goal_mask_4.float()))) + avep_mask_5 = (1.0 - self.goal_mask_5.float())*((self.context_size + 5)/(torch.sum(1.0 - self.goal_mask_5.float()))) + avep_mask_6 = (1.0 - self.goal_mask_6.float())*((self.context_size + 5)/(torch.sum(1.0 - self.goal_mask_6.float()))) + avep_mask_7 = (1.0 - self.goal_mask_7.float())*((self.context_size + 5)/(torch.sum(1.0 - self.goal_mask_7.float()))) + avep_mask_8 = (1.0 - self.goal_mask_8.float())*((self.context_size + 5)/(torch.sum(1.0 - self.goal_mask_8.float()))) + + self.avg_pool_mask = torch.cat([avep_mask_0, avep_mask_2, avep_mask_3, avep_mask_5, avep_mask_1, avep_mask_4, avep_mask_6, avep_mask_7, avep_mask_8], dim=0) + + def forward( + self, obs_img: torch.tensor, goal_pose: torch.tensor, map_images: torch.tensor, goal_img: torch.tensor, goal_mask: torch.tensor, feat_text: torch.tensor, current_img: torch.tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + # Get the goal encoding + # text feature + inst_encoding = feat_text + obsgoal_encoding_lan = self.film_model(current_img, inst_encoding) + obsgoal_encoding_lan_cat = obsgoal_encoding_lan.flatten(start_dim=1) + obsgoal_encoding_lan = self.compress_goal_enc_lan(obsgoal_encoding_lan_cat) + + if len(obsgoal_encoding_lan.shape) == 2: + obsgoal_encoding_lan = obsgoal_encoding_lan.unsqueeze(1) + assert obsgoal_encoding_lan.shape[2] == self.goal_encoding_size + goal_encoding_lan = obsgoal_encoding_lan + + if self.late_fusion: + goal_encoding_img = self.goal_encoder_img.extract_features(goal_img) + else: + obsgoal_img = torch.cat([obs_img[:, 3*self.context_size:, :, :], goal_img], dim=1) + goal_encoding_img = self.goal_encoder_img.extract_features(obsgoal_img) + goal_encoding_img = self.goal_encoder_img._avg_pooling(goal_encoding_img) + if self.goal_encoder._global_params.include_top: + goal_encoding_img = goal_encoding_img.flatten(start_dim=1) + goal_encoding_img = self.goal_encoder_img._dropout(goal_encoding_img) + goal_encoding_img = self.compress_goal_enc_img(goal_encoding_img) + + if len(goal_encoding_img.shape) == 2: + goal_encoding_img = goal_encoding_img.unsqueeze(1) + assert goal_encoding_img.shape[2] == self.goal_encoding_size + + device = obs_img.get_device() + goal_encoding = self.local_goal(goal_pose).unsqueeze(1) + map_encoding = self.goal_encoder.extract_features(map_images).unsqueeze(1) + map_encoding = self.obs_encoder._avg_pooling(map_encoding) + + obs_img = torch.split(obs_img, 3, dim=1) + obs_img = torch.concat(obs_img, dim=0) + + # get the observation encoding + obs_encoding = self.obs_encoder.extract_features(obs_img) + # currently the size is [batch_size*(self.context_size + 1), 1280, H/32, W/32] + obs_encoding = self.obs_encoder._avg_pooling(obs_encoding) + # currently the size is [batch_size*(self.context_size + 1), 1280, 1, 1] + if self.obs_encoder._global_params.include_top: + obs_encoding = obs_encoding.flatten(start_dim=1) + obs_encoding = self.obs_encoder._dropout(obs_encoding) + + if self.goal_encoder._global_params.include_top: + map_encoding = map_encoding.flatten(start_dim=1) + map_encoding = self.goal_encoder._dropout(map_encoding) + + obs_encoding = self.compress_obs_enc(obs_encoding) + map_encoding = self.compress_obs_enc_map(map_encoding) + + obs_encoding = obs_encoding.reshape((self.context_size+1, -1, self.obs_encoding_size)) + obs_encoding = torch.transpose(obs_encoding, 0, 1) + + # concatenate the goal encoding to the observation encoding + tokens = torch.cat((obs_encoding, goal_encoding, map_encoding.unsqueeze(1), goal_encoding_img, goal_encoding_lan), dim=1) + if goal_mask is not None: + no_goal_mask = goal_mask.long() + src_key_padding_mask = torch.index_select(self.all_masks.to(device), 0, no_goal_mask) + else: + src_key_padding_mask = None + + final_repr = self.decoder(tokens, src_key_padding_mask, self.avg_pool_mask.to(device), no_goal_mask) + + action_pred = self.action_predictor(final_repr) + dist_pred = self.dist_predictor(final_repr) + + # augment outputs to match labels size-wise + action_pred = action_pred.reshape( + (action_pred.shape[0], self.len_trajectory_pred, self.num_action_params) + ) + action_pred[:, :, :2] = torch.cumsum( + action_pred[:, :, :2], dim=1 + ) + if True: + action_pred[:, :, 2:] = F.normalize( + action_pred[:, :, 2:].clone(), dim=-1 + ) + + return action_pred, dist_pred, no_goal_mask + +def create_conv_layer(in_channels, out_channels, kernel_size, stride, padding): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), + nn.ReLU(inplace=True), + nn.BatchNorm2d(out_channels), + ) + +class InitialFeatureExtractor(nn.Module): + def __init__(self): + super(InitialFeatureExtractor, self).__init__() + + self.layers = nn.Sequential( + create_conv_layer(3, 128, 5, 2, 2), + create_conv_layer(128, 128, 3, 2, 1), + create_conv_layer(128, 128, 3, 2, 1), + ) + + def forward(self, x): + return self.layers(x) + +class IntermediateFeatureExtractor(nn.Module): + def __init__(self): + super(IntermediateFeatureExtractor, self).__init__() + + self.layers = nn.Sequential( + create_conv_layer(128, 256, 3, 2, 1), + create_conv_layer(256, 512, 3, 2, 1), + create_conv_layer(512, 1024, 3, 2, 1), + create_conv_layer(1024, 1024, 3, 2, 1), + ) + + def forward(self, x): + return self.layers(x) + + +class FiLMTransform(nn.Module): + def __init__(self): + super(FiLMTransform, self).__init__() + + def forward(self, x, gamma, beta): + beta = beta.view(x.size(0), x.size(1), 1, 1) + gamma = gamma.view(x.size(0), x.size(1), 1, 1) + + x = gamma * x + beta + + return x + +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_channels, out_channels, 1, 1, 0) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1) + self.norm2 = nn.BatchNorm2d(out_channels) + self.film_transform = FiLMTransform() + self.relu2 = nn.ReLU(inplace=True) + + def forward(self, x, beta, gamma): + x = self.conv1(x) + x = self.relu1(x) + identity = x + + x = self.conv2(x) + x = self.norm2(x) + x = self.film_transform(x, beta, gamma) + x = self.relu2(x) + + x = x + identity + + return x + +class FinalClassifier(nn.Module): + def __init__(self, input_channels, num_classes): + super(FinalClassifier, self).__init__() + + self.conv = nn.Conv2d(input_channels, 512, 1, 1, 0) + self.relu = nn.ReLU(inplace=True) + self.global_pool = nn.AdaptiveMaxPool2d((1, 1)) + self.fc_layers = nn.Sequential( + nn.Linear(512, 1024), + nn.ReLU(inplace=True), + nn.Linear(1024, 1024), + nn.ReLU(inplace=True), + nn.Linear(1024, num_classes) + ) + + def forward(self, x): + x = self.conv(x) + feature_map = x + x = self.global_pool(x) + x = x.view(x.size(0), x.size(1)) + x = self.fc_layers(x) + + return x, feature_map + +class FiLMNetwork(nn.Module): + def __init__(self, num_res_blocks, num_classes, num_channels, question_dim): + super(FiLMNetwork, self).__init__() + question_feature_dim = question_dim + + self.film_param_generator = nn.Linear(question_feature_dim, 2 * num_res_blocks * num_channels) + self.initial_feature_extractor = InitialFeatureExtractor() + self.residual_blocks = nn.ModuleList() + self.intermediate_feature_extractor = IntermediateFeatureExtractor() + + for _ in range(num_res_blocks): + self.residual_blocks.append(ResidualBlock(num_channels + 2, num_channels)) + + self.final_classifier = FinalClassifier(num_channels, num_classes) + + self.num_res_blocks = num_res_blocks + self.num_channels = num_channels + + def forward(self, x, question): + batch_size = x.size(0) + device = x.device + + x = self.initial_feature_extractor(x) + film_params = self.film_param_generator(question).view( + batch_size, self.num_res_blocks, 2, self.num_channels) + + d = x.size(2) + coords = torch.arange(-1, 1 + 0.00001, 2 / (d-1)).to(device) + coord_x = coords.expand(batch_size, 1, d, d) + coord_y = coords.view(d, 1).expand(batch_size, 1, d, d) + + for i, res_block in enumerate(self.residual_blocks): + beta = film_params[:, i, 0, :] + gamma = film_params[:, i, 1, :] + + x = torch.cat([x, coord_x, coord_y], 1) + x = res_block(x, beta, gamma) + + features = self.intermediate_feature_extractor(x) + + return features + +def build_film_model(num_res_blocks, num_classes, num_channels, question_dim): + return FiLMNetwork(num_res_blocks, num_classes, num_channels, question_dim) diff --git a/_meta/docker/features/vla_server/model/run_omnivla_edge.py b/_meta/docker/features/vla_server/model/run_omnivla_edge.py new file mode 100644 index 00000000..28110974 --- /dev/null +++ b/_meta/docker/features/vla_server/model/run_omnivla_edge.py @@ -0,0 +1,495 @@ +# =============================================================== +# OmniVLA edge Inference +# =============================================================== +# +# Sample inference code for OmniVLA edge +# if you want to control the robot, you need to update the current state such as pose and image in "run_omnivla_edge" and comment out "break" in "run". +# +# --------------------------- +# Paths and System Setup +# --------------------------- +import sys, os +sys.path.insert(0, '..') + +import time, math, json +from typing import Optional, Tuple, Type, Dict +from dataclasses import dataclass + +import numpy as np +from PIL import Image +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +from torch.nn.parallel import DistributedDataParallel as DDP +import torchvision.transforms as transforms +import matplotlib.pyplot as plt +import utm +import argparse +import yaml +import clip + +from utils_policy import transform_images_map, load_model, transform_images_PIL, transform_images_PIL_mask + +# =============================================================== +# Utility Functions +# =============================================================== +def remove_ddp_in_checkpoint(state_dict: dict) -> dict: + return {k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()} + +def load_checkpoint(module_name: str, path: str, step: int, device: str = "cpu") -> dict: + if not os.path.exists(os.path.join(path, f"{module_name}--{step}_checkpoint.pt")) and module_name == "pose_projector": + module_name = "proprio_projector" + checkpoint_path = os.path.join(path, f"{module_name}--{step}_checkpoint.pt") + print(f"Loading checkpoint: {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=device) + return remove_ddp_in_checkpoint(state_dict) + +def count_parameters(module: nn.Module, name: str) -> None: + num_params = sum(p.numel() for p in module.parameters() if p.requires_grad) + print(f"# trainable params in {name}: {num_params}") + +def init_module( + module_class: Type[nn.Module], + module_name: str, + cfg: "InferenceConfig", + device_id: int, + module_args: dict, + to_bf16: bool = False, +) -> DDP: + module = module_class(**module_args) + count_parameters(module, module_name) + + if cfg.resume: + state_dict = load_checkpoint(module_name, cfg.vla_path, cfg.resume_step) + module.load_state_dict(state_dict) + + if to_bf16: + module = module.to(torch.bfloat16) + module = module.to(device_id) + return module + +# =============================================================== +# Inference Class +# =============================================================== +class Inference: + def __init__(self, save_dir, lan_inst_prompt, goal_utm, goal_compass, goal_image_PIL): + self.tick_rate = 3 + self.lan_inst_prompt = lan_inst_prompt + self.goal_utm = goal_utm + self.goal_compass = goal_compass + self.goal_image_PIL = goal_image_PIL + self.count_id = 0 + self.linear, self.angular = 0.0, 0.0 + self.datastore_path_image = save_dir + # ---------------------------- + # Static Utility Methods + # ---------------------------- + @staticmethod + def calculate_relative_position(x_a, y_a, x_b, y_b): + return x_b - x_a, y_b - y_a + + @staticmethod + def rotate_to_local_frame(delta_x, delta_y, heading_a_rad): + rel_x = delta_x * math.cos(heading_a_rad) + delta_y * math.sin(heading_a_rad) + rel_y = -delta_x * math.sin(heading_a_rad) + delta_y * math.cos(heading_a_rad) + return rel_x, rel_y + + # ---------------------------- + # Main Loop + # ---------------------------- + def run(self): + loop_time = 1 / self.tick_rate + start_time = time.time() + while True: + if time.time() - start_time > loop_time: + self.tick() + start_time = time.time() + break + + def tick(self): + self.linear, self.angular = self.run_omnivla() + + # ---------------------------- + # OmniVLA Inference + # ---------------------------- + def run_omnivla(self): + thres_dist = 30.0 + metric_waypoint_spacing = 0.1 + + # Load current GPS & heading + current_lat = 37.87371258374039 + current_lon = -122.26729417226024 + current_compass = 270.0 + cur_utm = utm.from_latlon(current_lat, current_lon) + cur_compass = -float(current_compass) / 180.0 * math.pi # inverted compass + + # Local goal position + delta_x, delta_y = self.calculate_relative_position( + cur_utm[0], cur_utm[1], self.goal_utm[0], self.goal_utm[1] + ) + relative_x, relative_y = self.rotate_to_local_frame(delta_x, delta_y, cur_compass) + radius = np.sqrt(relative_x**2 + relative_y**2) + if radius > thres_dist: + relative_x *= thres_dist / radius + relative_y *= thres_dist / radius + + goal_pose_torch = torch.from_numpy(np.array([ + relative_y / metric_waypoint_spacing, + -relative_x / metric_waypoint_spacing, + np.cos(self.goal_compass - cur_compass), + np.sin(self.goal_compass - cur_compass) + ])).unsqueeze(0).float().to(device) + + # Overwriting "goal_pose_torch" to test only. If you want to use the GPS signal to calculate "goal_pose_torch", you need to comment out the following block. + yaw_ang = -90.0 + goal_pose_torch = torch.from_numpy(np.array([ + 1.0 / metric_waypoint_spacing, + -10.0 / metric_waypoint_spacing, + np.cos(yaw_ang/180.0*3.1415), + np.sin(yaw_ang/180.0*3.1415) + ])).unsqueeze(0).float().to(device) + + # Load current image + current_image_path = "./inference/current_img.jpg" + current_image_PIL = Image.open(current_image_path).convert("RGB") + + current_image_PIL_96 = current_image_PIL.resize(imgsize) + current_image_PIL_224 = current_image_PIL.resize(imgsize_clip) + + #In this test code, we feed same images for the observation history, assuming that the robot stopped at the current location. + context_queue = [current_image_PIL_96, current_image_PIL_96, current_image_PIL_96, current_image_PIL_96, current_image_PIL_96, current_image_PIL_96] + #obs_images = transform_images_PIL(context_queue) + obs_images = transform_images_PIL_mask(context_queue, mask_360_pil_96) + obs_images = torch.split(obs_images.to(device), 3, dim=1) + obs_image_cur = obs_images[-1].to(device) + obs_images = torch.cat(obs_images, dim=1).to(device) + + #cur_large_img = transform_images_PIL(current_image_PIL_224).to(device) + cur_large_img = transform_images_PIL_mask(current_image_PIL_224, mask_360_pil_224).to(device) + + #Dummy satellite image + satellite_cur = Image.new("RGB", (352, 352), color=(0, 0, 0)) + satellite_goal = Image.new("RGB", (352, 352), color=(0, 0, 0)) + current_map_image = transform_images_map(satellite_cur) + goal_map_image = transform_images_map(satellite_goal) + map_images = torch.cat((current_map_image.to(device), goal_map_image.to(device), obs_image_cur), axis=1) + + # Language instruction + lan_inst = self.lan_inst_prompt if lan_prompt else "xxxx" + obj_inst_lan = clip.tokenize(lan_inst, truncate=True).to(device) + + # Egocentric goal image + #goal_image = transform_images_PIL(goal_image_PIL).to(device) + goal_image = transform_images_PIL_mask(goal_image_PIL, mask_360_pil_96).to(device) + + batch = {} + batch["obs_images"] = obs_images + batch["goal_pose_torch"] = goal_pose_torch + batch["map_images"] = map_images + batch["goal_image"] = goal_image + batch["obj_inst_lan"] = obj_inst_lan + batch["cur_large_img"] = cur_large_img + + # Run forward pass + actions, modality_id = self.run_forward_pass( + model=model.eval(), + batch=batch, + device_id=device, + mode="train", + idrun=self.count_id, + ) + self.count_id += 1 + + waypoints = actions.float().cpu().numpy() + + # Select waypoint + waypoint_select = 4 + chosen_waypoint = waypoints[0][waypoint_select].copy() + chosen_waypoint[:2] *= metric_waypoint_spacing + dx, dy, hx, hy = chosen_waypoint + + # PD controller + EPS = 1e-8 + DT = 1 / 3 + if np.abs(dx) < EPS and np.abs(dy) < EPS: + linear_vel_value = 0 + angular_vel_value = 1.0 * clip_angle(np.arctan2(hy, hx)) / DT + elif np.abs(dx) < EPS: + linear_vel_value = 0 + angular_vel_value = 1.0 * np.sign(dy) * np.pi / (2 * DT) + else: + linear_vel_value = dx / DT + angular_vel_value = np.arctan(dy / dx) / DT + + linear_vel_value = np.clip(linear_vel_value, 0, 0.5) + angular_vel_value = np.clip(angular_vel_value, -1.0, 1.0) + + # Velocity limitation + maxv, maxw = 0.3, 0.3 + if np.abs(linear_vel_value) <= maxv: + if np.abs(angular_vel_value) <= maxw: + linear_vel_value_limit = linear_vel_value + angular_vel_value_limit = angular_vel_value + else: + rd = linear_vel_value / angular_vel_value + linear_vel_value_limit = maxw * np.sign(linear_vel_value) * np.abs(rd) + angular_vel_value_limit = maxw * np.sign(angular_vel_value) + else: + if np.abs(angular_vel_value) <= 0.001: + linear_vel_value_limit = maxv * np.sign(linear_vel_value) + angular_vel_value_limit = 0.0 + else: + rd = linear_vel_value / angular_vel_value + if np.abs(rd) >= maxv / maxw: + linear_vel_value_limit = maxv * np.sign(linear_vel_value) + angular_vel_value_limit = maxv * np.sign(angular_vel_value) / np.abs(rd) + else: + linear_vel_value_limit = maxw * np.sign(linear_vel_value) * np.abs(rd) + angular_vel_value_limit = maxw * np.sign(angular_vel_value) + + # Save behavior + self.save_robot_behavior( + current_image_PIL, self.goal_image_PIL, goal_pose_torch[0].cpu(), waypoints[0], + linear_vel_value_limit, angular_vel_value_limit, metric_waypoint_spacing, modality_id.cpu().numpy() + ) + + print("linear angular", linear_vel_value_limit, angular_vel_value_limit) + return linear_vel_value_limit, angular_vel_value_limit + + # ---------------------------- + # Save Robot Behavior Visualization + # ---------------------------- + def save_robot_behavior(self, cur_img, goal_img, goal_pose, waypoints, + linear_vel, angular_vel, metric_waypoint_spacing, mask_number): + fig = plt.figure(figsize=(34, 16), dpi=80) + gs = fig.add_gridspec(2, 2) + ax_ob = fig.add_subplot(gs[0, 0]) + ax_goal = fig.add_subplot(gs[1, 0]) + ax_graph_pos = fig.add_subplot(gs[:, 1]) + + ax_ob.imshow(np.array(cur_img).astype(np.uint8)) + ax_goal.imshow(np.array(goal_img).astype(np.uint8)) + + x_seq = waypoints[:, 0] #generated trajectory is on the robot coordinate. X is front and Y is left. + y_seq_inv = -waypoints[:, 1] + ax_graph_pos.plot(np.insert(y_seq_inv, 0, 0.0), np.insert(x_seq, 0, 0.0), linewidth=4.0, markersize=12, marker='o', color='blue') + + # Mask annotation + mask_type = int(mask_number[0]) + mask_texts = [ + "satellite only", "pose and satellite", "satellite and image", "all", + "pose only", "pose and image", "image only", "language only", "language and pose" + ] + if mask_type < len(mask_texts): + ax_graph_pos.annotate(mask_texts[mask_type], xy=(1.0, 0.0), xytext=(-20, 20), fontsize=18, textcoords='offset points') + + ax_ob.set_title("Egocentric current image", fontsize=18) + ax_goal.set_title("Egocentric goal image", fontsize=18) + ax_graph_pos.tick_params(axis='x', labelsize=15) + ax_graph_pos.tick_params(axis='y', labelsize=15) + + if int(mask_number[0]) == 1 or int(mask_number[0]) == 3 or int(mask_number[0]) == 4 or int(mask_number[0]) == 5 or int(mask_number[0]) == 8: + ax_graph_pos.plot(-goal_pose[1], goal_pose[0], marker = '*', color='red', markersize=15) + else: + ax_graph_pos.set_xlim(-3.0, 3.0) + ax_graph_pos.set_ylim(-0.1, 10.0) + ax_graph_pos.set_xlim(-3.0, 3.0) + ax_graph_pos.set_ylim(-0.1, 10.0) + + ax_graph_pos.set_title("Normalized generated 2D trajectories from OmniVLA", fontsize=18) + + save_path = os.path.join(self.datastore_path_image, f"{self.count_id}_ex_omnivla_edge.jpg") + plt.savefig(save_path) + + # ---------------------------- + # Run Forward Pass + # ---------------------------- + def run_forward_pass(self, model, batch, device_id, mode="vali", idrun=0) -> Tuple[torch.Tensor, Dict[str, float]]: + + #Setup masking + if pose_goal and satellite and image_goal and not lan_prompt: + modality_id = 3 + elif not pose_goal and satellite and not image_goal and not lan_prompt: + modality_id = 0 + elif pose_goal and not satellite and not image_goal and not lan_prompt: + modality_id = 4 + elif pose_goal and satellite and not image_goal and not lan_prompt: + modality_id = 1 + elif not pose_goal and satellite and image_goal and not lan_prompt: + modality_id = 2 + elif pose_goal and not satellite and image_goal and not lan_prompt: + modality_id = 5 + elif not pose_goal and not satellite and image_goal and not lan_prompt: + modality_id = 6 + elif not pose_goal and not satellite and not image_goal and lan_prompt: + modality_id = 7 + elif pose_goal and not satellite and not image_goal and lan_prompt: + modality_id = 8 + elif not pose_goal and not satellite and image_goal and lan_prompt: + modality_id = 9 + modality_id_select = torch.tensor([modality_id]).to(device) + + bimg, _, _, _ = batch["goal_image"].size() + with torch.no_grad(): + feat_text_lan = text_encoder.encode_text(batch["obj_inst_lan"]) + predicted_actions, distances, mask_number = model(batch["obs_images"].repeat(bimg, 1, 1, 1), batch["goal_pose_torch"].repeat(bimg,1), batch["map_images"].repeat(bimg, 1, 1, 1), batch["goal_image"], modality_id_select.repeat(bimg), feat_text_lan.repeat(bimg, 1), batch["cur_large_img"].repeat(bimg, 1, 1, 1)) + print("Generated action chunk", predicted_actions) + # Return both the loss tensor (with gradients) and the metrics dictionary (with detached values) + return predicted_actions, modality_id_select + + +# =============================================================== +# Inference Configuration +# =============================================================== +class InferenceConfig: + resume: bool = True + vla_path: str = "./omnivla-original" + resume_step: Optional[int] = 120000 + #vla_path: str = "./omnivla-finetuned-cast" + #resume_step: Optional[int] = 210000 + use_l1_regression: bool = True + use_diffusion: bool = False + use_film: bool = False + num_images_in_input: int = 2 + use_lora: bool = True + lora_rank: int = 32 + lora_dropout: float = 0.0 + +def define_model(cfg: InferenceConfig) -> None: + cfg.vla_path = cfg.vla_path.rstrip("/") + print(f"Loading OpenVLA Model `{cfg.vla_path}`") + + # GPU setup + device_id = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + torch.cuda.set_device(device_id) + torch.cuda.empty_cache() + + print( + "Detected constants:\n" + f"\tNUM_ACTIONS_CHUNK: {NUM_ACTIONS_CHUNK}\n" + f"\tACTION_DIM: {ACTION_DIM}\n" + f"\tPOSE_DIM: {POSE_DIM}\n" + f"\tACTION_PROPRIO_NORMALIZATION_TYPE: {ACTION_PROPRIO_NORMALIZATION_TYPE}" + ) + + # Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub) + AutoConfig.register("openvla", OpenVLAConfig) + AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor) + AutoProcessor.register(OpenVLAConfig, PrismaticProcessor) + AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction_MMNv1) + + # Load processor and VLA + processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True) + vla = AutoModelForVision2Seq.from_pretrained( + cfg.vla_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ).to(device_id) # trust_remote_code=True, + + vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input) + vla.to(dtype=torch.bfloat16, device=device_id) + + pose_projector = init_module( + ProprioProjector, + "pose_projector", + cfg, + device_id, + {"llm_dim": vla.llm_dim, "proprio_dim": POSE_DIM}, + ) + + if cfg.use_l1_regression: + action_head = init_module( + L1RegressionActionHead_idcat, + "action_head", + cfg, + device_id, + {"input_dim": vla.llm_dim, "hidden_dim": vla.llm_dim, "action_dim": ACTION_DIM}, + to_bf16=True, + ) + + # Get number of vision patches + NUM_PATCHES = vla.vision_backbone.get_num_patches() * vla.vision_backbone.get_num_images_in_input() + NUM_PATCHES += 1 #for goal pose + + # Create Action Tokenizer + action_tokenizer = ActionTokenizer(processor.tokenizer) + + return vla, action_head, pose_projector, device_id, NUM_PATCHES, action_tokenizer, processor + +# =============================================================== +# Main Entry +# =============================================================== +if __name__ == "__main__": + # select modality + pose_goal = False + satellite = False + image_goal = False + lan_prompt = True + + imgsize = (96, 96) + imgsize_clip = (224, 224) + + # Goal definitions + # language prompt + lan_inst_prompt = "blue trash bin" + + # GPS signal + goal_lat, goal_lon, goal_compass = 37.8738930785863, -122.26746181032362, 0.0 + goal_utm = utm.from_latlon(goal_lat, goal_lon) + goal_compass = -float(goal_compass) / 180.0 * math.pi + + # Egocentric goal image + goal_image_PIL = Image.open("./inference/goal_img.jpg").convert("RGB").resize(imgsize) + + Front_foward = True + + # load model parameters + model_params = {} + model_params["model_type"] = "omnivla-edge" + model_params["len_traj_pred"] = 8 + model_params["learn_angle"] = True + model_params["context_size"] = 5 + model_params["obs_encoder"] = "efficientnet-b0" + model_params["encoding_size"] = 256 + model_params["obs_encoding_size"] = 1024 + model_params["goal_encoding_size"] = 1024 + model_params["late_fusion"] = False + model_params["mha_num_attention_heads"] = 4 + model_params["mha_num_attention_layers"] = 4 + model_params["mha_ff_dim_factor"] = 4 + model_params["clip_type"] = "ViT-B/32" + + context_size = model_params["context_size"] + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print("Using device:", device) + + MODEL_WEIGHTS_PATH = "./omnivla-edge" + ckpth_path = MODEL_WEIGHTS_PATH + "/" + "omnivla-edge.pth" + if os.path.exists(ckpth_path): + print(f"Loading model from {ckpth_path}") + else: + raise FileNotFoundError(f"Model weights not found at {ckpth_path}") + model, text_encoder, preprocess = load_model( + ckpth_path, + model_params, + device, + ) + text_encoder = text_encoder.to(device).eval() + model = model.to(device).eval() + + #mask setting + #mask_360 = np.loadtxt(open(DIR_loc + "/train/mask_360view.csv", "rb"), delimiter=",", skiprows=0) + #Memo: Depending on your camera type, we observe the fisheye image type masking can work well for OmniVLA-edge. Following is no mask case. + mask_360_pil_96 = np.ones((96, 96, 3), dtype=np.float32) + mask_360_pil_224 = np.ones((224, 224, 3), dtype=np.float32) + + # Run inference + inference = Inference( + save_dir="./inference", + lan_inst_prompt=lan_inst_prompt, + goal_utm=goal_utm, + goal_compass=goal_compass, + goal_image_PIL=goal_image_PIL, + ) + inference.run() diff --git a/_meta/docker/features/vla_server/model/utils_policy.py b/_meta/docker/features/vla_server/model/utils_policy.py new file mode 100644 index 00000000..b0d26cf3 --- /dev/null +++ b/_meta/docker/features/vla_server/model/utils_policy.py @@ -0,0 +1,120 @@ +import os +import sys +import io +import matplotlib.pyplot as plt + +# ROS +#from sensor_msgs.msg import Image + +# pytorch +import torch +import torch.nn as nn +from torchvision import transforms +import torchvision.transforms.functional as TF + +import clip +import numpy as np +from PIL import Image as PILImage +from typing import List, Tuple, Dict, Optional + +#model architecture +from model_omnivla_edge import OmniVLA_edge + + +def load_model( + model_path: str, + config: dict, + device: torch.device = torch.device("cpu"), +) -> nn.Module: + """Load a model from a checkpoint file (works with models trained on multiple GPUs)""" + model_type = config["model_type"] + + if config["model_type"] == "omnivla-edge": + model = OmniVLA_edge( + context_size=config["context_size"], + len_traj_pred=config["len_traj_pred"], + learn_angle=config["learn_angle"], + obs_encoder=config["obs_encoder"], + obs_encoding_size=config["obs_encoding_size"], + late_fusion=config["late_fusion"], + mha_num_attention_heads=config["mha_num_attention_heads"], + mha_num_attention_layers=config["mha_num_attention_layers"], + mha_ff_dim_factor=config["mha_ff_dim_factor"], + ) + text_encoder, preprocess = clip.load(config["clip_type"]) + text_encoder.to(torch.float32) + else: + raise ValueError(f"Invalid model type: {model_type}") + + checkpoint = torch.load(model_path, map_location=device) + if model_type == "omnivla-edge": + state_dict = checkpoint + model.load_state_dict(state_dict, strict=True) + else: + loaded_model = checkpoint["model"] + try: + state_dict = loaded_model.module.state_dict() + model.load_state_dict(state_dict, strict=False) + except AttributeError as e: + state_dict = loaded_model.state_dict() + model.load_state_dict(state_dict, strict=False) + + return model, text_encoder, preprocess + +def transform_images_PIL_mask(pil_imgs: List[PILImage.Image], mask) -> torch.Tensor: + """Transforms a list of PIL image to a torch tensor.""" + transform_type = transforms.Compose( + [ + #transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ + 0.229, 0.224, 0.225]), + ] + ) + if type(pil_imgs) != list: + pil_imgs = [pil_imgs] + transf_imgs = [] + for pil_img in pil_imgs: + transf_img = transform_type(TF.to_tensor(pil_img*mask)/255.0) #/255.0 + transf_img = torch.unsqueeze(transf_img, 0) + transf_imgs.append(transf_img) + return torch.cat(transf_imgs, dim=1) + +def transform_images_PIL(pil_imgs: List[PILImage.Image]) -> torch.Tensor: + """Transforms a list of PIL image to a torch tensor.""" + transform_type = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ + 0.229, 0.224, 0.225]), + ] + ) + if type(pil_imgs) != list: + pil_imgs = [pil_imgs] + transf_imgs = [] + for pil_img in pil_imgs: + transf_img = transform_type(pil_img.copy()) + transf_img = torch.unsqueeze(transf_img, 0) + transf_imgs.append(transf_img) + return torch.cat(transf_imgs, dim=1) + +def transform_images_map(pil_imgs: List[PILImage.Image]) -> torch.Tensor: + """Transforms a list of PIL image to a torch tensor.""" + transform_type = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ + 0.229, 0.224, 0.225]), + ] + ) + image_size_small = (96, 96) + + if type(pil_imgs) != list: + pil_imgs = [pil_imgs] + transf_imgs = [] + for pil_img in pil_imgs: + w, h = pil_img.size + pil_img = pil_img.resize(image_size_small) + transf_img = transform_type(pil_img) + transf_img = torch.unsqueeze(transf_img, 0) + transf_imgs.append(transf_img) + return torch.cat(transf_imgs, dim=1) diff --git a/_meta/docker/features/vla_server/server_omnivla_edge.py b/_meta/docker/features/vla_server/server_omnivla_edge.py new file mode 100644 index 00000000..985325d8 --- /dev/null +++ b/_meta/docker/features/vla_server/server_omnivla_edge.py @@ -0,0 +1,140 @@ +import os +import sys +from io import BytesIO +import numpy as np +import torch +import clip +from PIL import Image + +import uvicorn +from fastapi import FastAPI, File, Form, UploadFile +from fastapi.responses import JSONResponse + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +_MODEL_DIR = os.environ.get( + "VLA_MODEL_DIR", + os.path.join(_SCRIPT_DIR, "model"), +) +sys.path.insert(0, _SCRIPT_DIR) # enables: from model.utils_policy import ... +sys.path.insert(0, _MODEL_DIR) # enables: from model_omnivla_edge import ... inside utils_policy.py +from model.utils_policy import load_model, transform_images_PIL_mask, transform_images_map + +_WEIGHTS_PATH = os.path.join(_MODEL_DIR, "omnivla-edge.pth") + +_IMGSIZE = (96, 96) #for efficientnet encoder +_IMGSIZE_CLIP = (224, 224) #for clip encoder + +_MODALITY_ID = 7 #lang instruction input +_METRIC_WP_SPACING = 0.1 #scaling the waypoint output to meter? + +# load model parameters +_MODEL_PARAMS = {} +_MODEL_PARAMS["model_type"] = "omnivla-edge" +_MODEL_PARAMS["len_traj_pred"] = 8 +_MODEL_PARAMS["learn_angle"] = True +_MODEL_PARAMS["context_size"] = 5 +_MODEL_PARAMS["obs_encoder"] = "efficientnet-b0" +_MODEL_PARAMS["encoding_size"] = 256 +_MODEL_PARAMS["obs_encoding_size"] = 1024 +_MODEL_PARAMS["goal_encoding_size"] = 1024 +_MODEL_PARAMS["late_fusion"] = False +_MODEL_PARAMS["mha_num_attention_heads"] = 4 +_MODEL_PARAMS["mha_num_attention_layers"] = 4 +_MODEL_PARAMS["mha_ff_dim_factor"] = 4 +_MODEL_PARAMS["clip_type"] = "ViT-B/32" + +app = FastAPI() + + +def _load()->None: + global device, model, text_encoder, mask_360_pil_96, mask_360_pil_224 + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print(f"[vla_server] loading {_WEIGHTS_PATH} on {device}", flush=True) + model, text_encoder, _ = load_model( + _WEIGHTS_PATH, + _MODEL_PARAMS, + device, + ) + text_encoder = text_encoder.to(device).eval() + model = model.to(device).eval() + + mask_360_pil_96 = np.ones((96, 96, 3), dtype=np.float32) + mask_360_pil_224 = np.ones((224, 224, 3), dtype=np.float32) + print("[vla_server] ready", flush=True) + +@app.on_event("startup") +def on_startup()->None: + _load() + +@app.get("/health") +def health()->JSONResponse: + return JSONResponse({"status":"ok"}) + +@app.post("/act") +async def act( + image: UploadFile=File(...), + instruction: str=Form(...) +)->JSONResponse: + img=await image.read() + + # Load current image + current_image_PIL = Image.open(BytesIO(img)).convert("RGB") + + current_image_PIL_96 = current_image_PIL.resize(_IMGSIZE) + current_image_PIL_224 = current_image_PIL.resize(_IMGSIZE_CLIP) + + #-----omnivla-edge inference code reuse------- + #In this test code, we feed same images for the observation history, assuming that the robot stopped at the current location. + context_queue = [current_image_PIL_96, current_image_PIL_96, current_image_PIL_96, current_image_PIL_96, current_image_PIL_96, current_image_PIL_96] + #obs_images = transform_images_PIL(context_queue) + obs_images = transform_images_PIL_mask(context_queue, mask_360_pil_96) + obs_images = torch.split(obs_images.to(device), 3, dim=1) + obs_image_cur = obs_images[-1].to(device) + obs_images = torch.cat(obs_images, dim=1).to(device) + + #cur_large_img = transform_images_PIL(current_image_PIL_224).to(device) + cur_large_img = transform_images_PIL_mask(current_image_PIL_224, mask_360_pil_224).to(device) + + #Dummy satellite image + satellite_cur = Image.new("RGB", (352, 352), color=(0, 0, 0)) + satellite_goal = Image.new("RGB", (352, 352), color=(0, 0, 0)) + current_map_image = transform_images_map(satellite_cur) + goal_map_image = transform_images_map(satellite_goal) + map_images = torch.cat((current_map_image.to(device), goal_map_image.to(device), obs_image_cur), axis=1) + + # Egocentric goal image + dummy_goal = Image.new("RGB", _IMGSIZE, color=(0, 0, 0)) + goal_image = transform_images_PIL_mask(dummy_goal, mask_360_pil_96).to(device) + goal_pose_torch = torch.zeros(1, 4, dtype=torch.float32, device=device) + + + # Language instruction + obj_inst_lan = clip.tokenize(instruction, truncate=True).to(device) + modality_id = torch.tensor([_MODALITY_ID]).to(device) + + with torch.no_grad(): + feat_text_lan = text_encoder.encode_text(obj_inst_lan) + predicted_actions, distances, mask_number = model( + obs_images, + goal_pose_torch, + map_images, + goal_image, + modality_id, + feat_text_lan, + cur_large_img, + ) + + + waypoints = predicted_actions.float().cpu().numpy()[0] + waypoints[:,:2] *= _METRIC_WP_SPACING + + return JSONResponse({"waypoints": waypoints.tolist()}) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="OmniVLA-edge inference server") + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + args = parser.parse_args() + uvicorn.run(app, host=args.host, port=args.port) \ No newline at end of file diff --git a/arena_bringup/arena_bringup/supervisor.py b/arena_bringup/arena_bringup/supervisor.py index 9c5c94ed..a98b03fb 100644 --- a/arena_bringup/arena_bringup/supervisor.py +++ b/arena_bringup/arena_bringup/supervisor.py @@ -1,6 +1,6 @@ """Process supervisor for `arena launch`. -Spawns arena_runtime + N task_generator envs (+ viz per env), discovers +Spawns arena_runtime + N task_generator envs (+ rviz per env), discovers readiness via rclpy graph queries, and propagates signals to the entire process tree of each child via process-group signaling. """ @@ -117,16 +117,41 @@ def get_param_string(self, node_name: str, param: str, timeout_s: float = 5.0) - self._node.destroy_client(cli) -from arena_bringup import viz_backends +def _viz_spawn_commands(ns: str, viz_args: dict[str, str]) -> list[tuple[str, list[str]]]: + """One rviz per env, fanning out across robots when `viz.robot:=all`.""" + robot = viz_args.get('robot', '0') + extras = [f'{k}:={v}' for k, v in viz_args.items() if k != 'robot'] + base = ['ros2', 'launch', 'rviz_utils', 'rviz_config.launch.py', f'ns:={ns}', *extras] + if robot == 'all': + return [(f'rviz_{ns}_r{i}', [*base, f'robot:={i}']) for i in range(_fleet_size(ns))] + return [(f'rviz_{ns}', [*base, f'robot:={robot}'])] + + +def _fleet_size(ns: str) -> int: + """Probe RobotFleet for `ns`; fall back to 1 if unknown.""" + import rclpy.qos + from task_generator_msgs.msg import RobotFleet + + seen: list[int] = [] + node = rclpy.create_node('arena_supervisor_fleet_probe') + qos = rclpy.qos.QoSProfile(depth=1, durability=rclpy.qos.DurabilityPolicy.TRANSIENT_LOCAL) + sub = node.create_subscription(RobotFleet, f'{ns}/state/robots', lambda m: seen.append(len(m.robots)), qos) + deadline = time.monotonic() + 5.0 + while not seen and time.monotonic() < deadline: + rclpy.spin_once(node, timeout_sec=0.1) + node.destroy_subscription(sub) + node.destroy_node() + return seen[0] if seen else 1 def parse_args(argv: list[str]) -> argparse.Namespace: - """Forward every k:=v to both runtime and env; supervisor-only knobs are env_n / viz / viz.*.""" + """Forward every k:=v to both runtime and env; supervisor-only knobs are env_n / rviz / viz.* / vla_server.""" env_n = 1 headless = False - viz = True - viz_set = False + rviz = True + rviz_set = False sim: str | None = None + vla_server: str | None = None runtime_args: list[str] = [] env_args: list[str] = [] viz_args: dict[str, str] = {} @@ -139,13 +164,16 @@ def parse_args(argv: list[str]) -> argparse.Namespace: if key == 'env_n': env_n = int(value) continue - if key == 'viz': - viz_set = True - viz = value.lower() in ('true', '1') + if key == 'rviz': + rviz_set = True + rviz = value.lower() in ('true', '1') continue if key.startswith('viz.'): viz_args[key[len('viz.') :]] = value continue + if key == 'vla_server': + vla_server = value + continue if key == 'sim': sim = value elif key == 'headless': @@ -153,14 +181,15 @@ def parse_args(argv: list[str]) -> argparse.Namespace: runtime_args.append(arg) env_args.append(arg) - if headless and not viz_set: - viz = False + if headless and not rviz_set: + rviz = False return argparse.Namespace( env_n=env_n, headless=headless, - viz=viz, + rviz=rviz, sim=sim, + vla_server=vla_server, runtime_args=runtime_args, env_args=env_args, viz_args=viz_args, @@ -187,6 +216,19 @@ def wait_until( def run(args: argparse.Namespace, sup: Supervisor) -> int: + if args.vla_server is not None: + import os as _os + _features_dir = _os.environ.get( + 'ARENA_FEATURES_DIR', + _os.path.join(_os.environ.get('ARENA_DIR', ''), '_meta', 'docker', 'features'), + ) + _vla_main = _os.path.join(_features_dir, 'vla_server', 'main') + sup.spawn('vla_server', ['bash', _vla_main, 'launch', args.vla_server]) + sys.stderr.write( + f'arena launch: VLA server spawning (model={args.vla_server}); ' + 'health endpoint: http://localhost:8000/health\n' + ) + if sup.has_service(REGISTER_ENV_SERVICE): if args.sim is not None: existing = sup.get_param_string('/arena', 'sim') @@ -229,7 +271,7 @@ def run(args: argparse.Namespace, sup: Supervisor) -> int: ], ) - if args.viz: + if args.rviz: if not wait_until( lambda: len(sup.viz_namespaces()) >= target_count, runtime_proc=runtime, @@ -237,7 +279,7 @@ def run(args: argparse.Namespace, sup: Supervisor) -> int: ): return 1 for ns in sorted(sup.viz_namespaces() - existing_ns): - for role, cmd in viz_backends.viz_commands(ns, viz_backends.env_idx_from_ns(ns), args.viz_args): + for role, cmd in _viz_spawn_commands(ns, args.viz_args): sup.spawn(role, cmd) if runtime is not None: diff --git a/task_generator/task_generator/constants/__init__.py b/task_generator/task_generator/constants/__init__.py index 7dd89b13..5e36c51b 100644 --- a/task_generator/task_generator/constants/__init__.py +++ b/task_generator/task_generator/constants/__init__.py @@ -41,6 +41,7 @@ class TM_Robots(enum.Enum): RANDOM = "random" SCENARIO = "scenario" DEMO = "demo" + VLA = "vla" @classmethod def prefix(cls, *args: object) -> Namespace: diff --git a/task_generator/task_generator/tasks/robots/__init__.py b/task_generator/task_generator/tasks/robots/__init__.py index afbe3763..f84db69b 100644 --- a/task_generator/task_generator/tasks/robots/__init__.py +++ b/task_generator/task_generator/tasks/robots/__init__.py @@ -7,7 +7,7 @@ from task_generator.tasks.robots._placement import random_placement from task_generator.tasks.robots.request import GoToPhase, TaskRequest -from . import demo, explore, guided, random, scenario +from . import demo, explore, guided, random, scenario, vla class TM_Robots(TaskMode): @@ -79,4 +79,4 @@ async def done(self) -> bool: return True -__all__ = ["TM_Robots", "demo", "explore", "guided", "random", "scenario"] +__all__ = ["TM_Robots", "demo", "explore", "guided", "random", "scenario", "vla"] diff --git a/task_generator/task_generator/tasks/robots/vla/README.md b/task_generator/task_generator/tasks/robots/vla/README.md new file mode 100644 index 00000000..7db254ac --- /dev/null +++ b/task_generator/task_generator/tasks/robots/vla/README.md @@ -0,0 +1,40 @@ +# VLA navigation task mode (`TM_VLA`) + +## Changes + +Each cycle (throttled to `_INFERENCE_INTERVAL`): + +1. **Image subscription is automatic** — on reset the task mode finds each robot's camera topic by + itself, preferring any sensor with `"front"` in the name. +2. The latest frame is sent to the inference server over HTTP, along with the instruction. +3. The model runs inference and returns **8 future waypoints**, sent back to the client. +4. The waypoints are filtered through `is_valid_pose` (anything off-map or inside an obstacle is + dropped), and the **furthest valid one** is submitted as the goal. +5. If the total distance from the current pose to that goal is below + `_WAYPOINT_PROXIMITY_THRESHOLD`, control is **handed off to nav2** to finish the approach. + +## Server + +Adds the `vla_server` arena feature. + +Run with: + +```bash +arena feature vla_server update +``` + +```bash +arena vla_server:=omnivla_edge +``` + +`omnivla_edge` is currently the only supported model. + +Download model with +```bash +git clone https://huggingface.co/NHirose/omnivla-edge +``` +put model.pth in /arena_ws/src/Arena/_meta/docker/features/vla_server/model folder +## Note + +The instruction is set in `impl.py` for now and has to be edited there directly — this will move to +a proper arg later. \ No newline at end of file diff --git a/task_generator/task_generator/tasks/robots/vla/__init__.py b/task_generator/task_generator/tasks/robots/vla/__init__.py new file mode 100644 index 00000000..6c82580f --- /dev/null +++ b/task_generator/task_generator/tasks/robots/vla/__init__.py @@ -0,0 +1,16 @@ +import typing + +from task_generator.constants import Constants +from task_generator.tasks.registry import _REGISTRY_NAMESPACE, ROBOTS_MODES + +if typing.TYPE_CHECKING: + from task_generator.tasks.robots import TM_Robots + +_NS = _REGISTRY_NAMESPACE("vla") + + +@ROBOTS_MODES.register(Constants.TaskMode.TM_Robots.VLA, namespace=_NS) +def _load_vla() -> type["TM_Robots"]: + from .impl import TM_VLA + + return TM_VLA \ No newline at end of file diff --git a/task_generator/task_generator/tasks/robots/vla/impl.py b/task_generator/task_generator/tasks/robots/vla/impl.py new file mode 100644 index 00000000..19644243 --- /dev/null +++ b/task_generator/task_generator/tasks/robots/vla/impl.py @@ -0,0 +1,283 @@ +import asyncio +import io +import math +import time +import numpy as np +import requests + +from PIL import Image +from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy +import sensor_msgs.msg + +from arena_rclpy_mixins.Time import Time +from task_generator.shared import Orientation, Pose, Position +from task_generator.tasks.robots import TM_Robots +from task_generator.tasks.robots.request import GoToPhase, TaskRequest + +from arena_robots.Sensor import SensorType + +#server +_VLA_BASE = "http://127.0.0.1:8000" +_VLA_SERVER = _VLA_BASE + "/act" +_VLA_HEALTH = _VLA_BASE + "/health" + +_INSTRUCTION = "go to a corner and stay there" #extend: later add arg or read from file + +_INFERENCE_INTERVAL = 2 #done is called 500ms by default in interval, overwrite with 2s + +_WAYPOINT_PROXIMITY_THRESHOLD = 0.2 #meter, threshhold before giving control to nav2 +_MAX_INVALID_STREAK = 3 #consecutive skipping invalid waypoint before forcing proximity handoff + +class TM_VLA(TM_Robots): + + _latest_images:dict + _image_subs:dict + + _timeouts:dict + _started:dict + _last_attempt:dict #inference interval + + _near_goal:dict + _invalid_streak:dict + + _inference_pending:dict #asyncio + + #---------------MAIN LOOP------------------(tm explore inspired)# + async def reset(self, **kwargs: object) -> None: + await super().reset(**kwargs) + + self._latest_images = {} + self._image_subs = [] + self._timeouts = {} + self._started = {} + self._last_attempt = {} + self._near_goal = {} + self._inference_pending = {} + self._invalid_streak = {} + + # Random valid spawn positions + biggest_robot = max((r.safe_distance for r in self._ctx.robots.values()), default=0.5) + n = len(self._ctx.robots) + positions = self._ctx.world_manager.get_positions_on_map(n=n, safe_dist=biggest_robot) + orientations = 2 * math.pi * self.node.conf.General.RNG.value.random(n) + for (name, robot), pos, ori in zip(self._ctx.robots.items(), positions, orientations, strict=False): + self._start_poses[name] = Pose(pos, Orientation.from_yaw(ori)) + + #try reach server + try: + requests.get(_VLA_HEALTH, timeout=2.0).raise_for_status() + except Exception: + self.node.get_logger().error( + f"[TM_VLA]: VLA Server not reachable at {_VLA_BASE}" + ) + + qos = QoSProfile( + reliability=ReliabilityPolicy.BEST_EFFORT, + history=HistoryPolicy.KEEP_LAST, + depth=1 + ) + + for name, robot in self._ctx.robots.items(): + self._latest_images[name]=None + + self._timeouts[name]=self.node.sim_time + self._started[name]=False + + topic=self.find_image_topic(robot) + if topic is None: + self.node.get_logger().warn( + f"[TM_VLA:{name}] no image sensor in model_params. Skipping {name}" + ) + continue + sub=self.node.create_subscription( + sensor_msgs.msg.Image, + topic, + lambda msg, n=name: self.robot_image(n,msg), + qos + ) + self._image_subs.append(sub) + self.node.get_logger().info( + f"[TM_VLA]sub to {topic}" + ) + @property + async def done(self)->bool: + timeout=self.node.conf.Robot.TIMEOUT.value + + for name, robot in self._ctx.robots.items(): + + if self._near_goal.get(name, False): + self.node.get_logger().warn( + f"[TM_VLA:{name}] proximity mode — waiting for nav2" + ) + if await robot.is_done: #wait for nav2 to finish instead of activating submit_vla_goal + return True + continue + + if (self.node.sim_time.sec - self._timeouts.get(name, Time()).sec) >= timeout: + self.node.get_logger().warn(f"[TM_VLA:{name}] episode timeout — ending") + return True + + await self.submit_vla_goal(name, robot) + + return False + + #---------------------help func-------------- + def find_image_topic(self, robot)->str|None: + image_sensors = [s for s in robot.robot_view.model_params.sensors if s.type==SensorType.IMAGE] + if not image_sensors: + return None + + preferred = next((s for s in image_sensors if "front" in s.name.lower()), image_sensors[0]) + return str(robot.namespace(preferred.topic.removeprefix("${namespace}/"))) + #----------------------- + def robot_image(self, robot, msg:sensor_msgs.msg.Image)->None: + self._latest_images[robot]=msg + #-------------------------- + def vla_inference(self, image:sensor_msgs.msg.Image)->list[tuple[float,float]]|None: + + #preping the image for sending to server + arr = np.frombuffer(image.data, dtype=np.uint8).reshape( + image.height, image.width, -1 + ) + if image.encoding=="bgr8": + arr=arr[:,:,::-1].copy() #reformat to rgb cuz pil want that + pil = Image.fromarray(arr).resize((512,512)) + buf = io.BytesIO() + pil.save(buf, format="JPEG") + + try: + response=requests.post( + _VLA_SERVER, + files={"image": ("frame.jpg", buf.getvalue(), "image/jpeg")}, + data={"instruction":_INSTRUCTION}, + timeout=10.0 + ) + response.raise_for_status() + except requests.exceptions.RequestException: + self.node.get_logger().warn(f"[TM_VLA] server unreachable at {_VLA_BASE}, retrying in {_INFERENCE_RETRY_INTERVAL:.0f}s") + return None + + return [(float(wp[0]), float(wp[1])) for wp in response.json()["waypoints"]] + + #---------------------- + async def submit_vla_goal(self, name, robot)->None: + # Do the following: + + # run vla_inference at set _INFERENCE_INTERVAL if nothing changed in done() + # (ie: if robot.is_done: submit_vla_goal #similar to explore) + + # extract (at max 8 waypoints cuz vla output), convert to proper waypoints in arena + # one after another using to_pose + # put those in list of GoToPhase for later submit_task(TaskRequest) + + #Filter out waypoints that is out of bound with respect to simulation + # Handeling nav2 handoff when total distant < Threshhold + now = time.monotonic() + if now - self._last_attempt.get(name,0.0)<_INFERENCE_INTERVAL: + return + self._last_attempt[name]=now + + #inference + image = self._latest_images.get(name) + current_pose=robot.pose + if image is None or current_pose is None: + self.node.get_logger().warn( + f"No image or pose(most likely img)" + ) + return + + if self._inference_pending.get(name,False): + return + self._inference_pending[name]=True + try: + waypoints=await asyncio.to_thread(self.vla_inference,image) + finally: + self._inference_pending[name]=False + if waypoints is None: + self.node.get_logger().warn( + f"waypoint empty" + ) + return + + + #convert waypoint to proper waypoint and append to list for submission + phases=[] + + for wp in waypoints: + goal_pose=self.to_pose(current_pose,wp) + phases.append(GoToPhase(pose=self._ctx.environment_manager.ezilear(goal_pose))) + #phases.append(GoToPhase(pose=goal_pose)) + + + #drop wp that is invalid + phases=[wp for wp in phases if self.is_valid_pose(wp.pose.position.x,wp.pose.position.y)] + for wp in phases: + self.node.get_logger().warn( + #f"[TM_VLA:WP], goal_valid:({wp.pose.position.x:.2f},{wp.pose.position.y:.2f}) " + f"[TM_VLA:WP], goal_valid:({self._ctx.environment_manager.realize(wp.pose).position.x:.2f},{self._ctx.environment_manager.realize(wp.pose).position.y:.2f}) " + ) + if not phases: #incase all invalide + streak = self._invalid_streak.get(name, 0) + 1 + self._invalid_streak[name] = streak + self.node.get_logger().warn( + f"[TM_VLA:{name}] all waypoints invalid ({streak}/{_MAX_INVALID_STREAK}) — skipping" + ) + if streak >= _MAX_INVALID_STREAK: + self.node.get_logger().warn( + f"[TM_VLA:{name}] invalid streak limit reached — forcing proximity handoff" + ) + self._near_goal[name] = True + return + + # calculate total distance between current pose and last goal pose + # if les than threshold, handoff to nav2 + phases=[phases[-1]] + #wps = [current_pose]+[wp.pose for wp in phases] + wps = [current_pose]+[self._ctx.environment_manager.realize(wp.pose) for wp in phases] + + total_dist=sum( + math.hypot( + wps[i+1].position.x-wps[i].position.x, + wps[i+1].position.y-wps[i].position.y + ) for i in range(len(wps)-1) + ) + self._near_goal[name]=total_dist<_WAYPOINT_PROXIMITY_THRESHOLD + if self._near_goal[name]: + phases = [phases[-1]] + self.node.get_logger().warn( + f"[TM_VLA:{name}] proximity mode — total_dist={total_dist:.2f}m < " + f"{_WAYPOINT_PROXIMITY_THRESHOLD}m, handing off final goal to nav2" + ) + + self.node.get_logger().warn( + f"[TM_VLA:{name}] {len(phases)} waypoints from " + f"robot=({current_pose.position.x:.2f},{current_pose.position.y:.2f}) " + #f"to ({phases[-1].pose.position.x:.2f},{phases[-1].pose.position.y:.2f})" + f"to ({self._ctx.environment_manager.realize(phases[-1].pose).position.x:.2f},{self._ctx.environment_manager.realize(phases[-1].pose).position.y:.2f})" + ) + + self._invalid_streak[name] = 0 + self._started[name] = True + self._timeouts[name] = self.node.sim_time + await robot.submit_task(TaskRequest(phases=phases)) + + #---------------------------- + def is_valid_pose(self, x:float, y:float)->bool: + from task_generator.manager.world_manager.utils import WorldOccupancy + world_map = self._ctx.world_manager.map + row, col = world_map.tf_pos2grid(Position(x=x, y=y)) + rows, cols = world_map.occupancy.grid.shape + if not (0 <= row < rows and 0 <= col < cols): + return False + return bool(WorldOccupancy.not_full(world_map.occupancy.grid)[int(row), int(col)]) + + def to_pose(self, current: Pose, action: tuple[float, float]) -> Pose: + dx, dy = action + yaw = current.orientation.to_yaw() + new_x = current.position.x + dx * math.cos(yaw) - dy * math.sin(yaw) + new_y = current.position.y + dx * math.sin(yaw) + dy * math.cos(yaw) + + move_dx = dx * math.cos(yaw) - dy * math.sin(yaw) + move_dy = dx * math.sin(yaw) + dy * math.cos(yaw) + new_yaw = math.atan2(move_dy, move_dx) + return Pose(Position(new_x, new_y), Orientation.from_yaw(new_yaw)) \ No newline at end of file