""" Unit tests for TTS streaming service. """ import base64 import json from unittest.mock import AsyncMock, MagicMock, patch import msgpack import pytest from tts_streaming import ( AUDIO_SUBJECT_PREFIX, DEFAULT_LANGUAGE, DEFAULT_SPEAKER, STATUS_SUBJECT_PREFIX, CustomVoice, StreamingTTS, VoiceRegistry, ) # --------------------------------------------------------------------------- # VoiceRegistry tests # --------------------------------------------------------------------------- class TestCustomVoice: """Tests for CustomVoice dataclass.""" def test_defaults(self): voice = CustomVoice( name="test-voice", model_path="/models/tts/custom/test-voice/model.pth", config_path="/models/tts/custom/test-voice/config.json", created_at="2026-02-13T00:00:00", ) assert voice.name == "test-voice" assert voice.language == "en" assert voice.model_type == "coqui-tts" def test_custom_fields(self): voice = CustomVoice( name="german-voice", model_path="/m/model.pth", config_path="", created_at="2026-01-01", language="de", model_type="custom-vits", ) assert voice.language == "de" assert voice.model_type == "custom-vits" class TestVoiceRegistry: """Tests for VoiceRegistry discovery.""" def test_empty_store(self, tmp_path): """Registry returns 0 when the store directory is empty.""" registry = VoiceRegistry(str(tmp_path)) count = registry.refresh() assert count == 0 assert registry.list_voices() == [] def test_missing_store(self, tmp_path): """Registry handles a non-existent store gracefully.""" registry = VoiceRegistry(str(tmp_path / "does-not-exist")) count = registry.refresh() assert count == 0 def test_discovers_valid_voice(self, tmp_path): """Registry discovers a voice directory with model_info.json + model.pth.""" voice_dir = tmp_path / "alice" voice_dir.mkdir() (voice_dir / "model.pth").write_bytes(b"fake-weights") (voice_dir / "config.json").write_text("{}") (voice_dir / "model_info.json").write_text( json.dumps( { "name": "alice", "created_at": "2026-02-13T12:00:00", "type": "coqui-tts", } ) ) registry = VoiceRegistry(str(tmp_path)) count = registry.refresh() assert count == 1 voice = registry.get("alice") assert voice is not None assert voice.name == "alice" assert voice.model_path == str(voice_dir / "model.pth") assert voice.config_path == str(voice_dir / "config.json") def test_skips_dir_without_model_pth(self, tmp_path): """Directories missing model.pth are skipped.""" voice_dir = tmp_path / "broken" voice_dir.mkdir() (voice_dir / "model_info.json").write_text(json.dumps({"name": "broken"})) # no model.pth registry = VoiceRegistry(str(tmp_path)) assert registry.refresh() == 0 def test_skips_dir_without_model_info(self, tmp_path): """Directories missing model_info.json are skipped.""" voice_dir = tmp_path / "no-info" voice_dir.mkdir() (voice_dir / "model.pth").write_bytes(b"data") registry = VoiceRegistry(str(tmp_path)) assert registry.refresh() == 0 def test_skips_plain_files(self, tmp_path): """Plain files in the store root are ignored.""" (tmp_path / "readme.txt").write_text("hello") registry = VoiceRegistry(str(tmp_path)) assert registry.refresh() == 0 def test_multiple_voices(self, tmp_path): """Multiple valid voices are all discovered.""" for name in ("alice", "bob", "charlie"): d = tmp_path / name d.mkdir() (d / "model.pth").write_bytes(b"w") (d / "model_info.json").write_text(json.dumps({"name": name})) registry = VoiceRegistry(str(tmp_path)) assert registry.refresh() == 3 names = {v["name"] for v in registry.list_voices()} assert names == {"alice", "bob", "charlie"} def test_refresh_detects_new_and_removed(self, tmp_path): """Subsequent refresh picks up additions and removals.""" d = tmp_path / "v1" d.mkdir() (d / "model.pth").write_bytes(b"w") (d / "model_info.json").write_text(json.dumps({"name": "v1"})) registry = VoiceRegistry(str(tmp_path)) assert registry.refresh() == 1 # Add a second voice d2 = tmp_path / "v2" d2.mkdir() (d2 / "model.pth").write_bytes(b"w") (d2 / "model_info.json").write_text(json.dumps({"name": "v2"})) assert registry.refresh() == 2 assert registry.get("v2") is not None # Remove first voice import shutil shutil.rmtree(d) assert registry.refresh() == 1 assert registry.get("v1") is None def test_get_returns_none_for_unknown(self, tmp_path): registry = VoiceRegistry(str(tmp_path)) registry.refresh() assert registry.get("nonexistent") is None def test_list_voices_serialization(self, tmp_path): """list_voices returns dicts suitable for msgpack.""" d = tmp_path / "voice1" d.mkdir() (d / "model.pth").write_bytes(b"w") (d / "model_info.json").write_text( json.dumps( { "name": "voice1", "language": "fr", "type": "coqui-tts", "created_at": "2026-02-13", } ) ) registry = VoiceRegistry(str(tmp_path)) registry.refresh() voices = registry.list_voices() assert len(voices) == 1 v = voices[0] assert v["name"] == "voice1" assert v["language"] == "fr" # Verify it's msgpack-serializable packed = msgpack.packb(voices) assert msgpack.unpackb(packed, raw=False) == voices # --------------------------------------------------------------------------- # StreamingTTS tests # --------------------------------------------------------------------------- class TestStreamingTTS: """Tests for the StreamingTTS service.""" @pytest.fixture def service(self, tmp_path): """Create a StreamingTTS instance with mocked dependencies.""" with patch("tts_streaming.VOICE_MODEL_STORE", str(tmp_path)): svc = StreamingTTS() svc.nc = AsyncMock() svc.js = AsyncMock() svc.http_client = AsyncMock() svc.is_healthy = True # Setup voice registry with a test voice d = tmp_path / "test-voice" d.mkdir() (d / "model.pth").write_bytes(b"w") (d / "config.json").write_text("{}") (d / "model_info.json").write_text( json.dumps({"name": "test-voice", "language": "en", "type": "coqui-tts"}) ) svc.voice_registry = VoiceRegistry(str(tmp_path)) svc.voice_registry.refresh() yield svc @pytest.mark.asyncio async def test_synthesize_default_speaker(self, service, mock_http_client): """Synthesis with default speaker sends basic payload.""" service.http_client = mock_http_client result = await service.synthesize("Hello world") assert result is not None call_kwargs = mock_http_client.post.call_args payload = call_kwargs.kwargs["json"] assert payload["text"] == "Hello world" assert payload["speaker"] == DEFAULT_SPEAKER assert payload["language"] == DEFAULT_LANGUAGE assert "model_path" not in payload @pytest.mark.asyncio async def test_synthesize_custom_voice(self, service, mock_http_client): """Synthesis with a registered custom voice includes model_path.""" service.http_client = mock_http_client result = await service.synthesize("Hello", speaker="test-voice") assert result is not None payload = mock_http_client.post.call_args.kwargs["json"] assert payload["speaker"] == "test-voice" assert "model_path" in payload assert payload["model_path"].endswith("model.pth") assert "config_path" in payload @pytest.mark.asyncio async def test_synthesize_ad_hoc_cloning(self, service, mock_http_client): """Synthesis with speaker_wav_b64 uses ad-hoc voice cloning.""" service.http_client = mock_http_client wav_b64 = base64.b64encode(b"fake-audio").decode() result = await service.synthesize("Hello", speaker_wav_b64=wav_b64) assert result is not None payload = mock_http_client.post.call_args.kwargs["json"] assert payload["speaker_wav"] == wav_b64 assert "model_path" not in payload @pytest.mark.asyncio async def test_synthesize_unknown_speaker_no_clone(self, service, mock_http_client): """Unknown speaker without reference audio falls through to default.""" service.http_client = mock_http_client await service.synthesize("Hello", speaker="unknown-voice") payload = mock_http_client.post.call_args.kwargs["json"] assert payload["speaker"] == "unknown-voice" assert "model_path" not in payload assert "speaker_wav" not in payload @pytest.mark.asyncio async def test_synthesize_failure_returns_none(self, service): """Synthesis returns None on HTTP error.""" service.http_client = AsyncMock() service.http_client.post = AsyncMock(side_effect=Exception("connection refused")) result = await service.synthesize("Hello") assert result is None @pytest.mark.asyncio async def test_stream_audio_publishes_chunks(self, service, sample_audio_bytes): """stream_audio publishes chunked messages to NATS.""" await service.stream_audio("sess-1", sample_audio_bytes) assert service.nc.publish.called # Verify at least one chunk was published call_args = service.nc.publish.call_args_list assert len(call_args) >= 1 subject = call_args[0].args[0] assert subject == f"{AUDIO_SUBJECT_PREFIX}.sess-1" data = msgpack.unpackb(call_args[-1].args[1], raw=False) assert data["is_last"] is True assert data["session_id"] == "sess-1" @pytest.mark.asyncio async def test_handle_request_empty_text(self, service): """Requests with empty text publish an error status.""" msg = MagicMock() msg.subject = "ai.voice.tts.request.sess-1" msg.data = msgpack.packb({"text": ""}) await service.handle_request(msg) # Should publish error status status_calls = [ c for c in service.nc.publish.call_args_list if c.args[0].startswith(STATUS_SUBJECT_PREFIX) ] assert len(status_calls) >= 1 status_data = msgpack.unpackb(status_calls[0].args[1], raw=False) assert status_data["status"] == "error" @pytest.mark.asyncio async def test_handle_request_success(self, service, mock_http_client): """Successful request publishes audio and completed status.""" service.http_client = mock_http_client msg = MagicMock() msg.subject = "ai.voice.tts.request.sess-2" msg.data = msgpack.packb({"text": "Hello world", "stream": True}) await service.handle_request(msg) subjects = [c.args[0] for c in service.nc.publish.call_args_list] # Should have status (processing), audio chunk(s), and status (completed) assert any(s.startswith(STATUS_SUBJECT_PREFIX) for s in subjects) assert any(s.startswith(AUDIO_SUBJECT_PREFIX) for s in subjects) @pytest.mark.asyncio async def test_handle_request_with_custom_voice(self, service, mock_http_client): """Request with custom voice speaker uses trained model.""" service.http_client = mock_http_client msg = MagicMock() msg.subject = "ai.voice.tts.request.sess-3" msg.data = msgpack.packb({"text": "Hello", "speaker": "test-voice", "stream": False}) await service.handle_request(msg) # Verify XTTS was called with model_path payload = mock_http_client.post.call_args.kwargs["json"] assert payload["speaker"] == "test-voice" assert "model_path" in payload @pytest.mark.asyncio async def test_handle_list_voices(self, service): """handle_list_voices returns voice list via request-reply.""" msg = MagicMock() msg.reply = "reply-inbox" msg.respond = AsyncMock() await service.handle_list_voices(msg) msg.respond.assert_called_once() data = msgpack.unpackb(msg.respond.call_args.args[0], raw=False) assert data["default_speaker"] == DEFAULT_SPEAKER assert len(data["custom_voices"]) == 1 assert data["custom_voices"][0]["name"] == "test-voice" @pytest.mark.asyncio async def test_handle_list_voices_no_reply(self, service): """handle_list_voices does not crash when msg.reply is None.""" msg = MagicMock() msg.reply = None msg.respond = AsyncMock() await service.handle_list_voices(msg) msg.respond.assert_not_called() @pytest.mark.asyncio async def test_handle_refresh_voices(self, service): """handle_refresh_voices rescans and returns updated count.""" msg = MagicMock() msg.reply = "reply-inbox" msg.respond = AsyncMock() await service.handle_refresh_voices(msg) msg.respond.assert_called_once() data = msgpack.unpackb(msg.respond.call_args.args[0], raw=False) assert data["count"] == 1 @pytest.mark.asyncio async def test_publish_status(self, service): """publish_status sends msgpack status to correct subject.""" await service.publish_status("sess-1", "completed", "Done") service.nc.publish.assert_called_once() subject = service.nc.publish.call_args.args[0] assert subject == f"{STATUS_SUBJECT_PREFIX}.sess-1" data = msgpack.unpackb(service.nc.publish.call_args.args[1], raw=False) assert data["status"] == "completed" assert data["message"] == "Done" def test_invalid_subject_format(self, service): """Requests with too few subject segments are skipped.""" msg = MagicMock() msg.subject = "ai.voice.tts" # Missing request.{session_id} msg.data = msgpack.packb({"text": "test"}) # Should not raise import asyncio asyncio.get_event_loop().run_until_complete(service.handle_request(msg))