feat: CI pipeline, lint fixes, and Renovate config
- pyproject.toml with ruff/pytest config (setuptools<81 pin) - Full test suite (26 tests) - Gitea Actions CI (lint, test, docker, notify) - Ruff lint/format fixes across source files - Renovate config for automated dependency updates Ref: ADR-0057
This commit is contained in:
335
stt_streaming.py
335
stt_streaming.py
@@ -11,46 +11,47 @@ Real-time Speech-to-Text service that processes live audio streams from NATS:
|
||||
This enables faster response times by processing audio as it arrives rather than
|
||||
waiting for complete audio upload.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
import struct
|
||||
from typing import Dict, Optional, List, Tuple
|
||||
from io import BytesIO
|
||||
|
||||
import httpx
|
||||
import msgpack
|
||||
import nats
|
||||
import nats.js
|
||||
from nats.aio.msg import Msg
|
||||
import numpy as np
|
||||
import webrtcvad
|
||||
from nats.aio.msg import Msg
|
||||
|
||||
# OpenTelemetry imports
|
||||
from opentelemetry import trace, metrics
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP
|
||||
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
|
||||
OTLPMetricExporter as OTLPMetricExporterHTTP,
|
||||
)
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||
OTLPSpanExporter as OTLPSpanExporterHTTP,
|
||||
)
|
||||
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
||||
from opentelemetry.instrumentation.logging import LoggingInstrumentor
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, SERVICE_VERSION, Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger("stt-streaming")
|
||||
|
||||
|
||||
# Initialize OpenTelemetry
|
||||
def setup_telemetry():
|
||||
"""Initialize OpenTelemetry tracing and metrics with HyperDX support."""
|
||||
@@ -59,98 +60,120 @@ def setup_telemetry():
|
||||
if not otel_enabled:
|
||||
logger.info("OpenTelemetry disabled")
|
||||
return None, None
|
||||
|
||||
|
||||
# OTEL configuration
|
||||
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317")
|
||||
otel_endpoint = os.environ.get(
|
||||
"OTEL_EXPORTER_OTLP_ENDPOINT",
|
||||
"http://opentelemetry-collector.observability.svc.cluster.local:4317",
|
||||
)
|
||||
service_name = os.environ.get("OTEL_SERVICE_NAME", "stt-streaming")
|
||||
service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml")
|
||||
|
||||
|
||||
# HyperDX configuration
|
||||
hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "")
|
||||
hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io")
|
||||
use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key
|
||||
|
||||
|
||||
# Create resource with service information
|
||||
resource = Resource.create({
|
||||
SERVICE_NAME: service_name,
|
||||
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
|
||||
SERVICE_NAMESPACE: service_namespace,
|
||||
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
|
||||
"host.name": os.environ.get("HOSTNAME", "unknown"),
|
||||
})
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
SERVICE_NAME: service_name,
|
||||
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
|
||||
SERVICE_NAMESPACE: service_namespace,
|
||||
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
|
||||
"host.name": os.environ.get("HOSTNAME", "unknown"),
|
||||
}
|
||||
)
|
||||
|
||||
# Setup tracing
|
||||
trace_provider = TracerProvider(resource=resource)
|
||||
|
||||
|
||||
if use_hyperdx:
|
||||
# Use HTTP exporter for HyperDX with API key header
|
||||
logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}")
|
||||
headers = {"authorization": hyperdx_api_key}
|
||||
otlp_span_exporter = OTLPSpanExporterHTTP(
|
||||
endpoint=f"{hyperdx_endpoint}/v1/traces",
|
||||
headers=headers
|
||||
endpoint=f"{hyperdx_endpoint}/v1/traces", headers=headers
|
||||
)
|
||||
otlp_metric_exporter = OTLPMetricExporterHTTP(
|
||||
endpoint=f"{hyperdx_endpoint}/v1/metrics",
|
||||
headers=headers
|
||||
endpoint=f"{hyperdx_endpoint}/v1/metrics", headers=headers
|
||||
)
|
||||
else:
|
||||
# Use gRPC exporter for standard OTEL collector
|
||||
otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True)
|
||||
otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True)
|
||||
|
||||
|
||||
trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
|
||||
trace.set_tracer_provider(trace_provider)
|
||||
|
||||
|
||||
# Setup metrics
|
||||
metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
otlp_metric_exporter, export_interval_millis=60000
|
||||
)
|
||||
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(meter_provider)
|
||||
|
||||
|
||||
# Instrument HTTPX
|
||||
HTTPXClientInstrumentor().instrument()
|
||||
|
||||
|
||||
# Instrument logging
|
||||
LoggingInstrumentor().instrument(set_logging_format=True)
|
||||
|
||||
|
||||
destination = "HyperDX" if use_hyperdx else "OTEL Collector"
|
||||
logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}")
|
||||
|
||||
|
||||
# Return tracer and meter for the service
|
||||
tracer = trace.get_tracer(__name__)
|
||||
meter = metrics.get_meter(__name__)
|
||||
|
||||
|
||||
return tracer, meter
|
||||
|
||||
|
||||
# Configuration from environment
|
||||
WHISPER_URL = os.environ.get("WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local")
|
||||
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
|
||||
|
||||
# NATS subjects for streaming
|
||||
STREAM_SUBJECT_PREFIX = "ai.voice.stream" # Full subject: ai.voice.stream.{session_id}
|
||||
TRANSCRIPTION_SUBJECT_PREFIX = "ai.voice.transcription" # Full subject: ai.voice.transcription.{session_id}
|
||||
TRANSCRIPTION_SUBJECT_PREFIX = (
|
||||
"ai.voice.transcription" # Full subject: ai.voice.transcription.{session_id}
|
||||
)
|
||||
|
||||
# Streaming parameters
|
||||
BUFFER_SIZE_BYTES = int(os.environ.get("STT_BUFFER_SIZE_BYTES", "512000")) # ~5 seconds at 16kHz 16-bit
|
||||
CHUNK_TIMEOUT_SECONDS = float(os.environ.get("STT_CHUNK_TIMEOUT", "2.0")) # Process after 2s of silence
|
||||
BUFFER_SIZE_BYTES = int(
|
||||
os.environ.get("STT_BUFFER_SIZE_BYTES", "512000")
|
||||
) # ~5 seconds at 16kHz 16-bit
|
||||
CHUNK_TIMEOUT_SECONDS = float(
|
||||
os.environ.get("STT_CHUNK_TIMEOUT", "2.0")
|
||||
) # Process after 2s of silence
|
||||
MAX_BUFFER_SIZE_BYTES = int(os.environ.get("STT_MAX_BUFFER_SIZE", "5120000")) # ~50 seconds max
|
||||
|
||||
# Audio constants
|
||||
AUDIO_SAMPLE_MAX_INT16 = 32768.0 # Maximum value for 16-bit signed integer audio
|
||||
VAD_VOICE_RATIO_THRESHOLD = float(os.environ.get("STT_VAD_VOICE_RATIO", "0.3")) # Min ratio of voice frames
|
||||
VAD_VOICE_RATIO_THRESHOLD = float(
|
||||
os.environ.get("STT_VAD_VOICE_RATIO", "0.3")
|
||||
) # Min ratio of voice frames
|
||||
|
||||
# Voice Activity Detection (VAD) parameters
|
||||
ENABLE_VAD = os.environ.get("STT_ENABLE_VAD", "true").lower() == "true"
|
||||
VAD_AGGRESSIVENESS = int(os.environ.get("STT_VAD_AGGRESSIVENESS", "2")) # 0-3, higher = more aggressive
|
||||
VAD_AGGRESSIVENESS = int(
|
||||
os.environ.get("STT_VAD_AGGRESSIVENESS", "2")
|
||||
) # 0-3, higher = more aggressive
|
||||
VAD_FRAME_DURATION_MS = int(os.environ.get("STT_VAD_FRAME_DURATION", "30")) # 10, 20, or 30 ms
|
||||
|
||||
# Audio threshold for interrupt detection (when LLM is responding)
|
||||
ENABLE_INTERRUPT_DETECTION = os.environ.get("STT_ENABLE_INTERRUPT_DETECTION", "true").lower() == "true"
|
||||
ENABLE_INTERRUPT_DETECTION = (
|
||||
os.environ.get("STT_ENABLE_INTERRUPT_DETECTION", "true").lower() == "true"
|
||||
)
|
||||
AUDIO_LEVEL_THRESHOLD = float(os.environ.get("STT_AUDIO_LEVEL_THRESHOLD", "0.02")) # RMS threshold
|
||||
INTERRUPT_DURATION_THRESHOLD = float(os.environ.get("STT_INTERRUPT_DURATION", "0.5")) # Seconds of speech to trigger
|
||||
INTERRUPT_DURATION_THRESHOLD = float(
|
||||
os.environ.get("STT_INTERRUPT_DURATION", "0.5")
|
||||
) # Seconds of speech to trigger
|
||||
|
||||
# Speaker diarization
|
||||
ENABLE_SPEAKER_DIARIZATION = os.environ.get("STT_ENABLE_SPEAKER_DIARIZATION", "false").lower() == "true"
|
||||
ENABLE_SPEAKER_DIARIZATION = (
|
||||
os.environ.get("STT_ENABLE_SPEAKER_DIARIZATION", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# Session states
|
||||
SESSION_STATE_LISTENING = "listening"
|
||||
@@ -160,17 +183,17 @@ SESSION_STATE_RESPONDING = "responding"
|
||||
def calculate_audio_rms(audio_data: bytes, sample_width: int = 2) -> float:
|
||||
"""
|
||||
Calculate RMS (Root Mean Square) audio level.
|
||||
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
sample_width: Bytes per sample (2 for 16-bit audio)
|
||||
|
||||
|
||||
Returns:
|
||||
RMS level normalized to 0.0-1.0 range
|
||||
"""
|
||||
if len(audio_data) < sample_width:
|
||||
return 0.0
|
||||
|
||||
|
||||
# Convert bytes to numpy array of int16 samples
|
||||
try:
|
||||
samples = np.frombuffer(audio_data, dtype=np.int16)
|
||||
@@ -186,30 +209,30 @@ def calculate_audio_rms(audio_data: bytes, sample_width: int = 2) -> float:
|
||||
def detect_voice_activity(audio_data: bytes, sample_rate: int = 16000) -> bool:
|
||||
"""
|
||||
Detect if audio contains voice using WebRTC VAD.
|
||||
|
||||
|
||||
Args:
|
||||
audio_data: Raw PCM audio bytes (16-bit, mono)
|
||||
sample_rate: Audio sample rate (8000, 16000, 32000, or 48000)
|
||||
|
||||
|
||||
Returns:
|
||||
True if voice is detected, False otherwise
|
||||
"""
|
||||
if not ENABLE_VAD:
|
||||
return True # Assume voice present if VAD disabled
|
||||
|
||||
|
||||
try:
|
||||
vad = webrtcvad.Vad(VAD_AGGRESSIVENESS)
|
||||
|
||||
|
||||
# WebRTC VAD requires specific frame sizes
|
||||
# Frame duration must be 10, 20, or 30 ms
|
||||
frame_size = int(sample_rate * VAD_FRAME_DURATION_MS / 1000) * 2 # *2 for 16-bit samples
|
||||
|
||||
|
||||
# Process audio in frames
|
||||
voice_frames = 0
|
||||
total_frames = 0
|
||||
|
||||
|
||||
for i in range(0, len(audio_data) - frame_size, frame_size):
|
||||
frame = audio_data[i:i + frame_size]
|
||||
frame = audio_data[i : i + frame_size]
|
||||
if len(frame) == frame_size:
|
||||
try:
|
||||
is_speech = vad.is_speech(frame, sample_rate)
|
||||
@@ -219,14 +242,14 @@ def detect_voice_activity(audio_data: bytes, sample_rate: int = 16000) -> bool:
|
||||
except Exception as e:
|
||||
logger.debug(f"VAD frame processing error: {e}")
|
||||
continue
|
||||
|
||||
|
||||
if total_frames == 0:
|
||||
return False
|
||||
|
||||
|
||||
# Consider voice detected if voice ratio exceeds threshold
|
||||
voice_ratio = voice_frames / total_frames
|
||||
return voice_ratio > VAD_VOICE_RATIO_THRESHOLD
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"VAD error: {e}")
|
||||
return True # Default to voice present on error
|
||||
@@ -234,7 +257,7 @@ def detect_voice_activity(audio_data: bytes, sample_rate: int = 16000) -> bool:
|
||||
|
||||
class AudioBuffer:
|
||||
"""Manages audio chunks for a streaming session with VAD and speaker tracking."""
|
||||
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self.session_id = session_id
|
||||
self.chunks = []
|
||||
@@ -247,46 +270,54 @@ class AudioBuffer:
|
||||
self.interrupt_start_time = None # Track when interrupt detection started
|
||||
self.has_voice_activity = False # Track if voice was detected in recent chunks
|
||||
self._last_chunk_vad_result = None # Cache VAD result for last chunk
|
||||
|
||||
|
||||
def add_chunk(self, audio_data: bytes) -> None:
|
||||
"""Add an audio chunk to the buffer and check for voice activity."""
|
||||
self.chunks.append(audio_data)
|
||||
self.total_bytes += len(audio_data)
|
||||
self.last_chunk_time = time.time()
|
||||
|
||||
|
||||
# Check for voice activity in this chunk and cache result
|
||||
has_voice = detect_voice_activity(audio_data)
|
||||
self.has_voice_activity = has_voice
|
||||
self._last_chunk_vad_result = has_voice
|
||||
|
||||
logger.debug(f"Session {self.session_id}: Added chunk, total {self.total_bytes} bytes, voice={has_voice}")
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Session {self.session_id}: Added chunk, total {self.total_bytes} bytes, voice={has_voice}"
|
||||
)
|
||||
|
||||
def check_interrupt(self, audio_data: bytes) -> bool:
|
||||
"""
|
||||
Check if audio indicates an interrupt during responding state.
|
||||
Uses cached VAD result if available.
|
||||
|
||||
|
||||
Returns:
|
||||
True if interrupt detected, False otherwise
|
||||
"""
|
||||
if not ENABLE_INTERRUPT_DETECTION:
|
||||
return False
|
||||
|
||||
|
||||
if self.state != SESSION_STATE_RESPONDING:
|
||||
return False
|
||||
|
||||
|
||||
# Calculate audio level
|
||||
rms_level = calculate_audio_rms(audio_data)
|
||||
|
||||
|
||||
# Use cached VAD result if available to avoid duplicate processing
|
||||
has_voice = self._last_chunk_vad_result if self._last_chunk_vad_result is not None else detect_voice_activity(audio_data)
|
||||
|
||||
has_voice = (
|
||||
self._last_chunk_vad_result
|
||||
if self._last_chunk_vad_result is not None
|
||||
else detect_voice_activity(audio_data)
|
||||
)
|
||||
|
||||
# Check if audio exceeds threshold and contains voice
|
||||
if rms_level >= AUDIO_LEVEL_THRESHOLD and has_voice:
|
||||
if self.interrupt_start_time is None:
|
||||
self.interrupt_start_time = time.time()
|
||||
logger.info(f"Session {self.session_id}: Potential interrupt detected (RMS={rms_level:.3f})")
|
||||
|
||||
logger.info(
|
||||
f"Session {self.session_id}: Potential interrupt detected (RMS={rms_level:.3f})"
|
||||
)
|
||||
|
||||
# Check if interrupt has lasted long enough
|
||||
elapsed = time.time() - self.interrupt_start_time
|
||||
if elapsed >= INTERRUPT_DURATION_THRESHOLD:
|
||||
@@ -295,9 +326,9 @@ class AudioBuffer:
|
||||
else:
|
||||
# Reset interrupt timer if audio drops below threshold
|
||||
self.interrupt_start_time = None
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def set_state(self, state: str) -> None:
|
||||
"""Set the session state (listening or responding)."""
|
||||
if state in (SESSION_STATE_LISTENING, SESSION_STATE_RESPONDING):
|
||||
@@ -307,15 +338,18 @@ class AudioBuffer:
|
||||
logger.info(f"Session {self.session_id}: State changed from {old_state} to {state}")
|
||||
# Reset interrupt tracking when changing states
|
||||
self.interrupt_start_time = None
|
||||
|
||||
|
||||
def should_process(self) -> bool:
|
||||
"""Determine if buffer should be processed now."""
|
||||
# Don't process if no voice activity detected (unless buffer is full or timed out)
|
||||
if ENABLE_VAD and not self.has_voice_activity:
|
||||
# Still process if buffer is very large or has timed out
|
||||
if self.total_bytes < BUFFER_SIZE_BYTES and time.time() - self.last_chunk_time < CHUNK_TIMEOUT_SECONDS:
|
||||
return False
|
||||
|
||||
if (
|
||||
ENABLE_VAD
|
||||
and not self.has_voice_activity
|
||||
and self.total_bytes < BUFFER_SIZE_BYTES
|
||||
and time.time() - self.last_chunk_time < CHUNK_TIMEOUT_SECONDS
|
||||
):
|
||||
return False
|
||||
|
||||
# Process if buffer size threshold reached
|
||||
if self.total_bytes >= BUFFER_SIZE_BYTES:
|
||||
return True
|
||||
@@ -323,21 +357,19 @@ class AudioBuffer:
|
||||
if time.time() - self.last_chunk_time > CHUNK_TIMEOUT_SECONDS and self.total_bytes > 0:
|
||||
return True
|
||||
# Process if buffer is too large (safety limit)
|
||||
if self.total_bytes >= MAX_BUFFER_SIZE_BYTES:
|
||||
return True
|
||||
return False
|
||||
|
||||
return self.total_bytes >= MAX_BUFFER_SIZE_BYTES
|
||||
|
||||
def get_audio(self) -> bytes:
|
||||
"""Get concatenated audio data."""
|
||||
return b''.join(self.chunks)
|
||||
|
||||
return b"".join(self.chunks)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the buffer after processing."""
|
||||
self.chunks = []
|
||||
self.total_bytes = 0
|
||||
self.sequence += 1
|
||||
self._last_chunk_vad_result = None # Clear cached VAD result
|
||||
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark stream as complete."""
|
||||
self.is_complete = True
|
||||
@@ -345,12 +377,12 @@ class AudioBuffer:
|
||||
|
||||
class StreamingSTT:
|
||||
"""Streaming Speech-to-Text service."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.nc = None
|
||||
self.js = None
|
||||
self.http_client = None
|
||||
self.sessions: Dict[str, AudioBuffer] = {}
|
||||
self.sessions: dict[str, AudioBuffer] = {}
|
||||
self.running = True
|
||||
self.processing_tasks = {}
|
||||
self.is_healthy = False
|
||||
@@ -358,32 +390,32 @@ class StreamingSTT:
|
||||
self.meter = None
|
||||
self.stream_counter = None
|
||||
self.transcription_duration = None
|
||||
|
||||
|
||||
async def setup(self):
|
||||
"""Initialize connections."""
|
||||
# Initialize OpenTelemetry
|
||||
self.tracer, self.meter = setup_telemetry()
|
||||
|
||||
|
||||
# Create metrics if OTEL is enabled
|
||||
if self.meter:
|
||||
self.stream_counter = self.meter.create_counter(
|
||||
name="stt_streams_total",
|
||||
description="Total number of STT streams processed",
|
||||
unit="1"
|
||||
unit="1",
|
||||
)
|
||||
self.transcription_duration = self.meter.create_histogram(
|
||||
name="stt_transcription_duration_seconds",
|
||||
description="Duration of STT transcription",
|
||||
unit="s"
|
||||
unit="s",
|
||||
)
|
||||
|
||||
|
||||
# NATS connection
|
||||
self.nc = await nats.connect(NATS_URL)
|
||||
logger.info(f"Connected to NATS at {NATS_URL}")
|
||||
|
||||
|
||||
# Initialize JetStream context
|
||||
self.js = self.nc.jetstream()
|
||||
|
||||
|
||||
# Create or update stream for voice stream messages
|
||||
try:
|
||||
stream_config = nats.js.api.StreamConfig(
|
||||
@@ -398,21 +430,20 @@ class StreamingSTT:
|
||||
except Exception as e:
|
||||
# Stream might already exist
|
||||
logger.info(f"JetStream stream setup: {e}")
|
||||
|
||||
|
||||
# HTTP client for Whisper service
|
||||
self.http_client = httpx.AsyncClient(timeout=180.0)
|
||||
logger.info("HTTP client initialized")
|
||||
|
||||
|
||||
# Mark as healthy once connections are established
|
||||
self.is_healthy = True
|
||||
|
||||
async def transcribe(self, audio_bytes: bytes) -> Optional[str]:
|
||||
|
||||
async def transcribe(self, audio_bytes: bytes) -> str | None:
|
||||
"""Transcribe audio using Whisper."""
|
||||
try:
|
||||
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
|
||||
response = await self.http_client.post(
|
||||
f"{WHISPER_URL}/v1/audio/transcriptions",
|
||||
files=files
|
||||
f"{WHISPER_URL}/v1/audio/transcriptions", files=files
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
@@ -422,22 +453,24 @@ class StreamingSTT:
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def process_buffer(self, session_id: str):
|
||||
"""Process accumulated audio buffer for a session."""
|
||||
buffer = self.sessions.get(session_id)
|
||||
if not buffer:
|
||||
return
|
||||
|
||||
|
||||
audio_data = buffer.get_audio()
|
||||
if not audio_data:
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(audio_data)} bytes for session {session_id}, sequence {buffer.sequence}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Processing {len(audio_data)} bytes for session {session_id}, sequence {buffer.sequence}"
|
||||
)
|
||||
|
||||
# Transcribe
|
||||
transcript = await self.transcribe(audio_data)
|
||||
|
||||
|
||||
if transcript:
|
||||
# Publish transcription result using msgpack binary format
|
||||
result = {
|
||||
@@ -449,24 +482,25 @@ class StreamingSTT:
|
||||
"timestamp": time.time(),
|
||||
"speaker_id": buffer.speaker_id,
|
||||
"has_voice_activity": buffer.has_voice_activity,
|
||||
"state": buffer.state
|
||||
"state": buffer.state,
|
||||
}
|
||||
|
||||
|
||||
await self.nc.publish(
|
||||
f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}",
|
||||
msgpack.packb(result)
|
||||
f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}", msgpack.packb(result)
|
||||
)
|
||||
logger.info(f"Published transcription for session {session_id} (seq {buffer.sequence}, speaker={buffer.speaker_id})")
|
||||
|
||||
logger.info(
|
||||
f"Published transcription for session {session_id} (seq {buffer.sequence}, speaker={buffer.speaker_id})"
|
||||
)
|
||||
|
||||
# Clear buffer after processing
|
||||
buffer.clear()
|
||||
|
||||
|
||||
# Clean up completed sessions asynchronously
|
||||
if buffer.is_complete:
|
||||
logger.info(f"Session {session_id} completed")
|
||||
# Schedule cleanup task to avoid blocking
|
||||
asyncio.create_task(self._cleanup_session(session_id))
|
||||
|
||||
|
||||
async def _cleanup_session(self, session_id: str):
|
||||
"""Clean up a completed session after a delay."""
|
||||
# Keep session for a bit in case of late messages
|
||||
@@ -476,34 +510,34 @@ class StreamingSTT:
|
||||
logger.info(f"Cleaned up session: {session_id}")
|
||||
if session_id in self.processing_tasks:
|
||||
del self.processing_tasks[session_id]
|
||||
|
||||
|
||||
async def monitor_buffer(self, session_id: str):
|
||||
"""Monitor buffer and trigger processing when needed."""
|
||||
while self.running and session_id in self.sessions:
|
||||
buffer = self.sessions.get(session_id)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
|
||||
if buffer.should_process():
|
||||
await self.process_buffer(session_id)
|
||||
|
||||
|
||||
# Don't spin too fast
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
async def handle_stream_message(self, msg: Msg):
|
||||
"""Handle incoming audio stream message."""
|
||||
try:
|
||||
# Extract session_id from subject: ai.voice.stream.{session_id}
|
||||
subject_parts = msg.subject.split('.')
|
||||
subject_parts = msg.subject.split(".")
|
||||
if len(subject_parts) < 4:
|
||||
logger.warning(f"Invalid subject format: {msg.subject}")
|
||||
return
|
||||
|
||||
|
||||
session_id = subject_parts[3]
|
||||
|
||||
|
||||
# Parse message using msgpack binary format
|
||||
data = msgpack.unpackb(msg.data, raw=False)
|
||||
|
||||
|
||||
# Handle control messages
|
||||
if data.get("type") == "start":
|
||||
logger.info(f"Starting stream session: {session_id}")
|
||||
@@ -520,19 +554,19 @@ class StreamingSTT:
|
||||
task = asyncio.create_task(self.monitor_buffer(session_id))
|
||||
self.processing_tasks[session_id] = task
|
||||
return
|
||||
|
||||
|
||||
if data.get("type") == "state_change":
|
||||
logger.info(f"State change for session {session_id}")
|
||||
buffer = self.sessions.get(session_id)
|
||||
if buffer:
|
||||
new_state = data.get("state", SESSION_STATE_LISTENING)
|
||||
buffer.set_state(new_state)
|
||||
|
||||
|
||||
# If switching to listening mode, reset any interrupt tracking
|
||||
if new_state == SESSION_STATE_LISTENING:
|
||||
buffer.interrupt_start_time = None
|
||||
return
|
||||
|
||||
|
||||
if data.get("type") == "end":
|
||||
logger.info(f"Ending stream session: {session_id}")
|
||||
buffer = self.sessions.get(session_id)
|
||||
@@ -542,15 +576,15 @@ class StreamingSTT:
|
||||
if buffer.total_bytes > 0:
|
||||
await self.process_buffer(session_id)
|
||||
return
|
||||
|
||||
|
||||
# Handle audio chunk
|
||||
if data.get("type") == "chunk":
|
||||
audio_b64 = data.get("audio_b64", "")
|
||||
if not audio_b64:
|
||||
return
|
||||
|
||||
|
||||
audio_bytes = base64.b64decode(audio_b64)
|
||||
|
||||
|
||||
# Create session if it doesn't exist (handle missing start message)
|
||||
# Check both sessions and processing_tasks to avoid race conditions
|
||||
if session_id not in self.sessions:
|
||||
@@ -560,9 +594,9 @@ class StreamingSTT:
|
||||
if session_id not in self.processing_tasks:
|
||||
task = asyncio.create_task(self.monitor_buffer(session_id))
|
||||
self.processing_tasks[session_id] = task
|
||||
|
||||
|
||||
buffer = self.sessions[session_id]
|
||||
|
||||
|
||||
# Check for interrupt if in responding state
|
||||
if buffer.check_interrupt(audio_bytes):
|
||||
# Publish interrupt notification
|
||||
@@ -570,57 +604,56 @@ class StreamingSTT:
|
||||
"session_id": session_id,
|
||||
"type": "interrupt",
|
||||
"timestamp": time.time(),
|
||||
"speaker_id": buffer.speaker_id
|
||||
"speaker_id": buffer.speaker_id,
|
||||
}
|
||||
await self.nc.publish(
|
||||
f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}",
|
||||
msgpack.packb(interrupt_msg)
|
||||
f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}", msgpack.packb(interrupt_msg)
|
||||
)
|
||||
logger.info(f"Published interrupt notification for session {session_id}")
|
||||
|
||||
|
||||
# Automatically switch back to listening mode
|
||||
buffer.set_state(SESSION_STATE_LISTENING)
|
||||
|
||||
|
||||
# Add chunk to buffer
|
||||
buffer.add_chunk(audio_bytes)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling stream message: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def run(self):
|
||||
"""Main run loop."""
|
||||
await self.setup()
|
||||
|
||||
|
||||
# Note: STT streaming uses regular NATS subscribe (not pull-based JetStream consumer)
|
||||
# because it handles real-time ephemeral audio streams with wildcard subscriptions.
|
||||
# The stream audio chunks are not meant to be persisted long-term or replayed.
|
||||
# However, the transcription RESULTS are published to JetStream for persistence.
|
||||
sub = await self.nc.subscribe(f"{STREAM_SUBJECT_PREFIX}.>", cb=self.handle_stream_message)
|
||||
logger.info(f"Subscribed to {STREAM_SUBJECT_PREFIX}.>")
|
||||
|
||||
|
||||
# Handle shutdown
|
||||
def signal_handler():
|
||||
self.running = False
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
|
||||
# Keep running
|
||||
while self.running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
# Cleanup
|
||||
logger.info("Shutting down...")
|
||||
|
||||
|
||||
# Cancel all monitoring tasks and wait for them to complete
|
||||
for task in self.processing_tasks.values():
|
||||
task.cancel()
|
||||
|
||||
|
||||
# Wait for all tasks to complete or be cancelled
|
||||
if self.processing_tasks:
|
||||
await asyncio.gather(*self.processing_tasks.values(), return_exceptions=True)
|
||||
|
||||
|
||||
await sub.unsubscribe()
|
||||
await self.nc.close()
|
||||
await self.http_client.aclose()
|
||||
|
||||
Reference in New Issue
Block a user