package main import ( "context" "encoding/base64" "fmt" "log/slog" "os" "strconv" "strings" "time" "github.com/nats-io/nats.go" "git.daviestechlabs.io/daviestechlabs/handler-base/clients" "git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/handler" ) func main() { cfg := config.Load() cfg.ServiceName = "chat-handler" cfg.NATSQueueGroup = "chat-handlers" // Chat-specific settings ragTopK := getEnvInt("RAG_TOP_K", 10) ragRerankTopK := getEnvInt("RAG_RERANK_TOP_K", 5) ragCollection := getEnv("RAG_COLLECTION", "documents") includeSources := getEnvBool("INCLUDE_SOURCES", true) enableTTS := getEnvBool("ENABLE_TTS", false) ttsLanguage := getEnv("TTS_LANGUAGE", "en") // Service clients timeout := 60 * time.Second embeddings := clients.NewEmbeddingsClient(cfg.EmbeddingsURL, timeout, "") reranker := clients.NewRerankerClient(cfg.RerankerURL, timeout) llm := clients.NewLLMClient(cfg.LLMURL, timeout) milvus := clients.NewMilvusClient(cfg.MilvusHost, cfg.MilvusPort, ragCollection) var tts *clients.TTSClient if enableTTS { tts = clients.NewTTSClient(cfg.TTSURL, timeout, ttsLanguage) } h := handler.New("ai.chat.user.*.message", cfg) h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { requestID := strVal(data, "request_id", "unknown") userID := strVal(data, "user_id", "unknown") query := strVal(data, "message", "") if query == "" { query = strVal(data, "query", "") } premium := boolVal(data, "premium", false) enableRAG := boolVal(data, "enable_rag", premium) enableReranker := boolVal(data, "enable_reranker", enableRAG) enableStreaming := boolVal(data, "enable_streaming", false) topK := intVal(data, "top_k", ragTopK) collection := strVal(data, "collection", ragCollection) reqEnableTTS := boolVal(data, "enable_tts", enableTTS) systemPrompt := strVal(data, "system_prompt", "") responseSubject := strVal(data, "response_subject", fmt.Sprintf("ai.chat.response.%s", requestID)) slog.Info("processing request", "request_id", requestID, "query_len", len(query)) contextText := "" var ragSources []string usedRAG := false // RAG pipeline if enableRAG { // 1. Embed query embedding, err := embeddings.EmbedSingle(ctx, query) if err != nil { slog.Error("embedding failed", "error", err) } else { // 2. Search Milvus _ = milvus _ = collection _ = topK _ = embedding // NOTE: Milvus search uses the gRPC SDK (requires milvus-sdk-go) // For now, we pass through without search; Milvus client will be // connected when the SDK is integrated. // documents := milvus.Search(ctx, embedding, topK) var documents []map[string]any // placeholder for Milvus results // 3. Rerank if enableReranker && len(documents) > 0 { texts := make([]string, len(documents)) for i, d := range documents { if t, ok := d["text"].(string); ok { texts[i] = t } } reranked, err := reranker.Rerank(ctx, query, texts, ragRerankTopK) if err != nil { slog.Error("rerank failed", "error", err) } else { documents = make([]map[string]any, len(reranked)) for i, r := range reranked { documents[i] = map[string]any{"document": r.Document, "score": r.Score} } } } // 4. Build context if len(documents) > 0 { var parts []string for i, d := range documents { text := "" if t, ok := d["document"].(string); ok { text = t } parts = append(parts, fmt.Sprintf("[%d] %s", i+1, text)) } contextText = strings.Join(parts, "\n\n") for _, d := range documents { if len(ragSources) >= 3 { break } src := "" if s, ok := d["source"].(string); ok { src = s } else if s, ok := d["document"].(string); ok && len(s) > 80 { src = s[:80] } ragSources = append(ragSources, src) } usedRAG = true } } } // 5. Generate LLM response responseText, err := llm.Generate(ctx, query, contextText, systemPrompt) if err != nil { slog.Error("LLM generation failed", "error", err) return map[string]any{ "user_id": userID, "success": false, "error": err.Error(), }, nil } // 6. Stream chunks if requested if enableStreaming { streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID) words := strings.Fields(responseText) chunkSize := 4 for i := 0; i < len(words); i += chunkSize { end := i + chunkSize if end > len(words) { end = len(words) } chunk := strings.Join(words[i:end], " ") _ = h.NATS.Publish(streamSubject, map[string]any{ "request_id": requestID, "type": "chunk", "content": chunk, "done": false, "timestamp": time.Now().Unix(), }) } _ = h.NATS.Publish(streamSubject, map[string]any{ "request_id": requestID, "type": "done", "content": "", "done": true, "timestamp": time.Now().Unix(), }) } // 7. Optional TTS var audioB64 string if reqEnableTTS && tts != nil { audioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "") if err != nil { slog.Error("TTS failed", "error", err) } else { audioB64 = base64.StdEncoding.EncodeToString(audioBytes) } } result := map[string]any{ "user_id": userID, "response": responseText, "response_text": responseText, "used_rag": usedRAG, "rag_sources": ragSources, "success": true, } if includeSources { result["rag_sources"] = ragSources } if audioB64 != "" { result["audio"] = audioB64 } // Publish to the response subject the frontend is waiting on _ = h.NATS.Publish(responseSubject, result) slog.Info("completed request", "request_id", requestID, "rag", usedRAG) return result, nil }) if err := h.Run(); err != nil { slog.Error("handler failed", "error", err) os.Exit(1) } } // Helpers func strVal(m map[string]any, key, fallback string) string { if v, ok := m[key]; ok { if s, ok := v.(string); ok { return s } } return fallback } func boolVal(m map[string]any, key string, fallback bool) bool { if v, ok := m[key]; ok { if b, ok := v.(bool); ok { return b } } return fallback } func intVal(m map[string]any, key string, fallback int) int { if v, ok := m[key]; ok { switch n := v.(type) { case int: return n case int64: return int(n) case float64: return int(n) } } return fallback } func getEnv(key, fallback string) string { if v := os.Getenv(key); v != "" { return v } return fallback } func getEnvInt(key string, fallback int) int { if v := os.Getenv(key); v != "" { if i, err := strconv.Atoi(v); err == nil { return i } } return fallback } func getEnvBool(key string, fallback bool) bool { if v := os.Getenv(key); v != "" { return strings.EqualFold(v, "true") || v == "1" } return fallback }