Files
stt-module/stt_streaming_local.py
Billy D. 8fc5eb1193 feat: add streaming STT service with Whisper backend
- stt_streaming.py: HTTP-based STT using external Whisper service
- stt_streaming_local.py: ROCm-based local Whisper inference
- Voice Activity Detection (VAD) with WebRTC
- Interrupt detection for barge-in support
- Session state management (listening/responding)
- OpenTelemetry instrumentation with HyperDX support
- Dockerfile variants for HTTP and ROCm deployments
2026-02-02 06:23:12 -05:00

512 lines
20 KiB
Python

#!/usr/bin/env python3
"""
Streaming STT Service with Local Whisper on ROCm
Real-time Speech-to-Text service that processes live audio streams from NATS
using local Whisper model running on AMD GPU via ROCm:
1. Subscribe to audio stream subject (ai.voice.stream.{session_id})
2. Buffer and accumulate audio chunks
3. Transcribe locally using Whisper on AMD GPU
4. Publish transcription results to response channel (ai.voice.transcription.{session_id})
This version runs Whisper directly on the AMD GPU using ROCm/PyTorch backend
instead of calling an external Whisper service.
Supports HyperDX for observability via OpenTelemetry.
"""
import asyncio
import base64
import contextlib
import io
import logging
import os
import signal
import tempfile
import time
from typing import Dict, Optional
from aiohttp import web
import msgpack
import nats
import nats.js
import numpy as np
import soundfile as sf
import torch
import whisper
from nats.aio.msg import Msg
# OpenTelemetry imports
from opentelemetry import trace, metrics
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
from opentelemetry.instrumentation.logging import LoggingInstrumentor
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("stt-streaming-rocm")
def setup_telemetry():
"""Initialize OpenTelemetry tracing and metrics with HyperDX support."""
# Check if OTEL is enabled
otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true"
if not otel_enabled:
logger.info("OpenTelemetry disabled")
return None, None
# OTEL configuration
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317")
service_name = os.environ.get("OTEL_SERVICE_NAME", "stt-streaming-rocm")
service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml")
# HyperDX configuration
hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "")
hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io")
use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key
# Create resource with service information
resource = Resource.create({
SERVICE_NAME: service_name,
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
SERVICE_NAMESPACE: service_namespace,
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
"host.name": os.environ.get("HOSTNAME", "unknown"),
})
# Setup tracing
trace_provider = TracerProvider(resource=resource)
if use_hyperdx:
# Use HTTP exporter for HyperDX with API key header
logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}")
headers = {"authorization": hyperdx_api_key}
otlp_span_exporter = OTLPSpanExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/traces",
headers=headers
)
otlp_metric_exporter = OTLPMetricExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/metrics",
headers=headers
)
else:
# Use gRPC exporter for standard OTEL collector
logger.info(f"Configuring OTEL gRPC exporter at {otel_endpoint}")
otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True)
otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True)
trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
trace.set_tracer_provider(trace_provider)
# Setup metrics
metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000)
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
# Instrument logging
LoggingInstrumentor().instrument(set_logging_format=True)
destination = "HyperDX" if use_hyperdx else "OTEL Collector"
logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}")
# Return tracer and meter for the service
tracer = trace.get_tracer(__name__)
meter = metrics.get_meter(__name__)
return tracer, meter
# Configuration from environment
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
WHISPER_MODEL_SIZE = os.environ.get("WHISPER_MODEL_SIZE", "medium")
WHISPER_DEVICE = os.environ.get("WHISPER_DEVICE", "cuda") # cuda uses ROCm on AMD
WHISPER_FP16 = os.environ.get("WHISPER_FP16", "true").lower() == "true"
# NATS subjects for streaming
STREAM_SUBJECT_PREFIX = "ai.voice.stream" # Full subject: ai.voice.stream.{session_id}
TRANSCRIPTION_SUBJECT_PREFIX = "ai.voice.transcription" # Full subject: ai.voice.transcription.{session_id}
# Streaming parameters
BUFFER_SIZE_BYTES = int(os.environ.get("STT_BUFFER_SIZE_BYTES", "512000")) # ~5 seconds at 16kHz 16-bit
CHUNK_TIMEOUT_SECONDS = float(os.environ.get("STT_CHUNK_TIMEOUT", "2.0")) # Process after 2s of silence
MAX_BUFFER_SIZE_BYTES = int(os.environ.get("STT_MAX_BUFFER_SIZE", "5120000")) # ~50 seconds max
# Health server port for kserve compatibility
HEALTH_PORT = int(os.environ.get("HEALTH_PORT", "8000"))
class AudioBuffer:
"""Manages audio chunks for a streaming session."""
def __init__(self, session_id: str):
self.session_id = session_id
self.chunks = []
self.total_bytes = 0
self.last_chunk_time = time.time()
self.is_complete = False
self.sequence = 0
def add_chunk(self, audio_data: bytes) -> None:
"""Add an audio chunk to the buffer."""
self.chunks.append(audio_data)
self.total_bytes += len(audio_data)
self.last_chunk_time = time.time()
logger.debug(f"Session {self.session_id}: Added chunk, total {self.total_bytes} bytes")
def should_process(self) -> bool:
"""Determine if buffer should be processed now."""
# Process if buffer size threshold reached
if self.total_bytes >= BUFFER_SIZE_BYTES:
return True
# Process if no chunks received for timeout duration
if time.time() - self.last_chunk_time > CHUNK_TIMEOUT_SECONDS and self.total_bytes > 0:
return True
# Process if buffer is too large (safety limit)
if self.total_bytes >= MAX_BUFFER_SIZE_BYTES:
return True
return False
def get_audio(self) -> bytes:
"""Get concatenated audio data."""
return b''.join(self.chunks)
def clear(self) -> None:
"""Clear the buffer after processing."""
self.chunks = []
self.total_bytes = 0
self.sequence += 1
def mark_complete(self) -> None:
"""Mark stream as complete."""
self.is_complete = True
class StreamingSTTLocal:
"""Streaming Speech-to-Text service with local Whisper on ROCm."""
def __init__(self):
self.nc = None
self.js = None
self.whisper_model = None
self.sessions: Dict[str, AudioBuffer] = {}
self.running = True
self.processing_tasks = {}
self.is_healthy = False
self.tracer = None
self.meter = None
self.stream_counter = None
self.transcription_duration = None
self.gpu_memory_gauge = None
async def setup(self):
"""Initialize connections and load model."""
# Initialize OpenTelemetry
self.tracer, self.meter = setup_telemetry()
# Create metrics if OTEL is enabled
if self.meter:
self.stream_counter = self.meter.create_counter(
name="stt_streams_total",
description="Total number of STT streams processed",
unit="1"
)
self.transcription_duration = self.meter.create_histogram(
name="stt_transcription_duration_seconds",
description="Duration of STT transcription",
unit="s"
)
self.gpu_memory_gauge = self.meter.create_observable_gauge(
name="stt_gpu_memory_bytes",
description="GPU memory usage in bytes",
callbacks=[self._get_gpu_memory]
)
# Check GPU availability
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
logger.info(f"ROCm GPU available: {gpu_name} ({gpu_memory:.1f}GB)")
else:
logger.warning("No GPU available, falling back to CPU")
# Load Whisper model
logger.info(f"Loading Whisper model: {WHISPER_MODEL_SIZE} on {WHISPER_DEVICE}")
start_time = time.time()
self.whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device=WHISPER_DEVICE)
load_time = time.time() - start_time
logger.info(f"Whisper model loaded in {load_time:.2f}s")
# NATS connection
self.nc = await nats.connect(NATS_URL)
logger.info(f"Connected to NATS at {NATS_URL}")
# Initialize JetStream context
self.js = self.nc.jetstream()
# Create or update stream for voice stream messages
try:
stream_config = nats.js.api.StreamConfig(
name="AI_VOICE_STREAM",
subjects=["ai.voice.stream.>", "ai.voice.transcription.>"],
retention=nats.js.api.RetentionPolicy.LIMITS,
max_age=300, # Keep messages for 5 minutes only (streaming is ephemeral)
storage=nats.js.api.StorageType.MEMORY, # Use memory for streaming data
)
await self.js.add_stream(stream_config)
logger.info("Created/updated JetStream stream: AI_VOICE_STREAM")
except Exception as e:
# Stream might already exist
logger.info(f"JetStream stream setup: {e}")
# Mark as healthy once connections are established
self.is_healthy = True
async def health_handler(self, request: web.Request) -> web.Response:
"""Handle health check requests for kserve compatibility."""
if self.is_healthy:
return web.json_response({
"status": "healthy",
"model": WHISPER_MODEL_SIZE,
"device": WHISPER_DEVICE,
"nats_connected": self.nc is not None and self.nc.is_connected,
})
else:
return web.json_response(
{"status": "unhealthy", "model": WHISPER_MODEL_SIZE},
status=503
)
async def start_health_server(self) -> web.AppRunner:
"""Start HTTP health server for kserve agent sidecar."""
app = web.Application()
app.router.add_get("/health", self.health_handler)
app.router.add_get("/ready", self.health_handler)
app.router.add_get("/", self.health_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "0.0.0.0", HEALTH_PORT)
await site.start()
logger.info(f"Health server started on port {HEALTH_PORT}")
return runner
def _get_gpu_memory(self, options):
"""Callback for GPU memory gauge."""
if torch.cuda.is_available():
memory_used = torch.cuda.memory_allocated(0)
yield metrics.Observation(memory_used, {"device": "0"})
def transcribe(self, audio_bytes: bytes) -> Optional[str]:
"""Transcribe audio using local Whisper model."""
start_time = time.time()
try:
# Write audio to temp file (Whisper needs file path or numpy array)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
tmp.write(audio_bytes)
tmp.flush()
# Transcribe with Whisper
result = self.whisper_model.transcribe(
tmp.name,
fp16=WHISPER_FP16 and WHISPER_DEVICE == "cuda",
language="en", # Can be made configurable
)
transcript = result.get("text", "").strip()
duration = time.time() - start_time
audio_duration = len(audio_bytes) / (16000 * 2) # Assuming 16kHz 16-bit
rtf = duration / audio_duration if audio_duration > 0 else 0
logger.info(f"Transcribed in {duration:.2f}s (RTF: {rtf:.2f}): {transcript[:100]}...")
# Record metrics
if self.transcription_duration:
self.transcription_duration.record(duration, {"model": WHISPER_MODEL_SIZE})
return transcript
except Exception as e:
logger.error(f"Transcription failed: {e}", exc_info=True)
return None
async def process_buffer(self, session_id: str):
"""Process accumulated audio buffer for a session."""
buffer = self.sessions.get(session_id)
if not buffer:
return
audio_data = buffer.get_audio()
if not audio_data:
return
logger.info(f"Processing {len(audio_data)} bytes for session {session_id}, sequence {buffer.sequence}")
# Record stream counter
if self.stream_counter:
self.stream_counter.add(1, {"session_id": session_id})
# Transcribe in thread pool to avoid blocking event loop
loop = asyncio.get_event_loop()
transcript = await loop.run_in_executor(None, self.transcribe, audio_data)
if transcript:
# Publish transcription result using msgpack binary format
result = {
"session_id": session_id,
"transcript": transcript,
"sequence": buffer.sequence,
"is_partial": not buffer.is_complete,
"is_final": buffer.is_complete,
"timestamp": time.time(),
"model": WHISPER_MODEL_SIZE,
"device": WHISPER_DEVICE,
}
await self.nc.publish(
f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}",
msgpack.packb(result)
)
logger.info(f"Published transcription for session {session_id} (seq {buffer.sequence})")
# Clear buffer after processing
buffer.clear()
# Clean up completed sessions asynchronously
if buffer.is_complete:
logger.info(f"Session {session_id} completed")
# Schedule cleanup task to avoid blocking
asyncio.create_task(self._cleanup_session(session_id))
async def _cleanup_session(self, session_id: str):
"""Clean up a completed session after a delay."""
# Keep session for a bit in case of late messages
await asyncio.sleep(5)
if session_id in self.sessions:
del self.sessions[session_id]
logger.info(f"Cleaned up session: {session_id}")
if session_id in self.processing_tasks:
del self.processing_tasks[session_id]
async def monitor_buffer(self, session_id: str):
"""Monitor buffer and trigger processing when needed."""
while self.running and session_id in self.sessions:
buffer = self.sessions.get(session_id)
if not buffer:
break
if buffer.should_process():
await self.process_buffer(session_id)
# Don't spin too fast
await asyncio.sleep(0.1)
async def handle_stream_message(self, msg: Msg):
"""Handle incoming audio stream message."""
try:
# Extract session_id from subject: ai.voice.stream.{session_id}
subject_parts = msg.subject.split('.')
if len(subject_parts) < 4:
logger.warning(f"Invalid subject format: {msg.subject}")
return
session_id = subject_parts[3]
# Parse message using msgpack binary format
data = msgpack.unpackb(msg.data, raw=False)
# Handle control messages
if data.get("type") == "start":
logger.info(f"Starting stream session: {session_id}")
self.sessions[session_id] = AudioBuffer(session_id)
# Start monitoring task for this session
task = asyncio.create_task(self.monitor_buffer(session_id))
self.processing_tasks[session_id] = task
return
if data.get("type") == "end":
logger.info(f"Ending stream session: {session_id}")
buffer = self.sessions.get(session_id)
if buffer:
buffer.mark_complete()
# Process any remaining audio
if buffer.total_bytes > 0:
await self.process_buffer(session_id)
return
# Handle audio chunk
if data.get("type") == "chunk":
audio_b64 = data.get("audio_b64", "")
if not audio_b64:
return
audio_bytes = base64.b64decode(audio_b64)
# Create session if it doesn't exist (handle missing start message)
if session_id not in self.sessions:
logger.info(f"Auto-creating session: {session_id}")
self.sessions[session_id] = AudioBuffer(session_id)
if session_id not in self.processing_tasks:
task = asyncio.create_task(self.monitor_buffer(session_id))
self.processing_tasks[session_id] = task
# Add chunk to buffer
self.sessions[session_id].add_chunk(audio_bytes)
except Exception as e:
logger.error(f"Error handling stream message: {e}", exc_info=True)
async def run(self):
"""Main run loop."""
await self.setup()
# Start health server for kserve compatibility
health_runner = await self.start_health_server()
# Subscribe to voice stream
sub = await self.nc.subscribe(f"{STREAM_SUBJECT_PREFIX}.>", cb=self.handle_stream_message)
logger.info(f"Subscribed to {STREAM_SUBJECT_PREFIX}.>")
# Handle shutdown
def signal_handler():
self.running = False
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
# Keep running
while self.running:
await asyncio.sleep(1)
# Cleanup
logger.info("Shutting down...")
# Cancel all monitoring tasks and wait for them to complete
for task in self.processing_tasks.values():
task.cancel()
if self.processing_tasks:
await asyncio.gather(*self.processing_tasks.values(), return_exceptions=True)
await sub.unsubscribe()
await self.nc.close()
await health_runner.cleanup()
logger.info("Shutdown complete")
if __name__ == "__main__":
service = StreamingSTTLocal()
asyncio.run(service.run())