- Replace msgpack encoding with protobuf wire format - Update field names to proto convention (UserId, RequestId, EnableRag, etc.) - Use messages.EffectiveQuery() standalone function - Cast TopK to int32 for proto compatibility - Rewrite tests for proto round-trips
252 lines
6.5 KiB
Go
252 lines
6.5 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/nats-io/nats.go"
|
|
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
|
"google.golang.org/protobuf/proto"
|
|
)
|
|
|
|
func main() {
|
|
cfg := config.Load()
|
|
cfg.ServiceName = "chat-handler"
|
|
cfg.NATSQueueGroup = "chat-handlers"
|
|
|
|
// Chat-specific settings
|
|
ragTopK := getEnvInt("RAG_TOP_K", 10)
|
|
ragRerankTopK := getEnvInt("RAG_RERANK_TOP_K", 5)
|
|
ragCollection := getEnv("RAG_COLLECTION", "documents")
|
|
includeSources := getEnvBool("INCLUDE_SOURCES", true)
|
|
enableTTS := getEnvBool("ENABLE_TTS", false)
|
|
ttsLanguage := getEnv("TTS_LANGUAGE", "en")
|
|
|
|
// Service clients
|
|
timeout := 60 * time.Second
|
|
embeddings := clients.NewEmbeddingsClient(cfg.EmbeddingsURL(), timeout, "")
|
|
reranker := clients.NewRerankerClient(cfg.RerankerURL(), timeout)
|
|
llm := clients.NewLLMClient(cfg.LLMURL(), timeout)
|
|
milvus := clients.NewMilvusClient(cfg.MilvusHost, cfg.MilvusPort, ragCollection)
|
|
|
|
var tts *clients.TTSClient
|
|
if enableTTS {
|
|
tts = clients.NewTTSClient(cfg.TTSURL(), timeout, ttsLanguage)
|
|
}
|
|
|
|
h := handler.New("ai.chat.user.*.message", cfg)
|
|
|
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
|
var req messages.ChatRequest
|
|
if err := natsutil.Decode(msg.Data, &req); err != nil {
|
|
slog.Error("decode failed", "error", err)
|
|
return &messages.ErrorResponse{Error: true, Message: err.Error(), Type: "DecodeError"}, nil
|
|
}
|
|
|
|
query := messages.EffectiveQuery(&req)
|
|
requestID := req.RequestId
|
|
if requestID == "" {
|
|
requestID = "unknown"
|
|
}
|
|
userID := req.UserId
|
|
if userID == "" {
|
|
userID = "unknown"
|
|
}
|
|
enableRAG := req.EnableRag
|
|
if !enableRAG && req.Premium {
|
|
enableRAG = true
|
|
}
|
|
enableReranker := req.EnableReranker
|
|
if !enableReranker && enableRAG {
|
|
enableReranker = true
|
|
}
|
|
topK := req.TopK
|
|
if topK == 0 {
|
|
topK = int32(ragTopK)
|
|
}
|
|
collection := req.Collection
|
|
if collection == "" {
|
|
collection = ragCollection
|
|
}
|
|
reqEnableTTS := req.EnableTts || enableTTS
|
|
systemPrompt := req.SystemPrompt
|
|
responseSubject := req.ResponseSubject
|
|
if responseSubject == "" {
|
|
responseSubject = fmt.Sprintf("ai.chat.response.%s", requestID)
|
|
}
|
|
|
|
slog.Info("processing request", "request_id", requestID, "query_len", len(query))
|
|
|
|
contextText := ""
|
|
var ragSources []string
|
|
usedRAG := false
|
|
|
|
// RAG pipeline
|
|
if enableRAG {
|
|
// 1. Embed query
|
|
embedding, err := embeddings.EmbedSingle(ctx, query)
|
|
if err != nil {
|
|
slog.Error("embedding failed", "error", err)
|
|
} else {
|
|
// 2. Search Milvus
|
|
_ = milvus
|
|
_ = collection
|
|
_ = topK
|
|
_ = embedding
|
|
// NOTE: Milvus search uses the gRPC SDK (requires milvus-sdk-go)
|
|
// For now, we pass through without search; Milvus client will be
|
|
// connected when the SDK is integrated.
|
|
// documents := milvus.Search(ctx, embedding, topK)
|
|
|
|
var documents []map[string]any // placeholder for Milvus results
|
|
|
|
// 3. Rerank
|
|
if enableReranker && len(documents) > 0 {
|
|
texts := make([]string, len(documents))
|
|
for i, d := range documents {
|
|
if t, ok := d["text"].(string); ok {
|
|
texts[i] = t
|
|
}
|
|
}
|
|
reranked, err := reranker.Rerank(ctx, query, texts, ragRerankTopK)
|
|
if err != nil {
|
|
slog.Error("rerank failed", "error", err)
|
|
} else {
|
|
documents = make([]map[string]any, len(reranked))
|
|
for i, r := range reranked {
|
|
documents[i] = map[string]any{"document": r.Document, "score": r.Score}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 4. Build context
|
|
if len(documents) > 0 {
|
|
var parts []string
|
|
for i, d := range documents {
|
|
text := ""
|
|
if t, ok := d["document"].(string); ok {
|
|
text = t
|
|
}
|
|
parts = append(parts, fmt.Sprintf("[%d] %s", i+1, text))
|
|
}
|
|
contextText = strings.Join(parts, "\n\n")
|
|
|
|
for _, d := range documents {
|
|
if len(ragSources) >= 3 {
|
|
break
|
|
}
|
|
src := ""
|
|
if s, ok := d["source"].(string); ok {
|
|
src = s
|
|
} else if s, ok := d["document"].(string); ok && len(s) > 80 {
|
|
src = s[:80]
|
|
}
|
|
ragSources = append(ragSources, src)
|
|
}
|
|
usedRAG = true
|
|
}
|
|
}
|
|
}
|
|
|
|
// 5. Generate LLM response (streaming when requested)
|
|
var responseText string
|
|
var err error
|
|
if req.EnableStreaming {
|
|
streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID)
|
|
responseText, err = llm.StreamGenerate(ctx, query, contextText, systemPrompt, func(token string) {
|
|
_ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{
|
|
RequestId: requestID,
|
|
Type: "chunk",
|
|
Content: token,
|
|
Timestamp: messages.Timestamp(),
|
|
})
|
|
})
|
|
_ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{
|
|
RequestId: requestID,
|
|
Type: "done",
|
|
Done: true,
|
|
Timestamp: messages.Timestamp(),
|
|
})
|
|
} else {
|
|
responseText, err = llm.Generate(ctx, query, contextText, systemPrompt)
|
|
}
|
|
if err != nil {
|
|
slog.Error("LLM generation failed", "error", err)
|
|
return &messages.ChatResponse{
|
|
UserId: userID,
|
|
Success: false,
|
|
Error: err.Error(),
|
|
}, nil
|
|
}
|
|
|
|
// 6. Optional TTS — audio as raw bytes (no base64)
|
|
var audio []byte
|
|
if reqEnableTTS && tts != nil {
|
|
audioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
|
|
if err != nil {
|
|
slog.Error("TTS failed", "error", err)
|
|
} else {
|
|
audio = audioBytes
|
|
}
|
|
}
|
|
|
|
result := &messages.ChatResponse{
|
|
UserId: userID,
|
|
Response: responseText,
|
|
ResponseText: responseText,
|
|
UsedRag: usedRAG,
|
|
Success: true,
|
|
Audio: audio,
|
|
}
|
|
if includeSources {
|
|
result.RagSources = ragSources
|
|
}
|
|
|
|
// Publish to the response subject the frontend is waiting on
|
|
_ = h.NATS.Publish(responseSubject, result)
|
|
|
|
slog.Info("completed request", "request_id", requestID, "rag", usedRAG)
|
|
return result, nil
|
|
})
|
|
|
|
if err := h.Run(); err != nil {
|
|
slog.Error("handler failed", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
// Helpers
|
|
|
|
func getEnv(key, fallback string) string {
|
|
if v := os.Getenv(key); v != "" {
|
|
return v
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func getEnvInt(key string, fallback int) int {
|
|
if v := os.Getenv(key); v != "" {
|
|
if i, err := strconv.Atoi(v); err == nil {
|
|
return i
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func getEnvBool(key string, fallback bool) bool {
|
|
if v := os.Getenv(key); v != "" {
|
|
return strings.EqualFold(v, "true") || v == "1"
|
|
}
|
|
return fallback
|
|
}
|