feat: migrate to typed messages, drop base64
- Switch OnMessage → OnTypedMessage with natsutil.Decode[messages.ChatRequest] - Return *messages.ChatResponse / *messages.ChatStreamChunk (not map[string]any) - Audio as raw []byte in msgpack (25% wire savings vs base64) - Remove strVal/boolVal/intVal helpers - Add .dockerignore, GOAMD64=v3 in Dockerfile - Update tests for typed structs (9 tests pass)
This commit is contained in:
144
main.go
144
main.go
@@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
@@ -15,6 +14,8 @@ import (
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -44,22 +45,44 @@ func main() {
|
||||
|
||||
h := handler.New("ai.chat.user.*.message", cfg)
|
||||
|
||||
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||
requestID := strVal(data, "request_id", "unknown")
|
||||
userID := strVal(data, "user_id", "unknown")
|
||||
query := strVal(data, "message", "")
|
||||
if query == "" {
|
||||
query = strVal(data, "query", "")
|
||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
||||
req, err := natsutil.Decode[messages.ChatRequest](msg.Data)
|
||||
if err != nil {
|
||||
slog.Error("decode failed", "error", err)
|
||||
return &messages.ErrorResponse{Error: true, Message: err.Error(), Type: "DecodeError"}, nil
|
||||
}
|
||||
|
||||
query := req.EffectiveQuery()
|
||||
requestID := req.RequestID
|
||||
if requestID == "" {
|
||||
requestID = "unknown"
|
||||
}
|
||||
userID := req.UserID
|
||||
if userID == "" {
|
||||
userID = "unknown"
|
||||
}
|
||||
enableRAG := req.EnableRAG
|
||||
if !enableRAG && req.Premium {
|
||||
enableRAG = true
|
||||
}
|
||||
enableReranker := req.EnableReranker
|
||||
if !enableReranker && enableRAG {
|
||||
enableReranker = true
|
||||
}
|
||||
topK := req.TopK
|
||||
if topK == 0 {
|
||||
topK = ragTopK
|
||||
}
|
||||
collection := req.Collection
|
||||
if collection == "" {
|
||||
collection = ragCollection
|
||||
}
|
||||
reqEnableTTS := req.EnableTTS || enableTTS
|
||||
systemPrompt := req.SystemPrompt
|
||||
responseSubject := req.ResponseSubject
|
||||
if responseSubject == "" {
|
||||
responseSubject = fmt.Sprintf("ai.chat.response.%s", requestID)
|
||||
}
|
||||
premium := boolVal(data, "premium", false)
|
||||
enableRAG := boolVal(data, "enable_rag", premium)
|
||||
enableReranker := boolVal(data, "enable_reranker", enableRAG)
|
||||
enableStreaming := boolVal(data, "enable_streaming", false)
|
||||
topK := intVal(data, "top_k", ragTopK)
|
||||
collection := strVal(data, "collection", ragCollection)
|
||||
reqEnableTTS := boolVal(data, "enable_tts", enableTTS)
|
||||
systemPrompt := strVal(data, "system_prompt", "")
|
||||
responseSubject := strVal(data, "response_subject", fmt.Sprintf("ai.chat.response.%s", requestID))
|
||||
|
||||
slog.Info("processing request", "request_id", requestID, "query_len", len(query))
|
||||
|
||||
@@ -138,15 +161,15 @@ func main() {
|
||||
responseText, err := llm.Generate(ctx, query, contextText, systemPrompt)
|
||||
if err != nil {
|
||||
slog.Error("LLM generation failed", "error", err)
|
||||
return map[string]any{
|
||||
"user_id": userID,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
return &messages.ChatResponse{
|
||||
UserID: userID,
|
||||
Success: false,
|
||||
Error: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 6. Stream chunks if requested
|
||||
if enableStreaming {
|
||||
if req.EnableStreaming {
|
||||
streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID)
|
||||
words := strings.Fields(responseText)
|
||||
chunkSize := 4
|
||||
@@ -156,47 +179,42 @@ func main() {
|
||||
end = len(words)
|
||||
}
|
||||
chunk := strings.Join(words[i:end], " ")
|
||||
_ = h.NATS.Publish(streamSubject, map[string]any{
|
||||
"request_id": requestID,
|
||||
"type": "chunk",
|
||||
"content": chunk,
|
||||
"done": false,
|
||||
"timestamp": time.Now().Unix(),
|
||||
_ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{
|
||||
RequestID: requestID,
|
||||
Type: "chunk",
|
||||
Content: chunk,
|
||||
Timestamp: messages.Timestamp(),
|
||||
})
|
||||
}
|
||||
_ = h.NATS.Publish(streamSubject, map[string]any{
|
||||
"request_id": requestID,
|
||||
"type": "done",
|
||||
"content": "",
|
||||
"done": true,
|
||||
"timestamp": time.Now().Unix(),
|
||||
_ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{
|
||||
RequestID: requestID,
|
||||
Type: "done",
|
||||
Done: true,
|
||||
Timestamp: messages.Timestamp(),
|
||||
})
|
||||
}
|
||||
|
||||
// 7. Optional TTS
|
||||
var audioB64 string
|
||||
// 7. Optional TTS — audio as raw bytes (no base64)
|
||||
var audio []byte
|
||||
if reqEnableTTS && tts != nil {
|
||||
audioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
|
||||
if err != nil {
|
||||
slog.Error("TTS failed", "error", err)
|
||||
} else {
|
||||
audioB64 = base64.StdEncoding.EncodeToString(audioBytes)
|
||||
audio = audioBytes
|
||||
}
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"user_id": userID,
|
||||
"response": responseText,
|
||||
"response_text": responseText,
|
||||
"used_rag": usedRAG,
|
||||
"rag_sources": ragSources,
|
||||
"success": true,
|
||||
result := &messages.ChatResponse{
|
||||
UserID: userID,
|
||||
Response: responseText,
|
||||
ResponseText: responseText,
|
||||
UsedRAG: usedRAG,
|
||||
Success: true,
|
||||
Audio: audio,
|
||||
}
|
||||
if includeSources {
|
||||
result["rag_sources"] = ragSources
|
||||
}
|
||||
if audioB64 != "" {
|
||||
result["audio"] = audioB64
|
||||
result.RAGSources = ragSources
|
||||
}
|
||||
|
||||
// Publish to the response subject the frontend is waiting on
|
||||
@@ -214,38 +232,6 @@ func main() {
|
||||
|
||||
// Helpers
|
||||
|
||||
func strVal(m map[string]any, key, fallback string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func boolVal(m map[string]any, key string, fallback bool) bool {
|
||||
if v, ok := m[key]; ok {
|
||||
if b, ok := v.(bool); ok {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func intVal(m map[string]any, key string, fallback int) int {
|
||||
if v, ok := m[key]; ok {
|
||||
switch n := v.(type) {
|
||||
case int:
|
||||
return n
|
||||
case int64:
|
||||
return int(n)
|
||||
case float64:
|
||||
return int(n)
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
|
||||
Reference in New Issue
Block a user