feature/go-handler-refactor #1
9
.dockerignore
Normal file
9
.dockerignore
Normal file
@@ -0,0 +1,9 @@
|
||||
.git
|
||||
.gitignore
|
||||
*.md
|
||||
LICENSE
|
||||
renovate.json
|
||||
*_test.go
|
||||
e2e_test.go
|
||||
__pycache__
|
||||
.env*
|
||||
@@ -14,7 +14,7 @@ RUN go mod download
|
||||
COPY . .
|
||||
|
||||
# Build static binary
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -o /pipeline-bridge .
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOAMD64=v3 go build -ldflags="-w -s" -o /pipeline-bridge .
|
||||
|
||||
# Runtime stage - scratch for minimal image
|
||||
FROM scratch
|
||||
|
||||
26
e2e_test.go
26
e2e_test.go
@@ -6,6 +6,8 @@ import (
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||
)
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
@@ -117,34 +119,28 @@ func TestPipelineDispatchE2E_AllEngines(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPipelineDispatchE2E_UnknownPipeline(t *testing.T) {
|
||||
// Simulate what main.go's OnMessage does for unknown pipeline
|
||||
data := map[string]any{
|
||||
"request_id": "req-bad",
|
||||
"pipeline": "nonexistent-pipeline",
|
||||
}
|
||||
|
||||
pipelineName := strVal(data, "pipeline", "")
|
||||
// Verify unknown pipeline is rejected and available list is provided
|
||||
pipelineName := "nonexistent-pipeline"
|
||||
_, ok := pipelines[pipelineName]
|
||||
if ok {
|
||||
t.Error("nonexistent pipeline should not be found")
|
||||
}
|
||||
|
||||
// Build error response like main.go
|
||||
names := make([]string, 0, len(pipelines))
|
||||
for k := range pipelines {
|
||||
names = append(names, k)
|
||||
}
|
||||
resp := map[string]any{
|
||||
"request_id": strVal(data, "request_id", ""),
|
||||
"status": "error",
|
||||
"error": "Unknown pipeline: nonexistent-pipeline",
|
||||
"available_pipelines": names,
|
||||
resp := &messages.PipelineStatus{
|
||||
RequestID: "req-bad",
|
||||
Status: "error",
|
||||
Error: "Unknown pipeline: nonexistent-pipeline",
|
||||
AvailablePipelines: names,
|
||||
}
|
||||
|
||||
if resp["status"] != "error" {
|
||||
if resp.Status != "error" {
|
||||
t.Error("expected error status")
|
||||
}
|
||||
if len(resp["available_pipelines"].([]string)) != len(pipelines) {
|
||||
if len(resp.AvailablePipelines) != len(pipelines) {
|
||||
t.Errorf("available_pipelines count mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
72
main.go
72
main.go
@@ -15,6 +15,8 @@ import (
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||
)
|
||||
|
||||
// Pipeline definitions — maps pipeline name to engine config.
|
||||
@@ -45,10 +47,21 @@ func main() {
|
||||
|
||||
h := handler.New("ai.pipeline.trigger", cfg)
|
||||
|
||||
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||
requestID := strVal(data, "request_id", "unknown")
|
||||
pipelineName := strVal(data, "pipeline", "")
|
||||
params := mapVal(data, "parameters")
|
||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
||||
req, err := natsutil.Decode[messages.PipelineTrigger](msg.Data)
|
||||
if err != nil {
|
||||
return &messages.PipelineStatus{Status: "error", Error: "Invalid request encoding"}, nil
|
||||
}
|
||||
|
||||
requestID := req.RequestID
|
||||
if requestID == "" {
|
||||
requestID = "unknown"
|
||||
}
|
||||
pipelineName := req.Pipeline
|
||||
params := req.Parameters
|
||||
if params == nil {
|
||||
params = map[string]any{}
|
||||
}
|
||||
|
||||
slog.Info("triggering pipeline", "pipeline", pipelineName, "request_id", requestID)
|
||||
|
||||
@@ -59,16 +72,15 @@ func main() {
|
||||
for k := range pipelines {
|
||||
names = append(names, k)
|
||||
}
|
||||
return map[string]any{
|
||||
"request_id": requestID,
|
||||
"status": "error",
|
||||
"error": fmt.Sprintf("Unknown pipeline: %s", pipelineName),
|
||||
"available_pipelines": names,
|
||||
return &messages.PipelineStatus{
|
||||
RequestID: requestID,
|
||||
Status: "error",
|
||||
Error: fmt.Sprintf("Unknown pipeline: %s", pipelineName),
|
||||
AvailablePipelines: names,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var runID string
|
||||
var err error
|
||||
|
||||
if pipeline.Engine == "argo" {
|
||||
runID, err = submitArgo(ctx, httpClient, argoHost, argoNamespace, pipeline.Template, params, requestID)
|
||||
@@ -78,20 +90,20 @@ func main() {
|
||||
|
||||
if err != nil {
|
||||
slog.Error("pipeline submit failed", "pipeline", pipelineName, "error", err)
|
||||
return map[string]any{
|
||||
"request_id": requestID,
|
||||
"status": "error",
|
||||
"error": err.Error(),
|
||||
return &messages.PipelineStatus{
|
||||
RequestID: requestID,
|
||||
Status: "error",
|
||||
Error: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"request_id": requestID,
|
||||
"status": "submitted",
|
||||
"run_id": runID,
|
||||
"engine": pipeline.Engine,
|
||||
"pipeline": pipelineName,
|
||||
"submitted_at": time.Now().UTC().Format(time.RFC3339),
|
||||
result := &messages.PipelineStatus{
|
||||
RequestID: requestID,
|
||||
Status: "submitted",
|
||||
RunID: runID,
|
||||
Engine: pipeline.Engine,
|
||||
Pipeline: pipelineName,
|
||||
SubmittedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
// Publish status update
|
||||
@@ -204,24 +216,6 @@ func submitKubeflow(ctx context.Context, client *http.Client, host, pipelineID s
|
||||
|
||||
// Helpers
|
||||
|
||||
func strVal(m map[string]any, key, fallback string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func mapVal(m map[string]any, key string) map[string]any {
|
||||
if v, ok := m[key]; ok {
|
||||
if sub, ok := v.(map[string]any); ok {
|
||||
return sub
|
||||
}
|
||||
}
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
|
||||
63
main_test.go
63
main_test.go
@@ -5,35 +5,58 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
func TestStrVal(t *testing.T) {
|
||||
m := map[string]any{"key": "value", "num": 42}
|
||||
if got := strVal(m, "key", ""); got != "value" {
|
||||
t.Errorf("strVal(key) = %q, want %q", got, "value")
|
||||
func TestPipelineTriggerDecode(t *testing.T) {
|
||||
req := messages.PipelineTrigger{
|
||||
RequestID: "req-001",
|
||||
Pipeline: "document-ingestion",
|
||||
Parameters: map[string]any{"source": "s3://bucket"},
|
||||
}
|
||||
if got := strVal(m, "missing", "default"); got != "default" {
|
||||
t.Errorf("strVal(missing) = %q, want %q", got, "default")
|
||||
data, err := msgpack.Marshal(&req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := strVal(m, "num", "fallback"); got != "fallback" {
|
||||
t.Errorf("strVal(num) = %q, want %q", got, "fallback")
|
||||
decoded, err := natsutil.Decode[messages.PipelineTrigger](data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if decoded.RequestID != "req-001" {
|
||||
t.Errorf("RequestID = %q", decoded.RequestID)
|
||||
}
|
||||
if decoded.Pipeline != "document-ingestion" {
|
||||
t.Errorf("Pipeline = %q", decoded.Pipeline)
|
||||
}
|
||||
if decoded.Parameters["source"] != "s3://bucket" {
|
||||
t.Errorf("Parameters = %v", decoded.Parameters)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapVal(t *testing.T) {
|
||||
inner := map[string]any{"a": "b"}
|
||||
m := map[string]any{"nested": inner, "scalar": "hi"}
|
||||
got := mapVal(m, "nested")
|
||||
if got["a"] != "b" {
|
||||
t.Errorf("mapVal(nested) = %v, want {a:b}", got)
|
||||
func TestPipelineStatusRoundtrip(t *testing.T) {
|
||||
status := messages.PipelineStatus{
|
||||
RequestID: "req-002",
|
||||
Status: "submitted",
|
||||
RunID: "argo-abc123",
|
||||
Engine: "argo",
|
||||
Pipeline: "batch-inference",
|
||||
}
|
||||
got2 := mapVal(m, "missing")
|
||||
if len(got2) != 0 {
|
||||
t.Errorf("mapVal(missing) should be empty, got %v", got2)
|
||||
data, err := msgpack.Marshal(&status)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got3 := mapVal(m, "scalar")
|
||||
if len(got3) != 0 {
|
||||
t.Errorf("mapVal(scalar) should be empty, got %v", got3)
|
||||
var got messages.PipelineStatus
|
||||
if err := msgpack.Unmarshal(data, &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.RunID != "argo-abc123" {
|
||||
t.Errorf("RunID = %q", got.RunID)
|
||||
}
|
||||
if got.Engine != "argo" {
|
||||
t.Errorf("Engine = %q", got.Engine)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user