Files
pipeline-bridge/main_test.go
Billy D. 66ef758808
Some checks failed
CI / Lint (push) Successful in 3m9s
CI / Test (push) Successful in 2m42s
CI / Release (push) Successful in 1m0s
CI / Notify (push) Successful in 2s
CI / Docker Build & Push (push) Failing after 9m28s
feat: migrate from msgpack to protobuf (handler-base v1.0.0)
- Replace msgpack encoding with protobuf wire format
- Update field names to proto convention
- Change Parameters type from map[string]any to map[string]string
- Rewrite tests for proto round-trips
2026-02-21 15:30:17 -05:00

187 lines
5.2 KiB
Go

package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
"google.golang.org/protobuf/proto"
)
func TestPipelineTriggerDecode(t *testing.T) {
req := &messages.PipelineTrigger{
RequestId: "req-001",
Pipeline: "document-ingestion",
Parameters: map[string]string{"source": "s3://bucket"},
}
data, err := proto.Marshal(req)
if err != nil {
t.Fatal(err)
}
var decoded messages.PipelineTrigger
if err := natsutil.Decode(data, &decoded); err != nil {
t.Fatal(err)
}
if decoded.RequestId != "req-001" {
t.Errorf("RequestID = %q", decoded.RequestId)
}
if decoded.Pipeline != "document-ingestion" {
t.Errorf("Pipeline = %q", decoded.Pipeline)
}
if decoded.Parameters["source"] != "s3://bucket" {
t.Errorf("Parameters = %v", decoded.Parameters)
}
}
func TestPipelineStatusRoundtrip(t *testing.T) {
status := messages.PipelineStatus{
RequestId: "req-002",
Status: "submitted",
RunId: "argo-abc123",
Engine: "argo",
Pipeline: "batch-inference",
}
data, err := proto.Marshal(&status)
if err != nil {
t.Fatal(err)
}
var got messages.PipelineStatus
if err := proto.Unmarshal(data, &got); err != nil {
t.Fatal(err)
}
if got.RunId != "argo-abc123" {
t.Errorf("RunID = %q", got.RunId)
}
if got.Engine != "argo" {
t.Errorf("Engine = %q", got.Engine)
}
}
func TestGetEnv(t *testing.T) {
t.Setenv("TEST_HOST", "http://test:8080")
if got := getEnv("TEST_HOST", "fallback"); got != "http://test:8080" {
t.Errorf("getEnv(TEST_HOST) = %q, want %q", got, "http://test:8080")
}
if got := getEnv("MISSING_VAR_XYZ", "default"); got != "default" {
t.Errorf("getEnv(MISSING) = %q, want %q", got, "default")
}
}
func TestPipelinesMap(t *testing.T) {
expected := []string{"document-ingestion", "batch-inference", "rag-query", "voice-pipeline", "model-evaluation"}
for _, name := range expected {
if _, ok := pipelines[name]; !ok {
t.Errorf("pipeline %q not found in pipelines map", name)
}
}
if got := pipelines["document-ingestion"].Engine; got != "argo" {
t.Errorf("document-ingestion engine = %q, want argo", got)
}
if got := pipelines["rag-query"].Engine; got != "kubeflow" {
t.Errorf("rag-query engine = %q, want kubeflow", got)
}
}
func TestSubmitArgo(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/api/v1/workflows/ai-ml" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("expected application/json content type")
}
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Errorf("failed to decode body: %v", err)
}
wf, ok := body["workflow"].(map[string]any)
if !ok {
t.Fatal("missing workflow key")
}
meta := wf["metadata"].(map[string]any)
if meta["namespace"] != "ai-ml" {
t.Errorf("unexpected namespace: %v", meta["namespace"])
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"metadata": map[string]any{"name": "document-ingestion-abc123"},
})
}))
defer ts.Close()
ctx := t.Context()
runID, err := submitArgo(ctx, ts.Client(), ts.URL, "ai-ml", "document-ingestion", map[string]string{
"source": "test",
}, "req-001")
if err != nil {
t.Fatalf("submitArgo() error: %v", err)
}
if runID != "document-ingestion-abc123" {
t.Errorf("runID = %q, want %q", runID, "document-ingestion-abc123")
}
}
func TestSubmitArgoError(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"message":"bad request"}`))
}))
defer ts.Close()
ctx := t.Context()
_, err := submitArgo(ctx, ts.Client(), ts.URL, "ai-ml", "bad-template", nil, "req-err")
if err == nil {
t.Fatal("expected error for 400 response")
}
}
func TestSubmitKubeflow(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/apis/v1beta1/runs" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"run": map[string]any{"id": "kf-run-456"},
})
}))
defer ts.Close()
ctx := t.Context()
runID, err := submitKubeflow(ctx, ts.Client(), ts.URL, "rag-pipeline", map[string]string{
"query": "test",
}, "req-002")
if err != nil {
t.Fatalf("submitKubeflow() error: %v", err)
}
if runID != "kf-run-456" {
t.Errorf("runID = %q, want %q", runID, "kf-run-456")
}
}
func TestSubmitKubeflowError(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("internal error"))
}))
defer ts.Close()
ctx := t.Context()
_, err := submitKubeflow(ctx, ts.Client(), ts.URL, "bad-pipeline", nil, "req-err2")
if err == nil {
t.Fatal("expected error for 500 response")
}
}