feature/go-handler-refactor #1

Merged
billy merged 3 commits from feature/go-handler-refactor into main 2026-02-20 12:33:52 +00:00
4 changed files with 117 additions and 82 deletions
Showing only changes of commit 8ef0c93e47 - Show all commits

9
.dockerignore Normal file
View File

@@ -0,0 +1,9 @@
.git
.gitignore
*.md
LICENSE
renovate.json
*_test.go
e2e_test.go
__pycache__
.env*

View File

@@ -10,7 +10,7 @@ RUN go mod download
COPY . . COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -o /voice-assistant . RUN CGO_ENABLED=0 GOOS=linux GOAMD64=v3 go build -ldflags="-w -s" -o /voice-assistant .
# Runtime stage # Runtime stage
FROM scratch FROM scratch

130
main.go
View File

@@ -2,7 +2,6 @@ package main
import ( import (
"context" "context"
"encoding/base64"
"fmt" "fmt"
"log/slog" "log/slog"
"os" "os"
@@ -15,6 +14,8 @@ import (
"git.daviestechlabs.io/daviestechlabs/handler-base/clients" "git.daviestechlabs.io/daviestechlabs/handler-base/clients"
"git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/config"
"git.daviestechlabs.io/daviestechlabs/handler-base/handler" "git.daviestechlabs.io/daviestechlabs/handler-base/handler"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
) )
func main() { func main() {
@@ -42,31 +43,47 @@ func main() {
h := handler.New("voice.request", cfg) h := handler.New("voice.request", cfg)
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
requestID := strVal(data, "request_id", "unknown") req, err := natsutil.Decode[messages.VoiceRequest](msg.Data)
audioB64 := strVal(data, "audio", "") if err != nil {
language := strVal(data, "language", sttLanguage) return &messages.VoiceResponse{Error: "Invalid request encoding"}, nil
collection := strVal(data, "collection", ragCollection) }
requestID := req.RequestID
if requestID == "" {
requestID = "unknown"
}
language := req.Language
if language == "" {
language = sttLanguage
}
collection := req.Collection
if collection == "" {
collection = ragCollection
}
slog.Info("processing voice request", "request_id", requestID) slog.Info("processing voice request", "request_id", requestID)
// 1. Decode audio errResp := func(msg string) (any, error) {
audioBytes, err := base64.StdEncoding.DecodeString(audioB64) return &messages.VoiceResponse{RequestID: requestID, Error: msg}, nil
if err != nil { }
return map[string]any{"request_id": requestID, "error": "Invalid audio encoding"}, nil
// 1. Audio arrives as raw bytes — no base64 decode needed
if len(req.Audio) == 0 {
return errResp("No audio data")
} }
// 2. Transcribe audio → text // 2. Transcribe audio → text
transcription, err := stt.Transcribe(ctx, audioBytes, language) transcription, err := stt.Transcribe(ctx, req.Audio, language)
if err != nil { if err != nil {
slog.Error("STT failed", "error", err) slog.Error("STT failed", "error", err)
return map[string]any{"request_id": requestID, "error": "Transcription failed"}, nil return errResp("Transcription failed")
} }
query := strings.TrimSpace(transcription.Text) query := strings.TrimSpace(transcription.Text)
if query == "" { if query == "" {
slog.Warn("empty transcription", "request_id", requestID) slog.Warn("empty transcription", "request_id", requestID)
return map[string]any{"request_id": requestID, "error": "Could not transcribe audio"}, nil return errResp("Could not transcribe audio")
} }
slog.Info("transcribed", "text", truncate(query, 50)) slog.Info("transcribed", "text", truncate(query, 50))
@@ -75,7 +92,7 @@ func main() {
embedding, err := embeddings.EmbedSingle(ctx, query) embedding, err := embeddings.EmbedSingle(ctx, query)
if err != nil { if err != nil {
slog.Error("embedding failed", "error", err) slog.Error("embedding failed", "error", err)
return map[string]any{"request_id": requestID, "error": "Embedding failed"}, nil return errResp("Embedding failed")
} }
// 4. Search Milvus for context (placeholder — requires Milvus SDK) // 4. Search Milvus for context (placeholder — requires Milvus SDK)
@@ -83,23 +100,25 @@ func main() {
_ = collection _ = collection
_ = embedding _ = embedding
_ = ragTopK _ = ragTopK
var documents []map[string]any // Milvus results placeholder type docResult struct {
Document string
Score float64
}
var documents []docResult // Milvus results placeholder
// 5. Rerank documents // 5. Rerank documents
if len(documents) > 0 { if len(documents) > 0 {
texts := make([]string, len(documents)) texts := make([]string, len(documents))
for i, d := range documents { for i, d := range documents {
if t, ok := d["text"].(string); ok { texts[i] = d.Document
texts[i] = t
}
} }
reranked, err := reranker.Rerank(ctx, query, texts, ragRerankTopK) reranked, err := reranker.Rerank(ctx, query, texts, ragRerankTopK)
if err != nil { if err != nil {
slog.Error("rerank failed", "error", err) slog.Error("rerank failed", "error", err)
} else { } else {
documents = make([]map[string]any, len(reranked)) documents = make([]docResult, len(reranked))
for i, r := range reranked { for i, r := range reranked {
documents[i] = map[string]any{"document": r.Document, "score": r.Score} documents[i] = docResult{Document: r.Document, Score: r.Score}
} }
} }
} }
@@ -107,9 +126,7 @@ func main() {
// 6. Build context // 6. Build context
var contextParts []string var contextParts []string
for _, d := range documents { for _, d := range documents {
if t, ok := d["document"].(string); ok { contextParts = append(contextParts, d.Document)
contextParts = append(contextParts, t)
}
} }
contextText := strings.Join(contextParts, "\n\n") contextText := strings.Join(contextParts, "\n\n")
@@ -117,47 +134,40 @@ func main() {
responseText, err := llm.Generate(ctx, query, contextText, "") responseText, err := llm.Generate(ctx, query, contextText, "")
if err != nil { if err != nil {
slog.Error("LLM generation failed", "error", err) slog.Error("LLM generation failed", "error", err)
return map[string]any{"request_id": requestID, "error": "Generation failed"}, nil return errResp("Generation failed")
} }
// 8. Synthesize speech // 8. Synthesize speech — response audio is raw bytes
responseAudioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "") responseAudio, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
if err != nil { if err != nil {
slog.Error("TTS failed", "error", err) slog.Error("TTS failed", "error", err)
return map[string]any{"request_id": requestID, "error": "Speech synthesis failed"}, nil return errResp("Speech synthesis failed")
} }
responseAudioB64 := base64.StdEncoding.EncodeToString(responseAudioBytes)
// Build response // Build typed response
result := map[string]any{ result := &messages.VoiceResponse{
"request_id": requestID, RequestID: requestID,
"response": responseText, Response: responseText,
"audio": responseAudioB64, Audio: responseAudio,
} }
if includeTranscription { if includeTranscription {
result["transcription"] = query result.Transcription = query
} }
if includeSources && len(documents) > 0 { if includeSources && len(documents) > 0 {
sources := make([]map[string]any, 0, 3) limit := 3
for i, d := range documents { if len(documents) < limit {
if i >= 3 { limit = len(documents)
break }
} result.Sources = make([]messages.DocumentSource, limit)
text := "" for i := 0; i < limit; i++ {
if t, ok := d["document"].(string); ok && len(t) > 200 { text := documents[i].Document
text = t[:200] if len(text) > 200 {
} else if t, ok := d["document"].(string); ok { text = text[:200]
text = t }
} result.Sources[i] = messages.DocumentSource{Text: text, Score: documents[i].Score}
score := 0.0
if s, ok := d["score"].(float64); ok {
score = s
}
sources = append(sources, map[string]any{"text": text, "score": score})
} }
result["sources"] = sources
} }
// Publish to response subject // Publish to response subject
@@ -176,24 +186,6 @@ func main() {
// Helpers // Helpers
func strVal(m map[string]any, key, fallback string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return fallback
}
func boolVal(m map[string]any, key string, fallback bool) bool {
if v, ok := m[key]; ok {
if b, ok := v.(bool); ok {
return b
}
}
return fallback
}
func getEnv(key, fallback string) string { func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
return v return v

View File

@@ -2,25 +2,59 @@ package main
import ( import (
"testing" "testing"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
"github.com/vmihailenco/msgpack/v5"
) )
func TestStrVal(t *testing.T) { func TestVoiceRequestDecode(t *testing.T) {
m := map[string]any{"key": "value"} req := messages.VoiceRequest{
if got := strVal(m, "key", ""); got != "value" { RequestID: "req-123",
t.Errorf("strVal(key) = %q", got) Audio: []byte{0x01, 0x02, 0x03},
Language: "en",
Collection: "docs",
} }
if got := strVal(m, "missing", "def"); got != "def" { data, err := msgpack.Marshal(&req)
t.Errorf("strVal(missing) = %q", got) if err != nil {
t.Fatal(err)
}
decoded, err := natsutil.Decode[messages.VoiceRequest](data)
if err != nil {
t.Fatal(err)
}
if decoded.RequestID != "req-123" {
t.Errorf("RequestID = %q", decoded.RequestID)
}
if len(decoded.Audio) != 3 {
t.Errorf("Audio len = %d", len(decoded.Audio))
} }
} }
func TestBoolVal(t *testing.T) { func TestVoiceResponseRoundtrip(t *testing.T) {
m := map[string]any{"flag": true} resp := messages.VoiceResponse{
if got := boolVal(m, "flag", false); !got { RequestID: "req-456",
t.Error("expected true") Response: "It is sunny today.",
Audio: make([]byte, 8000),
Transcription: "What is the weather?",
Sources: []messages.DocumentSource{{Text: "weather doc", Score: 0.9}},
} }
if got := boolVal(m, "missing", false); got { data, err := msgpack.Marshal(&resp)
t.Error("expected false fallback") if err != nil {
t.Fatal(err)
}
var got messages.VoiceResponse
if err := msgpack.Unmarshal(data, &got); err != nil {
t.Fatal(err)
}
if got.Response != "It is sunny today." {
t.Errorf("Response = %q", got.Response)
}
if len(got.Audio) != 8000 {
t.Errorf("Audio len = %d", len(got.Audio))
}
if len(got.Sources) != 1 || got.Sources[0].Text != "weather doc" {
t.Errorf("Sources = %v", got.Sources)
} }
} }