# Batch Inference Workflow # Runs LLM inference on a batch of inputs # Triggered via NATS: ai.pipeline.trigger with pipeline="batch-inference" --- apiVersion: argoproj.io/v1alpha1 kind: WorkflowTemplate metadata: name: batch-inference namespace: ai-ml labels: app.kubernetes.io/name: batch-inference app.kubernetes.io/part-of: ai-ml-pipelines spec: entrypoint: batch-inference serviceAccountName: argo-workflow arguments: parameters: - name: input-url description: "URL to JSON file with inference requests" - name: output-url description: "URL to store results (S3 path)" value: "" - name: use-rag value: "true" description: "Whether to use RAG for context" - name: max-tokens value: "500" description: "Maximum tokens per response" - name: temperature value: "0.7" description: "LLM temperature" templates: - name: batch-inference dag: tasks: - name: fetch-inputs template: fetch-input-data arguments: parameters: - name: input-url value: "{{workflow.parameters.input-url}}" - name: run-inference template: inference dependencies: [fetch-inputs] arguments: parameters: - name: use-rag value: "{{workflow.parameters.use-rag}}" - name: max-tokens value: "{{workflow.parameters.max-tokens}}" - name: temperature value: "{{workflow.parameters.temperature}}" artifacts: - name: inputs from: "{{tasks.fetch-inputs.outputs.artifacts.inputs}}" - name: upload-results template: upload-output dependencies: [run-inference] when: "{{workflow.parameters.output-url}} != ''" arguments: parameters: - name: output-url value: "{{workflow.parameters.output-url}}" artifacts: - name: results from: "{{tasks.run-inference.outputs.artifacts.results}}" - name: fetch-input-data inputs: parameters: - name: input-url outputs: artifacts: - name: inputs path: /tmp/inputs container: image: python:3.13-slim command: [python] args: - -c - | import json import urllib.request from pathlib import Path input_url = "{{inputs.parameters.input-url}}" output_dir = Path("/tmp/inputs") output_dir.mkdir(parents=True, exist_ok=True) print(f"Fetching inputs from: {input_url}") if input_url.startswith("s3://"): import subprocess subprocess.run(["pip", "install", "boto3", "-q"], check=True) import boto3 s3 = boto3.client("s3") bucket, key = input_url[5:].split("/", 1) s3.download_file(bucket, key, str(output_dir / "inputs.json")) elif input_url.startswith("http"): urllib.request.urlretrieve(input_url, output_dir / "inputs.json") else: print(f"Unsupported URL scheme: {input_url}") exit(1) # Validate JSON structure with open(output_dir / "inputs.json") as f: data = json.load(f) if "requests" not in data: print("Error: JSON must contain 'requests' array") exit(1) print(f"Loaded {len(data['requests'])} inference requests") resources: requests: memory: 256Mi cpu: 100m - name: inference inputs: parameters: - name: use-rag - name: max-tokens - name: temperature artifacts: - name: inputs path: /tmp/inputs outputs: artifacts: - name: results path: /tmp/results container: image: python:3.13-slim command: [python] args: - -c - | import subprocess subprocess.run(["pip", "install", "httpx", "pymilvus", "-q"], check=True) import json import httpx from pathlib import Path from typing import List, Dict # Configuration VLLM_URL = "http://llm-draft.ai-ml.svc.cluster.local:8000" EMBEDDINGS_URL = "http://embeddings-predictor.ai-ml.svc.cluster.local" RERANKER_URL = "http://reranker-predictor.ai-ml.svc.cluster.local" MILVUS_HOST = "milvus.ai-ml.svc.cluster.local" LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3" use_rag = "{{inputs.parameters.use-rag}}" == "true" max_tokens = int("{{inputs.parameters.max-tokens}}") temperature = float("{{inputs.parameters.temperature}}") input_dir = Path("/tmp/inputs") output_dir = Path("/tmp/results") output_dir.mkdir(parents=True, exist_ok=True) # Load inputs with open(input_dir / "inputs.json") as f: data = json.load(f) requests = data["requests"] print(f"Processing {len(requests)} requests (RAG: {use_rag})") # Initialize Milvus if using RAG collection = None if use_rag: try: from pymilvus import connections, Collection, utility connections.connect(host=MILVUS_HOST, port=19530) if utility.has_collection("knowledge_base"): collection = Collection("knowledge_base") collection.load() print("Milvus connected") except Exception as e: print(f"Milvus connection failed: {e}") use_rag = False def get_embeddings(texts: List[str], client: httpx.Client) -> List[List[float]]: response = client.post( f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"} ) result = response.json() return [d["embedding"] for d in result.get("data", [])] def search_milvus(embedding: List[float]) -> List[Dict]: results = collection.search( data=[embedding], anns_field="embedding", param={"metric_type": "COSINE", "params": {"ef": 64}}, limit=5, output_fields=["text", "source"] ) docs = [] for hits in results: for hit in hits: docs.append({ "text": hit.entity.get("text", ""), "source": hit.entity.get("source", ""), "score": hit.score }) return docs def rerank(query: str, documents: List[str], client: httpx.Client) -> List[Dict]: response = client.post( f"{RERANKER_URL}/v1/rerank", json={"query": query, "documents": documents} ) return response.json().get("results", []) # Process requests results = [] with httpx.Client(timeout=120.0) as client: for i, req in enumerate(requests): query = req.get("text", req.get("query", "")) req_id = req.get("id", str(i)) print(f"Processing {i+1}/{len(requests)}: {query[:50]}...") context = "" rag_sources = [] if use_rag and collection: try: # Get embeddings and search embeddings = get_embeddings([query], client) if embeddings: docs = search_milvus(embeddings[0]) if docs: doc_texts = [d["text"] for d in docs] reranked = rerank(query, doc_texts, client) sorted_docs = sorted(reranked, key=lambda x: x.get("relevance_score", 0), reverse=True)[:3] context = "\n\n".join([doc_texts[d["index"]] for d in sorted_docs]) rag_sources = [docs[d["index"]].get("source", "") for d in sorted_docs] except Exception as e: print(f" RAG failed: {e}") # Generate response try: messages = [{"role": "system", "content": "You are a helpful AI assistant."}] if context: messages.append({"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}) else: messages.append({"role": "user", "content": query}) response = client.post( f"{VLLM_URL}/v1/chat/completions", json={ "model": LLM_MODEL, "messages": messages, "max_tokens": max_tokens, "temperature": temperature } ) result = response.json() answer = result["choices"][0]["message"]["content"] except Exception as e: answer = f"Error: {e}" results.append({ "id": req_id, "query": query, "response": answer, "used_rag": bool(context), "rag_sources": rag_sources }) # Save results with open(output_dir / "results.json", "w") as f: json.dump({"results": results}, f, indent=2) print(f"Completed {len(results)} inferences") if collection: from pymilvus import connections connections.disconnect("default") envFrom: - configMapRef: name: ai-services-config resources: requests: memory: 1Gi cpu: 500m - name: upload-output inputs: parameters: - name: output-url artifacts: - name: results path: /tmp/results container: image: python:3.13-slim command: [python] args: - -c - | import subprocess subprocess.run(["pip", "install", "boto3", "-q"], check=True) import boto3 from pathlib import Path output_url = "{{inputs.parameters.output-url}}" results_file = Path("/tmp/results/results.json") print(f"Uploading results to: {output_url}") if output_url.startswith("s3://"): s3 = boto3.client("s3") bucket, key = output_url[5:].split("/", 1) s3.upload_file(str(results_file), bucket, key) print("Upload complete") else: print(f"Unsupported URL scheme: {output_url}") exit(1) resources: requests: memory: 256Mi cpu: 100m