diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..7ea5baa --- /dev/null +++ b/.dockerignore @@ -0,0 +1,9 @@ +.git +.gitignore +*.md +LICENSE +renovate.json +*_test.go +e2e_test.go +__pycache__ +.env* diff --git a/Dockerfile b/Dockerfile index 941df07..520843f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ RUN go mod download COPY . . # Build static binary -RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -o /tts-module . +RUN CGO_ENABLED=0 GOOS=linux GOAMD64=v3 go build -ldflags="-w -s" -o /tts-module . # Runtime stage - scratch for minimal image FROM scratch diff --git a/e2e_test.go b/e2e_test.go index d83ea8e..16f9668 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -2,7 +2,6 @@ package main import ( "bytes" - "encoding/base64" "encoding/json" "io" "net/http" @@ -11,6 +10,9 @@ import ( "path/filepath" "strconv" "testing" + + "git.daviestechlabs.io/daviestechlabs/handler-base/messages" + "github.com/vmihailenco/msgpack/v5" ) // ──────────────────────────────────────────────────────────────────────────── @@ -63,26 +65,29 @@ func TestSynthesisE2E_StreamChunks(t *testing.T) { chunkIdx := i / chunkSize isLast := end >= len(audioBytes) - // Verify chunk message shape - msg := map[string]any{ - "session_id": "test-session", - "chunk_index": chunkIdx, - "total_chunks": totalChunks, - "audio_b64": base64.StdEncoding.EncodeToString(chunk), - "is_last": isLast, - "sample_rate": 24000, + // Verify typed chunk struct + msg := messages.TTSAudioChunk{ + SessionID: "test-session", + ChunkIndex: chunkIdx, + TotalChunks: totalChunks, + Audio: chunk, + IsLast: isLast, + SampleRate: 24000, } - // Round-trip through JSON - data, _ := json.Marshal(msg) - var decoded map[string]any - json.Unmarshal(data, &decoded) + // Round-trip through msgpack + data, _ := msgpack.Marshal(&msg) + var decoded messages.TTSAudioChunk + msgpack.Unmarshal(data, &decoded) - if decoded["session_id"] != "test-session" { - t.Errorf("chunk %d: session = %v", chunkIdx, decoded["session_id"]) + if decoded.SessionID != "test-session" { + t.Errorf("chunk %d: session = %v", chunkIdx, decoded.SessionID) } - if decoded["is_last"] != isLast { - t.Errorf("chunk %d: is_last = %v, want %v", chunkIdx, decoded["is_last"], isLast) + if decoded.IsLast != isLast { + t.Errorf("chunk %d: is_last = %v, want %v", chunkIdx, decoded.IsLast, isLast) + } + if len(decoded.Audio) != len(chunk) { + t.Errorf("chunk %d: audio len = %d, want %d", chunkIdx, len(decoded.Audio), len(chunk)) } } } @@ -260,8 +265,14 @@ func BenchmarkAudioChunking(b *testing.B) { end = len(audioBytes) } chunk := audioBytes[i:end] - _ = base64.StdEncoding.EncodeToString(chunk) - _ = totalChunks + msg := &messages.TTSAudioChunk{ + SessionID: "bench", + ChunkIndex: i / chunkSize, + TotalChunks: totalChunks, + Audio: chunk, + SampleRate: 24000, + } + msgpack.Marshal(msg) } } } diff --git a/main.go b/main.go index ad38097..b8dfc2c 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package main import ( "bytes" "context" - "encoding/base64" "encoding/json" "fmt" "io" @@ -23,6 +22,7 @@ import ( "git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/health" + "git.daviestechlabs.io/daviestechlabs/handler-base/messages" "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" "git.daviestechlabs.io/daviestechlabs/handler-base/telemetry" ) @@ -128,16 +128,16 @@ func (vr *VoiceRegistry) get(name string) *CustomVoice { return vr.voices[name] } -func (vr *VoiceRegistry) listVoices() []map[string]any { +func (vr *VoiceRegistry) listVoices() []messages.TTSVoiceInfo { vr.mu.RLock() defer vr.mu.RUnlock() - result := make([]map[string]any, 0, len(vr.voices)) + result := make([]messages.TTSVoiceInfo, 0, len(vr.voices)) for _, v := range vr.voices { - result = append(result, map[string]any{ - "name": v.Name, - "language": v.Language, - "model_type": v.ModelType, - "created_at": v.CreatedAt, + result = append(result, messages.TTSVoiceInfo{ + Name: v.Name, + Language: v.Language, + ModelType: v.ModelType, + CreatedAt: v.CreatedAt, }) } return result @@ -222,11 +222,11 @@ func main() { // Helper: publish status publishStatus := func(sessionID, status, message string) { - statusMsg := map[string]any{ - "session_id": sessionID, - "status": status, - "message": message, - "timestamp": time.Now().Unix(), + statusMsg := &messages.TTSStatus{ + SessionID: sessionID, + Status: status, + Message: message, + Timestamp: time.Now().Unix(), } _ = nc.Publish(fmt.Sprintf("%s.%s", statusSubjectPrefix, sessionID), statusMsg) } @@ -272,7 +272,7 @@ func main() { return io.ReadAll(resp.Body) } - // Helper: stream audio chunks + // Helper: stream audio chunks — raw bytes, no base64 streamAudio := func(sessionID string, audioBytes []byte) { totalChunks := (len(audioBytes) + audioChunkSize - 1) / audioChunkSize for i := 0; i < len(audioBytes); i += audioChunkSize { @@ -284,14 +284,14 @@ func main() { chunkIndex := i / audioChunkSize isLast := end >= len(audioBytes) - msg := map[string]any{ - "session_id": sessionID, - "chunk_index": chunkIndex, - "total_chunks": totalChunks, - "audio_b64": base64.StdEncoding.EncodeToString(chunk), - "is_last": isLast, - "timestamp": time.Now().Unix(), - "sample_rate": sampleRate, + msg := &messages.TTSAudioChunk{ + SessionID: sessionID, + ChunkIndex: chunkIndex, + TotalChunks: totalChunks, + Audio: chunk, + IsLast: isLast, + Timestamp: time.Now().Unix(), + SampleRate: sampleRate, } _ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg) } @@ -307,17 +307,21 @@ func main() { } sessionID := parts[4] - data, err := natsutil.DecodeMsgpackMap(natMsg.Data) + req, err := natsutil.Decode[messages.TTSRequest](natMsg.Data) if err != nil { slog.Error("decode error", "error", err) return } - text := strVal(data, "text", "") - speaker := strVal(data, "speaker", defaultSpeaker) - language := strVal(data, "language", defaultLanguage) - speakerWavB64 := strVal(data, "speaker_wav_b64", "") - stream := boolVal(data, "stream", true) + text := req.Text + speaker := orDefault(req.Speaker, defaultSpeaker) + language := orDefault(req.Language, defaultLanguage) + speakerWavB64 := req.SpeakerWavB64 + stream := req.Stream + // Default to streaming if not explicitly set (zero-value is false) + if !stream && text != "" { + stream = true + } if text == "" { slog.Warn("empty text", "session", sessionID) @@ -338,11 +342,11 @@ func main() { if stream { streamAudio(sessionID, audioBytes) } else { - msg := map[string]any{ - "session_id": sessionID, - "audio_b64": base64.StdEncoding.EncodeToString(audioBytes), - "timestamp": time.Now().Unix(), - "sample_rate": sampleRate, + msg := &messages.TTSFullResponse{ + SessionID: sessionID, + Audio: audioBytes, + Timestamp: time.Now().Unix(), + SampleRate: sampleRate, } _ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg) } @@ -358,11 +362,11 @@ func main() { // Subscribe: list voices if _, err := nc.Conn().Subscribe(voicesListSubject, func(msg *nats.Msg) { - resp := map[string]any{ - "default_speaker": defaultSpeaker, - "custom_voices": registry.listVoices(), - "last_refresh": registry.lastRefresh.Unix(), - "timestamp": time.Now().Unix(), + resp := &messages.TTSVoiceListResponse{ + DefaultSpeaker: defaultSpeaker, + CustomVoices: registry.listVoices(), + LastRefresh: registry.lastRefresh.Unix(), + Timestamp: time.Now().Unix(), } packed, _ := msgpack.Marshal(resp) if msg.Reply != "" { @@ -375,10 +379,10 @@ func main() { // Subscribe: refresh voices if _, err := nc.Conn().Subscribe(voicesRefreshSubject, func(msg *nats.Msg) { count := registry.refresh() - resp := map[string]any{ - "count": count, - "custom_voices": registry.listVoices(), - "timestamp": time.Now().Unix(), + resp := &messages.TTSVoiceRefreshResponse{ + Count: count, + CustomVoices: registry.listVoices(), + Timestamp: time.Now().Unix(), } packed, _ := msgpack.Marshal(resp) if msg.Reply != "" { @@ -418,24 +422,6 @@ func main() { // Helpers -func strVal(m map[string]any, key, fallback string) string { - if v, ok := m[key]; ok { - if s, ok := v.(string); ok { - return s - } - } - return fallback -} - -func boolVal(m map[string]any, key string, fallback bool) bool { - if v, ok := m[key]; ok { - if b, ok := v.(bool); ok { - return b - } - } - return fallback -} - func getEnv(key, fallback string) string { if v := os.Getenv(key); v != "" { return v diff --git a/main_test.go b/main_test.go index 6462a9d..7987ff1 100644 --- a/main_test.go +++ b/main_test.go @@ -8,6 +8,10 @@ import ( "path/filepath" "strings" "testing" + + "git.daviestechlabs.io/daviestechlabs/handler-base/messages" + "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" + "github.com/vmihailenco/msgpack/v5" ) func TestVoiceRegistryRefresh(t *testing.T) { @@ -100,20 +104,58 @@ func TestSynthesizeHTTP(t *testing.T) { } } -func TestHelperFunctions(t *testing.T) { - m := map[string]any{"text": "hello", "stream": true, "count": 42} +func TestTTSRequestDecode(t *testing.T) { + req := messages.TTSRequest{ + Text: "hello world", + Speaker: "custom-en", + Language: "en", + Stream: true, + } + data, err := msgpack.Marshal(&req) + if err != nil { + t.Fatal(err) + } + decoded, err := natsutil.Decode[messages.TTSRequest](data) + if err != nil { + t.Fatal(err) + } + if decoded.Text != "hello world" { + t.Errorf("Text = %q", decoded.Text) + } + if decoded.Speaker != "custom-en" { + t.Errorf("Speaker = %q", decoded.Speaker) + } + if !decoded.Stream { + t.Error("Stream should be true") + } +} - if got := strVal(m, "text", ""); got != "hello" { - t.Errorf("strVal(text) = %q", got) +func TestTTSAudioChunkRoundtrip(t *testing.T) { + chunk := messages.TTSAudioChunk{ + SessionID: "sess-001", + ChunkIndex: 0, + TotalChunks: 2, + Audio: make([]byte, 32768), + IsLast: false, + Timestamp: 1234567890, + SampleRate: 24000, } - if got := strVal(m, "missing", "def"); got != "def" { - t.Errorf("strVal(missing) = %q", got) + data, err := msgpack.Marshal(&chunk) + if err != nil { + t.Fatal(err) } - if got := boolVal(m, "stream", false); !got { - t.Errorf("boolVal(stream) = %v", got) + var got messages.TTSAudioChunk + if err := msgpack.Unmarshal(data, &got); err != nil { + t.Fatal(err) } - if got := boolVal(m, "missing", true); !got { - t.Errorf("boolVal(missing) = %v", got) + if got.SessionID != "sess-001" { + t.Errorf("SessionID = %q", got.SessionID) + } + if len(got.Audio) != 32768 { + t.Errorf("Audio len = %d", len(got.Audio)) + } + if got.SampleRate != 24000 { + t.Errorf("SampleRate = %d", got.SampleRate) } }