package main import ( "encoding/json" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "testing" "git.daviestechlabs.io/daviestechlabs/handler-base/messages" "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" "github.com/vmihailenco/msgpack/v5" ) func TestVoiceRegistryRefresh(t *testing.T) { dir := t.TempDir() // Create a voice directory voiceDir := filepath.Join(dir, "test-voice") os.MkdirAll(voiceDir, 0o755) info := map[string]string{"name": "test-voice", "language": "en", "type": "coqui-tts", "created_at": "2024-01-01"} infoData, _ := json.Marshal(info) os.WriteFile(filepath.Join(voiceDir, "model_info.json"), infoData, 0o644) os.WriteFile(filepath.Join(voiceDir, "model.pth"), []byte("fake"), 0o644) vr := newVoiceRegistry(dir) count := vr.refresh() if count != 1 { t.Errorf("refresh() = %d, want 1", count) } voice := vr.get("test-voice") if voice == nil { t.Fatal("expected voice 'test-voice'") } if voice.Language != "en" { t.Errorf("language = %q, want %q", voice.Language, "en") } if voice.ModelPath != filepath.Join(voiceDir, "model.pth") { t.Errorf("model_path = %q", voice.ModelPath) } voices := vr.listVoices() if len(voices) != 1 { t.Errorf("listVoices() len = %d, want 1", len(voices)) } } func TestVoiceRegistryMissing(t *testing.T) { vr := newVoiceRegistry("/nonexistent/path") count := vr.refresh() if count != 0 { t.Errorf("refresh() = %d, want 0", count) } } func TestVoiceRegistryNoModel(t *testing.T) { dir := t.TempDir() voiceDir := filepath.Join(dir, "bad-voice") os.MkdirAll(voiceDir, 0o755) info := map[string]string{"name": "bad-voice"} infoData, _ := json.Marshal(info) os.WriteFile(filepath.Join(voiceDir, "model_info.json"), infoData, 0o644) // No model.pth vr := newVoiceRegistry(dir) count := vr.refresh() if count != 0 { t.Errorf("refresh() = %d, want 0", count) } } func TestSynthesizeHTTP(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/audio/speech" { t.Errorf("unexpected path: %s", r.URL.Path) } if r.Method != http.MethodPost { t.Errorf("expected POST, got %s", r.Method) } var payload map[string]any json.NewDecoder(r.Body).Decode(&payload) if payload["text"] != "hello" { t.Errorf("unexpected text: %v", payload["text"]) } w.Write([]byte{0x01, 0x02, 0x03, 0x04}) })) defer ts.Close() // Test the XTTS synthesis HTTP call directly client := ts.Client() body := `{"text":"hello","speaker":"default","language":"en"}` resp, err := client.Post(ts.URL+"/v1/audio/speech", "application/json", strings.NewReader(body)) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Errorf("status = %d, want 200", resp.StatusCode) } } func TestTTSRequestDecode(t *testing.T) { req := messages.TTSRequest{ Text: "hello world", Speaker: "custom-en", Language: "en", Stream: true, } data, err := msgpack.Marshal(&req) if err != nil { t.Fatal(err) } decoded, err := natsutil.Decode[messages.TTSRequest](data) if err != nil { t.Fatal(err) } if decoded.Text != "hello world" { t.Errorf("Text = %q", decoded.Text) } if decoded.Speaker != "custom-en" { t.Errorf("Speaker = %q", decoded.Speaker) } if !decoded.Stream { t.Error("Stream should be true") } } func TestTTSAudioChunkRoundtrip(t *testing.T) { chunk := messages.TTSAudioChunk{ SessionID: "sess-001", ChunkIndex: 0, TotalChunks: 2, Audio: make([]byte, 32768), IsLast: false, Timestamp: 1234567890, SampleRate: 24000, } data, err := msgpack.Marshal(&chunk) if err != nil { t.Fatal(err) } var got messages.TTSAudioChunk if err := msgpack.Unmarshal(data, &got); err != nil { t.Fatal(err) } if got.SessionID != "sess-001" { t.Errorf("SessionID = %q", got.SessionID) } if len(got.Audio) != 32768 { t.Errorf("Audio len = %d", len(got.Audio)) } if got.SampleRate != 24000 { t.Errorf("SampleRate = %d", got.SampleRate) } } func TestGetEnv(t *testing.T) { t.Setenv("MY_TEST_VAR", "value") if got := getEnv("MY_TEST_VAR", "fallback"); got != "value" { t.Errorf("getEnv = %q", got) } if got := getEnv("NONEXISTENT_XYZ", "fallback"); got != "fallback" { t.Errorf("getEnv = %q", got) } } func TestGetEnvInt(t *testing.T) { t.Setenv("MY_PORT", "8080") if got := getEnvInt("MY_PORT", 3000); got != 8080 { t.Errorf("getEnvInt = %d", got) } if got := getEnvInt("NONEXISTENT_XYZ", 3000); got != 3000 { t.Errorf("getEnvInt = %d", got) } } func TestOrDefault(t *testing.T) { if got := orDefault("", "en"); got != "en" { t.Errorf("orDefault('', 'en') = %q", got) } if got := orDefault("fr", "en"); got != "fr" { t.Errorf("orDefault('fr', 'en') = %q", got) } }