133 lines
3.7 KiB
Python
133 lines
3.7 KiB
Python
"""
|
|
STT service client (Whisper/faster-whisper).
|
|
"""
|
|
|
|
import logging
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
|
|
from handler_base.config import STTSettings
|
|
from handler_base.telemetry import create_span
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class STTClient:
|
|
"""
|
|
Client for the STT service (Whisper/faster-whisper).
|
|
|
|
Usage:
|
|
client = STTClient()
|
|
text = await client.transcribe(audio_bytes)
|
|
"""
|
|
|
|
def __init__(self, settings: Optional[STTSettings] = None):
|
|
self.settings = settings or STTSettings()
|
|
self._client = httpx.AsyncClient(
|
|
base_url=self.settings.stt_url,
|
|
timeout=180.0, # Transcription can be slow
|
|
)
|
|
|
|
async def close(self) -> None:
|
|
"""Close the HTTP client."""
|
|
await self._client.aclose()
|
|
|
|
async def transcribe(
|
|
self,
|
|
audio: bytes,
|
|
language: Optional[str] = None,
|
|
task: Optional[str] = None,
|
|
response_format: str = "json",
|
|
) -> dict:
|
|
"""
|
|
Transcribe audio to text.
|
|
|
|
Args:
|
|
audio: Audio bytes (WAV, MP3, etc.)
|
|
language: Language code (None for auto-detect)
|
|
task: "transcribe" or "translate"
|
|
response_format: "json", "text", "srt", "vtt"
|
|
|
|
Returns:
|
|
Dict with 'text', 'language', and optional 'segments'
|
|
"""
|
|
language = language or self.settings.stt_language
|
|
task = task or self.settings.stt_task
|
|
|
|
with create_span("stt.transcribe") as span:
|
|
if span:
|
|
span.set_attribute("stt.task", task)
|
|
span.set_attribute("stt.audio_size", len(audio))
|
|
if language:
|
|
span.set_attribute("stt.language", language)
|
|
|
|
files = {"file": ("audio.wav", audio, "audio/wav")}
|
|
data = {
|
|
"response_format": response_format,
|
|
}
|
|
if language:
|
|
data["language"] = language
|
|
|
|
# Choose endpoint based on task
|
|
if task == "translate":
|
|
endpoint = "/v1/audio/translations"
|
|
else:
|
|
endpoint = "/v1/audio/transcriptions"
|
|
|
|
response = await self._client.post(endpoint, files=files, data=data)
|
|
response.raise_for_status()
|
|
|
|
if response_format == "text":
|
|
return {"text": response.text}
|
|
|
|
result = response.json()
|
|
|
|
if span:
|
|
span.set_attribute("stt.result_length", len(result.get("text", "")))
|
|
if result.get("language"):
|
|
span.set_attribute("stt.detected_language", result["language"])
|
|
|
|
return result
|
|
|
|
async def transcribe_file(
|
|
self,
|
|
file_path: str,
|
|
language: Optional[str] = None,
|
|
task: Optional[str] = None,
|
|
) -> dict:
|
|
"""
|
|
Transcribe an audio file.
|
|
|
|
Args:
|
|
file_path: Path to audio file
|
|
language: Language code
|
|
task: "transcribe" or "translate"
|
|
|
|
Returns:
|
|
Transcription result
|
|
"""
|
|
with open(file_path, "rb") as f:
|
|
audio = f.read()
|
|
return await self.transcribe(audio, language, task)
|
|
|
|
async def translate(self, audio: bytes) -> dict:
|
|
"""
|
|
Translate audio to English.
|
|
|
|
Args:
|
|
audio: Audio bytes
|
|
|
|
Returns:
|
|
Translation result with 'text' key
|
|
"""
|
|
return await self.transcribe(audio, task="translate")
|
|
|
|
async def health(self) -> bool:
|
|
"""Check if the STT service is healthy."""
|
|
try:
|
|
response = await self._client.get("/health")
|
|
return response.status_code == 200
|
|
except Exception:
|
|
return False
|