#!/usr/bin/env python3 """ Streaming STT Service with Local Whisper on ROCm Real-time Speech-to-Text service that processes live audio streams from NATS using local Whisper model running on AMD GPU via ROCm: 1. Subscribe to audio stream subject (ai.voice.stream.{session_id}) 2. Buffer and accumulate audio chunks 3. Transcribe locally using Whisper on AMD GPU 4. Publish transcription results to response channel (ai.voice.transcription.{session_id}) This version runs Whisper directly on the AMD GPU using ROCm/PyTorch backend instead of calling an external Whisper service. Supports HyperDX for observability via OpenTelemetry. """ import asyncio import base64 import contextlib import io import logging import os import signal import tempfile import time from typing import Dict, Optional from aiohttp import web import msgpack import nats import nats.js import numpy as np import soundfile as sf import torch import whisper from nats.aio.msg import Msg # OpenTelemetry imports from opentelemetry import trace, metrics from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE from opentelemetry.instrumentation.logging import LoggingInstrumentor # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("stt-streaming-rocm") def setup_telemetry(): """Initialize OpenTelemetry tracing and metrics with HyperDX support.""" # Check if OTEL is enabled otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true" if not otel_enabled: logger.info("OpenTelemetry disabled") return None, None # OTEL configuration otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317") service_name = os.environ.get("OTEL_SERVICE_NAME", "stt-streaming-rocm") service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml") # HyperDX configuration hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "") hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io") use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key # Create resource with service information resource = Resource.create({ SERVICE_NAME: service_name, SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"), SERVICE_NAMESPACE: service_namespace, "deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"), "host.name": os.environ.get("HOSTNAME", "unknown"), }) # Setup tracing trace_provider = TracerProvider(resource=resource) if use_hyperdx: # Use HTTP exporter for HyperDX with API key header logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}") headers = {"authorization": hyperdx_api_key} otlp_span_exporter = OTLPSpanExporterHTTP( endpoint=f"{hyperdx_endpoint}/v1/traces", headers=headers ) otlp_metric_exporter = OTLPMetricExporterHTTP( endpoint=f"{hyperdx_endpoint}/v1/metrics", headers=headers ) else: # Use gRPC exporter for standard OTEL collector logger.info(f"Configuring OTEL gRPC exporter at {otel_endpoint}") otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True) otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True) trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter)) trace.set_tracer_provider(trace_provider) # Setup metrics metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000) meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) metrics.set_meter_provider(meter_provider) # Instrument logging LoggingInstrumentor().instrument(set_logging_format=True) destination = "HyperDX" if use_hyperdx else "OTEL Collector" logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}") # Return tracer and meter for the service tracer = trace.get_tracer(__name__) meter = metrics.get_meter(__name__) return tracer, meter # Configuration from environment NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222") WHISPER_MODEL_SIZE = os.environ.get("WHISPER_MODEL_SIZE", "medium") WHISPER_DEVICE = os.environ.get("WHISPER_DEVICE", "cuda") # cuda uses ROCm on AMD WHISPER_FP16 = os.environ.get("WHISPER_FP16", "true").lower() == "true" # NATS subjects for streaming STREAM_SUBJECT_PREFIX = "ai.voice.stream" # Full subject: ai.voice.stream.{session_id} TRANSCRIPTION_SUBJECT_PREFIX = "ai.voice.transcription" # Full subject: ai.voice.transcription.{session_id} # Streaming parameters BUFFER_SIZE_BYTES = int(os.environ.get("STT_BUFFER_SIZE_BYTES", "512000")) # ~5 seconds at 16kHz 16-bit CHUNK_TIMEOUT_SECONDS = float(os.environ.get("STT_CHUNK_TIMEOUT", "2.0")) # Process after 2s of silence MAX_BUFFER_SIZE_BYTES = int(os.environ.get("STT_MAX_BUFFER_SIZE", "5120000")) # ~50 seconds max # Health server port for kserve compatibility HEALTH_PORT = int(os.environ.get("HEALTH_PORT", "8000")) class AudioBuffer: """Manages audio chunks for a streaming session.""" def __init__(self, session_id: str): self.session_id = session_id self.chunks = [] self.total_bytes = 0 self.last_chunk_time = time.time() self.is_complete = False self.sequence = 0 def add_chunk(self, audio_data: bytes) -> None: """Add an audio chunk to the buffer.""" self.chunks.append(audio_data) self.total_bytes += len(audio_data) self.last_chunk_time = time.time() logger.debug(f"Session {self.session_id}: Added chunk, total {self.total_bytes} bytes") def should_process(self) -> bool: """Determine if buffer should be processed now.""" # Process if buffer size threshold reached if self.total_bytes >= BUFFER_SIZE_BYTES: return True # Process if no chunks received for timeout duration if time.time() - self.last_chunk_time > CHUNK_TIMEOUT_SECONDS and self.total_bytes > 0: return True # Process if buffer is too large (safety limit) if self.total_bytes >= MAX_BUFFER_SIZE_BYTES: return True return False def get_audio(self) -> bytes: """Get concatenated audio data.""" return b''.join(self.chunks) def clear(self) -> None: """Clear the buffer after processing.""" self.chunks = [] self.total_bytes = 0 self.sequence += 1 def mark_complete(self) -> None: """Mark stream as complete.""" self.is_complete = True class StreamingSTTLocal: """Streaming Speech-to-Text service with local Whisper on ROCm.""" def __init__(self): self.nc = None self.js = None self.whisper_model = None self.sessions: Dict[str, AudioBuffer] = {} self.running = True self.processing_tasks = {} self.is_healthy = False self.tracer = None self.meter = None self.stream_counter = None self.transcription_duration = None self.gpu_memory_gauge = None async def setup(self): """Initialize connections and load model.""" # Initialize OpenTelemetry self.tracer, self.meter = setup_telemetry() # Create metrics if OTEL is enabled if self.meter: self.stream_counter = self.meter.create_counter( name="stt_streams_total", description="Total number of STT streams processed", unit="1" ) self.transcription_duration = self.meter.create_histogram( name="stt_transcription_duration_seconds", description="Duration of STT transcription", unit="s" ) self.gpu_memory_gauge = self.meter.create_observable_gauge( name="stt_gpu_memory_bytes", description="GPU memory usage in bytes", callbacks=[self._get_gpu_memory] ) # Check GPU availability if torch.cuda.is_available(): gpu_name = torch.cuda.get_device_name(0) gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) logger.info(f"ROCm GPU available: {gpu_name} ({gpu_memory:.1f}GB)") else: logger.warning("No GPU available, falling back to CPU") # Load Whisper model logger.info(f"Loading Whisper model: {WHISPER_MODEL_SIZE} on {WHISPER_DEVICE}") start_time = time.time() self.whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device=WHISPER_DEVICE) load_time = time.time() - start_time logger.info(f"Whisper model loaded in {load_time:.2f}s") # NATS connection self.nc = await nats.connect(NATS_URL) logger.info(f"Connected to NATS at {NATS_URL}") # Initialize JetStream context self.js = self.nc.jetstream() # Create or update stream for voice stream messages try: stream_config = nats.js.api.StreamConfig( name="AI_VOICE_STREAM", subjects=["ai.voice.stream.>", "ai.voice.transcription.>"], retention=nats.js.api.RetentionPolicy.LIMITS, max_age=300, # Keep messages for 5 minutes only (streaming is ephemeral) storage=nats.js.api.StorageType.MEMORY, # Use memory for streaming data ) await self.js.add_stream(stream_config) logger.info("Created/updated JetStream stream: AI_VOICE_STREAM") except Exception as e: # Stream might already exist logger.info(f"JetStream stream setup: {e}") # Mark as healthy once connections are established self.is_healthy = True async def health_handler(self, request: web.Request) -> web.Response: """Handle health check requests for kserve compatibility.""" if self.is_healthy: return web.json_response({ "status": "healthy", "model": WHISPER_MODEL_SIZE, "device": WHISPER_DEVICE, "nats_connected": self.nc is not None and self.nc.is_connected, }) else: return web.json_response( {"status": "unhealthy", "model": WHISPER_MODEL_SIZE}, status=503 ) async def start_health_server(self) -> web.AppRunner: """Start HTTP health server for kserve agent sidecar.""" app = web.Application() app.router.add_get("/health", self.health_handler) app.router.add_get("/ready", self.health_handler) app.router.add_get("/", self.health_handler) runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, "0.0.0.0", HEALTH_PORT) await site.start() logger.info(f"Health server started on port {HEALTH_PORT}") return runner def _get_gpu_memory(self, options): """Callback for GPU memory gauge.""" if torch.cuda.is_available(): memory_used = torch.cuda.memory_allocated(0) yield metrics.Observation(memory_used, {"device": "0"}) def transcribe(self, audio_bytes: bytes) -> Optional[str]: """Transcribe audio using local Whisper model.""" start_time = time.time() try: # Write audio to temp file (Whisper needs file path or numpy array) with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp: tmp.write(audio_bytes) tmp.flush() # Transcribe with Whisper result = self.whisper_model.transcribe( tmp.name, fp16=WHISPER_FP16 and WHISPER_DEVICE == "cuda", language="en", # Can be made configurable ) transcript = result.get("text", "").strip() duration = time.time() - start_time audio_duration = len(audio_bytes) / (16000 * 2) # Assuming 16kHz 16-bit rtf = duration / audio_duration if audio_duration > 0 else 0 logger.info(f"Transcribed in {duration:.2f}s (RTF: {rtf:.2f}): {transcript[:100]}...") # Record metrics if self.transcription_duration: self.transcription_duration.record(duration, {"model": WHISPER_MODEL_SIZE}) return transcript except Exception as e: logger.error(f"Transcription failed: {e}", exc_info=True) return None async def process_buffer(self, session_id: str): """Process accumulated audio buffer for a session.""" buffer = self.sessions.get(session_id) if not buffer: return audio_data = buffer.get_audio() if not audio_data: return logger.info(f"Processing {len(audio_data)} bytes for session {session_id}, sequence {buffer.sequence}") # Record stream counter if self.stream_counter: self.stream_counter.add(1, {"session_id": session_id}) # Transcribe in thread pool to avoid blocking event loop loop = asyncio.get_event_loop() transcript = await loop.run_in_executor(None, self.transcribe, audio_data) if transcript: # Publish transcription result using msgpack binary format result = { "session_id": session_id, "transcript": transcript, "sequence": buffer.sequence, "is_partial": not buffer.is_complete, "is_final": buffer.is_complete, "timestamp": time.time(), "model": WHISPER_MODEL_SIZE, "device": WHISPER_DEVICE, } await self.nc.publish( f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}", msgpack.packb(result) ) logger.info(f"Published transcription for session {session_id} (seq {buffer.sequence})") # Clear buffer after processing buffer.clear() # Clean up completed sessions asynchronously if buffer.is_complete: logger.info(f"Session {session_id} completed") # Schedule cleanup task to avoid blocking asyncio.create_task(self._cleanup_session(session_id)) async def _cleanup_session(self, session_id: str): """Clean up a completed session after a delay.""" # Keep session for a bit in case of late messages await asyncio.sleep(5) if session_id in self.sessions: del self.sessions[session_id] logger.info(f"Cleaned up session: {session_id}") if session_id in self.processing_tasks: del self.processing_tasks[session_id] async def monitor_buffer(self, session_id: str): """Monitor buffer and trigger processing when needed.""" while self.running and session_id in self.sessions: buffer = self.sessions.get(session_id) if not buffer: break if buffer.should_process(): await self.process_buffer(session_id) # Don't spin too fast await asyncio.sleep(0.1) async def handle_stream_message(self, msg: Msg): """Handle incoming audio stream message.""" try: # Extract session_id from subject: ai.voice.stream.{session_id} subject_parts = msg.subject.split('.') if len(subject_parts) < 4: logger.warning(f"Invalid subject format: {msg.subject}") return session_id = subject_parts[3] # Parse message using msgpack binary format data = msgpack.unpackb(msg.data, raw=False) # Handle control messages if data.get("type") == "start": logger.info(f"Starting stream session: {session_id}") self.sessions[session_id] = AudioBuffer(session_id) # Start monitoring task for this session task = asyncio.create_task(self.monitor_buffer(session_id)) self.processing_tasks[session_id] = task return if data.get("type") == "end": logger.info(f"Ending stream session: {session_id}") buffer = self.sessions.get(session_id) if buffer: buffer.mark_complete() # Process any remaining audio if buffer.total_bytes > 0: await self.process_buffer(session_id) return # Handle audio chunk if data.get("type") == "chunk": audio_b64 = data.get("audio_b64", "") if not audio_b64: return audio_bytes = base64.b64decode(audio_b64) # Create session if it doesn't exist (handle missing start message) if session_id not in self.sessions: logger.info(f"Auto-creating session: {session_id}") self.sessions[session_id] = AudioBuffer(session_id) if session_id not in self.processing_tasks: task = asyncio.create_task(self.monitor_buffer(session_id)) self.processing_tasks[session_id] = task # Add chunk to buffer self.sessions[session_id].add_chunk(audio_bytes) except Exception as e: logger.error(f"Error handling stream message: {e}", exc_info=True) async def run(self): """Main run loop.""" await self.setup() # Start health server for kserve compatibility health_runner = await self.start_health_server() # Subscribe to voice stream sub = await self.nc.subscribe(f"{STREAM_SUBJECT_PREFIX}.>", cb=self.handle_stream_message) logger.info(f"Subscribed to {STREAM_SUBJECT_PREFIX}.>") # Handle shutdown def signal_handler(): self.running = False loop = asyncio.get_event_loop() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, signal_handler) # Keep running while self.running: await asyncio.sleep(1) # Cleanup logger.info("Shutting down...") # Cancel all monitoring tasks and wait for them to complete for task in self.processing_tasks.values(): task.cancel() if self.processing_tasks: await asyncio.gather(*self.processing_tasks.values(), return_exceptions=True) await sub.unsubscribe() await self.nc.close() await health_runner.cleanup() logger.info("Shutdown complete") if __name__ == "__main__": service = StreamingSTTLocal() asyncio.run(service.run())