feat: Add ML training and batch inference workflows
- batch-inference: LLM inference with optional RAG - qlora-training: QLoRA adapter fine-tuning from Milvus - hybrid-ml-training: Multi-GPU distributed training - coqui-voice-training: XTTS voice cloning - document-ingestion: Ingest documents to Milvus - eventsource-kfp: Argo Events / Kubeflow integration - kfp-integration: Bridge between Argo and Kubeflow
This commit is contained in:
328
batch-inference.yaml
Normal file
328
batch-inference.yaml
Normal file
@@ -0,0 +1,328 @@
|
||||
# Batch Inference Workflow
|
||||
# Runs LLM inference on a batch of inputs
|
||||
# Triggered via NATS: ai.pipeline.trigger with pipeline="batch-inference"
|
||||
---
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: WorkflowTemplate
|
||||
metadata:
|
||||
name: batch-inference
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: batch-inference
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
spec:
|
||||
entrypoint: batch-inference
|
||||
serviceAccountName: argo-workflow
|
||||
|
||||
arguments:
|
||||
parameters:
|
||||
- name: input-url
|
||||
description: "URL to JSON file with inference requests"
|
||||
- name: output-url
|
||||
description: "URL to store results (S3 path)"
|
||||
value: ""
|
||||
- name: use-rag
|
||||
value: "true"
|
||||
description: "Whether to use RAG for context"
|
||||
- name: max-tokens
|
||||
value: "500"
|
||||
description: "Maximum tokens per response"
|
||||
- name: temperature
|
||||
value: "0.7"
|
||||
description: "LLM temperature"
|
||||
|
||||
templates:
|
||||
- name: batch-inference
|
||||
dag:
|
||||
tasks:
|
||||
- name: fetch-inputs
|
||||
template: fetch-input-data
|
||||
arguments:
|
||||
parameters:
|
||||
- name: input-url
|
||||
value: "{{workflow.parameters.input-url}}"
|
||||
|
||||
- name: run-inference
|
||||
template: inference
|
||||
dependencies: [fetch-inputs]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: use-rag
|
||||
value: "{{workflow.parameters.use-rag}}"
|
||||
- name: max-tokens
|
||||
value: "{{workflow.parameters.max-tokens}}"
|
||||
- name: temperature
|
||||
value: "{{workflow.parameters.temperature}}"
|
||||
artifacts:
|
||||
- name: inputs
|
||||
from: "{{tasks.fetch-inputs.outputs.artifacts.inputs}}"
|
||||
|
||||
- name: upload-results
|
||||
template: upload-output
|
||||
dependencies: [run-inference]
|
||||
when: "{{workflow.parameters.output-url}} != ''"
|
||||
arguments:
|
||||
parameters:
|
||||
- name: output-url
|
||||
value: "{{workflow.parameters.output-url}}"
|
||||
artifacts:
|
||||
- name: results
|
||||
from: "{{tasks.run-inference.outputs.artifacts.results}}"
|
||||
|
||||
- name: fetch-input-data
|
||||
inputs:
|
||||
parameters:
|
||||
- name: input-url
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: inputs
|
||||
path: /tmp/inputs
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import json
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
input_url = "{{inputs.parameters.input-url}}"
|
||||
output_dir = Path("/tmp/inputs")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Fetching inputs from: {input_url}")
|
||||
|
||||
if input_url.startswith("s3://"):
|
||||
import subprocess
|
||||
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
|
||||
import boto3
|
||||
s3 = boto3.client("s3")
|
||||
bucket, key = input_url[5:].split("/", 1)
|
||||
s3.download_file(bucket, key, str(output_dir / "inputs.json"))
|
||||
elif input_url.startswith("http"):
|
||||
urllib.request.urlretrieve(input_url, output_dir / "inputs.json")
|
||||
else:
|
||||
print(f"Unsupported URL scheme: {input_url}")
|
||||
exit(1)
|
||||
|
||||
# Validate JSON structure
|
||||
with open(output_dir / "inputs.json") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if "requests" not in data:
|
||||
print("Error: JSON must contain 'requests' array")
|
||||
exit(1)
|
||||
|
||||
print(f"Loaded {len(data['requests'])} inference requests")
|
||||
resources:
|
||||
requests:
|
||||
memory: 256Mi
|
||||
cpu: 100m
|
||||
|
||||
- name: inference
|
||||
inputs:
|
||||
parameters:
|
||||
- name: use-rag
|
||||
- name: max-tokens
|
||||
- name: temperature
|
||||
artifacts:
|
||||
- name: inputs
|
||||
path: /tmp/inputs
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: results
|
||||
path: /tmp/results
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import subprocess
|
||||
subprocess.run(["pip", "install", "httpx", "pymilvus", "-q"], check=True)
|
||||
|
||||
import json
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from typing import List, Dict
|
||||
|
||||
# Configuration
|
||||
VLLM_URL = "http://llm-draft.ai-ml.svc.cluster.local:8000"
|
||||
EMBEDDINGS_URL = "http://embeddings-predictor.ai-ml.svc.cluster.local"
|
||||
RERANKER_URL = "http://reranker-predictor.ai-ml.svc.cluster.local"
|
||||
MILVUS_HOST = "milvus.ai-ml.svc.cluster.local"
|
||||
LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
|
||||
use_rag = "{{inputs.parameters.use-rag}}" == "true"
|
||||
max_tokens = int("{{inputs.parameters.max-tokens}}")
|
||||
temperature = float("{{inputs.parameters.temperature}}")
|
||||
|
||||
input_dir = Path("/tmp/inputs")
|
||||
output_dir = Path("/tmp/results")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load inputs
|
||||
with open(input_dir / "inputs.json") as f:
|
||||
data = json.load(f)
|
||||
requests = data["requests"]
|
||||
|
||||
print(f"Processing {len(requests)} requests (RAG: {use_rag})")
|
||||
|
||||
# Initialize Milvus if using RAG
|
||||
collection = None
|
||||
if use_rag:
|
||||
try:
|
||||
from pymilvus import connections, Collection, utility
|
||||
connections.connect(host=MILVUS_HOST, port=19530)
|
||||
if utility.has_collection("knowledge_base"):
|
||||
collection = Collection("knowledge_base")
|
||||
collection.load()
|
||||
print("Milvus connected")
|
||||
except Exception as e:
|
||||
print(f"Milvus connection failed: {e}")
|
||||
use_rag = False
|
||||
|
||||
def get_embeddings(texts: List[str], client: httpx.Client) -> List[List[float]]:
|
||||
response = client.post(
|
||||
f"{EMBEDDINGS_URL}/embeddings",
|
||||
json={"input": texts, "model": "bge"}
|
||||
)
|
||||
result = response.json()
|
||||
return [d["embedding"] for d in result.get("data", [])]
|
||||
|
||||
def search_milvus(embedding: List[float]) -> List[Dict]:
|
||||
results = collection.search(
|
||||
data=[embedding],
|
||||
anns_field="embedding",
|
||||
param={"metric_type": "COSINE", "params": {"ef": 64}},
|
||||
limit=5,
|
||||
output_fields=["text", "source"]
|
||||
)
|
||||
docs = []
|
||||
for hits in results:
|
||||
for hit in hits:
|
||||
docs.append({
|
||||
"text": hit.entity.get("text", ""),
|
||||
"source": hit.entity.get("source", ""),
|
||||
"score": hit.score
|
||||
})
|
||||
return docs
|
||||
|
||||
def rerank(query: str, documents: List[str], client: httpx.Client) -> List[Dict]:
|
||||
response = client.post(
|
||||
f"{RERANKER_URL}/v1/rerank",
|
||||
json={"query": query, "documents": documents}
|
||||
)
|
||||
return response.json().get("results", [])
|
||||
|
||||
# Process requests
|
||||
results = []
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
for i, req in enumerate(requests):
|
||||
query = req.get("text", req.get("query", ""))
|
||||
req_id = req.get("id", str(i))
|
||||
|
||||
print(f"Processing {i+1}/{len(requests)}: {query[:50]}...")
|
||||
|
||||
context = ""
|
||||
rag_sources = []
|
||||
|
||||
if use_rag and collection:
|
||||
try:
|
||||
# Get embeddings and search
|
||||
embeddings = get_embeddings([query], client)
|
||||
if embeddings:
|
||||
docs = search_milvus(embeddings[0])
|
||||
if docs:
|
||||
doc_texts = [d["text"] for d in docs]
|
||||
reranked = rerank(query, doc_texts, client)
|
||||
sorted_docs = sorted(reranked, key=lambda x: x.get("relevance_score", 0), reverse=True)[:3]
|
||||
context = "\n\n".join([doc_texts[d["index"]] for d in sorted_docs])
|
||||
rag_sources = [docs[d["index"]].get("source", "") for d in sorted_docs]
|
||||
except Exception as e:
|
||||
print(f" RAG failed: {e}")
|
||||
|
||||
# Generate response
|
||||
try:
|
||||
messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
|
||||
if context:
|
||||
messages.append({"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"})
|
||||
else:
|
||||
messages.append({"role": "user", "content": query})
|
||||
|
||||
response = client.post(
|
||||
f"{VLLM_URL}/v1/chat/completions",
|
||||
json={
|
||||
"model": LLM_MODEL,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature
|
||||
}
|
||||
)
|
||||
result = response.json()
|
||||
answer = result["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
answer = f"Error: {e}"
|
||||
|
||||
results.append({
|
||||
"id": req_id,
|
||||
"query": query,
|
||||
"response": answer,
|
||||
"used_rag": bool(context),
|
||||
"rag_sources": rag_sources
|
||||
})
|
||||
|
||||
# Save results
|
||||
with open(output_dir / "results.json", "w") as f:
|
||||
json.dump({"results": results}, f, indent=2)
|
||||
|
||||
print(f"Completed {len(results)} inferences")
|
||||
|
||||
if collection:
|
||||
from pymilvus import connections
|
||||
connections.disconnect("default")
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: ai-services-config
|
||||
resources:
|
||||
requests:
|
||||
memory: 1Gi
|
||||
cpu: 500m
|
||||
|
||||
- name: upload-output
|
||||
inputs:
|
||||
parameters:
|
||||
- name: output-url
|
||||
artifacts:
|
||||
- name: results
|
||||
path: /tmp/results
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import subprocess
|
||||
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
|
||||
|
||||
import boto3
|
||||
from pathlib import Path
|
||||
|
||||
output_url = "{{inputs.parameters.output-url}}"
|
||||
results_file = Path("/tmp/results/results.json")
|
||||
|
||||
print(f"Uploading results to: {output_url}")
|
||||
|
||||
if output_url.startswith("s3://"):
|
||||
s3 = boto3.client("s3")
|
||||
bucket, key = output_url[5:].split("/", 1)
|
||||
s3.upload_file(str(results_file), bucket, key)
|
||||
print("Upload complete")
|
||||
else:
|
||||
print(f"Unsupported URL scheme: {output_url}")
|
||||
exit(1)
|
||||
resources:
|
||||
requests:
|
||||
memory: 256Mi
|
||||
cpu: 100m
|
||||
Reference in New Issue
Block a user