Files
voice-assistant/main.go
Billy D. 2e66cac1e9 feat: rewrite voice-assistant in Go
Replace Python voice assistant with Go for smaller container images.
Uses handler-base Go module for NATS, health, telemetry, and all service clients.

- Full pipeline: STT → embed → Milvus → rerank → LLM → TTS
- Base64 audio encode/decode
- Dockerfile: multi-stage golang:1.25-alpine → scratch
- CI: Gitea Actions with lint/test/release/docker/notify
2026-02-19 18:00:58 -05:00

226 lines
6.0 KiB
Go

package main
import (
"context"
"encoding/base64"
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"time"
"github.com/nats-io/nats.go"
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
)
func main() {
cfg := config.Load()
cfg.ServiceName = "voice-assistant"
cfg.NATSQueueGroup = "voice-assistants"
// Voice-specific settings
ragTopK := getEnvInt("RAG_TOP_K", 10)
ragRerankTopK := getEnvInt("RAG_RERANK_TOP_K", 5)
ragCollection := getEnv("RAG_COLLECTION", "documents")
sttLanguage := getEnv("STT_LANGUAGE", "") // empty = auto-detect
ttsLanguage := getEnv("TTS_LANGUAGE", "en")
includeTranscription := getEnvBool("INCLUDE_TRANSCRIPTION", true)
includeSources := getEnvBool("INCLUDE_SOURCES", false)
// Service clients
timeout := 60 * time.Second
stt := clients.NewSTTClient(cfg.STTURL, timeout)
embeddings := clients.NewEmbeddingsClient(cfg.EmbeddingsURL, timeout, "")
reranker := clients.NewRerankerClient(cfg.RerankerURL, timeout)
llm := clients.NewLLMClient(cfg.LLMURL, timeout)
tts := clients.NewTTSClient(cfg.TTSURL, timeout, ttsLanguage)
milvus := clients.NewMilvusClient(cfg.MilvusHost, cfg.MilvusPort, ragCollection)
h := handler.New("voice.request", cfg)
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
requestID := strVal(data, "request_id", "unknown")
audioB64 := strVal(data, "audio", "")
language := strVal(data, "language", sttLanguage)
collection := strVal(data, "collection", ragCollection)
slog.Info("processing voice request", "request_id", requestID)
// 1. Decode audio
audioBytes, err := base64.StdEncoding.DecodeString(audioB64)
if err != nil {
return map[string]any{"request_id": requestID, "error": "Invalid audio encoding"}, nil
}
// 2. Transcribe audio → text
transcription, err := stt.Transcribe(ctx, audioBytes, language)
if err != nil {
slog.Error("STT failed", "error", err)
return map[string]any{"request_id": requestID, "error": "Transcription failed"}, nil
}
query := strings.TrimSpace(transcription.Text)
if query == "" {
slog.Warn("empty transcription", "request_id", requestID)
return map[string]any{"request_id": requestID, "error": "Could not transcribe audio"}, nil
}
slog.Info("transcribed", "text", truncate(query, 50))
// 3. Generate query embedding
embedding, err := embeddings.EmbedSingle(ctx, query)
if err != nil {
slog.Error("embedding failed", "error", err)
return map[string]any{"request_id": requestID, "error": "Embedding failed"}, nil
}
// 4. Search Milvus for context (placeholder — requires Milvus SDK)
_ = milvus
_ = collection
_ = embedding
_ = ragTopK
var documents []map[string]any // Milvus results placeholder
// 5. Rerank documents
if len(documents) > 0 {
texts := make([]string, len(documents))
for i, d := range documents {
if t, ok := d["text"].(string); ok {
texts[i] = t
}
}
reranked, err := reranker.Rerank(ctx, query, texts, ragRerankTopK)
if err != nil {
slog.Error("rerank failed", "error", err)
} else {
documents = make([]map[string]any, len(reranked))
for i, r := range reranked {
documents[i] = map[string]any{"document": r.Document, "score": r.Score}
}
}
}
// 6. Build context
var contextParts []string
for _, d := range documents {
if t, ok := d["document"].(string); ok {
contextParts = append(contextParts, t)
}
}
contextText := strings.Join(contextParts, "\n\n")
// 7. Generate LLM response
responseText, err := llm.Generate(ctx, query, contextText, "")
if err != nil {
slog.Error("LLM generation failed", "error", err)
return map[string]any{"request_id": requestID, "error": "Generation failed"}, nil
}
// 8. Synthesize speech
responseAudioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
if err != nil {
slog.Error("TTS failed", "error", err)
return map[string]any{"request_id": requestID, "error": "Speech synthesis failed"}, nil
}
responseAudioB64 := base64.StdEncoding.EncodeToString(responseAudioBytes)
// Build response
result := map[string]any{
"request_id": requestID,
"response": responseText,
"audio": responseAudioB64,
}
if includeTranscription {
result["transcription"] = query
}
if includeSources && len(documents) > 0 {
sources := make([]map[string]any, 0, 3)
for i, d := range documents {
if i >= 3 {
break
}
text := ""
if t, ok := d["document"].(string); ok && len(t) > 200 {
text = t[:200]
} else if t, ok := d["document"].(string); ok {
text = t
}
score := 0.0
if s, ok := d["score"].(float64); ok {
score = s
}
sources = append(sources, map[string]any{"text": text, "score": score})
}
result["sources"] = sources
}
// Publish to response subject
responseSubject := fmt.Sprintf("voice.response.%s", requestID)
_ = h.NATS.Publish(responseSubject, result)
slog.Info("completed voice request", "request_id", requestID)
return result, nil
})
if err := h.Run(); err != nil {
slog.Error("handler failed", "error", err)
os.Exit(1)
}
}
// Helpers
func strVal(m map[string]any, key, fallback string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return fallback
}
func boolVal(m map[string]any, key string, fallback bool) bool {
if v, ok := m[key]; ok {
if b, ok := v.(bool); ok {
return b
}
}
return fallback
}
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getEnvInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return i
}
}
return fallback
}
func getEnvBool(key string, fallback bool) bool {
if v := os.Getenv(key); v != "" {
return strings.EqualFold(v, "true") || v == "1"
}
return fallback
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}