feat: add py.typed, Ray handles for clients, and pre-commit config
- Add py.typed marker for PEP 561 type hint support - Add ray_utils module for Ray handle detection and caching - Update all clients (Embeddings, LLM, TTS, STT, Reranker) to use Ray handles when running inside Ray cluster for faster internal calls - Add .pre-commit-config.yaml with ruff and standard hooks - Add pre-commit and ray[serve] to optional dependencies - Bump ruff version to 0.4.0
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
"""
|
||||
Embeddings service client (Infinity/BGE).
|
||||
|
||||
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from handler_base.config import EmbeddingsSettings
|
||||
from handler_base.ray_utils import get_ray_handle
|
||||
from handler_base.telemetry import create_span
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -17,17 +20,33 @@ class EmbeddingsClient:
|
||||
"""
|
||||
Client for the embeddings service (Infinity with BGE models).
|
||||
|
||||
When running inside Ray, automatically uses Ray handles for faster
|
||||
internal communication. Falls back to HTTP for external calls.
|
||||
|
||||
Usage:
|
||||
client = EmbeddingsClient()
|
||||
embeddings = await client.embed(["Hello world"])
|
||||
"""
|
||||
|
||||
# Ray Serve deployment configuration
|
||||
RAY_DEPLOYMENT_NAME = "EmbeddingsDeployment"
|
||||
RAY_APP_NAME = "embeddings"
|
||||
|
||||
def __init__(self, settings: Optional[EmbeddingsSettings] = None):
|
||||
self.settings = settings or EmbeddingsSettings()
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.settings.embeddings_url,
|
||||
timeout=self.settings.http_timeout,
|
||||
)
|
||||
self._ray_handle: Optional[Any] = None
|
||||
self._ray_checked = False
|
||||
|
||||
def _get_ray_handle(self) -> Optional[Any]:
|
||||
"""Get Ray handle, checking only once."""
|
||||
if not self._ray_checked:
|
||||
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
|
||||
self._ray_checked = True
|
||||
return self._ray_handle
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
@@ -55,6 +74,23 @@ class EmbeddingsClient:
|
||||
span.set_attribute("embeddings.model", model)
|
||||
span.set_attribute("embeddings.batch_size", len(texts))
|
||||
|
||||
# Try Ray handle first (faster internal path)
|
||||
handle = self._get_ray_handle()
|
||||
if handle:
|
||||
try:
|
||||
if span:
|
||||
span.set_attribute("embeddings.transport", "ray")
|
||||
result = await handle.embed.remote(texts, model)
|
||||
if span and result:
|
||||
span.set_attribute("embeddings.dimensions", len(result[0]) if result else 0)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
|
||||
|
||||
# HTTP fallback
|
||||
if span:
|
||||
span.set_attribute("embeddings.transport", "http")
|
||||
|
||||
response = await self._client.post(
|
||||
"/embeddings",
|
||||
json={"input": texts, "model": model},
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
"""
|
||||
LLM service client (vLLM/OpenAI-compatible).
|
||||
|
||||
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import AsyncIterator, Optional
|
||||
from typing import Any, AsyncIterator, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from handler_base.config import LLMSettings
|
||||
from handler_base.ray_utils import get_ray_handle
|
||||
from handler_base.telemetry import create_span
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -17,6 +20,9 @@ class LLMClient:
|
||||
"""
|
||||
Client for the LLM service (vLLM with OpenAI-compatible API).
|
||||
|
||||
When running inside Ray, automatically uses Ray handles for faster
|
||||
internal communication. Falls back to HTTP for external calls.
|
||||
|
||||
Usage:
|
||||
client = LLMClient()
|
||||
response = await client.generate("Hello, how are you?")
|
||||
@@ -32,12 +38,25 @@ class LLMClient:
|
||||
print(chunk, end="")
|
||||
"""
|
||||
|
||||
# Ray Serve deployment configuration
|
||||
RAY_DEPLOYMENT_NAME = "VLLMDeployment"
|
||||
RAY_APP_NAME = "llm"
|
||||
|
||||
def __init__(self, settings: Optional[LLMSettings] = None):
|
||||
self.settings = settings or LLMSettings()
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.settings.llm_url,
|
||||
timeout=self.settings.http_timeout,
|
||||
)
|
||||
self._ray_handle: Optional[Any] = None
|
||||
self._ray_checked = False
|
||||
|
||||
def _get_ray_handle(self) -> Optional[Any]:
|
||||
"""Get Ray handle, checking only once."""
|
||||
if not self._ray_checked:
|
||||
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
|
||||
self._ray_checked = True
|
||||
return self._ray_handle
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
@@ -87,6 +106,24 @@ class LLMClient:
|
||||
if stop:
|
||||
payload["stop"] = stop
|
||||
|
||||
# Try Ray handle first (faster internal path)
|
||||
handle = self._get_ray_handle()
|
||||
if handle:
|
||||
try:
|
||||
if span:
|
||||
span.set_attribute("llm.transport", "ray")
|
||||
result = await handle.remote(payload)
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
if span:
|
||||
span.set_attribute("llm.response_length", len(content))
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
|
||||
|
||||
# HTTP fallback
|
||||
if span:
|
||||
span.set_attribute("llm.transport", "http")
|
||||
|
||||
response = await self._client.post("/v1/chat/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
"""
|
||||
Reranker service client (Infinity/BGE Reranker).
|
||||
|
||||
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from handler_base.config import Settings
|
||||
from handler_base.ray_utils import get_ray_handle
|
||||
from handler_base.telemetry import create_span
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -17,17 +20,33 @@ class RerankerClient:
|
||||
"""
|
||||
Client for the reranker service (Infinity with BGE Reranker).
|
||||
|
||||
When running inside Ray, automatically uses Ray handles for faster
|
||||
internal communication. Falls back to HTTP for external calls.
|
||||
|
||||
Usage:
|
||||
client = RerankerClient()
|
||||
reranked = await client.rerank("query", ["doc1", "doc2"])
|
||||
"""
|
||||
|
||||
# Ray Serve deployment configuration
|
||||
RAY_DEPLOYMENT_NAME = "RerankerDeployment"
|
||||
RAY_APP_NAME = "reranker"
|
||||
|
||||
def __init__(self, settings: Optional[Settings] = None):
|
||||
self.settings = settings or Settings()
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.settings.reranker_url,
|
||||
timeout=self.settings.http_timeout,
|
||||
)
|
||||
self._ray_handle: Optional[Any] = None
|
||||
self._ray_checked = False
|
||||
|
||||
def _get_ray_handle(self) -> Optional[Any]:
|
||||
"""Get Ray handle, checking only once."""
|
||||
if not self._ray_checked:
|
||||
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
|
||||
self._ray_checked = True
|
||||
return self._ray_handle
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
@@ -64,6 +83,32 @@ class RerankerClient:
|
||||
if top_k:
|
||||
payload["top_n"] = top_k
|
||||
|
||||
# Try Ray handle first (faster internal path)
|
||||
handle = self._get_ray_handle()
|
||||
if handle:
|
||||
try:
|
||||
if span:
|
||||
span.set_attribute("reranker.transport", "ray")
|
||||
results = await handle.rerank.remote(query, documents, top_k)
|
||||
# Enrich with original documents
|
||||
enriched = []
|
||||
for r in results:
|
||||
idx = r.get("index", 0)
|
||||
enriched.append(
|
||||
{
|
||||
"index": idx,
|
||||
"score": r.get("relevance_score", r.get("score", 0)),
|
||||
"document": documents[idx] if idx < len(documents) else "",
|
||||
}
|
||||
)
|
||||
return enriched
|
||||
except Exception as e:
|
||||
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
|
||||
|
||||
# HTTP fallback
|
||||
if span:
|
||||
span.set_attribute("reranker.transport", "http")
|
||||
|
||||
response = await self._client.post("/rerank", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
"""
|
||||
STT service client (Whisper/faster-whisper).
|
||||
|
||||
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from handler_base.config import STTSettings
|
||||
from handler_base.ray_utils import get_ray_handle
|
||||
from handler_base.telemetry import create_span
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -17,17 +20,33 @@ class STTClient:
|
||||
"""
|
||||
Client for the STT service (Whisper/faster-whisper).
|
||||
|
||||
When running inside Ray, automatically uses Ray handles for faster
|
||||
internal communication. Falls back to HTTP for external calls.
|
||||
|
||||
Usage:
|
||||
client = STTClient()
|
||||
text = await client.transcribe(audio_bytes)
|
||||
"""
|
||||
|
||||
# Ray Serve deployment configuration
|
||||
RAY_DEPLOYMENT_NAME = "WhisperDeployment"
|
||||
RAY_APP_NAME = "whisper"
|
||||
|
||||
def __init__(self, settings: Optional[STTSettings] = None):
|
||||
self.settings = settings or STTSettings()
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.settings.stt_url,
|
||||
timeout=180.0, # Transcription can be slow
|
||||
)
|
||||
self._ray_handle: Optional[Any] = None
|
||||
self._ray_checked = False
|
||||
|
||||
def _get_ray_handle(self) -> Optional[Any]:
|
||||
"""Get Ray handle, checking only once."""
|
||||
if not self._ray_checked:
|
||||
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
|
||||
self._ray_checked = True
|
||||
return self._ray_handle
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
@@ -62,6 +81,25 @@ class STTClient:
|
||||
if language:
|
||||
span.set_attribute("stt.language", language)
|
||||
|
||||
# Try Ray handle first (faster internal path)
|
||||
handle = self._get_ray_handle()
|
||||
if handle:
|
||||
try:
|
||||
if span:
|
||||
span.set_attribute("stt.transport", "ray")
|
||||
result = await handle.transcribe.remote(audio, language, task)
|
||||
if span:
|
||||
span.set_attribute("stt.result_length", len(result.get("text", "")))
|
||||
if result.get("language"):
|
||||
span.set_attribute("stt.detected_language", result["language"])
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
|
||||
|
||||
# HTTP fallback
|
||||
if span:
|
||||
span.set_attribute("stt.transport", "http")
|
||||
|
||||
files = {"file": ("audio.wav", audio, "audio/wav")}
|
||||
data = {
|
||||
"response_format": response_format,
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
"""
|
||||
TTS service client (Coqui XTTS).
|
||||
|
||||
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from handler_base.config import TTSSettings
|
||||
from handler_base.ray_utils import get_ray_handle
|
||||
from handler_base.telemetry import create_span
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -17,17 +20,33 @@ class TTSClient:
|
||||
"""
|
||||
Client for the TTS service (Coqui XTTS).
|
||||
|
||||
When running inside Ray, automatically uses Ray handles for faster
|
||||
internal communication. Falls back to HTTP for external calls.
|
||||
|
||||
Usage:
|
||||
client = TTSClient()
|
||||
audio_bytes = await client.synthesize("Hello world")
|
||||
"""
|
||||
|
||||
# Ray Serve deployment configuration
|
||||
RAY_DEPLOYMENT_NAME = "TTSDeployment"
|
||||
RAY_APP_NAME = "tts"
|
||||
|
||||
def __init__(self, settings: Optional[TTSSettings] = None):
|
||||
self.settings = settings or TTSSettings()
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.settings.tts_url,
|
||||
timeout=120.0, # TTS can be slow
|
||||
)
|
||||
self._ray_handle: Optional[Any] = None
|
||||
self._ray_checked = False
|
||||
|
||||
def _get_ray_handle(self) -> Optional[Any]:
|
||||
"""Get Ray handle, checking only once."""
|
||||
if not self._ray_checked:
|
||||
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
|
||||
self._ray_checked = True
|
||||
return self._ray_handle
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
@@ -64,6 +83,23 @@ class TTSClient:
|
||||
if speaker:
|
||||
params["speaker_id"] = speaker
|
||||
|
||||
# Try Ray handle first (faster internal path)
|
||||
handle = self._get_ray_handle()
|
||||
if handle:
|
||||
try:
|
||||
if span:
|
||||
span.set_attribute("tts.transport", "ray")
|
||||
audio_bytes = await handle.synthesize.remote(text, language, speaker)
|
||||
if span:
|
||||
span.set_attribute("tts.audio_size", len(audio_bytes))
|
||||
return audio_bytes
|
||||
except Exception as e:
|
||||
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
|
||||
|
||||
# HTTP fallback
|
||||
if span:
|
||||
span.set_attribute("tts.transport", "http")
|
||||
|
||||
response = await self._client.get("/api/tts", params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
0
handler_base/py.typed
Normal file
0
handler_base/py.typed
Normal file
70
handler_base/ray_utils.py
Normal file
70
handler_base/ray_utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
Ray integration utilities for handler-base clients.
|
||||
|
||||
When running inside a Ray cluster, clients can use Ray Serve handles
|
||||
for faster internal communication (gRPC instead of HTTP).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Ray handle cache to avoid repeated lookups
|
||||
_ray_handles: dict[str, Any] = {}
|
||||
_ray_available: Optional[bool] = None
|
||||
|
||||
|
||||
def is_ray_available() -> bool:
|
||||
"""Check if we're running inside a Ray cluster."""
|
||||
global _ray_available
|
||||
if _ray_available is not None:
|
||||
return _ray_available
|
||||
|
||||
try:
|
||||
import ray
|
||||
|
||||
_ray_available = ray.is_initialized()
|
||||
if _ray_available:
|
||||
logger.info("Ray detected - will use Ray handles for internal calls")
|
||||
return _ray_available
|
||||
except ImportError:
|
||||
_ray_available = False
|
||||
return False
|
||||
|
||||
|
||||
def get_ray_handle(deployment_name: str, app_name: str) -> Optional[Any]:
|
||||
"""
|
||||
Get a Ray Serve deployment handle for internal calls.
|
||||
|
||||
Args:
|
||||
deployment_name: Name of the Ray Serve deployment
|
||||
app_name: Name of the Ray Serve application
|
||||
|
||||
Returns:
|
||||
DeploymentHandle if available, None otherwise
|
||||
"""
|
||||
if not is_ray_available():
|
||||
return None
|
||||
|
||||
cache_key = f"{app_name}/{deployment_name}"
|
||||
if cache_key in _ray_handles:
|
||||
return _ray_handles[cache_key]
|
||||
|
||||
try:
|
||||
from ray import serve
|
||||
|
||||
handle = serve.get_deployment_handle(deployment_name, app_name=app_name)
|
||||
_ray_handles[cache_key] = handle
|
||||
logger.debug(f"Got Ray handle for {cache_key}")
|
||||
return handle
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get Ray handle for {cache_key}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def clear_ray_handles() -> None:
|
||||
"""Clear cached Ray handles (useful for testing)."""
|
||||
global _ray_handles, _ray_available
|
||||
_ray_handles.clear()
|
||||
_ray_available = None
|
||||
Reference in New Issue
Block a user