- Replace msgpack encoding with protobuf wire format - Update field names to proto convention - Use pointer slices for repeated message fields ([]*DocumentSource) - Rewrite tests for proto round-trips
219 lines
5.7 KiB
Go
219 lines
5.7 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/nats-io/nats.go"
|
|
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
|
"google.golang.org/protobuf/proto"
|
|
)
|
|
|
|
func main() {
|
|
cfg := config.Load()
|
|
cfg.ServiceName = "voice-assistant"
|
|
cfg.NATSQueueGroup = "voice-assistants"
|
|
|
|
// Voice-specific settings
|
|
ragTopK := getEnvInt("RAG_TOP_K", 10)
|
|
ragRerankTopK := getEnvInt("RAG_RERANK_TOP_K", 5)
|
|
ragCollection := getEnv("RAG_COLLECTION", "documents")
|
|
sttLanguage := getEnv("STT_LANGUAGE", "") // empty = auto-detect
|
|
ttsLanguage := getEnv("TTS_LANGUAGE", "en")
|
|
includeTranscription := getEnvBool("INCLUDE_TRANSCRIPTION", true)
|
|
includeSources := getEnvBool("INCLUDE_SOURCES", false)
|
|
|
|
// Service clients
|
|
timeout := 60 * time.Second
|
|
stt := clients.NewSTTClient(cfg.STTURL(), timeout)
|
|
embeddings := clients.NewEmbeddingsClient(cfg.EmbeddingsURL(), timeout, "")
|
|
reranker := clients.NewRerankerClient(cfg.RerankerURL(), timeout)
|
|
llm := clients.NewLLMClient(cfg.LLMURL(), timeout)
|
|
tts := clients.NewTTSClient(cfg.TTSURL(), timeout, ttsLanguage)
|
|
milvus := clients.NewMilvusClient(cfg.MilvusHost, cfg.MilvusPort, ragCollection)
|
|
|
|
h := handler.New("voice.request", cfg)
|
|
|
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
|
var req messages.VoiceRequest
|
|
if err := natsutil.Decode(msg.Data, &req); err != nil {
|
|
return &messages.VoiceResponse{Error: "Invalid request encoding"}, nil
|
|
}
|
|
|
|
requestID := req.RequestId
|
|
if requestID == "" {
|
|
requestID = "unknown"
|
|
}
|
|
language := req.Language
|
|
if language == "" {
|
|
language = sttLanguage
|
|
}
|
|
collection := req.Collection
|
|
if collection == "" {
|
|
collection = ragCollection
|
|
}
|
|
|
|
slog.Info("processing voice request", "request_id", requestID)
|
|
|
|
errResp := func(msg string) (proto.Message, error) {
|
|
return &messages.VoiceResponse{RequestId: requestID, Error: msg}, nil
|
|
}
|
|
|
|
// 1. Audio arrives as raw bytes — no base64 decode needed
|
|
if len(req.Audio) == 0 {
|
|
return errResp("No audio data")
|
|
}
|
|
|
|
// 2. Transcribe audio → text
|
|
transcription, err := stt.Transcribe(ctx, req.Audio, language)
|
|
if err != nil {
|
|
slog.Error("STT failed", "error", err)
|
|
return errResp("Transcription failed")
|
|
}
|
|
|
|
query := strings.TrimSpace(transcription.Text)
|
|
if query == "" {
|
|
slog.Warn("empty transcription", "request_id", requestID)
|
|
return errResp("Could not transcribe audio")
|
|
}
|
|
|
|
slog.Info("transcribed", "text", truncate(query, 50))
|
|
|
|
// 3. Generate query embedding
|
|
embedding, err := embeddings.EmbedSingle(ctx, query)
|
|
if err != nil {
|
|
slog.Error("embedding failed", "error", err)
|
|
return errResp("Embedding failed")
|
|
}
|
|
|
|
// 4. Search Milvus for context (placeholder — requires Milvus SDK)
|
|
_ = milvus
|
|
_ = collection
|
|
_ = embedding
|
|
_ = ragTopK
|
|
type docResult struct {
|
|
Document string
|
|
Score float64
|
|
}
|
|
var documents []docResult // Milvus results placeholder
|
|
|
|
// 5. Rerank documents
|
|
if len(documents) > 0 {
|
|
texts := make([]string, len(documents))
|
|
for i, d := range documents {
|
|
texts[i] = d.Document
|
|
}
|
|
reranked, err := reranker.Rerank(ctx, query, texts, ragRerankTopK)
|
|
if err != nil {
|
|
slog.Error("rerank failed", "error", err)
|
|
} else {
|
|
documents = make([]docResult, len(reranked))
|
|
for i, r := range reranked {
|
|
documents[i] = docResult{Document: r.Document, Score: r.Score}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 6. Build context
|
|
var contextParts []string
|
|
for _, d := range documents {
|
|
contextParts = append(contextParts, d.Document)
|
|
}
|
|
contextText := strings.Join(contextParts, "\n\n")
|
|
|
|
// 7. Generate LLM response
|
|
responseText, err := llm.Generate(ctx, query, contextText, "")
|
|
if err != nil {
|
|
slog.Error("LLM generation failed", "error", err)
|
|
return errResp("Generation failed")
|
|
}
|
|
|
|
// 8. Synthesize speech — response audio is raw bytes
|
|
responseAudio, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
|
|
if err != nil {
|
|
slog.Error("TTS failed", "error", err)
|
|
return errResp("Speech synthesis failed")
|
|
}
|
|
|
|
// Build typed response
|
|
result := &messages.VoiceResponse{
|
|
RequestId: requestID,
|
|
Response: responseText,
|
|
Audio: responseAudio,
|
|
}
|
|
|
|
if includeTranscription {
|
|
result.Transcription = query
|
|
}
|
|
|
|
if includeSources && len(documents) > 0 {
|
|
limit := 3
|
|
if len(documents) < limit {
|
|
limit = len(documents)
|
|
}
|
|
result.Sources = make([]*messages.DocumentSource, limit)
|
|
for i := 0; i < limit; i++ {
|
|
text := documents[i].Document
|
|
if len(text) > 200 {
|
|
text = text[:200]
|
|
}
|
|
result.Sources[i] = &messages.DocumentSource{Text: text, Score: documents[i].Score}
|
|
}
|
|
}
|
|
|
|
// Publish to response subject
|
|
responseSubject := fmt.Sprintf("voice.response.%s", requestID)
|
|
_ = h.NATS.Publish(responseSubject, result)
|
|
|
|
slog.Info("completed voice request", "request_id", requestID)
|
|
return result, nil
|
|
})
|
|
|
|
if err := h.Run(); err != nil {
|
|
slog.Error("handler failed", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
// Helpers
|
|
|
|
func getEnv(key, fallback string) string {
|
|
if v := os.Getenv(key); v != "" {
|
|
return v
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func getEnvInt(key string, fallback int) int {
|
|
if v := os.Getenv(key); v != "" {
|
|
if i, err := strconv.Atoi(v); err == nil {
|
|
return i
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func getEnvBool(key string, fallback bool) bool {
|
|
if v := os.Getenv(key); v != "" {
|
|
return strings.EqualFold(v, "true") || v == "1"
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func truncate(s string, n int) string {
|
|
if len(s) <= n {
|
|
return s
|
|
}
|
|
return s[:n] + "..."
|
|
}
|