All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 11s
The strixhalo LLM worker uses py_executable pointing to the Docker image venv which doesn't have the updated ray-serve-apps package. Wrap all InferenceLogger imports in try/except and guard usage with None checks so apps degrade gracefully without MLflow logging.
173 lines
5.0 KiB
Python
173 lines
5.0 KiB
Python
"""
|
|
Ray Serve deployment for faster-whisper STT.
|
|
Runs on: elminster (RTX 2070 8GB, CUDA)
|
|
"""
|
|
|
|
import base64
|
|
import io
|
|
import os
|
|
import time
|
|
from typing import Any
|
|
|
|
from ray import serve
|
|
|
|
try:
|
|
from ray_serve.mlflow_logger import InferenceLogger
|
|
except ImportError:
|
|
InferenceLogger = None
|
|
|
|
|
|
@serve.deployment(name="WhisperDeployment", num_replicas=1)
|
|
class WhisperDeployment:
|
|
def __init__(self):
|
|
import torch
|
|
from faster_whisper import WhisperModel
|
|
|
|
self.model_size = os.environ.get("MODEL_SIZE", "large-v3")
|
|
|
|
# Detect device and compute type
|
|
if torch.cuda.is_available():
|
|
self.device = "cuda"
|
|
self.compute_type = "float16"
|
|
else:
|
|
self.device = "cpu"
|
|
self.compute_type = "int8"
|
|
|
|
print(f"Loading Whisper model: {self.model_size}")
|
|
print(f"Using device: {self.device}, compute_type: {self.compute_type}")
|
|
|
|
self.model = WhisperModel(
|
|
self.model_size,
|
|
device=self.device,
|
|
compute_type=self.compute_type,
|
|
)
|
|
|
|
print("Whisper model loaded successfully")
|
|
|
|
# MLflow metrics
|
|
if InferenceLogger is not None:
|
|
self._mlflow = InferenceLogger(
|
|
experiment_name="ray-serve-whisper",
|
|
run_name=f"whisper-{self.model_size}",
|
|
tags={"model.name": f"whisper-{self.model_size}", "model.framework": "faster-whisper", "device": self.device},
|
|
flush_every=5,
|
|
)
|
|
self._mlflow.initialize(
|
|
params={"model_size": self.model_size, "device": self.device, "compute_type": self.compute_type}
|
|
)
|
|
else:
|
|
self._mlflow = None
|
|
|
|
async def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
|
|
"""
|
|
Handle transcription requests.
|
|
|
|
Expected request format:
|
|
{
|
|
"audio": "base64_encoded_audio_data",
|
|
"audio_format": "wav",
|
|
"language": "en",
|
|
"task": "transcribe",
|
|
"response_format": "json",
|
|
"word_timestamps": false
|
|
}
|
|
|
|
Alternative with file path:
|
|
{
|
|
"file": "/path/to/audio.wav",
|
|
...
|
|
}
|
|
"""
|
|
|
|
_start = time.time()
|
|
language = request.get("language")
|
|
task = request.get("task", "transcribe") # transcribe or translate
|
|
response_format = request.get("response_format", "json")
|
|
word_timestamps = request.get("word_timestamps", False)
|
|
|
|
# Get audio data
|
|
audio_input = None
|
|
|
|
if "audio" in request:
|
|
# Base64 encoded audio
|
|
audio_bytes = base64.b64decode(request["audio"])
|
|
audio_input = io.BytesIO(audio_bytes)
|
|
elif "file" in request:
|
|
# File path
|
|
audio_input = request["file"]
|
|
elif "audio_bytes" in request:
|
|
# Raw bytes
|
|
audio_input = io.BytesIO(request["audio_bytes"])
|
|
else:
|
|
return {
|
|
"error": "No audio data provided. Use 'audio' (base64), 'file' (path), or 'audio_bytes'",
|
|
}
|
|
|
|
# Transcribe
|
|
segments, info = self.model.transcribe(
|
|
audio_input,
|
|
language=language,
|
|
task=task,
|
|
word_timestamps=word_timestamps,
|
|
vad_filter=True,
|
|
)
|
|
|
|
# Collect segments
|
|
segment_list = []
|
|
full_text = ""
|
|
|
|
for segment in segments:
|
|
seg_data = {
|
|
"id": segment.id,
|
|
"start": segment.start,
|
|
"end": segment.end,
|
|
"text": segment.text,
|
|
}
|
|
|
|
if word_timestamps and segment.words:
|
|
seg_data["words"] = [
|
|
{
|
|
"word": word.word,
|
|
"start": word.start,
|
|
"end": word.end,
|
|
"probability": word.probability,
|
|
}
|
|
for word in segment.words
|
|
]
|
|
|
|
segment_list.append(seg_data)
|
|
full_text += segment.text
|
|
|
|
# Build response based on format
|
|
if response_format == "text":
|
|
return {"text": full_text.strip()}
|
|
|
|
if response_format == "verbose_json":
|
|
return {
|
|
"task": task,
|
|
"language": info.language,
|
|
"duration": info.duration,
|
|
"text": full_text.strip(),
|
|
"segments": segment_list,
|
|
}
|
|
|
|
# Log to MLflow
|
|
if self._mlflow:
|
|
self._mlflow.log_request(
|
|
latency_s=time.time() - _start,
|
|
audio_duration_s=info.duration,
|
|
segments=len(segment_list),
|
|
realtime_factor=(time.time() - _start) / info.duration if info.duration > 0 else 0,
|
|
)
|
|
|
|
# Default JSON format (OpenAI-compatible)
|
|
return {
|
|
"text": full_text.strip(),
|
|
"language": info.language,
|
|
"duration": info.duration,
|
|
"model": self.model_size,
|
|
}
|
|
|
|
|
|
app = WhisperDeployment.bind()
|