feat: CI pipeline, lint fixes, and Renovate config
Some checks failed
CI / Docker Build & Push (push) Has been skipped
CI / Notify (push) Successful in 1s
CI / Lint (push) Failing after 9s
CI / Test (push) Successful in 50s
CI / Release (push) Has been skipped

- pyproject.toml with ruff/pytest config (setuptools<81 pin)
- Full test suite (26 tests)
- Gitea Actions CI (lint, test, docker, notify)
- Ruff lint/format fixes across source files
- Renovate config for automated dependency updates

Ref: ADR-0057
This commit is contained in:
2026-02-13 15:33:35 -05:00
parent 8fc5eb1193
commit 55cd657364
10 changed files with 1655 additions and 279 deletions

View File

@@ -11,46 +11,47 @@ Real-time Speech-to-Text service that processes live audio streams from NATS:
This enables faster response times by processing audio as it arrives rather than
waiting for complete audio upload.
"""
import asyncio
import base64
import contextlib
import logging
import os
import signal
import time
import struct
from typing import Dict, Optional, List, Tuple
from io import BytesIO
import httpx
import msgpack
import nats
import nats.js
from nats.aio.msg import Msg
import numpy as np
import webrtcvad
from nats.aio.msg import Msg
# OpenTelemetry imports
from opentelemetry import trace, metrics
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as OTLPMetricExporterHTTP,
)
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterHTTP,
)
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, SERVICE_VERSION, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("stt-streaming")
# Initialize OpenTelemetry
def setup_telemetry():
"""Initialize OpenTelemetry tracing and metrics with HyperDX support."""
@@ -59,98 +60,120 @@ def setup_telemetry():
if not otel_enabled:
logger.info("OpenTelemetry disabled")
return None, None
# OTEL configuration
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317")
otel_endpoint = os.environ.get(
"OTEL_EXPORTER_OTLP_ENDPOINT",
"http://opentelemetry-collector.observability.svc.cluster.local:4317",
)
service_name = os.environ.get("OTEL_SERVICE_NAME", "stt-streaming")
service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml")
# HyperDX configuration
hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "")
hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io")
use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key
# Create resource with service information
resource = Resource.create({
SERVICE_NAME: service_name,
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
SERVICE_NAMESPACE: service_namespace,
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
"host.name": os.environ.get("HOSTNAME", "unknown"),
})
resource = Resource.create(
{
SERVICE_NAME: service_name,
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
SERVICE_NAMESPACE: service_namespace,
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
"host.name": os.environ.get("HOSTNAME", "unknown"),
}
)
# Setup tracing
trace_provider = TracerProvider(resource=resource)
if use_hyperdx:
# Use HTTP exporter for HyperDX with API key header
logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}")
headers = {"authorization": hyperdx_api_key}
otlp_span_exporter = OTLPSpanExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/traces",
headers=headers
endpoint=f"{hyperdx_endpoint}/v1/traces", headers=headers
)
otlp_metric_exporter = OTLPMetricExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/metrics",
headers=headers
endpoint=f"{hyperdx_endpoint}/v1/metrics", headers=headers
)
else:
# Use gRPC exporter for standard OTEL collector
otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True)
otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True)
trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
trace.set_tracer_provider(trace_provider)
# Setup metrics
metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000)
metric_reader = PeriodicExportingMetricReader(
otlp_metric_exporter, export_interval_millis=60000
)
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
# Instrument HTTPX
HTTPXClientInstrumentor().instrument()
# Instrument logging
LoggingInstrumentor().instrument(set_logging_format=True)
destination = "HyperDX" if use_hyperdx else "OTEL Collector"
logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}")
# Return tracer and meter for the service
tracer = trace.get_tracer(__name__)
meter = metrics.get_meter(__name__)
return tracer, meter
# Configuration from environment
WHISPER_URL = os.environ.get("WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local")
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
# NATS subjects for streaming
STREAM_SUBJECT_PREFIX = "ai.voice.stream" # Full subject: ai.voice.stream.{session_id}
TRANSCRIPTION_SUBJECT_PREFIX = "ai.voice.transcription" # Full subject: ai.voice.transcription.{session_id}
TRANSCRIPTION_SUBJECT_PREFIX = (
"ai.voice.transcription" # Full subject: ai.voice.transcription.{session_id}
)
# Streaming parameters
BUFFER_SIZE_BYTES = int(os.environ.get("STT_BUFFER_SIZE_BYTES", "512000")) # ~5 seconds at 16kHz 16-bit
CHUNK_TIMEOUT_SECONDS = float(os.environ.get("STT_CHUNK_TIMEOUT", "2.0")) # Process after 2s of silence
BUFFER_SIZE_BYTES = int(
os.environ.get("STT_BUFFER_SIZE_BYTES", "512000")
) # ~5 seconds at 16kHz 16-bit
CHUNK_TIMEOUT_SECONDS = float(
os.environ.get("STT_CHUNK_TIMEOUT", "2.0")
) # Process after 2s of silence
MAX_BUFFER_SIZE_BYTES = int(os.environ.get("STT_MAX_BUFFER_SIZE", "5120000")) # ~50 seconds max
# Audio constants
AUDIO_SAMPLE_MAX_INT16 = 32768.0 # Maximum value for 16-bit signed integer audio
VAD_VOICE_RATIO_THRESHOLD = float(os.environ.get("STT_VAD_VOICE_RATIO", "0.3")) # Min ratio of voice frames
VAD_VOICE_RATIO_THRESHOLD = float(
os.environ.get("STT_VAD_VOICE_RATIO", "0.3")
) # Min ratio of voice frames
# Voice Activity Detection (VAD) parameters
ENABLE_VAD = os.environ.get("STT_ENABLE_VAD", "true").lower() == "true"
VAD_AGGRESSIVENESS = int(os.environ.get("STT_VAD_AGGRESSIVENESS", "2")) # 0-3, higher = more aggressive
VAD_AGGRESSIVENESS = int(
os.environ.get("STT_VAD_AGGRESSIVENESS", "2")
) # 0-3, higher = more aggressive
VAD_FRAME_DURATION_MS = int(os.environ.get("STT_VAD_FRAME_DURATION", "30")) # 10, 20, or 30 ms
# Audio threshold for interrupt detection (when LLM is responding)
ENABLE_INTERRUPT_DETECTION = os.environ.get("STT_ENABLE_INTERRUPT_DETECTION", "true").lower() == "true"
ENABLE_INTERRUPT_DETECTION = (
os.environ.get("STT_ENABLE_INTERRUPT_DETECTION", "true").lower() == "true"
)
AUDIO_LEVEL_THRESHOLD = float(os.environ.get("STT_AUDIO_LEVEL_THRESHOLD", "0.02")) # RMS threshold
INTERRUPT_DURATION_THRESHOLD = float(os.environ.get("STT_INTERRUPT_DURATION", "0.5")) # Seconds of speech to trigger
INTERRUPT_DURATION_THRESHOLD = float(
os.environ.get("STT_INTERRUPT_DURATION", "0.5")
) # Seconds of speech to trigger
# Speaker diarization
ENABLE_SPEAKER_DIARIZATION = os.environ.get("STT_ENABLE_SPEAKER_DIARIZATION", "false").lower() == "true"
ENABLE_SPEAKER_DIARIZATION = (
os.environ.get("STT_ENABLE_SPEAKER_DIARIZATION", "false").lower() == "true"
)
# Session states
SESSION_STATE_LISTENING = "listening"
@@ -160,17 +183,17 @@ SESSION_STATE_RESPONDING = "responding"
def calculate_audio_rms(audio_data: bytes, sample_width: int = 2) -> float:
"""
Calculate RMS (Root Mean Square) audio level.
Args:
audio_data: Raw audio bytes
sample_width: Bytes per sample (2 for 16-bit audio)
Returns:
RMS level normalized to 0.0-1.0 range
"""
if len(audio_data) < sample_width:
return 0.0
# Convert bytes to numpy array of int16 samples
try:
samples = np.frombuffer(audio_data, dtype=np.int16)
@@ -186,30 +209,30 @@ def calculate_audio_rms(audio_data: bytes, sample_width: int = 2) -> float:
def detect_voice_activity(audio_data: bytes, sample_rate: int = 16000) -> bool:
"""
Detect if audio contains voice using WebRTC VAD.
Args:
audio_data: Raw PCM audio bytes (16-bit, mono)
sample_rate: Audio sample rate (8000, 16000, 32000, or 48000)
Returns:
True if voice is detected, False otherwise
"""
if not ENABLE_VAD:
return True # Assume voice present if VAD disabled
try:
vad = webrtcvad.Vad(VAD_AGGRESSIVENESS)
# WebRTC VAD requires specific frame sizes
# Frame duration must be 10, 20, or 30 ms
frame_size = int(sample_rate * VAD_FRAME_DURATION_MS / 1000) * 2 # *2 for 16-bit samples
# Process audio in frames
voice_frames = 0
total_frames = 0
for i in range(0, len(audio_data) - frame_size, frame_size):
frame = audio_data[i:i + frame_size]
frame = audio_data[i : i + frame_size]
if len(frame) == frame_size:
try:
is_speech = vad.is_speech(frame, sample_rate)
@@ -219,14 +242,14 @@ def detect_voice_activity(audio_data: bytes, sample_rate: int = 16000) -> bool:
except Exception as e:
logger.debug(f"VAD frame processing error: {e}")
continue
if total_frames == 0:
return False
# Consider voice detected if voice ratio exceeds threshold
voice_ratio = voice_frames / total_frames
return voice_ratio > VAD_VOICE_RATIO_THRESHOLD
except Exception as e:
logger.warning(f"VAD error: {e}")
return True # Default to voice present on error
@@ -234,7 +257,7 @@ def detect_voice_activity(audio_data: bytes, sample_rate: int = 16000) -> bool:
class AudioBuffer:
"""Manages audio chunks for a streaming session with VAD and speaker tracking."""
def __init__(self, session_id: str):
self.session_id = session_id
self.chunks = []
@@ -247,46 +270,54 @@ class AudioBuffer:
self.interrupt_start_time = None # Track when interrupt detection started
self.has_voice_activity = False # Track if voice was detected in recent chunks
self._last_chunk_vad_result = None # Cache VAD result for last chunk
def add_chunk(self, audio_data: bytes) -> None:
"""Add an audio chunk to the buffer and check for voice activity."""
self.chunks.append(audio_data)
self.total_bytes += len(audio_data)
self.last_chunk_time = time.time()
# Check for voice activity in this chunk and cache result
has_voice = detect_voice_activity(audio_data)
self.has_voice_activity = has_voice
self._last_chunk_vad_result = has_voice
logger.debug(f"Session {self.session_id}: Added chunk, total {self.total_bytes} bytes, voice={has_voice}")
logger.debug(
f"Session {self.session_id}: Added chunk, total {self.total_bytes} bytes, voice={has_voice}"
)
def check_interrupt(self, audio_data: bytes) -> bool:
"""
Check if audio indicates an interrupt during responding state.
Uses cached VAD result if available.
Returns:
True if interrupt detected, False otherwise
"""
if not ENABLE_INTERRUPT_DETECTION:
return False
if self.state != SESSION_STATE_RESPONDING:
return False
# Calculate audio level
rms_level = calculate_audio_rms(audio_data)
# Use cached VAD result if available to avoid duplicate processing
has_voice = self._last_chunk_vad_result if self._last_chunk_vad_result is not None else detect_voice_activity(audio_data)
has_voice = (
self._last_chunk_vad_result
if self._last_chunk_vad_result is not None
else detect_voice_activity(audio_data)
)
# Check if audio exceeds threshold and contains voice
if rms_level >= AUDIO_LEVEL_THRESHOLD and has_voice:
if self.interrupt_start_time is None:
self.interrupt_start_time = time.time()
logger.info(f"Session {self.session_id}: Potential interrupt detected (RMS={rms_level:.3f})")
logger.info(
f"Session {self.session_id}: Potential interrupt detected (RMS={rms_level:.3f})"
)
# Check if interrupt has lasted long enough
elapsed = time.time() - self.interrupt_start_time
if elapsed >= INTERRUPT_DURATION_THRESHOLD:
@@ -295,9 +326,9 @@ class AudioBuffer:
else:
# Reset interrupt timer if audio drops below threshold
self.interrupt_start_time = None
return False
def set_state(self, state: str) -> None:
"""Set the session state (listening or responding)."""
if state in (SESSION_STATE_LISTENING, SESSION_STATE_RESPONDING):
@@ -307,15 +338,18 @@ class AudioBuffer:
logger.info(f"Session {self.session_id}: State changed from {old_state} to {state}")
# Reset interrupt tracking when changing states
self.interrupt_start_time = None
def should_process(self) -> bool:
"""Determine if buffer should be processed now."""
# Don't process if no voice activity detected (unless buffer is full or timed out)
if ENABLE_VAD and not self.has_voice_activity:
# Still process if buffer is very large or has timed out
if self.total_bytes < BUFFER_SIZE_BYTES and time.time() - self.last_chunk_time < CHUNK_TIMEOUT_SECONDS:
return False
if (
ENABLE_VAD
and not self.has_voice_activity
and self.total_bytes < BUFFER_SIZE_BYTES
and time.time() - self.last_chunk_time < CHUNK_TIMEOUT_SECONDS
):
return False
# Process if buffer size threshold reached
if self.total_bytes >= BUFFER_SIZE_BYTES:
return True
@@ -323,21 +357,19 @@ class AudioBuffer:
if time.time() - self.last_chunk_time > CHUNK_TIMEOUT_SECONDS and self.total_bytes > 0:
return True
# Process if buffer is too large (safety limit)
if self.total_bytes >= MAX_BUFFER_SIZE_BYTES:
return True
return False
return self.total_bytes >= MAX_BUFFER_SIZE_BYTES
def get_audio(self) -> bytes:
"""Get concatenated audio data."""
return b''.join(self.chunks)
return b"".join(self.chunks)
def clear(self) -> None:
"""Clear the buffer after processing."""
self.chunks = []
self.total_bytes = 0
self.sequence += 1
self._last_chunk_vad_result = None # Clear cached VAD result
def mark_complete(self) -> None:
"""Mark stream as complete."""
self.is_complete = True
@@ -345,12 +377,12 @@ class AudioBuffer:
class StreamingSTT:
"""Streaming Speech-to-Text service."""
def __init__(self):
self.nc = None
self.js = None
self.http_client = None
self.sessions: Dict[str, AudioBuffer] = {}
self.sessions: dict[str, AudioBuffer] = {}
self.running = True
self.processing_tasks = {}
self.is_healthy = False
@@ -358,32 +390,32 @@ class StreamingSTT:
self.meter = None
self.stream_counter = None
self.transcription_duration = None
async def setup(self):
"""Initialize connections."""
# Initialize OpenTelemetry
self.tracer, self.meter = setup_telemetry()
# Create metrics if OTEL is enabled
if self.meter:
self.stream_counter = self.meter.create_counter(
name="stt_streams_total",
description="Total number of STT streams processed",
unit="1"
unit="1",
)
self.transcription_duration = self.meter.create_histogram(
name="stt_transcription_duration_seconds",
description="Duration of STT transcription",
unit="s"
unit="s",
)
# NATS connection
self.nc = await nats.connect(NATS_URL)
logger.info(f"Connected to NATS at {NATS_URL}")
# Initialize JetStream context
self.js = self.nc.jetstream()
# Create or update stream for voice stream messages
try:
stream_config = nats.js.api.StreamConfig(
@@ -398,21 +430,20 @@ class StreamingSTT:
except Exception as e:
# Stream might already exist
logger.info(f"JetStream stream setup: {e}")
# HTTP client for Whisper service
self.http_client = httpx.AsyncClient(timeout=180.0)
logger.info("HTTP client initialized")
# Mark as healthy once connections are established
self.is_healthy = True
async def transcribe(self, audio_bytes: bytes) -> Optional[str]:
async def transcribe(self, audio_bytes: bytes) -> str | None:
"""Transcribe audio using Whisper."""
try:
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
response = await self.http_client.post(
f"{WHISPER_URL}/v1/audio/transcriptions",
files=files
f"{WHISPER_URL}/v1/audio/transcriptions", files=files
)
response.raise_for_status()
result = response.json()
@@ -422,22 +453,24 @@ class StreamingSTT:
except Exception as e:
logger.error(f"Transcription failed: {e}")
return None
async def process_buffer(self, session_id: str):
"""Process accumulated audio buffer for a session."""
buffer = self.sessions.get(session_id)
if not buffer:
return
audio_data = buffer.get_audio()
if not audio_data:
return
logger.info(f"Processing {len(audio_data)} bytes for session {session_id}, sequence {buffer.sequence}")
logger.info(
f"Processing {len(audio_data)} bytes for session {session_id}, sequence {buffer.sequence}"
)
# Transcribe
transcript = await self.transcribe(audio_data)
if transcript:
# Publish transcription result using msgpack binary format
result = {
@@ -449,24 +482,25 @@ class StreamingSTT:
"timestamp": time.time(),
"speaker_id": buffer.speaker_id,
"has_voice_activity": buffer.has_voice_activity,
"state": buffer.state
"state": buffer.state,
}
await self.nc.publish(
f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}",
msgpack.packb(result)
f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}", msgpack.packb(result)
)
logger.info(f"Published transcription for session {session_id} (seq {buffer.sequence}, speaker={buffer.speaker_id})")
logger.info(
f"Published transcription for session {session_id} (seq {buffer.sequence}, speaker={buffer.speaker_id})"
)
# Clear buffer after processing
buffer.clear()
# Clean up completed sessions asynchronously
if buffer.is_complete:
logger.info(f"Session {session_id} completed")
# Schedule cleanup task to avoid blocking
asyncio.create_task(self._cleanup_session(session_id))
async def _cleanup_session(self, session_id: str):
"""Clean up a completed session after a delay."""
# Keep session for a bit in case of late messages
@@ -476,34 +510,34 @@ class StreamingSTT:
logger.info(f"Cleaned up session: {session_id}")
if session_id in self.processing_tasks:
del self.processing_tasks[session_id]
async def monitor_buffer(self, session_id: str):
"""Monitor buffer and trigger processing when needed."""
while self.running and session_id in self.sessions:
buffer = self.sessions.get(session_id)
if not buffer:
break
if buffer.should_process():
await self.process_buffer(session_id)
# Don't spin too fast
await asyncio.sleep(0.1)
async def handle_stream_message(self, msg: Msg):
"""Handle incoming audio stream message."""
try:
# Extract session_id from subject: ai.voice.stream.{session_id}
subject_parts = msg.subject.split('.')
subject_parts = msg.subject.split(".")
if len(subject_parts) < 4:
logger.warning(f"Invalid subject format: {msg.subject}")
return
session_id = subject_parts[3]
# Parse message using msgpack binary format
data = msgpack.unpackb(msg.data, raw=False)
# Handle control messages
if data.get("type") == "start":
logger.info(f"Starting stream session: {session_id}")
@@ -520,19 +554,19 @@ class StreamingSTT:
task = asyncio.create_task(self.monitor_buffer(session_id))
self.processing_tasks[session_id] = task
return
if data.get("type") == "state_change":
logger.info(f"State change for session {session_id}")
buffer = self.sessions.get(session_id)
if buffer:
new_state = data.get("state", SESSION_STATE_LISTENING)
buffer.set_state(new_state)
# If switching to listening mode, reset any interrupt tracking
if new_state == SESSION_STATE_LISTENING:
buffer.interrupt_start_time = None
return
if data.get("type") == "end":
logger.info(f"Ending stream session: {session_id}")
buffer = self.sessions.get(session_id)
@@ -542,15 +576,15 @@ class StreamingSTT:
if buffer.total_bytes > 0:
await self.process_buffer(session_id)
return
# Handle audio chunk
if data.get("type") == "chunk":
audio_b64 = data.get("audio_b64", "")
if not audio_b64:
return
audio_bytes = base64.b64decode(audio_b64)
# Create session if it doesn't exist (handle missing start message)
# Check both sessions and processing_tasks to avoid race conditions
if session_id not in self.sessions:
@@ -560,9 +594,9 @@ class StreamingSTT:
if session_id not in self.processing_tasks:
task = asyncio.create_task(self.monitor_buffer(session_id))
self.processing_tasks[session_id] = task
buffer = self.sessions[session_id]
# Check for interrupt if in responding state
if buffer.check_interrupt(audio_bytes):
# Publish interrupt notification
@@ -570,57 +604,56 @@ class StreamingSTT:
"session_id": session_id,
"type": "interrupt",
"timestamp": time.time(),
"speaker_id": buffer.speaker_id
"speaker_id": buffer.speaker_id,
}
await self.nc.publish(
f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}",
msgpack.packb(interrupt_msg)
f"{TRANSCRIPTION_SUBJECT_PREFIX}.{session_id}", msgpack.packb(interrupt_msg)
)
logger.info(f"Published interrupt notification for session {session_id}")
# Automatically switch back to listening mode
buffer.set_state(SESSION_STATE_LISTENING)
# Add chunk to buffer
buffer.add_chunk(audio_bytes)
except Exception as e:
logger.error(f"Error handling stream message: {e}", exc_info=True)
async def run(self):
"""Main run loop."""
await self.setup()
# Note: STT streaming uses regular NATS subscribe (not pull-based JetStream consumer)
# because it handles real-time ephemeral audio streams with wildcard subscriptions.
# The stream audio chunks are not meant to be persisted long-term or replayed.
# However, the transcription RESULTS are published to JetStream for persistence.
sub = await self.nc.subscribe(f"{STREAM_SUBJECT_PREFIX}.>", cb=self.handle_stream_message)
logger.info(f"Subscribed to {STREAM_SUBJECT_PREFIX}.>")
# Handle shutdown
def signal_handler():
self.running = False
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
# Keep running
while self.running:
await asyncio.sleep(1)
# Cleanup
logger.info("Shutting down...")
# Cancel all monitoring tasks and wait for them to complete
for task in self.processing_tasks.values():
task.cancel()
# Wait for all tasks to complete or be cancelled
if self.processing_tasks:
await asyncio.gather(*self.processing_tasks.values(), return_exceptions=True)
await sub.unsubscribe()
await self.nc.close()
await self.http_client.aclose()