From 4175e2070ca57e95d357593aacdbde72ba0938cf Mon Sep 17 00:00:00 2001 From: "Billy D." Date: Fri, 20 Feb 2026 07:10:43 -0500 Subject: [PATCH] feat: migrate to typed messages, drop base64 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Switch OnMessage → OnTypedMessage with natsutil.Decode[messages.ChatRequest] - Return *messages.ChatResponse / *messages.ChatStreamChunk (not map[string]any) - Audio as raw []byte in msgpack (25% wire savings vs base64) - Remove strVal/boolVal/intVal helpers - Add .dockerignore, GOAMD64=v3 in Dockerfile - Update tests for typed structs (9 tests pass) --- .dockerignore | 9 ++++ Dockerfile | 2 +- e2e_test.go | 63 +++++++++++----------- main.go | 144 +++++++++++++++++++++++--------------------------- main_test.go | 71 +++++++++++++++---------- 5 files changed, 150 insertions(+), 139 deletions(-) create mode 100644 .dockerignore diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..7ea5baa --- /dev/null +++ b/.dockerignore @@ -0,0 +1,9 @@ +.git +.gitignore +*.md +LICENSE +renovate.json +*_test.go +e2e_test.go +__pycache__ +.env* diff --git a/Dockerfile b/Dockerfile index f09750c..3603fca 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ RUN go mod download COPY . . -RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -o /chat-handler . +RUN CGO_ENABLED=0 GOOS=linux GOAMD64=v3 go build -ldflags="-w -s" -o /chat-handler . # Runtime stage FROM scratch diff --git a/e2e_test.go b/e2e_test.go index 8324f2e..1697768 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -9,6 +9,8 @@ import ( "time" "git.daviestechlabs.io/daviestechlabs/handler-base/clients" +"git.daviestechlabs.io/daviestechlabs/handler-base/messages" +"github.com/vmihailenco/msgpack/v5" ) // ──────────────────────────────────────────────────────────────────────────── @@ -161,39 +163,38 @@ t.Error("expected timeout error") } } -func TestChatPipeline_RequestBuilding(t *testing.T) { -// Test the map construction logic from main.go's OnMessage. -data := map[string]any{ -"request_id": "req-e2e-001", -"user_id": "user-1", -"message": "hello", -"premium": true, -"enable_rag": false, -"enable_streaming": false, -"system_prompt": "Be brief.", -} +func TestChatPipeline_TypedDecoding(t *testing.T) { + // Verify typed struct decoding from msgpack (same path as OnTypedMessage). + raw := map[string]any{ + "request_id": "req-e2e-001", + "user_id": "user-1", + "message": "hello", + "premium": true, + "enable_rag": false, + "enable_streaming": false, + "system_prompt": "Be brief.", + } + data, _ := msgpack.Marshal(raw) -requestID := strVal(data, "request_id", "unknown") -userID := strVal(data, "user_id", "unknown") -query := strVal(data, "message", "") -premium := boolVal(data, "premium", false) -enableRAG := boolVal(data, "enable_rag", premium) -systemPrompt := strVal(data, "system_prompt", "") + var req messages.ChatRequest + if err := msgpack.Unmarshal(data, &req); err != nil { + t.Fatal(err) + } -if requestID != "req-e2e-001" { -t.Errorf("requestID = %q", requestID) -} -if userID != "user-1" { -t.Errorf("userID = %q", userID) -} -if query != "hello" { -t.Errorf("query = %q", query) -} -if enableRAG { -t.Error("enable_rag=false should override premium=true") -} -if systemPrompt != "Be brief." { -t.Errorf("systemPrompt = %q", systemPrompt) + if req.RequestID != "req-e2e-001" { + t.Errorf("RequestID = %q", req.RequestID) + } + if req.UserID != "user-1" { + t.Errorf("UserID = %q", req.UserID) + } + if req.EffectiveQuery() != "hello" { + t.Errorf("query = %q", req.EffectiveQuery()) + } + if req.EnableRAG { + t.Error("EnableRAG should be false") + } + if req.SystemPrompt != "Be brief." { + t.Errorf("SystemPrompt = %q", req.SystemPrompt) } } diff --git a/main.go b/main.go index 6f33f9d..e2beb31 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "encoding/base64" "fmt" "log/slog" "os" @@ -15,6 +14,8 @@ import ( "git.daviestechlabs.io/daviestechlabs/handler-base/clients" "git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/handler" + "git.daviestechlabs.io/daviestechlabs/handler-base/messages" + "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" ) func main() { @@ -44,22 +45,44 @@ func main() { h := handler.New("ai.chat.user.*.message", cfg) - h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { - requestID := strVal(data, "request_id", "unknown") - userID := strVal(data, "user_id", "unknown") - query := strVal(data, "message", "") - if query == "" { - query = strVal(data, "query", "") + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) { + req, err := natsutil.Decode[messages.ChatRequest](msg.Data) + if err != nil { + slog.Error("decode failed", "error", err) + return &messages.ErrorResponse{Error: true, Message: err.Error(), Type: "DecodeError"}, nil + } + + query := req.EffectiveQuery() + requestID := req.RequestID + if requestID == "" { + requestID = "unknown" + } + userID := req.UserID + if userID == "" { + userID = "unknown" + } + enableRAG := req.EnableRAG + if !enableRAG && req.Premium { + enableRAG = true + } + enableReranker := req.EnableReranker + if !enableReranker && enableRAG { + enableReranker = true + } + topK := req.TopK + if topK == 0 { + topK = ragTopK + } + collection := req.Collection + if collection == "" { + collection = ragCollection + } + reqEnableTTS := req.EnableTTS || enableTTS + systemPrompt := req.SystemPrompt + responseSubject := req.ResponseSubject + if responseSubject == "" { + responseSubject = fmt.Sprintf("ai.chat.response.%s", requestID) } - premium := boolVal(data, "premium", false) - enableRAG := boolVal(data, "enable_rag", premium) - enableReranker := boolVal(data, "enable_reranker", enableRAG) - enableStreaming := boolVal(data, "enable_streaming", false) - topK := intVal(data, "top_k", ragTopK) - collection := strVal(data, "collection", ragCollection) - reqEnableTTS := boolVal(data, "enable_tts", enableTTS) - systemPrompt := strVal(data, "system_prompt", "") - responseSubject := strVal(data, "response_subject", fmt.Sprintf("ai.chat.response.%s", requestID)) slog.Info("processing request", "request_id", requestID, "query_len", len(query)) @@ -138,15 +161,15 @@ func main() { responseText, err := llm.Generate(ctx, query, contextText, systemPrompt) if err != nil { slog.Error("LLM generation failed", "error", err) - return map[string]any{ - "user_id": userID, - "success": false, - "error": err.Error(), + return &messages.ChatResponse{ + UserID: userID, + Success: false, + Error: err.Error(), }, nil } // 6. Stream chunks if requested - if enableStreaming { + if req.EnableStreaming { streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID) words := strings.Fields(responseText) chunkSize := 4 @@ -156,47 +179,42 @@ func main() { end = len(words) } chunk := strings.Join(words[i:end], " ") - _ = h.NATS.Publish(streamSubject, map[string]any{ - "request_id": requestID, - "type": "chunk", - "content": chunk, - "done": false, - "timestamp": time.Now().Unix(), + _ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{ + RequestID: requestID, + Type: "chunk", + Content: chunk, + Timestamp: messages.Timestamp(), }) } - _ = h.NATS.Publish(streamSubject, map[string]any{ - "request_id": requestID, - "type": "done", - "content": "", - "done": true, - "timestamp": time.Now().Unix(), + _ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{ + RequestID: requestID, + Type: "done", + Done: true, + Timestamp: messages.Timestamp(), }) } - // 7. Optional TTS - var audioB64 string + // 7. Optional TTS — audio as raw bytes (no base64) + var audio []byte if reqEnableTTS && tts != nil { audioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "") if err != nil { slog.Error("TTS failed", "error", err) } else { - audioB64 = base64.StdEncoding.EncodeToString(audioBytes) + audio = audioBytes } } - result := map[string]any{ - "user_id": userID, - "response": responseText, - "response_text": responseText, - "used_rag": usedRAG, - "rag_sources": ragSources, - "success": true, + result := &messages.ChatResponse{ + UserID: userID, + Response: responseText, + ResponseText: responseText, + UsedRAG: usedRAG, + Success: true, + Audio: audio, } if includeSources { - result["rag_sources"] = ragSources - } - if audioB64 != "" { - result["audio"] = audioB64 + result.RAGSources = ragSources } // Publish to the response subject the frontend is waiting on @@ -214,38 +232,6 @@ func main() { // Helpers -func strVal(m map[string]any, key, fallback string) string { - if v, ok := m[key]; ok { - if s, ok := v.(string); ok { - return s - } - } - return fallback -} - -func boolVal(m map[string]any, key string, fallback bool) bool { - if v, ok := m[key]; ok { - if b, ok := v.(bool); ok { - return b - } - } - return fallback -} - -func intVal(m map[string]any, key string, fallback int) int { - if v, ok := m[key]; ok { - switch n := v.(type) { - case int: - return n - case int64: - return int(n) - case float64: - return int(n) - } - } - return fallback -} - func getEnv(key, fallback string) string { if v := os.Getenv(key); v != "" { return v diff --git a/main_test.go b/main_test.go index f9da393..d87d4c5 100644 --- a/main_test.go +++ b/main_test.go @@ -3,44 +3,59 @@ package main import ( "os" "testing" + + "git.daviestechlabs.io/daviestechlabs/handler-base/messages" + "github.com/vmihailenco/msgpack/v5" ) -func TestStrVal(t *testing.T) { - m := map[string]any{"key": "value", "num": 42} - if got := strVal(m, "key", ""); got != "value" { - t.Errorf("strVal(key) = %q", got) +func TestChatRequestDecode(t *testing.T) { + // Verify a msgpack-encoded map decodes cleanly into typed struct. + raw := map[string]any{ + "request_id": "req-1", + "user_id": "user-1", + "message": "hello", + "premium": true, + "top_k": 10, } - if got := strVal(m, "missing", "def"); got != "def" { - t.Errorf("strVal(missing) = %q", got) + data, _ := msgpack.Marshal(raw) + var req messages.ChatRequest + if err := msgpack.Unmarshal(data, &req); err != nil { + t.Fatal(err) + } + if req.RequestID != "req-1" { + t.Errorf("RequestID = %q", req.RequestID) + } + if req.EffectiveQuery() != "hello" { + t.Errorf("EffectiveQuery = %q", req.EffectiveQuery()) + } + if !req.Premium { + t.Error("Premium should be true") + } + if req.TopK != 10 { + t.Errorf("TopK = %d", req.TopK) } } -func TestBoolVal(t *testing.T) { - m := map[string]any{"flag": true, "str": "not-bool"} - if got := boolVal(m, "flag", false); !got { - t.Error("boolVal(flag) should be true") +func TestChatResponseRoundtrip(t *testing.T) { + resp := &messages.ChatResponse{ + UserID: "user-1", + Response: "answer", + Success: true, + Audio: []byte{0x01, 0x02, 0x03}, } - if got := boolVal(m, "str", false); got { - t.Error("boolVal(str) should be false (not a bool)") + data, err := msgpack.Marshal(resp) + if err != nil { + t.Fatal(err) } - if got := boolVal(m, "missing", true); !got { - t.Error("boolVal(missing) should use fallback true") + var decoded messages.ChatResponse + if err := msgpack.Unmarshal(data, &decoded); err != nil { + t.Fatal(err) } -} - -func TestIntVal(t *testing.T) { - m := map[string]any{"int": 5, "float": 3.14, "int64": int64(99)} - if got := intVal(m, "int", 0); got != 5 { - t.Errorf("intVal(int) = %d", got) + if decoded.UserID != "user-1" || !decoded.Success { + t.Errorf("decoded = %+v", decoded) } - if got := intVal(m, "float", 0); got != 3 { - t.Errorf("intVal(float) = %d", got) - } - if got := intVal(m, "int64", 0); got != 99 { - t.Errorf("intVal(int64) = %d", got) - } - if got := intVal(m, "missing", 42); got != 42 { - t.Errorf("intVal(missing) = %d", got) + if len(decoded.Audio) != 3 { + t.Errorf("audio len = %d", len(decoded.Audio)) } }