diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f024e3d --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +.venv/ +__pycache__/ +*.pyc +*.pyo +.pytest_cache/ +.mypy_cache/ +*.egg-info/ +dist/ +build/ +.env +.env.local diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..16ab630 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,14 @@ +FROM python:3.12-slim + +WORKDIR /app + +# Copy requirements and install dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY stt_streaming.py . +COPY healthcheck.py . + +# Run the service +CMD ["python", "stt_streaming.py"] diff --git a/Dockerfile.rocm b/Dockerfile.rocm new file mode 100644 index 0000000..22abb08 --- /dev/null +++ b/Dockerfile.rocm @@ -0,0 +1,52 @@ +# STT Streaming Service with ROCm for AMD GPU Whisper inference +# Targets AMD Strix Halo (gfx1151 / RDNA 3.5) but includes RDNA 3 compatibility +# +# Uses OpenAI Whisper with PyTorch ROCm backend +# +FROM docker.io/rocm/pytorch:rocm7.1_ubuntu24.04_py3.12_pytorch_release_2.9.1 AS base + +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + ffmpeg \ + libsndfile1 \ + && rm -rf /var/lib/apt/lists/* + +# WORKAROUND: ROCm/ROCm#5853 - Standard PyTorch ROCm wheels cause segfault in +# libhsa-runtime64.so during VRAM allocation on gfx1151 (Strix Halo). +# TheRock nightly builds work correctly. Install BEFORE other deps since +# openai-whisper depends on torch. +RUN pip install --no-cache-dir --break-system-packages \ + --index-url https://rocm.nightlies.amd.com/v2/gfx1151/ \ + torch torchaudio torchvision --force-reinstall + +# Install Python dependencies for STT streaming +# Use pip directly (more reliable than uv in this base image) +COPY requirements-rocm.txt . +RUN pip install --no-cache-dir --break-system-packages -r requirements-rocm.txt + +# Download Whisper model at build time for faster startup +# Using medium model for good accuracy/speed balance +ARG WHISPER_MODEL=medium +ENV WHISPER_MODEL_SIZE=${WHISPER_MODEL} + +# Pre-download the model during build (whisper is installed as openai-whisper) +# Use python3 to ensure correct interpreter +RUN python3 -c "import whisper; whisper.load_model('${WHISPER_MODEL}')" || echo "Model will be downloaded at runtime" + +# Copy application code +COPY stt_streaming_local.py . +COPY healthcheck.py . + +# Set ROCm environment for AMD Strix Halo (gfx1151 / RDNA 3.5) +ENV HIP_VISIBLE_DEVICES=0 +ENV HSA_ENABLE_SDMA=0 +# Ensure PyTorch uses ROCm with expandable segments for large models +ENV PYTORCH_HIP_ALLOC_CONF=expandable_segments:True,max_split_size_mb:512 +# Target gfx1151 (Strix Halo) - ROCm 7.1+ has native support +# Falls back to runtime override if kernels not available +ENV ROCM_TARGET_LST=gfx1151,gfx1100 + +# Run the service +CMD ["python", "stt_streaming_local.py"] diff --git a/README.md b/README.md index ab06925..c97dfc8 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,182 @@ -# stt-module +# Streaming STT Module +A dedicated Speech-to-Text (STT) service that processes live audio streams from NATS for faster transcription responses. + +## Overview + +This module enables real-time speech-to-text processing by accepting audio chunks as they arrive rather than waiting for complete audio files. This significantly reduces latency in voice assistant applications. + +## Features + +- **Live Audio Streaming**: Accepts audio chunks via NATS as they're captured +- **Incremental Processing**: Transcribes audio as soon as sufficient data is buffered +- **Session Management**: Handles multiple concurrent streaming sessions +- **Automatic Buffer Management**: Processes audio based on size thresholds or timeout +- **Partial Results**: Publishes transcription results progressively during long streams +- **Voice Activity Detection (VAD)**: Detects speech vs silence to optimize processing +- **Interrupt Detection**: Detects when user speaks during LLM response and switches back to listening mode +- **Speaker Tracking**: Support for speaker identification in multi-speaker scenarios +- **State Management**: Tracks listening/responding states for proper interrupt handling + +## Architecture + +``` +┌─────────────────┐ +│ Audio Source │ (Frontend, Mobile App, etc.) +│ (Microphone) │ +└────────┬────────┘ + │ Chunks + ▼ +┌─────────────────┐ +│ NATS Subject │ ai.voice.stream.{session_id} +│ Audio Stream │ +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ STT Streaming │ (This Service) +│ Service │ - Buffers chunks +│ │ - Transcribes when ready +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ NATS Subject │ ai.voice.transcription.{session_id} +│ Transcription │ +└─────────────────┘ +``` + +## Variants + +### stt_streaming.py (HTTP Backend) +Uses an external Whisper service via HTTP. Lightweight container, delegates GPU inference to a separate service. + +### stt_streaming_local.py (ROCm Backend) +Runs Whisper locally on AMD GPU using ROCm/PyTorch. Single container with embedded model. + +## NATS Message Protocol + +### Audio Stream Input (ai.voice.stream.{session_id}) + +All messages use **msgpack** binary encoding. + +**Start Stream:** +```python +{ + "type": "start", + "session_id": "unique-session-id", + "sample_rate": 16000, + "channels": 1, + "state": "listening", # Optional: "listening" or "responding" + "speaker_id": "speaker-1" # Optional: identifier for speaker tracking +} +``` + +**Audio Chunk:** +```python +{ + "type": "chunk", + "audio_b64": "base64-encoded-audio-data", + "timestamp": 1234567890.123 +} +``` + +**State Change:** +```python +{ + "type": "state_change", + "state": "responding" # "listening" or "responding" +} +``` + +**End Stream:** +```python +{ + "type": "end" +} +``` + +### Transcription Output (ai.voice.transcription.{session_id}) + +**Transcription Result:** +```python +{ + "session_id": "unique-session-id", + "transcript": "transcribed text", + "sequence": 0, + "is_partial": False, + "is_final": True, + "timestamp": 1234567890.123, + "speaker_id": "speaker-1", # If provided in start message + "has_voice_activity": True, + "state": "listening" +} +``` + +**Interrupt Notification:** +```python +{ + "session_id": "unique-session-id", + "type": "interrupt", + "timestamp": 1234567890.123, + "speaker_id": "speaker-1" +} +``` + +## Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `NATS_URL` | `nats://nats.ai-ml.svc.cluster.local:4222` | NATS server URL | +| `WHISPER_URL` | `http://whisper-predictor.ai-ml.svc.cluster.local` | Whisper service URL (HTTP variant) | +| `WHISPER_MODEL_SIZE` | `medium` | Whisper model size (ROCm variant) | +| `WHISPER_DEVICE` | `cuda` | PyTorch device (ROCm variant) | +| `STT_BUFFER_SIZE_BYTES` | `512000` | Buffer size before processing (~5s) | +| `STT_CHUNK_TIMEOUT` | `2.0` | Seconds of silence before processing | +| `STT_ENABLE_VAD` | `true` | Enable voice activity detection | +| `STT_VAD_AGGRESSIVENESS` | `2` | VAD aggressiveness (0-3) | +| `STT_ENABLE_INTERRUPT_DETECTION` | `true` | Enable interrupt detection | +| `OTEL_ENABLED` | `true` | Enable OpenTelemetry | +| `HYPERDX_ENABLED` | `false` | Enable HyperDX observability | + +## Building + +### HTTP Variant +```bash +docker build -t stt-module:latest . +``` + +### ROCm Variant (AMD GPU) +```bash +docker build -f Dockerfile.rocm -t stt-module:rocm --build-arg WHISPER_MODEL=medium . +``` + +## Testing + +```bash +# Port-forward NATS +kubectl port-forward -n ai-ml svc/nats 4222:4222 + +# Start a session +python -c " +import nats +import msgpack +import asyncio + +async def test(): + nc = await nats.connect('nats://localhost:4222') + await nc.publish('ai.voice.stream.test-session', msgpack.packb({'type': 'start'})) + # Send audio chunks... + await nc.publish('ai.voice.stream.test-session', msgpack.packb({'type': 'end'})) + await nc.close() + +asyncio.run(test()) +" + +# Subscribe to transcriptions +nats sub "ai.voice.transcription.>" +``` + +## License + +MIT diff --git a/healthcheck.py b/healthcheck.py new file mode 100644 index 0000000..55a5678 --- /dev/null +++ b/healthcheck.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +""" +Health check script for Kubernetes probes +Verifies NATS connectivity +""" +import sys +import os +import asyncio + +import nats + +NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222") + + +async def check_health(): + """Check if service can connect to NATS.""" + try: + nc = await asyncio.wait_for(nats.connect(NATS_URL), timeout=5.0) + await nc.close() + return True + except Exception as e: + print(f"Health check failed: {e}", file=sys.stderr) + return False + + +if __name__ == "__main__": + result = asyncio.run(check_health()) + sys.exit(0 if result else 1) diff --git a/requirements-rocm.txt b/requirements-rocm.txt new file mode 100644 index 0000000..be4298a --- /dev/null +++ b/requirements-rocm.txt @@ -0,0 +1,24 @@ +# Core dependencies +nats-py>=2.0.0,<3.0.0 +msgpack + +# Whisper for local STT inference (uses PyTorch already in base image) +openai-whisper>=20231117 + +# Audio processing +soundfile +numpy + +# OpenTelemetry core +opentelemetry-api +opentelemetry-sdk +opentelemetry-exporter-otlp-proto-grpc +opentelemetry-exporter-otlp-proto-http +opentelemetry-instrumentation-logging + +# HyperDX support (uses OTLP protocol) +# HyperDX is compatible with standard OTEL exporters, just needs API key header +opentelemetry-sdk-extension-aws # For additional context propagation + +# HTTP health server for kserve compatibility +aiohttp diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d343dab --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +nats-py>=2.0.0,<3.0.0 +httpx>=0.20.0,<1.0.0 +msgpack + +# Audio processing +numpy>=1.20.0,<2.0.0 +webrtcvad>=2.0.10 +# pyannote.audio>=3.1.0 # Optional: for advanced speaker diarization + +# OpenTelemetry core +opentelemetry-api +opentelemetry-sdk + +# OTEL exporters (gRPC for local collector, HTTP for HyperDX) +opentelemetry-exporter-otlp-proto-grpc +opentelemetry-exporter-otlp-proto-http + +# OTEL instrumentation +opentelemetry-instrumentation-httpx +opentelemetry-instrumentation-logging diff --git a/stt_streaming.py b/stt_streaming.py new file mode 100644 index 0000000..d459500 --- /dev/null +++ b/stt_streaming.py @@ -0,0 +1,632 @@ +#!/usr/bin/env python3 +""" +Streaming STT Service + +Real-time Speech-to-Text service that processes live audio streams from NATS: +1. Subscribe to audio stream subject (ai.voice.stream.{session_id}) +2. Buffer and accumulate audio chunks +3. Transcribe when buffer reaches threshold or stream ends +4. Publish transcription results to response channel (ai.voice.transcription.{session_id}) + +This enables faster response times by processing audio as it arrives rather than +waiting for complete audio upload. +""" +import asyncio +import base64 +import contextlib +import logging +import os +import signal +import time +import struct +from typing import Dict, Optional, List, Tuple +from io import BytesIO + +import httpx +import msgpack +import nats +import nats.js +from nats.aio.msg import Msg +import numpy as np +import webrtcvad + +# OpenTelemetry imports +from opentelemetry import trace, metrics +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP +from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE +from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor +from opentelemetry.instrumentation.logging import LoggingInstrumentor + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger("stt-streaming") + +# Initialize OpenTelemetry +def setup_telemetry(): + """Initialize OpenTelemetry tracing and metrics with HyperDX support.""" + # Check if OTEL is enabled + otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true" + if not otel_enabled: + logger.info("OpenTelemetry disabled") + return None, None + + # OTEL configuration + otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317") + service_name = os.environ.get("OTEL_SERVICE_NAME", "stt-streaming") + service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml") + + # HyperDX configuration + hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "") + hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io") + use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key + + # Create resource with service information + resource = Resource.create({ + SERVICE_NAME: service_name, + SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"), + SERVICE_NAMESPACE: service_namespace, + "deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"), + "host.name": os.environ.get("HOSTNAME", "unknown"), + }) + + # Setup tracing + trace_provider = TracerProvider(resource=resource) + + if use_hyperdx: + # Use HTTP exporter for HyperDX with API key header + logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}") + headers = {"authorization": hyperdx_api_key} + otlp_span_exporter = OTLPSpanExporterHTTP( + endpoint=f"{hyperdx_endpoint}/v1/traces", + headers=headers + ) + otlp_metric_exporter = OTLPMetricExporterHTTP( + endpoint=f"{hyperdx_endpoint}/v1/metrics", + headers=headers + ) + else: + # Use gRPC exporter for standard OTEL collector + otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True) + otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True) + + trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter)) + trace.set_tracer_provider(trace_provider) + + # Setup metrics + metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000) + meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) + metrics.set_meter_provider(meter_provider) + + # Instrument HTTPX + HTTPXClientInstrumentor().instrument() + + # Instrument logging + LoggingInstrumentor().instrument(set_logging_format=True) + + destination = "HyperDX" if use_hyperdx else "OTEL Collector" + logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}") + + # Return tracer and meter for the service + tracer = trace.get_tracer(__name__) + meter = metrics.get_meter(__name__) + + return tracer, meter + +# Configuration from environment +WHISPER_URL = os.environ.get("WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local") +NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222") + +# NATS subjects for streaming +STREAM_SUBJECT_PREFIX = "ai.voice.stream" # Full subject: ai.voice.stream.{session_id} +TRANSCRIPTION_SUBJECT_PREFIX = "ai.voice.transcription" # Full subject: ai.voice.transcription.{session_id} + +# Streaming parameters +BUFFER_SIZE_BYTES = int(os.environ.get("STT_BUFFER_SIZE_BYTES", "512000")) # ~5 seconds at 16kHz 16-bit +CHUNK_TIMEOUT_SECONDS = float(os.environ.get("STT_CHUNK_TIMEOUT", "2.0")) # Process after 2s of silence +MAX_BUFFER_SIZE_BYTES = int(os.environ.get("STT_MAX_BUFFER_SIZE", "5120000")) # ~50 seconds max + +# Audio constants +AUDIO_SAMPLE_MAX_INT16 = 32768.0 # Maximum value for 16-bit signed integer audio +VAD_VOICE_RATIO_THRESHOLD = float(os.environ.get("STT_VAD_VOICE_RATIO", "0.3")) # Min ratio of voice frames + +# Voice Activity Detection (VAD) parameters +ENABLE_VAD = os.environ.get("STT_ENABLE_VAD", "true").lower() == "true" +VAD_AGGRESSIVENESS = int(os.environ.get("STT_VAD_AGGRESSIVENESS", "2")) # 0-3, higher = more aggressive +VAD_FRAME_DURATION_MS = int(os.environ.get("STT_VAD_FRAME_DURATION", "30")) # 10, 20, or 30 ms + +# Audio threshold for interrupt detection (when LLM is responding) +ENABLE_INTERRUPT_DETECTION = os.environ.get("STT_ENABLE_INTERRUPT_DETECTION", "true").lower() == "true" +AUDIO_LEVEL_THRESHOLD = float(os.environ.get("STT_AUDIO_LEVEL_THRESHOLD", "0.02")) # RMS threshold +INTERRUPT_DURATION_THRESHOLD = float(os.environ.get("STT_INTERRUPT_DURATION", "0.5")) # Seconds of speech to trigger + +# Speaker diarization +ENABLE_SPEAKER_DIARIZATION = os.environ.get("STT_ENABLE_SPEAKER_DIARIZATION", "false").lower() == "true" + +# Session states +SESSION_STATE_LISTENING = "listening" +SESSION_STATE_RESPONDING = "responding" + + +def calculate_audio_rms(audio_data: bytes, sample_width: int = 2) -> float: + """ + Calculate RMS (Root Mean Square) audio level. + + Args: + audio_data: Raw audio bytes + sample_width: Bytes per sample (2 for 16-bit audio) + + Returns: + RMS level normalized to 0.0-1.0 range + """ + if len(audio_data) < sample_width: + return 0.0 + + # Convert bytes to numpy array of int16 samples + try: + samples = np.frombuffer(audio_data, dtype=np.int16) + # Calculate RMS and normalize + rms = np.sqrt(np.mean(samples.astype(np.float32) ** 2)) + # Normalize to 0-1 range using defined constant + return float(rms / AUDIO_SAMPLE_MAX_INT16) + except Exception as e: + logger.warning(f"Error calculating RMS: {e}") + return 0.0 + + +def detect_voice_activity(audio_data: bytes, sample_rate: int = 16000) -> bool: + """ + Detect if audio contains voice using WebRTC VAD. + + Args: + audio_data: Raw PCM audio bytes (16-bit, mono) + sample_rate: Audio sample rate (8000, 16000, 32000, or 48000) + + Returns: + True if voice is detected, False otherwise + """ + if not ENABLE_VAD: + return True # Assume voice present if VAD disabled + + try: + vad = webrtcvad.Vad(VAD_AGGRESSIVENESS) + + # WebRTC VAD requires specific frame sizes + # Frame duration must be 10, 20, or 30 ms + frame_size = int(sample_rate * VAD_FRAME_DURATION_MS / 1000) * 2 # *2 for 16-bit samples + + # Process audio in frames + voice_frames = 0 + total_frames = 0 + + for i in range(0, len(audio_data) - frame_size, frame_size): + frame = audio_data[i:i + frame_size] + if len(frame) == frame_size: + try: + is_speech = vad.is_speech(frame, sample_rate) + if is_speech: + voice_frames += 1 + total_frames += 1 + except Exception as e: + logger.debug(f"VAD frame processing error: {e}") + continue + + if total_frames == 0: + return False + + # Consider voice detected if voice ratio exceeds threshold + voice_ratio = voice_frames / total_frames + return voice_ratio > VAD_VOICE_RATIO_THRESHOLD + + except Exception as e: + logger.warning(f"VAD error: {e}") + return True # Default to voice present on error + + +class AudioBuffer: + """Manages audio chunks for a streaming session with VAD and speaker tracking.""" + + def __init__(self, session_id: str): + self.session_id = session_id + self.chunks = [] + self.total_bytes = 0 + self.last_chunk_time = time.time() + self.is_complete = False + self.sequence = 0 + self.state = SESSION_STATE_LISTENING # Current session state + self.speaker_id = None # For speaker diarization + self.interrupt_start_time = None # Track when interrupt detection started + self.has_voice_activity = False # Track if voice was detected in recent chunks + self._last_chunk_vad_result = None # Cache VAD result for last chunk + + def add_chunk(self, audio_data: bytes) -> None: + """Add an audio chunk to the buffer and check for voice activity.""" + self.chunks.append(audio_data) + self.total_bytes += len(audio_data) + self.last_chunk_time = time.time() + + # Check for voice activity in this chunk and cache result + has_voice = detect_voice_activity(audio_data) + self.has_voice_activity = has_voice + self._last_chunk_vad_result = has_voice + + logger.debug(f"Session {self.session_id}: Added chunk, total {self.total_bytes} bytes, voice={has_voice}") + + def check_interrupt(self, audio_data: bytes) -> bool: + """ + Check if audio indicates an interrupt during responding state. + Uses cached VAD result if available. + + Returns: + True if interrupt detected, False otherwise + """ + if not ENABLE_INTERRUPT_DETECTION: + return False + + if self.state != SESSION_STATE_RESPONDING: + return False + + # Calculate audio level + rms_level = calculate_audio_rms(audio_data) + + # Use cached VAD result if available to avoid duplicate processing + has_voice = self._last_chunk_vad_result if self._last_chunk_vad_result is not None else detect_voice_activity(audio_data) + + # Check if audio exceeds threshold and contains voice + if rms_level >= AUDIO_LEVEL_THRESHOLD and has_voice: + if self.interrupt_start_time is None: + self.interrupt_start_time = time.time() + logger.info(f"Session {self.session_id}: Potential interrupt detected (RMS={rms_level:.3f})") + + # Check if interrupt has lasted long enough + elapsed = time.time() - self.interrupt_start_time + if elapsed >= INTERRUPT_DURATION_THRESHOLD: + logger.info(f"Session {self.session_id}: Interrupt confirmed after {elapsed:.1f}s") + return True + else: + # Reset interrupt timer if audio drops below threshold + self.interrupt_start_time = None + + return False + + def set_state(self, state: str) -> None: + """Set the session state (listening or responding).""" + if state in (SESSION_STATE_LISTENING, SESSION_STATE_RESPONDING): + old_state = self.state + self.state = state + if old_state != state: + logger.info(f"Session {self.session_id}: State changed from {old_state} to {state}") + # Reset interrupt tracking when changing states + self.interrupt_start_time = None + + def should_process(self) -> bool: + """Determine if buffer should be processed now.""" + # Don't process if no voice activity detected (unless buffer is full or timed out) + if ENABLE_VAD and not self.has_voice_activity: + # Still process if buffer is very large or has timed out + if self.total_bytes < BUFFER_SIZE_BYTES and time.time() - self.last_chunk_time < CHUNK_TIMEOUT_SECONDS: + return False + + # Process if buffer size threshold reached + if self.total_bytes >= BUFFER_SIZE_BYTES: + return True + # Process if no chunks received for timeout duration + if time.time() - self.last_chunk_time > CHUNK_TIMEOUT_SECONDS and self.total_bytes > 0: + return True + # Process if buffer is too large (safety limit) + if self.total_bytes >= MAX_BUFFER_SIZE_BYTES: + return True + return False + + def get_audio(self) -> bytes: + """Get concatenated audio data.""" + return b''.join(self.chunks) + + def clear(self) -> None: + """Clear the buffer after processing.""" + self.chunks = [] + self.total_bytes = 0 + self.sequence += 1 + self._last_chunk_vad_result = None # Clear cached VAD result + + def mark_complete(self) -> None: + """Mark stream as complete.""" + self.is_complete = True + + +class StreamingSTT: + """Streaming Speech-to-Text service.""" + + def __init__(self): + self.nc = None + self.js = None + self.http_client = None + self.sessions: Dict[str, AudioBuffer] = {} + self.running = True + self.processing_tasks = {} + self.is_healthy = False + self.tracer = None + self.meter = None + self.stream_counter = None + self.transcription_duration = None + + async def setup(self): + """Initialize connections.""" + # Initialize OpenTelemetry + self.tracer, self.meter = setup_telemetry() + + # Create metrics if OTEL is enabled + if self.meter: + self.stream_counter = self.meter.create_counter( + name="stt_streams_total", + description="Total number of STT streams processed", + unit="1" + ) + self.transcription_duration = self.meter.create_histogram( + name="stt_transcription_duration_seconds", + description="Duration of STT transcription", + unit="s" + ) + + # NATS connection + self.nc = await nats.connect(NATS_URL) + logger.info(f"Connected to NATS at {NATS_URL}") + + # Initialize JetStream context + self.js = self.nc.jetstream() + + # Create or update stream for voice stream messages + try: + stream_config = nats.js.api.StreamConfig( + name="AI_VOICE_STREAM", + subjects=["ai.voice.stream.>", "ai.voice.transcription.>"], + retention=nats.js.api.RetentionPolicy.LIMITS, + max_age=300, # Keep messages for 5 minutes only (streaming is ephemeral) + storage=nats.js.api.StorageType.MEMORY, # Use memory for streaming data + ) + await self.js.add_stream(stream_config) + logger.info("Created/updated JetStream stream: AI_VOICE_STREAM") + except Exception as e: + # Stream might already exist + logger.info(f"JetStream stream setup: {e}") + + # HTTP client for Whisper service + self.http_client = httpx.AsyncClient(timeout=180.0) + logger.info("HTTP client initialized") + + # Mark as healthy once connections are established + self.is_healthy = True + + async def transcribe(self, audio_bytes: bytes) -> Optional[str]: + """Transcribe audio using Whisper.""" + try: + files = {"file": ("audio.wav", audio_bytes, "audio/wav")} + response = await self.http_client.post( + f"{WHISPER_URL}/v1/audio/transcriptions", + files=files + ) + response.raise_for_status() + result = response.json() + transcript = result.get("text", "") + logger.info(f"Transcribed: {transcript[:100]}...") + return transcript + except Exception as e: + logger.error(f"Transcription failed: {e}") + return None + + async def process_buffer(self, session_id: str): + """Process accumulated audio buffer for a session.""" + buffer = self.sessions.get(session_id) + if not buffer: + return + + audio_data = buffer.get_audio() + if not audio_data: + return + + logger.info(f"Processing {len(audio_data)} bytes for session {session_id}, sequence {buffer.sequence}") + + # Transcribe + transcript = await self.transcribe(audio_data) + + if transcript: + # Publish transcription result using msgpack binary format + result = { + "session_id": session_id, + "transcript": transcript, + "sequence": buffer.sequence, + "is_partial": not buffer.is_complete, + "is_final": buffer.is_complete, + "timestamp": time.time(), + "speaker_id": buffer.speaker_id, + "has_voice_activity": buffer.has_voice_activity, + "state": buffer.state + } + + await self.nc.publish( + f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}", + msgpack.packb(result) + ) + logger.info(f"Published transcription for session {session_id} (seq {buffer.sequence}, speaker={buffer.speaker_id})") + + # Clear buffer after processing + buffer.clear() + + # Clean up completed sessions asynchronously + if buffer.is_complete: + logger.info(f"Session {session_id} completed") + # Schedule cleanup task to avoid blocking + asyncio.create_task(self._cleanup_session(session_id)) + + async def _cleanup_session(self, session_id: str): + """Clean up a completed session after a delay.""" + # Keep session for a bit in case of late messages + await asyncio.sleep(5) + if session_id in self.sessions: + del self.sessions[session_id] + logger.info(f"Cleaned up session: {session_id}") + if session_id in self.processing_tasks: + del self.processing_tasks[session_id] + + async def monitor_buffer(self, session_id: str): + """Monitor buffer and trigger processing when needed.""" + while self.running and session_id in self.sessions: + buffer = self.sessions.get(session_id) + if not buffer: + break + + if buffer.should_process(): + await self.process_buffer(session_id) + + # Don't spin too fast + await asyncio.sleep(0.1) + + async def handle_stream_message(self, msg: Msg): + """Handle incoming audio stream message.""" + try: + # Extract session_id from subject: ai.voice.stream.{session_id} + subject_parts = msg.subject.split('.') + if len(subject_parts) < 4: + logger.warning(f"Invalid subject format: {msg.subject}") + return + + session_id = subject_parts[3] + + # Parse message using msgpack binary format + data = msgpack.unpackb(msg.data, raw=False) + + # Handle control messages + if data.get("type") == "start": + logger.info(f"Starting stream session: {session_id}") + self.sessions[session_id] = AudioBuffer(session_id) + # Set initial state if provided + initial_state = data.get("state", SESSION_STATE_LISTENING) + self.sessions[session_id].set_state(initial_state) + # Store speaker_id if provided + speaker_id = data.get("speaker_id") + if speaker_id: + self.sessions[session_id].speaker_id = speaker_id + logger.info(f"Session {session_id}: Speaker ID set to {speaker_id}") + # Start monitoring task for this session + task = asyncio.create_task(self.monitor_buffer(session_id)) + self.processing_tasks[session_id] = task + return + + if data.get("type") == "state_change": + logger.info(f"State change for session {session_id}") + buffer = self.sessions.get(session_id) + if buffer: + new_state = data.get("state", SESSION_STATE_LISTENING) + buffer.set_state(new_state) + + # If switching to listening mode, reset any interrupt tracking + if new_state == SESSION_STATE_LISTENING: + buffer.interrupt_start_time = None + return + + if data.get("type") == "end": + logger.info(f"Ending stream session: {session_id}") + buffer = self.sessions.get(session_id) + if buffer: + buffer.mark_complete() + # Process any remaining audio + if buffer.total_bytes > 0: + await self.process_buffer(session_id) + return + + # Handle audio chunk + if data.get("type") == "chunk": + audio_b64 = data.get("audio_b64", "") + if not audio_b64: + return + + audio_bytes = base64.b64decode(audio_b64) + + # Create session if it doesn't exist (handle missing start message) + # Check both sessions and processing_tasks to avoid race conditions + if session_id not in self.sessions: + logger.info(f"Auto-creating session: {session_id}") + self.sessions[session_id] = AudioBuffer(session_id) + # Only create monitoring task if not already exists + if session_id not in self.processing_tasks: + task = asyncio.create_task(self.monitor_buffer(session_id)) + self.processing_tasks[session_id] = task + + buffer = self.sessions[session_id] + + # Check for interrupt if in responding state + if buffer.check_interrupt(audio_bytes): + # Publish interrupt notification + interrupt_msg = { + "session_id": session_id, + "type": "interrupt", + "timestamp": time.time(), + "speaker_id": buffer.speaker_id + } + await self.nc.publish( + f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}", + msgpack.packb(interrupt_msg) + ) + logger.info(f"Published interrupt notification for session {session_id}") + + # Automatically switch back to listening mode + buffer.set_state(SESSION_STATE_LISTENING) + + # Add chunk to buffer + buffer.add_chunk(audio_bytes) + + except Exception as e: + logger.error(f"Error handling stream message: {e}", exc_info=True) + + async def run(self): + """Main run loop.""" + await self.setup() + + # Note: STT streaming uses regular NATS subscribe (not pull-based JetStream consumer) + # because it handles real-time ephemeral audio streams with wildcard subscriptions. + # The stream audio chunks are not meant to be persisted long-term or replayed. + # However, the transcription RESULTS are published to JetStream for persistence. + sub = await self.nc.subscribe(f"{STREAM_SUBJECT_PREFIX}.>", cb=self.handle_stream_message) + logger.info(f"Subscribed to {STREAM_SUBJECT_PREFIX}.>") + + # Handle shutdown + def signal_handler(): + self.running = False + + loop = asyncio.get_event_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, signal_handler) + + # Keep running + while self.running: + await asyncio.sleep(1) + + # Cleanup + logger.info("Shutting down...") + + # Cancel all monitoring tasks and wait for them to complete + for task in self.processing_tasks.values(): + task.cancel() + + # Wait for all tasks to complete or be cancelled + if self.processing_tasks: + await asyncio.gather(*self.processing_tasks.values(), return_exceptions=True) + + await sub.unsubscribe() + await self.nc.close() + await self.http_client.aclose() + logger.info("Shutdown complete") + + +if __name__ == "__main__": + service = StreamingSTT() + asyncio.run(service.run()) diff --git a/stt_streaming_local.py b/stt_streaming_local.py new file mode 100644 index 0000000..196d0ae --- /dev/null +++ b/stt_streaming_local.py @@ -0,0 +1,511 @@ +#!/usr/bin/env python3 +""" +Streaming STT Service with Local Whisper on ROCm + +Real-time Speech-to-Text service that processes live audio streams from NATS +using local Whisper model running on AMD GPU via ROCm: + +1. Subscribe to audio stream subject (ai.voice.stream.{session_id}) +2. Buffer and accumulate audio chunks +3. Transcribe locally using Whisper on AMD GPU +4. Publish transcription results to response channel (ai.voice.transcription.{session_id}) + +This version runs Whisper directly on the AMD GPU using ROCm/PyTorch backend +instead of calling an external Whisper service. + +Supports HyperDX for observability via OpenTelemetry. +""" +import asyncio +import base64 +import contextlib +import io +import logging +import os +import signal +import tempfile +import time +from typing import Dict, Optional + +from aiohttp import web +import msgpack +import nats +import nats.js +import numpy as np +import soundfile as sf +import torch +import whisper +from nats.aio.msg import Msg + +# OpenTelemetry imports +from opentelemetry import trace, metrics +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP +from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE +from opentelemetry.instrumentation.logging import LoggingInstrumentor + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger("stt-streaming-rocm") + + +def setup_telemetry(): + """Initialize OpenTelemetry tracing and metrics with HyperDX support.""" + # Check if OTEL is enabled + otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true" + if not otel_enabled: + logger.info("OpenTelemetry disabled") + return None, None + + # OTEL configuration + otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317") + service_name = os.environ.get("OTEL_SERVICE_NAME", "stt-streaming-rocm") + service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml") + + # HyperDX configuration + hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "") + hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io") + use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key + + # Create resource with service information + resource = Resource.create({ + SERVICE_NAME: service_name, + SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"), + SERVICE_NAMESPACE: service_namespace, + "deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"), + "host.name": os.environ.get("HOSTNAME", "unknown"), + }) + + # Setup tracing + trace_provider = TracerProvider(resource=resource) + + if use_hyperdx: + # Use HTTP exporter for HyperDX with API key header + logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}") + headers = {"authorization": hyperdx_api_key} + otlp_span_exporter = OTLPSpanExporterHTTP( + endpoint=f"{hyperdx_endpoint}/v1/traces", + headers=headers + ) + otlp_metric_exporter = OTLPMetricExporterHTTP( + endpoint=f"{hyperdx_endpoint}/v1/metrics", + headers=headers + ) + else: + # Use gRPC exporter for standard OTEL collector + logger.info(f"Configuring OTEL gRPC exporter at {otel_endpoint}") + otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True) + otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True) + + trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter)) + trace.set_tracer_provider(trace_provider) + + # Setup metrics + metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000) + meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) + metrics.set_meter_provider(meter_provider) + + # Instrument logging + LoggingInstrumentor().instrument(set_logging_format=True) + + destination = "HyperDX" if use_hyperdx else "OTEL Collector" + logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}") + + # Return tracer and meter for the service + tracer = trace.get_tracer(__name__) + meter = metrics.get_meter(__name__) + + return tracer, meter + + +# Configuration from environment +NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222") +WHISPER_MODEL_SIZE = os.environ.get("WHISPER_MODEL_SIZE", "medium") +WHISPER_DEVICE = os.environ.get("WHISPER_DEVICE", "cuda") # cuda uses ROCm on AMD +WHISPER_FP16 = os.environ.get("WHISPER_FP16", "true").lower() == "true" + +# NATS subjects for streaming +STREAM_SUBJECT_PREFIX = "ai.voice.stream" # Full subject: ai.voice.stream.{session_id} +TRANSCRIPTION_SUBJECT_PREFIX = "ai.voice.transcription" # Full subject: ai.voice.transcription.{session_id} + +# Streaming parameters +BUFFER_SIZE_BYTES = int(os.environ.get("STT_BUFFER_SIZE_BYTES", "512000")) # ~5 seconds at 16kHz 16-bit +CHUNK_TIMEOUT_SECONDS = float(os.environ.get("STT_CHUNK_TIMEOUT", "2.0")) # Process after 2s of silence +MAX_BUFFER_SIZE_BYTES = int(os.environ.get("STT_MAX_BUFFER_SIZE", "5120000")) # ~50 seconds max + +# Health server port for kserve compatibility +HEALTH_PORT = int(os.environ.get("HEALTH_PORT", "8000")) + + +class AudioBuffer: + """Manages audio chunks for a streaming session.""" + + def __init__(self, session_id: str): + self.session_id = session_id + self.chunks = [] + self.total_bytes = 0 + self.last_chunk_time = time.time() + self.is_complete = False + self.sequence = 0 + + def add_chunk(self, audio_data: bytes) -> None: + """Add an audio chunk to the buffer.""" + self.chunks.append(audio_data) + self.total_bytes += len(audio_data) + self.last_chunk_time = time.time() + logger.debug(f"Session {self.session_id}: Added chunk, total {self.total_bytes} bytes") + + def should_process(self) -> bool: + """Determine if buffer should be processed now.""" + # Process if buffer size threshold reached + if self.total_bytes >= BUFFER_SIZE_BYTES: + return True + # Process if no chunks received for timeout duration + if time.time() - self.last_chunk_time > CHUNK_TIMEOUT_SECONDS and self.total_bytes > 0: + return True + # Process if buffer is too large (safety limit) + if self.total_bytes >= MAX_BUFFER_SIZE_BYTES: + return True + return False + + def get_audio(self) -> bytes: + """Get concatenated audio data.""" + return b''.join(self.chunks) + + def clear(self) -> None: + """Clear the buffer after processing.""" + self.chunks = [] + self.total_bytes = 0 + self.sequence += 1 + + def mark_complete(self) -> None: + """Mark stream as complete.""" + self.is_complete = True + + +class StreamingSTTLocal: + """Streaming Speech-to-Text service with local Whisper on ROCm.""" + + def __init__(self): + self.nc = None + self.js = None + self.whisper_model = None + self.sessions: Dict[str, AudioBuffer] = {} + self.running = True + self.processing_tasks = {} + self.is_healthy = False + self.tracer = None + self.meter = None + self.stream_counter = None + self.transcription_duration = None + self.gpu_memory_gauge = None + + async def setup(self): + """Initialize connections and load model.""" + # Initialize OpenTelemetry + self.tracer, self.meter = setup_telemetry() + + # Create metrics if OTEL is enabled + if self.meter: + self.stream_counter = self.meter.create_counter( + name="stt_streams_total", + description="Total number of STT streams processed", + unit="1" + ) + self.transcription_duration = self.meter.create_histogram( + name="stt_transcription_duration_seconds", + description="Duration of STT transcription", + unit="s" + ) + self.gpu_memory_gauge = self.meter.create_observable_gauge( + name="stt_gpu_memory_bytes", + description="GPU memory usage in bytes", + callbacks=[self._get_gpu_memory] + ) + + # Check GPU availability + if torch.cuda.is_available(): + gpu_name = torch.cuda.get_device_name(0) + gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) + logger.info(f"ROCm GPU available: {gpu_name} ({gpu_memory:.1f}GB)") + else: + logger.warning("No GPU available, falling back to CPU") + + # Load Whisper model + logger.info(f"Loading Whisper model: {WHISPER_MODEL_SIZE} on {WHISPER_DEVICE}") + start_time = time.time() + self.whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device=WHISPER_DEVICE) + load_time = time.time() - start_time + logger.info(f"Whisper model loaded in {load_time:.2f}s") + + # NATS connection + self.nc = await nats.connect(NATS_URL) + logger.info(f"Connected to NATS at {NATS_URL}") + + # Initialize JetStream context + self.js = self.nc.jetstream() + + # Create or update stream for voice stream messages + try: + stream_config = nats.js.api.StreamConfig( + name="AI_VOICE_STREAM", + subjects=["ai.voice.stream.>", "ai.voice.transcription.>"], + retention=nats.js.api.RetentionPolicy.LIMITS, + max_age=300, # Keep messages for 5 minutes only (streaming is ephemeral) + storage=nats.js.api.StorageType.MEMORY, # Use memory for streaming data + ) + await self.js.add_stream(stream_config) + logger.info("Created/updated JetStream stream: AI_VOICE_STREAM") + except Exception as e: + # Stream might already exist + logger.info(f"JetStream stream setup: {e}") + + # Mark as healthy once connections are established + self.is_healthy = True + + async def health_handler(self, request: web.Request) -> web.Response: + """Handle health check requests for kserve compatibility.""" + if self.is_healthy: + return web.json_response({ + "status": "healthy", + "model": WHISPER_MODEL_SIZE, + "device": WHISPER_DEVICE, + "nats_connected": self.nc is not None and self.nc.is_connected, + }) + else: + return web.json_response( + {"status": "unhealthy", "model": WHISPER_MODEL_SIZE}, + status=503 + ) + + async def start_health_server(self) -> web.AppRunner: + """Start HTTP health server for kserve agent sidecar.""" + app = web.Application() + app.router.add_get("/health", self.health_handler) + app.router.add_get("/ready", self.health_handler) + app.router.add_get("/", self.health_handler) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "0.0.0.0", HEALTH_PORT) + await site.start() + logger.info(f"Health server started on port {HEALTH_PORT}") + return runner + + def _get_gpu_memory(self, options): + """Callback for GPU memory gauge.""" + if torch.cuda.is_available(): + memory_used = torch.cuda.memory_allocated(0) + yield metrics.Observation(memory_used, {"device": "0"}) + + def transcribe(self, audio_bytes: bytes) -> Optional[str]: + """Transcribe audio using local Whisper model.""" + start_time = time.time() + + try: + # Write audio to temp file (Whisper needs file path or numpy array) + with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp: + tmp.write(audio_bytes) + tmp.flush() + + # Transcribe with Whisper + result = self.whisper_model.transcribe( + tmp.name, + fp16=WHISPER_FP16 and WHISPER_DEVICE == "cuda", + language="en", # Can be made configurable + ) + + transcript = result.get("text", "").strip() + + duration = time.time() - start_time + audio_duration = len(audio_bytes) / (16000 * 2) # Assuming 16kHz 16-bit + rtf = duration / audio_duration if audio_duration > 0 else 0 + + logger.info(f"Transcribed in {duration:.2f}s (RTF: {rtf:.2f}): {transcript[:100]}...") + + # Record metrics + if self.transcription_duration: + self.transcription_duration.record(duration, {"model": WHISPER_MODEL_SIZE}) + + return transcript + + except Exception as e: + logger.error(f"Transcription failed: {e}", exc_info=True) + return None + + async def process_buffer(self, session_id: str): + """Process accumulated audio buffer for a session.""" + buffer = self.sessions.get(session_id) + if not buffer: + return + + audio_data = buffer.get_audio() + if not audio_data: + return + + logger.info(f"Processing {len(audio_data)} bytes for session {session_id}, sequence {buffer.sequence}") + + # Record stream counter + if self.stream_counter: + self.stream_counter.add(1, {"session_id": session_id}) + + # Transcribe in thread pool to avoid blocking event loop + loop = asyncio.get_event_loop() + transcript = await loop.run_in_executor(None, self.transcribe, audio_data) + + if transcript: + # Publish transcription result using msgpack binary format + result = { + "session_id": session_id, + "transcript": transcript, + "sequence": buffer.sequence, + "is_partial": not buffer.is_complete, + "is_final": buffer.is_complete, + "timestamp": time.time(), + "model": WHISPER_MODEL_SIZE, + "device": WHISPER_DEVICE, + } + + await self.nc.publish( + f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}", + msgpack.packb(result) + ) + logger.info(f"Published transcription for session {session_id} (seq {buffer.sequence})") + + # Clear buffer after processing + buffer.clear() + + # Clean up completed sessions asynchronously + if buffer.is_complete: + logger.info(f"Session {session_id} completed") + # Schedule cleanup task to avoid blocking + asyncio.create_task(self._cleanup_session(session_id)) + + async def _cleanup_session(self, session_id: str): + """Clean up a completed session after a delay.""" + # Keep session for a bit in case of late messages + await asyncio.sleep(5) + if session_id in self.sessions: + del self.sessions[session_id] + logger.info(f"Cleaned up session: {session_id}") + if session_id in self.processing_tasks: + del self.processing_tasks[session_id] + + async def monitor_buffer(self, session_id: str): + """Monitor buffer and trigger processing when needed.""" + while self.running and session_id in self.sessions: + buffer = self.sessions.get(session_id) + if not buffer: + break + + if buffer.should_process(): + await self.process_buffer(session_id) + + # Don't spin too fast + await asyncio.sleep(0.1) + + async def handle_stream_message(self, msg: Msg): + """Handle incoming audio stream message.""" + try: + # Extract session_id from subject: ai.voice.stream.{session_id} + subject_parts = msg.subject.split('.') + if len(subject_parts) < 4: + logger.warning(f"Invalid subject format: {msg.subject}") + return + + session_id = subject_parts[3] + + # Parse message using msgpack binary format + data = msgpack.unpackb(msg.data, raw=False) + + # Handle control messages + if data.get("type") == "start": + logger.info(f"Starting stream session: {session_id}") + self.sessions[session_id] = AudioBuffer(session_id) + # Start monitoring task for this session + task = asyncio.create_task(self.monitor_buffer(session_id)) + self.processing_tasks[session_id] = task + return + + if data.get("type") == "end": + logger.info(f"Ending stream session: {session_id}") + buffer = self.sessions.get(session_id) + if buffer: + buffer.mark_complete() + # Process any remaining audio + if buffer.total_bytes > 0: + await self.process_buffer(session_id) + return + + # Handle audio chunk + if data.get("type") == "chunk": + audio_b64 = data.get("audio_b64", "") + if not audio_b64: + return + + audio_bytes = base64.b64decode(audio_b64) + + # Create session if it doesn't exist (handle missing start message) + if session_id not in self.sessions: + logger.info(f"Auto-creating session: {session_id}") + self.sessions[session_id] = AudioBuffer(session_id) + if session_id not in self.processing_tasks: + task = asyncio.create_task(self.monitor_buffer(session_id)) + self.processing_tasks[session_id] = task + + # Add chunk to buffer + self.sessions[session_id].add_chunk(audio_bytes) + + except Exception as e: + logger.error(f"Error handling stream message: {e}", exc_info=True) + + async def run(self): + """Main run loop.""" + await self.setup() + + # Start health server for kserve compatibility + health_runner = await self.start_health_server() + + # Subscribe to voice stream + sub = await self.nc.subscribe(f"{STREAM_SUBJECT_PREFIX}.>", cb=self.handle_stream_message) + logger.info(f"Subscribed to {STREAM_SUBJECT_PREFIX}.>") + + # Handle shutdown + def signal_handler(): + self.running = False + + loop = asyncio.get_event_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, signal_handler) + + # Keep running + while self.running: + await asyncio.sleep(1) + + # Cleanup + logger.info("Shutting down...") + + # Cancel all monitoring tasks and wait for them to complete + for task in self.processing_tasks.values(): + task.cancel() + + if self.processing_tasks: + await asyncio.gather(*self.processing_tasks.values(), return_exceptions=True) + + await sub.unsubscribe() + await self.nc.close() + await health_runner.cleanup() + logger.info("Shutdown complete") + + +if __name__ == "__main__": + service = StreamingSTTLocal() + asyncio.run(service.run())