feat: migrate to typed messages, drop base64
- Switch OnMessage → OnTypedMessage with natsutil.Decode[messages.VoiceRequest] - Return *messages.VoiceResponse with raw []byte audio (no base64) - Use messages.DocumentSource for RAG sources - Remove strVal/boolVal helpers - Add .dockerignore, GOAMD64=v3 in Dockerfile - Update tests for typed structs (7 tests pass)
This commit is contained in:
9
.dockerignore
Normal file
9
.dockerignore
Normal file
@@ -0,0 +1,9 @@
|
||||
.git
|
||||
.gitignore
|
||||
*.md
|
||||
LICENSE
|
||||
renovate.json
|
||||
*_test.go
|
||||
e2e_test.go
|
||||
__pycache__
|
||||
.env*
|
||||
@@ -10,7 +10,7 @@ RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -o /voice-assistant .
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOAMD64=v3 go build -ldflags="-w -s" -o /voice-assistant .
|
||||
|
||||
# Runtime stage
|
||||
FROM scratch
|
||||
|
||||
126
main.go
126
main.go
@@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
@@ -15,6 +14,8 @@ import (
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -42,31 +43,47 @@ func main() {
|
||||
|
||||
h := handler.New("voice.request", cfg)
|
||||
|
||||
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||
requestID := strVal(data, "request_id", "unknown")
|
||||
audioB64 := strVal(data, "audio", "")
|
||||
language := strVal(data, "language", sttLanguage)
|
||||
collection := strVal(data, "collection", ragCollection)
|
||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
||||
req, err := natsutil.Decode[messages.VoiceRequest](msg.Data)
|
||||
if err != nil {
|
||||
return &messages.VoiceResponse{Error: "Invalid request encoding"}, nil
|
||||
}
|
||||
|
||||
requestID := req.RequestID
|
||||
if requestID == "" {
|
||||
requestID = "unknown"
|
||||
}
|
||||
language := req.Language
|
||||
if language == "" {
|
||||
language = sttLanguage
|
||||
}
|
||||
collection := req.Collection
|
||||
if collection == "" {
|
||||
collection = ragCollection
|
||||
}
|
||||
|
||||
slog.Info("processing voice request", "request_id", requestID)
|
||||
|
||||
// 1. Decode audio
|
||||
audioBytes, err := base64.StdEncoding.DecodeString(audioB64)
|
||||
if err != nil {
|
||||
return map[string]any{"request_id": requestID, "error": "Invalid audio encoding"}, nil
|
||||
errResp := func(msg string) (any, error) {
|
||||
return &messages.VoiceResponse{RequestID: requestID, Error: msg}, nil
|
||||
}
|
||||
|
||||
// 1. Audio arrives as raw bytes — no base64 decode needed
|
||||
if len(req.Audio) == 0 {
|
||||
return errResp("No audio data")
|
||||
}
|
||||
|
||||
// 2. Transcribe audio → text
|
||||
transcription, err := stt.Transcribe(ctx, audioBytes, language)
|
||||
transcription, err := stt.Transcribe(ctx, req.Audio, language)
|
||||
if err != nil {
|
||||
slog.Error("STT failed", "error", err)
|
||||
return map[string]any{"request_id": requestID, "error": "Transcription failed"}, nil
|
||||
return errResp("Transcription failed")
|
||||
}
|
||||
|
||||
query := strings.TrimSpace(transcription.Text)
|
||||
if query == "" {
|
||||
slog.Warn("empty transcription", "request_id", requestID)
|
||||
return map[string]any{"request_id": requestID, "error": "Could not transcribe audio"}, nil
|
||||
return errResp("Could not transcribe audio")
|
||||
}
|
||||
|
||||
slog.Info("transcribed", "text", truncate(query, 50))
|
||||
@@ -75,7 +92,7 @@ func main() {
|
||||
embedding, err := embeddings.EmbedSingle(ctx, query)
|
||||
if err != nil {
|
||||
slog.Error("embedding failed", "error", err)
|
||||
return map[string]any{"request_id": requestID, "error": "Embedding failed"}, nil
|
||||
return errResp("Embedding failed")
|
||||
}
|
||||
|
||||
// 4. Search Milvus for context (placeholder — requires Milvus SDK)
|
||||
@@ -83,23 +100,25 @@ func main() {
|
||||
_ = collection
|
||||
_ = embedding
|
||||
_ = ragTopK
|
||||
var documents []map[string]any // Milvus results placeholder
|
||||
type docResult struct {
|
||||
Document string
|
||||
Score float64
|
||||
}
|
||||
var documents []docResult // Milvus results placeholder
|
||||
|
||||
// 5. Rerank documents
|
||||
if len(documents) > 0 {
|
||||
texts := make([]string, len(documents))
|
||||
for i, d := range documents {
|
||||
if t, ok := d["text"].(string); ok {
|
||||
texts[i] = t
|
||||
}
|
||||
texts[i] = d.Document
|
||||
}
|
||||
reranked, err := reranker.Rerank(ctx, query, texts, ragRerankTopK)
|
||||
if err != nil {
|
||||
slog.Error("rerank failed", "error", err)
|
||||
} else {
|
||||
documents = make([]map[string]any, len(reranked))
|
||||
documents = make([]docResult, len(reranked))
|
||||
for i, r := range reranked {
|
||||
documents[i] = map[string]any{"document": r.Document, "score": r.Score}
|
||||
documents[i] = docResult{Document: r.Document, Score: r.Score}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -107,9 +126,7 @@ func main() {
|
||||
// 6. Build context
|
||||
var contextParts []string
|
||||
for _, d := range documents {
|
||||
if t, ok := d["document"].(string); ok {
|
||||
contextParts = append(contextParts, t)
|
||||
}
|
||||
contextParts = append(contextParts, d.Document)
|
||||
}
|
||||
contextText := strings.Join(contextParts, "\n\n")
|
||||
|
||||
@@ -117,47 +134,40 @@ func main() {
|
||||
responseText, err := llm.Generate(ctx, query, contextText, "")
|
||||
if err != nil {
|
||||
slog.Error("LLM generation failed", "error", err)
|
||||
return map[string]any{"request_id": requestID, "error": "Generation failed"}, nil
|
||||
return errResp("Generation failed")
|
||||
}
|
||||
|
||||
// 8. Synthesize speech
|
||||
responseAudioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
|
||||
// 8. Synthesize speech — response audio is raw bytes
|
||||
responseAudio, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
|
||||
if err != nil {
|
||||
slog.Error("TTS failed", "error", err)
|
||||
return map[string]any{"request_id": requestID, "error": "Speech synthesis failed"}, nil
|
||||
return errResp("Speech synthesis failed")
|
||||
}
|
||||
responseAudioB64 := base64.StdEncoding.EncodeToString(responseAudioBytes)
|
||||
|
||||
// Build response
|
||||
result := map[string]any{
|
||||
"request_id": requestID,
|
||||
"response": responseText,
|
||||
"audio": responseAudioB64,
|
||||
// Build typed response
|
||||
result := &messages.VoiceResponse{
|
||||
RequestID: requestID,
|
||||
Response: responseText,
|
||||
Audio: responseAudio,
|
||||
}
|
||||
|
||||
if includeTranscription {
|
||||
result["transcription"] = query
|
||||
result.Transcription = query
|
||||
}
|
||||
|
||||
if includeSources && len(documents) > 0 {
|
||||
sources := make([]map[string]any, 0, 3)
|
||||
for i, d := range documents {
|
||||
if i >= 3 {
|
||||
break
|
||||
limit := 3
|
||||
if len(documents) < limit {
|
||||
limit = len(documents)
|
||||
}
|
||||
text := ""
|
||||
if t, ok := d["document"].(string); ok && len(t) > 200 {
|
||||
text = t[:200]
|
||||
} else if t, ok := d["document"].(string); ok {
|
||||
text = t
|
||||
result.Sources = make([]messages.DocumentSource, limit)
|
||||
for i := 0; i < limit; i++ {
|
||||
text := documents[i].Document
|
||||
if len(text) > 200 {
|
||||
text = text[:200]
|
||||
}
|
||||
score := 0.0
|
||||
if s, ok := d["score"].(float64); ok {
|
||||
score = s
|
||||
result.Sources[i] = messages.DocumentSource{Text: text, Score: documents[i].Score}
|
||||
}
|
||||
sources = append(sources, map[string]any{"text": text, "score": score})
|
||||
}
|
||||
result["sources"] = sources
|
||||
}
|
||||
|
||||
// Publish to response subject
|
||||
@@ -176,24 +186,6 @@ func main() {
|
||||
|
||||
// Helpers
|
||||
|
||||
func strVal(m map[string]any, key, fallback string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func boolVal(m map[string]any, key string, fallback bool) bool {
|
||||
if v, ok := m[key]; ok {
|
||||
if b, ok := v.(bool); ok {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
|
||||
58
main_test.go
58
main_test.go
@@ -2,25 +2,59 @@ package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
func TestStrVal(t *testing.T) {
|
||||
m := map[string]any{"key": "value"}
|
||||
if got := strVal(m, "key", ""); got != "value" {
|
||||
t.Errorf("strVal(key) = %q", got)
|
||||
func TestVoiceRequestDecode(t *testing.T) {
|
||||
req := messages.VoiceRequest{
|
||||
RequestID: "req-123",
|
||||
Audio: []byte{0x01, 0x02, 0x03},
|
||||
Language: "en",
|
||||
Collection: "docs",
|
||||
}
|
||||
if got := strVal(m, "missing", "def"); got != "def" {
|
||||
t.Errorf("strVal(missing) = %q", got)
|
||||
data, err := msgpack.Marshal(&req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
decoded, err := natsutil.Decode[messages.VoiceRequest](data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if decoded.RequestID != "req-123" {
|
||||
t.Errorf("RequestID = %q", decoded.RequestID)
|
||||
}
|
||||
if len(decoded.Audio) != 3 {
|
||||
t.Errorf("Audio len = %d", len(decoded.Audio))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolVal(t *testing.T) {
|
||||
m := map[string]any{"flag": true}
|
||||
if got := boolVal(m, "flag", false); !got {
|
||||
t.Error("expected true")
|
||||
func TestVoiceResponseRoundtrip(t *testing.T) {
|
||||
resp := messages.VoiceResponse{
|
||||
RequestID: "req-456",
|
||||
Response: "It is sunny today.",
|
||||
Audio: make([]byte, 8000),
|
||||
Transcription: "What is the weather?",
|
||||
Sources: []messages.DocumentSource{{Text: "weather doc", Score: 0.9}},
|
||||
}
|
||||
if got := boolVal(m, "missing", false); got {
|
||||
t.Error("expected false fallback")
|
||||
data, err := msgpack.Marshal(&resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var got messages.VoiceResponse
|
||||
if err := msgpack.Unmarshal(data, &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.Response != "It is sunny today." {
|
||||
t.Errorf("Response = %q", got.Response)
|
||||
}
|
||||
if len(got.Audio) != 8000 {
|
||||
t.Errorf("Audio len = %d", len(got.Audio))
|
||||
}
|
||||
if len(got.Sources) != 1 || got.Sources[0].Text != "weather doc" {
|
||||
t.Errorf("Sources = %v", got.Sources)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user