- Decode TTSRequest via natsutil.Decode[messages.TTSRequest] - Stream audio as raw bytes via messages.TTSAudioChunk (no base64) - Non-stream response uses messages.TTSFullResponse - Status updates use messages.TTSStatus - Voice list/refresh use messages.TTSVoiceListResponse/TTSVoiceRefreshResponse - Registry returns []messages.TTSVoiceInfo (not []map[string]any) - Remove strVal/boolVal helpers - Add .dockerignore, GOAMD64=v3 in Dockerfile - Update tests for typed structs (13 tests pass)
279 lines
8.1 KiB
Go
279 lines
8.1 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"testing"
|
|
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
|
"github.com/vmihailenco/msgpack/v5"
|
|
)
|
|
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
// E2E tests: voice registry + XTTS synthesis + audio streaming pipeline
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
|
|
func TestSynthesisE2E_StreamChunks(t *testing.T) {
|
|
// Mock XTTS returning 64 KB of audio
|
|
audioSize := 65536
|
|
xttsSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var payload map[string]any
|
|
json.NewDecoder(r.Body).Decode(&payload)
|
|
if payload["text"] == nil || payload["text"] == "" {
|
|
w.WriteHeader(400)
|
|
w.Write([]byte("empty text"))
|
|
return
|
|
}
|
|
w.Write(make([]byte, audioSize))
|
|
}))
|
|
defer xttsSrv.Close()
|
|
|
|
// Test synthesize + chunking logic
|
|
client := &http.Client{}
|
|
body := `{"text":"hello world","speaker":"default","language":"en"}`
|
|
resp, err := client.Post(xttsSrv.URL+"/v1/audio/speech", "application/json",
|
|
bytes.NewReader([]byte(body)))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
audioBytes, _ := io.ReadAll(resp.Body)
|
|
if len(audioBytes) != audioSize {
|
|
t.Fatalf("audio size = %d, want %d", len(audioBytes), audioSize)
|
|
}
|
|
|
|
// Simulate streaming: chunk into 32 KB pieces
|
|
chunkSize := 32768
|
|
totalChunks := (len(audioBytes) + chunkSize - 1) / chunkSize
|
|
if totalChunks != 2 {
|
|
t.Errorf("totalChunks = %d, want 2", totalChunks)
|
|
}
|
|
|
|
for i := 0; i < len(audioBytes); i += chunkSize {
|
|
end := i + chunkSize
|
|
if end > len(audioBytes) {
|
|
end = len(audioBytes)
|
|
}
|
|
chunk := audioBytes[i:end]
|
|
chunkIdx := i / chunkSize
|
|
isLast := end >= len(audioBytes)
|
|
|
|
// Verify typed chunk struct
|
|
msg := messages.TTSAudioChunk{
|
|
SessionID: "test-session",
|
|
ChunkIndex: chunkIdx,
|
|
TotalChunks: totalChunks,
|
|
Audio: chunk,
|
|
IsLast: isLast,
|
|
SampleRate: 24000,
|
|
}
|
|
|
|
// Round-trip through msgpack
|
|
data, _ := msgpack.Marshal(&msg)
|
|
var decoded messages.TTSAudioChunk
|
|
msgpack.Unmarshal(data, &decoded)
|
|
|
|
if decoded.SessionID != "test-session" {
|
|
t.Errorf("chunk %d: session = %v", chunkIdx, decoded.SessionID)
|
|
}
|
|
if decoded.IsLast != isLast {
|
|
t.Errorf("chunk %d: is_last = %v, want %v", chunkIdx, decoded.IsLast, isLast)
|
|
}
|
|
if len(decoded.Audio) != len(chunk) {
|
|
t.Errorf("chunk %d: audio len = %d, want %d", chunkIdx, len(decoded.Audio), len(chunk))
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSynthesisE2E_CustomVoice(t *testing.T) {
|
|
// Set up voice registry with temp dir
|
|
dir := t.TempDir()
|
|
voiceDir := filepath.Join(dir, "custom-en")
|
|
os.MkdirAll(voiceDir, 0o755)
|
|
info := map[string]string{
|
|
"name": "custom-en", "language": "en",
|
|
"type": "coqui-tts", "created_at": "2024-06-01",
|
|
}
|
|
infoData, _ := json.Marshal(info)
|
|
os.WriteFile(filepath.Join(voiceDir, "model_info.json"), infoData, 0o644)
|
|
os.WriteFile(filepath.Join(voiceDir, "model.pth"), []byte("fake-model"), 0o644)
|
|
os.WriteFile(filepath.Join(voiceDir, "config.json"), []byte("{}"), 0o644)
|
|
|
|
registry := newVoiceRegistry(dir)
|
|
count := registry.refresh()
|
|
if count != 1 {
|
|
t.Fatalf("refresh() = %d, want 1", count)
|
|
}
|
|
|
|
// XTTS mock that validates custom voice fields
|
|
xttsSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var payload map[string]any
|
|
json.NewDecoder(r.Body).Decode(&payload)
|
|
|
|
// When custom voice is used, model_path should be set
|
|
if payload["model_path"] == nil {
|
|
t.Error("expected model_path in custom voice request")
|
|
}
|
|
if payload["config_path"] == nil {
|
|
t.Error("expected config_path for voice with config")
|
|
}
|
|
w.Write(make([]byte, 4000))
|
|
}))
|
|
defer xttsSrv.Close()
|
|
|
|
voice := registry.get("custom-en")
|
|
if voice == nil {
|
|
t.Fatal("voice 'custom-en' not found")
|
|
}
|
|
|
|
// Build request payload like main.go does
|
|
payload := map[string]any{
|
|
"text": "hello custom voice",
|
|
"speaker": "custom-en",
|
|
"language": "en",
|
|
}
|
|
if voice != nil {
|
|
payload["model_path"] = voice.ModelPath
|
|
if voice.ConfigPath != "" {
|
|
payload["config_path"] = voice.ConfigPath
|
|
}
|
|
}
|
|
|
|
data, _ := json.Marshal(payload)
|
|
resp, err := http.Post(xttsSrv.URL+"/v1/audio/speech", "application/json",
|
|
bytes.NewReader(data))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != 200 {
|
|
t.Errorf("status = %d, want 200", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestSynthesisE2E_XTTSError(t *testing.T) {
|
|
failSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(503)
|
|
w.Write([]byte("model not loaded"))
|
|
}))
|
|
defer failSrv.Close()
|
|
|
|
resp, err := http.Post(failSrv.URL+"/v1/audio/speech", "application/json",
|
|
bytes.NewReader([]byte(`{"text":"test"}`)))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != 503 {
|
|
t.Errorf("status = %d, want 503", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestVoiceRegistryMultiple(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
// Create 3 voices
|
|
for _, name := range []string{"alice", "bob", "charlie"} {
|
|
vDir := filepath.Join(dir, name)
|
|
os.MkdirAll(vDir, 0o755)
|
|
info := map[string]string{"name": name, "language": "en"}
|
|
data, _ := json.Marshal(info)
|
|
os.WriteFile(filepath.Join(vDir, "model_info.json"), data, 0o644)
|
|
os.WriteFile(filepath.Join(vDir, "model.pth"), []byte("fake"), 0o644)
|
|
}
|
|
|
|
registry := newVoiceRegistry(dir)
|
|
count := registry.refresh()
|
|
if count != 3 {
|
|
t.Errorf("refresh() = %d, want 3", count)
|
|
}
|
|
|
|
voices := registry.listVoices()
|
|
if len(voices) != 3 {
|
|
t.Errorf("listVoices() = %d, want 3", len(voices))
|
|
}
|
|
|
|
for _, name := range []string{"alice", "bob", "charlie"} {
|
|
if v := registry.get(name); v == nil {
|
|
t.Errorf("voice %q not found", name)
|
|
}
|
|
}
|
|
if v := registry.get("nonexistent"); v != nil {
|
|
t.Error("expected nil for nonexistent voice")
|
|
}
|
|
}
|
|
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
// Benchmarks
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
|
|
func BenchmarkSynthesisRoundtrip(b *testing.B) {
|
|
xttsSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write(make([]byte, 16000))
|
|
}))
|
|
defer xttsSrv.Close()
|
|
|
|
client := &http.Client{}
|
|
body := []byte(`{"text":"benchmark text","speaker":"default","language":"en"}`)
|
|
|
|
b.ResetTimer()
|
|
for b.Loop() {
|
|
resp, _ := client.Post(xttsSrv.URL+"/v1/audio/speech", "application/json",
|
|
bytes.NewReader(body))
|
|
io.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
}
|
|
}
|
|
|
|
func BenchmarkVoiceRegistryRefresh(b *testing.B) {
|
|
dir := b.TempDir()
|
|
for i := 0; i < 10; i++ {
|
|
name := "voice-" + strconv.Itoa(i)
|
|
vDir := filepath.Join(dir, name)
|
|
os.MkdirAll(vDir, 0o755)
|
|
info := map[string]string{"name": name}
|
|
data, _ := json.Marshal(info)
|
|
os.WriteFile(filepath.Join(vDir, "model_info.json"), data, 0o644)
|
|
os.WriteFile(filepath.Join(vDir, "model.pth"), []byte("fake"), 0o644)
|
|
}
|
|
|
|
registry := newVoiceRegistry(dir)
|
|
|
|
b.ResetTimer()
|
|
for b.Loop() {
|
|
registry.refresh()
|
|
}
|
|
}
|
|
|
|
func BenchmarkAudioChunking(b *testing.B) {
|
|
audioBytes := make([]byte, 256*1024) // 256 KB audio
|
|
chunkSize := 32768
|
|
|
|
b.ResetTimer()
|
|
for b.Loop() {
|
|
totalChunks := (len(audioBytes) + chunkSize - 1) / chunkSize
|
|
for i := 0; i < len(audioBytes); i += chunkSize {
|
|
end := i + chunkSize
|
|
if end > len(audioBytes) {
|
|
end = len(audioBytes)
|
|
}
|
|
chunk := audioBytes[i:end]
|
|
msg := &messages.TTSAudioChunk{
|
|
SessionID: "bench",
|
|
ChunkIndex: i / chunkSize,
|
|
TotalChunks: totalChunks,
|
|
Audio: chunk,
|
|
SampleRate: 24000,
|
|
}
|
|
msgpack.Marshal(msg)
|
|
}
|
|
}
|
|
}
|