feat: CI pipeline, lint fixes, and Renovate config
Some checks failed
CI / Docker Build & Push (push) Has been skipped
CI / Notify (push) Successful in 1s
CI / Lint (push) Failing after 9s
CI / Test (push) Successful in 50s
CI / Release (push) Has been skipped

- pyproject.toml with ruff/pytest config (setuptools<81 pin)
- Full test suite (26 tests)
- Gitea Actions CI (lint, test, docker, notify)
- Ruff lint/format fixes across source files
- Renovate config for automated dependency updates

Ref: ADR-0057
This commit is contained in:
2026-02-13 15:33:35 -05:00
parent 8fc5eb1193
commit 55cd657364
10 changed files with 1655 additions and 279 deletions

View File

@@ -15,44 +15,43 @@ instead of calling an external Whisper service.
Supports HyperDX for observability via OpenTelemetry.
"""
import asyncio
import base64
import contextlib
import io
import logging
import os
import signal
import tempfile
import time
from typing import Dict, Optional
from aiohttp import web
import msgpack
import nats
import nats.js
import numpy as np
import soundfile as sf
import torch
import whisper
from aiohttp import web
from nats.aio.msg import Msg
# OpenTelemetry imports
from opentelemetry import trace, metrics
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as OTLPMetricExporterHTTP,
)
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterHTTP,
)
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, SERVICE_VERSION, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("stt-streaming-rocm")
@@ -64,65 +63,70 @@ def setup_telemetry():
if not otel_enabled:
logger.info("OpenTelemetry disabled")
return None, None
# OTEL configuration
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317")
otel_endpoint = os.environ.get(
"OTEL_EXPORTER_OTLP_ENDPOINT",
"http://opentelemetry-collector.observability.svc.cluster.local:4317",
)
service_name = os.environ.get("OTEL_SERVICE_NAME", "stt-streaming-rocm")
service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml")
# HyperDX configuration
hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "")
hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io")
use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key
# Create resource with service information
resource = Resource.create({
SERVICE_NAME: service_name,
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
SERVICE_NAMESPACE: service_namespace,
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
"host.name": os.environ.get("HOSTNAME", "unknown"),
})
resource = Resource.create(
{
SERVICE_NAME: service_name,
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
SERVICE_NAMESPACE: service_namespace,
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
"host.name": os.environ.get("HOSTNAME", "unknown"),
}
)
# Setup tracing
trace_provider = TracerProvider(resource=resource)
if use_hyperdx:
# Use HTTP exporter for HyperDX with API key header
logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}")
headers = {"authorization": hyperdx_api_key}
otlp_span_exporter = OTLPSpanExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/traces",
headers=headers
endpoint=f"{hyperdx_endpoint}/v1/traces", headers=headers
)
otlp_metric_exporter = OTLPMetricExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/metrics",
headers=headers
endpoint=f"{hyperdx_endpoint}/v1/metrics", headers=headers
)
else:
# Use gRPC exporter for standard OTEL collector
logger.info(f"Configuring OTEL gRPC exporter at {otel_endpoint}")
otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True)
otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True)
trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
trace.set_tracer_provider(trace_provider)
# Setup metrics
metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000)
metric_reader = PeriodicExportingMetricReader(
otlp_metric_exporter, export_interval_millis=60000
)
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
# Instrument logging
LoggingInstrumentor().instrument(set_logging_format=True)
destination = "HyperDX" if use_hyperdx else "OTEL Collector"
logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}")
# Return tracer and meter for the service
tracer = trace.get_tracer(__name__)
meter = metrics.get_meter(__name__)
return tracer, meter
@@ -134,11 +138,17 @@ WHISPER_FP16 = os.environ.get("WHISPER_FP16", "true").lower() == "true"
# NATS subjects for streaming
STREAM_SUBJECT_PREFIX = "ai.voice.stream" # Full subject: ai.voice.stream.{session_id}
TRANSCRIPTION_SUBJECT_PREFIX = "ai.voice.transcription" # Full subject: ai.voice.transcription.{session_id}
TRANSCRIPTION_SUBJECT_PREFIX = (
"ai.voice.transcription" # Full subject: ai.voice.transcription.{session_id}
)
# Streaming parameters
BUFFER_SIZE_BYTES = int(os.environ.get("STT_BUFFER_SIZE_BYTES", "512000")) # ~5 seconds at 16kHz 16-bit
CHUNK_TIMEOUT_SECONDS = float(os.environ.get("STT_CHUNK_TIMEOUT", "2.0")) # Process after 2s of silence
BUFFER_SIZE_BYTES = int(
os.environ.get("STT_BUFFER_SIZE_BYTES", "512000")
) # ~5 seconds at 16kHz 16-bit
CHUNK_TIMEOUT_SECONDS = float(
os.environ.get("STT_CHUNK_TIMEOUT", "2.0")
) # Process after 2s of silence
MAX_BUFFER_SIZE_BYTES = int(os.environ.get("STT_MAX_BUFFER_SIZE", "5120000")) # ~50 seconds max
# Health server port for kserve compatibility
@@ -147,7 +157,7 @@ HEALTH_PORT = int(os.environ.get("HEALTH_PORT", "8000"))
class AudioBuffer:
"""Manages audio chunks for a streaming session."""
def __init__(self, session_id: str):
self.session_id = session_id
self.chunks = []
@@ -155,14 +165,14 @@ class AudioBuffer:
self.last_chunk_time = time.time()
self.is_complete = False
self.sequence = 0
def add_chunk(self, audio_data: bytes) -> None:
"""Add an audio chunk to the buffer."""
self.chunks.append(audio_data)
self.total_bytes += len(audio_data)
self.last_chunk_time = time.time()
logger.debug(f"Session {self.session_id}: Added chunk, total {self.total_bytes} bytes")
def should_process(self) -> bool:
"""Determine if buffer should be processed now."""
# Process if buffer size threshold reached
@@ -172,20 +182,18 @@ class AudioBuffer:
if time.time() - self.last_chunk_time > CHUNK_TIMEOUT_SECONDS and self.total_bytes > 0:
return True
# Process if buffer is too large (safety limit)
if self.total_bytes >= MAX_BUFFER_SIZE_BYTES:
return True
return False
return self.total_bytes >= MAX_BUFFER_SIZE_BYTES
def get_audio(self) -> bytes:
"""Get concatenated audio data."""
return b''.join(self.chunks)
return b"".join(self.chunks)
def clear(self) -> None:
"""Clear the buffer after processing."""
self.chunks = []
self.total_bytes = 0
self.sequence += 1
def mark_complete(self) -> None:
"""Mark stream as complete."""
self.is_complete = True
@@ -193,12 +201,12 @@ class AudioBuffer:
class StreamingSTTLocal:
"""Streaming Speech-to-Text service with local Whisper on ROCm."""
def __init__(self):
self.nc = None
self.js = None
self.whisper_model = None
self.sessions: Dict[str, AudioBuffer] = {}
self.sessions: dict[str, AudioBuffer] = {}
self.running = True
self.processing_tasks = {}
self.is_healthy = False
@@ -207,30 +215,30 @@ class StreamingSTTLocal:
self.stream_counter = None
self.transcription_duration = None
self.gpu_memory_gauge = None
async def setup(self):
"""Initialize connections and load model."""
# Initialize OpenTelemetry
self.tracer, self.meter = setup_telemetry()
# Create metrics if OTEL is enabled
if self.meter:
self.stream_counter = self.meter.create_counter(
name="stt_streams_total",
description="Total number of STT streams processed",
unit="1"
unit="1",
)
self.transcription_duration = self.meter.create_histogram(
name="stt_transcription_duration_seconds",
description="Duration of STT transcription",
unit="s"
unit="s",
)
self.gpu_memory_gauge = self.meter.create_observable_gauge(
name="stt_gpu_memory_bytes",
description="GPU memory usage in bytes",
callbacks=[self._get_gpu_memory]
callbacks=[self._get_gpu_memory],
)
# Check GPU availability
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
@@ -238,21 +246,21 @@ class StreamingSTTLocal:
logger.info(f"ROCm GPU available: {gpu_name} ({gpu_memory:.1f}GB)")
else:
logger.warning("No GPU available, falling back to CPU")
# Load Whisper model
logger.info(f"Loading Whisper model: {WHISPER_MODEL_SIZE} on {WHISPER_DEVICE}")
start_time = time.time()
self.whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device=WHISPER_DEVICE)
load_time = time.time() - start_time
logger.info(f"Whisper model loaded in {load_time:.2f}s")
# NATS connection
self.nc = await nats.connect(NATS_URL)
logger.info(f"Connected to NATS at {NATS_URL}")
# Initialize JetStream context
self.js = self.nc.jetstream()
# Create or update stream for voice stream messages
try:
stream_config = nats.js.api.StreamConfig(
@@ -267,100 +275,103 @@ class StreamingSTTLocal:
except Exception as e:
# Stream might already exist
logger.info(f"JetStream stream setup: {e}")
# Mark as healthy once connections are established
self.is_healthy = True
async def health_handler(self, request: web.Request) -> web.Response:
"""Handle health check requests for kserve compatibility."""
if self.is_healthy:
return web.json_response({
"status": "healthy",
"model": WHISPER_MODEL_SIZE,
"device": WHISPER_DEVICE,
"nats_connected": self.nc is not None and self.nc.is_connected,
})
return web.json_response(
{
"status": "healthy",
"model": WHISPER_MODEL_SIZE,
"device": WHISPER_DEVICE,
"nats_connected": self.nc is not None and self.nc.is_connected,
}
)
else:
return web.json_response(
{"status": "unhealthy", "model": WHISPER_MODEL_SIZE},
status=503
{"status": "unhealthy", "model": WHISPER_MODEL_SIZE}, status=503
)
async def start_health_server(self) -> web.AppRunner:
"""Start HTTP health server for kserve agent sidecar."""
app = web.Application()
app.router.add_get("/health", self.health_handler)
app.router.add_get("/ready", self.health_handler)
app.router.add_get("/", self.health_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "0.0.0.0", HEALTH_PORT)
await site.start()
logger.info(f"Health server started on port {HEALTH_PORT}")
return runner
def _get_gpu_memory(self, options):
"""Callback for GPU memory gauge."""
if torch.cuda.is_available():
memory_used = torch.cuda.memory_allocated(0)
yield metrics.Observation(memory_used, {"device": "0"})
def transcribe(self, audio_bytes: bytes) -> Optional[str]:
def transcribe(self, audio_bytes: bytes) -> str | None:
"""Transcribe audio using local Whisper model."""
start_time = time.time()
try:
# Write audio to temp file (Whisper needs file path or numpy array)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
tmp.write(audio_bytes)
tmp.flush()
# Transcribe with Whisper
result = self.whisper_model.transcribe(
tmp.name,
fp16=WHISPER_FP16 and WHISPER_DEVICE == "cuda",
language="en", # Can be made configurable
)
transcript = result.get("text", "").strip()
duration = time.time() - start_time
audio_duration = len(audio_bytes) / (16000 * 2) # Assuming 16kHz 16-bit
rtf = duration / audio_duration if audio_duration > 0 else 0
logger.info(f"Transcribed in {duration:.2f}s (RTF: {rtf:.2f}): {transcript[:100]}...")
# Record metrics
if self.transcription_duration:
self.transcription_duration.record(duration, {"model": WHISPER_MODEL_SIZE})
return transcript
except Exception as e:
logger.error(f"Transcription failed: {e}", exc_info=True)
return None
async def process_buffer(self, session_id: str):
"""Process accumulated audio buffer for a session."""
buffer = self.sessions.get(session_id)
if not buffer:
return
audio_data = buffer.get_audio()
if not audio_data:
return
logger.info(f"Processing {len(audio_data)} bytes for session {session_id}, sequence {buffer.sequence}")
logger.info(
f"Processing {len(audio_data)} bytes for session {session_id}, sequence {buffer.sequence}"
)
# Record stream counter
if self.stream_counter:
self.stream_counter.add(1, {"session_id": session_id})
# Transcribe in thread pool to avoid blocking event loop
loop = asyncio.get_event_loop()
transcript = await loop.run_in_executor(None, self.transcribe, audio_data)
if transcript:
# Publish transcription result using msgpack binary format
result = {
@@ -373,22 +384,21 @@ class StreamingSTTLocal:
"model": WHISPER_MODEL_SIZE,
"device": WHISPER_DEVICE,
}
await self.nc.publish(
f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}",
msgpack.packb(result)
f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}", msgpack.packb(result)
)
logger.info(f"Published transcription for session {session_id} (seq {buffer.sequence})")
# Clear buffer after processing
buffer.clear()
# Clean up completed sessions asynchronously
if buffer.is_complete:
logger.info(f"Session {session_id} completed")
# Schedule cleanup task to avoid blocking
asyncio.create_task(self._cleanup_session(session_id))
async def _cleanup_session(self, session_id: str):
"""Clean up a completed session after a delay."""
# Keep session for a bit in case of late messages
@@ -398,34 +408,34 @@ class StreamingSTTLocal:
logger.info(f"Cleaned up session: {session_id}")
if session_id in self.processing_tasks:
del self.processing_tasks[session_id]
async def monitor_buffer(self, session_id: str):
"""Monitor buffer and trigger processing when needed."""
while self.running and session_id in self.sessions:
buffer = self.sessions.get(session_id)
if not buffer:
break
if buffer.should_process():
await self.process_buffer(session_id)
# Don't spin too fast
await asyncio.sleep(0.1)
async def handle_stream_message(self, msg: Msg):
"""Handle incoming audio stream message."""
try:
# Extract session_id from subject: ai.voice.stream.{session_id}
subject_parts = msg.subject.split('.')
subject_parts = msg.subject.split(".")
if len(subject_parts) < 4:
logger.warning(f"Invalid subject format: {msg.subject}")
return
session_id = subject_parts[3]
# Parse message using msgpack binary format
data = msgpack.unpackb(msg.data, raw=False)
# Handle control messages
if data.get("type") == "start":
logger.info(f"Starting stream session: {session_id}")
@@ -434,7 +444,7 @@ class StreamingSTTLocal:
task = asyncio.create_task(self.monitor_buffer(session_id))
self.processing_tasks[session_id] = task
return
if data.get("type") == "end":
logger.info(f"Ending stream session: {session_id}")
buffer = self.sessions.get(session_id)
@@ -444,15 +454,15 @@ class StreamingSTTLocal:
if buffer.total_bytes > 0:
await self.process_buffer(session_id)
return
# Handle audio chunk
if data.get("type") == "chunk":
audio_b64 = data.get("audio_b64", "")
if not audio_b64:
return
audio_bytes = base64.b64decode(audio_b64)
# Create session if it doesn't exist (handle missing start message)
if session_id not in self.sessions:
logger.info(f"Auto-creating session: {session_id}")
@@ -460,46 +470,46 @@ class StreamingSTTLocal:
if session_id not in self.processing_tasks:
task = asyncio.create_task(self.monitor_buffer(session_id))
self.processing_tasks[session_id] = task
# Add chunk to buffer
self.sessions[session_id].add_chunk(audio_bytes)
except Exception as e:
logger.error(f"Error handling stream message: {e}", exc_info=True)
async def run(self):
"""Main run loop."""
await self.setup()
# Start health server for kserve compatibility
health_runner = await self.start_health_server()
# Subscribe to voice stream
sub = await self.nc.subscribe(f"{STREAM_SUBJECT_PREFIX}.>", cb=self.handle_stream_message)
logger.info(f"Subscribed to {STREAM_SUBJECT_PREFIX}.>")
# Handle shutdown
def signal_handler():
self.running = False
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
# Keep running
while self.running:
await asyncio.sleep(1)
# Cleanup
logger.info("Shutting down...")
# Cancel all monitoring tasks and wait for them to complete
for task in self.processing_tasks.values():
task.cancel()
if self.processing_tasks:
await asyncio.gather(*self.processing_tasks.values(), return_exceptions=True)
await sub.unsubscribe()
await self.nc.close()
await health_runner.cleanup()