feat: scaffold avatar pipeline with ComfyUI driver, MLflow logging, and rclone promotion

- setup.sh: automated desktop env setup (ComfyUI, 3D-Pack, UniRig, Blender, Ray)
- ray-join.sh: join Ray cluster as external worker with 3d_gen resource label
- vrm_export.py: headless Blender GLB→VRM conversion script
- generate.py: ComfyUI API driver (submit workflow JSON, poll, download outputs)
- log_mlflow.py: REST-only MLflow experiment tracking (no SDK dependency)
- promote.py: rclone promotion of VRM files to gravenhollow S3
- CLI entry points: avatar-generate, avatar-promote
- workflows/ placeholder for ComfyUI exported workflow JSONs

Implements ADR-0063 (ComfyUI + TRELLIS + UniRig 3D avatar pipeline)
This commit is contained in:
2026-02-24 05:44:04 -05:00
parent a0c24406bd
commit 202b4e1d61
11 changed files with 1138 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
# ComfyUI image-to-VRM avatar generation pipeline
# with TRELLIS + UniRig on desktop Ray worker

227
avatar_pipeline/generate.py Normal file
View File

@@ -0,0 +1,227 @@
"""Submit a ComfyUI workflow and collect the output.
ComfyUI exposes a REST API at http://localhost:8188:
POST /prompt — queue a workflow
GET /history/{id} — poll execution status
GET /view?filename= — download output files
This module loads a workflow JSON, injects runtime parameters
(image path, seed, etc.), submits it, and waits for completion.
Usage:
avatar-generate --workflow workflows/image-to-vrm.json \\
--image reference.png \\
--name "My Avatar" \\
[--seed 42] \\
[--comfyui-url http://localhost:8188]
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
import time
import uuid
from pathlib import Path
import httpx
logger = logging.getLogger(__name__)
DEFAULT_COMFYUI_URL = "http://localhost:8188"
POLL_INTERVAL_S = 2.0
MAX_WAIT_S = 600 # 10 minutes
def load_workflow(path: Path) -> dict:
"""Load a ComfyUI workflow JSON (API format)."""
with open(path) as f:
return json.load(f)
def inject_params(workflow: dict, image_path: str | None, seed: int | None) -> dict:
"""Inject runtime parameters into the workflow.
This is a best-effort approach — it scans nodes for known types
and overrides their values. The exact node IDs depend on the
exported workflow JSON, so this may need adjustment per workflow.
"""
for _node_id, node in workflow.items():
if not isinstance(node, dict):
continue
class_type = node.get("class_type", "")
inputs = node.get("inputs", {})
# Override image path in LoadImage nodes
if class_type == "LoadImage" and image_path:
inputs["image"] = image_path
logger.info("Injected image path: %s", image_path)
# Override seed in any node that has a seed input
if seed is not None and "seed" in inputs:
inputs["seed"] = seed
logger.info("Injected seed=%d into node %s (%s)", seed, _node_id, class_type)
return workflow
def submit_prompt(
client: httpx.Client,
workflow: dict,
client_id: str,
) -> str:
"""Submit a workflow to ComfyUI and return the prompt ID."""
payload = {
"prompt": workflow,
"client_id": client_id,
}
resp = client.post("/prompt", json=payload)
resp.raise_for_status()
data = resp.json()
prompt_id = data["prompt_id"]
logger.info("Submitted prompt: %s", prompt_id)
return prompt_id
def wait_for_completion(
client: httpx.Client,
prompt_id: str,
timeout: float = MAX_WAIT_S,
) -> dict:
"""Poll /history/{prompt_id} until the workflow completes."""
start = time.monotonic()
while time.monotonic() - start < timeout:
resp = client.get(f"/history/{prompt_id}")
resp.raise_for_status()
history = resp.json()
if prompt_id in history:
entry = history[prompt_id]
status = entry.get("status", {})
if status.get("completed", False):
elapsed = time.monotonic() - start
logger.info("Workflow completed in %.1fs", elapsed)
return entry
if status.get("status_str") == "error":
logger.error("Workflow failed: %s", status)
raise RuntimeError(f"ComfyUI workflow failed: {status}")
time.sleep(POLL_INTERVAL_S)
raise TimeoutError(f"Workflow did not complete within {timeout}s")
def collect_outputs(entry: dict) -> list[dict]:
"""Extract output file info from a completed history entry."""
outputs = []
for _node_id, node_output in entry.get("outputs", {}).items():
for output_type in ("images", "gltf", "meshes"):
for item in node_output.get(output_type, []):
outputs.append(item)
return outputs
def download_output(
client: httpx.Client,
filename: str,
subfolder: str,
output_dir: Path,
file_type: str = "output",
) -> Path:
"""Download an output file from ComfyUI."""
params = {"filename": filename, "subfolder": subfolder, "type": file_type}
resp = client.get("/view", params=params)
resp.raise_for_status()
dest = output_dir / filename
dest.parent.mkdir(parents=True, exist_ok=True)
dest.write_bytes(resp.content)
logger.info("Downloaded: %s (%d bytes)", dest, len(resp.content))
return dest
def generate(
workflow_path: Path,
image_path: str | None = None,
seed: int | None = None,
output_dir: Path | None = None,
comfyui_url: str = DEFAULT_COMFYUI_URL,
) -> list[Path]:
"""Run the full generation pipeline: load → inject → submit → wait → download."""
output_dir = output_dir or Path("exports")
output_dir.mkdir(parents=True, exist_ok=True)
client_id = str(uuid.uuid4())
workflow = load_workflow(workflow_path)
workflow = inject_params(workflow, image_path, seed)
with httpx.Client(base_url=comfyui_url, timeout=30.0) as client:
prompt_id = submit_prompt(client, workflow, client_id)
entry = wait_for_completion(client, prompt_id)
outputs = collect_outputs(entry)
downloaded = []
for item in outputs:
path = download_output(
client,
filename=item["filename"],
subfolder=item.get("subfolder", ""),
output_dir=output_dir,
file_type=item.get("type", "output"),
)
downloaded.append(path)
return downloaded
def main() -> None:
"""CLI entry point."""
parser = argparse.ArgumentParser(
description="Submit a ComfyUI workflow and collect outputs"
)
parser.add_argument(
"--workflow",
type=Path,
required=True,
help="Path to ComfyUI workflow JSON (API format)",
)
parser.add_argument("--image", help="Reference image path to inject into LoadImage nodes")
parser.add_argument("--seed", type=int, help="Seed to inject into generation nodes")
parser.add_argument("--output-dir", type=Path, default=Path("exports"), help="Output directory")
parser.add_argument(
"--comfyui-url",
default=DEFAULT_COMFYUI_URL,
help="ComfyUI server URL (default: %(default)s)",
)
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
try:
outputs = generate(
workflow_path=args.workflow,
image_path=args.image,
seed=args.seed,
output_dir=args.output_dir,
comfyui_url=args.comfyui_url,
)
print(f"\nGenerated {len(outputs)} output(s):")
for p in outputs:
print(f" {p}")
except Exception:
logger.exception("Generation failed")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,150 @@
"""Log avatar generation results to MLflow via REST API.
Uses the same lightweight REST-only approach as ray-serve's
mlflow_logger.py — no heavyweight mlflow SDK dependency.
Usage:
from avatar_pipeline.log_mlflow import log_generation
log_generation(
avatar_name="Silver-Mage",
params={"trellis_seed": 42, "trellis_simplify": 0.95},
metrics={"vertex_count": 12345, "face_count": 8000, "duration_s": 45.2},
artifacts={"vrm": Path("exports/Silver-Mage.vrm")},
)
"""
from __future__ import annotations
import json
import logging
import os
import time
import urllib.error
import urllib.request
from pathlib import Path
logger = logging.getLogger(__name__)
DEFAULT_TRACKING_URI = "http://mlflow.mlflow.svc.cluster.local:80"
EXPERIMENT_NAME = "3d-avatar-generation"
def _base_url() -> str:
return os.environ.get("MLFLOW_TRACKING_URI", DEFAULT_TRACKING_URI).rstrip("/")
def _post(path: str, body: dict) -> dict:
url = f"{_base_url()}/api/2.0/mlflow/{path}"
data = json.dumps(body).encode()
req = urllib.request.Request(
url, data=data, headers={"Content-Type": "application/json"}, method="POST"
)
with urllib.request.urlopen(req, timeout=15) as resp:
return json.loads(resp.read().decode())
def _get_or_create_experiment(name: str) -> str:
try:
resp = _post("experiments/get-by-name", {"experiment_name": name})
return resp["experiment"]["experiment_id"]
except urllib.error.HTTPError:
resp = _post("experiments/create", {"name": name})
return resp["experiment_id"]
def _create_run(experiment_id: str, run_name: str, tags: dict[str, str]) -> str:
tag_list = [
{"key": k, "value": v}
for k, v in {
"mlflow.runName": run_name,
"mlflow.source.type": "LOCAL",
"mlflow.source.name": "avatar-pipeline",
"hostname": os.environ.get("HOSTNAME", "desktop"),
**tags,
}.items()
]
resp = _post(
"runs/create",
{
"experiment_id": experiment_id,
"run_name": run_name,
"start_time": int(time.time() * 1000),
"tags": tag_list,
},
)
return resp["run"]["info"]["run_id"]
def _log_params(run_id: str, params: dict[str, str | int | float]) -> None:
param_list = [{"key": k, "value": str(v)[:500]} for k, v in params.items()]
_post("runs/log-batch", {"run_id": run_id, "params": param_list})
def _log_metrics(run_id: str, metrics: dict[str, float]) -> None:
ts = int(time.time() * 1000)
metric_list = [
{"key": k, "value": float(v), "timestamp": ts, "step": 0} for k, v in metrics.items()
]
_post("runs/log-batch", {"run_id": run_id, "metrics": metric_list})
def _log_artifact(run_id: str, key: str, path: Path) -> None:
"""Log an artifact path as a tag (actual artifact upload requires artifact store).
For local files, we record the path. For S3-promoted files, the caller
should include the S3 URI in params instead.
"""
_post(
"runs/log-batch",
{
"run_id": run_id,
"tags": [{"key": f"artifact.{key}", "value": str(path)}],
},
)
def _end_run(run_id: str) -> None:
_post(
"runs/update",
{
"run_id": run_id,
"status": "FINISHED",
"end_time": int(time.time() * 1000),
},
)
def log_generation(
avatar_name: str,
params: dict[str, str | int | float] | None = None,
metrics: dict[str, float] | None = None,
artifacts: dict[str, Path] | None = None,
tags: dict[str, str] | None = None,
experiment_name: str = EXPERIMENT_NAME,
) -> str | None:
"""Log a complete avatar generation run to MLflow.
Returns the MLflow run_id on success, None on failure.
"""
try:
exp_id = _get_or_create_experiment(experiment_name)
run_id = _create_run(exp_id, run_name=avatar_name, tags=tags or {})
if params:
_log_params(run_id, params)
if metrics:
_log_metrics(run_id, metrics)
if artifacts:
for key, path in artifacts.items():
_log_artifact(run_id, key, path)
_end_run(run_id)
logger.info("Logged to MLflow: run_id=%s, experiment=%s", run_id, experiment_name)
return run_id
except Exception:
logger.warning("MLflow logging failed — generation still succeeded", exc_info=True)
return None

131
avatar_pipeline/promote.py Normal file
View File

@@ -0,0 +1,131 @@
"""Promote VRM files to gravenhollow via rclone.
Usage:
avatar-promote exports/Silver-Mage.vrm
avatar-promote exports/Silver-Mage.vrm --bucket companion-avatars
avatar-promote --dry-run exports/*.vrm
"""
from __future__ import annotations
import argparse
import logging
import shutil
import subprocess
import sys
from pathlib import Path
logger = logging.getLogger(__name__)
DEFAULT_REMOTE = "gravenhollow"
DEFAULT_BUCKET = "avatar-models"
def check_rclone() -> bool:
"""Verify rclone is installed and the remote is configured."""
if not shutil.which("rclone"):
logger.error("rclone not found. Install: sudo pacman -S rclone")
return False
result = subprocess.run(
["rclone", "listremotes"],
capture_output=True,
text=True,
check=False,
)
remotes = result.stdout.strip().split("\n")
if f"{DEFAULT_REMOTE}:" not in remotes:
logger.error(
"rclone remote '%s' not configured. Run scripts/setup.sh or configure manually.",
DEFAULT_REMOTE,
)
return False
return True
def promote(
files: list[Path],
remote: str = DEFAULT_REMOTE,
bucket: str = DEFAULT_BUCKET,
dry_run: bool = False,
) -> list[str]:
"""Copy VRM files to gravenhollow S3 via rclone.
Returns list of promoted remote paths.
"""
if not check_rclone():
raise RuntimeError("rclone not available")
promoted = []
for file_path in files:
if not file_path.exists():
logger.warning("File not found, skipping: %s", file_path)
continue
if file_path.suffix.lower() not in (".vrm", ".glb", ".fbx"):
logger.warning("Unexpected file type, skipping: %s", file_path)
continue
dest = f"{remote}:{bucket}/{file_path.name}"
cmd = ["rclone", "copy", str(file_path), f"{remote}:{bucket}/"]
if dry_run:
cmd.append("--dry-run")
logger.info("Promoting: %s%s", file_path, dest)
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
logger.error("rclone failed for %s: %s", file_path, result.stderr)
continue
if dry_run:
logger.info(" (dry-run) Would copy %s", file_path.name)
else:
logger.info(" Promoted: %s", dest)
promoted.append(dest)
return promoted
def main() -> None:
"""CLI entry point."""
parser = argparse.ArgumentParser(description="Promote VRM files to gravenhollow storage")
parser.add_argument("files", nargs="+", type=Path, help="VRM/GLB files to promote")
parser.add_argument("--remote", default=DEFAULT_REMOTE, help="rclone remote name")
parser.add_argument("--bucket", default=DEFAULT_BUCKET, help="S3 bucket name")
parser.add_argument("--dry-run", action="store_true", help="Show what would be copied")
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
try:
promoted = promote(
files=args.files,
remote=args.remote,
bucket=args.bucket,
dry_run=args.dry_run,
)
print(f"\nPromoted {len(promoted)} file(s)")
for p in promoted:
print(f" {p}")
except Exception:
logger.exception("Promotion failed")
sys.exit(1)
if __name__ == "__main__":
main()