Files
chat-handler/main.go
Billy D. 87d0545d2c
Some checks failed
CI / Lint (push) Successful in 3m0s
CI / Test (push) Successful in 3m23s
CI / Docker Build & Push (push) Failing after 4m55s
CI / Release (push) Successful in 1m4s
CI / Notify (push) Successful in 1s
feat: replace fake streaming with real SSE StreamGenerate
Use handler-base StreamGenerate() to publish real token-by-token
ChatStreamChunk messages to NATS as they arrive from Ray Serve,
instead of calling Generate() and splitting into 4-word chunks.

Add 8 streaming tests: happy path, system prompt, RAG context,
nil callback, timeout, HTTP error, context canceled, fallback.
2026-02-21 09:23:57 -05:00

250 lines
6.4 KiB
Go

package main
import (
"context"
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"time"
"github.com/nats-io/nats.go"
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
)
func main() {
cfg := config.Load()
cfg.ServiceName = "chat-handler"
cfg.NATSQueueGroup = "chat-handlers"
// Chat-specific settings
ragTopK := getEnvInt("RAG_TOP_K", 10)
ragRerankTopK := getEnvInt("RAG_RERANK_TOP_K", 5)
ragCollection := getEnv("RAG_COLLECTION", "documents")
includeSources := getEnvBool("INCLUDE_SOURCES", true)
enableTTS := getEnvBool("ENABLE_TTS", false)
ttsLanguage := getEnv("TTS_LANGUAGE", "en")
// Service clients
timeout := 60 * time.Second
embeddings := clients.NewEmbeddingsClient(cfg.EmbeddingsURL(), timeout, "")
reranker := clients.NewRerankerClient(cfg.RerankerURL(), timeout)
llm := clients.NewLLMClient(cfg.LLMURL(), timeout)
milvus := clients.NewMilvusClient(cfg.MilvusHost, cfg.MilvusPort, ragCollection)
var tts *clients.TTSClient
if enableTTS {
tts = clients.NewTTSClient(cfg.TTSURL(), timeout, ttsLanguage)
}
h := handler.New("ai.chat.user.*.message", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
req, err := natsutil.Decode[messages.ChatRequest](msg.Data)
if err != nil {
slog.Error("decode failed", "error", err)
return &messages.ErrorResponse{Error: true, Message: err.Error(), Type: "DecodeError"}, nil
}
query := req.EffectiveQuery()
requestID := req.RequestID
if requestID == "" {
requestID = "unknown"
}
userID := req.UserID
if userID == "" {
userID = "unknown"
}
enableRAG := req.EnableRAG
if !enableRAG && req.Premium {
enableRAG = true
}
enableReranker := req.EnableReranker
if !enableReranker && enableRAG {
enableReranker = true
}
topK := req.TopK
if topK == 0 {
topK = ragTopK
}
collection := req.Collection
if collection == "" {
collection = ragCollection
}
reqEnableTTS := req.EnableTTS || enableTTS
systemPrompt := req.SystemPrompt
responseSubject := req.ResponseSubject
if responseSubject == "" {
responseSubject = fmt.Sprintf("ai.chat.response.%s", requestID)
}
slog.Info("processing request", "request_id", requestID, "query_len", len(query))
contextText := ""
var ragSources []string
usedRAG := false
// RAG pipeline
if enableRAG {
// 1. Embed query
embedding, err := embeddings.EmbedSingle(ctx, query)
if err != nil {
slog.Error("embedding failed", "error", err)
} else {
// 2. Search Milvus
_ = milvus
_ = collection
_ = topK
_ = embedding
// NOTE: Milvus search uses the gRPC SDK (requires milvus-sdk-go)
// For now, we pass through without search; Milvus client will be
// connected when the SDK is integrated.
// documents := milvus.Search(ctx, embedding, topK)
var documents []map[string]any // placeholder for Milvus results
// 3. Rerank
if enableReranker && len(documents) > 0 {
texts := make([]string, len(documents))
for i, d := range documents {
if t, ok := d["text"].(string); ok {
texts[i] = t
}
}
reranked, err := reranker.Rerank(ctx, query, texts, ragRerankTopK)
if err != nil {
slog.Error("rerank failed", "error", err)
} else {
documents = make([]map[string]any, len(reranked))
for i, r := range reranked {
documents[i] = map[string]any{"document": r.Document, "score": r.Score}
}
}
}
// 4. Build context
if len(documents) > 0 {
var parts []string
for i, d := range documents {
text := ""
if t, ok := d["document"].(string); ok {
text = t
}
parts = append(parts, fmt.Sprintf("[%d] %s", i+1, text))
}
contextText = strings.Join(parts, "\n\n")
for _, d := range documents {
if len(ragSources) >= 3 {
break
}
src := ""
if s, ok := d["source"].(string); ok {
src = s
} else if s, ok := d["document"].(string); ok && len(s) > 80 {
src = s[:80]
}
ragSources = append(ragSources, src)
}
usedRAG = true
}
}
}
// 5. Generate LLM response (streaming when requested)
var responseText string
if req.EnableStreaming {
streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID)
responseText, err = llm.StreamGenerate(ctx, query, contextText, systemPrompt, func(token string) {
_ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{
RequestID: requestID,
Type: "chunk",
Content: token,
Timestamp: messages.Timestamp(),
})
})
_ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{
RequestID: requestID,
Type: "done",
Done: true,
Timestamp: messages.Timestamp(),
})
} else {
responseText, err = llm.Generate(ctx, query, contextText, systemPrompt)
}
if err != nil {
slog.Error("LLM generation failed", "error", err)
return &messages.ChatResponse{
UserID: userID,
Success: false,
Error: err.Error(),
}, nil
}
// 6. Optional TTS — audio as raw bytes (no base64)
var audio []byte
if reqEnableTTS && tts != nil {
audioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
if err != nil {
slog.Error("TTS failed", "error", err)
} else {
audio = audioBytes
}
}
result := &messages.ChatResponse{
UserID: userID,
Response: responseText,
ResponseText: responseText,
UsedRAG: usedRAG,
Success: true,
Audio: audio,
}
if includeSources {
result.RAGSources = ragSources
}
// Publish to the response subject the frontend is waiting on
_ = h.NATS.Publish(responseSubject, result)
slog.Info("completed request", "request_id", requestID, "rag", usedRAG)
return result, nil
})
if err := h.Run(); err != nil {
slog.Error("handler failed", "error", err)
os.Exit(1)
}
}
// Helpers
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getEnvInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return i
}
}
return fallback
}
func getEnvBool(key string, fallback bool) bool {
if v := os.Getenv(key); v != "" {
return strings.EqualFold(v, "true") || v == "1"
}
return fallback
}