feat: scaffold avatar pipeline with ComfyUI driver, MLflow logging, and rclone promotion
- setup.sh: automated desktop env setup (ComfyUI, 3D-Pack, UniRig, Blender, Ray) - ray-join.sh: join Ray cluster as external worker with 3d_gen resource label - vrm_export.py: headless Blender GLB→VRM conversion script - generate.py: ComfyUI API driver (submit workflow JSON, poll, download outputs) - log_mlflow.py: REST-only MLflow experiment tracking (no SDK dependency) - promote.py: rclone promotion of VRM files to gravenhollow S3 - CLI entry points: avatar-generate, avatar-promote - workflows/ placeholder for ComfyUI exported workflow JSONs Implements ADR-0063 (ComfyUI + TRELLIS + UniRig 3D avatar pipeline)
This commit is contained in:
2
avatar_pipeline/__init__.py
Normal file
2
avatar_pipeline/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# ComfyUI image-to-VRM avatar generation pipeline
|
||||
# with TRELLIS + UniRig on desktop Ray worker
|
||||
227
avatar_pipeline/generate.py
Normal file
227
avatar_pipeline/generate.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Submit a ComfyUI workflow and collect the output.
|
||||
|
||||
ComfyUI exposes a REST API at http://localhost:8188:
|
||||
POST /prompt — queue a workflow
|
||||
GET /history/{id} — poll execution status
|
||||
GET /view?filename= — download output files
|
||||
|
||||
This module loads a workflow JSON, injects runtime parameters
|
||||
(image path, seed, etc.), submits it, and waits for completion.
|
||||
|
||||
Usage:
|
||||
avatar-generate --workflow workflows/image-to-vrm.json \\
|
||||
--image reference.png \\
|
||||
--name "My Avatar" \\
|
||||
[--seed 42] \\
|
||||
[--comfyui-url http://localhost:8188]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_COMFYUI_URL = "http://localhost:8188"
|
||||
POLL_INTERVAL_S = 2.0
|
||||
MAX_WAIT_S = 600 # 10 minutes
|
||||
|
||||
|
||||
def load_workflow(path: Path) -> dict:
|
||||
"""Load a ComfyUI workflow JSON (API format)."""
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def inject_params(workflow: dict, image_path: str | None, seed: int | None) -> dict:
|
||||
"""Inject runtime parameters into the workflow.
|
||||
|
||||
This is a best-effort approach — it scans nodes for known types
|
||||
and overrides their values. The exact node IDs depend on the
|
||||
exported workflow JSON, so this may need adjustment per workflow.
|
||||
"""
|
||||
for _node_id, node in workflow.items():
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
|
||||
class_type = node.get("class_type", "")
|
||||
inputs = node.get("inputs", {})
|
||||
|
||||
# Override image path in LoadImage nodes
|
||||
if class_type == "LoadImage" and image_path:
|
||||
inputs["image"] = image_path
|
||||
logger.info("Injected image path: %s", image_path)
|
||||
|
||||
# Override seed in any node that has a seed input
|
||||
if seed is not None and "seed" in inputs:
|
||||
inputs["seed"] = seed
|
||||
logger.info("Injected seed=%d into node %s (%s)", seed, _node_id, class_type)
|
||||
|
||||
return workflow
|
||||
|
||||
|
||||
def submit_prompt(
|
||||
client: httpx.Client,
|
||||
workflow: dict,
|
||||
client_id: str,
|
||||
) -> str:
|
||||
"""Submit a workflow to ComfyUI and return the prompt ID."""
|
||||
payload = {
|
||||
"prompt": workflow,
|
||||
"client_id": client_id,
|
||||
}
|
||||
resp = client.post("/prompt", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
prompt_id = data["prompt_id"]
|
||||
logger.info("Submitted prompt: %s", prompt_id)
|
||||
return prompt_id
|
||||
|
||||
|
||||
def wait_for_completion(
|
||||
client: httpx.Client,
|
||||
prompt_id: str,
|
||||
timeout: float = MAX_WAIT_S,
|
||||
) -> dict:
|
||||
"""Poll /history/{prompt_id} until the workflow completes."""
|
||||
start = time.monotonic()
|
||||
|
||||
while time.monotonic() - start < timeout:
|
||||
resp = client.get(f"/history/{prompt_id}")
|
||||
resp.raise_for_status()
|
||||
history = resp.json()
|
||||
|
||||
if prompt_id in history:
|
||||
entry = history[prompt_id]
|
||||
status = entry.get("status", {})
|
||||
|
||||
if status.get("completed", False):
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info("Workflow completed in %.1fs", elapsed)
|
||||
return entry
|
||||
|
||||
if status.get("status_str") == "error":
|
||||
logger.error("Workflow failed: %s", status)
|
||||
raise RuntimeError(f"ComfyUI workflow failed: {status}")
|
||||
|
||||
time.sleep(POLL_INTERVAL_S)
|
||||
|
||||
raise TimeoutError(f"Workflow did not complete within {timeout}s")
|
||||
|
||||
|
||||
def collect_outputs(entry: dict) -> list[dict]:
|
||||
"""Extract output file info from a completed history entry."""
|
||||
outputs = []
|
||||
for _node_id, node_output in entry.get("outputs", {}).items():
|
||||
for output_type in ("images", "gltf", "meshes"):
|
||||
for item in node_output.get(output_type, []):
|
||||
outputs.append(item)
|
||||
return outputs
|
||||
|
||||
|
||||
def download_output(
|
||||
client: httpx.Client,
|
||||
filename: str,
|
||||
subfolder: str,
|
||||
output_dir: Path,
|
||||
file_type: str = "output",
|
||||
) -> Path:
|
||||
"""Download an output file from ComfyUI."""
|
||||
params = {"filename": filename, "subfolder": subfolder, "type": file_type}
|
||||
resp = client.get("/view", params=params)
|
||||
resp.raise_for_status()
|
||||
|
||||
dest = output_dir / filename
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_bytes(resp.content)
|
||||
logger.info("Downloaded: %s (%d bytes)", dest, len(resp.content))
|
||||
return dest
|
||||
|
||||
|
||||
def generate(
|
||||
workflow_path: Path,
|
||||
image_path: str | None = None,
|
||||
seed: int | None = None,
|
||||
output_dir: Path | None = None,
|
||||
comfyui_url: str = DEFAULT_COMFYUI_URL,
|
||||
) -> list[Path]:
|
||||
"""Run the full generation pipeline: load → inject → submit → wait → download."""
|
||||
output_dir = output_dir or Path("exports")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
client_id = str(uuid.uuid4())
|
||||
workflow = load_workflow(workflow_path)
|
||||
workflow = inject_params(workflow, image_path, seed)
|
||||
|
||||
with httpx.Client(base_url=comfyui_url, timeout=30.0) as client:
|
||||
prompt_id = submit_prompt(client, workflow, client_id)
|
||||
entry = wait_for_completion(client, prompt_id)
|
||||
outputs = collect_outputs(entry)
|
||||
|
||||
downloaded = []
|
||||
for item in outputs:
|
||||
path = download_output(
|
||||
client,
|
||||
filename=item["filename"],
|
||||
subfolder=item.get("subfolder", ""),
|
||||
output_dir=output_dir,
|
||||
file_type=item.get("type", "output"),
|
||||
)
|
||||
downloaded.append(path)
|
||||
|
||||
return downloaded
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""CLI entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Submit a ComfyUI workflow and collect outputs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workflow",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to ComfyUI workflow JSON (API format)",
|
||||
)
|
||||
parser.add_argument("--image", help="Reference image path to inject into LoadImage nodes")
|
||||
parser.add_argument("--seed", type=int, help="Seed to inject into generation nodes")
|
||||
parser.add_argument("--output-dir", type=Path, default=Path("exports"), help="Output directory")
|
||||
parser.add_argument(
|
||||
"--comfyui-url",
|
||||
default=DEFAULT_COMFYUI_URL,
|
||||
help="ComfyUI server URL (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument("--verbose", "-v", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
outputs = generate(
|
||||
workflow_path=args.workflow,
|
||||
image_path=args.image,
|
||||
seed=args.seed,
|
||||
output_dir=args.output_dir,
|
||||
comfyui_url=args.comfyui_url,
|
||||
)
|
||||
print(f"\nGenerated {len(outputs)} output(s):")
|
||||
for p in outputs:
|
||||
print(f" {p}")
|
||||
except Exception:
|
||||
logger.exception("Generation failed")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
150
avatar_pipeline/log_mlflow.py
Normal file
150
avatar_pipeline/log_mlflow.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Log avatar generation results to MLflow via REST API.
|
||||
|
||||
Uses the same lightweight REST-only approach as ray-serve's
|
||||
mlflow_logger.py — no heavyweight mlflow SDK dependency.
|
||||
|
||||
Usage:
|
||||
from avatar_pipeline.log_mlflow import log_generation
|
||||
|
||||
log_generation(
|
||||
avatar_name="Silver-Mage",
|
||||
params={"trellis_seed": 42, "trellis_simplify": 0.95},
|
||||
metrics={"vertex_count": 12345, "face_count": 8000, "duration_s": 45.2},
|
||||
artifacts={"vrm": Path("exports/Silver-Mage.vrm")},
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TRACKING_URI = "http://mlflow.mlflow.svc.cluster.local:80"
|
||||
EXPERIMENT_NAME = "3d-avatar-generation"
|
||||
|
||||
|
||||
def _base_url() -> str:
|
||||
return os.environ.get("MLFLOW_TRACKING_URI", DEFAULT_TRACKING_URI).rstrip("/")
|
||||
|
||||
|
||||
def _post(path: str, body: dict) -> dict:
|
||||
url = f"{_base_url()}/api/2.0/mlflow/{path}"
|
||||
data = json.dumps(body).encode()
|
||||
req = urllib.request.Request(
|
||||
url, data=data, headers={"Content-Type": "application/json"}, method="POST"
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||
return json.loads(resp.read().decode())
|
||||
|
||||
|
||||
def _get_or_create_experiment(name: str) -> str:
|
||||
try:
|
||||
resp = _post("experiments/get-by-name", {"experiment_name": name})
|
||||
return resp["experiment"]["experiment_id"]
|
||||
except urllib.error.HTTPError:
|
||||
resp = _post("experiments/create", {"name": name})
|
||||
return resp["experiment_id"]
|
||||
|
||||
|
||||
def _create_run(experiment_id: str, run_name: str, tags: dict[str, str]) -> str:
|
||||
tag_list = [
|
||||
{"key": k, "value": v}
|
||||
for k, v in {
|
||||
"mlflow.runName": run_name,
|
||||
"mlflow.source.type": "LOCAL",
|
||||
"mlflow.source.name": "avatar-pipeline",
|
||||
"hostname": os.environ.get("HOSTNAME", "desktop"),
|
||||
**tags,
|
||||
}.items()
|
||||
]
|
||||
resp = _post(
|
||||
"runs/create",
|
||||
{
|
||||
"experiment_id": experiment_id,
|
||||
"run_name": run_name,
|
||||
"start_time": int(time.time() * 1000),
|
||||
"tags": tag_list,
|
||||
},
|
||||
)
|
||||
return resp["run"]["info"]["run_id"]
|
||||
|
||||
|
||||
def _log_params(run_id: str, params: dict[str, str | int | float]) -> None:
|
||||
param_list = [{"key": k, "value": str(v)[:500]} for k, v in params.items()]
|
||||
_post("runs/log-batch", {"run_id": run_id, "params": param_list})
|
||||
|
||||
|
||||
def _log_metrics(run_id: str, metrics: dict[str, float]) -> None:
|
||||
ts = int(time.time() * 1000)
|
||||
metric_list = [
|
||||
{"key": k, "value": float(v), "timestamp": ts, "step": 0} for k, v in metrics.items()
|
||||
]
|
||||
_post("runs/log-batch", {"run_id": run_id, "metrics": metric_list})
|
||||
|
||||
|
||||
def _log_artifact(run_id: str, key: str, path: Path) -> None:
|
||||
"""Log an artifact path as a tag (actual artifact upload requires artifact store).
|
||||
|
||||
For local files, we record the path. For S3-promoted files, the caller
|
||||
should include the S3 URI in params instead.
|
||||
"""
|
||||
_post(
|
||||
"runs/log-batch",
|
||||
{
|
||||
"run_id": run_id,
|
||||
"tags": [{"key": f"artifact.{key}", "value": str(path)}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _end_run(run_id: str) -> None:
|
||||
_post(
|
||||
"runs/update",
|
||||
{
|
||||
"run_id": run_id,
|
||||
"status": "FINISHED",
|
||||
"end_time": int(time.time() * 1000),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def log_generation(
|
||||
avatar_name: str,
|
||||
params: dict[str, str | int | float] | None = None,
|
||||
metrics: dict[str, float] | None = None,
|
||||
artifacts: dict[str, Path] | None = None,
|
||||
tags: dict[str, str] | None = None,
|
||||
experiment_name: str = EXPERIMENT_NAME,
|
||||
) -> str | None:
|
||||
"""Log a complete avatar generation run to MLflow.
|
||||
|
||||
Returns the MLflow run_id on success, None on failure.
|
||||
"""
|
||||
try:
|
||||
exp_id = _get_or_create_experiment(experiment_name)
|
||||
run_id = _create_run(exp_id, run_name=avatar_name, tags=tags or {})
|
||||
|
||||
if params:
|
||||
_log_params(run_id, params)
|
||||
|
||||
if metrics:
|
||||
_log_metrics(run_id, metrics)
|
||||
|
||||
if artifacts:
|
||||
for key, path in artifacts.items():
|
||||
_log_artifact(run_id, key, path)
|
||||
|
||||
_end_run(run_id)
|
||||
logger.info("Logged to MLflow: run_id=%s, experiment=%s", run_id, experiment_name)
|
||||
return run_id
|
||||
|
||||
except Exception:
|
||||
logger.warning("MLflow logging failed — generation still succeeded", exc_info=True)
|
||||
return None
|
||||
131
avatar_pipeline/promote.py
Normal file
131
avatar_pipeline/promote.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Promote VRM files to gravenhollow via rclone.
|
||||
|
||||
Usage:
|
||||
avatar-promote exports/Silver-Mage.vrm
|
||||
avatar-promote exports/Silver-Mage.vrm --bucket companion-avatars
|
||||
avatar-promote --dry-run exports/*.vrm
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_REMOTE = "gravenhollow"
|
||||
DEFAULT_BUCKET = "avatar-models"
|
||||
|
||||
|
||||
def check_rclone() -> bool:
|
||||
"""Verify rclone is installed and the remote is configured."""
|
||||
if not shutil.which("rclone"):
|
||||
logger.error("rclone not found. Install: sudo pacman -S rclone")
|
||||
return False
|
||||
|
||||
result = subprocess.run(
|
||||
["rclone", "listremotes"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
remotes = result.stdout.strip().split("\n")
|
||||
|
||||
if f"{DEFAULT_REMOTE}:" not in remotes:
|
||||
logger.error(
|
||||
"rclone remote '%s' not configured. Run scripts/setup.sh or configure manually.",
|
||||
DEFAULT_REMOTE,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def promote(
|
||||
files: list[Path],
|
||||
remote: str = DEFAULT_REMOTE,
|
||||
bucket: str = DEFAULT_BUCKET,
|
||||
dry_run: bool = False,
|
||||
) -> list[str]:
|
||||
"""Copy VRM files to gravenhollow S3 via rclone.
|
||||
|
||||
Returns list of promoted remote paths.
|
||||
"""
|
||||
if not check_rclone():
|
||||
raise RuntimeError("rclone not available")
|
||||
|
||||
promoted = []
|
||||
for file_path in files:
|
||||
if not file_path.exists():
|
||||
logger.warning("File not found, skipping: %s", file_path)
|
||||
continue
|
||||
|
||||
if file_path.suffix.lower() not in (".vrm", ".glb", ".fbx"):
|
||||
logger.warning("Unexpected file type, skipping: %s", file_path)
|
||||
continue
|
||||
|
||||
dest = f"{remote}:{bucket}/{file_path.name}"
|
||||
cmd = ["rclone", "copy", str(file_path), f"{remote}:{bucket}/"]
|
||||
|
||||
if dry_run:
|
||||
cmd.append("--dry-run")
|
||||
|
||||
logger.info("Promoting: %s → %s", file_path, dest)
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error("rclone failed for %s: %s", file_path, result.stderr)
|
||||
continue
|
||||
|
||||
if dry_run:
|
||||
logger.info(" (dry-run) Would copy %s", file_path.name)
|
||||
else:
|
||||
logger.info(" Promoted: %s", dest)
|
||||
|
||||
promoted.append(dest)
|
||||
|
||||
return promoted
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""CLI entry point."""
|
||||
parser = argparse.ArgumentParser(description="Promote VRM files to gravenhollow storage")
|
||||
parser.add_argument("files", nargs="+", type=Path, help="VRM/GLB files to promote")
|
||||
parser.add_argument("--remote", default=DEFAULT_REMOTE, help="rclone remote name")
|
||||
parser.add_argument("--bucket", default=DEFAULT_BUCKET, help="S3 bucket name")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Show what would be copied")
|
||||
parser.add_argument("--verbose", "-v", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
promoted = promote(
|
||||
files=args.files,
|
||||
remote=args.remote,
|
||||
bucket=args.bucket,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
print(f"\nPromoted {len(promoted)} file(s)")
|
||||
for p in promoted:
|
||||
print(f" {p}")
|
||||
except Exception:
|
||||
logger.exception("Promotion failed")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user