package main import ( "context" "fmt" "log/slog" "os" "strconv" "strings" "time" "github.com/nats-io/nats.go" "git.daviestechlabs.io/daviestechlabs/handler-base/clients" "git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/handler" "git.daviestechlabs.io/daviestechlabs/handler-base/messages" "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" ) func main() { cfg := config.Load() cfg.ServiceName = "voice-assistant" cfg.NATSQueueGroup = "voice-assistants" // Voice-specific settings ragTopK := getEnvInt("RAG_TOP_K", 10) ragRerankTopK := getEnvInt("RAG_RERANK_TOP_K", 5) ragCollection := getEnv("RAG_COLLECTION", "documents") sttLanguage := getEnv("STT_LANGUAGE", "") // empty = auto-detect ttsLanguage := getEnv("TTS_LANGUAGE", "en") includeTranscription := getEnvBool("INCLUDE_TRANSCRIPTION", true) includeSources := getEnvBool("INCLUDE_SOURCES", false) // Service clients timeout := 60 * time.Second stt := clients.NewSTTClient(cfg.STTURL(), timeout) embeddings := clients.NewEmbeddingsClient(cfg.EmbeddingsURL(), timeout, "") reranker := clients.NewRerankerClient(cfg.RerankerURL(), timeout) llm := clients.NewLLMClient(cfg.LLMURL(), timeout) tts := clients.NewTTSClient(cfg.TTSURL(), timeout, ttsLanguage) milvus := clients.NewMilvusClient(cfg.MilvusHost, cfg.MilvusPort, ragCollection) h := handler.New("voice.request", cfg) h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) { req, err := natsutil.Decode[messages.VoiceRequest](msg.Data) if err != nil { return &messages.VoiceResponse{Error: "Invalid request encoding"}, nil } requestID := req.RequestID if requestID == "" { requestID = "unknown" } language := req.Language if language == "" { language = sttLanguage } collection := req.Collection if collection == "" { collection = ragCollection } slog.Info("processing voice request", "request_id", requestID) errResp := func(msg string) (any, error) { return &messages.VoiceResponse{RequestID: requestID, Error: msg}, nil } // 1. Audio arrives as raw bytes — no base64 decode needed if len(req.Audio) == 0 { return errResp("No audio data") } // 2. Transcribe audio → text transcription, err := stt.Transcribe(ctx, req.Audio, language) if err != nil { slog.Error("STT failed", "error", err) return errResp("Transcription failed") } query := strings.TrimSpace(transcription.Text) if query == "" { slog.Warn("empty transcription", "request_id", requestID) return errResp("Could not transcribe audio") } slog.Info("transcribed", "text", truncate(query, 50)) // 3. Generate query embedding embedding, err := embeddings.EmbedSingle(ctx, query) if err != nil { slog.Error("embedding failed", "error", err) return errResp("Embedding failed") } // 4. Search Milvus for context (placeholder — requires Milvus SDK) _ = milvus _ = collection _ = embedding _ = ragTopK type docResult struct { Document string Score float64 } var documents []docResult // Milvus results placeholder // 5. Rerank documents if len(documents) > 0 { texts := make([]string, len(documents)) for i, d := range documents { texts[i] = d.Document } reranked, err := reranker.Rerank(ctx, query, texts, ragRerankTopK) if err != nil { slog.Error("rerank failed", "error", err) } else { documents = make([]docResult, len(reranked)) for i, r := range reranked { documents[i] = docResult{Document: r.Document, Score: r.Score} } } } // 6. Build context var contextParts []string for _, d := range documents { contextParts = append(contextParts, d.Document) } contextText := strings.Join(contextParts, "\n\n") // 7. Generate LLM response responseText, err := llm.Generate(ctx, query, contextText, "") if err != nil { slog.Error("LLM generation failed", "error", err) return errResp("Generation failed") } // 8. Synthesize speech — response audio is raw bytes responseAudio, err := tts.Synthesize(ctx, responseText, ttsLanguage, "") if err != nil { slog.Error("TTS failed", "error", err) return errResp("Speech synthesis failed") } // Build typed response result := &messages.VoiceResponse{ RequestID: requestID, Response: responseText, Audio: responseAudio, } if includeTranscription { result.Transcription = query } if includeSources && len(documents) > 0 { limit := 3 if len(documents) < limit { limit = len(documents) } result.Sources = make([]messages.DocumentSource, limit) for i := 0; i < limit; i++ { text := documents[i].Document if len(text) > 200 { text = text[:200] } result.Sources[i] = messages.DocumentSource{Text: text, Score: documents[i].Score} } } // Publish to response subject responseSubject := fmt.Sprintf("voice.response.%s", requestID) _ = h.NATS.Publish(responseSubject, result) slog.Info("completed voice request", "request_id", requestID) return result, nil }) if err := h.Run(); err != nil { slog.Error("handler failed", "error", err) os.Exit(1) } } // Helpers func getEnv(key, fallback string) string { if v := os.Getenv(key); v != "" { return v } return fallback } func getEnvInt(key string, fallback int) int { if v := os.Getenv(key); v != "" { if i, err := strconv.Atoi(v); err == nil { return i } } return fallback } func getEnvBool(key string, fallback bool) bool { if v := os.Getenv(key); v != "" { return strings.EqualFold(v, "true") || v == "1" } return fallback } func truncate(s string, n int) string { if len(s) <= n { return s } return s[:n] + "..." }