diff --git a/py/AILab_BiRefNet.py b/py/AILab_BiRefNet.py index 94e2a7e..8362331 100644 --- a/py/AILab_BiRefNet.py +++ b/py/AILab_BiRefNet.py @@ -22,7 +22,8 @@ from safetensors.torch import load_file import cv2 -device = "cuda" if torch.cuda.is_available() else "cpu" +from AILab_utils import get_device +device = get_device() # Add model path folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG")) diff --git a/py/AILab_BodySegment.py b/py/AILab_BodySegment.py index 8e0082c..f10eed1 100644 --- a/py/AILab_BodySegment.py +++ b/py/AILab_BodySegment.py @@ -45,7 +45,8 @@ def RGB2RGBA(image: Image.Image, mask: Union[Image.Image, torch.Tensor]) -> Imag mask = mask.resize(image.size, Image.Resampling.LANCZOS) return Image.merge('RGBA', (*image.convert('RGB').split(), mask.convert('L'))) -device = "cuda" if torch.cuda.is_available() else "cpu" +from AILab_utils import get_device +device = get_device() folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG")) diff --git a/py/AILab_ClothSegment.py b/py/AILab_ClothSegment.py index 62bd5bd..96813d6 100644 --- a/py/AILab_ClothSegment.py +++ b/py/AILab_ClothSegment.py @@ -44,7 +44,8 @@ def RGB2RGBA(image: Image.Image, mask: Union[Image.Image, torch.Tensor]) -> Imag mask = mask.resize(image.size, Image.Resampling.LANCZOS) return Image.merge('RGBA', (*image.convert('RGB').split(), mask.convert('L'))) -device = "cuda" if torch.cuda.is_available() else "cpu" +from AILab_utils import get_device +device = get_device() folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG")) diff --git a/py/AILab_FaceSegment.py b/py/AILab_FaceSegment.py index 50c9e75..2e9c512 100644 --- a/py/AILab_FaceSegment.py +++ b/py/AILab_FaceSegment.py @@ -42,7 +42,8 @@ def RGB2RGBA(image: Image.Image, mask: Union[Image.Image, torch.Tensor]) -> Imag mask = mask.resize(image.size, Image.Resampling.LANCZOS) return Image.merge('RGBA', (*image.convert('RGB').split(), mask.convert('L'))) -device = "cuda" if torch.cuda.is_available() else "cpu" +from AILab_utils import get_device +device = get_device() folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG")) diff --git a/py/AILab_FashionSegment.py b/py/AILab_FashionSegment.py index 0fcb829..2278dc6 100644 --- a/py/AILab_FashionSegment.py +++ b/py/AILab_FashionSegment.py @@ -38,7 +38,8 @@ def mask2image(mask: torch.Tensor) -> Image.Image: mask = mask.unsqueeze(0) return tensor2pil(mask) -device = "cuda" if torch.cuda.is_available() else "cpu" +from AILab_utils import get_device +device = get_device() folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG")) diff --git a/py/AILab_RMBG.py b/py/AILab_RMBG.py index 06c5a63..c2d8adb 100644 --- a/py/AILab_RMBG.py +++ b/py/AILab_RMBG.py @@ -31,7 +31,8 @@ import cv2 import types -device = "cuda" if torch.cuda.is_available() else "cpu" +from AILab_utils import get_device +device = get_device() folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG")) diff --git a/py/AILab_SegmentV2.py b/py/AILab_SegmentV2.py index 1d4ea58..1b8bbcf 100644 --- a/py/AILab_SegmentV2.py +++ b/py/AILab_SegmentV2.py @@ -15,6 +15,7 @@ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection from AILab_ImageMaskTools import pil2tensor, tensor2pil +from AILab_utils import get_device SAM_MODELS = { "sam_vit_h (2.56GB)": { @@ -165,7 +166,7 @@ def __init__(self): def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30, mask_blur=0, mask_offset=0, background="Alpha", background_color="#222222", invert_output=False): - device = "cuda" if torch.cuda.is_available() else "cpu" + device = get_device() batch_size = image.shape[0] if len(image.shape) == 4 else 1 if len(image.shape) == 3: diff --git a/py/AILab_utils.py b/py/AILab_utils.py index acc2e47..4b66c81 100644 --- a/py/AILab_utils.py +++ b/py/AILab_utils.py @@ -3,6 +3,17 @@ import torch from PIL import Image from comfy.utils import common_upscale +import comfy.model_management as mm + + +def get_device(): + """Return the torch device ComfyUI selected (CUDA, MPS, or CPU). + + Defers to ComfyUI's device management instead of assuming CUDA-or-CPU, so nodes + follow whatever backend ComfyUI chose. Fixes models silently running on CPU on + Apple Silicon (MPS). See issues #200 and #135. + """ + return mm.get_torch_device() def tensor2pil(image: torch.Tensor) -> Image.Image: