Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion py/AILab_BiRefNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from safetensors.torch import load_file
import cv2

device = "cuda" if torch.cuda.is_available() else "cpu"
from AILab_utils import get_device
device = get_device()

# Add model path
folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG"))
Expand Down
3 changes: 2 additions & 1 deletion py/AILab_BodySegment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def RGB2RGBA(image: Image.Image, mask: Union[Image.Image, torch.Tensor]) -> Imag
mask = mask.resize(image.size, Image.Resampling.LANCZOS)
return Image.merge('RGBA', (*image.convert('RGB').split(), mask.convert('L')))

device = "cuda" if torch.cuda.is_available() else "cpu"
from AILab_utils import get_device
device = get_device()

folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG"))

Expand Down
3 changes: 2 additions & 1 deletion py/AILab_ClothSegment.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def RGB2RGBA(image: Image.Image, mask: Union[Image.Image, torch.Tensor]) -> Imag
mask = mask.resize(image.size, Image.Resampling.LANCZOS)
return Image.merge('RGBA', (*image.convert('RGB').split(), mask.convert('L')))

device = "cuda" if torch.cuda.is_available() else "cpu"
from AILab_utils import get_device
device = get_device()

folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG"))

Expand Down
3 changes: 2 additions & 1 deletion py/AILab_FaceSegment.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def RGB2RGBA(image: Image.Image, mask: Union[Image.Image, torch.Tensor]) -> Imag
mask = mask.resize(image.size, Image.Resampling.LANCZOS)
return Image.merge('RGBA', (*image.convert('RGB').split(), mask.convert('L')))

device = "cuda" if torch.cuda.is_available() else "cpu"
from AILab_utils import get_device
device = get_device()

folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG"))

Expand Down
3 changes: 2 additions & 1 deletion py/AILab_FashionSegment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def mask2image(mask: torch.Tensor) -> Image.Image:
mask = mask.unsqueeze(0)
return tensor2pil(mask)

device = "cuda" if torch.cuda.is_available() else "cpu"
from AILab_utils import get_device
device = get_device()

folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG"))

Expand Down
3 changes: 2 additions & 1 deletion py/AILab_RMBG.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
import cv2
import types

device = "cuda" if torch.cuda.is_available() else "cpu"
from AILab_utils import get_device
device = get_device()

folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG"))

Expand Down
3 changes: 2 additions & 1 deletion py/AILab_SegmentV2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

from AILab_ImageMaskTools import pil2tensor, tensor2pil
from AILab_utils import get_device

SAM_MODELS = {
"sam_vit_h (2.56GB)": {
Expand Down Expand Up @@ -165,7 +166,7 @@ def __init__(self):
def segment_v2(self, image, prompt, sam_model, dino_model, threshold=0.30,
mask_blur=0, mask_offset=0, background="Alpha",
background_color="#222222", invert_output=False):
device = "cuda" if torch.cuda.is_available() else "cpu"
device = get_device()

batch_size = image.shape[0] if len(image.shape) == 4 else 1
if len(image.shape) == 3:
Expand Down
11 changes: 11 additions & 0 deletions py/AILab_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@
import torch
from PIL import Image
from comfy.utils import common_upscale
import comfy.model_management as mm


def get_device():
"""Return the torch device ComfyUI selected (CUDA, MPS, or CPU).

Defers to ComfyUI's device management instead of assuming CUDA-or-CPU, so nodes
follow whatever backend ComfyUI chose. Fixes models silently running on CPU on
Apple Silicon (MPS). See issues #200 and #135.
"""
return mm.get_torch_device()


def tensor2pil(image: torch.Tensor) -> Image.Image:
Expand Down