#!/usr/bin/env python3 """ Streaming STT Service Real-time Speech-to-Text service that processes live audio streams from NATS: 1. Subscribe to audio stream subject (ai.voice.stream.{session_id}) 2. Buffer and accumulate audio chunks 3. Transcribe when buffer reaches threshold or stream ends 4. Publish transcription results to response channel (ai.voice.transcription.{session_id}) This enables faster response times by processing audio as it arrives rather than waiting for complete audio upload. """ import asyncio import base64 import logging import os import signal import time import httpx import msgpack import nats import nats.js import numpy as np import webrtcvad from nats.aio.msg import Msg # OpenTelemetry imports from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import ( OTLPMetricExporter as OTLPMetricExporterHTTP, ) from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( OTLPSpanExporter as OTLPSpanExporterHTTP, ) from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor from opentelemetry.instrumentation.logging import LoggingInstrumentor from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, SERVICE_VERSION, Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger("stt-streaming") # Initialize OpenTelemetry def setup_telemetry(): """Initialize OpenTelemetry tracing and metrics with HyperDX support.""" # Check if OTEL is enabled otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true" if not otel_enabled: logger.info("OpenTelemetry disabled") return None, None # OTEL configuration otel_endpoint = os.environ.get( "OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317", ) service_name = os.environ.get("OTEL_SERVICE_NAME", "stt-streaming") service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml") # HyperDX configuration hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "") hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io") use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key # Create resource with service information resource = Resource.create( { SERVICE_NAME: service_name, SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"), SERVICE_NAMESPACE: service_namespace, "deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"), "host.name": os.environ.get("HOSTNAME", "unknown"), } ) # Setup tracing trace_provider = TracerProvider(resource=resource) if use_hyperdx: # Use HTTP exporter for HyperDX with API key header logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}") headers = {"authorization": hyperdx_api_key} otlp_span_exporter = OTLPSpanExporterHTTP( endpoint=f"{hyperdx_endpoint}/v1/traces", headers=headers ) otlp_metric_exporter = OTLPMetricExporterHTTP( endpoint=f"{hyperdx_endpoint}/v1/metrics", headers=headers ) else: # Use gRPC exporter for standard OTEL collector otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True) otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True) trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter)) trace.set_tracer_provider(trace_provider) # Setup metrics metric_reader = PeriodicExportingMetricReader( otlp_metric_exporter, export_interval_millis=60000 ) meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) metrics.set_meter_provider(meter_provider) # Instrument HTTPX HTTPXClientInstrumentor().instrument() # Instrument logging LoggingInstrumentor().instrument(set_logging_format=True) destination = "HyperDX" if use_hyperdx else "OTEL Collector" logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}") # Return tracer and meter for the service tracer = trace.get_tracer(__name__) meter = metrics.get_meter(__name__) return tracer, meter # Configuration from environment WHISPER_URL = os.environ.get("WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local") NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222") # NATS subjects for streaming STREAM_SUBJECT_PREFIX = "ai.voice.stream" # Full subject: ai.voice.stream.{session_id} TRANSCRIPTION_SUBJECT_PREFIX = ( "ai.voice.transcription" # Full subject: ai.voice.transcription.{session_id} ) # Streaming parameters BUFFER_SIZE_BYTES = int( os.environ.get("STT_BUFFER_SIZE_BYTES", "512000") ) # ~5 seconds at 16kHz 16-bit CHUNK_TIMEOUT_SECONDS = float( os.environ.get("STT_CHUNK_TIMEOUT", "2.0") ) # Process after 2s of silence MAX_BUFFER_SIZE_BYTES = int(os.environ.get("STT_MAX_BUFFER_SIZE", "5120000")) # ~50 seconds max # Audio constants AUDIO_SAMPLE_MAX_INT16 = 32768.0 # Maximum value for 16-bit signed integer audio VAD_VOICE_RATIO_THRESHOLD = float( os.environ.get("STT_VAD_VOICE_RATIO", "0.3") ) # Min ratio of voice frames # Voice Activity Detection (VAD) parameters ENABLE_VAD = os.environ.get("STT_ENABLE_VAD", "true").lower() == "true" VAD_AGGRESSIVENESS = int( os.environ.get("STT_VAD_AGGRESSIVENESS", "2") ) # 0-3, higher = more aggressive VAD_FRAME_DURATION_MS = int(os.environ.get("STT_VAD_FRAME_DURATION", "30")) # 10, 20, or 30 ms # Audio threshold for interrupt detection (when LLM is responding) ENABLE_INTERRUPT_DETECTION = ( os.environ.get("STT_ENABLE_INTERRUPT_DETECTION", "true").lower() == "true" ) AUDIO_LEVEL_THRESHOLD = float(os.environ.get("STT_AUDIO_LEVEL_THRESHOLD", "0.02")) # RMS threshold INTERRUPT_DURATION_THRESHOLD = float( os.environ.get("STT_INTERRUPT_DURATION", "0.5") ) # Seconds of speech to trigger # Speaker diarization ENABLE_SPEAKER_DIARIZATION = ( os.environ.get("STT_ENABLE_SPEAKER_DIARIZATION", "false").lower() == "true" ) # Session states SESSION_STATE_LISTENING = "listening" SESSION_STATE_RESPONDING = "responding" def calculate_audio_rms(audio_data: bytes, sample_width: int = 2) -> float: """ Calculate RMS (Root Mean Square) audio level. Args: audio_data: Raw audio bytes sample_width: Bytes per sample (2 for 16-bit audio) Returns: RMS level normalized to 0.0-1.0 range """ if len(audio_data) < sample_width: return 0.0 # Convert bytes to numpy array of int16 samples try: samples = np.frombuffer(audio_data, dtype=np.int16) # Calculate RMS and normalize rms = np.sqrt(np.mean(samples.astype(np.float32) ** 2)) # Normalize to 0-1 range using defined constant return float(rms / AUDIO_SAMPLE_MAX_INT16) except Exception as e: logger.warning(f"Error calculating RMS: {e}") return 0.0 def detect_voice_activity(audio_data: bytes, sample_rate: int = 16000) -> bool: """ Detect if audio contains voice using WebRTC VAD. Args: audio_data: Raw PCM audio bytes (16-bit, mono) sample_rate: Audio sample rate (8000, 16000, 32000, or 48000) Returns: True if voice is detected, False otherwise """ if not ENABLE_VAD: return True # Assume voice present if VAD disabled try: vad = webrtcvad.Vad(VAD_AGGRESSIVENESS) # WebRTC VAD requires specific frame sizes # Frame duration must be 10, 20, or 30 ms frame_size = int(sample_rate * VAD_FRAME_DURATION_MS / 1000) * 2 # *2 for 16-bit samples # Process audio in frames voice_frames = 0 total_frames = 0 for i in range(0, len(audio_data) - frame_size, frame_size): frame = audio_data[i : i + frame_size] if len(frame) == frame_size: try: is_speech = vad.is_speech(frame, sample_rate) if is_speech: voice_frames += 1 total_frames += 1 except Exception as e: logger.debug(f"VAD frame processing error: {e}") continue if total_frames == 0: return False # Consider voice detected if voice ratio exceeds threshold voice_ratio = voice_frames / total_frames return voice_ratio > VAD_VOICE_RATIO_THRESHOLD except Exception as e: logger.warning(f"VAD error: {e}") return True # Default to voice present on error class AudioBuffer: """Manages audio chunks for a streaming session with VAD and speaker tracking.""" def __init__(self, session_id: str): self.session_id = session_id self.chunks = [] self.total_bytes = 0 self.last_chunk_time = time.time() self.is_complete = False self.sequence = 0 self.state = SESSION_STATE_LISTENING # Current session state self.speaker_id = None # For speaker diarization self.interrupt_start_time = None # Track when interrupt detection started self.has_voice_activity = False # Track if voice was detected in recent chunks self._last_chunk_vad_result = None # Cache VAD result for last chunk def add_chunk(self, audio_data: bytes) -> None: """Add an audio chunk to the buffer and check for voice activity.""" self.chunks.append(audio_data) self.total_bytes += len(audio_data) self.last_chunk_time = time.time() # Check for voice activity in this chunk and cache result has_voice = detect_voice_activity(audio_data) self.has_voice_activity = has_voice self._last_chunk_vad_result = has_voice logger.debug( f"Session {self.session_id}: Added chunk, total {self.total_bytes} bytes, voice={has_voice}" ) def check_interrupt(self, audio_data: bytes) -> bool: """ Check if audio indicates an interrupt during responding state. Uses cached VAD result if available. Returns: True if interrupt detected, False otherwise """ if not ENABLE_INTERRUPT_DETECTION: return False if self.state != SESSION_STATE_RESPONDING: return False # Calculate audio level rms_level = calculate_audio_rms(audio_data) # Use cached VAD result if available to avoid duplicate processing has_voice = ( self._last_chunk_vad_result if self._last_chunk_vad_result is not None else detect_voice_activity(audio_data) ) # Check if audio exceeds threshold and contains voice if rms_level >= AUDIO_LEVEL_THRESHOLD and has_voice: if self.interrupt_start_time is None: self.interrupt_start_time = time.time() logger.info( f"Session {self.session_id}: Potential interrupt detected (RMS={rms_level:.3f})" ) # Check if interrupt has lasted long enough elapsed = time.time() - self.interrupt_start_time if elapsed >= INTERRUPT_DURATION_THRESHOLD: logger.info(f"Session {self.session_id}: Interrupt confirmed after {elapsed:.1f}s") return True else: # Reset interrupt timer if audio drops below threshold self.interrupt_start_time = None return False def set_state(self, state: str) -> None: """Set the session state (listening or responding).""" if state in (SESSION_STATE_LISTENING, SESSION_STATE_RESPONDING): old_state = self.state self.state = state if old_state != state: logger.info(f"Session {self.session_id}: State changed from {old_state} to {state}") # Reset interrupt tracking when changing states self.interrupt_start_time = None def should_process(self) -> bool: """Determine if buffer should be processed now.""" # Don't process if no voice activity detected (unless buffer is full or timed out) if ( ENABLE_VAD and not self.has_voice_activity and self.total_bytes < BUFFER_SIZE_BYTES and time.time() - self.last_chunk_time < CHUNK_TIMEOUT_SECONDS ): return False # Process if buffer size threshold reached if self.total_bytes >= BUFFER_SIZE_BYTES: return True # Process if no chunks received for timeout duration if time.time() - self.last_chunk_time > CHUNK_TIMEOUT_SECONDS and self.total_bytes > 0: return True # Process if buffer is too large (safety limit) return self.total_bytes >= MAX_BUFFER_SIZE_BYTES def get_audio(self) -> bytes: """Get concatenated audio data.""" return b"".join(self.chunks) def clear(self) -> None: """Clear the buffer after processing.""" self.chunks = [] self.total_bytes = 0 self.sequence += 1 self._last_chunk_vad_result = None # Clear cached VAD result def mark_complete(self) -> None: """Mark stream as complete.""" self.is_complete = True class StreamingSTT: """Streaming Speech-to-Text service.""" def __init__(self): self.nc = None self.js = None self.http_client = None self.sessions: dict[str, AudioBuffer] = {} self.running = True self.processing_tasks = {} self.is_healthy = False self.tracer = None self.meter = None self.stream_counter = None self.transcription_duration = None async def setup(self): """Initialize connections.""" # Initialize OpenTelemetry self.tracer, self.meter = setup_telemetry() # Create metrics if OTEL is enabled if self.meter: self.stream_counter = self.meter.create_counter( name="stt_streams_total", description="Total number of STT streams processed", unit="1", ) self.transcription_duration = self.meter.create_histogram( name="stt_transcription_duration_seconds", description="Duration of STT transcription", unit="s", ) # NATS connection self.nc = await nats.connect(NATS_URL) logger.info(f"Connected to NATS at {NATS_URL}") # Initialize JetStream context self.js = self.nc.jetstream() # Create or update stream for voice stream messages try: stream_config = nats.js.api.StreamConfig( name="AI_VOICE_STREAM", subjects=["ai.voice.stream.>", "ai.voice.transcription.>"], retention=nats.js.api.RetentionPolicy.LIMITS, max_age=300, # Keep messages for 5 minutes only (streaming is ephemeral) storage=nats.js.api.StorageType.MEMORY, # Use memory for streaming data ) await self.js.add_stream(stream_config) logger.info("Created/updated JetStream stream: AI_VOICE_STREAM") except Exception as e: # Stream might already exist logger.info(f"JetStream stream setup: {e}") # HTTP client for Whisper service self.http_client = httpx.AsyncClient(timeout=180.0) logger.info("HTTP client initialized") # Mark as healthy once connections are established self.is_healthy = True async def transcribe(self, audio_bytes: bytes) -> str | None: """Transcribe audio using Whisper.""" try: files = {"file": ("audio.wav", audio_bytes, "audio/wav")} response = await self.http_client.post( f"{WHISPER_URL}/v1/audio/transcriptions", files=files ) response.raise_for_status() result = response.json() transcript = result.get("text", "") logger.info(f"Transcribed: {transcript[:100]}...") return transcript except Exception as e: logger.error(f"Transcription failed: {e}") return None async def process_buffer(self, session_id: str): """Process accumulated audio buffer for a session.""" buffer = self.sessions.get(session_id) if not buffer: return audio_data = buffer.get_audio() if not audio_data: return logger.info( f"Processing {len(audio_data)} bytes for session {session_id}, sequence {buffer.sequence}" ) # Transcribe transcript = await self.transcribe(audio_data) if transcript: # Publish transcription result using msgpack binary format result = { "session_id": session_id, "transcript": transcript, "sequence": buffer.sequence, "is_partial": not buffer.is_complete, "is_final": buffer.is_complete, "timestamp": time.time(), "speaker_id": buffer.speaker_id, "has_voice_activity": buffer.has_voice_activity, "state": buffer.state, } await self.nc.publish( f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}", msgpack.packb(result) ) logger.info( f"Published transcription for session {session_id} (seq {buffer.sequence}, speaker={buffer.speaker_id})" ) # Clear buffer after processing buffer.clear() # Clean up completed sessions asynchronously if buffer.is_complete: logger.info(f"Session {session_id} completed") # Schedule cleanup task to avoid blocking asyncio.create_task(self._cleanup_session(session_id)) async def _cleanup_session(self, session_id: str): """Clean up a completed session after a delay.""" # Keep session for a bit in case of late messages await asyncio.sleep(5) if session_id in self.sessions: del self.sessions[session_id] logger.info(f"Cleaned up session: {session_id}") if session_id in self.processing_tasks: del self.processing_tasks[session_id] async def monitor_buffer(self, session_id: str): """Monitor buffer and trigger processing when needed.""" while self.running and session_id in self.sessions: buffer = self.sessions.get(session_id) if not buffer: break if buffer.should_process(): await self.process_buffer(session_id) # Don't spin too fast await asyncio.sleep(0.1) async def handle_stream_message(self, msg: Msg): """Handle incoming audio stream message.""" try: # Extract session_id from subject: ai.voice.stream.{session_id} subject_parts = msg.subject.split(".") if len(subject_parts) < 4: logger.warning(f"Invalid subject format: {msg.subject}") return session_id = subject_parts[3] # Parse message using msgpack binary format data = msgpack.unpackb(msg.data, raw=False) # Handle control messages if data.get("type") == "start": logger.info(f"Starting stream session: {session_id}") self.sessions[session_id] = AudioBuffer(session_id) # Set initial state if provided initial_state = data.get("state", SESSION_STATE_LISTENING) self.sessions[session_id].set_state(initial_state) # Store speaker_id if provided speaker_id = data.get("speaker_id") if speaker_id: self.sessions[session_id].speaker_id = speaker_id logger.info(f"Session {session_id}: Speaker ID set to {speaker_id}") # Start monitoring task for this session task = asyncio.create_task(self.monitor_buffer(session_id)) self.processing_tasks[session_id] = task return if data.get("type") == "state_change": logger.info(f"State change for session {session_id}") buffer = self.sessions.get(session_id) if buffer: new_state = data.get("state", SESSION_STATE_LISTENING) buffer.set_state(new_state) # If switching to listening mode, reset any interrupt tracking if new_state == SESSION_STATE_LISTENING: buffer.interrupt_start_time = None return if data.get("type") == "end": logger.info(f"Ending stream session: {session_id}") buffer = self.sessions.get(session_id) if buffer: buffer.mark_complete() # Process any remaining audio if buffer.total_bytes > 0: await self.process_buffer(session_id) return # Handle audio chunk if data.get("type") == "chunk": audio_b64 = data.get("audio_b64", "") if not audio_b64: return audio_bytes = base64.b64decode(audio_b64) # Create session if it doesn't exist (handle missing start message) # Check both sessions and processing_tasks to avoid race conditions if session_id not in self.sessions: logger.info(f"Auto-creating session: {session_id}") self.sessions[session_id] = AudioBuffer(session_id) # Only create monitoring task if not already exists if session_id not in self.processing_tasks: task = asyncio.create_task(self.monitor_buffer(session_id)) self.processing_tasks[session_id] = task buffer = self.sessions[session_id] # Check for interrupt if in responding state if buffer.check_interrupt(audio_bytes): # Publish interrupt notification interrupt_msg = { "session_id": session_id, "type": "interrupt", "timestamp": time.time(), "speaker_id": buffer.speaker_id, } await self.nc.publish( f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}", msgpack.packb(interrupt_msg) ) logger.info(f"Published interrupt notification for session {session_id}") # Automatically switch back to listening mode buffer.set_state(SESSION_STATE_LISTENING) # Add chunk to buffer buffer.add_chunk(audio_bytes) except Exception as e: logger.error(f"Error handling stream message: {e}", exc_info=True) async def run(self): """Main run loop.""" await self.setup() # Note: STT streaming uses regular NATS subscribe (not pull-based JetStream consumer) # because it handles real-time ephemeral audio streams with wildcard subscriptions. # The stream audio chunks are not meant to be persisted long-term or replayed. # However, the transcription RESULTS are published to JetStream for persistence. sub = await self.nc.subscribe(f"{STREAM_SUBJECT_PREFIX}.>", cb=self.handle_stream_message) logger.info(f"Subscribed to {STREAM_SUBJECT_PREFIX}.>") # Handle shutdown def signal_handler(): self.running = False loop = asyncio.get_event_loop() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, signal_handler) # Keep running while self.running: await asyncio.sleep(1) # Cleanup logger.info("Shutting down...") # Cancel all monitoring tasks and wait for them to complete for task in self.processing_tasks.values(): task.cancel() # Wait for all tasks to complete or be cancelled if self.processing_tasks: await asyncio.gather(*self.processing_tasks.values(), return_exceptions=True) await sub.unsubscribe() await self.nc.close() await self.http_client.aclose() logger.info("Shutdown complete") if __name__ == "__main__": service = StreamingSTT() asyncio.run(service.run())