feature/go-handler-refactor #1

Merged
billy merged 3 commits from feature/go-handler-refactor into main 2026-02-20 12:33:33 +00:00
5 changed files with 150 additions and 139 deletions
Showing only changes of commit 4175e2070c - Show all commits

9
.dockerignore Normal file
View File

@@ -0,0 +1,9 @@
.git
.gitignore
*.md
LICENSE
renovate.json
*_test.go
e2e_test.go
__pycache__
.env*

View File

@@ -10,7 +10,7 @@ RUN go mod download
COPY . . COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -o /chat-handler . RUN CGO_ENABLED=0 GOOS=linux GOAMD64=v3 go build -ldflags="-w -s" -o /chat-handler .
# Runtime stage # Runtime stage
FROM scratch FROM scratch

View File

@@ -9,6 +9,8 @@ import (
"time" "time"
"git.daviestechlabs.io/daviestechlabs/handler-base/clients" "git.daviestechlabs.io/daviestechlabs/handler-base/clients"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"github.com/vmihailenco/msgpack/v5"
) )
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
@@ -161,9 +163,9 @@ t.Error("expected timeout error")
} }
} }
func TestChatPipeline_RequestBuilding(t *testing.T) { func TestChatPipeline_TypedDecoding(t *testing.T) {
// Test the map construction logic from main.go's OnMessage. // Verify typed struct decoding from msgpack (same path as OnTypedMessage).
data := map[string]any{ raw := map[string]any{
"request_id": "req-e2e-001", "request_id": "req-e2e-001",
"user_id": "user-1", "user_id": "user-1",
"message": "hello", "message": "hello",
@@ -172,28 +174,27 @@ data := map[string]any{
"enable_streaming": false, "enable_streaming": false,
"system_prompt": "Be brief.", "system_prompt": "Be brief.",
} }
data, _ := msgpack.Marshal(raw)
requestID := strVal(data, "request_id", "unknown") var req messages.ChatRequest
userID := strVal(data, "user_id", "unknown") if err := msgpack.Unmarshal(data, &req); err != nil {
query := strVal(data, "message", "") t.Fatal(err)
premium := boolVal(data, "premium", false) }
enableRAG := boolVal(data, "enable_rag", premium)
systemPrompt := strVal(data, "system_prompt", "")
if requestID != "req-e2e-001" { if req.RequestID != "req-e2e-001" {
t.Errorf("requestID = %q", requestID) t.Errorf("RequestID = %q", req.RequestID)
} }
if userID != "user-1" { if req.UserID != "user-1" {
t.Errorf("userID = %q", userID) t.Errorf("UserID = %q", req.UserID)
} }
if query != "hello" { if req.EffectiveQuery() != "hello" {
t.Errorf("query = %q", query) t.Errorf("query = %q", req.EffectiveQuery())
} }
if enableRAG { if req.EnableRAG {
t.Error("enable_rag=false should override premium=true") t.Error("EnableRAG should be false")
} }
if systemPrompt != "Be brief." { if req.SystemPrompt != "Be brief." {
t.Errorf("systemPrompt = %q", systemPrompt) t.Errorf("SystemPrompt = %q", req.SystemPrompt)
} }
} }

144
main.go
View File

@@ -2,7 +2,6 @@ package main
import ( import (
"context" "context"
"encoding/base64"
"fmt" "fmt"
"log/slog" "log/slog"
"os" "os"
@@ -15,6 +14,8 @@ import (
"git.daviestechlabs.io/daviestechlabs/handler-base/clients" "git.daviestechlabs.io/daviestechlabs/handler-base/clients"
"git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/config"
"git.daviestechlabs.io/daviestechlabs/handler-base/handler" "git.daviestechlabs.io/daviestechlabs/handler-base/handler"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
) )
func main() { func main() {
@@ -44,22 +45,44 @@ func main() {
h := handler.New("ai.chat.user.*.message", cfg) h := handler.New("ai.chat.user.*.message", cfg)
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
requestID := strVal(data, "request_id", "unknown") req, err := natsutil.Decode[messages.ChatRequest](msg.Data)
userID := strVal(data, "user_id", "unknown") if err != nil {
query := strVal(data, "message", "") slog.Error("decode failed", "error", err)
if query == "" { return &messages.ErrorResponse{Error: true, Message: err.Error(), Type: "DecodeError"}, nil
query = strVal(data, "query", "") }
query := req.EffectiveQuery()
requestID := req.RequestID
if requestID == "" {
requestID = "unknown"
}
userID := req.UserID
if userID == "" {
userID = "unknown"
}
enableRAG := req.EnableRAG
if !enableRAG && req.Premium {
enableRAG = true
}
enableReranker := req.EnableReranker
if !enableReranker && enableRAG {
enableReranker = true
}
topK := req.TopK
if topK == 0 {
topK = ragTopK
}
collection := req.Collection
if collection == "" {
collection = ragCollection
}
reqEnableTTS := req.EnableTTS || enableTTS
systemPrompt := req.SystemPrompt
responseSubject := req.ResponseSubject
if responseSubject == "" {
responseSubject = fmt.Sprintf("ai.chat.response.%s", requestID)
} }
premium := boolVal(data, "premium", false)
enableRAG := boolVal(data, "enable_rag", premium)
enableReranker := boolVal(data, "enable_reranker", enableRAG)
enableStreaming := boolVal(data, "enable_streaming", false)
topK := intVal(data, "top_k", ragTopK)
collection := strVal(data, "collection", ragCollection)
reqEnableTTS := boolVal(data, "enable_tts", enableTTS)
systemPrompt := strVal(data, "system_prompt", "")
responseSubject := strVal(data, "response_subject", fmt.Sprintf("ai.chat.response.%s", requestID))
slog.Info("processing request", "request_id", requestID, "query_len", len(query)) slog.Info("processing request", "request_id", requestID, "query_len", len(query))
@@ -138,15 +161,15 @@ func main() {
responseText, err := llm.Generate(ctx, query, contextText, systemPrompt) responseText, err := llm.Generate(ctx, query, contextText, systemPrompt)
if err != nil { if err != nil {
slog.Error("LLM generation failed", "error", err) slog.Error("LLM generation failed", "error", err)
return map[string]any{ return &messages.ChatResponse{
"user_id": userID, UserID: userID,
"success": false, Success: false,
"error": err.Error(), Error: err.Error(),
}, nil }, nil
} }
// 6. Stream chunks if requested // 6. Stream chunks if requested
if enableStreaming { if req.EnableStreaming {
streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID) streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID)
words := strings.Fields(responseText) words := strings.Fields(responseText)
chunkSize := 4 chunkSize := 4
@@ -156,47 +179,42 @@ func main() {
end = len(words) end = len(words)
} }
chunk := strings.Join(words[i:end], " ") chunk := strings.Join(words[i:end], " ")
_ = h.NATS.Publish(streamSubject, map[string]any{ _ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{
"request_id": requestID, RequestID: requestID,
"type": "chunk", Type: "chunk",
"content": chunk, Content: chunk,
"done": false, Timestamp: messages.Timestamp(),
"timestamp": time.Now().Unix(),
}) })
} }
_ = h.NATS.Publish(streamSubject, map[string]any{ _ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{
"request_id": requestID, RequestID: requestID,
"type": "done", Type: "done",
"content": "", Done: true,
"done": true, Timestamp: messages.Timestamp(),
"timestamp": time.Now().Unix(),
}) })
} }
// 7. Optional TTS // 7. Optional TTS — audio as raw bytes (no base64)
var audioB64 string var audio []byte
if reqEnableTTS && tts != nil { if reqEnableTTS && tts != nil {
audioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "") audioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
if err != nil { if err != nil {
slog.Error("TTS failed", "error", err) slog.Error("TTS failed", "error", err)
} else { } else {
audioB64 = base64.StdEncoding.EncodeToString(audioBytes) audio = audioBytes
} }
} }
result := map[string]any{ result := &messages.ChatResponse{
"user_id": userID, UserID: userID,
"response": responseText, Response: responseText,
"response_text": responseText, ResponseText: responseText,
"used_rag": usedRAG, UsedRAG: usedRAG,
"rag_sources": ragSources, Success: true,
"success": true, Audio: audio,
} }
if includeSources { if includeSources {
result["rag_sources"] = ragSources result.RAGSources = ragSources
}
if audioB64 != "" {
result["audio"] = audioB64
} }
// Publish to the response subject the frontend is waiting on // Publish to the response subject the frontend is waiting on
@@ -214,38 +232,6 @@ func main() {
// Helpers // Helpers
func strVal(m map[string]any, key, fallback string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return fallback
}
func boolVal(m map[string]any, key string, fallback bool) bool {
if v, ok := m[key]; ok {
if b, ok := v.(bool); ok {
return b
}
}
return fallback
}
func intVal(m map[string]any, key string, fallback int) int {
if v, ok := m[key]; ok {
switch n := v.(type) {
case int:
return n
case int64:
return int(n)
case float64:
return int(n)
}
}
return fallback
}
func getEnv(key, fallback string) string { func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
return v return v

View File

@@ -3,44 +3,59 @@ package main
import ( import (
"os" "os"
"testing" "testing"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"github.com/vmihailenco/msgpack/v5"
) )
func TestStrVal(t *testing.T) { func TestChatRequestDecode(t *testing.T) {
m := map[string]any{"key": "value", "num": 42} // Verify a msgpack-encoded map decodes cleanly into typed struct.
if got := strVal(m, "key", ""); got != "value" { raw := map[string]any{
t.Errorf("strVal(key) = %q", got) "request_id": "req-1",
"user_id": "user-1",
"message": "hello",
"premium": true,
"top_k": 10,
} }
if got := strVal(m, "missing", "def"); got != "def" { data, _ := msgpack.Marshal(raw)
t.Errorf("strVal(missing) = %q", got) var req messages.ChatRequest
if err := msgpack.Unmarshal(data, &req); err != nil {
t.Fatal(err)
}
if req.RequestID != "req-1" {
t.Errorf("RequestID = %q", req.RequestID)
}
if req.EffectiveQuery() != "hello" {
t.Errorf("EffectiveQuery = %q", req.EffectiveQuery())
}
if !req.Premium {
t.Error("Premium should be true")
}
if req.TopK != 10 {
t.Errorf("TopK = %d", req.TopK)
} }
} }
func TestBoolVal(t *testing.T) { func TestChatResponseRoundtrip(t *testing.T) {
m := map[string]any{"flag": true, "str": "not-bool"} resp := &messages.ChatResponse{
if got := boolVal(m, "flag", false); !got { UserID: "user-1",
t.Error("boolVal(flag) should be true") Response: "answer",
Success: true,
Audio: []byte{0x01, 0x02, 0x03},
} }
if got := boolVal(m, "str", false); got { data, err := msgpack.Marshal(resp)
t.Error("boolVal(str) should be false (not a bool)") if err != nil {
t.Fatal(err)
} }
if got := boolVal(m, "missing", true); !got { var decoded messages.ChatResponse
t.Error("boolVal(missing) should use fallback true") if err := msgpack.Unmarshal(data, &decoded); err != nil {
t.Fatal(err)
} }
if decoded.UserID != "user-1" || !decoded.Success {
t.Errorf("decoded = %+v", decoded)
} }
if len(decoded.Audio) != 3 {
func TestIntVal(t *testing.T) { t.Errorf("audio len = %d", len(decoded.Audio))
m := map[string]any{"int": 5, "float": 3.14, "int64": int64(99)}
if got := intVal(m, "int", 0); got != 5 {
t.Errorf("intVal(int) = %d", got)
}
if got := intVal(m, "float", 0); got != 3 {
t.Errorf("intVal(float) = %d", got)
}
if got := intVal(m, "int64", 0); got != 99 {
t.Errorf("intVal(int64) = %d", got)
}
if got := intVal(m, "missing", 42); got != 42 {
t.Errorf("intVal(missing) = %d", got)
} }
} }