feat: add streaming STT service with Whisper backend
- stt_streaming.py: HTTP-based STT using external Whisper service - stt_streaming_local.py: ROCm-based local Whisper inference - Voice Activity Detection (VAD) with WebRTC - Interrupt detection for barge-in support - Session state management (listening/responding) - OpenTelemetry instrumentation with HyperDX support - Dockerfile variants for HTTP and ROCm deployments
This commit is contained in:
11
.gitignore
vendored
Normal file
11
.gitignore
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
.venv/
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
.pytest_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
*.egg-info/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
14
Dockerfile
Normal file
14
Dockerfile
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy requirements and install dependencies
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY stt_streaming.py .
|
||||||
|
COPY healthcheck.py .
|
||||||
|
|
||||||
|
# Run the service
|
||||||
|
CMD ["python", "stt_streaming.py"]
|
||||||
52
Dockerfile.rocm
Normal file
52
Dockerfile.rocm
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# STT Streaming Service with ROCm for AMD GPU Whisper inference
|
||||||
|
# Targets AMD Strix Halo (gfx1151 / RDNA 3.5) but includes RDNA 3 compatibility
|
||||||
|
#
|
||||||
|
# Uses OpenAI Whisper with PyTorch ROCm backend
|
||||||
|
#
|
||||||
|
FROM docker.io/rocm/pytorch:rocm7.1_ubuntu24.04_py3.12_pytorch_release_2.9.1 AS base
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
ffmpeg \
|
||||||
|
libsndfile1 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# WORKAROUND: ROCm/ROCm#5853 - Standard PyTorch ROCm wheels cause segfault in
|
||||||
|
# libhsa-runtime64.so during VRAM allocation on gfx1151 (Strix Halo).
|
||||||
|
# TheRock nightly builds work correctly. Install BEFORE other deps since
|
||||||
|
# openai-whisper depends on torch.
|
||||||
|
RUN pip install --no-cache-dir --break-system-packages \
|
||||||
|
--index-url https://rocm.nightlies.amd.com/v2/gfx1151/ \
|
||||||
|
torch torchaudio torchvision --force-reinstall
|
||||||
|
|
||||||
|
# Install Python dependencies for STT streaming
|
||||||
|
# Use pip directly (more reliable than uv in this base image)
|
||||||
|
COPY requirements-rocm.txt .
|
||||||
|
RUN pip install --no-cache-dir --break-system-packages -r requirements-rocm.txt
|
||||||
|
|
||||||
|
# Download Whisper model at build time for faster startup
|
||||||
|
# Using medium model for good accuracy/speed balance
|
||||||
|
ARG WHISPER_MODEL=medium
|
||||||
|
ENV WHISPER_MODEL_SIZE=${WHISPER_MODEL}
|
||||||
|
|
||||||
|
# Pre-download the model during build (whisper is installed as openai-whisper)
|
||||||
|
# Use python3 to ensure correct interpreter
|
||||||
|
RUN python3 -c "import whisper; whisper.load_model('${WHISPER_MODEL}')" || echo "Model will be downloaded at runtime"
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY stt_streaming_local.py .
|
||||||
|
COPY healthcheck.py .
|
||||||
|
|
||||||
|
# Set ROCm environment for AMD Strix Halo (gfx1151 / RDNA 3.5)
|
||||||
|
ENV HIP_VISIBLE_DEVICES=0
|
||||||
|
ENV HSA_ENABLE_SDMA=0
|
||||||
|
# Ensure PyTorch uses ROCm with expandable segments for large models
|
||||||
|
ENV PYTORCH_HIP_ALLOC_CONF=expandable_segments:True,max_split_size_mb:512
|
||||||
|
# Target gfx1151 (Strix Halo) - ROCm 7.1+ has native support
|
||||||
|
# Falls back to runtime override if kernels not available
|
||||||
|
ENV ROCM_TARGET_LST=gfx1151,gfx1100
|
||||||
|
|
||||||
|
# Run the service
|
||||||
|
CMD ["python", "stt_streaming_local.py"]
|
||||||
182
README.md
182
README.md
@@ -1,2 +1,182 @@
|
|||||||
# stt-module
|
# Streaming STT Module
|
||||||
|
|
||||||
|
A dedicated Speech-to-Text (STT) service that processes live audio streams from NATS for faster transcription responses.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This module enables real-time speech-to-text processing by accepting audio chunks as they arrive rather than waiting for complete audio files. This significantly reduces latency in voice assistant applications.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Live Audio Streaming**: Accepts audio chunks via NATS as they're captured
|
||||||
|
- **Incremental Processing**: Transcribes audio as soon as sufficient data is buffered
|
||||||
|
- **Session Management**: Handles multiple concurrent streaming sessions
|
||||||
|
- **Automatic Buffer Management**: Processes audio based on size thresholds or timeout
|
||||||
|
- **Partial Results**: Publishes transcription results progressively during long streams
|
||||||
|
- **Voice Activity Detection (VAD)**: Detects speech vs silence to optimize processing
|
||||||
|
- **Interrupt Detection**: Detects when user speaks during LLM response and switches back to listening mode
|
||||||
|
- **Speaker Tracking**: Support for speaker identification in multi-speaker scenarios
|
||||||
|
- **State Management**: Tracks listening/responding states for proper interrupt handling
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────┐
|
||||||
|
│ Audio Source │ (Frontend, Mobile App, etc.)
|
||||||
|
│ (Microphone) │
|
||||||
|
└────────┬────────┘
|
||||||
|
│ Chunks
|
||||||
|
▼
|
||||||
|
┌─────────────────┐
|
||||||
|
│ NATS Subject │ ai.voice.stream.{session_id}
|
||||||
|
│ Audio Stream │
|
||||||
|
└────────┬────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────┐
|
||||||
|
│ STT Streaming │ (This Service)
|
||||||
|
│ Service │ - Buffers chunks
|
||||||
|
│ │ - Transcribes when ready
|
||||||
|
└────────┬────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────┐
|
||||||
|
│ NATS Subject │ ai.voice.transcription.{session_id}
|
||||||
|
│ Transcription │
|
||||||
|
└─────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## Variants
|
||||||
|
|
||||||
|
### stt_streaming.py (HTTP Backend)
|
||||||
|
Uses an external Whisper service via HTTP. Lightweight container, delegates GPU inference to a separate service.
|
||||||
|
|
||||||
|
### stt_streaming_local.py (ROCm Backend)
|
||||||
|
Runs Whisper locally on AMD GPU using ROCm/PyTorch. Single container with embedded model.
|
||||||
|
|
||||||
|
## NATS Message Protocol
|
||||||
|
|
||||||
|
### Audio Stream Input (ai.voice.stream.{session_id})
|
||||||
|
|
||||||
|
All messages use **msgpack** binary encoding.
|
||||||
|
|
||||||
|
**Start Stream:**
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"type": "start",
|
||||||
|
"session_id": "unique-session-id",
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"channels": 1,
|
||||||
|
"state": "listening", # Optional: "listening" or "responding"
|
||||||
|
"speaker_id": "speaker-1" # Optional: identifier for speaker tracking
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Audio Chunk:**
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"type": "chunk",
|
||||||
|
"audio_b64": "base64-encoded-audio-data",
|
||||||
|
"timestamp": 1234567890.123
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**State Change:**
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"type": "state_change",
|
||||||
|
"state": "responding" # "listening" or "responding"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**End Stream:**
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"type": "end"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Transcription Output (ai.voice.transcription.{session_id})
|
||||||
|
|
||||||
|
**Transcription Result:**
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"session_id": "unique-session-id",
|
||||||
|
"transcript": "transcribed text",
|
||||||
|
"sequence": 0,
|
||||||
|
"is_partial": False,
|
||||||
|
"is_final": True,
|
||||||
|
"timestamp": 1234567890.123,
|
||||||
|
"speaker_id": "speaker-1", # If provided in start message
|
||||||
|
"has_voice_activity": True,
|
||||||
|
"state": "listening"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Interrupt Notification:**
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"session_id": "unique-session-id",
|
||||||
|
"type": "interrupt",
|
||||||
|
"timestamp": 1234567890.123,
|
||||||
|
"speaker_id": "speaker-1"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
| Variable | Default | Description |
|
||||||
|
|----------|---------|-------------|
|
||||||
|
| `NATS_URL` | `nats://nats.ai-ml.svc.cluster.local:4222` | NATS server URL |
|
||||||
|
| `WHISPER_URL` | `http://whisper-predictor.ai-ml.svc.cluster.local` | Whisper service URL (HTTP variant) |
|
||||||
|
| `WHISPER_MODEL_SIZE` | `medium` | Whisper model size (ROCm variant) |
|
||||||
|
| `WHISPER_DEVICE` | `cuda` | PyTorch device (ROCm variant) |
|
||||||
|
| `STT_BUFFER_SIZE_BYTES` | `512000` | Buffer size before processing (~5s) |
|
||||||
|
| `STT_CHUNK_TIMEOUT` | `2.0` | Seconds of silence before processing |
|
||||||
|
| `STT_ENABLE_VAD` | `true` | Enable voice activity detection |
|
||||||
|
| `STT_VAD_AGGRESSIVENESS` | `2` | VAD aggressiveness (0-3) |
|
||||||
|
| `STT_ENABLE_INTERRUPT_DETECTION` | `true` | Enable interrupt detection |
|
||||||
|
| `OTEL_ENABLED` | `true` | Enable OpenTelemetry |
|
||||||
|
| `HYPERDX_ENABLED` | `false` | Enable HyperDX observability |
|
||||||
|
|
||||||
|
## Building
|
||||||
|
|
||||||
|
### HTTP Variant
|
||||||
|
```bash
|
||||||
|
docker build -t stt-module:latest .
|
||||||
|
```
|
||||||
|
|
||||||
|
### ROCm Variant (AMD GPU)
|
||||||
|
```bash
|
||||||
|
docker build -f Dockerfile.rocm -t stt-module:rocm --build-arg WHISPER_MODEL=medium .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Port-forward NATS
|
||||||
|
kubectl port-forward -n ai-ml svc/nats 4222:4222
|
||||||
|
|
||||||
|
# Start a session
|
||||||
|
python -c "
|
||||||
|
import nats
|
||||||
|
import msgpack
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def test():
|
||||||
|
nc = await nats.connect('nats://localhost:4222')
|
||||||
|
await nc.publish('ai.voice.stream.test-session', msgpack.packb({'type': 'start'}))
|
||||||
|
# Send audio chunks...
|
||||||
|
await nc.publish('ai.voice.stream.test-session', msgpack.packb({'type': 'end'}))
|
||||||
|
await nc.close()
|
||||||
|
|
||||||
|
asyncio.run(test())
|
||||||
|
"
|
||||||
|
|
||||||
|
# Subscribe to transcriptions
|
||||||
|
nats sub "ai.voice.transcription.>"
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT
|
||||||
|
|||||||
28
healthcheck.py
Normal file
28
healthcheck.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Health check script for Kubernetes probes
|
||||||
|
Verifies NATS connectivity
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import nats
|
||||||
|
|
||||||
|
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
|
||||||
|
|
||||||
|
|
||||||
|
async def check_health():
|
||||||
|
"""Check if service can connect to NATS."""
|
||||||
|
try:
|
||||||
|
nc = await asyncio.wait_for(nats.connect(NATS_URL), timeout=5.0)
|
||||||
|
await nc.close()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Health check failed: {e}", file=sys.stderr)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
result = asyncio.run(check_health())
|
||||||
|
sys.exit(0 if result else 1)
|
||||||
24
requirements-rocm.txt
Normal file
24
requirements-rocm.txt
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Core dependencies
|
||||||
|
nats-py>=2.0.0,<3.0.0
|
||||||
|
msgpack
|
||||||
|
|
||||||
|
# Whisper for local STT inference (uses PyTorch already in base image)
|
||||||
|
openai-whisper>=20231117
|
||||||
|
|
||||||
|
# Audio processing
|
||||||
|
soundfile
|
||||||
|
numpy
|
||||||
|
|
||||||
|
# OpenTelemetry core
|
||||||
|
opentelemetry-api
|
||||||
|
opentelemetry-sdk
|
||||||
|
opentelemetry-exporter-otlp-proto-grpc
|
||||||
|
opentelemetry-exporter-otlp-proto-http
|
||||||
|
opentelemetry-instrumentation-logging
|
||||||
|
|
||||||
|
# HyperDX support (uses OTLP protocol)
|
||||||
|
# HyperDX is compatible with standard OTEL exporters, just needs API key header
|
||||||
|
opentelemetry-sdk-extension-aws # For additional context propagation
|
||||||
|
|
||||||
|
# HTTP health server for kserve compatibility
|
||||||
|
aiohttp
|
||||||
20
requirements.txt
Normal file
20
requirements.txt
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
nats-py>=2.0.0,<3.0.0
|
||||||
|
httpx>=0.20.0,<1.0.0
|
||||||
|
msgpack
|
||||||
|
|
||||||
|
# Audio processing
|
||||||
|
numpy>=1.20.0,<2.0.0
|
||||||
|
webrtcvad>=2.0.10
|
||||||
|
# pyannote.audio>=3.1.0 # Optional: for advanced speaker diarization
|
||||||
|
|
||||||
|
# OpenTelemetry core
|
||||||
|
opentelemetry-api
|
||||||
|
opentelemetry-sdk
|
||||||
|
|
||||||
|
# OTEL exporters (gRPC for local collector, HTTP for HyperDX)
|
||||||
|
opentelemetry-exporter-otlp-proto-grpc
|
||||||
|
opentelemetry-exporter-otlp-proto-http
|
||||||
|
|
||||||
|
# OTEL instrumentation
|
||||||
|
opentelemetry-instrumentation-httpx
|
||||||
|
opentelemetry-instrumentation-logging
|
||||||
632
stt_streaming.py
Normal file
632
stt_streaming.py
Normal file
@@ -0,0 +1,632 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Streaming STT Service
|
||||||
|
|
||||||
|
Real-time Speech-to-Text service that processes live audio streams from NATS:
|
||||||
|
1. Subscribe to audio stream subject (ai.voice.stream.{session_id})
|
||||||
|
2. Buffer and accumulate audio chunks
|
||||||
|
3. Transcribe when buffer reaches threshold or stream ends
|
||||||
|
4. Publish transcription results to response channel (ai.voice.transcription.{session_id})
|
||||||
|
|
||||||
|
This enables faster response times by processing audio as it arrives rather than
|
||||||
|
waiting for complete audio upload.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
import struct
|
||||||
|
from typing import Dict, Optional, List, Tuple
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import msgpack
|
||||||
|
import nats
|
||||||
|
import nats.js
|
||||||
|
from nats.aio.msg import Msg
|
||||||
|
import numpy as np
|
||||||
|
import webrtcvad
|
||||||
|
|
||||||
|
# OpenTelemetry imports
|
||||||
|
from opentelemetry import trace, metrics
|
||||||
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||||
|
from opentelemetry.sdk.metrics import MeterProvider
|
||||||
|
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||||
|
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||||
|
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
|
||||||
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP
|
||||||
|
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP
|
||||||
|
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
|
||||||
|
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
||||||
|
from opentelemetry.instrumentation.logging import LoggingInstrumentor
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("stt-streaming")
|
||||||
|
|
||||||
|
# Initialize OpenTelemetry
|
||||||
|
def setup_telemetry():
|
||||||
|
"""Initialize OpenTelemetry tracing and metrics with HyperDX support."""
|
||||||
|
# Check if OTEL is enabled
|
||||||
|
otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true"
|
||||||
|
if not otel_enabled:
|
||||||
|
logger.info("OpenTelemetry disabled")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# OTEL configuration
|
||||||
|
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317")
|
||||||
|
service_name = os.environ.get("OTEL_SERVICE_NAME", "stt-streaming")
|
||||||
|
service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml")
|
||||||
|
|
||||||
|
# HyperDX configuration
|
||||||
|
hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "")
|
||||||
|
hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io")
|
||||||
|
use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key
|
||||||
|
|
||||||
|
# Create resource with service information
|
||||||
|
resource = Resource.create({
|
||||||
|
SERVICE_NAME: service_name,
|
||||||
|
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
|
||||||
|
SERVICE_NAMESPACE: service_namespace,
|
||||||
|
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
|
||||||
|
"host.name": os.environ.get("HOSTNAME", "unknown"),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Setup tracing
|
||||||
|
trace_provider = TracerProvider(resource=resource)
|
||||||
|
|
||||||
|
if use_hyperdx:
|
||||||
|
# Use HTTP exporter for HyperDX with API key header
|
||||||
|
logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}")
|
||||||
|
headers = {"authorization": hyperdx_api_key}
|
||||||
|
otlp_span_exporter = OTLPSpanExporterHTTP(
|
||||||
|
endpoint=f"{hyperdx_endpoint}/v1/traces",
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
otlp_metric_exporter = OTLPMetricExporterHTTP(
|
||||||
|
endpoint=f"{hyperdx_endpoint}/v1/metrics",
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use gRPC exporter for standard OTEL collector
|
||||||
|
otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True)
|
||||||
|
otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True)
|
||||||
|
|
||||||
|
trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
|
||||||
|
trace.set_tracer_provider(trace_provider)
|
||||||
|
|
||||||
|
# Setup metrics
|
||||||
|
metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000)
|
||||||
|
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||||
|
metrics.set_meter_provider(meter_provider)
|
||||||
|
|
||||||
|
# Instrument HTTPX
|
||||||
|
HTTPXClientInstrumentor().instrument()
|
||||||
|
|
||||||
|
# Instrument logging
|
||||||
|
LoggingInstrumentor().instrument(set_logging_format=True)
|
||||||
|
|
||||||
|
destination = "HyperDX" if use_hyperdx else "OTEL Collector"
|
||||||
|
logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}")
|
||||||
|
|
||||||
|
# Return tracer and meter for the service
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
meter = metrics.get_meter(__name__)
|
||||||
|
|
||||||
|
return tracer, meter
|
||||||
|
|
||||||
|
# Configuration from environment
|
||||||
|
WHISPER_URL = os.environ.get("WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local")
|
||||||
|
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
|
||||||
|
|
||||||
|
# NATS subjects for streaming
|
||||||
|
STREAM_SUBJECT_PREFIX = "ai.voice.stream" # Full subject: ai.voice.stream.{session_id}
|
||||||
|
TRANSCRIPTION_SUBJECT_PREFIX = "ai.voice.transcription" # Full subject: ai.voice.transcription.{session_id}
|
||||||
|
|
||||||
|
# Streaming parameters
|
||||||
|
BUFFER_SIZE_BYTES = int(os.environ.get("STT_BUFFER_SIZE_BYTES", "512000")) # ~5 seconds at 16kHz 16-bit
|
||||||
|
CHUNK_TIMEOUT_SECONDS = float(os.environ.get("STT_CHUNK_TIMEOUT", "2.0")) # Process after 2s of silence
|
||||||
|
MAX_BUFFER_SIZE_BYTES = int(os.environ.get("STT_MAX_BUFFER_SIZE", "5120000")) # ~50 seconds max
|
||||||
|
|
||||||
|
# Audio constants
|
||||||
|
AUDIO_SAMPLE_MAX_INT16 = 32768.0 # Maximum value for 16-bit signed integer audio
|
||||||
|
VAD_VOICE_RATIO_THRESHOLD = float(os.environ.get("STT_VAD_VOICE_RATIO", "0.3")) # Min ratio of voice frames
|
||||||
|
|
||||||
|
# Voice Activity Detection (VAD) parameters
|
||||||
|
ENABLE_VAD = os.environ.get("STT_ENABLE_VAD", "true").lower() == "true"
|
||||||
|
VAD_AGGRESSIVENESS = int(os.environ.get("STT_VAD_AGGRESSIVENESS", "2")) # 0-3, higher = more aggressive
|
||||||
|
VAD_FRAME_DURATION_MS = int(os.environ.get("STT_VAD_FRAME_DURATION", "30")) # 10, 20, or 30 ms
|
||||||
|
|
||||||
|
# Audio threshold for interrupt detection (when LLM is responding)
|
||||||
|
ENABLE_INTERRUPT_DETECTION = os.environ.get("STT_ENABLE_INTERRUPT_DETECTION", "true").lower() == "true"
|
||||||
|
AUDIO_LEVEL_THRESHOLD = float(os.environ.get("STT_AUDIO_LEVEL_THRESHOLD", "0.02")) # RMS threshold
|
||||||
|
INTERRUPT_DURATION_THRESHOLD = float(os.environ.get("STT_INTERRUPT_DURATION", "0.5")) # Seconds of speech to trigger
|
||||||
|
|
||||||
|
# Speaker diarization
|
||||||
|
ENABLE_SPEAKER_DIARIZATION = os.environ.get("STT_ENABLE_SPEAKER_DIARIZATION", "false").lower() == "true"
|
||||||
|
|
||||||
|
# Session states
|
||||||
|
SESSION_STATE_LISTENING = "listening"
|
||||||
|
SESSION_STATE_RESPONDING = "responding"
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_audio_rms(audio_data: bytes, sample_width: int = 2) -> float:
|
||||||
|
"""
|
||||||
|
Calculate RMS (Root Mean Square) audio level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: Raw audio bytes
|
||||||
|
sample_width: Bytes per sample (2 for 16-bit audio)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RMS level normalized to 0.0-1.0 range
|
||||||
|
"""
|
||||||
|
if len(audio_data) < sample_width:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Convert bytes to numpy array of int16 samples
|
||||||
|
try:
|
||||||
|
samples = np.frombuffer(audio_data, dtype=np.int16)
|
||||||
|
# Calculate RMS and normalize
|
||||||
|
rms = np.sqrt(np.mean(samples.astype(np.float32) ** 2))
|
||||||
|
# Normalize to 0-1 range using defined constant
|
||||||
|
return float(rms / AUDIO_SAMPLE_MAX_INT16)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error calculating RMS: {e}")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def detect_voice_activity(audio_data: bytes, sample_rate: int = 16000) -> bool:
|
||||||
|
"""
|
||||||
|
Detect if audio contains voice using WebRTC VAD.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: Raw PCM audio bytes (16-bit, mono)
|
||||||
|
sample_rate: Audio sample rate (8000, 16000, 32000, or 48000)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if voice is detected, False otherwise
|
||||||
|
"""
|
||||||
|
if not ENABLE_VAD:
|
||||||
|
return True # Assume voice present if VAD disabled
|
||||||
|
|
||||||
|
try:
|
||||||
|
vad = webrtcvad.Vad(VAD_AGGRESSIVENESS)
|
||||||
|
|
||||||
|
# WebRTC VAD requires specific frame sizes
|
||||||
|
# Frame duration must be 10, 20, or 30 ms
|
||||||
|
frame_size = int(sample_rate * VAD_FRAME_DURATION_MS / 1000) * 2 # *2 for 16-bit samples
|
||||||
|
|
||||||
|
# Process audio in frames
|
||||||
|
voice_frames = 0
|
||||||
|
total_frames = 0
|
||||||
|
|
||||||
|
for i in range(0, len(audio_data) - frame_size, frame_size):
|
||||||
|
frame = audio_data[i:i + frame_size]
|
||||||
|
if len(frame) == frame_size:
|
||||||
|
try:
|
||||||
|
is_speech = vad.is_speech(frame, sample_rate)
|
||||||
|
if is_speech:
|
||||||
|
voice_frames += 1
|
||||||
|
total_frames += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"VAD frame processing error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if total_frames == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Consider voice detected if voice ratio exceeds threshold
|
||||||
|
voice_ratio = voice_frames / total_frames
|
||||||
|
return voice_ratio > VAD_VOICE_RATIO_THRESHOLD
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"VAD error: {e}")
|
||||||
|
return True # Default to voice present on error
|
||||||
|
|
||||||
|
|
||||||
|
class AudioBuffer:
|
||||||
|
"""Manages audio chunks for a streaming session with VAD and speaker tracking."""
|
||||||
|
|
||||||
|
def __init__(self, session_id: str):
|
||||||
|
self.session_id = session_id
|
||||||
|
self.chunks = []
|
||||||
|
self.total_bytes = 0
|
||||||
|
self.last_chunk_time = time.time()
|
||||||
|
self.is_complete = False
|
||||||
|
self.sequence = 0
|
||||||
|
self.state = SESSION_STATE_LISTENING # Current session state
|
||||||
|
self.speaker_id = None # For speaker diarization
|
||||||
|
self.interrupt_start_time = None # Track when interrupt detection started
|
||||||
|
self.has_voice_activity = False # Track if voice was detected in recent chunks
|
||||||
|
self._last_chunk_vad_result = None # Cache VAD result for last chunk
|
||||||
|
|
||||||
|
def add_chunk(self, audio_data: bytes) -> None:
|
||||||
|
"""Add an audio chunk to the buffer and check for voice activity."""
|
||||||
|
self.chunks.append(audio_data)
|
||||||
|
self.total_bytes += len(audio_data)
|
||||||
|
self.last_chunk_time = time.time()
|
||||||
|
|
||||||
|
# Check for voice activity in this chunk and cache result
|
||||||
|
has_voice = detect_voice_activity(audio_data)
|
||||||
|
self.has_voice_activity = has_voice
|
||||||
|
self._last_chunk_vad_result = has_voice
|
||||||
|
|
||||||
|
logger.debug(f"Session {self.session_id}: Added chunk, total {self.total_bytes} bytes, voice={has_voice}")
|
||||||
|
|
||||||
|
def check_interrupt(self, audio_data: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
Check if audio indicates an interrupt during responding state.
|
||||||
|
Uses cached VAD result if available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if interrupt detected, False otherwise
|
||||||
|
"""
|
||||||
|
if not ENABLE_INTERRUPT_DETECTION:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.state != SESSION_STATE_RESPONDING:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Calculate audio level
|
||||||
|
rms_level = calculate_audio_rms(audio_data)
|
||||||
|
|
||||||
|
# Use cached VAD result if available to avoid duplicate processing
|
||||||
|
has_voice = self._last_chunk_vad_result if self._last_chunk_vad_result is not None else detect_voice_activity(audio_data)
|
||||||
|
|
||||||
|
# Check if audio exceeds threshold and contains voice
|
||||||
|
if rms_level >= AUDIO_LEVEL_THRESHOLD and has_voice:
|
||||||
|
if self.interrupt_start_time is None:
|
||||||
|
self.interrupt_start_time = time.time()
|
||||||
|
logger.info(f"Session {self.session_id}: Potential interrupt detected (RMS={rms_level:.3f})")
|
||||||
|
|
||||||
|
# Check if interrupt has lasted long enough
|
||||||
|
elapsed = time.time() - self.interrupt_start_time
|
||||||
|
if elapsed >= INTERRUPT_DURATION_THRESHOLD:
|
||||||
|
logger.info(f"Session {self.session_id}: Interrupt confirmed after {elapsed:.1f}s")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# Reset interrupt timer if audio drops below threshold
|
||||||
|
self.interrupt_start_time = None
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def set_state(self, state: str) -> None:
|
||||||
|
"""Set the session state (listening or responding)."""
|
||||||
|
if state in (SESSION_STATE_LISTENING, SESSION_STATE_RESPONDING):
|
||||||
|
old_state = self.state
|
||||||
|
self.state = state
|
||||||
|
if old_state != state:
|
||||||
|
logger.info(f"Session {self.session_id}: State changed from {old_state} to {state}")
|
||||||
|
# Reset interrupt tracking when changing states
|
||||||
|
self.interrupt_start_time = None
|
||||||
|
|
||||||
|
def should_process(self) -> bool:
|
||||||
|
"""Determine if buffer should be processed now."""
|
||||||
|
# Don't process if no voice activity detected (unless buffer is full or timed out)
|
||||||
|
if ENABLE_VAD and not self.has_voice_activity:
|
||||||
|
# Still process if buffer is very large or has timed out
|
||||||
|
if self.total_bytes < BUFFER_SIZE_BYTES and time.time() - self.last_chunk_time < CHUNK_TIMEOUT_SECONDS:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Process if buffer size threshold reached
|
||||||
|
if self.total_bytes >= BUFFER_SIZE_BYTES:
|
||||||
|
return True
|
||||||
|
# Process if no chunks received for timeout duration
|
||||||
|
if time.time() - self.last_chunk_time > CHUNK_TIMEOUT_SECONDS and self.total_bytes > 0:
|
||||||
|
return True
|
||||||
|
# Process if buffer is too large (safety limit)
|
||||||
|
if self.total_bytes >= MAX_BUFFER_SIZE_BYTES:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_audio(self) -> bytes:
|
||||||
|
"""Get concatenated audio data."""
|
||||||
|
return b''.join(self.chunks)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear the buffer after processing."""
|
||||||
|
self.chunks = []
|
||||||
|
self.total_bytes = 0
|
||||||
|
self.sequence += 1
|
||||||
|
self._last_chunk_vad_result = None # Clear cached VAD result
|
||||||
|
|
||||||
|
def mark_complete(self) -> None:
|
||||||
|
"""Mark stream as complete."""
|
||||||
|
self.is_complete = True
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingSTT:
|
||||||
|
"""Streaming Speech-to-Text service."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.nc = None
|
||||||
|
self.js = None
|
||||||
|
self.http_client = None
|
||||||
|
self.sessions: Dict[str, AudioBuffer] = {}
|
||||||
|
self.running = True
|
||||||
|
self.processing_tasks = {}
|
||||||
|
self.is_healthy = False
|
||||||
|
self.tracer = None
|
||||||
|
self.meter = None
|
||||||
|
self.stream_counter = None
|
||||||
|
self.transcription_duration = None
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
"""Initialize connections."""
|
||||||
|
# Initialize OpenTelemetry
|
||||||
|
self.tracer, self.meter = setup_telemetry()
|
||||||
|
|
||||||
|
# Create metrics if OTEL is enabled
|
||||||
|
if self.meter:
|
||||||
|
self.stream_counter = self.meter.create_counter(
|
||||||
|
name="stt_streams_total",
|
||||||
|
description="Total number of STT streams processed",
|
||||||
|
unit="1"
|
||||||
|
)
|
||||||
|
self.transcription_duration = self.meter.create_histogram(
|
||||||
|
name="stt_transcription_duration_seconds",
|
||||||
|
description="Duration of STT transcription",
|
||||||
|
unit="s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# NATS connection
|
||||||
|
self.nc = await nats.connect(NATS_URL)
|
||||||
|
logger.info(f"Connected to NATS at {NATS_URL}")
|
||||||
|
|
||||||
|
# Initialize JetStream context
|
||||||
|
self.js = self.nc.jetstream()
|
||||||
|
|
||||||
|
# Create or update stream for voice stream messages
|
||||||
|
try:
|
||||||
|
stream_config = nats.js.api.StreamConfig(
|
||||||
|
name="AI_VOICE_STREAM",
|
||||||
|
subjects=["ai.voice.stream.>", "ai.voice.transcription.>"],
|
||||||
|
retention=nats.js.api.RetentionPolicy.LIMITS,
|
||||||
|
max_age=300, # Keep messages for 5 minutes only (streaming is ephemeral)
|
||||||
|
storage=nats.js.api.StorageType.MEMORY, # Use memory for streaming data
|
||||||
|
)
|
||||||
|
await self.js.add_stream(stream_config)
|
||||||
|
logger.info("Created/updated JetStream stream: AI_VOICE_STREAM")
|
||||||
|
except Exception as e:
|
||||||
|
# Stream might already exist
|
||||||
|
logger.info(f"JetStream stream setup: {e}")
|
||||||
|
|
||||||
|
# HTTP client for Whisper service
|
||||||
|
self.http_client = httpx.AsyncClient(timeout=180.0)
|
||||||
|
logger.info("HTTP client initialized")
|
||||||
|
|
||||||
|
# Mark as healthy once connections are established
|
||||||
|
self.is_healthy = True
|
||||||
|
|
||||||
|
async def transcribe(self, audio_bytes: bytes) -> Optional[str]:
|
||||||
|
"""Transcribe audio using Whisper."""
|
||||||
|
try:
|
||||||
|
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
|
||||||
|
response = await self.http_client.post(
|
||||||
|
f"{WHISPER_URL}/v1/audio/transcriptions",
|
||||||
|
files=files
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
transcript = result.get("text", "")
|
||||||
|
logger.info(f"Transcribed: {transcript[:100]}...")
|
||||||
|
return transcript
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Transcription failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def process_buffer(self, session_id: str):
|
||||||
|
"""Process accumulated audio buffer for a session."""
|
||||||
|
buffer = self.sessions.get(session_id)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
|
||||||
|
audio_data = buffer.get_audio()
|
||||||
|
if not audio_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Processing {len(audio_data)} bytes for session {session_id}, sequence {buffer.sequence}")
|
||||||
|
|
||||||
|
# Transcribe
|
||||||
|
transcript = await self.transcribe(audio_data)
|
||||||
|
|
||||||
|
if transcript:
|
||||||
|
# Publish transcription result using msgpack binary format
|
||||||
|
result = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"transcript": transcript,
|
||||||
|
"sequence": buffer.sequence,
|
||||||
|
"is_partial": not buffer.is_complete,
|
||||||
|
"is_final": buffer.is_complete,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"speaker_id": buffer.speaker_id,
|
||||||
|
"has_voice_activity": buffer.has_voice_activity,
|
||||||
|
"state": buffer.state
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.nc.publish(
|
||||||
|
f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}",
|
||||||
|
msgpack.packb(result)
|
||||||
|
)
|
||||||
|
logger.info(f"Published transcription for session {session_id} (seq {buffer.sequence}, speaker={buffer.speaker_id})")
|
||||||
|
|
||||||
|
# Clear buffer after processing
|
||||||
|
buffer.clear()
|
||||||
|
|
||||||
|
# Clean up completed sessions asynchronously
|
||||||
|
if buffer.is_complete:
|
||||||
|
logger.info(f"Session {session_id} completed")
|
||||||
|
# Schedule cleanup task to avoid blocking
|
||||||
|
asyncio.create_task(self._cleanup_session(session_id))
|
||||||
|
|
||||||
|
async def _cleanup_session(self, session_id: str):
|
||||||
|
"""Clean up a completed session after a delay."""
|
||||||
|
# Keep session for a bit in case of late messages
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
if session_id in self.sessions:
|
||||||
|
del self.sessions[session_id]
|
||||||
|
logger.info(f"Cleaned up session: {session_id}")
|
||||||
|
if session_id in self.processing_tasks:
|
||||||
|
del self.processing_tasks[session_id]
|
||||||
|
|
||||||
|
async def monitor_buffer(self, session_id: str):
|
||||||
|
"""Monitor buffer and trigger processing when needed."""
|
||||||
|
while self.running and session_id in self.sessions:
|
||||||
|
buffer = self.sessions.get(session_id)
|
||||||
|
if not buffer:
|
||||||
|
break
|
||||||
|
|
||||||
|
if buffer.should_process():
|
||||||
|
await self.process_buffer(session_id)
|
||||||
|
|
||||||
|
# Don't spin too fast
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
async def handle_stream_message(self, msg: Msg):
|
||||||
|
"""Handle incoming audio stream message."""
|
||||||
|
try:
|
||||||
|
# Extract session_id from subject: ai.voice.stream.{session_id}
|
||||||
|
subject_parts = msg.subject.split('.')
|
||||||
|
if len(subject_parts) < 4:
|
||||||
|
logger.warning(f"Invalid subject format: {msg.subject}")
|
||||||
|
return
|
||||||
|
|
||||||
|
session_id = subject_parts[3]
|
||||||
|
|
||||||
|
# Parse message using msgpack binary format
|
||||||
|
data = msgpack.unpackb(msg.data, raw=False)
|
||||||
|
|
||||||
|
# Handle control messages
|
||||||
|
if data.get("type") == "start":
|
||||||
|
logger.info(f"Starting stream session: {session_id}")
|
||||||
|
self.sessions[session_id] = AudioBuffer(session_id)
|
||||||
|
# Set initial state if provided
|
||||||
|
initial_state = data.get("state", SESSION_STATE_LISTENING)
|
||||||
|
self.sessions[session_id].set_state(initial_state)
|
||||||
|
# Store speaker_id if provided
|
||||||
|
speaker_id = data.get("speaker_id")
|
||||||
|
if speaker_id:
|
||||||
|
self.sessions[session_id].speaker_id = speaker_id
|
||||||
|
logger.info(f"Session {session_id}: Speaker ID set to {speaker_id}")
|
||||||
|
# Start monitoring task for this session
|
||||||
|
task = asyncio.create_task(self.monitor_buffer(session_id))
|
||||||
|
self.processing_tasks[session_id] = task
|
||||||
|
return
|
||||||
|
|
||||||
|
if data.get("type") == "state_change":
|
||||||
|
logger.info(f"State change for session {session_id}")
|
||||||
|
buffer = self.sessions.get(session_id)
|
||||||
|
if buffer:
|
||||||
|
new_state = data.get("state", SESSION_STATE_LISTENING)
|
||||||
|
buffer.set_state(new_state)
|
||||||
|
|
||||||
|
# If switching to listening mode, reset any interrupt tracking
|
||||||
|
if new_state == SESSION_STATE_LISTENING:
|
||||||
|
buffer.interrupt_start_time = None
|
||||||
|
return
|
||||||
|
|
||||||
|
if data.get("type") == "end":
|
||||||
|
logger.info(f"Ending stream session: {session_id}")
|
||||||
|
buffer = self.sessions.get(session_id)
|
||||||
|
if buffer:
|
||||||
|
buffer.mark_complete()
|
||||||
|
# Process any remaining audio
|
||||||
|
if buffer.total_bytes > 0:
|
||||||
|
await self.process_buffer(session_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle audio chunk
|
||||||
|
if data.get("type") == "chunk":
|
||||||
|
audio_b64 = data.get("audio_b64", "")
|
||||||
|
if not audio_b64:
|
||||||
|
return
|
||||||
|
|
||||||
|
audio_bytes = base64.b64decode(audio_b64)
|
||||||
|
|
||||||
|
# Create session if it doesn't exist (handle missing start message)
|
||||||
|
# Check both sessions and processing_tasks to avoid race conditions
|
||||||
|
if session_id not in self.sessions:
|
||||||
|
logger.info(f"Auto-creating session: {session_id}")
|
||||||
|
self.sessions[session_id] = AudioBuffer(session_id)
|
||||||
|
# Only create monitoring task if not already exists
|
||||||
|
if session_id not in self.processing_tasks:
|
||||||
|
task = asyncio.create_task(self.monitor_buffer(session_id))
|
||||||
|
self.processing_tasks[session_id] = task
|
||||||
|
|
||||||
|
buffer = self.sessions[session_id]
|
||||||
|
|
||||||
|
# Check for interrupt if in responding state
|
||||||
|
if buffer.check_interrupt(audio_bytes):
|
||||||
|
# Publish interrupt notification
|
||||||
|
interrupt_msg = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"type": "interrupt",
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"speaker_id": buffer.speaker_id
|
||||||
|
}
|
||||||
|
await self.nc.publish(
|
||||||
|
f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}",
|
||||||
|
msgpack.packb(interrupt_msg)
|
||||||
|
)
|
||||||
|
logger.info(f"Published interrupt notification for session {session_id}")
|
||||||
|
|
||||||
|
# Automatically switch back to listening mode
|
||||||
|
buffer.set_state(SESSION_STATE_LISTENING)
|
||||||
|
|
||||||
|
# Add chunk to buffer
|
||||||
|
buffer.add_chunk(audio_bytes)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling stream message: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Main run loop."""
|
||||||
|
await self.setup()
|
||||||
|
|
||||||
|
# Note: STT streaming uses regular NATS subscribe (not pull-based JetStream consumer)
|
||||||
|
# because it handles real-time ephemeral audio streams with wildcard subscriptions.
|
||||||
|
# The stream audio chunks are not meant to be persisted long-term or replayed.
|
||||||
|
# However, the transcription RESULTS are published to JetStream for persistence.
|
||||||
|
sub = await self.nc.subscribe(f"{STREAM_SUBJECT_PREFIX}.>", cb=self.handle_stream_message)
|
||||||
|
logger.info(f"Subscribed to {STREAM_SUBJECT_PREFIX}.>")
|
||||||
|
|
||||||
|
# Handle shutdown
|
||||||
|
def signal_handler():
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||||
|
loop.add_signal_handler(sig, signal_handler)
|
||||||
|
|
||||||
|
# Keep running
|
||||||
|
while self.running:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
logger.info("Shutting down...")
|
||||||
|
|
||||||
|
# Cancel all monitoring tasks and wait for them to complete
|
||||||
|
for task in self.processing_tasks.values():
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Wait for all tasks to complete or be cancelled
|
||||||
|
if self.processing_tasks:
|
||||||
|
await asyncio.gather(*self.processing_tasks.values(), return_exceptions=True)
|
||||||
|
|
||||||
|
await sub.unsubscribe()
|
||||||
|
await self.nc.close()
|
||||||
|
await self.http_client.aclose()
|
||||||
|
logger.info("Shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
service = StreamingSTT()
|
||||||
|
asyncio.run(service.run())
|
||||||
511
stt_streaming_local.py
Normal file
511
stt_streaming_local.py
Normal file
@@ -0,0 +1,511 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Streaming STT Service with Local Whisper on ROCm
|
||||||
|
|
||||||
|
Real-time Speech-to-Text service that processes live audio streams from NATS
|
||||||
|
using local Whisper model running on AMD GPU via ROCm:
|
||||||
|
|
||||||
|
1. Subscribe to audio stream subject (ai.voice.stream.{session_id})
|
||||||
|
2. Buffer and accumulate audio chunks
|
||||||
|
3. Transcribe locally using Whisper on AMD GPU
|
||||||
|
4. Publish transcription results to response channel (ai.voice.transcription.{session_id})
|
||||||
|
|
||||||
|
This version runs Whisper directly on the AMD GPU using ROCm/PyTorch backend
|
||||||
|
instead of calling an external Whisper service.
|
||||||
|
|
||||||
|
Supports HyperDX for observability via OpenTelemetry.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import contextlib
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
import msgpack
|
||||||
|
import nats
|
||||||
|
import nats.js
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
import whisper
|
||||||
|
from nats.aio.msg import Msg
|
||||||
|
|
||||||
|
# OpenTelemetry imports
|
||||||
|
from opentelemetry import trace, metrics
|
||||||
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||||
|
from opentelemetry.sdk.metrics import MeterProvider
|
||||||
|
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||||
|
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||||
|
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
|
||||||
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP
|
||||||
|
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP
|
||||||
|
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
|
||||||
|
from opentelemetry.instrumentation.logging import LoggingInstrumentor
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("stt-streaming-rocm")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_telemetry():
|
||||||
|
"""Initialize OpenTelemetry tracing and metrics with HyperDX support."""
|
||||||
|
# Check if OTEL is enabled
|
||||||
|
otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true"
|
||||||
|
if not otel_enabled:
|
||||||
|
logger.info("OpenTelemetry disabled")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# OTEL configuration
|
||||||
|
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317")
|
||||||
|
service_name = os.environ.get("OTEL_SERVICE_NAME", "stt-streaming-rocm")
|
||||||
|
service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml")
|
||||||
|
|
||||||
|
# HyperDX configuration
|
||||||
|
hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "")
|
||||||
|
hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io")
|
||||||
|
use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key
|
||||||
|
|
||||||
|
# Create resource with service information
|
||||||
|
resource = Resource.create({
|
||||||
|
SERVICE_NAME: service_name,
|
||||||
|
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
|
||||||
|
SERVICE_NAMESPACE: service_namespace,
|
||||||
|
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
|
||||||
|
"host.name": os.environ.get("HOSTNAME", "unknown"),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Setup tracing
|
||||||
|
trace_provider = TracerProvider(resource=resource)
|
||||||
|
|
||||||
|
if use_hyperdx:
|
||||||
|
# Use HTTP exporter for HyperDX with API key header
|
||||||
|
logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}")
|
||||||
|
headers = {"authorization": hyperdx_api_key}
|
||||||
|
otlp_span_exporter = OTLPSpanExporterHTTP(
|
||||||
|
endpoint=f"{hyperdx_endpoint}/v1/traces",
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
otlp_metric_exporter = OTLPMetricExporterHTTP(
|
||||||
|
endpoint=f"{hyperdx_endpoint}/v1/metrics",
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use gRPC exporter for standard OTEL collector
|
||||||
|
logger.info(f"Configuring OTEL gRPC exporter at {otel_endpoint}")
|
||||||
|
otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True)
|
||||||
|
otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True)
|
||||||
|
|
||||||
|
trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
|
||||||
|
trace.set_tracer_provider(trace_provider)
|
||||||
|
|
||||||
|
# Setup metrics
|
||||||
|
metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000)
|
||||||
|
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||||
|
metrics.set_meter_provider(meter_provider)
|
||||||
|
|
||||||
|
# Instrument logging
|
||||||
|
LoggingInstrumentor().instrument(set_logging_format=True)
|
||||||
|
|
||||||
|
destination = "HyperDX" if use_hyperdx else "OTEL Collector"
|
||||||
|
logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}")
|
||||||
|
|
||||||
|
# Return tracer and meter for the service
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
meter = metrics.get_meter(__name__)
|
||||||
|
|
||||||
|
return tracer, meter
|
||||||
|
|
||||||
|
|
||||||
|
# Configuration from environment
|
||||||
|
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
|
||||||
|
WHISPER_MODEL_SIZE = os.environ.get("WHISPER_MODEL_SIZE", "medium")
|
||||||
|
WHISPER_DEVICE = os.environ.get("WHISPER_DEVICE", "cuda") # cuda uses ROCm on AMD
|
||||||
|
WHISPER_FP16 = os.environ.get("WHISPER_FP16", "true").lower() == "true"
|
||||||
|
|
||||||
|
# NATS subjects for streaming
|
||||||
|
STREAM_SUBJECT_PREFIX = "ai.voice.stream" # Full subject: ai.voice.stream.{session_id}
|
||||||
|
TRANSCRIPTION_SUBJECT_PREFIX = "ai.voice.transcription" # Full subject: ai.voice.transcription.{session_id}
|
||||||
|
|
||||||
|
# Streaming parameters
|
||||||
|
BUFFER_SIZE_BYTES = int(os.environ.get("STT_BUFFER_SIZE_BYTES", "512000")) # ~5 seconds at 16kHz 16-bit
|
||||||
|
CHUNK_TIMEOUT_SECONDS = float(os.environ.get("STT_CHUNK_TIMEOUT", "2.0")) # Process after 2s of silence
|
||||||
|
MAX_BUFFER_SIZE_BYTES = int(os.environ.get("STT_MAX_BUFFER_SIZE", "5120000")) # ~50 seconds max
|
||||||
|
|
||||||
|
# Health server port for kserve compatibility
|
||||||
|
HEALTH_PORT = int(os.environ.get("HEALTH_PORT", "8000"))
|
||||||
|
|
||||||
|
|
||||||
|
class AudioBuffer:
|
||||||
|
"""Manages audio chunks for a streaming session."""
|
||||||
|
|
||||||
|
def __init__(self, session_id: str):
|
||||||
|
self.session_id = session_id
|
||||||
|
self.chunks = []
|
||||||
|
self.total_bytes = 0
|
||||||
|
self.last_chunk_time = time.time()
|
||||||
|
self.is_complete = False
|
||||||
|
self.sequence = 0
|
||||||
|
|
||||||
|
def add_chunk(self, audio_data: bytes) -> None:
|
||||||
|
"""Add an audio chunk to the buffer."""
|
||||||
|
self.chunks.append(audio_data)
|
||||||
|
self.total_bytes += len(audio_data)
|
||||||
|
self.last_chunk_time = time.time()
|
||||||
|
logger.debug(f"Session {self.session_id}: Added chunk, total {self.total_bytes} bytes")
|
||||||
|
|
||||||
|
def should_process(self) -> bool:
|
||||||
|
"""Determine if buffer should be processed now."""
|
||||||
|
# Process if buffer size threshold reached
|
||||||
|
if self.total_bytes >= BUFFER_SIZE_BYTES:
|
||||||
|
return True
|
||||||
|
# Process if no chunks received for timeout duration
|
||||||
|
if time.time() - self.last_chunk_time > CHUNK_TIMEOUT_SECONDS and self.total_bytes > 0:
|
||||||
|
return True
|
||||||
|
# Process if buffer is too large (safety limit)
|
||||||
|
if self.total_bytes >= MAX_BUFFER_SIZE_BYTES:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_audio(self) -> bytes:
|
||||||
|
"""Get concatenated audio data."""
|
||||||
|
return b''.join(self.chunks)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear the buffer after processing."""
|
||||||
|
self.chunks = []
|
||||||
|
self.total_bytes = 0
|
||||||
|
self.sequence += 1
|
||||||
|
|
||||||
|
def mark_complete(self) -> None:
|
||||||
|
"""Mark stream as complete."""
|
||||||
|
self.is_complete = True
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingSTTLocal:
|
||||||
|
"""Streaming Speech-to-Text service with local Whisper on ROCm."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.nc = None
|
||||||
|
self.js = None
|
||||||
|
self.whisper_model = None
|
||||||
|
self.sessions: Dict[str, AudioBuffer] = {}
|
||||||
|
self.running = True
|
||||||
|
self.processing_tasks = {}
|
||||||
|
self.is_healthy = False
|
||||||
|
self.tracer = None
|
||||||
|
self.meter = None
|
||||||
|
self.stream_counter = None
|
||||||
|
self.transcription_duration = None
|
||||||
|
self.gpu_memory_gauge = None
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
"""Initialize connections and load model."""
|
||||||
|
# Initialize OpenTelemetry
|
||||||
|
self.tracer, self.meter = setup_telemetry()
|
||||||
|
|
||||||
|
# Create metrics if OTEL is enabled
|
||||||
|
if self.meter:
|
||||||
|
self.stream_counter = self.meter.create_counter(
|
||||||
|
name="stt_streams_total",
|
||||||
|
description="Total number of STT streams processed",
|
||||||
|
unit="1"
|
||||||
|
)
|
||||||
|
self.transcription_duration = self.meter.create_histogram(
|
||||||
|
name="stt_transcription_duration_seconds",
|
||||||
|
description="Duration of STT transcription",
|
||||||
|
unit="s"
|
||||||
|
)
|
||||||
|
self.gpu_memory_gauge = self.meter.create_observable_gauge(
|
||||||
|
name="stt_gpu_memory_bytes",
|
||||||
|
description="GPU memory usage in bytes",
|
||||||
|
callbacks=[self._get_gpu_memory]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check GPU availability
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
gpu_name = torch.cuda.get_device_name(0)
|
||||||
|
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
||||||
|
logger.info(f"ROCm GPU available: {gpu_name} ({gpu_memory:.1f}GB)")
|
||||||
|
else:
|
||||||
|
logger.warning("No GPU available, falling back to CPU")
|
||||||
|
|
||||||
|
# Load Whisper model
|
||||||
|
logger.info(f"Loading Whisper model: {WHISPER_MODEL_SIZE} on {WHISPER_DEVICE}")
|
||||||
|
start_time = time.time()
|
||||||
|
self.whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device=WHISPER_DEVICE)
|
||||||
|
load_time = time.time() - start_time
|
||||||
|
logger.info(f"Whisper model loaded in {load_time:.2f}s")
|
||||||
|
|
||||||
|
# NATS connection
|
||||||
|
self.nc = await nats.connect(NATS_URL)
|
||||||
|
logger.info(f"Connected to NATS at {NATS_URL}")
|
||||||
|
|
||||||
|
# Initialize JetStream context
|
||||||
|
self.js = self.nc.jetstream()
|
||||||
|
|
||||||
|
# Create or update stream for voice stream messages
|
||||||
|
try:
|
||||||
|
stream_config = nats.js.api.StreamConfig(
|
||||||
|
name="AI_VOICE_STREAM",
|
||||||
|
subjects=["ai.voice.stream.>", "ai.voice.transcription.>"],
|
||||||
|
retention=nats.js.api.RetentionPolicy.LIMITS,
|
||||||
|
max_age=300, # Keep messages for 5 minutes only (streaming is ephemeral)
|
||||||
|
storage=nats.js.api.StorageType.MEMORY, # Use memory for streaming data
|
||||||
|
)
|
||||||
|
await self.js.add_stream(stream_config)
|
||||||
|
logger.info("Created/updated JetStream stream: AI_VOICE_STREAM")
|
||||||
|
except Exception as e:
|
||||||
|
# Stream might already exist
|
||||||
|
logger.info(f"JetStream stream setup: {e}")
|
||||||
|
|
||||||
|
# Mark as healthy once connections are established
|
||||||
|
self.is_healthy = True
|
||||||
|
|
||||||
|
async def health_handler(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle health check requests for kserve compatibility."""
|
||||||
|
if self.is_healthy:
|
||||||
|
return web.json_response({
|
||||||
|
"status": "healthy",
|
||||||
|
"model": WHISPER_MODEL_SIZE,
|
||||||
|
"device": WHISPER_DEVICE,
|
||||||
|
"nats_connected": self.nc is not None and self.nc.is_connected,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return web.json_response(
|
||||||
|
{"status": "unhealthy", "model": WHISPER_MODEL_SIZE},
|
||||||
|
status=503
|
||||||
|
)
|
||||||
|
|
||||||
|
async def start_health_server(self) -> web.AppRunner:
|
||||||
|
"""Start HTTP health server for kserve agent sidecar."""
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get("/health", self.health_handler)
|
||||||
|
app.router.add_get("/ready", self.health_handler)
|
||||||
|
app.router.add_get("/", self.health_handler)
|
||||||
|
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "0.0.0.0", HEALTH_PORT)
|
||||||
|
await site.start()
|
||||||
|
logger.info(f"Health server started on port {HEALTH_PORT}")
|
||||||
|
return runner
|
||||||
|
|
||||||
|
def _get_gpu_memory(self, options):
|
||||||
|
"""Callback for GPU memory gauge."""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
memory_used = torch.cuda.memory_allocated(0)
|
||||||
|
yield metrics.Observation(memory_used, {"device": "0"})
|
||||||
|
|
||||||
|
def transcribe(self, audio_bytes: bytes) -> Optional[str]:
|
||||||
|
"""Transcribe audio using local Whisper model."""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Write audio to temp file (Whisper needs file path or numpy array)
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
|
||||||
|
tmp.write(audio_bytes)
|
||||||
|
tmp.flush()
|
||||||
|
|
||||||
|
# Transcribe with Whisper
|
||||||
|
result = self.whisper_model.transcribe(
|
||||||
|
tmp.name,
|
||||||
|
fp16=WHISPER_FP16 and WHISPER_DEVICE == "cuda",
|
||||||
|
language="en", # Can be made configurable
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript = result.get("text", "").strip()
|
||||||
|
|
||||||
|
duration = time.time() - start_time
|
||||||
|
audio_duration = len(audio_bytes) / (16000 * 2) # Assuming 16kHz 16-bit
|
||||||
|
rtf = duration / audio_duration if audio_duration > 0 else 0
|
||||||
|
|
||||||
|
logger.info(f"Transcribed in {duration:.2f}s (RTF: {rtf:.2f}): {transcript[:100]}...")
|
||||||
|
|
||||||
|
# Record metrics
|
||||||
|
if self.transcription_duration:
|
||||||
|
self.transcription_duration.record(duration, {"model": WHISPER_MODEL_SIZE})
|
||||||
|
|
||||||
|
return transcript
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Transcription failed: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def process_buffer(self, session_id: str):
|
||||||
|
"""Process accumulated audio buffer for a session."""
|
||||||
|
buffer = self.sessions.get(session_id)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
|
||||||
|
audio_data = buffer.get_audio()
|
||||||
|
if not audio_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Processing {len(audio_data)} bytes for session {session_id}, sequence {buffer.sequence}")
|
||||||
|
|
||||||
|
# Record stream counter
|
||||||
|
if self.stream_counter:
|
||||||
|
self.stream_counter.add(1, {"session_id": session_id})
|
||||||
|
|
||||||
|
# Transcribe in thread pool to avoid blocking event loop
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
transcript = await loop.run_in_executor(None, self.transcribe, audio_data)
|
||||||
|
|
||||||
|
if transcript:
|
||||||
|
# Publish transcription result using msgpack binary format
|
||||||
|
result = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"transcript": transcript,
|
||||||
|
"sequence": buffer.sequence,
|
||||||
|
"is_partial": not buffer.is_complete,
|
||||||
|
"is_final": buffer.is_complete,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"model": WHISPER_MODEL_SIZE,
|
||||||
|
"device": WHISPER_DEVICE,
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.nc.publish(
|
||||||
|
f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}",
|
||||||
|
msgpack.packb(result)
|
||||||
|
)
|
||||||
|
logger.info(f"Published transcription for session {session_id} (seq {buffer.sequence})")
|
||||||
|
|
||||||
|
# Clear buffer after processing
|
||||||
|
buffer.clear()
|
||||||
|
|
||||||
|
# Clean up completed sessions asynchronously
|
||||||
|
if buffer.is_complete:
|
||||||
|
logger.info(f"Session {session_id} completed")
|
||||||
|
# Schedule cleanup task to avoid blocking
|
||||||
|
asyncio.create_task(self._cleanup_session(session_id))
|
||||||
|
|
||||||
|
async def _cleanup_session(self, session_id: str):
|
||||||
|
"""Clean up a completed session after a delay."""
|
||||||
|
# Keep session for a bit in case of late messages
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
if session_id in self.sessions:
|
||||||
|
del self.sessions[session_id]
|
||||||
|
logger.info(f"Cleaned up session: {session_id}")
|
||||||
|
if session_id in self.processing_tasks:
|
||||||
|
del self.processing_tasks[session_id]
|
||||||
|
|
||||||
|
async def monitor_buffer(self, session_id: str):
|
||||||
|
"""Monitor buffer and trigger processing when needed."""
|
||||||
|
while self.running and session_id in self.sessions:
|
||||||
|
buffer = self.sessions.get(session_id)
|
||||||
|
if not buffer:
|
||||||
|
break
|
||||||
|
|
||||||
|
if buffer.should_process():
|
||||||
|
await self.process_buffer(session_id)
|
||||||
|
|
||||||
|
# Don't spin too fast
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
async def handle_stream_message(self, msg: Msg):
|
||||||
|
"""Handle incoming audio stream message."""
|
||||||
|
try:
|
||||||
|
# Extract session_id from subject: ai.voice.stream.{session_id}
|
||||||
|
subject_parts = msg.subject.split('.')
|
||||||
|
if len(subject_parts) < 4:
|
||||||
|
logger.warning(f"Invalid subject format: {msg.subject}")
|
||||||
|
return
|
||||||
|
|
||||||
|
session_id = subject_parts[3]
|
||||||
|
|
||||||
|
# Parse message using msgpack binary format
|
||||||
|
data = msgpack.unpackb(msg.data, raw=False)
|
||||||
|
|
||||||
|
# Handle control messages
|
||||||
|
if data.get("type") == "start":
|
||||||
|
logger.info(f"Starting stream session: {session_id}")
|
||||||
|
self.sessions[session_id] = AudioBuffer(session_id)
|
||||||
|
# Start monitoring task for this session
|
||||||
|
task = asyncio.create_task(self.monitor_buffer(session_id))
|
||||||
|
self.processing_tasks[session_id] = task
|
||||||
|
return
|
||||||
|
|
||||||
|
if data.get("type") == "end":
|
||||||
|
logger.info(f"Ending stream session: {session_id}")
|
||||||
|
buffer = self.sessions.get(session_id)
|
||||||
|
if buffer:
|
||||||
|
buffer.mark_complete()
|
||||||
|
# Process any remaining audio
|
||||||
|
if buffer.total_bytes > 0:
|
||||||
|
await self.process_buffer(session_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle audio chunk
|
||||||
|
if data.get("type") == "chunk":
|
||||||
|
audio_b64 = data.get("audio_b64", "")
|
||||||
|
if not audio_b64:
|
||||||
|
return
|
||||||
|
|
||||||
|
audio_bytes = base64.b64decode(audio_b64)
|
||||||
|
|
||||||
|
# Create session if it doesn't exist (handle missing start message)
|
||||||
|
if session_id not in self.sessions:
|
||||||
|
logger.info(f"Auto-creating session: {session_id}")
|
||||||
|
self.sessions[session_id] = AudioBuffer(session_id)
|
||||||
|
if session_id not in self.processing_tasks:
|
||||||
|
task = asyncio.create_task(self.monitor_buffer(session_id))
|
||||||
|
self.processing_tasks[session_id] = task
|
||||||
|
|
||||||
|
# Add chunk to buffer
|
||||||
|
self.sessions[session_id].add_chunk(audio_bytes)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling stream message: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Main run loop."""
|
||||||
|
await self.setup()
|
||||||
|
|
||||||
|
# Start health server for kserve compatibility
|
||||||
|
health_runner = await self.start_health_server()
|
||||||
|
|
||||||
|
# Subscribe to voice stream
|
||||||
|
sub = await self.nc.subscribe(f"{STREAM_SUBJECT_PREFIX}.>", cb=self.handle_stream_message)
|
||||||
|
logger.info(f"Subscribed to {STREAM_SUBJECT_PREFIX}.>")
|
||||||
|
|
||||||
|
# Handle shutdown
|
||||||
|
def signal_handler():
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||||
|
loop.add_signal_handler(sig, signal_handler)
|
||||||
|
|
||||||
|
# Keep running
|
||||||
|
while self.running:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
logger.info("Shutting down...")
|
||||||
|
|
||||||
|
# Cancel all monitoring tasks and wait for them to complete
|
||||||
|
for task in self.processing_tasks.values():
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
if self.processing_tasks:
|
||||||
|
await asyncio.gather(*self.processing_tasks.values(), return_exceptions=True)
|
||||||
|
|
||||||
|
await sub.unsubscribe()
|
||||||
|
await self.nc.close()
|
||||||
|
await health_runner.cleanup()
|
||||||
|
logger.info("Shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
service = StreamingSTTLocal()
|
||||||
|
asyncio.run(service.run())
|
||||||
Reference in New Issue
Block a user