feat: custom voice support, CI pipeline, and Renovate config
- VoiceRegistry for trained voices from Kubeflow pipeline - Custom voice routing in synthesize() - NATS subjects for listing/refreshing voices - pyproject.toml with ruff/pytest config - Full test suite (26 tests) - Gitea Actions CI (lint, test, docker, notify) - Renovate config for automated dependency updates Ref: ADR-0056, ADR-0057
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
52
tests/conftest.py
Normal file
52
tests/conftest.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Pytest configuration and fixtures for tts-module tests.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Set test environment variables before importing
|
||||
os.environ.setdefault("NATS_URL", "nats://localhost:4222")
|
||||
os.environ.setdefault("XTTS_URL", "http://localhost:8000")
|
||||
os.environ.setdefault("OTEL_ENABLED", "false")
|
||||
os.environ.setdefault("VOICE_MODEL_STORE", "/tmp/test-voice-models")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create event loop for async tests."""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_audio_bytes():
|
||||
"""Sample audio bytes for testing (silent 16-bit PCM)."""
|
||||
return bytes([0x00] * 4096)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_nats():
|
||||
"""Mock NATS connection."""
|
||||
nc = AsyncMock()
|
||||
nc.publish = AsyncMock()
|
||||
nc.subscribe = AsyncMock()
|
||||
nc.close = AsyncMock()
|
||||
return nc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_client():
|
||||
"""Mock httpx async client."""
|
||||
client = AsyncMock()
|
||||
response = MagicMock()
|
||||
response.status_code = 200
|
||||
response.content = bytes([0x00] * 2048)
|
||||
response.raise_for_status = MagicMock()
|
||||
client.post = AsyncMock(return_value=response)
|
||||
client.aclose = AsyncMock()
|
||||
return client
|
||||
416
tests/test_tts_streaming.py
Normal file
416
tests/test_tts_streaming.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
Unit tests for TTS streaming service.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import msgpack
|
||||
import pytest
|
||||
|
||||
from tts_streaming import (
|
||||
AUDIO_SUBJECT_PREFIX,
|
||||
DEFAULT_LANGUAGE,
|
||||
DEFAULT_SPEAKER,
|
||||
STATUS_SUBJECT_PREFIX,
|
||||
CustomVoice,
|
||||
StreamingTTS,
|
||||
VoiceRegistry,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VoiceRegistry tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCustomVoice:
|
||||
"""Tests for CustomVoice dataclass."""
|
||||
|
||||
def test_defaults(self):
|
||||
voice = CustomVoice(
|
||||
name="test-voice",
|
||||
model_path="/models/tts/custom/test-voice/model.pth",
|
||||
config_path="/models/tts/custom/test-voice/config.json",
|
||||
created_at="2026-02-13T00:00:00",
|
||||
)
|
||||
assert voice.name == "test-voice"
|
||||
assert voice.language == "en"
|
||||
assert voice.model_type == "coqui-tts"
|
||||
|
||||
def test_custom_fields(self):
|
||||
voice = CustomVoice(
|
||||
name="german-voice",
|
||||
model_path="/m/model.pth",
|
||||
config_path="",
|
||||
created_at="2026-01-01",
|
||||
language="de",
|
||||
model_type="custom-vits",
|
||||
)
|
||||
assert voice.language == "de"
|
||||
assert voice.model_type == "custom-vits"
|
||||
|
||||
|
||||
class TestVoiceRegistry:
|
||||
"""Tests for VoiceRegistry discovery."""
|
||||
|
||||
def test_empty_store(self, tmp_path):
|
||||
"""Registry returns 0 when the store directory is empty."""
|
||||
registry = VoiceRegistry(str(tmp_path))
|
||||
count = registry.refresh()
|
||||
assert count == 0
|
||||
assert registry.list_voices() == []
|
||||
|
||||
def test_missing_store(self, tmp_path):
|
||||
"""Registry handles a non-existent store gracefully."""
|
||||
registry = VoiceRegistry(str(tmp_path / "does-not-exist"))
|
||||
count = registry.refresh()
|
||||
assert count == 0
|
||||
|
||||
def test_discovers_valid_voice(self, tmp_path):
|
||||
"""Registry discovers a voice directory with model_info.json + model.pth."""
|
||||
voice_dir = tmp_path / "alice"
|
||||
voice_dir.mkdir()
|
||||
(voice_dir / "model.pth").write_bytes(b"fake-weights")
|
||||
(voice_dir / "config.json").write_text("{}")
|
||||
(voice_dir / "model_info.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "alice",
|
||||
"created_at": "2026-02-13T12:00:00",
|
||||
"type": "coqui-tts",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
registry = VoiceRegistry(str(tmp_path))
|
||||
count = registry.refresh()
|
||||
|
||||
assert count == 1
|
||||
voice = registry.get("alice")
|
||||
assert voice is not None
|
||||
assert voice.name == "alice"
|
||||
assert voice.model_path == str(voice_dir / "model.pth")
|
||||
assert voice.config_path == str(voice_dir / "config.json")
|
||||
|
||||
def test_skips_dir_without_model_pth(self, tmp_path):
|
||||
"""Directories missing model.pth are skipped."""
|
||||
voice_dir = tmp_path / "broken"
|
||||
voice_dir.mkdir()
|
||||
(voice_dir / "model_info.json").write_text(json.dumps({"name": "broken"}))
|
||||
# no model.pth
|
||||
|
||||
registry = VoiceRegistry(str(tmp_path))
|
||||
assert registry.refresh() == 0
|
||||
|
||||
def test_skips_dir_without_model_info(self, tmp_path):
|
||||
"""Directories missing model_info.json are skipped."""
|
||||
voice_dir = tmp_path / "no-info"
|
||||
voice_dir.mkdir()
|
||||
(voice_dir / "model.pth").write_bytes(b"data")
|
||||
|
||||
registry = VoiceRegistry(str(tmp_path))
|
||||
assert registry.refresh() == 0
|
||||
|
||||
def test_skips_plain_files(self, tmp_path):
|
||||
"""Plain files in the store root are ignored."""
|
||||
(tmp_path / "readme.txt").write_text("hello")
|
||||
|
||||
registry = VoiceRegistry(str(tmp_path))
|
||||
assert registry.refresh() == 0
|
||||
|
||||
def test_multiple_voices(self, tmp_path):
|
||||
"""Multiple valid voices are all discovered."""
|
||||
for name in ("alice", "bob", "charlie"):
|
||||
d = tmp_path / name
|
||||
d.mkdir()
|
||||
(d / "model.pth").write_bytes(b"w")
|
||||
(d / "model_info.json").write_text(json.dumps({"name": name}))
|
||||
|
||||
registry = VoiceRegistry(str(tmp_path))
|
||||
assert registry.refresh() == 3
|
||||
names = {v["name"] for v in registry.list_voices()}
|
||||
assert names == {"alice", "bob", "charlie"}
|
||||
|
||||
def test_refresh_detects_new_and_removed(self, tmp_path):
|
||||
"""Subsequent refresh picks up additions and removals."""
|
||||
d = tmp_path / "v1"
|
||||
d.mkdir()
|
||||
(d / "model.pth").write_bytes(b"w")
|
||||
(d / "model_info.json").write_text(json.dumps({"name": "v1"}))
|
||||
|
||||
registry = VoiceRegistry(str(tmp_path))
|
||||
assert registry.refresh() == 1
|
||||
|
||||
# Add a second voice
|
||||
d2 = tmp_path / "v2"
|
||||
d2.mkdir()
|
||||
(d2 / "model.pth").write_bytes(b"w")
|
||||
(d2 / "model_info.json").write_text(json.dumps({"name": "v2"}))
|
||||
|
||||
assert registry.refresh() == 2
|
||||
assert registry.get("v2") is not None
|
||||
|
||||
# Remove first voice
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(d)
|
||||
assert registry.refresh() == 1
|
||||
assert registry.get("v1") is None
|
||||
|
||||
def test_get_returns_none_for_unknown(self, tmp_path):
|
||||
registry = VoiceRegistry(str(tmp_path))
|
||||
registry.refresh()
|
||||
assert registry.get("nonexistent") is None
|
||||
|
||||
def test_list_voices_serialization(self, tmp_path):
|
||||
"""list_voices returns dicts suitable for msgpack."""
|
||||
d = tmp_path / "voice1"
|
||||
d.mkdir()
|
||||
(d / "model.pth").write_bytes(b"w")
|
||||
(d / "model_info.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "voice1",
|
||||
"language": "fr",
|
||||
"type": "coqui-tts",
|
||||
"created_at": "2026-02-13",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
registry = VoiceRegistry(str(tmp_path))
|
||||
registry.refresh()
|
||||
voices = registry.list_voices()
|
||||
|
||||
assert len(voices) == 1
|
||||
v = voices[0]
|
||||
assert v["name"] == "voice1"
|
||||
assert v["language"] == "fr"
|
||||
# Verify it's msgpack-serializable
|
||||
packed = msgpack.packb(voices)
|
||||
assert msgpack.unpackb(packed, raw=False) == voices
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StreamingTTS tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStreamingTTS:
|
||||
"""Tests for the StreamingTTS service."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, tmp_path):
|
||||
"""Create a StreamingTTS instance with mocked dependencies."""
|
||||
with patch("tts_streaming.VOICE_MODEL_STORE", str(tmp_path)):
|
||||
svc = StreamingTTS()
|
||||
svc.nc = AsyncMock()
|
||||
svc.js = AsyncMock()
|
||||
svc.http_client = AsyncMock()
|
||||
svc.is_healthy = True
|
||||
|
||||
# Setup voice registry with a test voice
|
||||
d = tmp_path / "test-voice"
|
||||
d.mkdir()
|
||||
(d / "model.pth").write_bytes(b"w")
|
||||
(d / "config.json").write_text("{}")
|
||||
(d / "model_info.json").write_text(
|
||||
json.dumps({"name": "test-voice", "language": "en", "type": "coqui-tts"})
|
||||
)
|
||||
svc.voice_registry = VoiceRegistry(str(tmp_path))
|
||||
svc.voice_registry.refresh()
|
||||
|
||||
yield svc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesize_default_speaker(self, service, mock_http_client):
|
||||
"""Synthesis with default speaker sends basic payload."""
|
||||
service.http_client = mock_http_client
|
||||
|
||||
result = await service.synthesize("Hello world")
|
||||
|
||||
assert result is not None
|
||||
call_kwargs = mock_http_client.post.call_args
|
||||
payload = call_kwargs.kwargs["json"]
|
||||
assert payload["text"] == "Hello world"
|
||||
assert payload["speaker"] == DEFAULT_SPEAKER
|
||||
assert payload["language"] == DEFAULT_LANGUAGE
|
||||
assert "model_path" not in payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesize_custom_voice(self, service, mock_http_client):
|
||||
"""Synthesis with a registered custom voice includes model_path."""
|
||||
service.http_client = mock_http_client
|
||||
|
||||
result = await service.synthesize("Hello", speaker="test-voice")
|
||||
|
||||
assert result is not None
|
||||
payload = mock_http_client.post.call_args.kwargs["json"]
|
||||
assert payload["speaker"] == "test-voice"
|
||||
assert "model_path" in payload
|
||||
assert payload["model_path"].endswith("model.pth")
|
||||
assert "config_path" in payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesize_ad_hoc_cloning(self, service, mock_http_client):
|
||||
"""Synthesis with speaker_wav_b64 uses ad-hoc voice cloning."""
|
||||
service.http_client = mock_http_client
|
||||
wav_b64 = base64.b64encode(b"fake-audio").decode()
|
||||
|
||||
result = await service.synthesize("Hello", speaker_wav_b64=wav_b64)
|
||||
|
||||
assert result is not None
|
||||
payload = mock_http_client.post.call_args.kwargs["json"]
|
||||
assert payload["speaker_wav"] == wav_b64
|
||||
assert "model_path" not in payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesize_unknown_speaker_no_clone(self, service, mock_http_client):
|
||||
"""Unknown speaker without reference audio falls through to default."""
|
||||
service.http_client = mock_http_client
|
||||
|
||||
await service.synthesize("Hello", speaker="unknown-voice")
|
||||
|
||||
payload = mock_http_client.post.call_args.kwargs["json"]
|
||||
assert payload["speaker"] == "unknown-voice"
|
||||
assert "model_path" not in payload
|
||||
assert "speaker_wav" not in payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesize_failure_returns_none(self, service):
|
||||
"""Synthesis returns None on HTTP error."""
|
||||
service.http_client = AsyncMock()
|
||||
service.http_client.post = AsyncMock(side_effect=Exception("connection refused"))
|
||||
|
||||
result = await service.synthesize("Hello")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_audio_publishes_chunks(self, service, sample_audio_bytes):
|
||||
"""stream_audio publishes chunked messages to NATS."""
|
||||
await service.stream_audio("sess-1", sample_audio_bytes)
|
||||
|
||||
assert service.nc.publish.called
|
||||
# Verify at least one chunk was published
|
||||
call_args = service.nc.publish.call_args_list
|
||||
assert len(call_args) >= 1
|
||||
|
||||
subject = call_args[0].args[0]
|
||||
assert subject == f"{AUDIO_SUBJECT_PREFIX}.sess-1"
|
||||
|
||||
data = msgpack.unpackb(call_args[-1].args[1], raw=False)
|
||||
assert data["is_last"] is True
|
||||
assert data["session_id"] == "sess-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_request_empty_text(self, service):
|
||||
"""Requests with empty text publish an error status."""
|
||||
msg = MagicMock()
|
||||
msg.subject = "ai.voice.tts.request.sess-1"
|
||||
msg.data = msgpack.packb({"text": ""})
|
||||
|
||||
await service.handle_request(msg)
|
||||
|
||||
# Should publish error status
|
||||
status_calls = [
|
||||
c
|
||||
for c in service.nc.publish.call_args_list
|
||||
if c.args[0].startswith(STATUS_SUBJECT_PREFIX)
|
||||
]
|
||||
assert len(status_calls) >= 1
|
||||
status_data = msgpack.unpackb(status_calls[0].args[1], raw=False)
|
||||
assert status_data["status"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_request_success(self, service, mock_http_client):
|
||||
"""Successful request publishes audio and completed status."""
|
||||
service.http_client = mock_http_client
|
||||
|
||||
msg = MagicMock()
|
||||
msg.subject = "ai.voice.tts.request.sess-2"
|
||||
msg.data = msgpack.packb({"text": "Hello world", "stream": True})
|
||||
|
||||
await service.handle_request(msg)
|
||||
|
||||
subjects = [c.args[0] for c in service.nc.publish.call_args_list]
|
||||
# Should have status (processing), audio chunk(s), and status (completed)
|
||||
assert any(s.startswith(STATUS_SUBJECT_PREFIX) for s in subjects)
|
||||
assert any(s.startswith(AUDIO_SUBJECT_PREFIX) for s in subjects)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_request_with_custom_voice(self, service, mock_http_client):
|
||||
"""Request with custom voice speaker uses trained model."""
|
||||
service.http_client = mock_http_client
|
||||
|
||||
msg = MagicMock()
|
||||
msg.subject = "ai.voice.tts.request.sess-3"
|
||||
msg.data = msgpack.packb({"text": "Hello", "speaker": "test-voice", "stream": False})
|
||||
|
||||
await service.handle_request(msg)
|
||||
|
||||
# Verify XTTS was called with model_path
|
||||
payload = mock_http_client.post.call_args.kwargs["json"]
|
||||
assert payload["speaker"] == "test-voice"
|
||||
assert "model_path" in payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_list_voices(self, service):
|
||||
"""handle_list_voices returns voice list via request-reply."""
|
||||
msg = MagicMock()
|
||||
msg.reply = "reply-inbox"
|
||||
msg.respond = AsyncMock()
|
||||
|
||||
await service.handle_list_voices(msg)
|
||||
|
||||
msg.respond.assert_called_once()
|
||||
data = msgpack.unpackb(msg.respond.call_args.args[0], raw=False)
|
||||
assert data["default_speaker"] == DEFAULT_SPEAKER
|
||||
assert len(data["custom_voices"]) == 1
|
||||
assert data["custom_voices"][0]["name"] == "test-voice"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_list_voices_no_reply(self, service):
|
||||
"""handle_list_voices does not crash when msg.reply is None."""
|
||||
msg = MagicMock()
|
||||
msg.reply = None
|
||||
msg.respond = AsyncMock()
|
||||
|
||||
await service.handle_list_voices(msg)
|
||||
msg.respond.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_refresh_voices(self, service):
|
||||
"""handle_refresh_voices rescans and returns updated count."""
|
||||
msg = MagicMock()
|
||||
msg.reply = "reply-inbox"
|
||||
msg.respond = AsyncMock()
|
||||
|
||||
await service.handle_refresh_voices(msg)
|
||||
|
||||
msg.respond.assert_called_once()
|
||||
data = msgpack.unpackb(msg.respond.call_args.args[0], raw=False)
|
||||
assert data["count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_status(self, service):
|
||||
"""publish_status sends msgpack status to correct subject."""
|
||||
await service.publish_status("sess-1", "completed", "Done")
|
||||
|
||||
service.nc.publish.assert_called_once()
|
||||
subject = service.nc.publish.call_args.args[0]
|
||||
assert subject == f"{STATUS_SUBJECT_PREFIX}.sess-1"
|
||||
data = msgpack.unpackb(service.nc.publish.call_args.args[1], raw=False)
|
||||
assert data["status"] == "completed"
|
||||
assert data["message"] == "Done"
|
||||
|
||||
def test_invalid_subject_format(self, service):
|
||||
"""Requests with too few subject segments are skipped."""
|
||||
msg = MagicMock()
|
||||
msg.subject = "ai.voice.tts" # Missing request.{session_id}
|
||||
msg.data = msgpack.packb({"text": "test"})
|
||||
|
||||
# Should not raise
|
||||
import asyncio
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(service.handle_request(msg))
|
||||
Reference in New Issue
Block a user