feature/go-handler-refactor #1
9
.dockerignore
Normal file
9
.dockerignore
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
*.md
|
||||||
|
LICENSE
|
||||||
|
renovate.json
|
||||||
|
*_test.go
|
||||||
|
e2e_test.go
|
||||||
|
__pycache__
|
||||||
|
.env*
|
||||||
@@ -14,7 +14,7 @@ RUN go mod download
|
|||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Build static binary
|
# Build static binary
|
||||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -o /tts-module .
|
RUN CGO_ENABLED=0 GOOS=linux GOAMD64=v3 go build -ldflags="-w -s" -o /tts-module .
|
||||||
|
|
||||||
# Runtime stage - scratch for minimal image
|
# Runtime stage - scratch for minimal image
|
||||||
FROM scratch
|
FROM scratch
|
||||||
|
|||||||
49
e2e_test.go
49
e2e_test.go
@@ -2,7 +2,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -11,6 +10,9 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||||
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -63,26 +65,29 @@ func TestSynthesisE2E_StreamChunks(t *testing.T) {
|
|||||||
chunkIdx := i / chunkSize
|
chunkIdx := i / chunkSize
|
||||||
isLast := end >= len(audioBytes)
|
isLast := end >= len(audioBytes)
|
||||||
|
|
||||||
// Verify chunk message shape
|
// Verify typed chunk struct
|
||||||
msg := map[string]any{
|
msg := messages.TTSAudioChunk{
|
||||||
"session_id": "test-session",
|
SessionID: "test-session",
|
||||||
"chunk_index": chunkIdx,
|
ChunkIndex: chunkIdx,
|
||||||
"total_chunks": totalChunks,
|
TotalChunks: totalChunks,
|
||||||
"audio_b64": base64.StdEncoding.EncodeToString(chunk),
|
Audio: chunk,
|
||||||
"is_last": isLast,
|
IsLast: isLast,
|
||||||
"sample_rate": 24000,
|
SampleRate: 24000,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Round-trip through JSON
|
// Round-trip through msgpack
|
||||||
data, _ := json.Marshal(msg)
|
data, _ := msgpack.Marshal(&msg)
|
||||||
var decoded map[string]any
|
var decoded messages.TTSAudioChunk
|
||||||
json.Unmarshal(data, &decoded)
|
msgpack.Unmarshal(data, &decoded)
|
||||||
|
|
||||||
if decoded["session_id"] != "test-session" {
|
if decoded.SessionID != "test-session" {
|
||||||
t.Errorf("chunk %d: session = %v", chunkIdx, decoded["session_id"])
|
t.Errorf("chunk %d: session = %v", chunkIdx, decoded.SessionID)
|
||||||
}
|
}
|
||||||
if decoded["is_last"] != isLast {
|
if decoded.IsLast != isLast {
|
||||||
t.Errorf("chunk %d: is_last = %v, want %v", chunkIdx, decoded["is_last"], isLast)
|
t.Errorf("chunk %d: is_last = %v, want %v", chunkIdx, decoded.IsLast, isLast)
|
||||||
|
}
|
||||||
|
if len(decoded.Audio) != len(chunk) {
|
||||||
|
t.Errorf("chunk %d: audio len = %d, want %d", chunkIdx, len(decoded.Audio), len(chunk))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -260,8 +265,14 @@ func BenchmarkAudioChunking(b *testing.B) {
|
|||||||
end = len(audioBytes)
|
end = len(audioBytes)
|
||||||
}
|
}
|
||||||
chunk := audioBytes[i:end]
|
chunk := audioBytes[i:end]
|
||||||
_ = base64.StdEncoding.EncodeToString(chunk)
|
msg := &messages.TTSAudioChunk{
|
||||||
_ = totalChunks
|
SessionID: "bench",
|
||||||
|
ChunkIndex: i / chunkSize,
|
||||||
|
TotalChunks: totalChunks,
|
||||||
|
Audio: chunk,
|
||||||
|
SampleRate: 24000,
|
||||||
|
}
|
||||||
|
msgpack.Marshal(msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
106
main.go
106
main.go
@@ -3,7 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -23,6 +22,7 @@ import (
|
|||||||
|
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
|
||||||
)
|
)
|
||||||
@@ -128,16 +128,16 @@ func (vr *VoiceRegistry) get(name string) *CustomVoice {
|
|||||||
return vr.voices[name]
|
return vr.voices[name]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vr *VoiceRegistry) listVoices() []map[string]any {
|
func (vr *VoiceRegistry) listVoices() []messages.TTSVoiceInfo {
|
||||||
vr.mu.RLock()
|
vr.mu.RLock()
|
||||||
defer vr.mu.RUnlock()
|
defer vr.mu.RUnlock()
|
||||||
result := make([]map[string]any, 0, len(vr.voices))
|
result := make([]messages.TTSVoiceInfo, 0, len(vr.voices))
|
||||||
for _, v := range vr.voices {
|
for _, v := range vr.voices {
|
||||||
result = append(result, map[string]any{
|
result = append(result, messages.TTSVoiceInfo{
|
||||||
"name": v.Name,
|
Name: v.Name,
|
||||||
"language": v.Language,
|
Language: v.Language,
|
||||||
"model_type": v.ModelType,
|
ModelType: v.ModelType,
|
||||||
"created_at": v.CreatedAt,
|
CreatedAt: v.CreatedAt,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
@@ -222,11 +222,11 @@ func main() {
|
|||||||
|
|
||||||
// Helper: publish status
|
// Helper: publish status
|
||||||
publishStatus := func(sessionID, status, message string) {
|
publishStatus := func(sessionID, status, message string) {
|
||||||
statusMsg := map[string]any{
|
statusMsg := &messages.TTSStatus{
|
||||||
"session_id": sessionID,
|
SessionID: sessionID,
|
||||||
"status": status,
|
Status: status,
|
||||||
"message": message,
|
Message: message,
|
||||||
"timestamp": time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
}
|
}
|
||||||
_ = nc.Publish(fmt.Sprintf("%s.%s", statusSubjectPrefix, sessionID), statusMsg)
|
_ = nc.Publish(fmt.Sprintf("%s.%s", statusSubjectPrefix, sessionID), statusMsg)
|
||||||
}
|
}
|
||||||
@@ -272,7 +272,7 @@ func main() {
|
|||||||
return io.ReadAll(resp.Body)
|
return io.ReadAll(resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper: stream audio chunks
|
// Helper: stream audio chunks — raw bytes, no base64
|
||||||
streamAudio := func(sessionID string, audioBytes []byte) {
|
streamAudio := func(sessionID string, audioBytes []byte) {
|
||||||
totalChunks := (len(audioBytes) + audioChunkSize - 1) / audioChunkSize
|
totalChunks := (len(audioBytes) + audioChunkSize - 1) / audioChunkSize
|
||||||
for i := 0; i < len(audioBytes); i += audioChunkSize {
|
for i := 0; i < len(audioBytes); i += audioChunkSize {
|
||||||
@@ -284,14 +284,14 @@ func main() {
|
|||||||
chunkIndex := i / audioChunkSize
|
chunkIndex := i / audioChunkSize
|
||||||
isLast := end >= len(audioBytes)
|
isLast := end >= len(audioBytes)
|
||||||
|
|
||||||
msg := map[string]any{
|
msg := &messages.TTSAudioChunk{
|
||||||
"session_id": sessionID,
|
SessionID: sessionID,
|
||||||
"chunk_index": chunkIndex,
|
ChunkIndex: chunkIndex,
|
||||||
"total_chunks": totalChunks,
|
TotalChunks: totalChunks,
|
||||||
"audio_b64": base64.StdEncoding.EncodeToString(chunk),
|
Audio: chunk,
|
||||||
"is_last": isLast,
|
IsLast: isLast,
|
||||||
"timestamp": time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
"sample_rate": sampleRate,
|
SampleRate: sampleRate,
|
||||||
}
|
}
|
||||||
_ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg)
|
_ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg)
|
||||||
}
|
}
|
||||||
@@ -307,17 +307,21 @@ func main() {
|
|||||||
}
|
}
|
||||||
sessionID := parts[4]
|
sessionID := parts[4]
|
||||||
|
|
||||||
data, err := natsutil.DecodeMsgpackMap(natMsg.Data)
|
req, err := natsutil.Decode[messages.TTSRequest](natMsg.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("decode error", "error", err)
|
slog.Error("decode error", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
text := strVal(data, "text", "")
|
text := req.Text
|
||||||
speaker := strVal(data, "speaker", defaultSpeaker)
|
speaker := orDefault(req.Speaker, defaultSpeaker)
|
||||||
language := strVal(data, "language", defaultLanguage)
|
language := orDefault(req.Language, defaultLanguage)
|
||||||
speakerWavB64 := strVal(data, "speaker_wav_b64", "")
|
speakerWavB64 := req.SpeakerWavB64
|
||||||
stream := boolVal(data, "stream", true)
|
stream := req.Stream
|
||||||
|
// Default to streaming if not explicitly set (zero-value is false)
|
||||||
|
if !stream && text != "" {
|
||||||
|
stream = true
|
||||||
|
}
|
||||||
|
|
||||||
if text == "" {
|
if text == "" {
|
||||||
slog.Warn("empty text", "session", sessionID)
|
slog.Warn("empty text", "session", sessionID)
|
||||||
@@ -338,11 +342,11 @@ func main() {
|
|||||||
if stream {
|
if stream {
|
||||||
streamAudio(sessionID, audioBytes)
|
streamAudio(sessionID, audioBytes)
|
||||||
} else {
|
} else {
|
||||||
msg := map[string]any{
|
msg := &messages.TTSFullResponse{
|
||||||
"session_id": sessionID,
|
SessionID: sessionID,
|
||||||
"audio_b64": base64.StdEncoding.EncodeToString(audioBytes),
|
Audio: audioBytes,
|
||||||
"timestamp": time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
"sample_rate": sampleRate,
|
SampleRate: sampleRate,
|
||||||
}
|
}
|
||||||
_ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg)
|
_ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg)
|
||||||
}
|
}
|
||||||
@@ -358,11 +362,11 @@ func main() {
|
|||||||
|
|
||||||
// Subscribe: list voices
|
// Subscribe: list voices
|
||||||
if _, err := nc.Conn().Subscribe(voicesListSubject, func(msg *nats.Msg) {
|
if _, err := nc.Conn().Subscribe(voicesListSubject, func(msg *nats.Msg) {
|
||||||
resp := map[string]any{
|
resp := &messages.TTSVoiceListResponse{
|
||||||
"default_speaker": defaultSpeaker,
|
DefaultSpeaker: defaultSpeaker,
|
||||||
"custom_voices": registry.listVoices(),
|
CustomVoices: registry.listVoices(),
|
||||||
"last_refresh": registry.lastRefresh.Unix(),
|
LastRefresh: registry.lastRefresh.Unix(),
|
||||||
"timestamp": time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
}
|
}
|
||||||
packed, _ := msgpack.Marshal(resp)
|
packed, _ := msgpack.Marshal(resp)
|
||||||
if msg.Reply != "" {
|
if msg.Reply != "" {
|
||||||
@@ -375,10 +379,10 @@ func main() {
|
|||||||
// Subscribe: refresh voices
|
// Subscribe: refresh voices
|
||||||
if _, err := nc.Conn().Subscribe(voicesRefreshSubject, func(msg *nats.Msg) {
|
if _, err := nc.Conn().Subscribe(voicesRefreshSubject, func(msg *nats.Msg) {
|
||||||
count := registry.refresh()
|
count := registry.refresh()
|
||||||
resp := map[string]any{
|
resp := &messages.TTSVoiceRefreshResponse{
|
||||||
"count": count,
|
Count: count,
|
||||||
"custom_voices": registry.listVoices(),
|
CustomVoices: registry.listVoices(),
|
||||||
"timestamp": time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
}
|
}
|
||||||
packed, _ := msgpack.Marshal(resp)
|
packed, _ := msgpack.Marshal(resp)
|
||||||
if msg.Reply != "" {
|
if msg.Reply != "" {
|
||||||
@@ -418,24 +422,6 @@ func main() {
|
|||||||
|
|
||||||
// Helpers
|
// Helpers
|
||||||
|
|
||||||
func strVal(m map[string]any, key, fallback string) string {
|
|
||||||
if v, ok := m[key]; ok {
|
|
||||||
if s, ok := v.(string); ok {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
func boolVal(m map[string]any, key string, fallback bool) bool {
|
|
||||||
if v, ok := m[key]; ok {
|
|
||||||
if b, ok := v.(bool); ok {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
func getEnv(key, fallback string) string {
|
func getEnv(key, fallback string) string {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
return v
|
return v
|
||||||
|
|||||||
62
main_test.go
62
main_test.go
@@ -8,6 +8,10 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||||
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestVoiceRegistryRefresh(t *testing.T) {
|
func TestVoiceRegistryRefresh(t *testing.T) {
|
||||||
@@ -100,20 +104,58 @@ func TestSynthesizeHTTP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHelperFunctions(t *testing.T) {
|
func TestTTSRequestDecode(t *testing.T) {
|
||||||
m := map[string]any{"text": "hello", "stream": true, "count": 42}
|
req := messages.TTSRequest{
|
||||||
|
Text: "hello world",
|
||||||
|
Speaker: "custom-en",
|
||||||
|
Language: "en",
|
||||||
|
Stream: true,
|
||||||
|
}
|
||||||
|
data, err := msgpack.Marshal(&req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
decoded, err := natsutil.Decode[messages.TTSRequest](data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if decoded.Text != "hello world" {
|
||||||
|
t.Errorf("Text = %q", decoded.Text)
|
||||||
|
}
|
||||||
|
if decoded.Speaker != "custom-en" {
|
||||||
|
t.Errorf("Speaker = %q", decoded.Speaker)
|
||||||
|
}
|
||||||
|
if !decoded.Stream {
|
||||||
|
t.Error("Stream should be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if got := strVal(m, "text", ""); got != "hello" {
|
func TestTTSAudioChunkRoundtrip(t *testing.T) {
|
||||||
t.Errorf("strVal(text) = %q", got)
|
chunk := messages.TTSAudioChunk{
|
||||||
|
SessionID: "sess-001",
|
||||||
|
ChunkIndex: 0,
|
||||||
|
TotalChunks: 2,
|
||||||
|
Audio: make([]byte, 32768),
|
||||||
|
IsLast: false,
|
||||||
|
Timestamp: 1234567890,
|
||||||
|
SampleRate: 24000,
|
||||||
}
|
}
|
||||||
if got := strVal(m, "missing", "def"); got != "def" {
|
data, err := msgpack.Marshal(&chunk)
|
||||||
t.Errorf("strVal(missing) = %q", got)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if got := boolVal(m, "stream", false); !got {
|
var got messages.TTSAudioChunk
|
||||||
t.Errorf("boolVal(stream) = %v", got)
|
if err := msgpack.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if got := boolVal(m, "missing", true); !got {
|
if got.SessionID != "sess-001" {
|
||||||
t.Errorf("boolVal(missing) = %v", got)
|
t.Errorf("SessionID = %q", got.SessionID)
|
||||||
|
}
|
||||||
|
if len(got.Audio) != 32768 {
|
||||||
|
t.Errorf("Audio len = %d", len(got.Audio))
|
||||||
|
}
|
||||||
|
if got.SampleRate != 24000 {
|
||||||
|
t.Errorf("SampleRate = %d", got.SampleRate)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user