package main import ( "context" "fmt" "log/slog" "os" "strconv" "strings" "time" "github.com/nats-io/nats.go" "git.daviestechlabs.io/daviestechlabs/handler-base/clients" "git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/handler" "git.daviestechlabs.io/daviestechlabs/handler-base/messages" "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" ) func main() { cfg := config.Load() cfg.ServiceName = "chat-handler" cfg.NATSQueueGroup = "chat-handlers" // Chat-specific settings ragTopK := getEnvInt("RAG_TOP_K", 10) ragRerankTopK := getEnvInt("RAG_RERANK_TOP_K", 5) ragCollection := getEnv("RAG_COLLECTION", "documents") includeSources := getEnvBool("INCLUDE_SOURCES", true) enableTTS := getEnvBool("ENABLE_TTS", false) ttsLanguage := getEnv("TTS_LANGUAGE", "en") // Service clients timeout := 60 * time.Second embeddings := clients.NewEmbeddingsClient(cfg.EmbeddingsURL(), timeout, "") reranker := clients.NewRerankerClient(cfg.RerankerURL(), timeout) llm := clients.NewLLMClient(cfg.LLMURL(), timeout) milvus := clients.NewMilvusClient(cfg.MilvusHost, cfg.MilvusPort, ragCollection) var tts *clients.TTSClient if enableTTS { tts = clients.NewTTSClient(cfg.TTSURL(), timeout, ttsLanguage) } h := handler.New("ai.chat.user.*.message", cfg) h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) { req, err := natsutil.Decode[messages.ChatRequest](msg.Data) if err != nil { slog.Error("decode failed", "error", err) return &messages.ErrorResponse{Error: true, Message: err.Error(), Type: "DecodeError"}, nil } query := req.EffectiveQuery() requestID := req.RequestID if requestID == "" { requestID = "unknown" } userID := req.UserID if userID == "" { userID = "unknown" } enableRAG := req.EnableRAG if !enableRAG && req.Premium { enableRAG = true } enableReranker := req.EnableReranker if !enableReranker && enableRAG { enableReranker = true } topK := req.TopK if topK == 0 { topK = ragTopK } collection := req.Collection if collection == "" { collection = ragCollection } reqEnableTTS := req.EnableTTS || enableTTS systemPrompt := req.SystemPrompt responseSubject := req.ResponseSubject if responseSubject == "" { responseSubject = fmt.Sprintf("ai.chat.response.%s", requestID) } slog.Info("processing request", "request_id", requestID, "query_len", len(query)) contextText := "" var ragSources []string usedRAG := false // RAG pipeline if enableRAG { // 1. Embed query embedding, err := embeddings.EmbedSingle(ctx, query) if err != nil { slog.Error("embedding failed", "error", err) } else { // 2. Search Milvus _ = milvus _ = collection _ = topK _ = embedding // NOTE: Milvus search uses the gRPC SDK (requires milvus-sdk-go) // For now, we pass through without search; Milvus client will be // connected when the SDK is integrated. // documents := milvus.Search(ctx, embedding, topK) var documents []map[string]any // placeholder for Milvus results // 3. Rerank if enableReranker && len(documents) > 0 { texts := make([]string, len(documents)) for i, d := range documents { if t, ok := d["text"].(string); ok { texts[i] = t } } reranked, err := reranker.Rerank(ctx, query, texts, ragRerankTopK) if err != nil { slog.Error("rerank failed", "error", err) } else { documents = make([]map[string]any, len(reranked)) for i, r := range reranked { documents[i] = map[string]any{"document": r.Document, "score": r.Score} } } } // 4. Build context if len(documents) > 0 { var parts []string for i, d := range documents { text := "" if t, ok := d["document"].(string); ok { text = t } parts = append(parts, fmt.Sprintf("[%d] %s", i+1, text)) } contextText = strings.Join(parts, "\n\n") for _, d := range documents { if len(ragSources) >= 3 { break } src := "" if s, ok := d["source"].(string); ok { src = s } else if s, ok := d["document"].(string); ok && len(s) > 80 { src = s[:80] } ragSources = append(ragSources, src) } usedRAG = true } } } // 5. Generate LLM response responseText, err := llm.Generate(ctx, query, contextText, systemPrompt) if err != nil { slog.Error("LLM generation failed", "error", err) return &messages.ChatResponse{ UserID: userID, Success: false, Error: err.Error(), }, nil } // 6. Stream chunks if requested if req.EnableStreaming { streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID) words := strings.Fields(responseText) chunkSize := 4 for i := 0; i < len(words); i += chunkSize { end := i + chunkSize if end > len(words) { end = len(words) } chunk := strings.Join(words[i:end], " ") _ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{ RequestID: requestID, Type: "chunk", Content: chunk, Timestamp: messages.Timestamp(), }) } _ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{ RequestID: requestID, Type: "done", Done: true, Timestamp: messages.Timestamp(), }) } // 7. Optional TTS — audio as raw bytes (no base64) var audio []byte if reqEnableTTS && tts != nil { audioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "") if err != nil { slog.Error("TTS failed", "error", err) } else { audio = audioBytes } } result := &messages.ChatResponse{ UserID: userID, Response: responseText, ResponseText: responseText, UsedRAG: usedRAG, Success: true, Audio: audio, } if includeSources { result.RAGSources = ragSources } // Publish to the response subject the frontend is waiting on _ = h.NATS.Publish(responseSubject, result) slog.Info("completed request", "request_id", requestID, "rag", usedRAG) return result, nil }) if err := h.Run(); err != nil { slog.Error("handler failed", "error", err) os.Exit(1) } } // Helpers func getEnv(key, fallback string) string { if v := os.Getenv(key); v != "" { return v } return fallback } func getEnvInt(key string, fallback int) int { if v := os.Getenv(key); v != "" { if i, err := strconv.Atoi(v); err == nil { return i } } return fallback } func getEnvBool(key string, fallback bool) bool { if v := os.Getenv(key); v != "" { return strings.EqualFold(v, "true") || v == "1" } return fallback }