feat: scaffold avatar pipeline with ComfyUI driver, MLflow logging, and rclone promotion
- setup.sh: automated desktop env setup (ComfyUI, 3D-Pack, UniRig, Blender, Ray) - ray-join.sh: join Ray cluster as external worker with 3d_gen resource label - vrm_export.py: headless Blender GLB→VRM conversion script - generate.py: ComfyUI API driver (submit workflow JSON, poll, download outputs) - log_mlflow.py: REST-only MLflow experiment tracking (no SDK dependency) - promote.py: rclone promotion of VRM files to gravenhollow S3 - CLI entry points: avatar-generate, avatar-promote - workflows/ placeholder for ComfyUI exported workflow JSONs Implements ADR-0063 (ComfyUI + TRELLIS + UniRig 3D avatar pipeline)
This commit is contained in:
227
avatar_pipeline/generate.py
Normal file
227
avatar_pipeline/generate.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Submit a ComfyUI workflow and collect the output.
|
||||
|
||||
ComfyUI exposes a REST API at http://localhost:8188:
|
||||
POST /prompt — queue a workflow
|
||||
GET /history/{id} — poll execution status
|
||||
GET /view?filename= — download output files
|
||||
|
||||
This module loads a workflow JSON, injects runtime parameters
|
||||
(image path, seed, etc.), submits it, and waits for completion.
|
||||
|
||||
Usage:
|
||||
avatar-generate --workflow workflows/image-to-vrm.json \\
|
||||
--image reference.png \\
|
||||
--name "My Avatar" \\
|
||||
[--seed 42] \\
|
||||
[--comfyui-url http://localhost:8188]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_COMFYUI_URL = "http://localhost:8188"
|
||||
POLL_INTERVAL_S = 2.0
|
||||
MAX_WAIT_S = 600 # 10 minutes
|
||||
|
||||
|
||||
def load_workflow(path: Path) -> dict:
|
||||
"""Load a ComfyUI workflow JSON (API format)."""
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def inject_params(workflow: dict, image_path: str | None, seed: int | None) -> dict:
|
||||
"""Inject runtime parameters into the workflow.
|
||||
|
||||
This is a best-effort approach — it scans nodes for known types
|
||||
and overrides their values. The exact node IDs depend on the
|
||||
exported workflow JSON, so this may need adjustment per workflow.
|
||||
"""
|
||||
for _node_id, node in workflow.items():
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
|
||||
class_type = node.get("class_type", "")
|
||||
inputs = node.get("inputs", {})
|
||||
|
||||
# Override image path in LoadImage nodes
|
||||
if class_type == "LoadImage" and image_path:
|
||||
inputs["image"] = image_path
|
||||
logger.info("Injected image path: %s", image_path)
|
||||
|
||||
# Override seed in any node that has a seed input
|
||||
if seed is not None and "seed" in inputs:
|
||||
inputs["seed"] = seed
|
||||
logger.info("Injected seed=%d into node %s (%s)", seed, _node_id, class_type)
|
||||
|
||||
return workflow
|
||||
|
||||
|
||||
def submit_prompt(
|
||||
client: httpx.Client,
|
||||
workflow: dict,
|
||||
client_id: str,
|
||||
) -> str:
|
||||
"""Submit a workflow to ComfyUI and return the prompt ID."""
|
||||
payload = {
|
||||
"prompt": workflow,
|
||||
"client_id": client_id,
|
||||
}
|
||||
resp = client.post("/prompt", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
prompt_id = data["prompt_id"]
|
||||
logger.info("Submitted prompt: %s", prompt_id)
|
||||
return prompt_id
|
||||
|
||||
|
||||
def wait_for_completion(
|
||||
client: httpx.Client,
|
||||
prompt_id: str,
|
||||
timeout: float = MAX_WAIT_S,
|
||||
) -> dict:
|
||||
"""Poll /history/{prompt_id} until the workflow completes."""
|
||||
start = time.monotonic()
|
||||
|
||||
while time.monotonic() - start < timeout:
|
||||
resp = client.get(f"/history/{prompt_id}")
|
||||
resp.raise_for_status()
|
||||
history = resp.json()
|
||||
|
||||
if prompt_id in history:
|
||||
entry = history[prompt_id]
|
||||
status = entry.get("status", {})
|
||||
|
||||
if status.get("completed", False):
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info("Workflow completed in %.1fs", elapsed)
|
||||
return entry
|
||||
|
||||
if status.get("status_str") == "error":
|
||||
logger.error("Workflow failed: %s", status)
|
||||
raise RuntimeError(f"ComfyUI workflow failed: {status}")
|
||||
|
||||
time.sleep(POLL_INTERVAL_S)
|
||||
|
||||
raise TimeoutError(f"Workflow did not complete within {timeout}s")
|
||||
|
||||
|
||||
def collect_outputs(entry: dict) -> list[dict]:
|
||||
"""Extract output file info from a completed history entry."""
|
||||
outputs = []
|
||||
for _node_id, node_output in entry.get("outputs", {}).items():
|
||||
for output_type in ("images", "gltf", "meshes"):
|
||||
for item in node_output.get(output_type, []):
|
||||
outputs.append(item)
|
||||
return outputs
|
||||
|
||||
|
||||
def download_output(
|
||||
client: httpx.Client,
|
||||
filename: str,
|
||||
subfolder: str,
|
||||
output_dir: Path,
|
||||
file_type: str = "output",
|
||||
) -> Path:
|
||||
"""Download an output file from ComfyUI."""
|
||||
params = {"filename": filename, "subfolder": subfolder, "type": file_type}
|
||||
resp = client.get("/view", params=params)
|
||||
resp.raise_for_status()
|
||||
|
||||
dest = output_dir / filename
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_bytes(resp.content)
|
||||
logger.info("Downloaded: %s (%d bytes)", dest, len(resp.content))
|
||||
return dest
|
||||
|
||||
|
||||
def generate(
|
||||
workflow_path: Path,
|
||||
image_path: str | None = None,
|
||||
seed: int | None = None,
|
||||
output_dir: Path | None = None,
|
||||
comfyui_url: str = DEFAULT_COMFYUI_URL,
|
||||
) -> list[Path]:
|
||||
"""Run the full generation pipeline: load → inject → submit → wait → download."""
|
||||
output_dir = output_dir or Path("exports")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
client_id = str(uuid.uuid4())
|
||||
workflow = load_workflow(workflow_path)
|
||||
workflow = inject_params(workflow, image_path, seed)
|
||||
|
||||
with httpx.Client(base_url=comfyui_url, timeout=30.0) as client:
|
||||
prompt_id = submit_prompt(client, workflow, client_id)
|
||||
entry = wait_for_completion(client, prompt_id)
|
||||
outputs = collect_outputs(entry)
|
||||
|
||||
downloaded = []
|
||||
for item in outputs:
|
||||
path = download_output(
|
||||
client,
|
||||
filename=item["filename"],
|
||||
subfolder=item.get("subfolder", ""),
|
||||
output_dir=output_dir,
|
||||
file_type=item.get("type", "output"),
|
||||
)
|
||||
downloaded.append(path)
|
||||
|
||||
return downloaded
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""CLI entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Submit a ComfyUI workflow and collect outputs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workflow",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to ComfyUI workflow JSON (API format)",
|
||||
)
|
||||
parser.add_argument("--image", help="Reference image path to inject into LoadImage nodes")
|
||||
parser.add_argument("--seed", type=int, help="Seed to inject into generation nodes")
|
||||
parser.add_argument("--output-dir", type=Path, default=Path("exports"), help="Output directory")
|
||||
parser.add_argument(
|
||||
"--comfyui-url",
|
||||
default=DEFAULT_COMFYUI_URL,
|
||||
help="ComfyUI server URL (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument("--verbose", "-v", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
outputs = generate(
|
||||
workflow_path=args.workflow,
|
||||
image_path=args.image,
|
||||
seed=args.seed,
|
||||
output_dir=args.output_dir,
|
||||
comfyui_url=args.comfyui_url,
|
||||
)
|
||||
print(f"\nGenerated {len(outputs)} output(s):")
|
||||
for p in outputs:
|
||||
print(f" {p}")
|
||||
except Exception:
|
||||
logger.exception("Generation failed")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user