feat: migrate to typed messages, drop base64
Some checks failed
CI / Lint (pull_request) Failing after 1m1s
CI / Test (pull_request) Failing after 1m21s
CI / Release (pull_request) Has been skipped
CI / Docker Build & Push (pull_request) Has been skipped
CI / Notify (pull_request) Successful in 1s

- Decode TTSRequest via natsutil.Decode[messages.TTSRequest]
- Stream audio as raw bytes via messages.TTSAudioChunk (no base64)
- Non-stream response uses messages.TTSFullResponse
- Status updates use messages.TTSStatus
- Voice list/refresh use messages.TTSVoiceListResponse/TTSVoiceRefreshResponse
- Registry returns []messages.TTSVoiceInfo (not []map[string]any)
- Remove strVal/boolVal helpers
- Add .dockerignore, GOAMD64=v3 in Dockerfile
- Update tests for typed structs (13 tests pass)
This commit is contained in:
2026-02-20 07:11:13 -05:00
parent b8d9a277c5
commit 85b481b6c4
5 changed files with 138 additions and 90 deletions

9
.dockerignore Normal file
View File

@@ -0,0 +1,9 @@
.git
.gitignore
*.md
LICENSE
renovate.json
*_test.go
e2e_test.go
__pycache__
.env*

View File

@@ -14,7 +14,7 @@ RUN go mod download
COPY . . COPY . .
# Build static binary # Build static binary
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -o /tts-module . RUN CGO_ENABLED=0 GOOS=linux GOAMD64=v3 go build -ldflags="-w -s" -o /tts-module .
# Runtime stage - scratch for minimal image # Runtime stage - scratch for minimal image
FROM scratch FROM scratch

View File

@@ -2,7 +2,6 @@ package main
import ( import (
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@@ -11,6 +10,9 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"testing" "testing"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"github.com/vmihailenco/msgpack/v5"
) )
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
@@ -63,26 +65,29 @@ func TestSynthesisE2E_StreamChunks(t *testing.T) {
chunkIdx := i / chunkSize chunkIdx := i / chunkSize
isLast := end >= len(audioBytes) isLast := end >= len(audioBytes)
// Verify chunk message shape // Verify typed chunk struct
msg := map[string]any{ msg := messages.TTSAudioChunk{
"session_id": "test-session", SessionID: "test-session",
"chunk_index": chunkIdx, ChunkIndex: chunkIdx,
"total_chunks": totalChunks, TotalChunks: totalChunks,
"audio_b64": base64.StdEncoding.EncodeToString(chunk), Audio: chunk,
"is_last": isLast, IsLast: isLast,
"sample_rate": 24000, SampleRate: 24000,
} }
// Round-trip through JSON // Round-trip through msgpack
data, _ := json.Marshal(msg) data, _ := msgpack.Marshal(&msg)
var decoded map[string]any var decoded messages.TTSAudioChunk
json.Unmarshal(data, &decoded) msgpack.Unmarshal(data, &decoded)
if decoded["session_id"] != "test-session" { if decoded.SessionID != "test-session" {
t.Errorf("chunk %d: session = %v", chunkIdx, decoded["session_id"]) t.Errorf("chunk %d: session = %v", chunkIdx, decoded.SessionID)
} }
if decoded["is_last"] != isLast { if decoded.IsLast != isLast {
t.Errorf("chunk %d: is_last = %v, want %v", chunkIdx, decoded["is_last"], isLast) t.Errorf("chunk %d: is_last = %v, want %v", chunkIdx, decoded.IsLast, isLast)
}
if len(decoded.Audio) != len(chunk) {
t.Errorf("chunk %d: audio len = %d, want %d", chunkIdx, len(decoded.Audio), len(chunk))
} }
} }
} }
@@ -260,8 +265,14 @@ func BenchmarkAudioChunking(b *testing.B) {
end = len(audioBytes) end = len(audioBytes)
} }
chunk := audioBytes[i:end] chunk := audioBytes[i:end]
_ = base64.StdEncoding.EncodeToString(chunk) msg := &messages.TTSAudioChunk{
_ = totalChunks SessionID: "bench",
ChunkIndex: i / chunkSize,
TotalChunks: totalChunks,
Audio: chunk,
SampleRate: 24000,
}
msgpack.Marshal(msg)
} }
} }
} }

106
main.go
View File

@@ -3,7 +3,6 @@ package main
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -23,6 +22,7 @@ import (
"git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/config"
"git.daviestechlabs.io/daviestechlabs/handler-base/health" "git.daviestechlabs.io/daviestechlabs/handler-base/health"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry" "git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
) )
@@ -128,16 +128,16 @@ func (vr *VoiceRegistry) get(name string) *CustomVoice {
return vr.voices[name] return vr.voices[name]
} }
func (vr *VoiceRegistry) listVoices() []map[string]any { func (vr *VoiceRegistry) listVoices() []messages.TTSVoiceInfo {
vr.mu.RLock() vr.mu.RLock()
defer vr.mu.RUnlock() defer vr.mu.RUnlock()
result := make([]map[string]any, 0, len(vr.voices)) result := make([]messages.TTSVoiceInfo, 0, len(vr.voices))
for _, v := range vr.voices { for _, v := range vr.voices {
result = append(result, map[string]any{ result = append(result, messages.TTSVoiceInfo{
"name": v.Name, Name: v.Name,
"language": v.Language, Language: v.Language,
"model_type": v.ModelType, ModelType: v.ModelType,
"created_at": v.CreatedAt, CreatedAt: v.CreatedAt,
}) })
} }
return result return result
@@ -222,11 +222,11 @@ func main() {
// Helper: publish status // Helper: publish status
publishStatus := func(sessionID, status, message string) { publishStatus := func(sessionID, status, message string) {
statusMsg := map[string]any{ statusMsg := &messages.TTSStatus{
"session_id": sessionID, SessionID: sessionID,
"status": status, Status: status,
"message": message, Message: message,
"timestamp": time.Now().Unix(), Timestamp: time.Now().Unix(),
} }
_ = nc.Publish(fmt.Sprintf("%s.%s", statusSubjectPrefix, sessionID), statusMsg) _ = nc.Publish(fmt.Sprintf("%s.%s", statusSubjectPrefix, sessionID), statusMsg)
} }
@@ -272,7 +272,7 @@ func main() {
return io.ReadAll(resp.Body) return io.ReadAll(resp.Body)
} }
// Helper: stream audio chunks // Helper: stream audio chunks — raw bytes, no base64
streamAudio := func(sessionID string, audioBytes []byte) { streamAudio := func(sessionID string, audioBytes []byte) {
totalChunks := (len(audioBytes) + audioChunkSize - 1) / audioChunkSize totalChunks := (len(audioBytes) + audioChunkSize - 1) / audioChunkSize
for i := 0; i < len(audioBytes); i += audioChunkSize { for i := 0; i < len(audioBytes); i += audioChunkSize {
@@ -284,14 +284,14 @@ func main() {
chunkIndex := i / audioChunkSize chunkIndex := i / audioChunkSize
isLast := end >= len(audioBytes) isLast := end >= len(audioBytes)
msg := map[string]any{ msg := &messages.TTSAudioChunk{
"session_id": sessionID, SessionID: sessionID,
"chunk_index": chunkIndex, ChunkIndex: chunkIndex,
"total_chunks": totalChunks, TotalChunks: totalChunks,
"audio_b64": base64.StdEncoding.EncodeToString(chunk), Audio: chunk,
"is_last": isLast, IsLast: isLast,
"timestamp": time.Now().Unix(), Timestamp: time.Now().Unix(),
"sample_rate": sampleRate, SampleRate: sampleRate,
} }
_ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg) _ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg)
} }
@@ -307,17 +307,21 @@ func main() {
} }
sessionID := parts[4] sessionID := parts[4]
data, err := natsutil.DecodeMsgpackMap(natMsg.Data) req, err := natsutil.Decode[messages.TTSRequest](natMsg.Data)
if err != nil { if err != nil {
slog.Error("decode error", "error", err) slog.Error("decode error", "error", err)
return return
} }
text := strVal(data, "text", "") text := req.Text
speaker := strVal(data, "speaker", defaultSpeaker) speaker := orDefault(req.Speaker, defaultSpeaker)
language := strVal(data, "language", defaultLanguage) language := orDefault(req.Language, defaultLanguage)
speakerWavB64 := strVal(data, "speaker_wav_b64", "") speakerWavB64 := req.SpeakerWavB64
stream := boolVal(data, "stream", true) stream := req.Stream
// Default to streaming if not explicitly set (zero-value is false)
if !stream && text != "" {
stream = true
}
if text == "" { if text == "" {
slog.Warn("empty text", "session", sessionID) slog.Warn("empty text", "session", sessionID)
@@ -338,11 +342,11 @@ func main() {
if stream { if stream {
streamAudio(sessionID, audioBytes) streamAudio(sessionID, audioBytes)
} else { } else {
msg := map[string]any{ msg := &messages.TTSFullResponse{
"session_id": sessionID, SessionID: sessionID,
"audio_b64": base64.StdEncoding.EncodeToString(audioBytes), Audio: audioBytes,
"timestamp": time.Now().Unix(), Timestamp: time.Now().Unix(),
"sample_rate": sampleRate, SampleRate: sampleRate,
} }
_ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg) _ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg)
} }
@@ -358,11 +362,11 @@ func main() {
// Subscribe: list voices // Subscribe: list voices
if _, err := nc.Conn().Subscribe(voicesListSubject, func(msg *nats.Msg) { if _, err := nc.Conn().Subscribe(voicesListSubject, func(msg *nats.Msg) {
resp := map[string]any{ resp := &messages.TTSVoiceListResponse{
"default_speaker": defaultSpeaker, DefaultSpeaker: defaultSpeaker,
"custom_voices": registry.listVoices(), CustomVoices: registry.listVoices(),
"last_refresh": registry.lastRefresh.Unix(), LastRefresh: registry.lastRefresh.Unix(),
"timestamp": time.Now().Unix(), Timestamp: time.Now().Unix(),
} }
packed, _ := msgpack.Marshal(resp) packed, _ := msgpack.Marshal(resp)
if msg.Reply != "" { if msg.Reply != "" {
@@ -375,10 +379,10 @@ func main() {
// Subscribe: refresh voices // Subscribe: refresh voices
if _, err := nc.Conn().Subscribe(voicesRefreshSubject, func(msg *nats.Msg) { if _, err := nc.Conn().Subscribe(voicesRefreshSubject, func(msg *nats.Msg) {
count := registry.refresh() count := registry.refresh()
resp := map[string]any{ resp := &messages.TTSVoiceRefreshResponse{
"count": count, Count: count,
"custom_voices": registry.listVoices(), CustomVoices: registry.listVoices(),
"timestamp": time.Now().Unix(), Timestamp: time.Now().Unix(),
} }
packed, _ := msgpack.Marshal(resp) packed, _ := msgpack.Marshal(resp)
if msg.Reply != "" { if msg.Reply != "" {
@@ -418,24 +422,6 @@ func main() {
// Helpers // Helpers
func strVal(m map[string]any, key, fallback string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return fallback
}
func boolVal(m map[string]any, key string, fallback bool) bool {
if v, ok := m[key]; ok {
if b, ok := v.(bool); ok {
return b
}
}
return fallback
}
func getEnv(key, fallback string) string { func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
return v return v

View File

@@ -8,6 +8,10 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
"github.com/vmihailenco/msgpack/v5"
) )
func TestVoiceRegistryRefresh(t *testing.T) { func TestVoiceRegistryRefresh(t *testing.T) {
@@ -100,20 +104,58 @@ func TestSynthesizeHTTP(t *testing.T) {
} }
} }
func TestHelperFunctions(t *testing.T) { func TestTTSRequestDecode(t *testing.T) {
m := map[string]any{"text": "hello", "stream": true, "count": 42} req := messages.TTSRequest{
Text: "hello world",
Speaker: "custom-en",
Language: "en",
Stream: true,
}
data, err := msgpack.Marshal(&req)
if err != nil {
t.Fatal(err)
}
decoded, err := natsutil.Decode[messages.TTSRequest](data)
if err != nil {
t.Fatal(err)
}
if decoded.Text != "hello world" {
t.Errorf("Text = %q", decoded.Text)
}
if decoded.Speaker != "custom-en" {
t.Errorf("Speaker = %q", decoded.Speaker)
}
if !decoded.Stream {
t.Error("Stream should be true")
}
}
if got := strVal(m, "text", ""); got != "hello" { func TestTTSAudioChunkRoundtrip(t *testing.T) {
t.Errorf("strVal(text) = %q", got) chunk := messages.TTSAudioChunk{
SessionID: "sess-001",
ChunkIndex: 0,
TotalChunks: 2,
Audio: make([]byte, 32768),
IsLast: false,
Timestamp: 1234567890,
SampleRate: 24000,
} }
if got := strVal(m, "missing", "def"); got != "def" { data, err := msgpack.Marshal(&chunk)
t.Errorf("strVal(missing) = %q", got) if err != nil {
t.Fatal(err)
} }
if got := boolVal(m, "stream", false); !got { var got messages.TTSAudioChunk
t.Errorf("boolVal(stream) = %v", got) if err := msgpack.Unmarshal(data, &got); err != nil {
t.Fatal(err)
} }
if got := boolVal(m, "missing", true); !got { if got.SessionID != "sess-001" {
t.Errorf("boolVal(missing) = %v", got) t.Errorf("SessionID = %q", got.SessionID)
}
if len(got.Audio) != 32768 {
t.Errorf("Audio len = %d", len(got.Audio))
}
if got.SampleRate != 24000 {
t.Errorf("SampleRate = %d", got.SampleRate)
} }
} }