- VoiceRegistry for trained voices from Kubeflow pipeline - Custom voice routing in synthesize() - NATS subjects for listing/refreshing voices - pyproject.toml with ruff/pytest config - Full test suite (26 tests) - Gitea Actions CI (lint, test, docker, notify) - Renovate config for automated dependency updates Ref: ADR-0056, ADR-0057
417 lines
15 KiB
Python
417 lines
15 KiB
Python
"""
|
|
Unit tests for TTS streaming service.
|
|
"""
|
|
|
|
import base64
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import msgpack
|
|
import pytest
|
|
|
|
from tts_streaming import (
|
|
AUDIO_SUBJECT_PREFIX,
|
|
DEFAULT_LANGUAGE,
|
|
DEFAULT_SPEAKER,
|
|
STATUS_SUBJECT_PREFIX,
|
|
CustomVoice,
|
|
StreamingTTS,
|
|
VoiceRegistry,
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# VoiceRegistry tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCustomVoice:
|
|
"""Tests for CustomVoice dataclass."""
|
|
|
|
def test_defaults(self):
|
|
voice = CustomVoice(
|
|
name="test-voice",
|
|
model_path="/models/tts/custom/test-voice/model.pth",
|
|
config_path="/models/tts/custom/test-voice/config.json",
|
|
created_at="2026-02-13T00:00:00",
|
|
)
|
|
assert voice.name == "test-voice"
|
|
assert voice.language == "en"
|
|
assert voice.model_type == "coqui-tts"
|
|
|
|
def test_custom_fields(self):
|
|
voice = CustomVoice(
|
|
name="german-voice",
|
|
model_path="/m/model.pth",
|
|
config_path="",
|
|
created_at="2026-01-01",
|
|
language="de",
|
|
model_type="custom-vits",
|
|
)
|
|
assert voice.language == "de"
|
|
assert voice.model_type == "custom-vits"
|
|
|
|
|
|
class TestVoiceRegistry:
|
|
"""Tests for VoiceRegistry discovery."""
|
|
|
|
def test_empty_store(self, tmp_path):
|
|
"""Registry returns 0 when the store directory is empty."""
|
|
registry = VoiceRegistry(str(tmp_path))
|
|
count = registry.refresh()
|
|
assert count == 0
|
|
assert registry.list_voices() == []
|
|
|
|
def test_missing_store(self, tmp_path):
|
|
"""Registry handles a non-existent store gracefully."""
|
|
registry = VoiceRegistry(str(tmp_path / "does-not-exist"))
|
|
count = registry.refresh()
|
|
assert count == 0
|
|
|
|
def test_discovers_valid_voice(self, tmp_path):
|
|
"""Registry discovers a voice directory with model_info.json + model.pth."""
|
|
voice_dir = tmp_path / "alice"
|
|
voice_dir.mkdir()
|
|
(voice_dir / "model.pth").write_bytes(b"fake-weights")
|
|
(voice_dir / "config.json").write_text("{}")
|
|
(voice_dir / "model_info.json").write_text(
|
|
json.dumps(
|
|
{
|
|
"name": "alice",
|
|
"created_at": "2026-02-13T12:00:00",
|
|
"type": "coqui-tts",
|
|
}
|
|
)
|
|
)
|
|
|
|
registry = VoiceRegistry(str(tmp_path))
|
|
count = registry.refresh()
|
|
|
|
assert count == 1
|
|
voice = registry.get("alice")
|
|
assert voice is not None
|
|
assert voice.name == "alice"
|
|
assert voice.model_path == str(voice_dir / "model.pth")
|
|
assert voice.config_path == str(voice_dir / "config.json")
|
|
|
|
def test_skips_dir_without_model_pth(self, tmp_path):
|
|
"""Directories missing model.pth are skipped."""
|
|
voice_dir = tmp_path / "broken"
|
|
voice_dir.mkdir()
|
|
(voice_dir / "model_info.json").write_text(json.dumps({"name": "broken"}))
|
|
# no model.pth
|
|
|
|
registry = VoiceRegistry(str(tmp_path))
|
|
assert registry.refresh() == 0
|
|
|
|
def test_skips_dir_without_model_info(self, tmp_path):
|
|
"""Directories missing model_info.json are skipped."""
|
|
voice_dir = tmp_path / "no-info"
|
|
voice_dir.mkdir()
|
|
(voice_dir / "model.pth").write_bytes(b"data")
|
|
|
|
registry = VoiceRegistry(str(tmp_path))
|
|
assert registry.refresh() == 0
|
|
|
|
def test_skips_plain_files(self, tmp_path):
|
|
"""Plain files in the store root are ignored."""
|
|
(tmp_path / "readme.txt").write_text("hello")
|
|
|
|
registry = VoiceRegistry(str(tmp_path))
|
|
assert registry.refresh() == 0
|
|
|
|
def test_multiple_voices(self, tmp_path):
|
|
"""Multiple valid voices are all discovered."""
|
|
for name in ("alice", "bob", "charlie"):
|
|
d = tmp_path / name
|
|
d.mkdir()
|
|
(d / "model.pth").write_bytes(b"w")
|
|
(d / "model_info.json").write_text(json.dumps({"name": name}))
|
|
|
|
registry = VoiceRegistry(str(tmp_path))
|
|
assert registry.refresh() == 3
|
|
names = {v["name"] for v in registry.list_voices()}
|
|
assert names == {"alice", "bob", "charlie"}
|
|
|
|
def test_refresh_detects_new_and_removed(self, tmp_path):
|
|
"""Subsequent refresh picks up additions and removals."""
|
|
d = tmp_path / "v1"
|
|
d.mkdir()
|
|
(d / "model.pth").write_bytes(b"w")
|
|
(d / "model_info.json").write_text(json.dumps({"name": "v1"}))
|
|
|
|
registry = VoiceRegistry(str(tmp_path))
|
|
assert registry.refresh() == 1
|
|
|
|
# Add a second voice
|
|
d2 = tmp_path / "v2"
|
|
d2.mkdir()
|
|
(d2 / "model.pth").write_bytes(b"w")
|
|
(d2 / "model_info.json").write_text(json.dumps({"name": "v2"}))
|
|
|
|
assert registry.refresh() == 2
|
|
assert registry.get("v2") is not None
|
|
|
|
# Remove first voice
|
|
import shutil
|
|
|
|
shutil.rmtree(d)
|
|
assert registry.refresh() == 1
|
|
assert registry.get("v1") is None
|
|
|
|
def test_get_returns_none_for_unknown(self, tmp_path):
|
|
registry = VoiceRegistry(str(tmp_path))
|
|
registry.refresh()
|
|
assert registry.get("nonexistent") is None
|
|
|
|
def test_list_voices_serialization(self, tmp_path):
|
|
"""list_voices returns dicts suitable for msgpack."""
|
|
d = tmp_path / "voice1"
|
|
d.mkdir()
|
|
(d / "model.pth").write_bytes(b"w")
|
|
(d / "model_info.json").write_text(
|
|
json.dumps(
|
|
{
|
|
"name": "voice1",
|
|
"language": "fr",
|
|
"type": "coqui-tts",
|
|
"created_at": "2026-02-13",
|
|
}
|
|
)
|
|
)
|
|
|
|
registry = VoiceRegistry(str(tmp_path))
|
|
registry.refresh()
|
|
voices = registry.list_voices()
|
|
|
|
assert len(voices) == 1
|
|
v = voices[0]
|
|
assert v["name"] == "voice1"
|
|
assert v["language"] == "fr"
|
|
# Verify it's msgpack-serializable
|
|
packed = msgpack.packb(voices)
|
|
assert msgpack.unpackb(packed, raw=False) == voices
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# StreamingTTS tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestStreamingTTS:
|
|
"""Tests for the StreamingTTS service."""
|
|
|
|
@pytest.fixture
|
|
def service(self, tmp_path):
|
|
"""Create a StreamingTTS instance with mocked dependencies."""
|
|
with patch("tts_streaming.VOICE_MODEL_STORE", str(tmp_path)):
|
|
svc = StreamingTTS()
|
|
svc.nc = AsyncMock()
|
|
svc.js = AsyncMock()
|
|
svc.http_client = AsyncMock()
|
|
svc.is_healthy = True
|
|
|
|
# Setup voice registry with a test voice
|
|
d = tmp_path / "test-voice"
|
|
d.mkdir()
|
|
(d / "model.pth").write_bytes(b"w")
|
|
(d / "config.json").write_text("{}")
|
|
(d / "model_info.json").write_text(
|
|
json.dumps({"name": "test-voice", "language": "en", "type": "coqui-tts"})
|
|
)
|
|
svc.voice_registry = VoiceRegistry(str(tmp_path))
|
|
svc.voice_registry.refresh()
|
|
|
|
yield svc
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_synthesize_default_speaker(self, service, mock_http_client):
|
|
"""Synthesis with default speaker sends basic payload."""
|
|
service.http_client = mock_http_client
|
|
|
|
result = await service.synthesize("Hello world")
|
|
|
|
assert result is not None
|
|
call_kwargs = mock_http_client.post.call_args
|
|
payload = call_kwargs.kwargs["json"]
|
|
assert payload["text"] == "Hello world"
|
|
assert payload["speaker"] == DEFAULT_SPEAKER
|
|
assert payload["language"] == DEFAULT_LANGUAGE
|
|
assert "model_path" not in payload
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_synthesize_custom_voice(self, service, mock_http_client):
|
|
"""Synthesis with a registered custom voice includes model_path."""
|
|
service.http_client = mock_http_client
|
|
|
|
result = await service.synthesize("Hello", speaker="test-voice")
|
|
|
|
assert result is not None
|
|
payload = mock_http_client.post.call_args.kwargs["json"]
|
|
assert payload["speaker"] == "test-voice"
|
|
assert "model_path" in payload
|
|
assert payload["model_path"].endswith("model.pth")
|
|
assert "config_path" in payload
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_synthesize_ad_hoc_cloning(self, service, mock_http_client):
|
|
"""Synthesis with speaker_wav_b64 uses ad-hoc voice cloning."""
|
|
service.http_client = mock_http_client
|
|
wav_b64 = base64.b64encode(b"fake-audio").decode()
|
|
|
|
result = await service.synthesize("Hello", speaker_wav_b64=wav_b64)
|
|
|
|
assert result is not None
|
|
payload = mock_http_client.post.call_args.kwargs["json"]
|
|
assert payload["speaker_wav"] == wav_b64
|
|
assert "model_path" not in payload
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_synthesize_unknown_speaker_no_clone(self, service, mock_http_client):
|
|
"""Unknown speaker without reference audio falls through to default."""
|
|
service.http_client = mock_http_client
|
|
|
|
await service.synthesize("Hello", speaker="unknown-voice")
|
|
|
|
payload = mock_http_client.post.call_args.kwargs["json"]
|
|
assert payload["speaker"] == "unknown-voice"
|
|
assert "model_path" not in payload
|
|
assert "speaker_wav" not in payload
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_synthesize_failure_returns_none(self, service):
|
|
"""Synthesis returns None on HTTP error."""
|
|
service.http_client = AsyncMock()
|
|
service.http_client.post = AsyncMock(side_effect=Exception("connection refused"))
|
|
|
|
result = await service.synthesize("Hello")
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_audio_publishes_chunks(self, service, sample_audio_bytes):
|
|
"""stream_audio publishes chunked messages to NATS."""
|
|
await service.stream_audio("sess-1", sample_audio_bytes)
|
|
|
|
assert service.nc.publish.called
|
|
# Verify at least one chunk was published
|
|
call_args = service.nc.publish.call_args_list
|
|
assert len(call_args) >= 1
|
|
|
|
subject = call_args[0].args[0]
|
|
assert subject == f"{AUDIO_SUBJECT_PREFIX}.sess-1"
|
|
|
|
data = msgpack.unpackb(call_args[-1].args[1], raw=False)
|
|
assert data["is_last"] is True
|
|
assert data["session_id"] == "sess-1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_request_empty_text(self, service):
|
|
"""Requests with empty text publish an error status."""
|
|
msg = MagicMock()
|
|
msg.subject = "ai.voice.tts.request.sess-1"
|
|
msg.data = msgpack.packb({"text": ""})
|
|
|
|
await service.handle_request(msg)
|
|
|
|
# Should publish error status
|
|
status_calls = [
|
|
c
|
|
for c in service.nc.publish.call_args_list
|
|
if c.args[0].startswith(STATUS_SUBJECT_PREFIX)
|
|
]
|
|
assert len(status_calls) >= 1
|
|
status_data = msgpack.unpackb(status_calls[0].args[1], raw=False)
|
|
assert status_data["status"] == "error"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_request_success(self, service, mock_http_client):
|
|
"""Successful request publishes audio and completed status."""
|
|
service.http_client = mock_http_client
|
|
|
|
msg = MagicMock()
|
|
msg.subject = "ai.voice.tts.request.sess-2"
|
|
msg.data = msgpack.packb({"text": "Hello world", "stream": True})
|
|
|
|
await service.handle_request(msg)
|
|
|
|
subjects = [c.args[0] for c in service.nc.publish.call_args_list]
|
|
# Should have status (processing), audio chunk(s), and status (completed)
|
|
assert any(s.startswith(STATUS_SUBJECT_PREFIX) for s in subjects)
|
|
assert any(s.startswith(AUDIO_SUBJECT_PREFIX) for s in subjects)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_request_with_custom_voice(self, service, mock_http_client):
|
|
"""Request with custom voice speaker uses trained model."""
|
|
service.http_client = mock_http_client
|
|
|
|
msg = MagicMock()
|
|
msg.subject = "ai.voice.tts.request.sess-3"
|
|
msg.data = msgpack.packb({"text": "Hello", "speaker": "test-voice", "stream": False})
|
|
|
|
await service.handle_request(msg)
|
|
|
|
# Verify XTTS was called with model_path
|
|
payload = mock_http_client.post.call_args.kwargs["json"]
|
|
assert payload["speaker"] == "test-voice"
|
|
assert "model_path" in payload
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_list_voices(self, service):
|
|
"""handle_list_voices returns voice list via request-reply."""
|
|
msg = MagicMock()
|
|
msg.reply = "reply-inbox"
|
|
msg.respond = AsyncMock()
|
|
|
|
await service.handle_list_voices(msg)
|
|
|
|
msg.respond.assert_called_once()
|
|
data = msgpack.unpackb(msg.respond.call_args.args[0], raw=False)
|
|
assert data["default_speaker"] == DEFAULT_SPEAKER
|
|
assert len(data["custom_voices"]) == 1
|
|
assert data["custom_voices"][0]["name"] == "test-voice"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_list_voices_no_reply(self, service):
|
|
"""handle_list_voices does not crash when msg.reply is None."""
|
|
msg = MagicMock()
|
|
msg.reply = None
|
|
msg.respond = AsyncMock()
|
|
|
|
await service.handle_list_voices(msg)
|
|
msg.respond.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_refresh_voices(self, service):
|
|
"""handle_refresh_voices rescans and returns updated count."""
|
|
msg = MagicMock()
|
|
msg.reply = "reply-inbox"
|
|
msg.respond = AsyncMock()
|
|
|
|
await service.handle_refresh_voices(msg)
|
|
|
|
msg.respond.assert_called_once()
|
|
data = msgpack.unpackb(msg.respond.call_args.args[0], raw=False)
|
|
assert data["count"] == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publish_status(self, service):
|
|
"""publish_status sends msgpack status to correct subject."""
|
|
await service.publish_status("sess-1", "completed", "Done")
|
|
|
|
service.nc.publish.assert_called_once()
|
|
subject = service.nc.publish.call_args.args[0]
|
|
assert subject == f"{STATUS_SUBJECT_PREFIX}.sess-1"
|
|
data = msgpack.unpackb(service.nc.publish.call_args.args[1], raw=False)
|
|
assert data["status"] == "completed"
|
|
assert data["message"] == "Done"
|
|
|
|
def test_invalid_subject_format(self, service):
|
|
"""Requests with too few subject segments are skipped."""
|
|
msg = MagicMock()
|
|
msg.subject = "ai.voice.tts" # Missing request.{session_id}
|
|
msg.data = msgpack.packb({"text": "test"})
|
|
|
|
# Should not raise
|
|
import asyncio
|
|
|
|
asyncio.get_event_loop().run_until_complete(service.handle_request(msg))
|