feature/go-handler-refactor #1
9
.dockerignore
Normal file
9
.dockerignore
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
*.md
|
||||||
|
LICENSE
|
||||||
|
renovate.json
|
||||||
|
*_test.go
|
||||||
|
e2e_test.go
|
||||||
|
__pycache__
|
||||||
|
.env*
|
||||||
@@ -10,7 +10,7 @@ RUN go mod download
|
|||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -o /chat-handler .
|
RUN CGO_ENABLED=0 GOOS=linux GOAMD64=v3 go build -ldflags="-w -s" -o /chat-handler .
|
||||||
|
|
||||||
# Runtime stage
|
# Runtime stage
|
||||||
FROM scratch
|
FROM scratch
|
||||||
|
|||||||
63
e2e_test.go
63
e2e_test.go
@@ -9,6 +9,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||||
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -161,39 +163,38 @@ t.Error("expected timeout error")
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestChatPipeline_RequestBuilding(t *testing.T) {
|
func TestChatPipeline_TypedDecoding(t *testing.T) {
|
||||||
// Test the map construction logic from main.go's OnMessage.
|
// Verify typed struct decoding from msgpack (same path as OnTypedMessage).
|
||||||
data := map[string]any{
|
raw := map[string]any{
|
||||||
"request_id": "req-e2e-001",
|
"request_id": "req-e2e-001",
|
||||||
"user_id": "user-1",
|
"user_id": "user-1",
|
||||||
"message": "hello",
|
"message": "hello",
|
||||||
"premium": true,
|
"premium": true,
|
||||||
"enable_rag": false,
|
"enable_rag": false,
|
||||||
"enable_streaming": false,
|
"enable_streaming": false,
|
||||||
"system_prompt": "Be brief.",
|
"system_prompt": "Be brief.",
|
||||||
}
|
}
|
||||||
|
data, _ := msgpack.Marshal(raw)
|
||||||
|
|
||||||
requestID := strVal(data, "request_id", "unknown")
|
var req messages.ChatRequest
|
||||||
userID := strVal(data, "user_id", "unknown")
|
if err := msgpack.Unmarshal(data, &req); err != nil {
|
||||||
query := strVal(data, "message", "")
|
t.Fatal(err)
|
||||||
premium := boolVal(data, "premium", false)
|
}
|
||||||
enableRAG := boolVal(data, "enable_rag", premium)
|
|
||||||
systemPrompt := strVal(data, "system_prompt", "")
|
|
||||||
|
|
||||||
if requestID != "req-e2e-001" {
|
if req.RequestID != "req-e2e-001" {
|
||||||
t.Errorf("requestID = %q", requestID)
|
t.Errorf("RequestID = %q", req.RequestID)
|
||||||
}
|
}
|
||||||
if userID != "user-1" {
|
if req.UserID != "user-1" {
|
||||||
t.Errorf("userID = %q", userID)
|
t.Errorf("UserID = %q", req.UserID)
|
||||||
}
|
}
|
||||||
if query != "hello" {
|
if req.EffectiveQuery() != "hello" {
|
||||||
t.Errorf("query = %q", query)
|
t.Errorf("query = %q", req.EffectiveQuery())
|
||||||
}
|
}
|
||||||
if enableRAG {
|
if req.EnableRAG {
|
||||||
t.Error("enable_rag=false should override premium=true")
|
t.Error("EnableRAG should be false")
|
||||||
}
|
}
|
||||||
if systemPrompt != "Be brief." {
|
if req.SystemPrompt != "Be brief." {
|
||||||
t.Errorf("systemPrompt = %q", systemPrompt)
|
t.Errorf("SystemPrompt = %q", req.SystemPrompt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
144
main.go
144
main.go
@@ -2,7 +2,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
@@ -15,6 +14,8 @@ import (
|
|||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -44,22 +45,44 @@ func main() {
|
|||||||
|
|
||||||
h := handler.New("ai.chat.user.*.message", cfg)
|
h := handler.New("ai.chat.user.*.message", cfg)
|
||||||
|
|
||||||
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
||||||
requestID := strVal(data, "request_id", "unknown")
|
req, err := natsutil.Decode[messages.ChatRequest](msg.Data)
|
||||||
userID := strVal(data, "user_id", "unknown")
|
if err != nil {
|
||||||
query := strVal(data, "message", "")
|
slog.Error("decode failed", "error", err)
|
||||||
if query == "" {
|
return &messages.ErrorResponse{Error: true, Message: err.Error(), Type: "DecodeError"}, nil
|
||||||
query = strVal(data, "query", "")
|
}
|
||||||
|
|
||||||
|
query := req.EffectiveQuery()
|
||||||
|
requestID := req.RequestID
|
||||||
|
if requestID == "" {
|
||||||
|
requestID = "unknown"
|
||||||
|
}
|
||||||
|
userID := req.UserID
|
||||||
|
if userID == "" {
|
||||||
|
userID = "unknown"
|
||||||
|
}
|
||||||
|
enableRAG := req.EnableRAG
|
||||||
|
if !enableRAG && req.Premium {
|
||||||
|
enableRAG = true
|
||||||
|
}
|
||||||
|
enableReranker := req.EnableReranker
|
||||||
|
if !enableReranker && enableRAG {
|
||||||
|
enableReranker = true
|
||||||
|
}
|
||||||
|
topK := req.TopK
|
||||||
|
if topK == 0 {
|
||||||
|
topK = ragTopK
|
||||||
|
}
|
||||||
|
collection := req.Collection
|
||||||
|
if collection == "" {
|
||||||
|
collection = ragCollection
|
||||||
|
}
|
||||||
|
reqEnableTTS := req.EnableTTS || enableTTS
|
||||||
|
systemPrompt := req.SystemPrompt
|
||||||
|
responseSubject := req.ResponseSubject
|
||||||
|
if responseSubject == "" {
|
||||||
|
responseSubject = fmt.Sprintf("ai.chat.response.%s", requestID)
|
||||||
}
|
}
|
||||||
premium := boolVal(data, "premium", false)
|
|
||||||
enableRAG := boolVal(data, "enable_rag", premium)
|
|
||||||
enableReranker := boolVal(data, "enable_reranker", enableRAG)
|
|
||||||
enableStreaming := boolVal(data, "enable_streaming", false)
|
|
||||||
topK := intVal(data, "top_k", ragTopK)
|
|
||||||
collection := strVal(data, "collection", ragCollection)
|
|
||||||
reqEnableTTS := boolVal(data, "enable_tts", enableTTS)
|
|
||||||
systemPrompt := strVal(data, "system_prompt", "")
|
|
||||||
responseSubject := strVal(data, "response_subject", fmt.Sprintf("ai.chat.response.%s", requestID))
|
|
||||||
|
|
||||||
slog.Info("processing request", "request_id", requestID, "query_len", len(query))
|
slog.Info("processing request", "request_id", requestID, "query_len", len(query))
|
||||||
|
|
||||||
@@ -138,15 +161,15 @@ func main() {
|
|||||||
responseText, err := llm.Generate(ctx, query, contextText, systemPrompt)
|
responseText, err := llm.Generate(ctx, query, contextText, systemPrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("LLM generation failed", "error", err)
|
slog.Error("LLM generation failed", "error", err)
|
||||||
return map[string]any{
|
return &messages.ChatResponse{
|
||||||
"user_id": userID,
|
UserID: userID,
|
||||||
"success": false,
|
Success: false,
|
||||||
"error": err.Error(),
|
Error: err.Error(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. Stream chunks if requested
|
// 6. Stream chunks if requested
|
||||||
if enableStreaming {
|
if req.EnableStreaming {
|
||||||
streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID)
|
streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID)
|
||||||
words := strings.Fields(responseText)
|
words := strings.Fields(responseText)
|
||||||
chunkSize := 4
|
chunkSize := 4
|
||||||
@@ -156,47 +179,42 @@ func main() {
|
|||||||
end = len(words)
|
end = len(words)
|
||||||
}
|
}
|
||||||
chunk := strings.Join(words[i:end], " ")
|
chunk := strings.Join(words[i:end], " ")
|
||||||
_ = h.NATS.Publish(streamSubject, map[string]any{
|
_ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{
|
||||||
"request_id": requestID,
|
RequestID: requestID,
|
||||||
"type": "chunk",
|
Type: "chunk",
|
||||||
"content": chunk,
|
Content: chunk,
|
||||||
"done": false,
|
Timestamp: messages.Timestamp(),
|
||||||
"timestamp": time.Now().Unix(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
_ = h.NATS.Publish(streamSubject, map[string]any{
|
_ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{
|
||||||
"request_id": requestID,
|
RequestID: requestID,
|
||||||
"type": "done",
|
Type: "done",
|
||||||
"content": "",
|
Done: true,
|
||||||
"done": true,
|
Timestamp: messages.Timestamp(),
|
||||||
"timestamp": time.Now().Unix(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7. Optional TTS
|
// 7. Optional TTS — audio as raw bytes (no base64)
|
||||||
var audioB64 string
|
var audio []byte
|
||||||
if reqEnableTTS && tts != nil {
|
if reqEnableTTS && tts != nil {
|
||||||
audioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
|
audioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("TTS failed", "error", err)
|
slog.Error("TTS failed", "error", err)
|
||||||
} else {
|
} else {
|
||||||
audioB64 = base64.StdEncoding.EncodeToString(audioBytes)
|
audio = audioBytes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result := map[string]any{
|
result := &messages.ChatResponse{
|
||||||
"user_id": userID,
|
UserID: userID,
|
||||||
"response": responseText,
|
Response: responseText,
|
||||||
"response_text": responseText,
|
ResponseText: responseText,
|
||||||
"used_rag": usedRAG,
|
UsedRAG: usedRAG,
|
||||||
"rag_sources": ragSources,
|
Success: true,
|
||||||
"success": true,
|
Audio: audio,
|
||||||
}
|
}
|
||||||
if includeSources {
|
if includeSources {
|
||||||
result["rag_sources"] = ragSources
|
result.RAGSources = ragSources
|
||||||
}
|
|
||||||
if audioB64 != "" {
|
|
||||||
result["audio"] = audioB64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish to the response subject the frontend is waiting on
|
// Publish to the response subject the frontend is waiting on
|
||||||
@@ -214,38 +232,6 @@ func main() {
|
|||||||
|
|
||||||
// Helpers
|
// Helpers
|
||||||
|
|
||||||
func strVal(m map[string]any, key, fallback string) string {
|
|
||||||
if v, ok := m[key]; ok {
|
|
||||||
if s, ok := v.(string); ok {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
func boolVal(m map[string]any, key string, fallback bool) bool {
|
|
||||||
if v, ok := m[key]; ok {
|
|
||||||
if b, ok := v.(bool); ok {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
func intVal(m map[string]any, key string, fallback int) int {
|
|
||||||
if v, ok := m[key]; ok {
|
|
||||||
switch n := v.(type) {
|
|
||||||
case int:
|
|
||||||
return n
|
|
||||||
case int64:
|
|
||||||
return int(n)
|
|
||||||
case float64:
|
|
||||||
return int(n)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
func getEnv(key, fallback string) string {
|
func getEnv(key, fallback string) string {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
return v
|
return v
|
||||||
|
|||||||
71
main_test.go
71
main_test.go
@@ -3,44 +3,59 @@ package main
|
|||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||||
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestStrVal(t *testing.T) {
|
func TestChatRequestDecode(t *testing.T) {
|
||||||
m := map[string]any{"key": "value", "num": 42}
|
// Verify a msgpack-encoded map decodes cleanly into typed struct.
|
||||||
if got := strVal(m, "key", ""); got != "value" {
|
raw := map[string]any{
|
||||||
t.Errorf("strVal(key) = %q", got)
|
"request_id": "req-1",
|
||||||
|
"user_id": "user-1",
|
||||||
|
"message": "hello",
|
||||||
|
"premium": true,
|
||||||
|
"top_k": 10,
|
||||||
}
|
}
|
||||||
if got := strVal(m, "missing", "def"); got != "def" {
|
data, _ := msgpack.Marshal(raw)
|
||||||
t.Errorf("strVal(missing) = %q", got)
|
var req messages.ChatRequest
|
||||||
|
if err := msgpack.Unmarshal(data, &req); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if req.RequestID != "req-1" {
|
||||||
|
t.Errorf("RequestID = %q", req.RequestID)
|
||||||
|
}
|
||||||
|
if req.EffectiveQuery() != "hello" {
|
||||||
|
t.Errorf("EffectiveQuery = %q", req.EffectiveQuery())
|
||||||
|
}
|
||||||
|
if !req.Premium {
|
||||||
|
t.Error("Premium should be true")
|
||||||
|
}
|
||||||
|
if req.TopK != 10 {
|
||||||
|
t.Errorf("TopK = %d", req.TopK)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBoolVal(t *testing.T) {
|
func TestChatResponseRoundtrip(t *testing.T) {
|
||||||
m := map[string]any{"flag": true, "str": "not-bool"}
|
resp := &messages.ChatResponse{
|
||||||
if got := boolVal(m, "flag", false); !got {
|
UserID: "user-1",
|
||||||
t.Error("boolVal(flag) should be true")
|
Response: "answer",
|
||||||
|
Success: true,
|
||||||
|
Audio: []byte{0x01, 0x02, 0x03},
|
||||||
}
|
}
|
||||||
if got := boolVal(m, "str", false); got {
|
data, err := msgpack.Marshal(resp)
|
||||||
t.Error("boolVal(str) should be false (not a bool)")
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if got := boolVal(m, "missing", true); !got {
|
var decoded messages.ChatResponse
|
||||||
t.Error("boolVal(missing) should use fallback true")
|
if err := msgpack.Unmarshal(data, &decoded); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
if decoded.UserID != "user-1" || !decoded.Success {
|
||||||
|
t.Errorf("decoded = %+v", decoded)
|
||||||
func TestIntVal(t *testing.T) {
|
|
||||||
m := map[string]any{"int": 5, "float": 3.14, "int64": int64(99)}
|
|
||||||
if got := intVal(m, "int", 0); got != 5 {
|
|
||||||
t.Errorf("intVal(int) = %d", got)
|
|
||||||
}
|
}
|
||||||
if got := intVal(m, "float", 0); got != 3 {
|
if len(decoded.Audio) != 3 {
|
||||||
t.Errorf("intVal(float) = %d", got)
|
t.Errorf("audio len = %d", len(decoded.Audio))
|
||||||
}
|
|
||||||
if got := intVal(m, "int64", 0); got != 99 {
|
|
||||||
t.Errorf("intVal(int64) = %d", got)
|
|
||||||
}
|
|
||||||
if got := intVal(m, "missing", 42); got != 42 {
|
|
||||||
t.Errorf("intVal(missing) = %d", got)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user