diff --git a/core/dispatch.py b/core/dispatch.py index b73dff3..28d55fb 100644 --- a/core/dispatch.py +++ b/core/dispatch.py @@ -8,10 +8,11 @@ from PIL import Image from copy import deepcopy -from .utils import clean_url, get_client_id +from .utils import clean_url, get_client_id, get_auth_headers -def clear_remote_queue(remote_url): - r = requests.get(f"{remote_url}/queue", timeout=4) +def clear_remote_queue(remote_url, remote_bearer_token): + headers = get_auth_headers(remote_bearer_token) + r = requests.get(f"{remote_url}/queue", headers=headers, timeout=4) r.raise_for_status() queue = r.json() @@ -23,6 +24,7 @@ def clear_remote_queue(remote_url): r = requests.post( f"{remote_url}/queue", json = {"delete" : to_cancel}, + headers=headers, timeout = 4, ) r.raise_for_status() @@ -32,29 +34,32 @@ def clear_remote_queue(remote_url): r = requests.post( f"{remote_url}/interrupt", json = {}, + headers=headers, timeout = 4, ) r.raise_for_status() break -def get_remote_os(remote_url): +def get_remote_os(remote_url, remote_bearer_token): + headers = get_auth_headers(remote_bearer_token) url = f"{remote_url}/system_stats" - r = requests.get(url, timeout=4) + r = requests.get(url, headers=headers, timeout=4) r.raise_for_status() data = r.json() return data["system"]["os"] -def get_output_nodes(remote_url): +def get_output_nodes(remote_url, remote_bearer_token): # I'm 90% sure this could just use the # list from the host but better safe than sorry url = f"{remote_url}/object_info" - r = requests.get(url, timeout=4) + headers = get_auth_headers(remote_bearer_token) + r = requests.get(url, headers=headers, timeout=4) r.raise_for_status() data = r.json() out = [k for k, v in data.items() if v.get("output_node")] return out -def dispatch_to_remote(remote_url, prompt, job_id=f"{get_client_id()}-unknown", outputs="final_image"): +def dispatch_to_remote(remote_url, remote_bearer_token, prompt, job_id=f"{get_client_id()}-unknown", outputs="final_image"): ### PROMPT LOGIC ### prompt = deepcopy(prompt) to_del = [] @@ -108,7 +113,7 @@ def recursive_node_deletion(start_node): for i in to_del: del prompt[i] ### OS LOGIC ### - sep_remote = "\\" if get_remote_os(remote_url) == "nt" else "/" + sep_remote = "\\" if get_remote_os(remote_url, remote_bearer_token) == "nt" else "/" sep_local = "\\" if os.name == "nt" else "/" sem_input_map = { # class type : input to replace "CheckpointLoaderSimple" : "ckpt_name", @@ -130,10 +135,11 @@ def recursive_node_deletion(start_node): "job_id": job_id, } } + headers = get_auth_headers(remote_bearer_token) ar = requests.post( f"{remote_url}/prompt", data = json.dumps(data), - headers = {"Content-Type": "application/json"}, + headers = {"Content-Type": "application/json", **headers}, timeout = 4, ) ar.raise_for_status() diff --git a/core/fetch.py b/core/fetch.py index 7c64f21..78ad528 100644 --- a/core/fetch.py +++ b/core/fetch.py @@ -4,6 +4,7 @@ import requests import numpy as np from PIL import Image +from .utils import get_auth_headers POLLING = 0.5 @@ -15,10 +16,11 @@ def get_job_output(inputs, outputs): break return outputs[output_id].get("images", []) -def wait_for_job(remote_url, job_id): +def wait_for_job(remote_url, remote_bearer_token, job_id): fail = 0 + headers = get_auth_headers(remote_bearer_token) while fail <= 3: - r = requests.get(f"{remote_url}/history", timeout=4) + r = requests.get(f"{remote_url}/history", headers=headers, timeout=4) try: r.raise_for_status() except Exception as e: @@ -40,7 +42,7 @@ def wait_for_job(remote_url, job_id): time.sleep(POLLING) raise OSError("Failed to fetch image from remote client!") -def fetch_from_remote(remote_url, job_id): +def fetch_from_remote(remote_url, remote_bearer_token, job_id): def img_to_torch(img): image = img.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 @@ -51,10 +53,11 @@ def img_to_torch(img): return None images = [] - for i in wait_for_job(remote_url, job_id): + headers = get_auth_headers(remote_bearer_token) + for i in wait_for_job(remote_url, remote_bearer_token, job_id): img_url = f"{remote_url}/view?filename={i['filename']}&subfolder={i['subfolder']}&type={i['type']}" - ir = requests.get(img_url, stream=True, timeout=16) + ir = requests.get(img_url, headers=headers, stream=True, timeout=16) ir.raise_for_status() img = Image.open(ir.raw) images.append(img_to_torch(img)) diff --git a/core/utils.py b/core/utils.py index 8002a37..8b478a7 100644 --- a/core/utils.py +++ b/core/utils.py @@ -21,3 +21,9 @@ def clean_url(raw, multi=False): raw = raw.replace(' ', ',').replace('\n', ',').replace('\t', ',') urls = [x.rstrip('/') for x in raw.split(',') if x.strip()] return urls if multi else urls[0] + +def get_auth_headers(remote_bearer_token): + headers = {} + if remote_bearer_token and remote_bearer_token.strip(): + headers["Authorization"] = f"Bearer {remote_bearer_token.strip()}" + return headers diff --git a/nodes/advanced.py b/nodes/advanced.py index 18f258a..2220dd3 100644 --- a/nodes/advanced.py +++ b/nodes/advanced.py @@ -73,6 +73,11 @@ def INPUT_TYPES(s): "multiline": False, "default": "http://127.0.0.1:8288/", }), + "remote_bearer_token": ("STRING", { + "multiline": False, + "default": "", + "tooltip": "Optional Bearer token for authenticated remote ComfyUI servers (e.g., for ComfyUI-Login)", + }), "batch_override": ("INT", {"default": 0, "min": 0, "max": 8}), "enabled": (["true", "false", "remote"],{"default": "true"}), "outputs": (["final_image", "any"],{"default":"final_image"}), @@ -85,7 +90,7 @@ def INPUT_TYPES(s): CATEGORY = "remote/advanced" TITLE = "Queue on remote (worker)" - def queue(self, remote_chain, remote_url, batch_override, enabled, outputs): + def queue(self, remote_chain, remote_url, remote_bearer_token, batch_override, enabled, outputs): current_offset = remote_chain["seed_offset"] remote_chain["seed_offset"] += 1 if batch_override == 0 else batch_override if enabled == "false": @@ -98,15 +103,17 @@ def queue(self, remote_chain, remote_url, batch_override, enabled, outputs): return (remote_chain, {}) remote_url = clean_url(remote_url) - clear_remote_queue(remote_url) + clear_remote_queue(remote_url, remote_bearer_token) dispatch_to_remote( remote_url, + remote_bearer_token, remote_chain["prompt"], remote_chain["job_id"], outputs, ) remote_info = { "remote_url" : remote_url, + "remote_bearer_token": remote_bearer_token, "job_id" : remote_chain["job_id"], } return (remote_chain, remote_info) diff --git a/nodes/simple.py b/nodes/simple.py index 872cca4..7d1fb4c 100644 --- a/nodes/simple.py +++ b/nodes/simple.py @@ -28,6 +28,7 @@ def INPUT_TYPES(s): def fetch(self, final_image, remote_info): out = fetch_from_remote( remote_url = remote_info.get("remote_url"), + remote_bearer_token = remote_info.get("remote_bearer_token"), job_id = remote_info.get("job_id"), ) if out is None: @@ -49,6 +50,11 @@ def INPUT_TYPES(s): "multiline": False, "default": "http://127.0.0.1:8288/", }), + "remote_bearer_token": ("STRING", { + "multiline": False, + "default": "", + "tooltip": "Optional Bearer token for authenticated remote ComfyUI servers (e.g., for ComfyUI-Login)", + }), "batch_local": ("INT", {"default": 1, "min": 1, "max": 8}), "batch_remote": ("INT", {"default": 1, "min": 1, "max": 8}), "trigger": (["on_change", "always"],), @@ -66,7 +72,7 @@ def INPUT_TYPES(s): CATEGORY = "remote" TITLE = "Queue on remote (single)" - def queue(self, remote_url, batch_local, batch_remote, trigger, enabled, seed, prompt): + def queue(self, remote_url, remote_bearer_token, batch_local, batch_remote, trigger, enabled, seed, prompt): if enabled == "false": return (seed, batch_local, {}) if enabled == "remote": @@ -74,18 +80,19 @@ def queue(self, remote_url, batch_local, batch_remote, trigger, enabled, seed, p job_id = get_new_job_id() remote_url = clean_url(remote_url) - clear_remote_queue(remote_url) - dispatch_to_remote(remote_url, prompt, job_id) + clear_remote_queue(remote_url, remote_bearer_token) + dispatch_to_remote(remote_url, remote_bearer_token, prompt, job_id) remote_info = { "remote_url" : remote_url, + "remote_bearer_token": remote_bearer_token, "job_id" : job_id, } return (seed, batch_local, remote_info) @classmethod - def IS_CHANGED(self, remote_url, batch_local, batch_remote, trigger, enabled, seed, prompt): - uuid = f"W:{remote_url},B1:{batch_local},B2:{batch_remote},S:{seed},E:{enabled}" + def IS_CHANGED(self, remote_url, remote_bearer_token, batch_local, batch_remote, trigger, enabled, seed, prompt): + uuid = f"W:{remote_url},TOKEN:{remote_bearer_token},B1:{batch_local},B2:{batch_remote},S:{seed},E:{enabled}" return uuid if trigger == "on_change" else str(time.time()) NODE_CLASS_MAPPINGS = {