- pyproject.toml with ruff/pytest config (setuptools<81 pin) - Full test suite (26 tests) - Gitea Actions CI (lint, test, docker, notify) - Ruff lint/format fixes across source files - Renovate config for automated dependency updates Ref: ADR-0057
263 lines
9.0 KiB
Python
263 lines
9.0 KiB
Python
"""
|
|
Unit tests for STT streaming service.
|
|
"""
|
|
|
|
import time
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import msgpack
|
|
import pytest
|
|
|
|
from stt_streaming import (
|
|
TRANSCRIPTION_SUBJECT_PREFIX,
|
|
AudioBuffer,
|
|
StreamingSTT,
|
|
calculate_audio_rms,
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Utility function tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCalculateAudioRms:
|
|
"""Tests for calculate_audio_rms helper."""
|
|
|
|
def test_silence_returns_zero(self, silent_pcm_bytes):
|
|
rms = calculate_audio_rms(silent_pcm_bytes)
|
|
assert rms == 0.0
|
|
|
|
def test_noisy_signal_above_zero(self, noisy_pcm_bytes):
|
|
rms = calculate_audio_rms(noisy_pcm_bytes)
|
|
assert rms > 0.0
|
|
|
|
def test_empty_bytes(self):
|
|
rms = calculate_audio_rms(b"")
|
|
assert rms == 0.0
|
|
|
|
def test_single_byte(self):
|
|
rms = calculate_audio_rms(b"\x00")
|
|
assert rms == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# AudioBuffer tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAudioBuffer:
|
|
"""Tests for AudioBuffer session management."""
|
|
|
|
def test_init(self):
|
|
buf = AudioBuffer("sess-1")
|
|
assert buf.session_id == "sess-1"
|
|
assert buf.total_bytes == 0
|
|
assert buf.is_complete is False
|
|
assert buf.sequence == 0
|
|
assert buf.state == "listening"
|
|
|
|
def test_add_chunk(self, silent_pcm_bytes):
|
|
buf = AudioBuffer("sess-1")
|
|
buf.add_chunk(silent_pcm_bytes)
|
|
assert buf.total_bytes == len(silent_pcm_bytes)
|
|
assert len(buf.chunks) == 1
|
|
|
|
def test_get_audio_concatenates(self, silent_pcm_bytes):
|
|
buf = AudioBuffer("sess-1")
|
|
buf.add_chunk(silent_pcm_bytes)
|
|
buf.add_chunk(silent_pcm_bytes)
|
|
audio = buf.get_audio()
|
|
assert len(audio) == len(silent_pcm_bytes) * 2
|
|
|
|
def test_clear_resets(self, silent_pcm_bytes):
|
|
buf = AudioBuffer("sess-1")
|
|
buf.add_chunk(silent_pcm_bytes)
|
|
buf.clear()
|
|
assert buf.total_bytes == 0
|
|
assert buf.chunks == []
|
|
assert buf.sequence == 1 # Incremented after clear
|
|
|
|
def test_mark_complete(self):
|
|
buf = AudioBuffer("sess-1")
|
|
buf.mark_complete()
|
|
assert buf.is_complete is True
|
|
|
|
def test_set_state(self):
|
|
buf = AudioBuffer("sess-1")
|
|
assert buf.state == "listening"
|
|
buf.set_state("responding")
|
|
assert buf.state == "responding"
|
|
buf.set_state("listening")
|
|
assert buf.state == "listening"
|
|
|
|
def test_set_invalid_state_ignored(self):
|
|
buf = AudioBuffer("sess-1")
|
|
buf.set_state("invalid")
|
|
assert buf.state == "listening" # Unchanged
|
|
|
|
@patch("stt_streaming.BUFFER_SIZE_BYTES", 100)
|
|
def test_should_process_when_buffer_full(self, silent_pcm_bytes):
|
|
buf = AudioBuffer("sess-1")
|
|
buf.add_chunk(silent_pcm_bytes) # 2000 bytes > 100 threshold
|
|
assert buf.should_process() is True
|
|
|
|
def test_should_not_process_when_empty(self):
|
|
buf = AudioBuffer("sess-1")
|
|
assert buf.should_process() is False
|
|
|
|
@patch("stt_streaming.CHUNK_TIMEOUT_SECONDS", 0.0)
|
|
def test_should_process_on_timeout(self, silent_pcm_bytes):
|
|
buf = AudioBuffer("sess-1")
|
|
buf.add_chunk(silent_pcm_bytes)
|
|
buf.last_chunk_time = time.time() - 10 # Force timeout
|
|
assert buf.should_process() is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# StreamingSTT tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestStreamingSTT:
|
|
"""Tests for the StreamingSTT service."""
|
|
|
|
@pytest.fixture
|
|
def service(self, mock_nats, mock_http_client):
|
|
svc = StreamingSTT()
|
|
svc.nc = mock_nats
|
|
svc.js = mock_nats.jetstream()
|
|
svc.http_client = mock_http_client
|
|
svc.is_healthy = True
|
|
return svc
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transcribe_success(self, service):
|
|
"""Successful transcription returns text."""
|
|
result = await service.transcribe(b"fake-audio")
|
|
assert result == "Hello world"
|
|
service.http_client.post.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transcribe_failure(self, service):
|
|
"""Transcription failure returns None."""
|
|
service.http_client.post = AsyncMock(side_effect=Exception("timeout"))
|
|
result = await service.transcribe(b"fake-audio")
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_start_message(self, service):
|
|
"""Start message creates a new session buffer."""
|
|
msg = MagicMock()
|
|
msg.subject = "ai.voice.stream.sess-1"
|
|
msg.data = msgpack.packb({"type": "start"})
|
|
|
|
await service.handle_stream_message(msg)
|
|
|
|
assert "sess-1" in service.sessions
|
|
assert service.sessions["sess-1"].state == "listening"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_start_with_speaker_id(self, service):
|
|
"""Start message with speaker_id stores it on the buffer."""
|
|
msg = MagicMock()
|
|
msg.subject = "ai.voice.stream.sess-2"
|
|
msg.data = msgpack.packb({"type": "start", "speaker_id": "user-42"})
|
|
|
|
await service.handle_stream_message(msg)
|
|
assert service.sessions["sess-2"].speaker_id == "user-42"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_state_change(self, service):
|
|
"""State change message updates the buffer state."""
|
|
# Create session first
|
|
service.sessions["sess-1"] = AudioBuffer("sess-1")
|
|
|
|
msg = MagicMock()
|
|
msg.subject = "ai.voice.stream.sess-1"
|
|
msg.data = msgpack.packb({"type": "state_change", "state": "responding"})
|
|
|
|
await service.handle_stream_message(msg)
|
|
assert service.sessions["sess-1"].state == "responding"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_audio_chunk(self, service, sample_audio_b64):
|
|
"""Audio chunks are added to the session buffer."""
|
|
service.sessions["sess-1"] = AudioBuffer("sess-1")
|
|
# Prevent auto-creation of monitoring task
|
|
service.processing_tasks["sess-1"] = MagicMock()
|
|
|
|
msg = MagicMock()
|
|
msg.subject = "ai.voice.stream.sess-1"
|
|
msg.data = msgpack.packb({"type": "chunk", "audio_b64": sample_audio_b64})
|
|
|
|
await service.handle_stream_message(msg)
|
|
assert service.sessions["sess-1"].total_bytes > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_end_message(self, service, mock_http_client):
|
|
"""End message triggers processing of remaining audio."""
|
|
buf = AudioBuffer("sess-1")
|
|
buf.add_chunk(b"\x00" * 100)
|
|
service.sessions["sess-1"] = buf
|
|
|
|
msg = MagicMock()
|
|
msg.subject = "ai.voice.stream.sess-1"
|
|
msg.data = msgpack.packb({"type": "end"})
|
|
|
|
await service.handle_stream_message(msg)
|
|
|
|
# Should have published a transcription
|
|
assert service.nc.publish.called
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_auto_create_session(self, service, sample_audio_b64):
|
|
"""Chunk message auto-creates session when start was missed."""
|
|
msg = MagicMock()
|
|
msg.subject = "ai.voice.stream.new-sess"
|
|
msg.data = msgpack.packb({"type": "chunk", "audio_b64": sample_audio_b64})
|
|
|
|
await service.handle_stream_message(msg)
|
|
assert "new-sess" in service.sessions
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_buffer_publishes_result(self, service, mock_http_client):
|
|
"""process_buffer publishes transcription to NATS."""
|
|
buf = AudioBuffer("sess-1")
|
|
buf.add_chunk(b"\x00" * 100)
|
|
service.sessions["sess-1"] = buf
|
|
|
|
await service.process_buffer("sess-1")
|
|
|
|
# Verify transcription published
|
|
pub_calls = service.nc.publish.call_args_list
|
|
assert len(pub_calls) >= 1
|
|
subject = pub_calls[0].args[0]
|
|
assert subject == f"{TRANSCRIPTION_SUBJECT_PREFIX}.sess-1"
|
|
data = msgpack.unpackb(pub_calls[0].args[1], raw=False)
|
|
assert data["transcript"] == "Hello world"
|
|
assert data["session_id"] == "sess-1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_buffer_no_session(self, service):
|
|
"""process_buffer handles missing session gracefully."""
|
|
await service.process_buffer("nonexistent")
|
|
service.nc.publish.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_buffer_empty_audio(self, service):
|
|
"""process_buffer skips empty buffers."""
|
|
service.sessions["sess-1"] = AudioBuffer("sess-1")
|
|
await service.process_buffer("sess-1")
|
|
service.nc.publish.assert_not_called()
|
|
|
|
def test_invalid_subject(self, service):
|
|
"""Messages with invalid subjects are skipped."""
|
|
msg = MagicMock()
|
|
msg.subject = "ai.voice" # Too few parts
|
|
msg.data = msgpack.packb({"type": "start"})
|
|
|
|
import asyncio
|
|
|
|
asyncio.get_event_loop().run_until_complete(service.handle_stream_message(msg))
|
|
assert len(service.sessions) == 0
|