diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..7ea5baa --- /dev/null +++ b/.dockerignore @@ -0,0 +1,9 @@ +.git +.gitignore +*.md +LICENSE +renovate.json +*_test.go +e2e_test.go +__pycache__ +.env* diff --git a/Dockerfile b/Dockerfile index c2e2f0c..1a67226 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ RUN go mod download COPY . . # Build static binary -RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -o /pipeline-bridge . +RUN CGO_ENABLED=0 GOOS=linux GOAMD64=v3 go build -ldflags="-w -s" -o /pipeline-bridge . # Runtime stage - scratch for minimal image FROM scratch diff --git a/e2e_test.go b/e2e_test.go index a1900d5..6805a8f 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -6,6 +6,8 @@ import ( "net/http/httptest" "sync/atomic" "testing" + + "git.daviestechlabs.io/daviestechlabs/handler-base/messages" ) // ──────────────────────────────────────────────────────────────────────────── @@ -117,34 +119,28 @@ func TestPipelineDispatchE2E_AllEngines(t *testing.T) { } func TestPipelineDispatchE2E_UnknownPipeline(t *testing.T) { - // Simulate what main.go's OnMessage does for unknown pipeline - data := map[string]any{ - "request_id": "req-bad", - "pipeline": "nonexistent-pipeline", - } - - pipelineName := strVal(data, "pipeline", "") + // Verify unknown pipeline is rejected and available list is provided + pipelineName := "nonexistent-pipeline" _, ok := pipelines[pipelineName] if ok { t.Error("nonexistent pipeline should not be found") } - // Build error response like main.go names := make([]string, 0, len(pipelines)) for k := range pipelines { names = append(names, k) } - resp := map[string]any{ - "request_id": strVal(data, "request_id", ""), - "status": "error", - "error": "Unknown pipeline: nonexistent-pipeline", - "available_pipelines": names, + resp := &messages.PipelineStatus{ + RequestID: "req-bad", + Status: "error", + Error: "Unknown pipeline: nonexistent-pipeline", + AvailablePipelines: names, } - if resp["status"] != "error" { + if resp.Status != "error" { t.Error("expected error status") } - if len(resp["available_pipelines"].([]string)) != len(pipelines) { + if len(resp.AvailablePipelines) != len(pipelines) { t.Errorf("available_pipelines count mismatch") } } diff --git a/main.go b/main.go index 20ebe37..26fa09e 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,8 @@ import ( "git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/handler" + "git.daviestechlabs.io/daviestechlabs/handler-base/messages" + "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" ) // Pipeline definitions — maps pipeline name to engine config. @@ -45,10 +47,21 @@ func main() { h := handler.New("ai.pipeline.trigger", cfg) - h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { - requestID := strVal(data, "request_id", "unknown") - pipelineName := strVal(data, "pipeline", "") - params := mapVal(data, "parameters") + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) { + req, err := natsutil.Decode[messages.PipelineTrigger](msg.Data) + if err != nil { + return &messages.PipelineStatus{Status: "error", Error: "Invalid request encoding"}, nil + } + + requestID := req.RequestID + if requestID == "" { + requestID = "unknown" + } + pipelineName := req.Pipeline + params := req.Parameters + if params == nil { + params = map[string]any{} + } slog.Info("triggering pipeline", "pipeline", pipelineName, "request_id", requestID) @@ -59,16 +72,15 @@ func main() { for k := range pipelines { names = append(names, k) } - return map[string]any{ - "request_id": requestID, - "status": "error", - "error": fmt.Sprintf("Unknown pipeline: %s", pipelineName), - "available_pipelines": names, + return &messages.PipelineStatus{ + RequestID: requestID, + Status: "error", + Error: fmt.Sprintf("Unknown pipeline: %s", pipelineName), + AvailablePipelines: names, }, nil } var runID string - var err error if pipeline.Engine == "argo" { runID, err = submitArgo(ctx, httpClient, argoHost, argoNamespace, pipeline.Template, params, requestID) @@ -78,20 +90,20 @@ func main() { if err != nil { slog.Error("pipeline submit failed", "pipeline", pipelineName, "error", err) - return map[string]any{ - "request_id": requestID, - "status": "error", - "error": err.Error(), + return &messages.PipelineStatus{ + RequestID: requestID, + Status: "error", + Error: err.Error(), }, nil } - result := map[string]any{ - "request_id": requestID, - "status": "submitted", - "run_id": runID, - "engine": pipeline.Engine, - "pipeline": pipelineName, - "submitted_at": time.Now().UTC().Format(time.RFC3339), + result := &messages.PipelineStatus{ + RequestID: requestID, + Status: "submitted", + RunID: runID, + Engine: pipeline.Engine, + Pipeline: pipelineName, + SubmittedAt: time.Now().UTC().Format(time.RFC3339), } // Publish status update @@ -204,24 +216,6 @@ func submitKubeflow(ctx context.Context, client *http.Client, host, pipelineID s // Helpers -func strVal(m map[string]any, key, fallback string) string { - if v, ok := m[key]; ok { - if s, ok := v.(string); ok { - return s - } - } - return fallback -} - -func mapVal(m map[string]any, key string) map[string]any { - if v, ok := m[key]; ok { - if sub, ok := v.(map[string]any); ok { - return sub - } - } - return map[string]any{} -} - func getEnv(key, fallback string) string { if v := os.Getenv(key); v != "" { return v diff --git a/main_test.go b/main_test.go index e63a0b4..640964c 100644 --- a/main_test.go +++ b/main_test.go @@ -5,35 +5,58 @@ import ( "net/http" "net/http/httptest" "testing" + + "git.daviestechlabs.io/daviestechlabs/handler-base/messages" + "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" + "github.com/vmihailenco/msgpack/v5" ) -func TestStrVal(t *testing.T) { - m := map[string]any{"key": "value", "num": 42} - if got := strVal(m, "key", ""); got != "value" { - t.Errorf("strVal(key) = %q, want %q", got, "value") +func TestPipelineTriggerDecode(t *testing.T) { + req := messages.PipelineTrigger{ + RequestID: "req-001", + Pipeline: "document-ingestion", + Parameters: map[string]any{"source": "s3://bucket"}, } - if got := strVal(m, "missing", "default"); got != "default" { - t.Errorf("strVal(missing) = %q, want %q", got, "default") + data, err := msgpack.Marshal(&req) + if err != nil { + t.Fatal(err) } - if got := strVal(m, "num", "fallback"); got != "fallback" { - t.Errorf("strVal(num) = %q, want %q", got, "fallback") + decoded, err := natsutil.Decode[messages.PipelineTrigger](data) + if err != nil { + t.Fatal(err) + } + if decoded.RequestID != "req-001" { + t.Errorf("RequestID = %q", decoded.RequestID) + } + if decoded.Pipeline != "document-ingestion" { + t.Errorf("Pipeline = %q", decoded.Pipeline) + } + if decoded.Parameters["source"] != "s3://bucket" { + t.Errorf("Parameters = %v", decoded.Parameters) } } -func TestMapVal(t *testing.T) { - inner := map[string]any{"a": "b"} - m := map[string]any{"nested": inner, "scalar": "hi"} - got := mapVal(m, "nested") - if got["a"] != "b" { - t.Errorf("mapVal(nested) = %v, want {a:b}", got) +func TestPipelineStatusRoundtrip(t *testing.T) { + status := messages.PipelineStatus{ + RequestID: "req-002", + Status: "submitted", + RunID: "argo-abc123", + Engine: "argo", + Pipeline: "batch-inference", } - got2 := mapVal(m, "missing") - if len(got2) != 0 { - t.Errorf("mapVal(missing) should be empty, got %v", got2) + data, err := msgpack.Marshal(&status) + if err != nil { + t.Fatal(err) } - got3 := mapVal(m, "scalar") - if len(got3) != 0 { - t.Errorf("mapVal(scalar) should be empty, got %v", got3) + var got messages.PipelineStatus + if err := msgpack.Unmarshal(data, &got); err != nil { + t.Fatal(err) + } + if got.RunID != "argo-abc123" { + t.Errorf("RunID = %q", got.RunID) + } + if got.Engine != "argo" { + t.Errorf("Engine = %q", got.Engine) } }