diff --git a/e2e_test.go b/e2e_test.go new file mode 100644 index 0000000..d83ea8e --- /dev/null +++ b/e2e_test.go @@ -0,0 +1,267 @@ +package main + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strconv" + "testing" +) + +// ──────────────────────────────────────────────────────────────────────────── +// E2E tests: voice registry + XTTS synthesis + audio streaming pipeline +// ──────────────────────────────────────────────────────────────────────────── + +func TestSynthesisE2E_StreamChunks(t *testing.T) { + // Mock XTTS returning 64 KB of audio + audioSize := 65536 + xttsSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var payload map[string]any + json.NewDecoder(r.Body).Decode(&payload) + if payload["text"] == nil || payload["text"] == "" { + w.WriteHeader(400) + w.Write([]byte("empty text")) + return + } + w.Write(make([]byte, audioSize)) + })) + defer xttsSrv.Close() + + // Test synthesize + chunking logic + client := &http.Client{} + body := `{"text":"hello world","speaker":"default","language":"en"}` + resp, err := client.Post(xttsSrv.URL+"/v1/audio/speech", "application/json", + bytes.NewReader([]byte(body))) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + audioBytes, _ := io.ReadAll(resp.Body) + if len(audioBytes) != audioSize { + t.Fatalf("audio size = %d, want %d", len(audioBytes), audioSize) + } + + // Simulate streaming: chunk into 32 KB pieces + chunkSize := 32768 + totalChunks := (len(audioBytes) + chunkSize - 1) / chunkSize + if totalChunks != 2 { + t.Errorf("totalChunks = %d, want 2", totalChunks) + } + + for i := 0; i < len(audioBytes); i += chunkSize { + end := i + chunkSize + if end > len(audioBytes) { + end = len(audioBytes) + } + chunk := audioBytes[i:end] + chunkIdx := i / chunkSize + isLast := end >= len(audioBytes) + + // Verify chunk message shape + msg := map[string]any{ + "session_id": "test-session", + "chunk_index": chunkIdx, + "total_chunks": totalChunks, + "audio_b64": base64.StdEncoding.EncodeToString(chunk), + "is_last": isLast, + "sample_rate": 24000, + } + + // Round-trip through JSON + data, _ := json.Marshal(msg) + var decoded map[string]any + json.Unmarshal(data, &decoded) + + if decoded["session_id"] != "test-session" { + t.Errorf("chunk %d: session = %v", chunkIdx, decoded["session_id"]) + } + if decoded["is_last"] != isLast { + t.Errorf("chunk %d: is_last = %v, want %v", chunkIdx, decoded["is_last"], isLast) + } + } +} + +func TestSynthesisE2E_CustomVoice(t *testing.T) { + // Set up voice registry with temp dir + dir := t.TempDir() + voiceDir := filepath.Join(dir, "custom-en") + os.MkdirAll(voiceDir, 0o755) + info := map[string]string{ + "name": "custom-en", "language": "en", + "type": "coqui-tts", "created_at": "2024-06-01", + } + infoData, _ := json.Marshal(info) + os.WriteFile(filepath.Join(voiceDir, "model_info.json"), infoData, 0o644) + os.WriteFile(filepath.Join(voiceDir, "model.pth"), []byte("fake-model"), 0o644) + os.WriteFile(filepath.Join(voiceDir, "config.json"), []byte("{}"), 0o644) + + registry := newVoiceRegistry(dir) + count := registry.refresh() + if count != 1 { + t.Fatalf("refresh() = %d, want 1", count) + } + + // XTTS mock that validates custom voice fields + xttsSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var payload map[string]any + json.NewDecoder(r.Body).Decode(&payload) + + // When custom voice is used, model_path should be set + if payload["model_path"] == nil { + t.Error("expected model_path in custom voice request") + } + if payload["config_path"] == nil { + t.Error("expected config_path for voice with config") + } + w.Write(make([]byte, 4000)) + })) + defer xttsSrv.Close() + + voice := registry.get("custom-en") + if voice == nil { + t.Fatal("voice 'custom-en' not found") + } + + // Build request payload like main.go does + payload := map[string]any{ + "text": "hello custom voice", + "speaker": "custom-en", + "language": "en", + } + if voice != nil { + payload["model_path"] = voice.ModelPath + if voice.ConfigPath != "" { + payload["config_path"] = voice.ConfigPath + } + } + + data, _ := json.Marshal(payload) + resp, err := http.Post(xttsSrv.URL+"/v1/audio/speech", "application/json", + bytes.NewReader(data)) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("status = %d, want 200", resp.StatusCode) + } +} + +func TestSynthesisE2E_XTTSError(t *testing.T) { + failSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(503) + w.Write([]byte("model not loaded")) + })) + defer failSrv.Close() + + resp, err := http.Post(failSrv.URL+"/v1/audio/speech", "application/json", + bytes.NewReader([]byte(`{"text":"test"}`))) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 503 { + t.Errorf("status = %d, want 503", resp.StatusCode) + } +} + +func TestVoiceRegistryMultiple(t *testing.T) { + dir := t.TempDir() + + // Create 3 voices + for _, name := range []string{"alice", "bob", "charlie"} { + vDir := filepath.Join(dir, name) + os.MkdirAll(vDir, 0o755) + info := map[string]string{"name": name, "language": "en"} + data, _ := json.Marshal(info) + os.WriteFile(filepath.Join(vDir, "model_info.json"), data, 0o644) + os.WriteFile(filepath.Join(vDir, "model.pth"), []byte("fake"), 0o644) + } + + registry := newVoiceRegistry(dir) + count := registry.refresh() + if count != 3 { + t.Errorf("refresh() = %d, want 3", count) + } + + voices := registry.listVoices() + if len(voices) != 3 { + t.Errorf("listVoices() = %d, want 3", len(voices)) + } + + for _, name := range []string{"alice", "bob", "charlie"} { + if v := registry.get(name); v == nil { + t.Errorf("voice %q not found", name) + } + } + if v := registry.get("nonexistent"); v != nil { + t.Error("expected nil for nonexistent voice") + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// Benchmarks +// ──────────────────────────────────────────────────────────────────────────── + +func BenchmarkSynthesisRoundtrip(b *testing.B) { + xttsSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(make([]byte, 16000)) + })) + defer xttsSrv.Close() + + client := &http.Client{} + body := []byte(`{"text":"benchmark text","speaker":"default","language":"en"}`) + + b.ResetTimer() + for b.Loop() { + resp, _ := client.Post(xttsSrv.URL+"/v1/audio/speech", "application/json", + bytes.NewReader(body)) + io.ReadAll(resp.Body) + resp.Body.Close() + } +} + +func BenchmarkVoiceRegistryRefresh(b *testing.B) { + dir := b.TempDir() + for i := 0; i < 10; i++ { + name := "voice-" + strconv.Itoa(i) + vDir := filepath.Join(dir, name) + os.MkdirAll(vDir, 0o755) + info := map[string]string{"name": name} + data, _ := json.Marshal(info) + os.WriteFile(filepath.Join(vDir, "model_info.json"), data, 0o644) + os.WriteFile(filepath.Join(vDir, "model.pth"), []byte("fake"), 0o644) + } + + registry := newVoiceRegistry(dir) + + b.ResetTimer() + for b.Loop() { + registry.refresh() + } +} + +func BenchmarkAudioChunking(b *testing.B) { + audioBytes := make([]byte, 256*1024) // 256 KB audio + chunkSize := 32768 + + b.ResetTimer() + for b.Loop() { + totalChunks := (len(audioBytes) + chunkSize - 1) / chunkSize + for i := 0; i < len(audioBytes); i += chunkSize { + end := i + chunkSize + if end > len(audioBytes) { + end = len(audioBytes) + } + chunk := audioBytes[i:end] + _ = base64.StdEncoding.EncodeToString(chunk) + _ = totalChunks + } + } +} diff --git a/go.mod b/go.mod index 0f5946f..c0013d2 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( require ( github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/google/uuid v1.6.0 // indirect diff --git a/go.sum b/go.sum index 4a3f959..b9a1b68 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=