Files
tts-module/e2e_test.go
Billy D. 85b481b6c4
Some checks failed
CI / Lint (pull_request) Failing after 1m1s
CI / Test (pull_request) Failing after 1m21s
CI / Release (pull_request) Has been skipped
CI / Docker Build & Push (pull_request) Has been skipped
CI / Notify (pull_request) Successful in 1s
feat: migrate to typed messages, drop base64
- Decode TTSRequest via natsutil.Decode[messages.TTSRequest]
- Stream audio as raw bytes via messages.TTSAudioChunk (no base64)
- Non-stream response uses messages.TTSFullResponse
- Status updates use messages.TTSStatus
- Voice list/refresh use messages.TTSVoiceListResponse/TTSVoiceRefreshResponse
- Registry returns []messages.TTSVoiceInfo (not []map[string]any)
- Remove strVal/boolVal helpers
- Add .dockerignore, GOAMD64=v3 in Dockerfile
- Update tests for typed structs (13 tests pass)
2026-02-20 07:11:13 -05:00

279 lines
8.1 KiB
Go

package main
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"testing"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"github.com/vmihailenco/msgpack/v5"
)
// ────────────────────────────────────────────────────────────────────────────
// E2E tests: voice registry + XTTS synthesis + audio streaming pipeline
// ────────────────────────────────────────────────────────────────────────────
func TestSynthesisE2E_StreamChunks(t *testing.T) {
// Mock XTTS returning 64 KB of audio
audioSize := 65536
xttsSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var payload map[string]any
json.NewDecoder(r.Body).Decode(&payload)
if payload["text"] == nil || payload["text"] == "" {
w.WriteHeader(400)
w.Write([]byte("empty text"))
return
}
w.Write(make([]byte, audioSize))
}))
defer xttsSrv.Close()
// Test synthesize + chunking logic
client := &http.Client{}
body := `{"text":"hello world","speaker":"default","language":"en"}`
resp, err := client.Post(xttsSrv.URL+"/v1/audio/speech", "application/json",
bytes.NewReader([]byte(body)))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
audioBytes, _ := io.ReadAll(resp.Body)
if len(audioBytes) != audioSize {
t.Fatalf("audio size = %d, want %d", len(audioBytes), audioSize)
}
// Simulate streaming: chunk into 32 KB pieces
chunkSize := 32768
totalChunks := (len(audioBytes) + chunkSize - 1) / chunkSize
if totalChunks != 2 {
t.Errorf("totalChunks = %d, want 2", totalChunks)
}
for i := 0; i < len(audioBytes); i += chunkSize {
end := i + chunkSize
if end > len(audioBytes) {
end = len(audioBytes)
}
chunk := audioBytes[i:end]
chunkIdx := i / chunkSize
isLast := end >= len(audioBytes)
// Verify typed chunk struct
msg := messages.TTSAudioChunk{
SessionID: "test-session",
ChunkIndex: chunkIdx,
TotalChunks: totalChunks,
Audio: chunk,
IsLast: isLast,
SampleRate: 24000,
}
// Round-trip through msgpack
data, _ := msgpack.Marshal(&msg)
var decoded messages.TTSAudioChunk
msgpack.Unmarshal(data, &decoded)
if decoded.SessionID != "test-session" {
t.Errorf("chunk %d: session = %v", chunkIdx, decoded.SessionID)
}
if decoded.IsLast != isLast {
t.Errorf("chunk %d: is_last = %v, want %v", chunkIdx, decoded.IsLast, isLast)
}
if len(decoded.Audio) != len(chunk) {
t.Errorf("chunk %d: audio len = %d, want %d", chunkIdx, len(decoded.Audio), len(chunk))
}
}
}
func TestSynthesisE2E_CustomVoice(t *testing.T) {
// Set up voice registry with temp dir
dir := t.TempDir()
voiceDir := filepath.Join(dir, "custom-en")
os.MkdirAll(voiceDir, 0o755)
info := map[string]string{
"name": "custom-en", "language": "en",
"type": "coqui-tts", "created_at": "2024-06-01",
}
infoData, _ := json.Marshal(info)
os.WriteFile(filepath.Join(voiceDir, "model_info.json"), infoData, 0o644)
os.WriteFile(filepath.Join(voiceDir, "model.pth"), []byte("fake-model"), 0o644)
os.WriteFile(filepath.Join(voiceDir, "config.json"), []byte("{}"), 0o644)
registry := newVoiceRegistry(dir)
count := registry.refresh()
if count != 1 {
t.Fatalf("refresh() = %d, want 1", count)
}
// XTTS mock that validates custom voice fields
xttsSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var payload map[string]any
json.NewDecoder(r.Body).Decode(&payload)
// When custom voice is used, model_path should be set
if payload["model_path"] == nil {
t.Error("expected model_path in custom voice request")
}
if payload["config_path"] == nil {
t.Error("expected config_path for voice with config")
}
w.Write(make([]byte, 4000))
}))
defer xttsSrv.Close()
voice := registry.get("custom-en")
if voice == nil {
t.Fatal("voice 'custom-en' not found")
}
// Build request payload like main.go does
payload := map[string]any{
"text": "hello custom voice",
"speaker": "custom-en",
"language": "en",
}
if voice != nil {
payload["model_path"] = voice.ModelPath
if voice.ConfigPath != "" {
payload["config_path"] = voice.ConfigPath
}
}
data, _ := json.Marshal(payload)
resp, err := http.Post(xttsSrv.URL+"/v1/audio/speech", "application/json",
bytes.NewReader(data))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("status = %d, want 200", resp.StatusCode)
}
}
func TestSynthesisE2E_XTTSError(t *testing.T) {
failSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(503)
w.Write([]byte("model not loaded"))
}))
defer failSrv.Close()
resp, err := http.Post(failSrv.URL+"/v1/audio/speech", "application/json",
bytes.NewReader([]byte(`{"text":"test"}`)))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 503 {
t.Errorf("status = %d, want 503", resp.StatusCode)
}
}
func TestVoiceRegistryMultiple(t *testing.T) {
dir := t.TempDir()
// Create 3 voices
for _, name := range []string{"alice", "bob", "charlie"} {
vDir := filepath.Join(dir, name)
os.MkdirAll(vDir, 0o755)
info := map[string]string{"name": name, "language": "en"}
data, _ := json.Marshal(info)
os.WriteFile(filepath.Join(vDir, "model_info.json"), data, 0o644)
os.WriteFile(filepath.Join(vDir, "model.pth"), []byte("fake"), 0o644)
}
registry := newVoiceRegistry(dir)
count := registry.refresh()
if count != 3 {
t.Errorf("refresh() = %d, want 3", count)
}
voices := registry.listVoices()
if len(voices) != 3 {
t.Errorf("listVoices() = %d, want 3", len(voices))
}
for _, name := range []string{"alice", "bob", "charlie"} {
if v := registry.get(name); v == nil {
t.Errorf("voice %q not found", name)
}
}
if v := registry.get("nonexistent"); v != nil {
t.Error("expected nil for nonexistent voice")
}
}
// ────────────────────────────────────────────────────────────────────────────
// Benchmarks
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkSynthesisRoundtrip(b *testing.B) {
xttsSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(make([]byte, 16000))
}))
defer xttsSrv.Close()
client := &http.Client{}
body := []byte(`{"text":"benchmark text","speaker":"default","language":"en"}`)
b.ResetTimer()
for b.Loop() {
resp, _ := client.Post(xttsSrv.URL+"/v1/audio/speech", "application/json",
bytes.NewReader(body))
io.ReadAll(resp.Body)
resp.Body.Close()
}
}
func BenchmarkVoiceRegistryRefresh(b *testing.B) {
dir := b.TempDir()
for i := 0; i < 10; i++ {
name := "voice-" + strconv.Itoa(i)
vDir := filepath.Join(dir, name)
os.MkdirAll(vDir, 0o755)
info := map[string]string{"name": name}
data, _ := json.Marshal(info)
os.WriteFile(filepath.Join(vDir, "model_info.json"), data, 0o644)
os.WriteFile(filepath.Join(vDir, "model.pth"), []byte("fake"), 0o644)
}
registry := newVoiceRegistry(dir)
b.ResetTimer()
for b.Loop() {
registry.refresh()
}
}
func BenchmarkAudioChunking(b *testing.B) {
audioBytes := make([]byte, 256*1024) // 256 KB audio
chunkSize := 32768
b.ResetTimer()
for b.Loop() {
totalChunks := (len(audioBytes) + chunkSize - 1) / chunkSize
for i := 0; i < len(audioBytes); i += chunkSize {
end := i + chunkSize
if end > len(audioBytes) {
end = len(audioBytes)
}
chunk := audioBytes[i:end]
msg := &messages.TTSAudioChunk{
SessionID: "bench",
ChunkIndex: i / chunkSize,
TotalChunks: totalChunks,
Audio: chunk,
SampleRate: 24000,
}
msgpack.Marshal(msg)
}
}
}