feature/go-handler-refactor #1
9
.dockerignore
Normal file
9
.dockerignore
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
*.md
|
||||||
|
LICENSE
|
||||||
|
renovate.json
|
||||||
|
*_test.go
|
||||||
|
e2e_test.go
|
||||||
|
__pycache__
|
||||||
|
.env*
|
||||||
@@ -10,7 +10,7 @@ RUN go mod download
|
|||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -o /voice-assistant .
|
RUN CGO_ENABLED=0 GOOS=linux GOAMD64=v3 go build -ldflags="-w -s" -o /voice-assistant .
|
||||||
|
|
||||||
# Runtime stage
|
# Runtime stage
|
||||||
FROM scratch
|
FROM scratch
|
||||||
|
|||||||
130
main.go
130
main.go
@@ -2,7 +2,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
@@ -15,6 +14,8 @@ import (
|
|||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -42,31 +43,47 @@ func main() {
|
|||||||
|
|
||||||
h := handler.New("voice.request", cfg)
|
h := handler.New("voice.request", cfg)
|
||||||
|
|
||||||
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
||||||
requestID := strVal(data, "request_id", "unknown")
|
req, err := natsutil.Decode[messages.VoiceRequest](msg.Data)
|
||||||
audioB64 := strVal(data, "audio", "")
|
if err != nil {
|
||||||
language := strVal(data, "language", sttLanguage)
|
return &messages.VoiceResponse{Error: "Invalid request encoding"}, nil
|
||||||
collection := strVal(data, "collection", ragCollection)
|
}
|
||||||
|
|
||||||
|
requestID := req.RequestID
|
||||||
|
if requestID == "" {
|
||||||
|
requestID = "unknown"
|
||||||
|
}
|
||||||
|
language := req.Language
|
||||||
|
if language == "" {
|
||||||
|
language = sttLanguage
|
||||||
|
}
|
||||||
|
collection := req.Collection
|
||||||
|
if collection == "" {
|
||||||
|
collection = ragCollection
|
||||||
|
}
|
||||||
|
|
||||||
slog.Info("processing voice request", "request_id", requestID)
|
slog.Info("processing voice request", "request_id", requestID)
|
||||||
|
|
||||||
// 1. Decode audio
|
errResp := func(msg string) (any, error) {
|
||||||
audioBytes, err := base64.StdEncoding.DecodeString(audioB64)
|
return &messages.VoiceResponse{RequestID: requestID, Error: msg}, nil
|
||||||
if err != nil {
|
}
|
||||||
return map[string]any{"request_id": requestID, "error": "Invalid audio encoding"}, nil
|
|
||||||
|
// 1. Audio arrives as raw bytes — no base64 decode needed
|
||||||
|
if len(req.Audio) == 0 {
|
||||||
|
return errResp("No audio data")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Transcribe audio → text
|
// 2. Transcribe audio → text
|
||||||
transcription, err := stt.Transcribe(ctx, audioBytes, language)
|
transcription, err := stt.Transcribe(ctx, req.Audio, language)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("STT failed", "error", err)
|
slog.Error("STT failed", "error", err)
|
||||||
return map[string]any{"request_id": requestID, "error": "Transcription failed"}, nil
|
return errResp("Transcription failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
query := strings.TrimSpace(transcription.Text)
|
query := strings.TrimSpace(transcription.Text)
|
||||||
if query == "" {
|
if query == "" {
|
||||||
slog.Warn("empty transcription", "request_id", requestID)
|
slog.Warn("empty transcription", "request_id", requestID)
|
||||||
return map[string]any{"request_id": requestID, "error": "Could not transcribe audio"}, nil
|
return errResp("Could not transcribe audio")
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("transcribed", "text", truncate(query, 50))
|
slog.Info("transcribed", "text", truncate(query, 50))
|
||||||
@@ -75,7 +92,7 @@ func main() {
|
|||||||
embedding, err := embeddings.EmbedSingle(ctx, query)
|
embedding, err := embeddings.EmbedSingle(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("embedding failed", "error", err)
|
slog.Error("embedding failed", "error", err)
|
||||||
return map[string]any{"request_id": requestID, "error": "Embedding failed"}, nil
|
return errResp("Embedding failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Search Milvus for context (placeholder — requires Milvus SDK)
|
// 4. Search Milvus for context (placeholder — requires Milvus SDK)
|
||||||
@@ -83,23 +100,25 @@ func main() {
|
|||||||
_ = collection
|
_ = collection
|
||||||
_ = embedding
|
_ = embedding
|
||||||
_ = ragTopK
|
_ = ragTopK
|
||||||
var documents []map[string]any // Milvus results placeholder
|
type docResult struct {
|
||||||
|
Document string
|
||||||
|
Score float64
|
||||||
|
}
|
||||||
|
var documents []docResult // Milvus results placeholder
|
||||||
|
|
||||||
// 5. Rerank documents
|
// 5. Rerank documents
|
||||||
if len(documents) > 0 {
|
if len(documents) > 0 {
|
||||||
texts := make([]string, len(documents))
|
texts := make([]string, len(documents))
|
||||||
for i, d := range documents {
|
for i, d := range documents {
|
||||||
if t, ok := d["text"].(string); ok {
|
texts[i] = d.Document
|
||||||
texts[i] = t
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
reranked, err := reranker.Rerank(ctx, query, texts, ragRerankTopK)
|
reranked, err := reranker.Rerank(ctx, query, texts, ragRerankTopK)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("rerank failed", "error", err)
|
slog.Error("rerank failed", "error", err)
|
||||||
} else {
|
} else {
|
||||||
documents = make([]map[string]any, len(reranked))
|
documents = make([]docResult, len(reranked))
|
||||||
for i, r := range reranked {
|
for i, r := range reranked {
|
||||||
documents[i] = map[string]any{"document": r.Document, "score": r.Score}
|
documents[i] = docResult{Document: r.Document, Score: r.Score}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -107,9 +126,7 @@ func main() {
|
|||||||
// 6. Build context
|
// 6. Build context
|
||||||
var contextParts []string
|
var contextParts []string
|
||||||
for _, d := range documents {
|
for _, d := range documents {
|
||||||
if t, ok := d["document"].(string); ok {
|
contextParts = append(contextParts, d.Document)
|
||||||
contextParts = append(contextParts, t)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
contextText := strings.Join(contextParts, "\n\n")
|
contextText := strings.Join(contextParts, "\n\n")
|
||||||
|
|
||||||
@@ -117,47 +134,40 @@ func main() {
|
|||||||
responseText, err := llm.Generate(ctx, query, contextText, "")
|
responseText, err := llm.Generate(ctx, query, contextText, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("LLM generation failed", "error", err)
|
slog.Error("LLM generation failed", "error", err)
|
||||||
return map[string]any{"request_id": requestID, "error": "Generation failed"}, nil
|
return errResp("Generation failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 8. Synthesize speech
|
// 8. Synthesize speech — response audio is raw bytes
|
||||||
responseAudioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
|
responseAudio, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("TTS failed", "error", err)
|
slog.Error("TTS failed", "error", err)
|
||||||
return map[string]any{"request_id": requestID, "error": "Speech synthesis failed"}, nil
|
return errResp("Speech synthesis failed")
|
||||||
}
|
}
|
||||||
responseAudioB64 := base64.StdEncoding.EncodeToString(responseAudioBytes)
|
|
||||||
|
|
||||||
// Build response
|
// Build typed response
|
||||||
result := map[string]any{
|
result := &messages.VoiceResponse{
|
||||||
"request_id": requestID,
|
RequestID: requestID,
|
||||||
"response": responseText,
|
Response: responseText,
|
||||||
"audio": responseAudioB64,
|
Audio: responseAudio,
|
||||||
}
|
}
|
||||||
|
|
||||||
if includeTranscription {
|
if includeTranscription {
|
||||||
result["transcription"] = query
|
result.Transcription = query
|
||||||
}
|
}
|
||||||
|
|
||||||
if includeSources && len(documents) > 0 {
|
if includeSources && len(documents) > 0 {
|
||||||
sources := make([]map[string]any, 0, 3)
|
limit := 3
|
||||||
for i, d := range documents {
|
if len(documents) < limit {
|
||||||
if i >= 3 {
|
limit = len(documents)
|
||||||
break
|
}
|
||||||
}
|
result.Sources = make([]messages.DocumentSource, limit)
|
||||||
text := ""
|
for i := 0; i < limit; i++ {
|
||||||
if t, ok := d["document"].(string); ok && len(t) > 200 {
|
text := documents[i].Document
|
||||||
text = t[:200]
|
if len(text) > 200 {
|
||||||
} else if t, ok := d["document"].(string); ok {
|
text = text[:200]
|
||||||
text = t
|
}
|
||||||
}
|
result.Sources[i] = messages.DocumentSource{Text: text, Score: documents[i].Score}
|
||||||
score := 0.0
|
|
||||||
if s, ok := d["score"].(float64); ok {
|
|
||||||
score = s
|
|
||||||
}
|
|
||||||
sources = append(sources, map[string]any{"text": text, "score": score})
|
|
||||||
}
|
}
|
||||||
result["sources"] = sources
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish to response subject
|
// Publish to response subject
|
||||||
@@ -176,24 +186,6 @@ func main() {
|
|||||||
|
|
||||||
// Helpers
|
// Helpers
|
||||||
|
|
||||||
func strVal(m map[string]any, key, fallback string) string {
|
|
||||||
if v, ok := m[key]; ok {
|
|
||||||
if s, ok := v.(string); ok {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
func boolVal(m map[string]any, key string, fallback bool) bool {
|
|
||||||
if v, ok := m[key]; ok {
|
|
||||||
if b, ok := v.(bool); ok {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
func getEnv(key, fallback string) string {
|
func getEnv(key, fallback string) string {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
return v
|
return v
|
||||||
|
|||||||
58
main_test.go
58
main_test.go
@@ -2,25 +2,59 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||||
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestStrVal(t *testing.T) {
|
func TestVoiceRequestDecode(t *testing.T) {
|
||||||
m := map[string]any{"key": "value"}
|
req := messages.VoiceRequest{
|
||||||
if got := strVal(m, "key", ""); got != "value" {
|
RequestID: "req-123",
|
||||||
t.Errorf("strVal(key) = %q", got)
|
Audio: []byte{0x01, 0x02, 0x03},
|
||||||
|
Language: "en",
|
||||||
|
Collection: "docs",
|
||||||
}
|
}
|
||||||
if got := strVal(m, "missing", "def"); got != "def" {
|
data, err := msgpack.Marshal(&req)
|
||||||
t.Errorf("strVal(missing) = %q", got)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
decoded, err := natsutil.Decode[messages.VoiceRequest](data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if decoded.RequestID != "req-123" {
|
||||||
|
t.Errorf("RequestID = %q", decoded.RequestID)
|
||||||
|
}
|
||||||
|
if len(decoded.Audio) != 3 {
|
||||||
|
t.Errorf("Audio len = %d", len(decoded.Audio))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBoolVal(t *testing.T) {
|
func TestVoiceResponseRoundtrip(t *testing.T) {
|
||||||
m := map[string]any{"flag": true}
|
resp := messages.VoiceResponse{
|
||||||
if got := boolVal(m, "flag", false); !got {
|
RequestID: "req-456",
|
||||||
t.Error("expected true")
|
Response: "It is sunny today.",
|
||||||
|
Audio: make([]byte, 8000),
|
||||||
|
Transcription: "What is the weather?",
|
||||||
|
Sources: []messages.DocumentSource{{Text: "weather doc", Score: 0.9}},
|
||||||
}
|
}
|
||||||
if got := boolVal(m, "missing", false); got {
|
data, err := msgpack.Marshal(&resp)
|
||||||
t.Error("expected false fallback")
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var got messages.VoiceResponse
|
||||||
|
if err := msgpack.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got.Response != "It is sunny today." {
|
||||||
|
t.Errorf("Response = %q", got.Response)
|
||||||
|
}
|
||||||
|
if len(got.Audio) != 8000 {
|
||||||
|
t.Errorf("Audio len = %d", len(got.Audio))
|
||||||
|
}
|
||||||
|
if len(got.Sources) != 1 || got.Sources[0].Text != "weather doc" {
|
||||||
|
t.Errorf("Sources = %v", got.Sources)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user