diff --git a/frontend/src/components/graph/AddNodeModal.tsx b/frontend/src/components/graph/AddNodeModal.tsx index 56a2e6253..813dc2881 100644 --- a/frontend/src/components/graph/AddNodeModal.tsx +++ b/frontend/src/components/graph/AddNodeModal.tsx @@ -1,4 +1,4 @@ -import { useState, useMemo, useRef, useEffect } from "react"; +import { useState, useMemo, useRef, useEffect, useCallback } from "react"; import { Dialog, DialogContent, @@ -6,6 +6,7 @@ import { DialogTitle, DialogDescription, } from "../ui/dialog"; +import type { FlowNodeData } from "../../lib/graphUtils"; interface AddNodeModalProps { open: boolean; @@ -36,8 +37,10 @@ interface AddNodeModalProps { | "tempo" | "prompt_list" | "prompt_blend" - | "scheduler", - subType?: string + | "scheduler" + | "custom_node", + subType?: string, + extraData?: Partial ) => void; } @@ -67,12 +70,15 @@ interface NodeCatalogItem { | "tempo" | "prompt_list" | "prompt_blend" - | "scheduler"; + | "scheduler" + | "custom_node"; subType?: string; name: string; description: string; color: string; category: string; + /** Full definition for custom nodes (inputs/outputs/params). */ + customNodeDef?: Record; } const NODE_CATALOG: NodeCatalogItem[] = [ @@ -85,8 +91,8 @@ const NODE_CATALOG: NodeCatalogItem[] = [ }, { type: "pipeline", - name: "Pipeline", - description: "Processing pipeline node", + name: "Node", + description: "Video processing node (pick a model after dropping it)", color: "#60a5fa", category: "I/O", }, @@ -294,7 +300,15 @@ const NODE_CATALOG: NodeCatalogItem[] = [ }, ]; -const CATEGORIES = ["All", "I/O", "Values", "Controls", "UI", "Utility"]; +const CATEGORIES = [ + "All", + "I/O", + "Values", + "Controls", + "UI", + "Utility", + "Plugins", +]; interface TooltipState { text: string; @@ -388,10 +402,50 @@ export function AddNodeModal({ }: AddNodeModalProps) { const [searchText, setSearchText] = useState(""); const [activeCategory, setActiveCategory] = useState("All"); + const [customNodes, setCustomNodes] = useState([]); + + useEffect(() => { + if (!open) return; + fetch("/api/v1/nodes/definitions") + .then(r => r.json()) + .then(data => { + // The unified endpoint returns both pipelines (pipeline_meta != null) + // and plain custom nodes. Pipelines are still added via the hardcoded + // "Pipeline" catalog entry (placeholder + dropdown); the scheduler + // has its own catalog entry with a bespoke widget. Filter both out + // of the plugin listing to avoid duplication. + const items: NodeCatalogItem[] = (data.nodes ?? []) + + .filter( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (n: any) => + n.pipeline_meta == null && n.node_type_id !== "scheduler" + ) + // eslint-disable-next-line @typescript-eslint/no-explicit-any + .map((n: any) => ({ + type: "custom_node" as const, + subType: n.node_type_id, + name: n.display_name || n.node_type_id, + description: n.description || "", + color: "#9ca3af", + category: "Plugins", + customNodeDef: n, + })); + setCustomNodes(items); + }) + .catch(() => { + /* ignore — custom nodes just won't appear */ + }); + }, [open]); + + const fullCatalog = useMemo( + () => [...NODE_CATALOG, ...customNodes], + [customNodes] + ); const filteredItems = useMemo(() => { const lowerSearch = searchText.toLowerCase(); - return NODE_CATALOG.filter(item => { + return fullCatalog.filter(item => { const matchesSearch = !lowerSearch || item.name.toLowerCase().includes(lowerSearch) || @@ -400,14 +454,37 @@ export function AddNodeModal({ activeCategory === "All" || item.category === activeCategory; return matchesSearch && matchesCategory; }); - }, [searchText, activeCategory]); + }, [searchText, activeCategory, fullCatalog]); - const handleSelect = (item: NodeCatalogItem) => { - onSelectNodeType(item.type, item.subType); - onClose(); - setSearchText(""); - setActiveCategory("All"); - }; + const handleSelect = useCallback( + (item: NodeCatalogItem) => { + if (item.type === "custom_node" && item.customNodeDef) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const def = item.customNodeDef as any; + onSelectNodeType("custom_node", item.subType, { + customNodeTypeId: def.node_type_id, + customNodeDisplayName: def.display_name || def.node_type_id, + customNodeCategory: def.category || "", + customNodeInputs: def.inputs || [], + customNodeOutputs: def.outputs || [], + customNodeParamDefs: def.params || [], + customNodeParams: Object.fromEntries( + (def.params || []) + // eslint-disable-next-line @typescript-eslint/no-explicit-any + .filter((p: any) => p.default != null) + // eslint-disable-next-line @typescript-eslint/no-explicit-any + .map((p: any) => [p.name, p.default]) + ), + }); + } else { + onSelectNodeType(item.type, item.subType); + } + onClose(); + setSearchText(""); + setActiveCategory("All"); + }, + [onSelectNodeType, onClose] + ); const handleClose = () => { onClose(); diff --git a/frontend/src/components/graph/GraphEditor.tsx b/frontend/src/components/graph/GraphEditor.tsx index dda25008d..a4358f013 100644 --- a/frontend/src/components/graph/GraphEditor.tsx +++ b/frontend/src/components/graph/GraphEditor.tsx @@ -55,6 +55,7 @@ import { TempoNode } from "./nodes/TempoNode"; import { PromptListNode } from "./nodes/PromptListNode"; import { PromptBlendNode } from "./nodes/PromptBlendNode"; import { SchedulerNode } from "./nodes/SchedulerNode"; +import { CustomNode } from "./nodes/CustomNode"; import { CustomEdge } from "./CustomEdge"; import { ContextMenu } from "./ContextMenu"; import { AddNodeModal } from "./AddNodeModal"; @@ -130,6 +131,7 @@ const nodeTypes = { prompt_list: PromptListNode, prompt_blend: PromptBlendNode, scheduler: SchedulerNode, + custom_node: CustomNode, }; const edgeTypes = { diff --git a/frontend/src/components/graph/contextMenuItems.tsx b/frontend/src/components/graph/contextMenuItems.tsx index 0d4903f9e..caeadd698 100644 --- a/frontend/src/components/graph/contextMenuItems.tsx +++ b/frontend/src/components/graph/contextMenuItems.tsx @@ -96,10 +96,10 @@ export function buildPaneMenuItems(deps: { keywords: ["input", "camera", "video"], }, { - label: "Pipeline", + label: "Node", icon: , onClick: () => handleNodeTypeSelect("pipeline"), - keywords: ["process", "effect", "filter"], + keywords: ["process", "effect", "filter", "pipeline"], }, { label: "Sink", diff --git a/frontend/src/components/graph/hooks/graph/useGraphPersistence.ts b/frontend/src/components/graph/hooks/graph/useGraphPersistence.ts index 8a2248d31..f31df93f8 100644 --- a/frontend/src/components/graph/hooks/graph/useGraphPersistence.ts +++ b/frontend/src/components/graph/hooks/graph/useGraphPersistence.ts @@ -7,8 +7,8 @@ import { parseHandleId, } from "../../../../lib/graphUtils"; import type { FlowNodeData } from "../../../../lib/graphUtils"; -import type { PluginInfo } from "../../../../lib/api"; -import { resolveWorkflow } from "../../../../lib/api"; +import type { PluginInfo, NodeDefinitionDto } from "../../../../lib/api"; +import { resolveWorkflow, fetchNodeDefinitions } from "../../../../lib/api"; import type { ScopeWorkflow, WorkflowResolutionPlan, @@ -65,6 +65,72 @@ function clearGraphFromLocalStorage(): void { } } +/** + * After loading or importing a workflow, fetch `/api/v1/nodes/definitions` + * and hydrate each custom_node flow node with its inputs/outputs/params/ + * display metadata. Saved workflows only persist `node_type_id` and + * `params`, so the port definitions have to be re-attached before render. + * User-supplied param values override the defaults from the definition. + * + * Accepts an AbortSignal so rapid reloads / unmounts can cancel an + * in-flight fetch before its setNodes callback stomps on newer state. + */ +function hydrateCustomNodeDefinitions( + nodes: Node[], + setNodes: React.Dispatch[]>>, + signal: AbortSignal +): void { + const customFlowNodes = nodes.filter( + n => n.data.nodeType === "custom_node" && !n.data.customNodeInputs + ); + if (customFlowNodes.length === 0) return; + fetchNodeDefinitions({ signal }) + .then(data => { + if (signal.aborted) return; + const defMap = new Map( + data.nodes.map(d => [d.node_type_id, d]) + ); + setNodes(prev => { + if (signal.aborted) return prev; + return prev.map(n => { + if ( + n.data.nodeType !== "custom_node" || + !n.data.customNodeTypeId || + n.data.customNodeInputs + ) { + return n; + } + const def = defMap.get(n.data.customNodeTypeId); + if (!def) return n; + return { + ...n, + data: { + ...n.data, + customNodeDisplayName: def.display_name, + customNodeCategory: def.category, + customNodeInputs: def.inputs ?? [], + customNodeOutputs: def.outputs ?? [], + customNodeParamDefs: def.params ?? [], + customNodeParams: { + ...Object.fromEntries( + (def.params ?? []) + .filter(p => p.default != null) + .map(p => [p.name, p.default] as const) + ), + // User-edited values take precedence over definition defaults + ...(n.data.customNodeParams || {}), + }, + }, + }; + }); + }); + }) + .catch((err: unknown) => { + if (err instanceof DOMException && err.name === "AbortError") return; + // custom nodes just won't be hydrated; render falls back to placeholders + }); +} + interface UseGraphPersistenceArgs { nodes: Node[]; edges: Edge[]; @@ -129,6 +195,25 @@ export function useGraphPersistence({ // to localStorage so we skip the expensive save when nothing changed. const lastSavedJsonRef = useRef(""); + // AbortController for in-flight custom-node hydration fetches. Aborted + // before each new hydrate and on unmount so a stale /api/v1/nodes/definitions + // response can't overwrite newer nodes state. + const hydrateAbortRef = useRef(null); + const startHydrate = useCallback( + (initialNodes: Node[]) => { + hydrateAbortRef.current?.abort(); + const controller = new AbortController(); + hydrateAbortRef.current = controller; + hydrateCustomNodeDefinitions(initialNodes, setNodes, controller.signal); + }, + [setNodes] + ); + useEffect(() => { + return () => { + hydrateAbortRef.current?.abort(); + }; + }, []); + const loadGraph = useCallback(() => { if (Object.keys(portsMap).length === 0) return; resetNavigationRef.current?.(); @@ -182,6 +267,8 @@ export function useGraphPersistence({ }, 0); } + startHydrate(enriched); + // Allow async side-effects (e.g. source mode restore) to settle // before re-enabling change notifications. setTimeout(() => { @@ -202,6 +289,7 @@ export function useGraphPersistence({ setEdges, setNodeParams, setNodes, + startHydrate, ]); useEffect(() => { @@ -374,6 +462,8 @@ export function useGraphPersistence({ } }, 0); } + + startHydrate(enriched); }, [ portsMap, @@ -384,6 +474,7 @@ export function useGraphPersistence({ setNodeParams, enrichDepsRef, resetNavigationRef, + startHydrate, ] ); diff --git a/frontend/src/components/graph/hooks/node/useNodeFactories.ts b/frontend/src/components/graph/hooks/node/useNodeFactories.ts index dfdaeef90..1c18acb22 100644 --- a/frontend/src/components/graph/hooks/node/useNodeFactories.ts +++ b/frontend/src/components/graph/hooks/node/useNodeFactories.ts @@ -45,7 +45,8 @@ type NodeTypeKey = | "tempo" | "prompt_list" | "prompt_blend" - | "scheduler"; + | "scheduler" + | "custom_node"; interface NodeDefaults { /** The React Flow node `type` */ @@ -471,6 +472,15 @@ const NODE_DEFAULTS: Record = { ], }, }, + custom_node: { + type: "custom_node", + idPrefix: "custom", + defaultX: 300, + data: { + label: "Custom Node", + nodeType: "custom_node" as const, + }, + }, }; interface UseNodeFactoriesArgs { @@ -566,8 +576,10 @@ export function useNodeFactories({ | "tempo" | "prompt_list" | "prompt_blend" - | "scheduler", - subType?: string + | "scheduler" + | "custom_node", + subType?: string, + extraData?: Partial ) => { if (!pendingNodePosition) return; @@ -598,6 +610,12 @@ export function useNodeFactories({ outputSinkType: defaultType, outputSinkName: defaultNames[defaultType] || "Scope", }); + } else if (type === "custom_node") { + addNode("custom_node", pendingNodePosition, { + customNodeTypeId: subType, + label: subType || "Custom Node", + ...extraData, + }); } else { addNode(type as NodeTypeKey, pendingNodePosition); } diff --git a/frontend/src/components/graph/nodes/CustomNode.tsx b/frontend/src/components/graph/nodes/CustomNode.tsx new file mode 100644 index 000000000..5b1e6a6ef --- /dev/null +++ b/frontend/src/components/graph/nodes/CustomNode.tsx @@ -0,0 +1,238 @@ +import { Handle, Position } from "@xyflow/react"; +import type { NodeProps, Node } from "@xyflow/react"; +import type { FlowNodeData } from "../../../lib/graphUtils"; +import { buildHandleId } from "../../../lib/graphUtils"; +import { useNodeData } from "../hooks/node/useNodeData"; +import { useNodeCollapse } from "../hooks/node/useNodeCollapse"; +import { NodeCard, NodeHeader, NodeBody, collapsedHandleStyle } from "../ui"; + +type CustomNodeType = Node; + +/* Port type -> color mapping for custom types */ +const PORT_COLORS: Record = { + audio: "#22c55e", + video: "#eeeeee", + number: "#38bdf8", + string: "#fbbf24", + boolean: "#34d399", + trigger: "#f97316", + latent: "#a855f7", + model: "#f59e0b", + vae: "#f59e0b", + clip: "#f59e0b", + conditioning: "#3b82f6", + semantic_hints: "#06b6d4", + config: "#6b7280", + curve: "#ec4899", + mask: "#ef4444", + lora: "#f472b6", +}; + +function portColor(portType: string): string { + return PORT_COLORS[portType] ?? "#9ca3af"; +} + +export function CustomNode({ id, data, selected }: NodeProps) { + const { updateData } = useNodeData(id); + const { collapsed, toggleCollapse } = useNodeCollapse(); + + const inputs = data.customNodeInputs ?? []; + const outputs = data.customNodeOutputs ?? []; + const params = data.customNodeParamDefs ?? []; + const displayName = + data.customTitle || + data.customNodeDisplayName || + data.customNodeTypeId || + "Custom Node"; + const category = data.customNodeCategory ?? ""; + + return ( + + updateData({ customTitle: t })} + collapsed={collapsed} + onCollapseToggle={toggleCollapse} + /> + {!collapsed && ( + + {/* Show category badge */} + {category && ( +
+ + {category} + +
+ )} + {/* Show input ports */} + {inputs.length > 0 && ( +
+ {inputs.map(p => ( +
+ + {p.name} + + {p.port_type} + +
+ ))} +
+ )} + {/* Show output ports */} + {outputs.length > 0 && ( +
+ {outputs.map(p => ( +
+ + {p.port_type} + + {p.name} + +
+ ))} +
+ )} + {/* Parameter widgets (ComfyUI-style editable params) */} + {params.length > 0 && ( +
+ {params.map(p => { + const val = data.customNodeParams?.[p.name] ?? p.default ?? ""; + return ( +
+ + {p.description || p.name} + + {p.param_type === "select" && + Array.isArray(p.ui?.options) ? ( + + ) : p.param_type === "boolean" ? ( + + updateData({ + customNodeParams: { + ...data.customNodeParams, + [p.name]: e.target.checked, + }, + }) + } + className="accent-blue-500" + /> + ) : p.param_type === "number" ? ( + + updateData({ + customNodeParams: { + ...data.customNodeParams, + [p.name]: Number(e.target.value), + }, + }) + } + /> + ) : ( + + updateData({ + customNodeParams: { + ...data.customNodeParams, + [p.name]: e.target.value, + }, + }) + } + /> + )} +
+ ); + })} +
+ )} +
+ )} + + {/* Input handles (left side) */} + {inputs.map((p, i) => ( + + ))} + + {/* Output handles (right side) */} + {outputs.map((p, i) => ( + 0 ? inputs.length * 18 + 4 : 0) + i * 18}px`, + width: 8, + height: 8, + ...(collapsed ? collapsedHandleStyle : {}), + }} + /> + ))} +
+ ); +} diff --git a/frontend/src/components/graph/nodes/PipelineNode.tsx b/frontend/src/components/graph/nodes/PipelineNode.tsx index 074468602..d419ae6cd 100644 --- a/frontend/src/components/graph/nodes/PipelineNode.tsx +++ b/frontend/src/components/graph/nodes/PipelineNode.tsx @@ -84,7 +84,7 @@ export function PipelineNode({ const supportsLoRA = data.supportsLoRA ?? false; const isStreaming = data.isStreaming ?? false; - const pipelineName = data.pipelineId || "Pipeline"; + const pipelineName = data.pipelineId || "Node"; // Inject unavailable pipelineId into options const isUnavailable = @@ -187,8 +187,8 @@ export function PipelineNode({ /> {!collapsed && ( - {/* Pipeline selector */} - + {/* Node type selector (video pipeline model) */} + { diff --git a/frontend/src/components/graph/utils/connectionValidation.ts b/frontend/src/components/graph/utils/connectionValidation.ts index 1a3bf183c..475bb3a1e 100644 --- a/frontend/src/components/graph/utils/connectionValidation.ts +++ b/frontend/src/components/graph/utils/connectionValidation.ts @@ -302,9 +302,32 @@ export function validateConnection( return true; } - // Stream ↔ stream always ok - if (sourceParsed.kind === "stream" && targetParsed.kind === "stream") - return true; + // Stream ↔ stream: for custom_node edges, enforce port-type matching + // against the node's declared inputs/outputs. For built-in source / + // pipeline / sink nodes, streams are untyped (video) and always ok. + if (sourceParsed.kind === "stream" && targetParsed.kind === "stream") { + const sourceNode = nodes.find(n => n.id === connection.source); + const targetNode = nodes.find(n => n.id === connection.target); + if (!sourceNode || !targetNode) return true; + const srcIsCustom = sourceNode.data.nodeType === "custom_node"; + const tgtIsCustom = targetNode.data.nodeType === "custom_node"; + if (!srcIsCustom && !tgtIsCustom) return true; + + const srcType = srcIsCustom + ? sourceNode.data.customNodeOutputs?.find( + p => p.name === sourceParsed.name + )?.port_type + : undefined; + const tgtType = tgtIsCustom + ? targetNode.data.customNodeInputs?.find( + p => p.name === targetParsed.name + )?.port_type + : undefined; + + // If we can't look up both types, allow (assume compatible). + if (!srcType || !tgtType) return true; + return srcType === tgtType; + } // Param ↔ param if (sourceParsed.kind === "param" && targetParsed.kind === "param") { diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index f419b67b2..c13205c6c 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -913,8 +913,12 @@ export const deleteApiKey = async ( export interface GraphNode { id: string; - type: "source" | "pipeline" | "sink" | "record"; + type: "source" | "pipeline" | "sink" | "record" | "node"; pipeline_id?: string | null; + /** Node type ID (NodeRegistry key) when type is "node" */ + node_type_id?: string | null; + /** Per-node parameter values for custom nodes */ + params?: Record | null; x?: number | null; y?: number | null; w?: number | null; @@ -927,6 +931,63 @@ export interface GraphNode { sink_name?: string | null; } +export interface NodePortDef { + name: string; + port_type: string; + required?: boolean; + description?: string; + default_value?: unknown; +} + +export interface NodeParamDef { + name: string; + param_type: "number" | "string" | "boolean" | "select"; + default?: unknown; + description?: string; + /** + * Free-form widget hints (``min``/``max``/``step`` for number, + * ``options`` for select, …). Keeps widget-specific fields off the + * base schema; the frontend renderer dispatches on ``param_type`` + * and reads whichever ``ui`` keys apply. + */ + ui?: Record | null; + convertible_to_input?: boolean; +} + +export interface NodeDefinitionDto { + node_type_id: string; + display_name: string; + category: string; + description: string; + inputs: NodePortDef[]; + outputs: NodePortDef[]; + params: NodeParamDef[]; + continuous: boolean; + /** + * Rich pipeline-only metadata (config_schema, mode_defaults, + * supports_lora, supports_vace, etc.) populated for entries whose + * underlying class is a Pipeline subclass. ``null`` for plain nodes. + */ + pipeline_meta?: Record | null; +} + +export interface NodeDefinitionsResponse { + nodes: NodeDefinitionDto[]; +} + +export const fetchNodeDefinitions = async ( + options: { signal?: AbortSignal } = {} +): Promise => { + const response = await fetch("/api/v1/nodes/definitions", { + method: "GET", + signal: options.signal, + }); + if (!response.ok) { + throw new Error(`Failed to fetch node definitions: ${response.statusText}`); + } + return response.json(); +}; + export interface GraphEdge { from: string; from_port: string; diff --git a/frontend/src/lib/graphUtils.ts b/frontend/src/lib/graphUtils.ts index fa0c930fe..c187ad002 100644 --- a/frontend/src/lib/graphUtils.ts +++ b/frontend/src/lib/graphUtils.ts @@ -5,6 +5,8 @@ import type { GraphEdge, PipelineSchemaInfo, LoRAFileInfo, + NodePortDef, + NodeParamDef, } from "./api"; import { inferPrimitiveFieldType } from "./schemaSettings"; import { resolveLoRAPath } from "./workflowSettings"; @@ -104,7 +106,8 @@ export interface FlowNodeData { | "prompt_list" | "prompt_blend" | "scheduler" - | "audio"; + | "audio" + | "custom_node"; availablePipelineIds?: string[]; /** Declared input ports for the selected pipeline */ streamInputs?: string[]; @@ -386,6 +389,22 @@ export interface FlowNodeData { /* ── Tempo beat count offset ── */ tempoBeatCountOffset?: number; + /* ── Custom node fields ── */ + /** For custom_node: the node_type_id from the backend registry */ + customNodeTypeId?: string; + /** For custom_node: display name from node definition */ + customNodeDisplayName?: string; + /** For custom_node: category from node definition */ + customNodeCategory?: string; + /** For custom_node: input port definitions */ + customNodeInputs?: NodePortDef[]; + /** For custom_node: output port definitions */ + customNodeOutputs?: NodePortDef[]; + /** For custom_node: current parameter values (user-editable) */ + customNodeParams?: Record; + /** For custom_node: parameter definitions from API (widget metadata) */ + customNodeParamDefs?: NodeParamDef[]; + /* ── Node lock / pin / collapse ── */ /** When true, parameter inputs on this node are disabled (read-only). */ locked?: boolean; @@ -782,6 +801,66 @@ export function graphConfigToFlow( }); }); + // Backend custom nodes (type="node"). Port metadata is hydrated later + // from GET /api/v1/nodes/definitions in useGraphPersistence. + const customNodes = graph.nodes.filter( + n => n.type === "node" && !isSubgraphInnerNode(n.id) + ); + customNodes.forEach((n, i) => { + const savedX = n.x ?? undefined; + const savedY = n.y ?? undefined; + const sizeProps = + n.w != null || n.h != null + ? { + width: n.w ?? undefined, + height: n.h ?? undefined, + style: { width: n.w ?? undefined, height: n.h ?? undefined }, + } + : {}; + const position = { + x: savedX !== undefined ? savedX : START_X + COLUMN_GAP * 1.5, + y: savedY !== undefined ? savedY : START_Y + i * (NODE_HEIGHT + ROW_GAP), + }; + // Scheduler has its own React renderer — rehydrate into its native + // nodeType so the bespoke widget (trigger list, transport controls) + // comes back after a round-trip, instead of the generic custom_node UI. + if (n.node_type_id === "scheduler") { + const params = (n.params ?? {}) as Record; + nodes.push({ + id: n.id, + type: "scheduler", + position, + ...sizeProps, + data: { + label: "Scheduler", + nodeType: "scheduler", + schedulerTriggers: + (params.triggers as Array<{ time: number; port_name: string }>) ?? + [], + schedulerLoop: (params.loop as boolean) ?? false, + schedulerDuration: (params.duration as number) ?? 30, + schedulerElapsed: 0, + schedulerIsPlaying: false, + schedulerFireCounts: {}, + schedulerTickCount: 0, + }, + }); + return; + } + nodes.push({ + id: n.id, + type: "custom_node", + position, + ...sizeProps, + data: { + label: n.node_type_id || n.id, + nodeType: "custom_node", + customNodeTypeId: n.node_type_id ?? undefined, + customNodeParams: (n.params as Record) ?? undefined, + }, + }); + }); + // Convert edges - add stream: prefix to handle IDs // Skip edges that reference flattened inner subgraph nodes const edges: Edge[] = graph.edges @@ -926,7 +1005,6 @@ const FRONTEND_ONLY_TYPES = new Set([ "tempo", "prompt_list", "prompt_blend", - "scheduler", ]); /** Fields in FlowNodeData that are non-serializable (functions, streams, etc.) */ @@ -1202,6 +1280,16 @@ export function flowToGraphConfig( n.height ?? n.measured?.height ?? (typeof n.style?.height === "number" ? n.style.height : undefined); + const isBackendNode = + n.data.nodeType === "custom_node" || n.data.nodeType === "scheduler"; + const schedulerParams = + n.data.nodeType === "scheduler" + ? { + triggers: n.data.schedulerTriggers ?? [], + loop: n.data.schedulerLoop ?? false, + duration: n.data.schedulerDuration ?? 30, + } + : undefined; return { id: n.id, type: @@ -1211,11 +1299,23 @@ export function flowToGraphConfig( ? "sink" : n.data.nodeType === "record" ? "record" - : "pipeline", + : isBackendNode + ? "node" + : "pipeline", pipeline_id: n.data.nodeType === "pipeline" ? (n.data.pipelineId ?? null) : undefined, + node_type_id: + n.data.nodeType === "custom_node" + ? (n.data.customNodeTypeId ?? null) + : n.data.nodeType === "scheduler" + ? "scheduler" + : undefined, + params: + n.data.nodeType === "custom_node" && n.data.customNodeParams + ? n.data.customNodeParams + : schedulerParams, x: n.position.x, y: n.position.y, w: w && !Number.isNaN(w) ? w : undefined, diff --git a/frontend/src/pages/StreamPage.tsx b/frontend/src/pages/StreamPage.tsx index c9ab0fe69..c43c9ce0e 100644 --- a/frontend/src/pages/StreamPage.tsx +++ b/frontend/src/pages/StreamPage.tsx @@ -2731,10 +2731,13 @@ export function StreamPage() { ); } - const loadSuccess = await loadPipeline(loadItems); - if (!loadSuccess) { - console.error("Failed to load pipeline, cannot start stream"); - return false; + // Node-only graphs (no pipeline nodes) skip the pipeline load step. + if (loadItems.length > 0) { + const loadSuccess = await loadPipeline(loadItems); + if (!loadSuccess) { + console.error("Failed to load pipeline, cannot start stream"); + return false; + } } // Check video requirements based on input mode. diff --git a/src/scope/core/nodes/__init__.py b/src/scope/core/nodes/__init__.py new file mode 100644 index 000000000..ac4e46d4e --- /dev/null +++ b/src/scope/core/nodes/__init__.py @@ -0,0 +1,31 @@ +"""Backend node system for Scope. + +Provides a base class and registry for defining fine-grained processing +nodes that can be wired into pipeline graphs. Nodes are simpler than full +pipelines — they declare typed input/output ports, editable parameters, +and a small execution contract. Built-in nodes and plugin-provided nodes +are discovered here and rendered generically by the frontend via +``GET /api/v1/nodes/definitions``. +""" + +from .base import BaseNode, NodeDefinition, NodeParam, NodePort +from .builtins import AudioSourceNode, SchedulerNode +from .registry import NodeRegistry + + +def register_builtin_nodes() -> None: + """Register all built-in node types shipped with the foundation.""" + NodeRegistry.register(AudioSourceNode) + NodeRegistry.register(SchedulerNode) + + +__all__ = [ + "AudioSourceNode", + "BaseNode", + "NodeDefinition", + "NodeParam", + "NodePort", + "NodeRegistry", + "SchedulerNode", + "register_builtin_nodes", +] diff --git a/src/scope/core/nodes/base.py b/src/scope/core/nodes/base.py new file mode 100644 index 000000000..2280316aa --- /dev/null +++ b/src/scope/core/nodes/base.py @@ -0,0 +1,191 @@ +"""Base classes for the Scope node system. + +Nodes are lightweight, fine-grained processing units that can be wired into +pipeline graphs alongside pipelines. Each node type declares typed input/ +output ports and editable parameters, and subclasses implement their own +execution contract (which may differ between execution backends). + +This module intentionally keeps ``BaseNode`` minimal — only a class-level +identifier and a ``get_definition()`` classmethod are required. Concrete +execution backends (graph executor integration, event-driven runtime, etc.) +layer their own abstract methods on top. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, ClassVar, Literal + +from pydantic import BaseModel, Field + + +class NodePort(BaseModel): + """Describes an input or output port on a node.""" + + name: str = Field(..., description="Port identifier (used in edge wiring)") + port_type: str = Field( + ..., + description=( + "Type of data carried by this port. Built-in types: " + "'audio', 'video', 'number', 'string', 'boolean', 'trigger'. " + "Plugins may define custom types (e.g. 'latent', 'model')." + ), + ) + required: bool = Field(default=True, description="Whether this input is required") + description: str = Field(default="", description="Human-readable description") + default_value: Any = Field(default=None, description="Default value for inputs") + + +class NodeParam(BaseModel): + """Describes an editable parameter (widget) on a node. + + Parameters are user-configurable values that live on the node card. + Like ComfyUI widgets, a parameter may be overridden by connecting + an incoming wire to the corresponding input port — the widget then + becomes an input and the default value is ignored. + + Widget-specific hints (number min/max/step, select options, etc.) + go into the free-form ``ui`` dict so the base schema doesn't grow + as new widget kinds are added. The frontend renderer dispatches on + ``param_type`` and reads whichever ``ui`` keys apply. + """ + + name: str = Field(..., description="Parameter identifier") + param_type: Literal["number", "string", "boolean", "select"] = Field( + ..., description="Widget type for the frontend" + ) + default: Any = Field(default=None, description="Default value") + description: str = Field(default="", description="Human-readable label") + ui: dict[str, Any] | None = Field( + default=None, + description=( + "Widget-specific hints consumed by the frontend renderer. " + "Number widgets read ``min``/``max``/``step``; select " + "widgets read ``options``; plugin-defined widget kinds may " + "use any keys they like." + ), + ) + convertible_to_input: bool = Field( + default=True, + description=( + "If True, this parameter can be overridden by connecting an " + "input wire (ComfyUI-style widget-to-input conversion)." + ), + ) + + +class NodeDefinition(BaseModel): + """Static metadata describing a node type.""" + + node_type_id: str = Field(..., description="Unique node type identifier") + display_name: str = Field(..., description="Human-readable name") + category: str = Field(default="general", description="Category for grouping") + description: str = Field(default="", description="What this node does") + inputs: list[NodePort] = Field(default_factory=list) + outputs: list[NodePort] = Field(default_factory=list) + params: list[NodeParam] = Field( + default_factory=list, + description="Editable parameters (widgets) displayed on the node card.", + ) + continuous: bool = Field( + default=False, + description=( + "If True, source nodes (no inputs) re-execute continuously " + "instead of executing once. Useful for streaming generators." + ), + ) + pipeline_meta: dict[str, Any] | None = Field( + default=None, + description=( + "Rich pipeline-only metadata (config_schema, mode_defaults, " + "supports_lora, supports_vace, etc.) for nodes that are " + ":class:`Pipeline` subclasses. ``None`` for plain nodes. " + "Populated by ``Pipeline.get_definition()`` from the config " + "class's ``get_schema_with_metadata()``." + ), + ) + + +class BaseNode(ABC): + """Abstract base class for all backend node types. + + Subclasses must set ``node_type_id`` as a ``ClassVar`` and implement + ``get_definition()`` and ``execute()``. Nodes run inside a + :class:`NodeProcessor` in the pipeline graph: input port values arrive + in the ``inputs`` dict and the return dict fans out to downstream + nodes or pipelines via queue edges. + + Example:: + + class MyNode(BaseNode): + node_type_id: ClassVar[str] = "my-plugin.my-node" + + @classmethod + def get_definition(cls) -> NodeDefinition: + return NodeDefinition( + node_type_id=cls.node_type_id, + display_name="My Node", + inputs=[NodePort(name="audio", port_type="audio")], + outputs=[NodePort(name="audio", port_type="audio")], + ) + + def execute(self, inputs, **kwargs): + return {"audio": inputs["audio"]} + """ + + node_type_id: ClassVar[str] + + def __init__(self, node_id: str = "", config: dict[str, Any] | None = None): + self.node_id = node_id + self.config = config or {} + + @classmethod + @abstractmethod + def get_definition(cls) -> NodeDefinition: + """Return static metadata for this node type.""" + + @classmethod + def get_dynamic_output_ports(cls, params: dict[str, Any]) -> set[str]: + """Return output port names that depend on runtime params. + + Used by the graph executor to accept edges from ports that are + not declared statically in :meth:`get_definition` — e.g. the + scheduler node derives one output port per user-configured + trigger. The default returns an empty set; nodes with dynamic + outputs override. + """ + return set() + + def shutdown(self) -> None: # noqa: B027 — intentional no-op hook + """Release any resources held by the node. + + Called by :class:`NodeProcessor` when the graph is torn down. + The default is a no-op; nodes that start background threads or + hold OS handles override. + """ + + def execute( + self, + inputs: dict[str, Any], + **kwargs, + ) -> dict[str, Any]: + """Execute the node with the given inputs. + + Plain backend nodes (scheduled by ``NodeProcessor``) override + this. :class:`Pipeline` subclasses inherit the default because + they're driven by ``__call__`` from ``PipelineProcessor`` and + never reach this code path. + + Args: + inputs: Dict mapping input port names to their values. + Audio ports receive ``(tensor, sample_rate)`` tuples. + Video ports receive frame tensors. + **kwargs: Additional context (pipeline parameters, etc.) + + Returns: + Dict mapping output port names to their values. + """ + raise NotImplementedError( + f"{type(self).__name__} must override execute() or be invoked " + "via PipelineProcessor (Pipeline subclass) instead of NodeProcessor." + ) diff --git a/src/scope/core/nodes/builtins/__init__.py b/src/scope/core/nodes/builtins/__init__.py new file mode 100644 index 000000000..eec3bf677 --- /dev/null +++ b/src/scope/core/nodes/builtins/__init__.py @@ -0,0 +1,6 @@ +"""Built-in nodes shipped with the foundation abstraction.""" + +from .audio_io import AudioSourceNode +from .scheduler import SchedulerNode + +__all__ = ["AudioSourceNode", "SchedulerNode"] diff --git a/src/scope/core/nodes/builtins/audio_io.py b/src/scope/core/nodes/builtins/audio_io.py new file mode 100644 index 000000000..2845e8638 --- /dev/null +++ b/src/scope/core/nodes/builtins/audio_io.py @@ -0,0 +1,185 @@ +"""Built-in audio I/O nodes: AudioSource (WAV file → audio stream). + +Terminal audio output is handled by the regular Sink node: audio edges +into a Sink are routed straight to the WebRTC audio track via the +session's audio_output_queue, with no intermediate node needed. +""" + +from __future__ import annotations + +import logging +import os +import time +import wave +from typing import Any, ClassVar + +import numpy as np +import torch + +from ..base import BaseNode, NodeDefinition, NodeParam, NodePort + +logger = logging.getLogger(__name__) + +SAMPLE_RATE = 48000 +CHUNK_DURATION = 0.1 # 100ms chunks for streaming +CHUNK_SAMPLES = int(SAMPLE_RATE * CHUNK_DURATION) + + +class AudioSourceNode(BaseNode): + """Load audio from a WAV file and stream it in 100ms chunks, looping.""" + + node_type_id: ClassVar[str] = "audio.AudioSource" + + def __init__(self, node_id: str, config: dict[str, Any] | None = None): + super().__init__(node_id, config) + self._audio_data: np.ndarray | None = None + self._position = 0 + self._loaded_file: str = "" + self._last_call_time: float | None = None + + @classmethod + def get_definition(cls) -> NodeDefinition: + return NodeDefinition( + node_type_id=cls.node_type_id, + display_name="Audio Source", + category="audio", + description="Load audio from a WAV file at 48kHz stereo.", + continuous=True, + inputs=[], + outputs=[ + NodePort(name="audio", port_type="audio", description="Audio waveform"), + ], + params=[ + NodeParam( + name="file_id", + param_type="string", + default="", + description="Audio file path", + ), + NodeParam( + name="duration", + param_type="number", + default=15.0, + description="Duration (s)", + ui={"min": 1, "max": 600, "step": 1}, + ), + NodeParam( + name="mode", + param_type="select", + default="full", + description="Output mode", + ui={"options": ["full", "stream"]}, + ), + ], + ) + + def _load_audio(self, file_path: str, duration: float) -> None: + """Load, decode, resample to 48kHz stereo, and clip to duration.""" + with wave.open(file_path, "rb") as wf: + sr = wf.getframerate() + n_channels = wf.getnchannels() + sampwidth = wf.getsampwidth() + raw = wf.readframes(wf.getnframes()) + + if sampwidth == 2: + data = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0 + elif sampwidth == 4: + data = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0 + else: + data = np.frombuffer(raw, dtype=np.uint8).astype(np.float32) / 128.0 - 1.0 + + data = data.reshape(-1, n_channels) + if n_channels == 1: + data = np.stack([data[:, 0], data[:, 0]], axis=-1) + elif n_channels > 2: + data = data[:, :2] + data = data.T # (channels, samples) + + if sr != SAMPLE_RATE and sr > 0: + num_samples = data.shape[1] + new_len = int(num_samples * SAMPLE_RATE / sr) + old_indices = np.linspace(0, num_samples - 1, new_len) + resampled = np.zeros((data.shape[0], new_len), dtype=np.float32) + for ch in range(data.shape[0]): + resampled[ch] = np.interp(old_indices, np.arange(num_samples), data[ch]) + data = resampled + + max_samples = int(duration * SAMPLE_RATE) + if data.shape[1] > max_samples: + data = data[:, :max_samples] + + self._audio_data = data + self._position = 0 + self._loaded_file = file_path + logger.info( + "AudioSource loaded: %s (%.1fs)", + file_path, + data.shape[1] / SAMPLE_RATE, + ) + + def execute(self, inputs: dict[str, Any], **kwargs) -> dict[str, Any]: + file_id = kwargs.get("file_id", "") + duration = float(kwargs.get("duration", 15.0)) + # "full" = emit entire clip once (for batch DAGs); "stream" = 100ms chunks + mode = kwargs.get("mode", "stream") + + if not file_id: + return {} + file_id = self._resolve_path(file_id) + if not file_id: + return {} + + if file_id != self._loaded_file: + try: + self._load_audio(file_id, duration) + except Exception as e: + logger.error("AudioSourceNode failed to load %s: %s", file_id, e) + return {} + + if self._audio_data is None or self._audio_data.shape[1] == 0: + return {} + + if mode == "full": + return self._emit_full() + return self._emit_chunk() + + @staticmethod + def _resolve_path(file_id: str) -> str | None: + """Resolve a file path. Absolute → cwd → ~/.daydream-scope/assets.""" + if os.path.isabs(file_id) and os.path.exists(file_id): + return file_id + if os.path.exists(file_id): + return os.path.abspath(file_id) + from pathlib import Path + + candidate = Path.home() / ".daydream-scope" / "assets" / file_id + if candidate.exists(): + return str(candidate) + logger.warning("AudioSource: file not found: %s", file_id) + return None + + def _emit_full(self) -> dict[str, Any]: + return {"audio": (torch.from_numpy(self._audio_data.copy()), SAMPLE_RATE)} + + def _emit_chunk(self) -> dict[str, Any]: + # Pace to real-time + now = time.monotonic() + if self._last_call_time is not None: + elapsed = now - self._last_call_time + if elapsed < CHUNK_DURATION * 0.8: + time.sleep(CHUNK_DURATION - elapsed) + self._last_call_time = time.monotonic() + + total = self._audio_data.shape[1] + chunk = np.zeros((self._audio_data.shape[0], CHUNK_SAMPLES), dtype=np.float32) + remaining = CHUNK_SAMPLES + offset = 0 + while remaining > 0: + avail = min(remaining, total - self._position) + chunk[:, offset : offset + avail] = self._audio_data[ + :, self._position : self._position + avail + ] + self._position = (self._position + avail) % total + offset += avail + remaining -= avail + return {"audio": (torch.from_numpy(chunk), SAMPLE_RATE)} diff --git a/src/scope/core/nodes/builtins/scheduler.py b/src/scope/core/nodes/builtins/scheduler.py new file mode 100644 index 000000000..4307ce1ed --- /dev/null +++ b/src/scope/core/nodes/builtins/scheduler.py @@ -0,0 +1,310 @@ +"""Scheduler node — time-based trigger sequencer. + +Fires named triggers at specified times. Supports looping, dynamic +output ports derived from the trigger list, and an internal monotonic +clock running at ~200 Hz for sub-frame timing accuracy. + +Port semantics +-------------- +- Static inputs: ``start`` (trigger), ``reset`` (trigger) +- Static outputs: ``tick`` (number), ``elapsed`` (number), ``is_playing`` (boolean) +- Dynamic outputs: one per unique trigger ``port_name``, emitting an + incrementing counter each time the trigger fires. Counters (rather + than booleans) allow downstream nodes to detect every firing even + when multiple events coincide within a single frame. +""" + +from __future__ import annotations + +import threading +import time +from typing import Any, ClassVar, TypedDict + +from ..base import BaseNode, NodeDefinition, NodeParam, NodePort + +# Internal tick interval — 5 ms gives ~200 Hz resolution while keeping +# CPU usage negligible. +_TICK_INTERVAL = 0.005 + +# Throttle elapsed-time broadcasts to avoid flooding downstream nodes. +_ELAPSED_BROADCAST_INTERVAL = 0.05 # 50 ms → ~20 Hz + + +class TriggerSpec(TypedDict): + time: float + port_name: str + + +class SchedulerNode(BaseNode): + """Time-based trigger scheduler. + + Configured via the ``triggers``, ``loop``, and ``duration`` parameters. + Each trigger has a ``time`` (seconds) and ``port_name`` (output port). + When elapsed time reaches a trigger's time the corresponding output + fires with an incrementing counter. + """ + + node_type_id: ClassVar[str] = "scheduler" + + def __init__(self, node_id: str, config: dict[str, Any] | None = None): + super().__init__(node_id, config) + # Single lock guards all mutable state (including _pending_outputs) + # so lock ordering can never deadlock. + self._lock = threading.Lock() + self._playing = False + self._start_time: float | None = None + self._elapsed = 0.0 + self._tick_count = 0 + self._fire_counts: dict[str, int] = {} + self._fired_keys: set[tuple[str, float]] = set() + self._triggers: list[TriggerSpec] = [] + self._loop = False + self._duration = 0.0 + self._last_elapsed_broadcast = 0.0 + self._timer_thread: threading.Thread | None = None + self._stop_event = threading.Event() + # Accumulated outputs drained on next execute(). Only populated + # when something has actually changed so downstream queues don't + # get flooded with unchanged state. + self._pending_outputs: dict[str, Any] = {} + # Previous counter values on edge-triggered inputs. Any strict + # increment toggles the corresponding action; a stale value + # sitting on the queue is ignored. Seeded at 0 so the first + # positive counter delivered from upstream is treated as a fire. + self._prev_start_counter: int = 0 + self._prev_reset_counter: int = 0 + # Auto-start fires once per node lifetime, regardless of later + # stale zeros observed on start/reset inputs. + self._auto_start_done = False + + @classmethod + def get_definition(cls) -> NodeDefinition: + return NodeDefinition( + node_type_id=cls.node_type_id, + display_name="Scheduler", + category="timing", + description=( + "Time-based trigger scheduler. Add trigger points and " + "connect them to other nodes to drive timed actions." + ), + continuous=True, + inputs=[ + NodePort( + name="start", + port_type="trigger", + required=False, + description="Play / pause toggle", + ), + NodePort( + name="reset", + port_type="trigger", + required=False, + description="Reset elapsed time and trigger state", + ), + ], + outputs=[ + NodePort(name="tick", port_type="number", description="Tick counter"), + NodePort( + name="elapsed", + port_type="number", + description="Elapsed time (seconds)", + ), + NodePort( + name="is_playing", + port_type="boolean", + description="Whether the scheduler is playing", + ), + ], + params=[ + NodeParam( + name="triggers", + param_type="string", + default=[], + description="List of { time, port_name } trigger points", + ui={"widget": "trigger_list"}, + convertible_to_input=False, + ), + NodeParam( + name="loop", + param_type="boolean", + default=False, + description="Loop when duration is reached", + ), + NodeParam( + name="duration", + param_type="number", + default=30.0, + description="Total duration (s); 0 disables auto-stop", + ui={"min": 0, "max": 3600, "step": 0.1}, + ), + ], + ) + + @classmethod + def get_dynamic_output_ports(cls, params: dict[str, Any]) -> set[str]: + triggers = params.get("triggers") or [] + if not isinstance(triggers, list): + return set() + names: set[str] = set() + for trig in triggers: + if isinstance(trig, dict): + name = trig.get("port_name") + if isinstance(name, str) and name: + names.add(name) + return names + + # ------------------------------------------------------------------ + # Internal timer + # ------------------------------------------------------------------ + + def _start_timer(self) -> None: + if self._timer_thread is not None and self._timer_thread.is_alive(): + return + self._stop_event.clear() + self._timer_thread = threading.Thread( + target=self._timer_loop, + daemon=True, + name=f"SchedulerNode[{self.node_id}]", + ) + self._timer_thread.start() + + def _stop_timer(self) -> None: + self._stop_event.set() + t = self._timer_thread + self._timer_thread = None + if t is not None and t is not threading.current_thread(): + t.join(timeout=1.0) + + def _timer_loop(self) -> None: + """Background thread: checks triggers at ~200 Hz.""" + while not self._stop_event.is_set(): + with self._lock: + if self._playing and self._start_time is not None: + self._elapsed = time.monotonic() - self._start_time + self._check_triggers() + self._stop_event.wait(_TICK_INTERVAL) + + def _check_triggers(self) -> None: + """Fire any triggers whose time has been reached. Caller holds ``_lock``.""" + for trig in self._triggers: + t = float(trig["time"]) + port = trig["port_name"] + key = (port, t) + if t <= self._elapsed and key not in self._fired_keys: + self._fired_keys.add(key) + self._fire_counts[port] = self._fire_counts.get(port, 0) + 1 + self._tick_count += 1 + self._pending_outputs[port] = self._fire_counts[port] + self._pending_outputs["tick"] = self._tick_count + + # Broadcast elapsed at a throttled rate so downstream nodes see + # a heartbeat without being flooded. + now = time.monotonic() + if now - self._last_elapsed_broadcast >= _ELAPSED_BROADCAST_INTERVAL: + self._last_elapsed_broadcast = now + self._pending_outputs["elapsed"] = round(self._elapsed, 3) + self._pending_outputs["is_playing"] = self._playing + + # Handle loop / auto-stop + if self._duration > 0 and self._elapsed >= self._duration: + if self._loop: + self._reset_state() + self._start_time = time.monotonic() + else: + all_fired = all( + (t["port_name"], float(t["time"])) in self._fired_keys + for t in self._triggers + ) + if all_fired: + self._playing = False + self._pending_outputs["is_playing"] = False + + def _reset_state(self) -> None: + """Reset elapsed time and trigger state. Caller holds ``_lock``.""" + self._elapsed = 0.0 + self._fired_keys.clear() + self._fire_counts.clear() + self._tick_count = 0 + + # ------------------------------------------------------------------ + # Node interface + # ------------------------------------------------------------------ + + def execute(self, inputs: dict[str, Any], **kwargs) -> dict[str, Any]: + triggers = kwargs.get("triggers", self.config.get("triggers", [])) + loop = kwargs.get("loop", self.config.get("loop", False)) + duration = kwargs.get("duration", self.config.get("duration", 0.0)) + + # Edge-detect start/reset counters so stale values on the input + # queue don't retrigger actions. Prev counters seeded at 0 in + # __init__ so the first positive pulse from upstream fires. + start_val = _as_counter(inputs.get("start")) + reset_val = _as_counter(inputs.get("reset")) + start_fired = start_val is not None and start_val > self._prev_start_counter + reset_fired = reset_val is not None and reset_val > self._prev_reset_counter + if start_val is not None: + self._prev_start_counter = start_val + if reset_val is not None: + self._prev_reset_counter = reset_val + + with self._lock: + normalized_triggers: list[TriggerSpec] = [] + if isinstance(triggers, list): + for t in triggers: + if not isinstance(t, dict): + continue + port = t.get("port_name") + time_val = t.get("time") + if isinstance(port, str) and port and time_val is not None: + normalized_triggers.append( + {"time": float(time_val), "port_name": port} + ) + self._triggers = normalized_triggers + self._loop = bool(loop) + self._duration = float(duration) + + if reset_fired: + self._reset_state() + self._start_time = time.monotonic() if self._playing else None + + if start_fired: + self._playing = not self._playing + self._pending_outputs["is_playing"] = self._playing + if self._playing: + self._start_time = time.monotonic() - self._elapsed + self._start_timer() + + # Auto-start on first execute so wiring the node into a graph + # drives downstream triggers without requiring an explicit start + # pulse. Gated by an explicit flag so a later stale zero on + # start/reset can't accidentally re-arm it. + if ( + not self._auto_start_done + and not self._playing + and self._start_time is None + and self._triggers + ): + self._auto_start_done = True + self._playing = True + self._start_time = time.monotonic() + self._pending_outputs["is_playing"] = True + self._start_timer() + + if not self._pending_outputs: + return {} + + outputs = self._pending_outputs + self._pending_outputs = {} + return outputs + + def shutdown(self) -> None: + self._stop_timer() + + +def _as_counter(value: Any) -> int | None: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None diff --git a/src/scope/core/nodes/processor.py b/src/scope/core/nodes/processor.py new file mode 100644 index 000000000..0638bfdec --- /dev/null +++ b/src/scope/core/nodes/processor.py @@ -0,0 +1,218 @@ +"""Node processor — wraps a BaseNode for execution in the pipeline graph. + +Adapts the node interface (typed I/O ports) to the pipeline processor +interface (input/output queues, worker thread). +""" + +import logging +import queue +import threading +from typing import Any + +import torch + +from .base import BaseNode + +logger = logging.getLogger(__name__) + +SLEEP_TIME = 0.01 + + +class NodeProcessor: + """Runs a BaseNode in a dedicated thread. Input queues feed the node, + output queues fan out its results to downstream nodes. + + Source nodes (no inputs) execute once by default; nodes marked + ``continuous=True`` in their definition re-execute on every tick, which + is how streaming sources (audio) and sinks (audio loop) stay alive. + """ + + def __init__( + self, + node: BaseNode, + node_id: str, + initial_parameters: dict | None = None, + ): + self.node = node + self.node_id = node_id + self.parameters = initial_parameters or {} + + # Port-based queues wired by the graph executor + self.input_queues: dict[str, queue.Queue] = {} + self.output_queues: dict[str, list[queue.Queue]] = {} + self.input_queue_lock = threading.Lock() + self.external_queue_refs: list[tuple[dict, str]] = [] + + definition = node.get_definition() + self.audio_input_ports: set[str] = { + p.name for p in definition.inputs if p.port_type == "audio" + } + self.audio_output_ports: set[str] = { + p.name for p in definition.outputs if p.port_type == "audio" + } + + # Audio output queue consumed by FrameProcessor.get_audio() on the sink + self.audio_output_queue: queue.Queue[tuple[torch.Tensor, int]] = queue.Queue( + maxsize=10 + ) + + self.worker_thread: threading.Thread | None = None + self.shutdown_event = threading.Event() + self.running = False + + # Execution state + self._source_executed = False + self._has_executed = False + self._continuous = definition.continuous + + # PipelineProcessor interface compatibility: graph_executor populates + # this for every processor; kept as an empty dict so that write is safe. + self.output_consumers: dict[str, list] = {} + self.paused = False + + @property + def output_queue(self) -> queue.Queue | None: + qs = self.output_queues.get("video") + return qs[0] if qs else None + + def start(self) -> None: + if self.running: + return + self.running = True + self.shutdown_event.clear() + self.worker_thread = threading.Thread( + target=self._worker_loop, daemon=True, name=f"NodeProcessor[{self.node_id}]" + ) + self.worker_thread.start() + + def stop(self) -> None: + if not self.running: + return + self.running = False + self.shutdown_event.set() + if self.worker_thread is not None: + self.worker_thread.join(timeout=5.0) + try: + self.node.shutdown() + except Exception: + logger.exception("Error shutting down node %s", self.node_id) + logger.info("NodeProcessor stopped: %s", self.node_id) + + def update_parameters(self, parameters: dict[str, Any]) -> None: + self.parameters.update(parameters) + + def set_beat_cache_reset_rate(self, rate): # PipelineProcessor compat + pass + + def get_fps(self) -> float: + return 30.0 + + def _worker_loop(self) -> None: + while not self.shutdown_event.is_set(): + try: + self._process_once() + except Exception: + logger.exception("Error in node processor %s", self.node_id) + with self.input_queue_lock: + is_source = not self.input_queues + if is_source: + # Avoid infinite retry on failing source nodes + self._source_executed = True + self._continuous = False + self.shutdown_event.wait(SLEEP_TIME) + + def _process_once(self) -> None: + if self.paused: + self.shutdown_event.wait(SLEEP_TIME) + return + + with self.input_queue_lock: + all_queues = dict(self.input_queues) + + is_source_node = not all_queues + + # Source nodes execute once; continuous=True nodes re-execute every + # tick (for streaming I/O like AudioSource chunking). + if is_source_node and self._source_executed and not self._continuous: + self.shutdown_event.wait(1.0) + return + + # Gather inputs. Continuous nodes consume whatever's available + # (empty inputs stay absent). Non-continuous nodes wait until every + # input queue has data, so they execute with a complete input set. + inputs: dict[str, Any] = {} + if all_queues: + if self._continuous: + for port_name, q in all_queues.items(): + try: + inputs[port_name] = q.get_nowait() + except queue.Empty: + pass + else: + if any(q.empty() for q in all_queues.values()): + self.shutdown_event.wait(SLEEP_TIME) + return + inputs = {name: q.get_nowait() for name, q in all_queues.items()} + + # Non-continuous nodes skip re-execution when no new inputs arrived + # and they already have a cached output. + if self._has_executed and not inputs and not self._continuous: + self.shutdown_event.wait(SLEEP_TIME) + return + + outputs = self.node.execute(inputs, **self.parameters) + + if is_source_node: + self._source_executed = True + + if not outputs: + self.shutdown_event.wait(SLEEP_TIME) + return + + self._has_executed = True + self._route_outputs(outputs) + + def _route_outputs(self, outputs: dict[str, Any]) -> None: + for port_name, value in outputs.items(): + if value is None: + continue + + # Audio outputs also feed the FrameProcessor's audio path + if port_name in self.audio_output_ports: + self._route_audio(value) + + # Fan out to all downstream queues on this port. Block briefly + # when queues are full so producers throttle to consumer pace + # and GPU tensors don't pile up in memory. + out_queues = self.output_queues.get(port_name) + if out_queues: + for oq in out_queues: + while not self.shutdown_event.is_set(): + try: + oq.put(value, timeout=0.1) + break + except queue.Full: + continue + + def _route_audio(self, value: Any) -> None: + """Extract audio tensor and push to audio_output_queue for WebRTC.""" + if isinstance(value, tuple) and len(value) == 2: + audio_tensor, audio_sr = value + else: + audio_tensor = getattr(value, "waveform", None) + audio_sr = getattr(value, "sample_rate", 48000) + if audio_tensor is None: + return + if hasattr(audio_tensor, "is_cuda") and audio_tensor.is_cuda: + audio_tensor = audio_tensor.detach().cpu() + # VAE decoders (e.g. ACEStep) return (1, C, T); the audio track + # expects (C, T). Drop a leading singleton batch dim so the + # channel/interleave path in AudioProcessingTrack doesn't misread + # the layout and produce slowed-down / garbled playback. + if hasattr(audio_tensor, "dim") and audio_tensor.dim() == 3: + if audio_tensor.shape[0] == 1: + audio_tensor = audio_tensor.squeeze(0) + try: + self.audio_output_queue.put_nowait((audio_tensor, audio_sr)) + except queue.Full: + pass diff --git a/src/scope/core/nodes/registry.py b/src/scope/core/nodes/registry.py new file mode 100644 index 000000000..59ac5cb4c --- /dev/null +++ b/src/scope/core/nodes/registry.py @@ -0,0 +1,80 @@ +"""Unified registry for every node type on the graph (plain nodes + pipelines).""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .base import BaseNode, NodeDefinition + +logger = logging.getLogger(__name__) + + +def _derive_node_type_id(node_class: type) -> str | None: + """Return the registry key for a node class, or None if not derivable. + + Plain nodes carry the id as the ``node_type_id`` classvar; pipelines + keep it on their config class as ``pipeline_id``. + """ + node_type_id = getattr(node_class, "node_type_id", None) + if node_type_id is not None: + return node_type_id + # Lazy import: nodes.registry is loaded before pipelines.interface. + # Narrow to ImportError so real bugs (AttributeError on a broken + # config class, typos in pipeline_id, etc.) surface instead of being + # silently swallowed. + try: + from scope.core.pipelines.interface import Pipeline + except ImportError: + return None + + if issubclass(node_class, Pipeline): + return node_class.get_config_class().pipeline_id + return None + + +class NodeRegistry: + """Central registry for all available node types.""" + + _nodes: dict[str, type[BaseNode]] = {} + + @classmethod + def register(cls, node_class: type[BaseNode]) -> None: + """Register a :class:`BaseNode` subclass (plain node or pipeline).""" + node_type_id = _derive_node_type_id(node_class) + if node_type_id is None: + raise ValueError( + f"Cannot determine node_type_id for {node_class.__name__}; " + "set a ClassVar[str] `node_type_id` on plain nodes or a " + "`pipeline_id` on the pipeline config class." + ) + cls._nodes[node_type_id] = node_class + logger.debug("Registered node type: %s", node_type_id) + + @classmethod + def get(cls, node_type_id: str) -> type[BaseNode] | None: + return cls._nodes.get(node_type_id) + + @classmethod + def is_registered(cls, node_type_id: str) -> bool: + return node_type_id in cls._nodes + + @classmethod + def list_node_types(cls) -> list[str]: + return list(cls._nodes.keys()) + + @classmethod + def get_all_definitions(cls) -> list[NodeDefinition]: + return [nc.get_definition() for nc in cls._nodes.values()] + + @classmethod + def unregister(cls, node_type_id: str) -> bool: + if node_type_id in cls._nodes: + del cls._nodes[node_type_id] + return True + return False + + @classmethod + def clear(cls) -> None: + cls._nodes.clear() diff --git a/src/scope/core/pipelines/interface.py b/src/scope/core/pipelines/interface.py index 18312c647..acfbebe87 100644 --- a/src/scope/core/pipelines/interface.py +++ b/src/scope/core/pipelines/interface.py @@ -1,10 +1,20 @@ -"""Base interface for all pipelines.""" +"""Base interface for all pipelines. + +A :class:`Pipeline` is a :class:`scope.core.nodes.BaseNode` subclass — +the "heavy" kind that batches video frames, loads GPU models, and +carries a rich Pydantic config class. The graph editor and user-facing +docs call them *Nodes*; the name ``Pipeline`` survives as the +implementation base class so existing subclasses and plugins keep +working unchanged. +""" from abc import ABC, abstractmethod from typing import TYPE_CHECKING from pydantic import BaseModel +from scope.core.nodes.base import BaseNode, NodeDefinition, NodePort + if TYPE_CHECKING: from .schema import BasePipelineConfig @@ -15,50 +25,57 @@ class Requirements(BaseModel): input_size: int -class Pipeline(ABC): - """Abstract base class for all pipelines. - - Pipelines must implement get_config_class() to return their Pydantic config model. - This enables: - - Validation via model_validate() / model_validate_json() - - JSON Schema generation via model_json_schema() - - Type-safe configuration access - - API introspection and automatic UI generation +class Pipeline(BaseNode, ABC): + """Abstract base class for video-pipeline nodes. - See schema.py for the BasePipelineConfig model and pipeline-specific configs. - For multi-mode pipeline support (text/video), pipelines use helper functions - from defaults.py (resolve_input_mode, apply_mode_defaults_to_state, etc.). + Subclasses implement ``__call__`` (the per-chunk processing + function) and ``get_config_class`` (returning a Pydantic config + that drives validation, JSON-schema generation, the parameter + panel, and parameter defaults). Everything else — registry, + plugin hook, graph editor — is the same as for plain nodes. """ @classmethod def get_config_class(cls) -> type["BasePipelineConfig"]: """Return the Pydantic config class for this pipeline. - The config class should inherit from BasePipelineConfig and define: - - pipeline_id: Unique identifier - - pipeline_name: Human-readable name - - pipeline_description: Capabilities description - - pipeline_version: Version string - - Default parameter values for the pipeline - - Returns: - Pydantic config model class - - Note: - Subclasses should override this method to return their config class. - The default implementation returns BasePipelineConfig. - - Example: - from .schema import LongLiveConfig - - @classmethod - def get_config_class(cls) -> type[BasePipelineConfig]: - return LongLiveConfig + Subclasses override to return their concrete config; the + default returns ``BasePipelineConfig``. """ from .schema import BasePipelineConfig return BasePipelineConfig + @classmethod + def get_definition(cls) -> NodeDefinition: + """Project the pipeline's config class into a :class:`NodeDefinition`. + + Populates the compact node-catalog fields (id, ports, etc.) + and stuffs the full ``get_schema_with_metadata()`` output into + ``pipeline_meta``, which is the rich data ``PipelineNode.tsx`` + renders in the parameter panel. ``params`` is left empty + because the Pydantic schema is too structured to flatten into + ``NodeParam[]`` widgets. + """ + config = cls.get_config_class() + return NodeDefinition( + node_type_id=config.pipeline_id, + display_name=getattr(config, "pipeline_name", config.pipeline_id), + category="pipeline", + description=getattr(config, "pipeline_description", "") or "", + inputs=[ + NodePort(name=name, port_type="video") + for name in (getattr(config, "inputs", ["video"]) or ["video"]) + ], + outputs=[ + NodePort(name=name, port_type="video") + for name in (getattr(config, "outputs", ["video"]) or ["video"]) + ], + params=[], + continuous=False, + pipeline_meta=config.get_schema_with_metadata(), + ) + @abstractmethod def __call__(self, **kwargs) -> dict: """ diff --git a/src/scope/core/pipelines/registry.py b/src/scope/core/pipelines/registry.py index 030f7a3d1..d50e9d2fc 100644 --- a/src/scope/core/pipelines/registry.py +++ b/src/scope/core/pipelines/registry.py @@ -1,8 +1,10 @@ -"""Pipeline registry for centralized pipeline management. +"""Pipeline registry — a filtering view over :class:`NodeRegistry`. -This module provides a registry pattern to eliminate if/elif chains when -accessing pipelines by ID. It enables dynamic pipeline discovery and -metadata retrieval. +Pipelines and plain custom nodes share the same ``NodeRegistry._nodes`` +storage. ``PipelineRegistry`` projects that storage down to entries +whose class is a :class:`Pipeline` subclass and exposes the same API +the rest of the codebase always used, so existing call sites and +plugins keep working unchanged. """ import importlib @@ -11,6 +13,8 @@ import torch +from scope.core.nodes.registry import NodeRegistry + if TYPE_CHECKING: from .interface import Pipeline from .schema import BasePipelineConfig @@ -18,83 +22,67 @@ logger = logging.getLogger(__name__) -class PipelineRegistry: - """Registry for managing available pipelines.""" +def _is_pipeline(node_class: object) -> bool: + """Return True when ``node_class`` is a :class:`Pipeline` subclass. + + Lazily imports :class:`Pipeline` to dodge the import cycle between + the pipelines and nodes packages at module load time. + """ + from .interface import Pipeline + + return isinstance(node_class, type) and issubclass(node_class, Pipeline) + - _pipelines: dict[str, type["Pipeline"]] = {} +class PipelineRegistry: + """Filtering view over :class:`NodeRegistry` for pipeline classes.""" @classmethod def register(cls, pipeline_id: str, pipeline_class: type["Pipeline"]) -> None: - """Register a pipeline class with its ID. + """Plant a pipeline class into the unified :class:`NodeRegistry`. - Args: - pipeline_id: Unique identifier for the pipeline - pipeline_class: Pipeline class to register + Delegates to ``NodeRegistry.register`` so the same logging and + id-derivation path runs for built-in pipelines and plugin nodes + alike. The explicit ``pipeline_id`` argument is asserted against + the derived id to catch drift between the registry key and the + config class's ``pipeline_id``. """ - cls._pipelines[pipeline_id] = pipeline_class + config_pipeline_id = pipeline_class.get_config_class().pipeline_id + if pipeline_id != config_pipeline_id: + class_name = getattr(pipeline_class, "__name__", repr(pipeline_class)) + raise ValueError( + f"Pipeline id mismatch: registered as '{pipeline_id}' but " + f"{class_name}.get_config_class().pipeline_id is " + f"'{config_pipeline_id}'." + ) + NodeRegistry.register(pipeline_class) @classmethod def get(cls, pipeline_id: str) -> type["Pipeline"] | None: - """Get a pipeline class by its ID. - - Args: - pipeline_id: Pipeline identifier - - Returns: - Pipeline class if found, None otherwise - """ - return cls._pipelines.get(pipeline_id) + node_class = NodeRegistry.get(pipeline_id) + return node_class if _is_pipeline(node_class) else None @classmethod def unregister(cls, pipeline_id: str) -> bool: - """Remove a pipeline from the registry. - - Args: - pipeline_id: Pipeline identifier to remove - - Returns: - True if pipeline was removed, False if not found - """ - if pipeline_id in cls._pipelines: - del cls._pipelines[pipeline_id] - return True - return False + if cls.get(pipeline_id) is None: + return False + return NodeRegistry.unregister(pipeline_id) @classmethod def is_registered(cls, pipeline_id: str) -> bool: - """Check if a pipeline is registered. - - Args: - pipeline_id: Pipeline identifier - - Returns: - True if pipeline is registered, False otherwise - """ - return pipeline_id in cls._pipelines + return cls.get(pipeline_id) is not None @classmethod def get_config_class(cls, pipeline_id: str) -> type["BasePipelineConfig"] | None: - """Get config class for a specific pipeline. - - Args: - pipeline_id: Pipeline identifier - - Returns: - Pydantic config class if found, None otherwise - """ pipeline_class = cls.get(pipeline_id) - if pipeline_class is None: - return None - return pipeline_class.get_config_class() + return pipeline_class.get_config_class() if pipeline_class else None @classmethod def list_pipelines(cls) -> list[str]: - """Get list of all registered pipeline IDs. - - Returns: - List of pipeline IDs - """ - return list(cls._pipelines.keys()) + return [ + pid + for pid in NodeRegistry.list_node_types() + if _is_pipeline(NodeRegistry.get(pid)) + ] @classmethod def chain_produces_video(cls, pipeline_ids: list[str]) -> bool: @@ -246,28 +234,41 @@ def _register_pipelines(): def _initialize_registry(): - """Initialize registry with built-in pipelines and plugins.""" + """Initialize registry with built-in pipelines, nodes, and plugins.""" # Register built-in pipelines first _register_pipelines() - # Load and register plugin pipelines + # Register built-in nodes (no-op on the base abstraction branch) + from scope.core.nodes import register_builtin_nodes + + register_builtin_nodes() + + # Load and register plugins. The unified register_plugin_nodes fires + # both register_pipelines and register_nodes hooks, so old and new + # plugins are picked up in one call. try: from scope.core.plugins import ( ensure_plugins_installed, load_plugins, - register_plugin_pipelines, + register_plugin_nodes, ) ensure_plugins_installed() load_plugins() - register_plugin_pipelines(PipelineRegistry) + register_plugin_nodes() except Exception as e: logger.error( f"Failed to load plugins: {e}. Built-in pipelines are still available." ) + from scope.core.nodes.registry import NodeRegistry + pipeline_count = len(PipelineRegistry.list_pipelines()) - logger.info(f"Registry initialized with {pipeline_count} pipeline(s)") + node_count = len(NodeRegistry.list_node_types()) + logger.info( + f"Registry initialized with {pipeline_count} pipeline(s) and " + f"{node_count} node(s)" + ) # Auto-register pipelines on module import diff --git a/src/scope/core/plugins/__init__.py b/src/scope/core/plugins/__init__.py index a681707d7..652f85883 100644 --- a/src/scope/core/plugins/__init__.py +++ b/src/scope/core/plugins/__init__.py @@ -14,6 +14,7 @@ get_plugin_manager, load_plugins, pm, + register_plugin_nodes, register_plugin_pipelines, ) @@ -22,6 +23,7 @@ "ensure_plugins_installed", "load_plugins", "pm", + "register_plugin_nodes", "register_plugin_pipelines", "get_plugin_manager", "FailedPluginInfo", diff --git a/src/scope/core/plugins/hookspecs.py b/src/scope/core/plugins/hookspecs.py index 62ce4f672..2c0aa78b7 100644 --- a/src/scope/core/plugins/hookspecs.py +++ b/src/scope/core/plugins/hookspecs.py @@ -22,3 +22,17 @@ def register_pipelines(self, register): def register_pipelines(register): register(MyPipeline) """ + + @hookspec + def register_nodes(self, register): + """Register custom node types. + + Args: + register: Callback to register node classes. + Usage: register(NodeClass) + + Example: + @scope.core.hookimpl + def register_nodes(register): + register(MyCustomNode) + """ diff --git a/src/scope/core/plugins/manager.py b/src/scope/core/plugins/manager.py index 87ee5a48a..d18d7118f 100644 --- a/src/scope/core/plugins/manager.py +++ b/src/scope/core/plugins/manager.py @@ -550,30 +550,36 @@ def get_failed_plugins(self) -> list[FailedPluginInfo]: with self._lock: return list(self._failed_plugins) - def register_plugin_pipelines(self, registry: "PipelineRegistry") -> None: - """Call register_pipelines hook for all plugins. - - Args: - registry: PipelineRegistry to register pipelines with + def register_plugin_nodes(self, registry: Any = None) -> None: + """Fire ``register_nodes`` and ``register_pipelines`` hooks. + + Both hooks plant into the unified :class:`NodeRegistry` storage, + so existing plugins using ``register_pipelines(register)`` keep + working unchanged alongside new ones using ``register_nodes``. + The ``registry`` argument is accepted for legacy callers but + ignored — the unified storage is always used. """ + from scope.core.nodes.registry import NodeRegistry + from scope.core.pipelines.registry import PipelineRegistry + + del registry # legacy parameter, kept for callsite compat + with self._lock: - # Clear previous mappings self._pipeline_to_plugin.clear() - def register_callback(pipeline_class: Any) -> None: - """Callback function passed to plugins.""" - config_class = pipeline_class.get_config_class() - pipeline_id = config_class.pipeline_id - registry.register(pipeline_id, pipeline_class) - - # Track which plugin owns this pipeline - # We'll update this mapping after the hook call - logger.info(f"Registered plugin pipeline: {pipeline_id}") + def register_callback(node_class: Any) -> None: + NodeRegistry.register(node_class) + node_id = getattr(node_class, "node_type_id", None) or ( + node_class.get_config_class().pipeline_id + ) + logger.info(f"Registered plugin node: {node_id}") + self._pm.hook.register_nodes(register=register_callback) self._pm.hook.register_pipelines(register=register_callback) + self._update_pipeline_plugin_mapping(PipelineRegistry) - # Update pipeline-to-plugin mapping by checking which plugins provide which pipelines - self._update_pipeline_plugin_mapping(registry) + # Backwards-compat alias for internal callers using the legacy name. + register_plugin_pipelines = register_plugin_nodes def _update_pipeline_plugin_mapping(self, registry: "PipelineRegistry") -> None: """Update the mapping of pipeline IDs to plugin names.""" @@ -1658,10 +1664,15 @@ def load_plugins() -> None: get_plugin_manager().load_plugins() -def register_plugin_pipelines(registry: "PipelineRegistry") -> None: - """Call register_pipelines hook for all plugins. +def register_plugin_nodes(registry: Any = None) -> None: + """Fire ``register_nodes`` + ``register_pipelines`` hooks. - Args: - registry: PipelineRegistry to register pipelines with + Both hookspecs plant into the unified :class:`NodeRegistry` storage, + so old and new plugins coexist. The ``registry`` argument is kept + for legacy callers and ignored. """ - get_plugin_manager().register_plugin_pipelines(registry) + get_plugin_manager().register_plugin_nodes(registry) + + +# Backwards-compat alias for internal callers using the legacy name. +register_plugin_pipelines = register_plugin_nodes diff --git a/src/scope/server/app.py b/src/scope/server/app.py index d77939909..315cfd74d 100644 --- a/src/scope/server/app.py +++ b/src/scope/server/app.py @@ -823,47 +823,62 @@ async def get_pipeline_schemas( http_request: Request, cloud_manager: ScopeCloudBackend = Depends(get_scope_cloud), ): - """Get configuration schemas and defaults for all available pipelines. + """Compat alias for the pipeline-rich subset of the unified node catalog. - Returns the output of each pipeline's get_schema_with_metadata() method, - which includes: - - Pipeline metadata (id, name, description, version) - - supported_modes: List of supported input modes ("text", "video") - - default_mode: Default input mode for this pipeline - - mode_defaults: Mode-specific default overrides (if any) - - config_schema: Full JSON schema with defaults + Derives its response from the unified :class:`NodeRegistry` — + every entry whose :attr:`NodeDefinition.pipeline_meta` is set is a + pipeline, and ``pipeline_meta`` already holds the full output of + ``get_schema_with_metadata()`` (config_schema, mode_defaults, + supports_lora, supports_vace, etc.). Kept so existing frontend + callers in ``usePipelines.ts`` keep working without migration; new + code should read from ``GET /api/v1/nodes/definitions`` instead. - The frontend should use this as the source of truth for parameter defaults. - - In cloud mode (when connected to cloud), this proxies the request to the - cloud-hosted scope backend to get the available pipelines there. + In cloud mode this proxies to the cloud-hosted scope backend. """ global _pipeline_schemas_cache if _pipeline_schemas_cache is not None: return _pipeline_schemas_cache - from scope.core.pipelines.registry import PipelineRegistry + from scope.core.nodes.registry import NodeRegistry from scope.core.plugins import get_plugin_manager plugin_manager = get_plugin_manager() pipelines: dict = {} - for pipeline_id in PipelineRegistry.list_pipelines(): - config_class = PipelineRegistry.get_config_class(pipeline_id) - if config_class: - # get_schema_with_metadata() includes supported_modes, default_mode, - # and mode_defaults directly from the config class - schema_data = config_class.get_schema_with_metadata() - schema_data["plugin_name"] = plugin_manager.get_plugin_for_pipeline( - pipeline_id - ) - pipelines[pipeline_id] = schema_data + for definition in NodeRegistry.get_all_definitions(): + if definition.pipeline_meta is None: + continue + schema_data = dict(definition.pipeline_meta) + schema_data["plugin_name"] = plugin_manager.get_plugin_for_pipeline( + definition.node_type_id + ) + pipelines[definition.node_type_id] = schema_data response = PipelineSchemasResponse(pipelines=pipelines) _pipeline_schemas_cache = response return response +# --------------------------------------------------------------------------- +# Node definitions +# --------------------------------------------------------------------------- + + +@app.get("/api/v1/nodes/definitions") +async def get_node_definitions(): + """Return definitions for every registered node — pipelines included. + + The unified discovery endpoint. Pipelines (Pipeline subclasses) + appear with ``pipeline_meta`` populated; plain custom nodes leave + ``pipeline_meta`` ``None``. Frontend consumers that only want plain + nodes filter on ``pipeline_meta == null`` client-side; consumers + that want rich pipeline data read from ``pipeline_meta`` directly. + """ + from scope.core.nodes.registry import NodeRegistry + + return {"nodes": [d.model_dump() for d in NodeRegistry.get_all_definitions()]} + + # --------------------------------------------------------------------------- # OSC endpoints # --------------------------------------------------------------------------- @@ -1201,13 +1216,31 @@ async def handle_webrtc_offer( logger.info("Using relay mode - video will flow through backend to cloud") return await webrtc_manager.handle_offer_with_relay(request, cloud_manager) - # Local mode: ensure pipeline is loaded before proceeding - status_info = await pipeline_manager.get_status_info_async() - if status_info["status"] != "loaded": - raise HTTPException( - status_code=400, - detail="Pipeline not loaded. Please load pipeline first.", + # Local mode: ensure pipeline is loaded before proceeding. + # Node-only graphs (no pipeline nodes) skip this check — the graph + # executor handles custom nodes directly without loading pipelines. + graph_data = ( + request.initialParameters.graph if request.initialParameters else None + ) + has_graph_nodes = False + if graph_data is not None: + nodes = ( + graph_data.get("nodes", []) + if isinstance(graph_data, dict) + else getattr(graph_data, "nodes", []) + ) + has_graph_nodes = any( + (n.get("type") if isinstance(n, dict) else getattr(n, "type", None)) + == "node" + for n in nodes ) + if not has_graph_nodes: + status_info = await pipeline_manager.get_status_info_async() + if status_info["status"] != "loaded": + raise HTTPException( + status_code=400, + detail="Pipeline not loaded. Please load pipeline first.", + ) return await webrtc_manager.handle_offer( request, pipeline_manager, tempo_sync=tempo_sync diff --git a/src/scope/server/frame_processor.py b/src/scope/server/frame_processor.py index 712966274..389d7b83d 100644 --- a/src/scope/server/frame_processor.py +++ b/src/scope/server/frame_processor.py @@ -218,8 +218,25 @@ def start(self): ) return - # Local mode: setup pipeline graph - if not self.pipeline_ids: + # Local mode: setup pipeline graph. + # Node-only graphs (custom nodes, audio-only workflows) are allowed + # to start without any pipeline IDs — the graph executor still runs + # custom nodes via NodeProcessor and the audio path works through + # the standard audio_output_queue. + graph_param = (self.parameters or {}).get("graph") + _has_custom_nodes = False + if graph_param is not None: + _nodes = ( + graph_param.get("nodes", []) + if isinstance(graph_param, dict) + else getattr(graph_param, "nodes", []) + ) + _has_custom_nodes = any( + (n.get("type") if isinstance(n, dict) else getattr(n, "type", None)) + == "node" + for n in _nodes + ) + if not self.pipeline_ids and not _has_custom_nodes: error_msg = "No pipeline IDs provided, cannot start" logger.error(error_msg) self.running = False diff --git a/src/scope/server/graph_executor.py b/src/scope/server/graph_executor.py index 3939f510f..c0f594e43 100644 --- a/src/scope/server/graph_executor.py +++ b/src/scope/server/graph_executor.py @@ -12,6 +12,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any +from scope.core.nodes.registry import NodeRegistry + from .graph_schema import GraphConfig from .pipeline_processor import PipelineProcessor @@ -92,7 +94,11 @@ def build_graph( # when multiple pipelines fan-in to the same sink. _sink_record_ids = {n.id for n in graph.nodes if n.type in ("sink", "record")} - # 1) Create one queue per edge (all edges are stream; frame-by-frame) + # 1) Create one queue per edge (all edges are stream; frame-by-frame). + # Node→node edges use maxsize=1 so the DAG executes one cycle at a + # time and large tensors don't pile up in memory; pipeline edges use + # the larger default so video chunks can accumulate. + _node_ids = {n.id for n in graph.nodes if n.type == "node"} stream_queues: dict[tuple[str, str], queue.Queue] = {} for e in graph.edges: if e.kind == "stream": @@ -105,10 +111,12 @@ def build_graph( f"node={e.to_node!r}, port={e.to_port!r}. " f"Fan-in to a single port is not supported." ) - stream_queues[key] = queue.Queue(maxsize=DEFAULT_INPUT_QUEUE_MAXSIZE) + both_nodes = e.from_node in _node_ids and e.to_node in _node_ids + size = 1 if both_nodes else DEFAULT_INPUT_QUEUE_MAXSIZE + stream_queues[key] = queue.Queue(maxsize=size) - # 2) Create a processor per pipeline node and wire input_queues per port - node_processors: dict[str, PipelineProcessor] = {} + # 2) Create a processor per pipeline/custom node and wire input_queues + node_processors: dict[str, Any] = {} # PipelineProcessor | NodeProcessor pipeline_ids: list[str] = [] # Per-pipeline tempo: if any node explicitly opts in via tempo_sync=True, @@ -118,24 +126,46 @@ def build_graph( any_node_has_tempo = any(n.tempo_sync for n in graph.nodes if n.type == "pipeline") for node in graph.nodes: - if node.type != "pipeline" or node.pipeline_id is None: + if node.type == "pipeline" and node.pipeline_id is not None: + node_gets_tempo = node.tempo_sync or not any_node_has_tempo + pipeline = pipeline_manager.get_pipeline_by_id(node.id) + processor = PipelineProcessor( + pipeline=pipeline, + pipeline_id=node.pipeline_id, + initial_parameters=initial_parameters.copy(), + session_id=session_id, + user_id=user_id, + connection_id=connection_id, + connection_info=connection_info, + tempo_sync=tempo_sync if node_gets_tempo else None, + modulation_engine=modulation_engine if node_gets_tempo else None, + node_id=node.id, + ) + node_processors[node.id] = processor + pipeline_ids.append(node.pipeline_id) + elif node.type == "node" and node.node_type_id is not None: + from scope.core.nodes.processor import NodeProcessor + + node_cls = NodeRegistry.get(node.node_type_id) + if node_cls is None: + raise ValueError( + f"Unknown node type '{node.node_type_id}' for node '{node.id}'" + ) + node_instance = node_cls(node_id=node.id) + # Merge per-node params (from workflow) with global initial params. + # Per-node values take precedence (e.g. "steps": 8 on DiffusionConfig). + node_params = {**initial_parameters} + if node.params: + node_params.update(node.params) + processor = NodeProcessor( + node=node_instance, + node_id=node.id, + initial_parameters=node_params, + ) + node_processors[node.id] = processor + pipeline_ids.append(f"node:{node.node_type_id}") + else: continue - node_gets_tempo = node.tempo_sync or not any_node_has_tempo - pipeline = pipeline_manager.get_pipeline_by_id(node.id) - processor = PipelineProcessor( - pipeline=pipeline, - pipeline_id=node.pipeline_id, - initial_parameters=initial_parameters.copy(), - session_id=session_id, - user_id=user_id, - connection_id=connection_id, - connection_info=connection_info, - tempo_sync=tempo_sync if node_gets_tempo else None, - modulation_engine=modulation_engine if node_gets_tempo else None, - node_id=node.id, - ) - node_processors[node.id] = processor - pipeline_ids.append(node.pipeline_id) for e in graph.edges_to(node.id): if e.kind != "stream": @@ -143,10 +173,13 @@ def build_graph( q = stream_queues.get((node.id, e.to_port)) if q is not None: processor.input_queues[e.to_port] = q + # Mark audio ports so the processor consumes them differently + if e.to_port == "audio": + processor.audio_input_ports.add(e.to_port) # 3) Set each producer's output_queues per port for node in graph.nodes: - if node.type != "pipeline" or node.id not in node_processors: + if node.type not in ("pipeline", "node") or node.id not in node_processors: continue proc = node_processors[node.id] out_by_port: dict[str, list[queue.Queue]] = {} @@ -207,6 +240,11 @@ def build_graph( feeder_proc = node_processors.get(e.from_node) if feeder_proc is not None: sink_processors_by_node[sink_id] = feeder_proc + # Audio edges to sinks are served via audio_output_queue, + # not dedicated sink queues — skip queue allocation so + # the feeder isn't blocked on a queue nobody drains. + if e.from_port == "audio" or e.to_port == "audio": + break sink_node = node_by_id[sink_id] sink_mode = sink_node.sink_mode # WebRTC preview reads sink_queues_by_node; NDI/Spout/Syphon threads @@ -334,11 +372,23 @@ def _validate_edge_ports( # Build a map of node_id -> (declared_inputs, declared_outputs) port_map: dict[str, tuple[set[str], set[str]]] = {} for node in graph.nodes: - if node.type != "pipeline" or node.pipeline_id is None: - continue - pipeline = pipeline_manager.get_pipeline_by_id(node.id) - config_class = pipeline.get_config_class() - port_map[node.id] = (set(config_class.inputs), set(config_class.outputs)) + if node.type == "pipeline" and node.pipeline_id is not None: + pipeline = pipeline_manager.get_pipeline_by_id(node.id) + config_class = pipeline.get_config_class() + port_map[node.id] = ( + set(config_class.inputs), + set(config_class.outputs), + ) + elif node.type == "node" and node.node_type_id is not None: + node_cls = NodeRegistry.get(node.node_type_id) + if node_cls is not None: + defn = node_cls.get_definition() + static_outputs = {p.name for p in defn.outputs} + dynamic_outputs = node_cls.get_dynamic_output_ports(node.params or {}) + port_map[node.id] = ( + {p.name for p in defn.inputs}, + static_outputs | dynamic_outputs, + ) errors: list[str] = [] for e in graph.edges: diff --git a/src/scope/server/graph_schema.py b/src/scope/server/graph_schema.py index e15fae357..6dbcb283e 100644 --- a/src/scope/server/graph_schema.py +++ b/src/scope/server/graph_schema.py @@ -29,7 +29,7 @@ from __future__ import annotations -from typing import Literal +from typing import Any, Literal from pydantic import BaseModel, Field @@ -41,14 +41,25 @@ class GraphNode(BaseModel): ..., description="Unique node id (e.g. 'input', 'yolo_plugin', 'longlive', 'output')", ) - type: Literal["source", "pipeline", "sink", "record"] = Field( + type: Literal["source", "pipeline", "sink", "record", "node"] = Field( ..., - description="source = external input, pipeline = pipeline instance, sink = output, record = file recorder", + description=( + "source = external input, pipeline = pipeline instance, " + "sink = output, record = file recorder, node = custom backend node" + ), ) pipeline_id: str | None = Field( default=None, description="Pipeline ID (registry key) when type is 'pipeline'", ) + node_type_id: str | None = Field( + default=None, + description="Node type ID (NodeRegistry key) when type is 'node'", + ) + params: dict[str, Any] | None = Field( + default=None, + description="Per-node parameter values for custom nodes", + ) source_mode: str | None = Field( default=None, description="Video source mode for source nodes: 'video', 'camera', 'spout', 'ndi', 'syphon'", @@ -114,6 +125,10 @@ def get_record_node_ids(self) -> list[str]: """Return node ids that are record nodes.""" return [n.id for n in self.nodes if n.type == "record"] + def get_backend_node_ids(self) -> list[str]: + """Return node ids that are backend (custom) nodes.""" + return [n.id for n in self.nodes if n.type == "node"] + def edges_from(self, node_id: str) -> list[GraphEdge]: """Return edges whose source is the given node.""" return [e for e in self.edges if e.from_node == node_id] @@ -149,10 +164,12 @@ def validate_structure(self) -> list[str]: if not self.get_sink_node_ids(): errors.append("Graph must have at least one sink node") - # Pipeline nodes must have pipeline_id + # Pipeline nodes must have pipeline_id; backend nodes need node_type_id for node in self.nodes: if node.type == "pipeline" and not node.pipeline_id: errors.append(f"Pipeline node '{node.id}' is missing pipeline_id") + if node.type == "node" and not node.node_type_id: + errors.append(f"Node '{node.id}' is missing node_type_id") # Edge references must point to existing nodes node_id_set = set(node_ids) diff --git a/src/scope/server/mcp_router.py b/src/scope/server/mcp_router.py index 0d1426e7a..ad2b2119c 100644 --- a/src/scope/server/mcp_router.py +++ b/src/scope/server/mcp_router.py @@ -384,10 +384,11 @@ async def start_stream( detail=f"Invalid graph: {'; '.join(errors)}", ) pipeline_ids = graph_config.get_pipeline_node_ids() - if not pipeline_ids: + backend_node_ids = graph_config.get_backend_node_ids() + if not pipeline_ids and not backend_node_ids: raise HTTPException( status_code=400, - detail="Graph must contain at least one pipeline node", + detail="Graph must contain at least one pipeline or custom node", ) pipeline_tuples = [ @@ -397,8 +398,8 @@ async def start_stream( ] pipeline_id_list = [t[1] for t in pipeline_tuples] - if not use_cloud: - # Local mode: load pipelines locally + if not use_cloud and pipeline_tuples: + # Local mode: load pipelines locally (skip for node-only graphs) await pipeline_manager.load_pipelines(pipeline_tuples) initial_params: dict = { @@ -449,9 +450,25 @@ async def start_stream( detail="FrameProcessor failed to start (check logs for details)", ) + # Custom (backend) nodes can produce audio too — if any audio sink / + # audio-emitting custom node is in the graph, expect audio output. + expect_audio = PipelineRegistry.chain_produces_audio(pipeline_id_list) + if request.graph is not None and not expect_audio: + from scope.core.nodes.registry import NodeRegistry as _NR + + for gnode in graph_config.nodes: + if gnode.type == "node" and gnode.node_type_id: + nc = _NR.get(gnode.node_type_id) + if nc is None: + continue + defn = nc.get_definition() + if any(p.port_type == "audio" for p in defn.outputs): + expect_audio = True + break + session = HeadlessSession( frame_processor=frame_processor, - expect_audio=PipelineRegistry.chain_produces_audio(pipeline_id_list), + expect_audio=expect_audio, ) session.start_frame_consumer() webrtc_manager.add_headless_session(session) diff --git a/src/scope/server/pipeline_processor.py b/src/scope/server/pipeline_processor.py index 701de137b..69528fe7b 100644 --- a/src/scope/server/pipeline_processor.py +++ b/src/scope/server/pipeline_processor.py @@ -89,6 +89,10 @@ def __init__( # record_queues_by_node). Updated by _resize_output_queue so cached # references stay in sync when a queue object is replaced. self.external_queue_refs: list[tuple[dict, str]] = [] + # Input ports that carry audio (tensor, sample_rate) tuples instead + # of video frame tensors. Set by graph_executor when wiring audio + # edges so audio inputs don't get accumulated into a video chunk. + self.audio_input_ports: set[str] = set() # Audio output queue: (audio_tensor, sample_rate) tuples. # Consumed by FrameProcessor.get_audio() on the sink processor. @@ -550,6 +554,16 @@ def process_chunk(self): "Audio output queue full for %s, dropping audio chunk", self.pipeline_id, ) + # Also fan out audio to graph edge queues on the "audio" port + # so downstream custom nodes can consume it. + audio_queues = self.output_queues.get("audio") + if audio_queues: + packed = (audio_cpu, audio_sample_rate) + for q in audio_queues: + try: + q.put_nowait(packed) + except queue.Full: + pass # Extract video from the returned dictionary output = output_dict.get("video") diff --git a/src/scope/server/webrtc.py b/src/scope/server/webrtc.py index 18c8618a8..045f24fec 100644 --- a/src/scope/server/webrtc.py +++ b/src/scope/server/webrtc.py @@ -60,6 +60,39 @@ vpx.MAX_BITRATE = 10000000 +def _graph_produces_audio(graph_data: dict) -> bool: + """Return True when any node in the graph exposes an audio output. + + Inspects custom-node definitions from :class:`NodeRegistry` and + pipeline configs so WebRTC can decide whether to open an audio + track for graphs that produce audio through custom nodes rather + than through a registered pipeline. + """ + from scope.core.nodes.registry import NodeRegistry + + for node in graph_data.get("nodes", []): + node_type = node.get("type") + if node_type == "node": + node_type_id = node.get("node_type_id") + if node_type_id: + node_cls = NodeRegistry.get(node_type_id) + if node_cls is not None: + defn = node_cls.get_definition() + if any(p.port_type == "audio" for p in defn.outputs): + return True + elif node_type == "pipeline": + pid = node.get("pipeline_id") + if pid: + cfg = PipelineRegistry.get_config_class(pid) + if cfg and getattr(cfg, "produces_audio", False): + return True + # Also treat any audio-kind edge as an audio-producing graph. + for edge in graph_data.get("edges", []): + if edge.get("from_port") == "audio" or edge.get("to_port") == "audio": + return True + return False + + def _parse_graph_node_ids( initial_parameters: dict, ) -> tuple[list[str], list[str], list[str], list[str], list[str], bool]: @@ -478,6 +511,11 @@ async def handle_offer( ) produces_audio = PipelineRegistry.chain_produces_audio(pipeline_ids) + # Graphs that emit audio through custom nodes (e.g. ACEStep) + # have no pipeline-produced audio, so also inspect the graph. + graph_data_for_audio = initial_parameters.get("graph") + if not produces_audio and graph_data_for_audio: + produces_audio = _graph_produces_audio(graph_data_for_audio) if produces_audio: audio_track = AudioProcessingTrack( frame_processor=frame_processor,