Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions core/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from PIL import Image
from copy import deepcopy

from .utils import clean_url, get_client_id
from .utils import clean_url, get_client_id, get_auth_headers

def clear_remote_queue(remote_url):
r = requests.get(f"{remote_url}/queue", timeout=4)
def clear_remote_queue(remote_url, remote_bearer_token):
headers = get_auth_headers(remote_bearer_token)
r = requests.get(f"{remote_url}/queue", headers=headers, timeout=4)
r.raise_for_status()
queue = r.json()

Expand All @@ -23,6 +24,7 @@ def clear_remote_queue(remote_url):
r = requests.post(
f"{remote_url}/queue",
json = {"delete" : to_cancel},
headers=headers,
timeout = 4,
)
r.raise_for_status()
Expand All @@ -32,29 +34,32 @@ def clear_remote_queue(remote_url):
r = requests.post(
f"{remote_url}/interrupt",
json = {},
headers=headers,
timeout = 4,
)
r.raise_for_status()
break

def get_remote_os(remote_url):
def get_remote_os(remote_url, remote_bearer_token):
headers = get_auth_headers(remote_bearer_token)
url = f"{remote_url}/system_stats"
r = requests.get(url, timeout=4)
r = requests.get(url, headers=headers, timeout=4)
r.raise_for_status()
data = r.json()
return data["system"]["os"]

def get_output_nodes(remote_url):
def get_output_nodes(remote_url, remote_bearer_token):
# I'm 90% sure this could just use the
# list from the host but better safe than sorry
url = f"{remote_url}/object_info"
r = requests.get(url, timeout=4)
headers = get_auth_headers(remote_bearer_token)
r = requests.get(url, headers=headers, timeout=4)
r.raise_for_status()
data = r.json()
out = [k for k, v in data.items() if v.get("output_node")]
return out

def dispatch_to_remote(remote_url, prompt, job_id=f"{get_client_id()}-unknown", outputs="final_image"):
def dispatch_to_remote(remote_url, remote_bearer_token, prompt, job_id=f"{get_client_id()}-unknown", outputs="final_image"):
### PROMPT LOGIC ###
prompt = deepcopy(prompt)
to_del = []
Expand Down Expand Up @@ -108,7 +113,7 @@ def recursive_node_deletion(start_node):
for i in to_del: del prompt[i]

### OS LOGIC ###
sep_remote = "\\" if get_remote_os(remote_url) == "nt" else "/"
sep_remote = "\\" if get_remote_os(remote_url, remote_bearer_token) == "nt" else "/"
sep_local = "\\" if os.name == "nt" else "/"
sem_input_map = { # class type : input to replace
"CheckpointLoaderSimple" : "ckpt_name",
Expand All @@ -130,10 +135,11 @@ def recursive_node_deletion(start_node):
"job_id": job_id,
}
}
headers = get_auth_headers(remote_bearer_token)
ar = requests.post(
f"{remote_url}/prompt",
data = json.dumps(data),
headers = {"Content-Type": "application/json"},
headers = {"Content-Type": "application/json", **headers},
timeout = 4,
)
ar.raise_for_status()
Expand Down
13 changes: 8 additions & 5 deletions core/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import requests
import numpy as np
from PIL import Image
from .utils import get_auth_headers

POLLING = 0.5

Expand All @@ -15,10 +16,11 @@ def get_job_output(inputs, outputs):
break
return outputs[output_id].get("images", [])

def wait_for_job(remote_url, job_id):
def wait_for_job(remote_url, remote_bearer_token, job_id):
fail = 0
headers = get_auth_headers(remote_bearer_token)
while fail <= 3:
r = requests.get(f"{remote_url}/history", timeout=4)
r = requests.get(f"{remote_url}/history", headers=headers, timeout=4)
try:
r.raise_for_status()
except Exception as e:
Expand All @@ -40,7 +42,7 @@ def wait_for_job(remote_url, job_id):
time.sleep(POLLING)
raise OSError("Failed to fetch image from remote client!")

def fetch_from_remote(remote_url, job_id):
def fetch_from_remote(remote_url, remote_bearer_token, job_id):
def img_to_torch(img):
image = img.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
Expand All @@ -51,10 +53,11 @@ def img_to_torch(img):
return None

images = []
for i in wait_for_job(remote_url, job_id):
headers = get_auth_headers(remote_bearer_token)
for i in wait_for_job(remote_url, remote_bearer_token, job_id):
img_url = f"{remote_url}/view?filename={i['filename']}&subfolder={i['subfolder']}&type={i['type']}"

ir = requests.get(img_url, stream=True, timeout=16)
ir = requests.get(img_url, headers=headers, stream=True, timeout=16)
ir.raise_for_status()
img = Image.open(ir.raw)
images.append(img_to_torch(img))
Expand Down
6 changes: 6 additions & 0 deletions core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,9 @@ def clean_url(raw, multi=False):
raw = raw.replace(' ', ',').replace('\n', ',').replace('\t', ',')
urls = [x.rstrip('/') for x in raw.split(',') if x.strip()]
return urls if multi else urls[0]

def get_auth_headers(remote_bearer_token):
headers = {}
if remote_bearer_token and remote_bearer_token.strip():
headers["Authorization"] = f"Bearer {remote_bearer_token.strip()}"
return headers
11 changes: 9 additions & 2 deletions nodes/advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def INPUT_TYPES(s):
"multiline": False,
"default": "http://127.0.0.1:8288/",
}),
"remote_bearer_token": ("STRING", {
"multiline": False,
"default": "",
"tooltip": "Optional Bearer token for authenticated remote ComfyUI servers (e.g., for ComfyUI-Login)",
}),
"batch_override": ("INT", {"default": 0, "min": 0, "max": 8}),
"enabled": (["true", "false", "remote"],{"default": "true"}),
"outputs": (["final_image", "any"],{"default":"final_image"}),
Expand All @@ -85,7 +90,7 @@ def INPUT_TYPES(s):
CATEGORY = "remote/advanced"
TITLE = "Queue on remote (worker)"

def queue(self, remote_chain, remote_url, batch_override, enabled, outputs):
def queue(self, remote_chain, remote_url, remote_bearer_token, batch_override, enabled, outputs):
current_offset = remote_chain["seed_offset"]
remote_chain["seed_offset"] += 1 if batch_override == 0 else batch_override
if enabled == "false":
Expand All @@ -98,15 +103,17 @@ def queue(self, remote_chain, remote_url, batch_override, enabled, outputs):
return (remote_chain, {})

remote_url = clean_url(remote_url)
clear_remote_queue(remote_url)
clear_remote_queue(remote_url, remote_bearer_token)
dispatch_to_remote(
remote_url,
remote_bearer_token,
remote_chain["prompt"],
remote_chain["job_id"],
outputs,
)
remote_info = {
"remote_url" : remote_url,
"remote_bearer_token": remote_bearer_token,
"job_id" : remote_chain["job_id"],
}
return (remote_chain, remote_info)
Expand Down
17 changes: 12 additions & 5 deletions nodes/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def INPUT_TYPES(s):
def fetch(self, final_image, remote_info):
out = fetch_from_remote(
remote_url = remote_info.get("remote_url"),
remote_bearer_token = remote_info.get("remote_bearer_token"),
job_id = remote_info.get("job_id"),
)
if out is None:
Expand All @@ -49,6 +50,11 @@ def INPUT_TYPES(s):
"multiline": False,
"default": "http://127.0.0.1:8288/",
}),
"remote_bearer_token": ("STRING", {
"multiline": False,
"default": "",
"tooltip": "Optional Bearer token for authenticated remote ComfyUI servers (e.g., for ComfyUI-Login)",
}),
"batch_local": ("INT", {"default": 1, "min": 1, "max": 8}),
"batch_remote": ("INT", {"default": 1, "min": 1, "max": 8}),
"trigger": (["on_change", "always"],),
Expand All @@ -66,26 +72,27 @@ def INPUT_TYPES(s):
CATEGORY = "remote"
TITLE = "Queue on remote (single)"

def queue(self, remote_url, batch_local, batch_remote, trigger, enabled, seed, prompt):
def queue(self, remote_url, remote_bearer_token, batch_local, batch_remote, trigger, enabled, seed, prompt):
if enabled == "false":
return (seed, batch_local, {})
if enabled == "remote":
return (seed+batch_local, batch_remote, {})

job_id = get_new_job_id()
remote_url = clean_url(remote_url)
clear_remote_queue(remote_url)
dispatch_to_remote(remote_url, prompt, job_id)
clear_remote_queue(remote_url, remote_bearer_token)
dispatch_to_remote(remote_url, remote_bearer_token, prompt, job_id)

remote_info = {
"remote_url" : remote_url,
"remote_bearer_token": remote_bearer_token,
"job_id" : job_id,
}
return (seed, batch_local, remote_info)

@classmethod
def IS_CHANGED(self, remote_url, batch_local, batch_remote, trigger, enabled, seed, prompt):
uuid = f"W:{remote_url},B1:{batch_local},B2:{batch_remote},S:{seed},E:{enabled}"
def IS_CHANGED(self, remote_url, remote_bearer_token, batch_local, batch_remote, trigger, enabled, seed, prompt):
uuid = f"W:{remote_url},TOKEN:{remote_bearer_token},B1:{batch_local},B2:{batch_remote},S:{seed},E:{enabled}"
return uuid if trigger == "on_change" else str(time.time())

NODE_CLASS_MAPPINGS = {
Expand Down