- Add error checks for unchecked return values (errcheck) - Remove unused struct fields (unused) - Fix gofmt formatting issues
225 lines
6.3 KiB
Go
225 lines
6.3 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/nats-io/nats.go"
|
|
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
|
)
|
|
|
|
// Pipeline definitions — maps pipeline name to engine config.
|
|
var pipelines = map[string]pipelineDef{
|
|
"document-ingestion": {Engine: "argo", Template: "document-ingestion"},
|
|
"batch-inference": {Engine: "argo", Template: "batch-inference"},
|
|
"rag-query": {Engine: "kubeflow", PipelineID: "rag-pipeline"},
|
|
"voice-pipeline": {Engine: "kubeflow", PipelineID: "voice-pipeline"},
|
|
"model-evaluation": {Engine: "argo", Template: "model-evaluation"},
|
|
}
|
|
|
|
type pipelineDef struct {
|
|
Engine string // "argo" or "kubeflow"
|
|
Template string // Argo WorkflowTemplate name
|
|
PipelineID string // Kubeflow pipeline ID
|
|
}
|
|
|
|
func main() {
|
|
cfg := config.Load()
|
|
cfg.ServiceName = "pipeline-bridge"
|
|
cfg.NATSQueueGroup = "pipeline-bridges"
|
|
|
|
kubeflowHost := getEnv("KUBEFLOW_HOST", "http://ml-pipeline.kubeflow.svc.cluster.local:8888")
|
|
argoHost := getEnv("ARGO_HOST", "http://argo-server.argo.svc.cluster.local:2746")
|
|
argoNamespace := getEnv("ARGO_NAMESPACE", "ai-ml")
|
|
|
|
httpClient := &http.Client{Timeout: 60 * time.Second}
|
|
|
|
h := handler.New("ai.pipeline.trigger", cfg)
|
|
|
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
|
req, err := natsutil.Decode[messages.PipelineTrigger](msg.Data)
|
|
if err != nil {
|
|
return &messages.PipelineStatus{Status: "error", Error: "Invalid request encoding"}, nil
|
|
}
|
|
|
|
requestID := req.RequestID
|
|
if requestID == "" {
|
|
requestID = "unknown"
|
|
}
|
|
pipelineName := req.Pipeline
|
|
params := req.Parameters
|
|
if params == nil {
|
|
params = map[string]any{}
|
|
}
|
|
|
|
slog.Info("triggering pipeline", "pipeline", pipelineName, "request_id", requestID)
|
|
|
|
// Validate pipeline
|
|
pipeline, ok := pipelines[pipelineName]
|
|
if !ok {
|
|
names := make([]string, 0, len(pipelines))
|
|
for k := range pipelines {
|
|
names = append(names, k)
|
|
}
|
|
return &messages.PipelineStatus{
|
|
RequestID: requestID,
|
|
Status: "error",
|
|
Error: fmt.Sprintf("Unknown pipeline: %s", pipelineName),
|
|
AvailablePipelines: names,
|
|
}, nil
|
|
}
|
|
|
|
var runID string
|
|
|
|
if pipeline.Engine == "argo" {
|
|
runID, err = submitArgo(ctx, httpClient, argoHost, argoNamespace, pipeline.Template, params, requestID)
|
|
} else {
|
|
runID, err = submitKubeflow(ctx, httpClient, kubeflowHost, pipeline.PipelineID, params, requestID)
|
|
}
|
|
|
|
if err != nil {
|
|
slog.Error("pipeline submit failed", "pipeline", pipelineName, "error", err)
|
|
return &messages.PipelineStatus{
|
|
RequestID: requestID,
|
|
Status: "error",
|
|
Error: err.Error(),
|
|
}, nil
|
|
}
|
|
|
|
result := &messages.PipelineStatus{
|
|
RequestID: requestID,
|
|
Status: "submitted",
|
|
RunID: runID,
|
|
Engine: pipeline.Engine,
|
|
Pipeline: pipelineName,
|
|
SubmittedAt: time.Now().UTC().Format(time.RFC3339),
|
|
}
|
|
|
|
// Publish status update
|
|
_ = h.NATS.Publish(fmt.Sprintf("ai.pipeline.status.%s", requestID), result)
|
|
|
|
slog.Info("pipeline submitted", "pipeline", pipelineName, "run_id", runID)
|
|
return result, nil
|
|
})
|
|
|
|
if err := h.Run(); err != nil {
|
|
slog.Error("handler failed", "error", err)
|
|
}
|
|
}
|
|
|
|
func submitArgo(ctx context.Context, client *http.Client, host, namespace, template string, params map[string]any, requestID string) (string, error) {
|
|
argoParams := make([]map[string]string, 0, len(params))
|
|
for k, v := range params {
|
|
argoParams = append(argoParams, map[string]string{"name": k, "value": fmt.Sprintf("%v", v)})
|
|
}
|
|
|
|
workflow := map[string]any{
|
|
"apiVersion": "argoproj.io/v1alpha1",
|
|
"kind": "Workflow",
|
|
"metadata": map[string]any{
|
|
"generateName": template + "-",
|
|
"namespace": namespace,
|
|
"labels": map[string]string{"request-id": requestID},
|
|
},
|
|
"spec": map[string]any{
|
|
"workflowTemplateRef": map[string]string{"name": template},
|
|
"arguments": map[string]any{"parameters": argoParams},
|
|
},
|
|
}
|
|
|
|
body, _ := json.Marshal(map[string]any{"workflow": workflow})
|
|
url := fmt.Sprintf("%s/api/v1/workflows/%s", host, namespace)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("argo request: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
|
|
if resp.StatusCode >= 400 {
|
|
return "", fmt.Errorf("argo %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var result struct {
|
|
Metadata struct {
|
|
Name string `json:"name"`
|
|
} `json:"metadata"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
|
return "", err
|
|
}
|
|
return result.Metadata.Name, nil
|
|
}
|
|
|
|
func submitKubeflow(ctx context.Context, client *http.Client, host, pipelineID string, params map[string]any, requestID string) (string, error) {
|
|
kfParams := make([]map[string]string, 0, len(params))
|
|
for k, v := range params {
|
|
kfParams = append(kfParams, map[string]string{"name": k, "value": fmt.Sprintf("%v", v)})
|
|
}
|
|
|
|
runRequest := map[string]any{
|
|
"name": fmt.Sprintf("%s-%s", pipelineID, requestID[:min(8, len(requestID))]),
|
|
"pipeline_spec": map[string]any{
|
|
"pipeline_id": pipelineID,
|
|
"parameters": kfParams,
|
|
},
|
|
}
|
|
|
|
body, _ := json.Marshal(runRequest)
|
|
url := fmt.Sprintf("%s/apis/v1beta1/runs", host)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("kubeflow request: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
|
|
if resp.StatusCode >= 400 {
|
|
return "", fmt.Errorf("kubeflow %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var result struct {
|
|
Run struct {
|
|
ID string `json:"id"`
|
|
} `json:"run"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
|
return "", err
|
|
}
|
|
return result.Run.ID, nil
|
|
}
|
|
|
|
// Helpers
|
|
|
|
func getEnv(key, fallback string) string {
|
|
if v := os.Getenv(key); v != "" {
|
|
return v
|
|
}
|
|
return fallback
|
|
}
|