feat: add MLflow inference logging to all Ray Serve apps
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 16s
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 16s
- Add mlflow_logger.py: lightweight REST-based MLflow logger (no mlflow dep) - Instrument serve_llm.py with latency, token counts, tokens/sec metrics - Instrument serve_embeddings.py with latency, batch_size, total_tokens - Instrument serve_whisper.py with latency, audio_duration, realtime_factor - Instrument serve_tts.py with latency, audio_duration, text_chars - Instrument serve_reranker.py with latency, num_pairs, top_k
This commit is contained in:
211
ray_serve/mlflow_logger.py
Normal file
211
ray_serve/mlflow_logger.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Lightweight MLflow metrics logger using the REST API.
|
||||||
|
|
||||||
|
Avoids importing the heavyweight mlflow package — uses only stdlib
|
||||||
|
urllib so it works inside any Ray Serve actor without extra pip deps.
|
||||||
|
|
||||||
|
Each deployment creates **one persistent MLflow run** on startup and
|
||||||
|
logs per-request metrics with an incrementing step counter. This
|
||||||
|
gives time-series charts in the MLflow UI. The run is terminated
|
||||||
|
when the actor shuts down (or left RUNNING if the process crashes).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import urllib.error
|
||||||
|
import urllib.request
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _RunState:
|
||||||
|
run_id: str
|
||||||
|
experiment_id: str
|
||||||
|
step: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceLogger:
|
||||||
|
"""Per-deployment MLflow metrics logger backed by the REST API.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
experiment_name:
|
||||||
|
MLflow experiment name (created if missing).
|
||||||
|
run_name:
|
||||||
|
Human-readable run name shown in the MLflow UI.
|
||||||
|
tracking_uri:
|
||||||
|
MLflow tracking server. Defaults to ``MLFLOW_TRACKING_URI`` env var
|
||||||
|
or the in-cluster service address.
|
||||||
|
tags:
|
||||||
|
Extra tags attached to the run (e.g. model name, GPU, node).
|
||||||
|
flush_every:
|
||||||
|
Batch this many metric points before flushing (reduces HTTP calls).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experiment_name: str,
|
||||||
|
run_name: str,
|
||||||
|
tracking_uri: str | None = None,
|
||||||
|
tags: dict[str, str] | None = None,
|
||||||
|
flush_every: int = 1,
|
||||||
|
):
|
||||||
|
self._base = (
|
||||||
|
tracking_uri
|
||||||
|
or os.environ.get("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80")
|
||||||
|
).rstrip("/")
|
||||||
|
self._experiment_name = experiment_name
|
||||||
|
self._run_name = run_name
|
||||||
|
self._tags = tags or {}
|
||||||
|
self._flush_every = max(1, flush_every)
|
||||||
|
|
||||||
|
self._state: _RunState | None = None
|
||||||
|
self._buffer: list[dict[str, Any]] = []
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._enabled = True
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def initialize(self, params: dict[str, str] | None = None) -> None:
|
||||||
|
"""Create experiment + run. Safe to call from ``__init__``."""
|
||||||
|
try:
|
||||||
|
exp_id = self._get_or_create_experiment()
|
||||||
|
run_id = self._create_run(exp_id)
|
||||||
|
self._state = _RunState(run_id=run_id, experiment_id=exp_id)
|
||||||
|
if params:
|
||||||
|
self._log_params(params)
|
||||||
|
# Register cleanup
|
||||||
|
atexit.register(self._end_run)
|
||||||
|
logger.info("MLflow run started: %s (experiment=%s)", run_id, self._experiment_name)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("MLflow init failed — metrics will not be logged", exc_info=True)
|
||||||
|
self._enabled = False
|
||||||
|
|
||||||
|
def log_request(self, **metrics: float) -> None:
|
||||||
|
"""Log one set of metrics for a single inference request.
|
||||||
|
|
||||||
|
Metrics are buffered and flushed every ``flush_every`` calls.
|
||||||
|
"""
|
||||||
|
if not self._enabled or not self._state:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._state.step += 1
|
||||||
|
step = self._state.step
|
||||||
|
ts = int(time.time() * 1000)
|
||||||
|
|
||||||
|
for key, value in metrics.items():
|
||||||
|
self._buffer.append(
|
||||||
|
{"key": key, "value": value, "timestamp": ts, "step": step}
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(self._buffer) >= self._flush_every * len(metrics):
|
||||||
|
self._flush()
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
"""Force-flush any buffered metrics."""
|
||||||
|
with self._lock:
|
||||||
|
self._flush()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# REST helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _post(self, path: str, body: dict) -> dict:
|
||||||
|
url = f"{self._base}/api/2.0/mlflow/{path}"
|
||||||
|
data = json.dumps(body).encode()
|
||||||
|
req = urllib.request.Request(
|
||||||
|
url, data=data, headers={"Content-Type": "application/json"}, method="POST"
|
||||||
|
)
|
||||||
|
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||||
|
return json.loads(resp.read().decode())
|
||||||
|
|
||||||
|
def _get_or_create_experiment(self) -> str:
|
||||||
|
try:
|
||||||
|
resp = self._post(
|
||||||
|
"experiments/get-by-name",
|
||||||
|
{"experiment_name": self._experiment_name},
|
||||||
|
)
|
||||||
|
return resp["experiment"]["experiment_id"]
|
||||||
|
except urllib.error.HTTPError:
|
||||||
|
resp = self._post(
|
||||||
|
"experiments/create",
|
||||||
|
{"name": self._experiment_name},
|
||||||
|
)
|
||||||
|
return resp["experiment_id"]
|
||||||
|
|
||||||
|
def _create_run(self, experiment_id: str) -> str:
|
||||||
|
tags = [
|
||||||
|
{"key": k, "value": v}
|
||||||
|
for k, v in {
|
||||||
|
"mlflow.runName": self._run_name,
|
||||||
|
"mlflow.source.type": "LOCAL",
|
||||||
|
"hostname": os.environ.get("HOSTNAME", "unknown"),
|
||||||
|
"namespace": os.environ.get("POD_NAMESPACE", "unknown"),
|
||||||
|
**self._tags,
|
||||||
|
}.items()
|
||||||
|
]
|
||||||
|
resp = self._post(
|
||||||
|
"runs/create",
|
||||||
|
{
|
||||||
|
"experiment_id": experiment_id,
|
||||||
|
"run_name": self._run_name,
|
||||||
|
"start_time": int(time.time() * 1000),
|
||||||
|
"tags": tags,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return resp["run"]["info"]["run_id"]
|
||||||
|
|
||||||
|
def _log_params(self, params: dict[str, str]) -> None:
|
||||||
|
if not self._state:
|
||||||
|
return
|
||||||
|
param_list = [{"key": k, "value": str(v)[:500]} for k, v in params.items()]
|
||||||
|
try:
|
||||||
|
self._post(
|
||||||
|
"runs/log-batch",
|
||||||
|
{"run_id": self._state.run_id, "params": param_list},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to log params", exc_info=True)
|
||||||
|
|
||||||
|
def _flush(self) -> None:
|
||||||
|
"""Send buffered metrics in a single `log-batch` call."""
|
||||||
|
if not self._buffer or not self._state:
|
||||||
|
return
|
||||||
|
batch = self._buffer[:]
|
||||||
|
self._buffer.clear()
|
||||||
|
try:
|
||||||
|
self._post(
|
||||||
|
"runs/log-batch",
|
||||||
|
{"run_id": self._state.run_id, "metrics": batch},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to flush %d metrics", len(batch), exc_info=True)
|
||||||
|
|
||||||
|
def _end_run(self) -> None:
|
||||||
|
"""Mark the MLflow run as FINISHED."""
|
||||||
|
if not self._state:
|
||||||
|
return
|
||||||
|
self._flush()
|
||||||
|
try:
|
||||||
|
self._post(
|
||||||
|
"runs/update",
|
||||||
|
{
|
||||||
|
"run_id": self._state.run_id,
|
||||||
|
"status": "FINISHED",
|
||||||
|
"end_time": int(time.time() * 1000),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.info("MLflow run %s ended", self._state.run_id)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to end MLflow run", exc_info=True)
|
||||||
@@ -4,10 +4,13 @@ Runs on: drizzt (Radeon 680M iGPU, ROCm)
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ray import serve
|
from ray import serve
|
||||||
|
|
||||||
|
from ray_serve.mlflow_logger import InferenceLogger
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment(name="EmbeddingsDeployment", num_replicas=1)
|
@serve.deployment(name="EmbeddingsDeployment", num_replicas=1)
|
||||||
class EmbeddingsDeployment:
|
class EmbeddingsDeployment:
|
||||||
@@ -33,6 +36,17 @@ class EmbeddingsDeployment:
|
|||||||
|
|
||||||
print(f"Model loaded. Embedding dimension: {self.embedding_dim}")
|
print(f"Model loaded. Embedding dimension: {self.embedding_dim}")
|
||||||
|
|
||||||
|
# MLflow metrics
|
||||||
|
self._mlflow = InferenceLogger(
|
||||||
|
experiment_name="ray-serve-embeddings",
|
||||||
|
run_name=f"embeddings-{self.model_id.split('/')[-1]}",
|
||||||
|
tags={"model.name": self.model_id, "model.framework": "sentence-transformers", "device": self.device},
|
||||||
|
flush_every=10,
|
||||||
|
)
|
||||||
|
self._mlflow.initialize(
|
||||||
|
params={"model_id": self.model_id, "embedding_dim": str(self.embedding_dim), "device": self.device}
|
||||||
|
)
|
||||||
|
|
||||||
async def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
|
async def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Handle OpenAI-compatible embedding requests.
|
Handle OpenAI-compatible embedding requests.
|
||||||
@@ -46,6 +60,8 @@ class EmbeddingsDeployment:
|
|||||||
"""
|
"""
|
||||||
input_data = request.get("input", "")
|
input_data = request.get("input", "")
|
||||||
|
|
||||||
|
_start = time.time()
|
||||||
|
|
||||||
# Handle both single string and list of strings
|
# Handle both single string and list of strings
|
||||||
texts = [input_data] if isinstance(input_data, str) else input_data
|
texts = [input_data] if isinstance(input_data, str) else input_data
|
||||||
|
|
||||||
@@ -69,6 +85,13 @@ class EmbeddingsDeployment:
|
|||||||
)
|
)
|
||||||
total_tokens += len(text.split())
|
total_tokens += len(text.split())
|
||||||
|
|
||||||
|
# Log to MLflow
|
||||||
|
self._mlflow.log_request(
|
||||||
|
latency_s=time.time() - _start,
|
||||||
|
batch_size=len(texts),
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
# Return OpenAI-compatible response
|
# Return OpenAI-compatible response
|
||||||
return {
|
return {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from typing import Any
|
|||||||
|
|
||||||
from ray import serve
|
from ray import serve
|
||||||
|
|
||||||
|
from ray_serve.mlflow_logger import InferenceLogger
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment(name="LLMDeployment", num_replicas=1)
|
@serve.deployment(name="LLMDeployment", num_replicas=1)
|
||||||
class LLMDeployment:
|
class LLMDeployment:
|
||||||
@@ -37,6 +39,21 @@ class LLMDeployment:
|
|||||||
self.SamplingParams = SamplingParams
|
self.SamplingParams = SamplingParams
|
||||||
print(f"Model {self.model_id} async engine created")
|
print(f"Model {self.model_id} async engine created")
|
||||||
|
|
||||||
|
# MLflow metrics
|
||||||
|
self._mlflow = InferenceLogger(
|
||||||
|
experiment_name="ray-serve-llm",
|
||||||
|
run_name=f"llm-{self.model_id.split('/')[-1]}",
|
||||||
|
tags={"model.name": self.model_id, "model.framework": "vllm", "gpu": "strixhalo"},
|
||||||
|
flush_every=5,
|
||||||
|
)
|
||||||
|
self._mlflow.initialize(
|
||||||
|
params={
|
||||||
|
"model_id": self.model_id,
|
||||||
|
"max_model_len": str(self.max_model_len),
|
||||||
|
"gpu_memory_utilization": str(self.gpu_memory_utilization),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
async def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
|
async def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Handle OpenAI-compatible chat completion requests.
|
Handle OpenAI-compatible chat completion requests.
|
||||||
@@ -67,11 +84,27 @@ class LLMDeployment:
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
request_id = uuid.uuid4().hex
|
request_id = uuid.uuid4().hex
|
||||||
final_result = None
|
final_result = None
|
||||||
async for result in self.engine.generate(prompt, sampling_params, request_id):
|
async for result in self.engine.generate(prompt, sampling_params, request_id):
|
||||||
final_result = result
|
final_result = result
|
||||||
generated_text = final_result.outputs[0].text
|
generated_text = final_result.outputs[0].text
|
||||||
|
latency = time.time() - start_time
|
||||||
|
|
||||||
|
prompt_tokens = len(prompt.split())
|
||||||
|
completion_tokens = len(generated_text.split())
|
||||||
|
|
||||||
|
# Log to MLflow
|
||||||
|
self._mlflow.log_request(
|
||||||
|
latency_s=latency,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
tokens_per_second=completion_tokens / latency if latency > 0 else 0,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens_requested=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
# Return OpenAI-compatible response
|
# Return OpenAI-compatible response
|
||||||
return {
|
return {
|
||||||
@@ -90,9 +123,9 @@ class LLMDeployment:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": len(prompt.split()),
|
"prompt_tokens": prompt_tokens,
|
||||||
"completion_tokens": len(generated_text.split()),
|
"completion_tokens": completion_tokens,
|
||||||
"total_tokens": len(prompt.split()) + len(generated_text.split()),
|
"total_tokens": prompt_tokens + completion_tokens,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,10 +4,13 @@ Runs on: drizzt (Radeon 680M iGPU, ROCm) or danilo (Intel i915 iGPU, OpenVINO/IP
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ray import serve
|
from ray import serve
|
||||||
|
|
||||||
|
from ray_serve.mlflow_logger import InferenceLogger
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment(name="RerankerDeployment", num_replicas=1)
|
@serve.deployment(name="RerankerDeployment", num_replicas=1)
|
||||||
class RerankerDeployment:
|
class RerankerDeployment:
|
||||||
@@ -58,6 +61,17 @@ class RerankerDeployment:
|
|||||||
|
|
||||||
print("Reranker model loaded successfully")
|
print("Reranker model loaded successfully")
|
||||||
|
|
||||||
|
# MLflow metrics
|
||||||
|
self._mlflow = InferenceLogger(
|
||||||
|
experiment_name="ray-serve-reranker",
|
||||||
|
run_name=f"reranker-{self.model_id.split('/')[-1]}",
|
||||||
|
tags={"model.name": self.model_id, "model.framework": "sentence-transformers", "device": self.device},
|
||||||
|
flush_every=10,
|
||||||
|
)
|
||||||
|
self._mlflow.initialize(
|
||||||
|
params={"model_id": self.model_id, "device": self.device, "use_ipex": str(self.use_ipex)}
|
||||||
|
)
|
||||||
|
|
||||||
async def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
|
async def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Handle reranking requests.
|
Handle reranking requests.
|
||||||
@@ -75,6 +89,8 @@ class RerankerDeployment:
|
|||||||
"pairs": [["query", "doc1"], ["query", "doc2"]]
|
"pairs": [["query", "doc1"], ["query", "doc2"]]
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
_start = time.time()
|
||||||
|
|
||||||
# Handle pairs format
|
# Handle pairs format
|
||||||
if "pairs" in request:
|
if "pairs" in request:
|
||||||
pairs = request["pairs"]
|
pairs = request["pairs"]
|
||||||
@@ -89,6 +105,11 @@ class RerankerDeployment:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._mlflow.log_request(
|
||||||
|
latency_s=time.time() - _start,
|
||||||
|
num_pairs=len(pairs),
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"results": results,
|
"results": results,
|
||||||
@@ -131,6 +152,14 @@ class RerankerDeployment:
|
|||||||
# Apply top_k
|
# Apply top_k
|
||||||
results = results[:top_k]
|
results = results[:top_k]
|
||||||
|
|
||||||
|
# Log to MLflow
|
||||||
|
self._mlflow.log_request(
|
||||||
|
latency_s=time.time() - _start,
|
||||||
|
num_pairs=len(pairs),
|
||||||
|
num_documents=len(documents),
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"results": results,
|
"results": results,
|
||||||
|
|||||||
@@ -6,10 +6,13 @@ Runs on: elminster (RTX 2070 8GB, CUDA)
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ray import serve
|
from ray import serve
|
||||||
|
|
||||||
|
from ray_serve.mlflow_logger import InferenceLogger
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment(name="TTSDeployment", num_replicas=1)
|
@serve.deployment(name="TTSDeployment", num_replicas=1)
|
||||||
class TTSDeployment:
|
class TTSDeployment:
|
||||||
@@ -32,6 +35,15 @@ class TTSDeployment:
|
|||||||
|
|
||||||
print("TTS model loaded successfully")
|
print("TTS model loaded successfully")
|
||||||
|
|
||||||
|
# MLflow metrics
|
||||||
|
self._mlflow = InferenceLogger(
|
||||||
|
experiment_name="ray-serve-tts",
|
||||||
|
run_name=f"tts-{self.model_name.split('/')[-1]}",
|
||||||
|
tags={"model.name": self.model_name, "model.framework": "coqui-tts", "gpu": str(self.use_gpu)},
|
||||||
|
flush_every=5,
|
||||||
|
)
|
||||||
|
self._mlflow.initialize(params={"model_name": self.model_name, "use_gpu": str(self.use_gpu)})
|
||||||
|
|
||||||
async def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
|
async def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Handle text-to-speech requests.
|
Handle text-to-speech requests.
|
||||||
@@ -49,6 +61,7 @@ class TTSDeployment:
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
|
|
||||||
|
_start = time.time()
|
||||||
text = request.get("text", "")
|
text = request.get("text", "")
|
||||||
speaker = request.get("speaker")
|
speaker = request.get("speaker")
|
||||||
language = request.get("language")
|
language = request.get("language")
|
||||||
@@ -88,10 +101,20 @@ class TTSDeployment:
|
|||||||
wavfile.write(buffer, sample_rate, wav_int16)
|
wavfile.write(buffer, sample_rate, wav_int16)
|
||||||
audio_bytes = buffer.getvalue()
|
audio_bytes = buffer.getvalue()
|
||||||
|
|
||||||
|
duration = len(wav) / sample_rate
|
||||||
|
|
||||||
|
# Log to MLflow
|
||||||
|
self._mlflow.log_request(
|
||||||
|
latency_s=time.time() - _start,
|
||||||
|
audio_duration_s=duration,
|
||||||
|
text_chars=len(text),
|
||||||
|
realtime_factor=(time.time() - _start) / duration if duration > 0 else 0,
|
||||||
|
)
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"sample_rate": sample_rate,
|
"sample_rate": sample_rate,
|
||||||
"duration": len(wav) / sample_rate,
|
"duration": duration,
|
||||||
"format": output_format,
|
"format": output_format,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,10 +6,13 @@ Runs on: elminster (RTX 2070 8GB, CUDA)
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ray import serve
|
from ray import serve
|
||||||
|
|
||||||
|
from ray_serve.mlflow_logger import InferenceLogger
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment(name="WhisperDeployment", num_replicas=1)
|
@serve.deployment(name="WhisperDeployment", num_replicas=1)
|
||||||
class WhisperDeployment:
|
class WhisperDeployment:
|
||||||
@@ -38,6 +41,17 @@ class WhisperDeployment:
|
|||||||
|
|
||||||
print("Whisper model loaded successfully")
|
print("Whisper model loaded successfully")
|
||||||
|
|
||||||
|
# MLflow metrics
|
||||||
|
self._mlflow = InferenceLogger(
|
||||||
|
experiment_name="ray-serve-whisper",
|
||||||
|
run_name=f"whisper-{self.model_size}",
|
||||||
|
tags={"model.name": f"whisper-{self.model_size}", "model.framework": "faster-whisper", "device": self.device},
|
||||||
|
flush_every=5,
|
||||||
|
)
|
||||||
|
self._mlflow.initialize(
|
||||||
|
params={"model_size": self.model_size, "device": self.device, "compute_type": self.compute_type}
|
||||||
|
)
|
||||||
|
|
||||||
async def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
|
async def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Handle transcription requests.
|
Handle transcription requests.
|
||||||
@@ -59,6 +73,7 @@ class WhisperDeployment:
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_start = time.time()
|
||||||
language = request.get("language")
|
language = request.get("language")
|
||||||
task = request.get("task", "transcribe") # transcribe or translate
|
task = request.get("task", "transcribe") # transcribe or translate
|
||||||
response_format = request.get("response_format", "json")
|
response_format = request.get("response_format", "json")
|
||||||
@@ -130,6 +145,14 @@ class WhisperDeployment:
|
|||||||
"segments": segment_list,
|
"segments": segment_list,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Log to MLflow
|
||||||
|
self._mlflow.log_request(
|
||||||
|
latency_s=time.time() - _start,
|
||||||
|
audio_duration_s=info.duration,
|
||||||
|
segments=len(segment_list),
|
||||||
|
realtime_factor=(time.time() - _start) / info.duration if info.duration > 0 else 0,
|
||||||
|
)
|
||||||
|
|
||||||
# Default JSON format (OpenAI-compatible)
|
# Default JSON format (OpenAI-compatible)
|
||||||
return {
|
return {
|
||||||
"text": full_text.strip(),
|
"text": full_text.strip(),
|
||||||
|
|||||||
Reference in New Issue
Block a user