feat: migrate to typed messages, drop base64
- Switch OnMessage → OnTypedMessage with natsutil.Decode[messages.ChatRequest] - Return *messages.ChatResponse / *messages.ChatStreamChunk (not map[string]any) - Audio as raw []byte in msgpack (25% wire savings vs base64) - Remove strVal/boolVal/intVal helpers - Add .dockerignore, GOAMD64=v3 in Dockerfile - Update tests for typed structs (9 tests pass)
This commit is contained in:
9
.dockerignore
Normal file
9
.dockerignore
Normal file
@@ -0,0 +1,9 @@
|
||||
.git
|
||||
.gitignore
|
||||
*.md
|
||||
LICENSE
|
||||
renovate.json
|
||||
*_test.go
|
||||
e2e_test.go
|
||||
__pycache__
|
||||
.env*
|
||||
@@ -10,7 +10,7 @@ RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -o /chat-handler .
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOAMD64=v3 go build -ldflags="-w -s" -o /chat-handler .
|
||||
|
||||
# Runtime stage
|
||||
FROM scratch
|
||||
|
||||
63
e2e_test.go
63
e2e_test.go
@@ -9,6 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
@@ -161,39 +163,38 @@ t.Error("expected timeout error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPipeline_RequestBuilding(t *testing.T) {
|
||||
// Test the map construction logic from main.go's OnMessage.
|
||||
data := map[string]any{
|
||||
"request_id": "req-e2e-001",
|
||||
"user_id": "user-1",
|
||||
"message": "hello",
|
||||
"premium": true,
|
||||
"enable_rag": false,
|
||||
"enable_streaming": false,
|
||||
"system_prompt": "Be brief.",
|
||||
}
|
||||
func TestChatPipeline_TypedDecoding(t *testing.T) {
|
||||
// Verify typed struct decoding from msgpack (same path as OnTypedMessage).
|
||||
raw := map[string]any{
|
||||
"request_id": "req-e2e-001",
|
||||
"user_id": "user-1",
|
||||
"message": "hello",
|
||||
"premium": true,
|
||||
"enable_rag": false,
|
||||
"enable_streaming": false,
|
||||
"system_prompt": "Be brief.",
|
||||
}
|
||||
data, _ := msgpack.Marshal(raw)
|
||||
|
||||
requestID := strVal(data, "request_id", "unknown")
|
||||
userID := strVal(data, "user_id", "unknown")
|
||||
query := strVal(data, "message", "")
|
||||
premium := boolVal(data, "premium", false)
|
||||
enableRAG := boolVal(data, "enable_rag", premium)
|
||||
systemPrompt := strVal(data, "system_prompt", "")
|
||||
var req messages.ChatRequest
|
||||
if err := msgpack.Unmarshal(data, &req); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if requestID != "req-e2e-001" {
|
||||
t.Errorf("requestID = %q", requestID)
|
||||
}
|
||||
if userID != "user-1" {
|
||||
t.Errorf("userID = %q", userID)
|
||||
}
|
||||
if query != "hello" {
|
||||
t.Errorf("query = %q", query)
|
||||
}
|
||||
if enableRAG {
|
||||
t.Error("enable_rag=false should override premium=true")
|
||||
}
|
||||
if systemPrompt != "Be brief." {
|
||||
t.Errorf("systemPrompt = %q", systemPrompt)
|
||||
if req.RequestID != "req-e2e-001" {
|
||||
t.Errorf("RequestID = %q", req.RequestID)
|
||||
}
|
||||
if req.UserID != "user-1" {
|
||||
t.Errorf("UserID = %q", req.UserID)
|
||||
}
|
||||
if req.EffectiveQuery() != "hello" {
|
||||
t.Errorf("query = %q", req.EffectiveQuery())
|
||||
}
|
||||
if req.EnableRAG {
|
||||
t.Error("EnableRAG should be false")
|
||||
}
|
||||
if req.SystemPrompt != "Be brief." {
|
||||
t.Errorf("SystemPrompt = %q", req.SystemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
144
main.go
144
main.go
@@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
@@ -15,6 +14,8 @@ import (
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -44,22 +45,44 @@ func main() {
|
||||
|
||||
h := handler.New("ai.chat.user.*.message", cfg)
|
||||
|
||||
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||
requestID := strVal(data, "request_id", "unknown")
|
||||
userID := strVal(data, "user_id", "unknown")
|
||||
query := strVal(data, "message", "")
|
||||
if query == "" {
|
||||
query = strVal(data, "query", "")
|
||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
||||
req, err := natsutil.Decode[messages.ChatRequest](msg.Data)
|
||||
if err != nil {
|
||||
slog.Error("decode failed", "error", err)
|
||||
return &messages.ErrorResponse{Error: true, Message: err.Error(), Type: "DecodeError"}, nil
|
||||
}
|
||||
|
||||
query := req.EffectiveQuery()
|
||||
requestID := req.RequestID
|
||||
if requestID == "" {
|
||||
requestID = "unknown"
|
||||
}
|
||||
userID := req.UserID
|
||||
if userID == "" {
|
||||
userID = "unknown"
|
||||
}
|
||||
enableRAG := req.EnableRAG
|
||||
if !enableRAG && req.Premium {
|
||||
enableRAG = true
|
||||
}
|
||||
enableReranker := req.EnableReranker
|
||||
if !enableReranker && enableRAG {
|
||||
enableReranker = true
|
||||
}
|
||||
topK := req.TopK
|
||||
if topK == 0 {
|
||||
topK = ragTopK
|
||||
}
|
||||
collection := req.Collection
|
||||
if collection == "" {
|
||||
collection = ragCollection
|
||||
}
|
||||
reqEnableTTS := req.EnableTTS || enableTTS
|
||||
systemPrompt := req.SystemPrompt
|
||||
responseSubject := req.ResponseSubject
|
||||
if responseSubject == "" {
|
||||
responseSubject = fmt.Sprintf("ai.chat.response.%s", requestID)
|
||||
}
|
||||
premium := boolVal(data, "premium", false)
|
||||
enableRAG := boolVal(data, "enable_rag", premium)
|
||||
enableReranker := boolVal(data, "enable_reranker", enableRAG)
|
||||
enableStreaming := boolVal(data, "enable_streaming", false)
|
||||
topK := intVal(data, "top_k", ragTopK)
|
||||
collection := strVal(data, "collection", ragCollection)
|
||||
reqEnableTTS := boolVal(data, "enable_tts", enableTTS)
|
||||
systemPrompt := strVal(data, "system_prompt", "")
|
||||
responseSubject := strVal(data, "response_subject", fmt.Sprintf("ai.chat.response.%s", requestID))
|
||||
|
||||
slog.Info("processing request", "request_id", requestID, "query_len", len(query))
|
||||
|
||||
@@ -138,15 +161,15 @@ func main() {
|
||||
responseText, err := llm.Generate(ctx, query, contextText, systemPrompt)
|
||||
if err != nil {
|
||||
slog.Error("LLM generation failed", "error", err)
|
||||
return map[string]any{
|
||||
"user_id": userID,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
return &messages.ChatResponse{
|
||||
UserID: userID,
|
||||
Success: false,
|
||||
Error: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 6. Stream chunks if requested
|
||||
if enableStreaming {
|
||||
if req.EnableStreaming {
|
||||
streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID)
|
||||
words := strings.Fields(responseText)
|
||||
chunkSize := 4
|
||||
@@ -156,47 +179,42 @@ func main() {
|
||||
end = len(words)
|
||||
}
|
||||
chunk := strings.Join(words[i:end], " ")
|
||||
_ = h.NATS.Publish(streamSubject, map[string]any{
|
||||
"request_id": requestID,
|
||||
"type": "chunk",
|
||||
"content": chunk,
|
||||
"done": false,
|
||||
"timestamp": time.Now().Unix(),
|
||||
_ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{
|
||||
RequestID: requestID,
|
||||
Type: "chunk",
|
||||
Content: chunk,
|
||||
Timestamp: messages.Timestamp(),
|
||||
})
|
||||
}
|
||||
_ = h.NATS.Publish(streamSubject, map[string]any{
|
||||
"request_id": requestID,
|
||||
"type": "done",
|
||||
"content": "",
|
||||
"done": true,
|
||||
"timestamp": time.Now().Unix(),
|
||||
_ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{
|
||||
RequestID: requestID,
|
||||
Type: "done",
|
||||
Done: true,
|
||||
Timestamp: messages.Timestamp(),
|
||||
})
|
||||
}
|
||||
|
||||
// 7. Optional TTS
|
||||
var audioB64 string
|
||||
// 7. Optional TTS — audio as raw bytes (no base64)
|
||||
var audio []byte
|
||||
if reqEnableTTS && tts != nil {
|
||||
audioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
|
||||
if err != nil {
|
||||
slog.Error("TTS failed", "error", err)
|
||||
} else {
|
||||
audioB64 = base64.StdEncoding.EncodeToString(audioBytes)
|
||||
audio = audioBytes
|
||||
}
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"user_id": userID,
|
||||
"response": responseText,
|
||||
"response_text": responseText,
|
||||
"used_rag": usedRAG,
|
||||
"rag_sources": ragSources,
|
||||
"success": true,
|
||||
result := &messages.ChatResponse{
|
||||
UserID: userID,
|
||||
Response: responseText,
|
||||
ResponseText: responseText,
|
||||
UsedRAG: usedRAG,
|
||||
Success: true,
|
||||
Audio: audio,
|
||||
}
|
||||
if includeSources {
|
||||
result["rag_sources"] = ragSources
|
||||
}
|
||||
if audioB64 != "" {
|
||||
result["audio"] = audioB64
|
||||
result.RAGSources = ragSources
|
||||
}
|
||||
|
||||
// Publish to the response subject the frontend is waiting on
|
||||
@@ -214,38 +232,6 @@ func main() {
|
||||
|
||||
// Helpers
|
||||
|
||||
func strVal(m map[string]any, key, fallback string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func boolVal(m map[string]any, key string, fallback bool) bool {
|
||||
if v, ok := m[key]; ok {
|
||||
if b, ok := v.(bool); ok {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func intVal(m map[string]any, key string, fallback int) int {
|
||||
if v, ok := m[key]; ok {
|
||||
switch n := v.(type) {
|
||||
case int:
|
||||
return n
|
||||
case int64:
|
||||
return int(n)
|
||||
case float64:
|
||||
return int(n)
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
|
||||
71
main_test.go
71
main_test.go
@@ -3,44 +3,59 @@ package main
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
func TestStrVal(t *testing.T) {
|
||||
m := map[string]any{"key": "value", "num": 42}
|
||||
if got := strVal(m, "key", ""); got != "value" {
|
||||
t.Errorf("strVal(key) = %q", got)
|
||||
func TestChatRequestDecode(t *testing.T) {
|
||||
// Verify a msgpack-encoded map decodes cleanly into typed struct.
|
||||
raw := map[string]any{
|
||||
"request_id": "req-1",
|
||||
"user_id": "user-1",
|
||||
"message": "hello",
|
||||
"premium": true,
|
||||
"top_k": 10,
|
||||
}
|
||||
if got := strVal(m, "missing", "def"); got != "def" {
|
||||
t.Errorf("strVal(missing) = %q", got)
|
||||
data, _ := msgpack.Marshal(raw)
|
||||
var req messages.ChatRequest
|
||||
if err := msgpack.Unmarshal(data, &req); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if req.RequestID != "req-1" {
|
||||
t.Errorf("RequestID = %q", req.RequestID)
|
||||
}
|
||||
if req.EffectiveQuery() != "hello" {
|
||||
t.Errorf("EffectiveQuery = %q", req.EffectiveQuery())
|
||||
}
|
||||
if !req.Premium {
|
||||
t.Error("Premium should be true")
|
||||
}
|
||||
if req.TopK != 10 {
|
||||
t.Errorf("TopK = %d", req.TopK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolVal(t *testing.T) {
|
||||
m := map[string]any{"flag": true, "str": "not-bool"}
|
||||
if got := boolVal(m, "flag", false); !got {
|
||||
t.Error("boolVal(flag) should be true")
|
||||
func TestChatResponseRoundtrip(t *testing.T) {
|
||||
resp := &messages.ChatResponse{
|
||||
UserID: "user-1",
|
||||
Response: "answer",
|
||||
Success: true,
|
||||
Audio: []byte{0x01, 0x02, 0x03},
|
||||
}
|
||||
if got := boolVal(m, "str", false); got {
|
||||
t.Error("boolVal(str) should be false (not a bool)")
|
||||
data, err := msgpack.Marshal(resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := boolVal(m, "missing", true); !got {
|
||||
t.Error("boolVal(missing) should use fallback true")
|
||||
var decoded messages.ChatResponse
|
||||
if err := msgpack.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntVal(t *testing.T) {
|
||||
m := map[string]any{"int": 5, "float": 3.14, "int64": int64(99)}
|
||||
if got := intVal(m, "int", 0); got != 5 {
|
||||
t.Errorf("intVal(int) = %d", got)
|
||||
if decoded.UserID != "user-1" || !decoded.Success {
|
||||
t.Errorf("decoded = %+v", decoded)
|
||||
}
|
||||
if got := intVal(m, "float", 0); got != 3 {
|
||||
t.Errorf("intVal(float) = %d", got)
|
||||
}
|
||||
if got := intVal(m, "int64", 0); got != 99 {
|
||||
t.Errorf("intVal(int64) = %d", got)
|
||||
}
|
||||
if got := intVal(m, "missing", 42); got != 42 {
|
||||
t.Errorf("intVal(missing) = %d", got)
|
||||
if len(decoded.Audio) != 3 {
|
||||
t.Errorf("audio len = %d", len(decoded.Audio))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user