Files
voice-assistant/main.go
Billy D. 8ef0c93e47
Some checks failed
CI / Lint (pull_request) Failing after 1m22s
CI / Test (pull_request) Failing after 1m21s
CI / Release (pull_request) Has been skipped
CI / Notify (pull_request) Successful in 1s
feat: migrate to typed messages, drop base64
- Switch OnMessage → OnTypedMessage with natsutil.Decode[messages.VoiceRequest]
- Return *messages.VoiceResponse with raw []byte audio (no base64)
- Use messages.DocumentSource for RAG sources
- Remove strVal/boolVal helpers
- Add .dockerignore, GOAMD64=v3 in Dockerfile
- Update tests for typed structs (7 tests pass)
2026-02-20 07:10:51 -05:00

218 lines
5.6 KiB
Go

package main
import (
"context"
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"time"
"github.com/nats-io/nats.go"
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
)
func main() {
cfg := config.Load()
cfg.ServiceName = "voice-assistant"
cfg.NATSQueueGroup = "voice-assistants"
// Voice-specific settings
ragTopK := getEnvInt("RAG_TOP_K", 10)
ragRerankTopK := getEnvInt("RAG_RERANK_TOP_K", 5)
ragCollection := getEnv("RAG_COLLECTION", "documents")
sttLanguage := getEnv("STT_LANGUAGE", "") // empty = auto-detect
ttsLanguage := getEnv("TTS_LANGUAGE", "en")
includeTranscription := getEnvBool("INCLUDE_TRANSCRIPTION", true)
includeSources := getEnvBool("INCLUDE_SOURCES", false)
// Service clients
timeout := 60 * time.Second
stt := clients.NewSTTClient(cfg.STTURL(), timeout)
embeddings := clients.NewEmbeddingsClient(cfg.EmbeddingsURL(), timeout, "")
reranker := clients.NewRerankerClient(cfg.RerankerURL(), timeout)
llm := clients.NewLLMClient(cfg.LLMURL(), timeout)
tts := clients.NewTTSClient(cfg.TTSURL(), timeout, ttsLanguage)
milvus := clients.NewMilvusClient(cfg.MilvusHost, cfg.MilvusPort, ragCollection)
h := handler.New("voice.request", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
req, err := natsutil.Decode[messages.VoiceRequest](msg.Data)
if err != nil {
return &messages.VoiceResponse{Error: "Invalid request encoding"}, nil
}
requestID := req.RequestID
if requestID == "" {
requestID = "unknown"
}
language := req.Language
if language == "" {
language = sttLanguage
}
collection := req.Collection
if collection == "" {
collection = ragCollection
}
slog.Info("processing voice request", "request_id", requestID)
errResp := func(msg string) (any, error) {
return &messages.VoiceResponse{RequestID: requestID, Error: msg}, nil
}
// 1. Audio arrives as raw bytes — no base64 decode needed
if len(req.Audio) == 0 {
return errResp("No audio data")
}
// 2. Transcribe audio → text
transcription, err := stt.Transcribe(ctx, req.Audio, language)
if err != nil {
slog.Error("STT failed", "error", err)
return errResp("Transcription failed")
}
query := strings.TrimSpace(transcription.Text)
if query == "" {
slog.Warn("empty transcription", "request_id", requestID)
return errResp("Could not transcribe audio")
}
slog.Info("transcribed", "text", truncate(query, 50))
// 3. Generate query embedding
embedding, err := embeddings.EmbedSingle(ctx, query)
if err != nil {
slog.Error("embedding failed", "error", err)
return errResp("Embedding failed")
}
// 4. Search Milvus for context (placeholder — requires Milvus SDK)
_ = milvus
_ = collection
_ = embedding
_ = ragTopK
type docResult struct {
Document string
Score float64
}
var documents []docResult // Milvus results placeholder
// 5. Rerank documents
if len(documents) > 0 {
texts := make([]string, len(documents))
for i, d := range documents {
texts[i] = d.Document
}
reranked, err := reranker.Rerank(ctx, query, texts, ragRerankTopK)
if err != nil {
slog.Error("rerank failed", "error", err)
} else {
documents = make([]docResult, len(reranked))
for i, r := range reranked {
documents[i] = docResult{Document: r.Document, Score: r.Score}
}
}
}
// 6. Build context
var contextParts []string
for _, d := range documents {
contextParts = append(contextParts, d.Document)
}
contextText := strings.Join(contextParts, "\n\n")
// 7. Generate LLM response
responseText, err := llm.Generate(ctx, query, contextText, "")
if err != nil {
slog.Error("LLM generation failed", "error", err)
return errResp("Generation failed")
}
// 8. Synthesize speech — response audio is raw bytes
responseAudio, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
if err != nil {
slog.Error("TTS failed", "error", err)
return errResp("Speech synthesis failed")
}
// Build typed response
result := &messages.VoiceResponse{
RequestID: requestID,
Response: responseText,
Audio: responseAudio,
}
if includeTranscription {
result.Transcription = query
}
if includeSources && len(documents) > 0 {
limit := 3
if len(documents) < limit {
limit = len(documents)
}
result.Sources = make([]messages.DocumentSource, limit)
for i := 0; i < limit; i++ {
text := documents[i].Document
if len(text) > 200 {
text = text[:200]
}
result.Sources[i] = messages.DocumentSource{Text: text, Score: documents[i].Score}
}
}
// Publish to response subject
responseSubject := fmt.Sprintf("voice.response.%s", requestID)
_ = h.NATS.Publish(responseSubject, result)
slog.Info("completed voice request", "request_id", requestID)
return result, nil
})
if err := h.Run(); err != nil {
slog.Error("handler failed", "error", err)
os.Exit(1)
}
}
// Helpers
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getEnvInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return i
}
}
return fallback
}
func getEnvBool(key string, fallback bool) bool {
if v := os.Getenv(key); v != "" {
return strings.EqualFold(v, "true") || v == "1"
}
return fallback
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}