Files
ray-serve/ray_serve/serve_whisper.py
Billy D. 15e4b8afa3
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 11s
fix: make mlflow_logger import optional with no-op fallback
The strixhalo LLM worker uses py_executable pointing to the Docker
image venv which doesn't have the updated ray-serve-apps package.
Wrap all InferenceLogger imports in try/except and guard usage with
None checks so apps degrade gracefully without MLflow logging.
2026-02-12 07:01:17 -05:00

173 lines
5.0 KiB
Python

"""
Ray Serve deployment for faster-whisper STT.
Runs on: elminster (RTX 2070 8GB, CUDA)
"""
import base64
import io
import os
import time
from typing import Any
from ray import serve
try:
from ray_serve.mlflow_logger import InferenceLogger
except ImportError:
InferenceLogger = None
@serve.deployment(name="WhisperDeployment", num_replicas=1)
class WhisperDeployment:
def __init__(self):
import torch
from faster_whisper import WhisperModel
self.model_size = os.environ.get("MODEL_SIZE", "large-v3")
# Detect device and compute type
if torch.cuda.is_available():
self.device = "cuda"
self.compute_type = "float16"
else:
self.device = "cpu"
self.compute_type = "int8"
print(f"Loading Whisper model: {self.model_size}")
print(f"Using device: {self.device}, compute_type: {self.compute_type}")
self.model = WhisperModel(
self.model_size,
device=self.device,
compute_type=self.compute_type,
)
print("Whisper model loaded successfully")
# MLflow metrics
if InferenceLogger is not None:
self._mlflow = InferenceLogger(
experiment_name="ray-serve-whisper",
run_name=f"whisper-{self.model_size}",
tags={"model.name": f"whisper-{self.model_size}", "model.framework": "faster-whisper", "device": self.device},
flush_every=5,
)
self._mlflow.initialize(
params={"model_size": self.model_size, "device": self.device, "compute_type": self.compute_type}
)
else:
self._mlflow = None
async def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
"""
Handle transcription requests.
Expected request format:
{
"audio": "base64_encoded_audio_data",
"audio_format": "wav",
"language": "en",
"task": "transcribe",
"response_format": "json",
"word_timestamps": false
}
Alternative with file path:
{
"file": "/path/to/audio.wav",
...
}
"""
_start = time.time()
language = request.get("language")
task = request.get("task", "transcribe") # transcribe or translate
response_format = request.get("response_format", "json")
word_timestamps = request.get("word_timestamps", False)
# Get audio data
audio_input = None
if "audio" in request:
# Base64 encoded audio
audio_bytes = base64.b64decode(request["audio"])
audio_input = io.BytesIO(audio_bytes)
elif "file" in request:
# File path
audio_input = request["file"]
elif "audio_bytes" in request:
# Raw bytes
audio_input = io.BytesIO(request["audio_bytes"])
else:
return {
"error": "No audio data provided. Use 'audio' (base64), 'file' (path), or 'audio_bytes'",
}
# Transcribe
segments, info = self.model.transcribe(
audio_input,
language=language,
task=task,
word_timestamps=word_timestamps,
vad_filter=True,
)
# Collect segments
segment_list = []
full_text = ""
for segment in segments:
seg_data = {
"id": segment.id,
"start": segment.start,
"end": segment.end,
"text": segment.text,
}
if word_timestamps and segment.words:
seg_data["words"] = [
{
"word": word.word,
"start": word.start,
"end": word.end,
"probability": word.probability,
}
for word in segment.words
]
segment_list.append(seg_data)
full_text += segment.text
# Build response based on format
if response_format == "text":
return {"text": full_text.strip()}
if response_format == "verbose_json":
return {
"task": task,
"language": info.language,
"duration": info.duration,
"text": full_text.strip(),
"segments": segment_list,
}
# Log to MLflow
if self._mlflow:
self._mlflow.log_request(
latency_s=time.time() - _start,
audio_duration_s=info.duration,
segments=len(segment_list),
realtime_factor=(time.time() - _start) / info.duration if info.duration > 0 else 0,
)
# Default JSON format (OpenAI-compatible)
return {
"text": full_text.strip(),
"language": info.language,
"duration": info.duration,
"model": self.model_size,
}
app = WhisperDeployment.bind()