fix: auto-fix ruff linting errors and remove unsupported upload-artifact
This commit is contained in:
@@ -57,12 +57,6 @@ jobs:
|
|||||||
- name: Run tests with coverage
|
- name: Run tests with coverage
|
||||||
run: uv run pytest --cov=handler_base --cov-report=xml --cov-report=term
|
run: uv run pytest --cov=handler_base --cov-report=xml --cov-report=term
|
||||||
|
|
||||||
- name: Upload coverage artifact
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: coverage
|
|
||||||
path: coverage.xml
|
|
||||||
|
|
||||||
release:
|
release:
|
||||||
name: Release
|
name: Release
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|||||||
@@ -8,11 +8,12 @@ Provides consistent patterns for:
|
|||||||
- Graceful shutdown
|
- Graceful shutdown
|
||||||
- Service client wrappers
|
- Service client wrappers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from handler_base.config import Settings
|
from handler_base.config import Settings
|
||||||
from handler_base.handler import Handler
|
from handler_base.handler import Handler
|
||||||
from handler_base.health import HealthServer
|
from handler_base.health import HealthServer
|
||||||
from handler_base.nats_client import NATSClient
|
from handler_base.nats_client import NATSClient
|
||||||
from handler_base.telemetry import setup_telemetry, get_tracer, get_meter
|
from handler_base.telemetry import get_meter, get_tracer, setup_telemetry
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Handler",
|
"Handler",
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
Service client wrappers for AI/ML backends.
|
Service client wrappers for AI/ML backends.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from handler_base.clients.embeddings import EmbeddingsClient
|
from handler_base.clients.embeddings import EmbeddingsClient
|
||||||
from handler_base.clients.reranker import RerankerClient
|
|
||||||
from handler_base.clients.llm import LLMClient
|
from handler_base.clients.llm import LLMClient
|
||||||
from handler_base.clients.tts import TTSClient
|
|
||||||
from handler_base.clients.stt import STTClient
|
|
||||||
from handler_base.clients.milvus import MilvusClient
|
from handler_base.clients.milvus import MilvusClient
|
||||||
|
from handler_base.clients.reranker import RerankerClient
|
||||||
|
from handler_base.clients.stt import STTClient
|
||||||
|
from handler_base.clients.tts import TTSClient
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EmbeddingsClient",
|
"EmbeddingsClient",
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Embeddings service client (Infinity/BGE).
|
Embeddings service client (Infinity/BGE).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -15,23 +16,23 @@ logger = logging.getLogger(__name__)
|
|||||||
class EmbeddingsClient:
|
class EmbeddingsClient:
|
||||||
"""
|
"""
|
||||||
Client for the embeddings service (Infinity with BGE models).
|
Client for the embeddings service (Infinity with BGE models).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
client = EmbeddingsClient()
|
client = EmbeddingsClient()
|
||||||
embeddings = await client.embed(["Hello world"])
|
embeddings = await client.embed(["Hello world"])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, settings: Optional[EmbeddingsSettings] = None):
|
def __init__(self, settings: Optional[EmbeddingsSettings] = None):
|
||||||
self.settings = settings or EmbeddingsSettings()
|
self.settings = settings or EmbeddingsSettings()
|
||||||
self._client = httpx.AsyncClient(
|
self._client = httpx.AsyncClient(
|
||||||
base_url=self.settings.embeddings_url,
|
base_url=self.settings.embeddings_url,
|
||||||
timeout=self.settings.http_timeout,
|
timeout=self.settings.http_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the HTTP client."""
|
"""Close the HTTP client."""
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
|
|
||||||
async def embed(
|
async def embed(
|
||||||
self,
|
self,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
@@ -39,49 +40,49 @@ class EmbeddingsClient:
|
|||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""
|
"""
|
||||||
Generate embeddings for a list of texts.
|
Generate embeddings for a list of texts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: List of texts to embed
|
texts: List of texts to embed
|
||||||
model: Model name (defaults to settings)
|
model: Model name (defaults to settings)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of embedding vectors
|
List of embedding vectors
|
||||||
"""
|
"""
|
||||||
model = model or self.settings.embeddings_model
|
model = model or self.settings.embeddings_model
|
||||||
|
|
||||||
with create_span("embeddings.embed") as span:
|
with create_span("embeddings.embed") as span:
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("embeddings.model", model)
|
span.set_attribute("embeddings.model", model)
|
||||||
span.set_attribute("embeddings.batch_size", len(texts))
|
span.set_attribute("embeddings.batch_size", len(texts))
|
||||||
|
|
||||||
response = await self._client.post(
|
response = await self._client.post(
|
||||||
"/embeddings",
|
"/embeddings",
|
||||||
json={"input": texts, "model": model},
|
json={"input": texts, "model": model},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
result = response.json()
|
result = response.json()
|
||||||
embeddings = [d["embedding"] for d in result.get("data", [])]
|
embeddings = [d["embedding"] for d in result.get("data", [])]
|
||||||
|
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("embeddings.dimensions", len(embeddings[0]) if embeddings else 0)
|
span.set_attribute("embeddings.dimensions", len(embeddings[0]) if embeddings else 0)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
async def embed_single(self, text: str, model: Optional[str] = None) -> list[float]:
|
async def embed_single(self, text: str, model: Optional[str] = None) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Generate embedding for a single text.
|
Generate embedding for a single text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Text to embed
|
text: Text to embed
|
||||||
model: Model name (defaults to settings)
|
model: Model name (defaults to settings)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Embedding vector
|
Embedding vector
|
||||||
"""
|
"""
|
||||||
embeddings = await self.embed([text], model)
|
embeddings = await self.embed([text], model)
|
||||||
return embeddings[0] if embeddings else []
|
return embeddings[0] if embeddings else []
|
||||||
|
|
||||||
async def health(self) -> bool:
|
async def health(self) -> bool:
|
||||||
"""Check if the embeddings service is healthy."""
|
"""Check if the embeddings service is healthy."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
LLM service client (vLLM/OpenAI-compatible).
|
LLM service client (vLLM/OpenAI-compatible).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, AsyncIterator
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@@ -15,33 +16,33 @@ logger = logging.getLogger(__name__)
|
|||||||
class LLMClient:
|
class LLMClient:
|
||||||
"""
|
"""
|
||||||
Client for the LLM service (vLLM with OpenAI-compatible API).
|
Client for the LLM service (vLLM with OpenAI-compatible API).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
client = LLMClient()
|
client = LLMClient()
|
||||||
response = await client.generate("Hello, how are you?")
|
response = await client.generate("Hello, how are you?")
|
||||||
|
|
||||||
# With context for RAG
|
# With context for RAG
|
||||||
response = await client.generate(
|
response = await client.generate(
|
||||||
"What is the capital?",
|
"What is the capital?",
|
||||||
context="France is a country in Europe..."
|
context="France is a country in Europe..."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Streaming
|
# Streaming
|
||||||
async for chunk in client.stream("Tell me a story"):
|
async for chunk in client.stream("Tell me a story"):
|
||||||
print(chunk, end="")
|
print(chunk, end="")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, settings: Optional[LLMSettings] = None):
|
def __init__(self, settings: Optional[LLMSettings] = None):
|
||||||
self.settings = settings or LLMSettings()
|
self.settings = settings or LLMSettings()
|
||||||
self._client = httpx.AsyncClient(
|
self._client = httpx.AsyncClient(
|
||||||
base_url=self.settings.llm_url,
|
base_url=self.settings.llm_url,
|
||||||
timeout=self.settings.http_timeout,
|
timeout=self.settings.http_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the HTTP client."""
|
"""Close the HTTP client."""
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -54,7 +55,7 @@ class LLMClient:
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a response from the LLM.
|
Generate a response from the LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: User prompt/query
|
prompt: User prompt/query
|
||||||
context: Optional context for RAG
|
context: Optional context for RAG
|
||||||
@@ -63,19 +64,19 @@ class LLMClient:
|
|||||||
temperature: Sampling temperature
|
temperature: Sampling temperature
|
||||||
top_p: Top-p sampling
|
top_p: Top-p sampling
|
||||||
stop: Stop sequences
|
stop: Stop sequences
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Generated text response
|
Generated text response
|
||||||
"""
|
"""
|
||||||
with create_span("llm.generate") as span:
|
with create_span("llm.generate") as span:
|
||||||
messages = self._build_messages(prompt, context, system_prompt)
|
messages = self._build_messages(prompt, context, system_prompt)
|
||||||
|
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("llm.model", self.settings.llm_model)
|
span.set_attribute("llm.model", self.settings.llm_model)
|
||||||
span.set_attribute("llm.prompt_length", len(prompt))
|
span.set_attribute("llm.prompt_length", len(prompt))
|
||||||
if context:
|
if context:
|
||||||
span.set_attribute("llm.context_length", len(context))
|
span.set_attribute("llm.context_length", len(context))
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.settings.llm_model,
|
"model": self.settings.llm_model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
@@ -85,21 +86,21 @@ class LLMClient:
|
|||||||
}
|
}
|
||||||
if stop:
|
if stop:
|
||||||
payload["stop"] = stop
|
payload["stop"] = stop
|
||||||
|
|
||||||
response = await self._client.post("/v1/chat/completions", json=payload)
|
response = await self._client.post("/v1/chat/completions", json=payload)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
result = response.json()
|
result = response.json()
|
||||||
content = result["choices"][0]["message"]["content"]
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("llm.response_length", len(content))
|
span.set_attribute("llm.response_length", len(content))
|
||||||
usage = result.get("usage", {})
|
usage = result.get("usage", {})
|
||||||
span.set_attribute("llm.prompt_tokens", usage.get("prompt_tokens", 0))
|
span.set_attribute("llm.prompt_tokens", usage.get("prompt_tokens", 0))
|
||||||
span.set_attribute("llm.completion_tokens", usage.get("completion_tokens", 0))
|
span.set_attribute("llm.completion_tokens", usage.get("completion_tokens", 0))
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
async def stream(
|
async def stream(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -110,19 +111,19 @@ class LLMClient:
|
|||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
"""
|
"""
|
||||||
Stream a response from the LLM.
|
Stream a response from the LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: User prompt/query
|
prompt: User prompt/query
|
||||||
context: Optional context for RAG
|
context: Optional context for RAG
|
||||||
system_prompt: Optional system prompt
|
system_prompt: Optional system prompt
|
||||||
max_tokens: Maximum tokens to generate
|
max_tokens: Maximum tokens to generate
|
||||||
temperature: Sampling temperature
|
temperature: Sampling temperature
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Text chunks as they're generated
|
Text chunks as they're generated
|
||||||
"""
|
"""
|
||||||
messages = self._build_messages(prompt, context, system_prompt)
|
messages = self._build_messages(prompt, context, system_prompt)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.settings.llm_model,
|
"model": self.settings.llm_model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
@@ -130,25 +131,24 @@ class LLMClient:
|
|||||||
"temperature": temperature or self.settings.llm_temperature,
|
"temperature": temperature or self.settings.llm_temperature,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
async with self._client.stream(
|
async with self._client.stream("POST", "/v1/chat/completions", json=payload) as response:
|
||||||
"POST", "/v1/chat/completions", json=payload
|
|
||||||
) as response:
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
async for line in response.aiter_lines():
|
async for line in response.aiter_lines():
|
||||||
if line.startswith("data: "):
|
if line.startswith("data: "):
|
||||||
data = line[6:]
|
data = line[6:]
|
||||||
if data == "[DONE]":
|
if data == "[DONE]":
|
||||||
break
|
break
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
chunk = json.loads(data)
|
chunk = json.loads(data)
|
||||||
delta = chunk["choices"][0].get("delta", {})
|
delta = chunk["choices"][0].get("delta", {})
|
||||||
content = delta.get("content", "")
|
content = delta.get("content", "")
|
||||||
if content:
|
if content:
|
||||||
yield content
|
yield content
|
||||||
|
|
||||||
def _build_messages(
|
def _build_messages(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -157,32 +157,36 @@ class LLMClient:
|
|||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Build the messages list for the API call."""
|
"""Build the messages list for the API call."""
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
# System prompt
|
# System prompt
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
elif context:
|
elif context:
|
||||||
# Default RAG system prompt
|
# Default RAG system prompt
|
||||||
messages.append({
|
messages.append(
|
||||||
"role": "system",
|
{
|
||||||
"content": (
|
"role": "system",
|
||||||
"You are a helpful assistant. Use the provided context to answer "
|
"content": (
|
||||||
"the user's question. If the context doesn't contain relevant "
|
"You are a helpful assistant. Use the provided context to answer "
|
||||||
"information, say so."
|
"the user's question. If the context doesn't contain relevant "
|
||||||
),
|
"information, say so."
|
||||||
})
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Add context as a separate message if provided
|
# Add context as a separate message if provided
|
||||||
if context:
|
if context:
|
||||||
messages.append({
|
messages.append(
|
||||||
"role": "user",
|
{
|
||||||
"content": f"Context:\n{context}\n\nQuestion: {prompt}",
|
"role": "user",
|
||||||
})
|
"content": f"Context:\n{context}\n\nQuestion: {prompt}",
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def health(self) -> bool:
|
async def health(self) -> bool:
|
||||||
"""Check if the LLM service is healthy."""
|
"""Check if the LLM service is healthy."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
"""
|
"""
|
||||||
Milvus vector database client.
|
Milvus vector database client.
|
||||||
"""
|
"""
|
||||||
import logging
|
|
||||||
from typing import Optional, Any
|
|
||||||
|
|
||||||
from pymilvus import connections, Collection, utility
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pymilvus import Collection, connections, utility
|
||||||
|
|
||||||
from handler_base.config import Settings
|
from handler_base.config import Settings
|
||||||
from handler_base.telemetry import create_span
|
from handler_base.telemetry import create_span
|
||||||
@@ -15,42 +16,42 @@ logger = logging.getLogger(__name__)
|
|||||||
class MilvusClient:
|
class MilvusClient:
|
||||||
"""
|
"""
|
||||||
Client for Milvus vector database.
|
Client for Milvus vector database.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
client = MilvusClient()
|
client = MilvusClient()
|
||||||
await client.connect()
|
await client.connect()
|
||||||
results = await client.search(embedding, limit=10)
|
results = await client.search(embedding, limit=10)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, settings: Optional[Settings] = None):
|
def __init__(self, settings: Optional[Settings] = None):
|
||||||
self.settings = settings or Settings()
|
self.settings = settings or Settings()
|
||||||
self._connected = False
|
self._connected = False
|
||||||
self._collection: Optional[Collection] = None
|
self._collection: Optional[Collection] = None
|
||||||
|
|
||||||
async def connect(self, collection_name: Optional[str] = None) -> None:
|
async def connect(self, collection_name: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Connect to Milvus and load collection.
|
Connect to Milvus and load collection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
collection_name: Collection to use (defaults to settings)
|
collection_name: Collection to use (defaults to settings)
|
||||||
"""
|
"""
|
||||||
collection_name = collection_name or self.settings.milvus_collection
|
collection_name = collection_name or self.settings.milvus_collection
|
||||||
|
|
||||||
connections.connect(
|
connections.connect(
|
||||||
alias="default",
|
alias="default",
|
||||||
host=self.settings.milvus_host,
|
host=self.settings.milvus_host,
|
||||||
port=self.settings.milvus_port,
|
port=self.settings.milvus_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
if utility.has_collection(collection_name):
|
if utility.has_collection(collection_name):
|
||||||
self._collection = Collection(collection_name)
|
self._collection = Collection(collection_name)
|
||||||
self._collection.load()
|
self._collection.load()
|
||||||
logger.info(f"Connected to Milvus collection: {collection_name}")
|
logger.info(f"Connected to Milvus collection: {collection_name}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Collection {collection_name} does not exist")
|
logger.warning(f"Collection {collection_name} does not exist")
|
||||||
|
|
||||||
self._connected = True
|
self._connected = True
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close Milvus connection."""
|
"""Close Milvus connection."""
|
||||||
if self._collection:
|
if self._collection:
|
||||||
@@ -58,7 +59,7 @@ class MilvusClient:
|
|||||||
connections.disconnect("default")
|
connections.disconnect("default")
|
||||||
self._connected = False
|
self._connected = False
|
||||||
logger.info("Disconnected from Milvus")
|
logger.info("Disconnected from Milvus")
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
embedding: list[float],
|
embedding: list[float],
|
||||||
@@ -68,26 +69,26 @@ class MilvusClient:
|
|||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Search for similar vectors.
|
Search for similar vectors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embedding: Query embedding vector
|
embedding: Query embedding vector
|
||||||
limit: Maximum number of results
|
limit: Maximum number of results
|
||||||
output_fields: Fields to return (default: all)
|
output_fields: Fields to return (default: all)
|
||||||
filter_expr: Optional filter expression
|
filter_expr: Optional filter expression
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of results with 'id', 'distance', and requested fields
|
List of results with 'id', 'distance', and requested fields
|
||||||
"""
|
"""
|
||||||
if not self._collection:
|
if not self._collection:
|
||||||
raise RuntimeError("Not connected to collection")
|
raise RuntimeError("Not connected to collection")
|
||||||
|
|
||||||
with create_span("milvus.search") as span:
|
with create_span("milvus.search") as span:
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("milvus.collection", self._collection.name)
|
span.set_attribute("milvus.collection", self._collection.name)
|
||||||
span.set_attribute("milvus.limit", limit)
|
span.set_attribute("milvus.limit", limit)
|
||||||
|
|
||||||
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
||||||
|
|
||||||
results = self._collection.search(
|
results = self._collection.search(
|
||||||
data=[embedding],
|
data=[embedding],
|
||||||
anns_field="embedding",
|
anns_field="embedding",
|
||||||
@@ -96,7 +97,7 @@ class MilvusClient:
|
|||||||
output_fields=output_fields,
|
output_fields=output_fields,
|
||||||
expr=filter_expr,
|
expr=filter_expr,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to list of dicts
|
# Convert to list of dicts
|
||||||
hits = []
|
hits = []
|
||||||
for hit in results[0]:
|
for hit in results[0]:
|
||||||
@@ -111,12 +112,12 @@ class MilvusClient:
|
|||||||
if hasattr(hit.entity, field):
|
if hasattr(hit.entity, field):
|
||||||
item[field] = getattr(hit.entity, field)
|
item[field] = getattr(hit.entity, field)
|
||||||
hits.append(item)
|
hits.append(item)
|
||||||
|
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("milvus.results", len(hits))
|
span.set_attribute("milvus.results", len(hits))
|
||||||
|
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
async def search_with_texts(
|
async def search_with_texts(
|
||||||
self,
|
self,
|
||||||
embedding: list[float],
|
embedding: list[float],
|
||||||
@@ -126,22 +127,22 @@ class MilvusClient:
|
|||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Search and return text content with metadata.
|
Search and return text content with metadata.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embedding: Query embedding
|
embedding: Query embedding
|
||||||
limit: Maximum results
|
limit: Maximum results
|
||||||
text_field: Name of text field in collection
|
text_field: Name of text field in collection
|
||||||
metadata_fields: Additional metadata fields to return
|
metadata_fields: Additional metadata fields to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of results with text and metadata
|
List of results with text and metadata
|
||||||
"""
|
"""
|
||||||
output_fields = [text_field]
|
output_fields = [text_field]
|
||||||
if metadata_fields:
|
if metadata_fields:
|
||||||
output_fields.extend(metadata_fields)
|
output_fields.extend(metadata_fields)
|
||||||
|
|
||||||
return await self.search(embedding, limit, output_fields)
|
return await self.search(embedding, limit, output_fields)
|
||||||
|
|
||||||
async def insert(
|
async def insert(
|
||||||
self,
|
self,
|
||||||
embeddings: list[list[float]],
|
embeddings: list[list[float]],
|
||||||
@@ -149,34 +150,34 @@ class MilvusClient:
|
|||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Insert vectors with data into the collection.
|
Insert vectors with data into the collection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embeddings: List of embedding vectors
|
embeddings: List of embedding vectors
|
||||||
data: List of dicts with field values
|
data: List of dicts with field values
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of inserted IDs
|
List of inserted IDs
|
||||||
"""
|
"""
|
||||||
if not self._collection:
|
if not self._collection:
|
||||||
raise RuntimeError("Not connected to collection")
|
raise RuntimeError("Not connected to collection")
|
||||||
|
|
||||||
with create_span("milvus.insert") as span:
|
with create_span("milvus.insert") as span:
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("milvus.collection", self._collection.name)
|
span.set_attribute("milvus.collection", self._collection.name)
|
||||||
span.set_attribute("milvus.count", len(embeddings))
|
span.set_attribute("milvus.count", len(embeddings))
|
||||||
|
|
||||||
# Build insert data
|
# Build insert data
|
||||||
insert_data = [embeddings]
|
insert_data = [embeddings]
|
||||||
for field in self._collection.schema.fields:
|
for field in self._collection.schema.fields:
|
||||||
if field.name not in ("id", "embedding"):
|
if field.name not in ("id", "embedding"):
|
||||||
field_values = [d.get(field.name) for d in data]
|
field_values = [d.get(field.name) for d in data]
|
||||||
insert_data.append(field_values)
|
insert_data.append(field_values)
|
||||||
|
|
||||||
result = self._collection.insert(insert_data)
|
result = self._collection.insert(insert_data)
|
||||||
self._collection.flush()
|
self._collection.flush()
|
||||||
|
|
||||||
return result.primary_keys
|
return result.primary_keys
|
||||||
|
|
||||||
def health(self) -> bool:
|
def health(self) -> bool:
|
||||||
"""Check if connected to Milvus."""
|
"""Check if connected to Milvus."""
|
||||||
return self._connected and utility.get_connection_addr("default") is not None
|
return self._connected and utility.get_connection_addr("default") is not None
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Reranker service client (Infinity/BGE Reranker).
|
Reranker service client (Infinity/BGE Reranker).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -15,23 +16,23 @@ logger = logging.getLogger(__name__)
|
|||||||
class RerankerClient:
|
class RerankerClient:
|
||||||
"""
|
"""
|
||||||
Client for the reranker service (Infinity with BGE Reranker).
|
Client for the reranker service (Infinity with BGE Reranker).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
client = RerankerClient()
|
client = RerankerClient()
|
||||||
reranked = await client.rerank("query", ["doc1", "doc2"])
|
reranked = await client.rerank("query", ["doc1", "doc2"])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, settings: Optional[Settings] = None):
|
def __init__(self, settings: Optional[Settings] = None):
|
||||||
self.settings = settings or Settings()
|
self.settings = settings or Settings()
|
||||||
self._client = httpx.AsyncClient(
|
self._client = httpx.AsyncClient(
|
||||||
base_url=self.settings.reranker_url,
|
base_url=self.settings.reranker_url,
|
||||||
timeout=self.settings.http_timeout,
|
timeout=self.settings.http_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the HTTP client."""
|
"""Close the HTTP client."""
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
|
|
||||||
async def rerank(
|
async def rerank(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@@ -40,12 +41,12 @@ class RerankerClient:
|
|||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Rerank documents based on relevance to query.
|
Rerank documents based on relevance to query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Query text
|
query: Query text
|
||||||
documents: List of documents to rerank
|
documents: List of documents to rerank
|
||||||
top_k: Number of top results to return (default: all)
|
top_k: Number of top results to return (default: all)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of dicts with 'index', 'score', and 'document' keys,
|
List of dicts with 'index', 'score', and 'document' keys,
|
||||||
sorted by relevance score descending.
|
sorted by relevance score descending.
|
||||||
@@ -55,32 +56,34 @@ class RerankerClient:
|
|||||||
span.set_attribute("reranker.num_documents", len(documents))
|
span.set_attribute("reranker.num_documents", len(documents))
|
||||||
if top_k:
|
if top_k:
|
||||||
span.set_attribute("reranker.top_k", top_k)
|
span.set_attribute("reranker.top_k", top_k)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"documents": documents,
|
"documents": documents,
|
||||||
}
|
}
|
||||||
if top_k:
|
if top_k:
|
||||||
payload["top_n"] = top_k
|
payload["top_n"] = top_k
|
||||||
|
|
||||||
response = await self._client.post("/rerank", json=payload)
|
response = await self._client.post("/rerank", json=payload)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
result = response.json()
|
result = response.json()
|
||||||
results = result.get("results", [])
|
results = result.get("results", [])
|
||||||
|
|
||||||
# Enrich with original documents
|
# Enrich with original documents
|
||||||
enriched = []
|
enriched = []
|
||||||
for r in results:
|
for r in results:
|
||||||
idx = r.get("index", 0)
|
idx = r.get("index", 0)
|
||||||
enriched.append({
|
enriched.append(
|
||||||
"index": idx,
|
{
|
||||||
"score": r.get("relevance_score", r.get("score", 0)),
|
"index": idx,
|
||||||
"document": documents[idx] if idx < len(documents) else "",
|
"score": r.get("relevance_score", r.get("score", 0)),
|
||||||
})
|
"document": documents[idx] if idx < len(documents) else "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return enriched
|
return enriched
|
||||||
|
|
||||||
async def rerank_with_metadata(
|
async def rerank_with_metadata(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@@ -90,27 +93,27 @@ class RerankerClient:
|
|||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Rerank documents with metadata, preserving metadata in results.
|
Rerank documents with metadata, preserving metadata in results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Query text
|
query: Query text
|
||||||
documents: List of dicts with text and metadata
|
documents: List of dicts with text and metadata
|
||||||
text_key: Key containing text in each document dict
|
text_key: Key containing text in each document dict
|
||||||
top_k: Number of top results to return
|
top_k: Number of top results to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Reranked documents with original metadata preserved.
|
Reranked documents with original metadata preserved.
|
||||||
"""
|
"""
|
||||||
texts = [d.get(text_key, "") for d in documents]
|
texts = [d.get(text_key, "") for d in documents]
|
||||||
reranked = await self.rerank(query, texts, top_k)
|
reranked = await self.rerank(query, texts, top_k)
|
||||||
|
|
||||||
# Merge back metadata
|
# Merge back metadata
|
||||||
for r in reranked:
|
for r in reranked:
|
||||||
idx = r["index"]
|
idx = r["index"]
|
||||||
if idx < len(documents):
|
if idx < len(documents):
|
||||||
r["metadata"] = {k: v for k, v in documents[idx].items() if k != text_key}
|
r["metadata"] = {k: v for k, v in documents[idx].items() if k != text_key}
|
||||||
|
|
||||||
return reranked
|
return reranked
|
||||||
|
|
||||||
async def health(self) -> bool:
|
async def health(self) -> bool:
|
||||||
"""Check if the reranker service is healthy."""
|
"""Check if the reranker service is healthy."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
STT service client (Whisper/faster-whisper).
|
STT service client (Whisper/faster-whisper).
|
||||||
"""
|
"""
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -16,23 +16,23 @@ logger = logging.getLogger(__name__)
|
|||||||
class STTClient:
|
class STTClient:
|
||||||
"""
|
"""
|
||||||
Client for the STT service (Whisper/faster-whisper).
|
Client for the STT service (Whisper/faster-whisper).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
client = STTClient()
|
client = STTClient()
|
||||||
text = await client.transcribe(audio_bytes)
|
text = await client.transcribe(audio_bytes)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, settings: Optional[STTSettings] = None):
|
def __init__(self, settings: Optional[STTSettings] = None):
|
||||||
self.settings = settings or STTSettings()
|
self.settings = settings or STTSettings()
|
||||||
self._client = httpx.AsyncClient(
|
self._client = httpx.AsyncClient(
|
||||||
base_url=self.settings.stt_url,
|
base_url=self.settings.stt_url,
|
||||||
timeout=180.0, # Transcription can be slow
|
timeout=180.0, # Transcription can be slow
|
||||||
)
|
)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the HTTP client."""
|
"""Close the HTTP client."""
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
|
|
||||||
async def transcribe(
|
async def transcribe(
|
||||||
self,
|
self,
|
||||||
audio: bytes,
|
audio: bytes,
|
||||||
@@ -42,54 +42,54 @@ class STTClient:
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Transcribe audio to text.
|
Transcribe audio to text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio: Audio bytes (WAV, MP3, etc.)
|
audio: Audio bytes (WAV, MP3, etc.)
|
||||||
language: Language code (None for auto-detect)
|
language: Language code (None for auto-detect)
|
||||||
task: "transcribe" or "translate"
|
task: "transcribe" or "translate"
|
||||||
response_format: "json", "text", "srt", "vtt"
|
response_format: "json", "text", "srt", "vtt"
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with 'text', 'language', and optional 'segments'
|
Dict with 'text', 'language', and optional 'segments'
|
||||||
"""
|
"""
|
||||||
language = language or self.settings.stt_language
|
language = language or self.settings.stt_language
|
||||||
task = task or self.settings.stt_task
|
task = task or self.settings.stt_task
|
||||||
|
|
||||||
with create_span("stt.transcribe") as span:
|
with create_span("stt.transcribe") as span:
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("stt.task", task)
|
span.set_attribute("stt.task", task)
|
||||||
span.set_attribute("stt.audio_size", len(audio))
|
span.set_attribute("stt.audio_size", len(audio))
|
||||||
if language:
|
if language:
|
||||||
span.set_attribute("stt.language", language)
|
span.set_attribute("stt.language", language)
|
||||||
|
|
||||||
files = {"file": ("audio.wav", audio, "audio/wav")}
|
files = {"file": ("audio.wav", audio, "audio/wav")}
|
||||||
data = {
|
data = {
|
||||||
"response_format": response_format,
|
"response_format": response_format,
|
||||||
}
|
}
|
||||||
if language:
|
if language:
|
||||||
data["language"] = language
|
data["language"] = language
|
||||||
|
|
||||||
# Choose endpoint based on task
|
# Choose endpoint based on task
|
||||||
if task == "translate":
|
if task == "translate":
|
||||||
endpoint = "/v1/audio/translations"
|
endpoint = "/v1/audio/translations"
|
||||||
else:
|
else:
|
||||||
endpoint = "/v1/audio/transcriptions"
|
endpoint = "/v1/audio/transcriptions"
|
||||||
|
|
||||||
response = await self._client.post(endpoint, files=files, data=data)
|
response = await self._client.post(endpoint, files=files, data=data)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
if response_format == "text":
|
if response_format == "text":
|
||||||
return {"text": response.text}
|
return {"text": response.text}
|
||||||
|
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("stt.result_length", len(result.get("text", "")))
|
span.set_attribute("stt.result_length", len(result.get("text", "")))
|
||||||
if result.get("language"):
|
if result.get("language"):
|
||||||
span.set_attribute("stt.detected_language", result["language"])
|
span.set_attribute("stt.detected_language", result["language"])
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def transcribe_file(
|
async def transcribe_file(
|
||||||
self,
|
self,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
@@ -98,31 +98,31 @@ class STTClient:
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Transcribe an audio file.
|
Transcribe an audio file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: Path to audio file
|
file_path: Path to audio file
|
||||||
language: Language code
|
language: Language code
|
||||||
task: "transcribe" or "translate"
|
task: "transcribe" or "translate"
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Transcription result
|
Transcription result
|
||||||
"""
|
"""
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
audio = f.read()
|
audio = f.read()
|
||||||
return await self.transcribe(audio, language, task)
|
return await self.transcribe(audio, language, task)
|
||||||
|
|
||||||
async def translate(self, audio: bytes) -> dict:
|
async def translate(self, audio: bytes) -> dict:
|
||||||
"""
|
"""
|
||||||
Translate audio to English.
|
Translate audio to English.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio: Audio bytes
|
audio: Audio bytes
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Translation result with 'text' key
|
Translation result with 'text' key
|
||||||
"""
|
"""
|
||||||
return await self.transcribe(audio, task="translate")
|
return await self.transcribe(audio, task="translate")
|
||||||
|
|
||||||
async def health(self) -> bool:
|
async def health(self) -> bool:
|
||||||
"""Check if the STT service is healthy."""
|
"""Check if the STT service is healthy."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
TTS service client (Coqui XTTS).
|
TTS service client (Coqui XTTS).
|
||||||
"""
|
"""
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -16,23 +16,23 @@ logger = logging.getLogger(__name__)
|
|||||||
class TTSClient:
|
class TTSClient:
|
||||||
"""
|
"""
|
||||||
Client for the TTS service (Coqui XTTS).
|
Client for the TTS service (Coqui XTTS).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
client = TTSClient()
|
client = TTSClient()
|
||||||
audio_bytes = await client.synthesize("Hello world")
|
audio_bytes = await client.synthesize("Hello world")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, settings: Optional[TTSSettings] = None):
|
def __init__(self, settings: Optional[TTSSettings] = None):
|
||||||
self.settings = settings or TTSSettings()
|
self.settings = settings or TTSSettings()
|
||||||
self._client = httpx.AsyncClient(
|
self._client = httpx.AsyncClient(
|
||||||
base_url=self.settings.tts_url,
|
base_url=self.settings.tts_url,
|
||||||
timeout=120.0, # TTS can be slow
|
timeout=120.0, # TTS can be slow
|
||||||
)
|
)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the HTTP client."""
|
"""Close the HTTP client."""
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
|
|
||||||
async def synthesize(
|
async def synthesize(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
@@ -41,39 +41,39 @@ class TTSClient:
|
|||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""
|
"""
|
||||||
Synthesize speech from text.
|
Synthesize speech from text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Text to synthesize
|
text: Text to synthesize
|
||||||
language: Language code (e.g., "en", "es", "fr")
|
language: Language code (e.g., "en", "es", "fr")
|
||||||
speaker: Speaker ID or reference
|
speaker: Speaker ID or reference
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
WAV audio bytes
|
WAV audio bytes
|
||||||
"""
|
"""
|
||||||
language = language or self.settings.tts_language
|
language = language or self.settings.tts_language
|
||||||
|
|
||||||
with create_span("tts.synthesize") as span:
|
with create_span("tts.synthesize") as span:
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("tts.language", language)
|
span.set_attribute("tts.language", language)
|
||||||
span.set_attribute("tts.text_length", len(text))
|
span.set_attribute("tts.text_length", len(text))
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"text": text,
|
"text": text,
|
||||||
"language_id": language,
|
"language_id": language,
|
||||||
}
|
}
|
||||||
if speaker:
|
if speaker:
|
||||||
params["speaker_id"] = speaker
|
params["speaker_id"] = speaker
|
||||||
|
|
||||||
response = await self._client.get("/api/tts", params=params)
|
response = await self._client.get("/api/tts", params=params)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
audio_bytes = response.content
|
audio_bytes = response.content
|
||||||
|
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("tts.audio_size", len(audio_bytes))
|
span.set_attribute("tts.audio_size", len(audio_bytes))
|
||||||
|
|
||||||
return audio_bytes
|
return audio_bytes
|
||||||
|
|
||||||
async def synthesize_to_file(
|
async def synthesize_to_file(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
@@ -83,7 +83,7 @@ class TTSClient:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Synthesize speech and save to a file.
|
Synthesize speech and save to a file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Text to synthesize
|
text: Text to synthesize
|
||||||
output_path: Path to save the audio file
|
output_path: Path to save the audio file
|
||||||
@@ -91,10 +91,10 @@ class TTSClient:
|
|||||||
speaker: Speaker ID
|
speaker: Speaker ID
|
||||||
"""
|
"""
|
||||||
audio_bytes = await self.synthesize(text, language, speaker)
|
audio_bytes = await self.synthesize(text, language, speaker)
|
||||||
|
|
||||||
with open(output_path, "wb") as f:
|
with open(output_path, "wb") as f:
|
||||||
f.write(audio_bytes)
|
f.write(audio_bytes)
|
||||||
|
|
||||||
async def get_speakers(self) -> list[dict]:
|
async def get_speakers(self) -> list[dict]:
|
||||||
"""Get available speakers/voices."""
|
"""Get available speakers/voices."""
|
||||||
try:
|
try:
|
||||||
@@ -103,7 +103,7 @@ class TTSClient:
|
|||||||
return response.json()
|
return response.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def health(self) -> bool:
|
async def health(self) -> bool:
|
||||||
"""Check if the TTS service is healthy."""
|
"""Check if the TTS service is healthy."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -3,67 +3,69 @@ Configuration management using Pydantic Settings.
|
|||||||
|
|
||||||
Environment variables are automatically loaded and validated.
|
Environment variables are automatically loaded and validated.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
"""Base settings for all handler services."""
|
"""Base settings for all handler services."""
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
env_file=".env",
|
env_file=".env",
|
||||||
env_file_encoding="utf-8",
|
env_file_encoding="utf-8",
|
||||||
extra="ignore",
|
extra="ignore",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Service identification
|
# Service identification
|
||||||
service_name: str = "handler"
|
service_name: str = "handler"
|
||||||
service_version: str = "1.0.0"
|
service_version: str = "1.0.0"
|
||||||
service_namespace: str = "ai-ml"
|
service_namespace: str = "ai-ml"
|
||||||
deployment_env: str = "production"
|
deployment_env: str = "production"
|
||||||
|
|
||||||
# NATS configuration
|
# NATS configuration
|
||||||
nats_url: str = "nats://nats.ai-ml.svc.cluster.local:4222"
|
nats_url: str = "nats://nats.ai-ml.svc.cluster.local:4222"
|
||||||
nats_user: Optional[str] = None
|
nats_user: Optional[str] = None
|
||||||
nats_password: Optional[str] = None
|
nats_password: Optional[str] = None
|
||||||
nats_queue_group: Optional[str] = None
|
nats_queue_group: Optional[str] = None
|
||||||
|
|
||||||
# Redis/Valkey configuration
|
# Redis/Valkey configuration
|
||||||
redis_url: str = "redis://valkey.ai-ml.svc.cluster.local:6379"
|
redis_url: str = "redis://valkey.ai-ml.svc.cluster.local:6379"
|
||||||
redis_password: Optional[str] = None
|
redis_password: Optional[str] = None
|
||||||
|
|
||||||
# Milvus configuration
|
# Milvus configuration
|
||||||
milvus_host: str = "milvus.ai-ml.svc.cluster.local"
|
milvus_host: str = "milvus.ai-ml.svc.cluster.local"
|
||||||
milvus_port: int = 19530
|
milvus_port: int = 19530
|
||||||
milvus_collection: str = "documents"
|
milvus_collection: str = "documents"
|
||||||
|
|
||||||
# Service endpoints
|
# Service endpoints
|
||||||
embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local"
|
embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local"
|
||||||
reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local"
|
reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local"
|
||||||
llm_url: str = "http://vllm-predictor.ai-ml.svc.cluster.local"
|
llm_url: str = "http://vllm-predictor.ai-ml.svc.cluster.local"
|
||||||
tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local"
|
tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local"
|
||||||
stt_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local"
|
stt_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local"
|
||||||
|
|
||||||
# OpenTelemetry configuration
|
# OpenTelemetry configuration
|
||||||
otel_enabled: bool = True
|
otel_enabled: bool = True
|
||||||
otel_endpoint: str = "http://opentelemetry-collector.observability.svc.cluster.local:4317"
|
otel_endpoint: str = "http://opentelemetry-collector.observability.svc.cluster.local:4317"
|
||||||
otel_use_http: bool = False
|
otel_use_http: bool = False
|
||||||
|
|
||||||
# HyperDX configuration
|
# HyperDX configuration
|
||||||
hyperdx_enabled: bool = False
|
hyperdx_enabled: bool = False
|
||||||
hyperdx_api_key: Optional[str] = None
|
hyperdx_api_key: Optional[str] = None
|
||||||
hyperdx_endpoint: str = "https://in-otel.hyperdx.io"
|
hyperdx_endpoint: str = "https://in-otel.hyperdx.io"
|
||||||
|
|
||||||
# MLflow configuration
|
# MLflow configuration
|
||||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80"
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80"
|
||||||
mlflow_experiment_name: Optional[str] = None
|
mlflow_experiment_name: Optional[str] = None
|
||||||
mlflow_enabled: bool = True
|
mlflow_enabled: bool = True
|
||||||
|
|
||||||
# Health check configuration
|
# Health check configuration
|
||||||
health_port: int = 8080
|
health_port: int = 8080
|
||||||
health_path: str = "/health"
|
health_path: str = "/health"
|
||||||
ready_path: str = "/ready"
|
ready_path: str = "/ready"
|
||||||
|
|
||||||
# Timeouts (seconds)
|
# Timeouts (seconds)
|
||||||
http_timeout: float = 60.0
|
http_timeout: float = 60.0
|
||||||
nats_timeout: float = 30.0
|
nats_timeout: float = 30.0
|
||||||
@@ -71,14 +73,14 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
class EmbeddingsSettings(Settings):
|
class EmbeddingsSettings(Settings):
|
||||||
"""Settings for embeddings service client."""
|
"""Settings for embeddings service client."""
|
||||||
|
|
||||||
embeddings_model: str = "bge"
|
embeddings_model: str = "bge"
|
||||||
embeddings_batch_size: int = 32
|
embeddings_batch_size: int = 32
|
||||||
|
|
||||||
|
|
||||||
class LLMSettings(Settings):
|
class LLMSettings(Settings):
|
||||||
"""Settings for LLM service client."""
|
"""Settings for LLM service client."""
|
||||||
|
|
||||||
llm_model: str = "default"
|
llm_model: str = "default"
|
||||||
llm_max_tokens: int = 2048
|
llm_max_tokens: int = 2048
|
||||||
llm_temperature: float = 0.7
|
llm_temperature: float = 0.7
|
||||||
@@ -87,13 +89,13 @@ class LLMSettings(Settings):
|
|||||||
|
|
||||||
class TTSSettings(Settings):
|
class TTSSettings(Settings):
|
||||||
"""Settings for TTS service client."""
|
"""Settings for TTS service client."""
|
||||||
|
|
||||||
tts_language: str = "en"
|
tts_language: str = "en"
|
||||||
tts_speaker: Optional[str] = None
|
tts_speaker: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class STTSettings(Settings):
|
class STTSettings(Settings):
|
||||||
"""Settings for STT service client."""
|
"""Settings for STT service client."""
|
||||||
|
|
||||||
stt_language: Optional[str] = None # Auto-detect
|
stt_language: Optional[str] = None # Auto-detect
|
||||||
stt_task: str = "transcribe" # or "translate"
|
stt_task: str = "transcribe" # or "translate"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Base handler class for building NATS-based services.
|
Base handler class for building NATS-based services.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import signal
|
import signal
|
||||||
@@ -12,7 +13,7 @@ from nats.aio.msg import Msg
|
|||||||
from handler_base.config import Settings
|
from handler_base.config import Settings
|
||||||
from handler_base.health import HealthServer
|
from handler_base.health import HealthServer
|
||||||
from handler_base.nats_client import NATSClient
|
from handler_base.nats_client import NATSClient
|
||||||
from handler_base.telemetry import setup_telemetry, create_span
|
from handler_base.telemetry import create_span, setup_telemetry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -20,25 +21,25 @@ logger = logging.getLogger(__name__)
|
|||||||
class Handler(ABC):
|
class Handler(ABC):
|
||||||
"""
|
"""
|
||||||
Base class for NATS message handlers.
|
Base class for NATS message handlers.
|
||||||
|
|
||||||
Subclass and implement:
|
Subclass and implement:
|
||||||
- setup(): Initialize your service clients
|
- setup(): Initialize your service clients
|
||||||
- handle_message(): Process incoming messages
|
- handle_message(): Process incoming messages
|
||||||
- teardown(): Clean up resources (optional)
|
- teardown(): Clean up resources (optional)
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
class MyHandler(Handler):
|
class MyHandler(Handler):
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
self.embeddings = EmbeddingsClient()
|
self.embeddings = EmbeddingsClient()
|
||||||
|
|
||||||
async def handle_message(self, msg: Msg, data: dict) -> Optional[dict]:
|
async def handle_message(self, msg: Msg, data: dict) -> Optional[dict]:
|
||||||
result = await self.embeddings.embed(data["text"])
|
result = await self.embeddings.embed(data["text"])
|
||||||
return {"embedding": result}
|
return {"embedding": result}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
MyHandler(subject="my.subject").run()
|
MyHandler(subject="my.subject").run()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
subject: str,
|
subject: str,
|
||||||
@@ -47,7 +48,7 @@ class Handler(ABC):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the handler.
|
Initialize the handler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
subject: NATS subject to subscribe to
|
subject: NATS subject to subscribe to
|
||||||
settings: Configuration settings
|
settings: Configuration settings
|
||||||
@@ -56,78 +57,78 @@ class Handler(ABC):
|
|||||||
self.subject = subject
|
self.subject = subject
|
||||||
self.settings = settings or Settings()
|
self.settings = settings or Settings()
|
||||||
self.queue_group = queue_group or self.settings.nats_queue_group
|
self.queue_group = queue_group or self.settings.nats_queue_group
|
||||||
|
|
||||||
self.nats = NATSClient(self.settings)
|
self.nats = NATSClient(self.settings)
|
||||||
self.health_server = HealthServer(self.settings, self._check_ready)
|
self.health_server = HealthServer(self.settings, self._check_ready)
|
||||||
|
|
||||||
self._running = False
|
self._running = False
|
||||||
self._shutdown_event = asyncio.Event()
|
self._shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def setup(self) -> None:
|
async def setup(self) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize service clients and resources.
|
Initialize service clients and resources.
|
||||||
|
|
||||||
Called once before starting to handle messages.
|
Called once before starting to handle messages.
|
||||||
Override this to set up your service-specific clients.
|
Override this to set up your service-specific clients.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def handle_message(self, msg: Msg, data: Any) -> Optional[Any]:
|
async def handle_message(self, msg: Msg, data: Any) -> Optional[Any]:
|
||||||
"""
|
"""
|
||||||
Handle an incoming message.
|
Handle an incoming message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg: Raw NATS message
|
msg: Raw NATS message
|
||||||
data: Decoded message data (msgpack unpacked)
|
data: Decoded message data (msgpack unpacked)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional response data. If returned and msg has a reply subject,
|
Optional response data. If returned and msg has a reply subject,
|
||||||
the response will be sent automatically.
|
the response will be sent automatically.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def teardown(self) -> None:
|
async def teardown(self) -> None:
|
||||||
"""
|
"""
|
||||||
Clean up resources.
|
Clean up resources.
|
||||||
|
|
||||||
Called during graceful shutdown.
|
Called during graceful shutdown.
|
||||||
Override to add custom cleanup logic.
|
Override to add custom cleanup logic.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _check_ready(self) -> bool:
|
async def _check_ready(self) -> bool:
|
||||||
"""Check if the service is ready to handle requests."""
|
"""Check if the service is ready to handle requests."""
|
||||||
return self._running and self.nats._nc is not None
|
return self._running and self.nats._nc is not None
|
||||||
|
|
||||||
async def _message_handler(self, msg: Msg) -> None:
|
async def _message_handler(self, msg: Msg) -> None:
|
||||||
"""Internal message handler with tracing and error handling."""
|
"""Internal message handler with tracing and error handling."""
|
||||||
with create_span(f"handle.{self.subject}") as span:
|
with create_span(f"handle.{self.subject}") as span:
|
||||||
try:
|
try:
|
||||||
# Decode message
|
# Decode message
|
||||||
data = NATSClient.decode_msgpack(msg)
|
data = NATSClient.decode_msgpack(msg)
|
||||||
|
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("messaging.destination", msg.subject)
|
span.set_attribute("messaging.destination", msg.subject)
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
request_id = data.get("request_id", data.get("id"))
|
request_id = data.get("request_id", data.get("id"))
|
||||||
if request_id:
|
if request_id:
|
||||||
span.set_attribute("request.id", str(request_id))
|
span.set_attribute("request.id", str(request_id))
|
||||||
|
|
||||||
# Handle message
|
# Handle message
|
||||||
response = await self.handle_message(msg, data)
|
response = await self.handle_message(msg, data)
|
||||||
|
|
||||||
# Send response if applicable
|
# Send response if applicable
|
||||||
if response is not None and msg.reply:
|
if response is not None and msg.reply:
|
||||||
await self.nats.publish(msg.reply, response)
|
await self.nats.publish(msg.reply, response)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error handling message on {msg.subject}")
|
logger.exception(f"Error handling message on {msg.subject}")
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("error", True)
|
span.set_attribute("error", True)
|
||||||
span.set_attribute("error.message", str(e))
|
span.set_attribute("error.message", str(e))
|
||||||
|
|
||||||
# Send error response if reply expected
|
# Send error response if reply expected
|
||||||
if msg.reply:
|
if msg.reply:
|
||||||
error_response = {
|
error_response = {
|
||||||
@@ -136,71 +137,71 @@ class Handler(ABC):
|
|||||||
"type": type(e).__name__,
|
"type": type(e).__name__,
|
||||||
}
|
}
|
||||||
await self.nats.publish(msg.reply, error_response)
|
await self.nats.publish(msg.reply, error_response)
|
||||||
|
|
||||||
def _setup_signals(self) -> None:
|
def _setup_signals(self) -> None:
|
||||||
"""Set up signal handlers for graceful shutdown."""
|
"""Set up signal handlers for graceful shutdown."""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||||
loop.add_signal_handler(sig, self._handle_signal, sig)
|
loop.add_signal_handler(sig, self._handle_signal, sig)
|
||||||
|
|
||||||
def _handle_signal(self, sig: signal.Signals) -> None:
|
def _handle_signal(self, sig: signal.Signals) -> None:
|
||||||
"""Handle shutdown signal."""
|
"""Handle shutdown signal."""
|
||||||
logger.info(f"Received {sig.name}, initiating graceful shutdown...")
|
logger.info(f"Received {sig.name}, initiating graceful shutdown...")
|
||||||
self._shutdown_event.set()
|
self._shutdown_event.set()
|
||||||
|
|
||||||
async def _run(self) -> None:
|
async def _run(self) -> None:
|
||||||
"""Main async run loop."""
|
"""Main async run loop."""
|
||||||
# Setup telemetry
|
# Setup telemetry
|
||||||
setup_telemetry(self.settings)
|
setup_telemetry(self.settings)
|
||||||
|
|
||||||
# Start health server
|
# Start health server
|
||||||
self.health_server.start()
|
self.health_server.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Connect to NATS
|
# Connect to NATS
|
||||||
await self.nats.connect()
|
await self.nats.connect()
|
||||||
|
|
||||||
# Run user setup
|
# Run user setup
|
||||||
logger.info("Running service setup...")
|
logger.info("Running service setup...")
|
||||||
await self.setup()
|
await self.setup()
|
||||||
|
|
||||||
# Subscribe to subject
|
# Subscribe to subject
|
||||||
await self.nats.subscribe(
|
await self.nats.subscribe(
|
||||||
self.subject,
|
self.subject,
|
||||||
self._message_handler,
|
self._message_handler,
|
||||||
queue=self.queue_group,
|
queue=self.queue_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
logger.info(f"Handler ready, listening on {self.subject}")
|
logger.info(f"Handler ready, listening on {self.subject}")
|
||||||
|
|
||||||
# Wait for shutdown signal
|
# Wait for shutdown signal
|
||||||
await self._shutdown_event.wait()
|
await self._shutdown_event.wait()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.exception("Fatal error in handler")
|
logger.exception("Fatal error in handler")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
# Graceful shutdown
|
# Graceful shutdown
|
||||||
logger.info("Shutting down...")
|
logger.info("Shutting down...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.teardown()
|
await self.teardown()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error during teardown: {e}")
|
logger.warning(f"Error during teardown: {e}")
|
||||||
|
|
||||||
await self.nats.close()
|
await self.nats.close()
|
||||||
self.health_server.stop()
|
self.health_server.stop()
|
||||||
|
|
||||||
logger.info("Shutdown complete")
|
logger.info("Shutdown complete")
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""
|
"""
|
||||||
Run the handler.
|
Run the handler.
|
||||||
|
|
||||||
This is the main entry point. It sets up signal handlers
|
This is the main entry point. It sets up signal handlers
|
||||||
and runs the async event loop.
|
and runs the async event loop.
|
||||||
"""
|
"""
|
||||||
@@ -209,12 +210,12 @@ class Handler(ABC):
|
|||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Starting {self.settings.service_name} v{self.settings.service_version}")
|
logger.info(f"Starting {self.settings.service_name} v{self.settings.service_version}")
|
||||||
|
|
||||||
# Run the async loop
|
# Run the async loop
|
||||||
asyncio.run(self._run_with_signals())
|
asyncio.run(self._run_with_signals())
|
||||||
|
|
||||||
async def _run_with_signals(self) -> None:
|
async def _run_with_signals(self) -> None:
|
||||||
"""Run with signal handling."""
|
"""Run with signal handling."""
|
||||||
self._setup_signals()
|
self._setup_signals()
|
||||||
|
|||||||
@@ -3,12 +3,13 @@ HTTP health check server.
|
|||||||
|
|
||||||
Provides /health and /ready endpoints for Kubernetes probes.
|
Provides /health and /ready endpoints for Kubernetes probes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
from typing import Callable, Optional, Awaitable
|
|
||||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
|
||||||
import threading
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
|
from typing import Awaitable, Callable, Optional
|
||||||
|
|
||||||
from handler_base.config import Settings
|
from handler_base.config import Settings
|
||||||
|
|
||||||
@@ -17,16 +18,16 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class HealthHandler(BaseHTTPRequestHandler):
|
class HealthHandler(BaseHTTPRequestHandler):
|
||||||
"""HTTP request handler for health checks."""
|
"""HTTP request handler for health checks."""
|
||||||
|
|
||||||
# Class-level state
|
# Class-level state
|
||||||
ready_check: Optional[Callable[[], Awaitable[bool]]] = None
|
ready_check: Optional[Callable[[], Awaitable[bool]]] = None
|
||||||
health_path: str = "/health"
|
health_path: str = "/health"
|
||||||
ready_path: str = "/ready"
|
ready_path: str = "/ready"
|
||||||
|
|
||||||
def log_message(self, format, *args):
|
def log_message(self, format, *args):
|
||||||
"""Suppress default logging."""
|
"""Suppress default logging."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
"""Handle GET requests for health/ready endpoints."""
|
"""Handle GET requests for health/ready endpoints."""
|
||||||
if self.path == self.health_path:
|
if self.path == self.health_path:
|
||||||
@@ -35,7 +36,7 @@ class HealthHandler(BaseHTTPRequestHandler):
|
|||||||
self._handle_ready()
|
self._handle_ready()
|
||||||
else:
|
else:
|
||||||
self._respond_not_found()
|
self._respond_not_found()
|
||||||
|
|
||||||
def _handle_ready(self):
|
def _handle_ready(self):
|
||||||
"""Check readiness and respond."""
|
"""Check readiness and respond."""
|
||||||
# Access via class to avoid method binding issues
|
# Access via class to avoid method binding issues
|
||||||
@@ -43,7 +44,7 @@ class HealthHandler(BaseHTTPRequestHandler):
|
|||||||
if ready_check is None:
|
if ready_check is None:
|
||||||
self._respond_ok({"status": "ready"})
|
self._respond_ok({"status": "ready"})
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Run the async check in a new event loop
|
# Run the async check in a new event loop
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
@@ -51,7 +52,7 @@ class HealthHandler(BaseHTTPRequestHandler):
|
|||||||
is_ready = loop.run_until_complete(ready_check())
|
is_ready = loop.run_until_complete(ready_check())
|
||||||
finally:
|
finally:
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
if is_ready:
|
if is_ready:
|
||||||
self._respond_ok({"status": "ready"})
|
self._respond_ok({"status": "ready"})
|
||||||
else:
|
else:
|
||||||
@@ -59,19 +60,19 @@ class HealthHandler(BaseHTTPRequestHandler):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Readiness check failed")
|
logger.exception("Readiness check failed")
|
||||||
self._respond_unavailable({"status": "error", "message": str(e)})
|
self._respond_unavailable({"status": "error", "message": str(e)})
|
||||||
|
|
||||||
def _respond_ok(self, data: dict):
|
def _respond_ok(self, data: dict):
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header("Content-Type", "application/json")
|
self.send_header("Content-Type", "application/json")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.wfile.write(json.dumps(data).encode())
|
self.wfile.write(json.dumps(data).encode())
|
||||||
|
|
||||||
def _respond_unavailable(self, data: dict):
|
def _respond_unavailable(self, data: dict):
|
||||||
self.send_response(503)
|
self.send_response(503)
|
||||||
self.send_header("Content-Type", "application/json")
|
self.send_header("Content-Type", "application/json")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.wfile.write(json.dumps(data).encode())
|
self.wfile.write(json.dumps(data).encode())
|
||||||
|
|
||||||
def _respond_not_found(self):
|
def _respond_not_found(self):
|
||||||
self.send_response(404)
|
self.send_response(404)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
@@ -80,14 +81,14 @@ class HealthHandler(BaseHTTPRequestHandler):
|
|||||||
class HealthServer:
|
class HealthServer:
|
||||||
"""
|
"""
|
||||||
Background HTTP server for health checks.
|
Background HTTP server for health checks.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
server = HealthServer(settings)
|
server = HealthServer(settings)
|
||||||
server.start()
|
server.start()
|
||||||
# ... run your service ...
|
# ... run your service ...
|
||||||
server.stop()
|
server.stop()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
settings: Optional[Settings] = None,
|
settings: Optional[Settings] = None,
|
||||||
@@ -97,24 +98,24 @@ class HealthServer:
|
|||||||
self.ready_check = ready_check
|
self.ready_check = ready_check
|
||||||
self._server: Optional[HTTPServer] = None
|
self._server: Optional[HTTPServer] = None
|
||||||
self._thread: Optional[threading.Thread] = None
|
self._thread: Optional[threading.Thread] = None
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
"""Start the health check server in a background thread."""
|
"""Start the health check server in a background thread."""
|
||||||
# Configure handler class
|
# Configure handler class
|
||||||
HealthHandler.ready_check = self.ready_check
|
HealthHandler.ready_check = self.ready_check
|
||||||
HealthHandler.health_path = self.settings.health_path
|
HealthHandler.health_path = self.settings.health_path
|
||||||
HealthHandler.ready_path = self.settings.ready_path
|
HealthHandler.ready_path = self.settings.ready_path
|
||||||
|
|
||||||
# Create and start server
|
# Create and start server
|
||||||
self._server = HTTPServer(("0.0.0.0", self.settings.health_port), HealthHandler)
|
self._server = HTTPServer(("0.0.0.0", self.settings.health_port), HealthHandler)
|
||||||
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
|
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Health server started on port {self.settings.health_port} "
|
f"Health server started on port {self.settings.health_port} "
|
||||||
f"(health: {self.settings.health_path}, ready: {self.settings.ready_path})"
|
f"(health: {self.settings.health_path}, ready: {self.settings.ready_path})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Stop the health check server."""
|
"""Stop the health check server."""
|
||||||
if self._server:
|
if self._server:
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
NATS client wrapper with connection management and utilities.
|
NATS client wrapper with connection management and utilities.
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Optional, Awaitable
|
from typing import Any, Awaitable, Callable, Optional
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
import nats
|
import nats
|
||||||
@@ -20,34 +20,34 @@ logger = logging.getLogger(__name__)
|
|||||||
class NATSClient:
|
class NATSClient:
|
||||||
"""
|
"""
|
||||||
NATS client with automatic connection management.
|
NATS client with automatic connection management.
|
||||||
|
|
||||||
Supports:
|
Supports:
|
||||||
- Core NATS pub/sub
|
- Core NATS pub/sub
|
||||||
- JetStream for persistence
|
- JetStream for persistence
|
||||||
- Queue groups for load balancing
|
- Queue groups for load balancing
|
||||||
- Msgpack serialization
|
- Msgpack serialization
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, settings: Optional[Settings] = None):
|
def __init__(self, settings: Optional[Settings] = None):
|
||||||
self.settings = settings or Settings()
|
self.settings = settings or Settings()
|
||||||
self._nc: Optional[Client] = None
|
self._nc: Optional[Client] = None
|
||||||
self._js: Optional[JetStreamContext] = None
|
self._js: Optional[JetStreamContext] = None
|
||||||
self._subscriptions: list = []
|
self._subscriptions: list = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nc(self) -> Client:
|
def nc(self) -> Client:
|
||||||
"""Get the NATS client, raising if not connected."""
|
"""Get the NATS client, raising if not connected."""
|
||||||
if self._nc is None:
|
if self._nc is None:
|
||||||
raise RuntimeError("NATS client not connected. Call connect() first.")
|
raise RuntimeError("NATS client not connected. Call connect() first.")
|
||||||
return self._nc
|
return self._nc
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def js(self) -> JetStreamContext:
|
def js(self) -> JetStreamContext:
|
||||||
"""Get JetStream context, raising if not connected."""
|
"""Get JetStream context, raising if not connected."""
|
||||||
if self._js is None:
|
if self._js is None:
|
||||||
raise RuntimeError("JetStream not initialized. Call connect() first.")
|
raise RuntimeError("JetStream not initialized. Call connect() first.")
|
||||||
return self._js
|
return self._js
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
"""Connect to NATS server."""
|
"""Connect to NATS server."""
|
||||||
connect_opts = {
|
connect_opts = {
|
||||||
@@ -55,16 +55,16 @@ class NATSClient:
|
|||||||
"reconnect_time_wait": 2,
|
"reconnect_time_wait": 2,
|
||||||
"max_reconnect_attempts": -1, # Infinite
|
"max_reconnect_attempts": -1, # Infinite
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.settings.nats_user and self.settings.nats_password:
|
if self.settings.nats_user and self.settings.nats_password:
|
||||||
connect_opts["user"] = self.settings.nats_user
|
connect_opts["user"] = self.settings.nats_user
|
||||||
connect_opts["password"] = self.settings.nats_password
|
connect_opts["password"] = self.settings.nats_password
|
||||||
|
|
||||||
logger.info(f"Connecting to NATS at {self.settings.nats_url}")
|
logger.info(f"Connecting to NATS at {self.settings.nats_url}")
|
||||||
self._nc = await nats.connect(**connect_opts)
|
self._nc = await nats.connect(**connect_opts)
|
||||||
self._js = self._nc.jetstream()
|
self._js = self._nc.jetstream()
|
||||||
logger.info("Connected to NATS")
|
logger.info("Connected to NATS")
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close NATS connection gracefully."""
|
"""Close NATS connection gracefully."""
|
||||||
if self._nc:
|
if self._nc:
|
||||||
@@ -74,13 +74,13 @@ class NATSClient:
|
|||||||
await sub.drain()
|
await sub.drain()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error draining subscription: {e}")
|
logger.warning(f"Error draining subscription: {e}")
|
||||||
|
|
||||||
await self._nc.drain()
|
await self._nc.drain()
|
||||||
await self._nc.close()
|
await self._nc.close()
|
||||||
self._nc = None
|
self._nc = None
|
||||||
self._js = None
|
self._js = None
|
||||||
logger.info("NATS connection closed")
|
logger.info("NATS connection closed")
|
||||||
|
|
||||||
async def subscribe(
|
async def subscribe(
|
||||||
self,
|
self,
|
||||||
subject: str,
|
subject: str,
|
||||||
@@ -89,24 +89,24 @@ class NATSClient:
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Subscribe to a subject with a handler function.
|
Subscribe to a subject with a handler function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
subject: NATS subject to subscribe to
|
subject: NATS subject to subscribe to
|
||||||
handler: Async function to handle messages
|
handler: Async function to handle messages
|
||||||
queue: Optional queue group for load balancing
|
queue: Optional queue group for load balancing
|
||||||
"""
|
"""
|
||||||
queue = queue or self.settings.nats_queue_group
|
queue = queue or self.settings.nats_queue_group
|
||||||
|
|
||||||
if queue:
|
if queue:
|
||||||
sub = await self.nc.subscribe(subject, queue=queue, cb=handler)
|
sub = await self.nc.subscribe(subject, queue=queue, cb=handler)
|
||||||
logger.info(f"Subscribed to {subject} (queue: {queue})")
|
logger.info(f"Subscribed to {subject} (queue: {queue})")
|
||||||
else:
|
else:
|
||||||
sub = await self.nc.subscribe(subject, cb=handler)
|
sub = await self.nc.subscribe(subject, cb=handler)
|
||||||
logger.info(f"Subscribed to {subject}")
|
logger.info(f"Subscribed to {subject}")
|
||||||
|
|
||||||
self._subscriptions.append(sub)
|
self._subscriptions.append(sub)
|
||||||
return sub
|
return sub
|
||||||
|
|
||||||
async def publish(
|
async def publish(
|
||||||
self,
|
self,
|
||||||
subject: str,
|
subject: str,
|
||||||
@@ -115,7 +115,7 @@ class NATSClient:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Publish a message to a subject.
|
Publish a message to a subject.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
subject: NATS subject to publish to
|
subject: NATS subject to publish to
|
||||||
data: Data to publish (will be serialized)
|
data: Data to publish (will be serialized)
|
||||||
@@ -124,15 +124,16 @@ class NATSClient:
|
|||||||
with create_span("nats.publish") as span:
|
with create_span("nats.publish") as span:
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("messaging.destination", subject)
|
span.set_attribute("messaging.destination", subject)
|
||||||
|
|
||||||
if use_msgpack:
|
if use_msgpack:
|
||||||
payload = msgpack.packb(data, use_bin_type=True)
|
payload = msgpack.packb(data, use_bin_type=True)
|
||||||
else:
|
else:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
payload = json.dumps(data).encode()
|
payload = json.dumps(data).encode()
|
||||||
|
|
||||||
await self.nc.publish(subject, payload)
|
await self.nc.publish(subject, payload)
|
||||||
|
|
||||||
async def request(
|
async def request(
|
||||||
self,
|
self,
|
||||||
subject: str,
|
subject: str,
|
||||||
@@ -142,43 +143,46 @@ class NATSClient:
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Send a request and wait for response.
|
Send a request and wait for response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
subject: NATS subject to send request to
|
subject: NATS subject to send request to
|
||||||
data: Request data
|
data: Request data
|
||||||
timeout: Response timeout in seconds
|
timeout: Response timeout in seconds
|
||||||
use_msgpack: Whether to use msgpack serialization
|
use_msgpack: Whether to use msgpack serialization
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Decoded response data
|
Decoded response data
|
||||||
"""
|
"""
|
||||||
timeout = timeout or self.settings.nats_timeout
|
timeout = timeout or self.settings.nats_timeout
|
||||||
|
|
||||||
with create_span("nats.request") as span:
|
with create_span("nats.request") as span:
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("messaging.destination", subject)
|
span.set_attribute("messaging.destination", subject)
|
||||||
|
|
||||||
if use_msgpack:
|
if use_msgpack:
|
||||||
payload = msgpack.packb(data, use_bin_type=True)
|
payload = msgpack.packb(data, use_bin_type=True)
|
||||||
else:
|
else:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
payload = json.dumps(data).encode()
|
payload = json.dumps(data).encode()
|
||||||
|
|
||||||
response = await self.nc.request(subject, payload, timeout=timeout)
|
response = await self.nc.request(subject, payload, timeout=timeout)
|
||||||
|
|
||||||
if use_msgpack:
|
if use_msgpack:
|
||||||
return msgpack.unpackb(response.data, raw=False)
|
return msgpack.unpackb(response.data, raw=False)
|
||||||
else:
|
else:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
return json.loads(response.data.decode())
|
return json.loads(response.data.decode())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def decode_msgpack(msg: Msg) -> Any:
|
def decode_msgpack(msg: Msg) -> Any:
|
||||||
"""Decode a msgpack message."""
|
"""Decode a msgpack message."""
|
||||||
return msgpack.unpackb(msg.data, raw=False)
|
return msgpack.unpackb(msg.data, raw=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def decode_json(msg: Msg) -> Any:
|
def decode_json(msg: Msg) -> Any:
|
||||||
"""Decode a JSON message."""
|
"""Decode a JSON message."""
|
||||||
import json
|
import json
|
||||||
|
|
||||||
return json.loads(msg.data.decode())
|
return json.loads(msg.data.decode())
|
||||||
|
|||||||
@@ -3,26 +3,27 @@ OpenTelemetry setup for tracing and metrics.
|
|||||||
|
|
||||||
Supports both gRPC and HTTP exporters, with optional HyperDX integration.
|
Supports both gRPC and HTTP exporters, with optional HyperDX integration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from opentelemetry import trace, metrics
|
from opentelemetry import metrics, trace
|
||||||
from opentelemetry.sdk.trace import TracerProvider
|
|
||||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
|
||||||
from opentelemetry.sdk.metrics import MeterProvider
|
|
||||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
|
||||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
|
||||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
|
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
|
||||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||||
OTLPSpanExporter as OTLPSpanExporterHTTP,
|
|
||||||
)
|
|
||||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
|
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
|
||||||
OTLPMetricExporter as OTLPMetricExporterHTTP,
|
OTLPMetricExporter as OTLPMetricExporterHTTP,
|
||||||
)
|
)
|
||||||
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||||
|
OTLPSpanExporter as OTLPSpanExporterHTTP,
|
||||||
|
)
|
||||||
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
||||||
from opentelemetry.instrumentation.logging import LoggingInstrumentor
|
from opentelemetry.instrumentation.logging import LoggingInstrumentor
|
||||||
|
from opentelemetry.sdk.metrics import MeterProvider
|
||||||
|
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||||
|
from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, SERVICE_VERSION, Resource
|
||||||
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||||
|
|
||||||
from handler_base.config import Settings
|
from handler_base.config import Settings
|
||||||
|
|
||||||
@@ -39,35 +40,37 @@ def setup_telemetry(
|
|||||||
) -> Tuple[Optional[trace.Tracer], Optional[metrics.Meter]]:
|
) -> Tuple[Optional[trace.Tracer], Optional[metrics.Meter]]:
|
||||||
"""
|
"""
|
||||||
Initialize OpenTelemetry tracing and metrics.
|
Initialize OpenTelemetry tracing and metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
settings: Configuration settings. If None, loads from environment.
|
settings: Configuration settings. If None, loads from environment.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (tracer, meter) or (None, None) if disabled.
|
Tuple of (tracer, meter) or (None, None) if disabled.
|
||||||
"""
|
"""
|
||||||
global _tracer, _meter, _initialized
|
global _tracer, _meter, _initialized
|
||||||
|
|
||||||
if _initialized:
|
if _initialized:
|
||||||
return _tracer, _meter
|
return _tracer, _meter
|
||||||
|
|
||||||
if settings is None:
|
if settings is None:
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
if not settings.otel_enabled:
|
if not settings.otel_enabled:
|
||||||
logger.info("OpenTelemetry disabled")
|
logger.info("OpenTelemetry disabled")
|
||||||
_initialized = True
|
_initialized = True
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Create resource with service information
|
# Create resource with service information
|
||||||
resource = Resource.create({
|
resource = Resource.create(
|
||||||
SERVICE_NAME: settings.service_name,
|
{
|
||||||
SERVICE_VERSION: settings.service_version,
|
SERVICE_NAME: settings.service_name,
|
||||||
SERVICE_NAMESPACE: settings.service_namespace,
|
SERVICE_VERSION: settings.service_version,
|
||||||
"deployment.environment": settings.deployment_env,
|
SERVICE_NAMESPACE: settings.service_namespace,
|
||||||
"host.name": os.environ.get("HOSTNAME", "unknown"),
|
"deployment.environment": settings.deployment_env,
|
||||||
})
|
"host.name": os.environ.get("HOSTNAME", "unknown"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Determine endpoint and exporter type
|
# Determine endpoint and exporter type
|
||||||
if settings.hyperdx_enabled and settings.hyperdx_api_key:
|
if settings.hyperdx_enabled and settings.hyperdx_api_key:
|
||||||
# HyperDX uses HTTP with API key header
|
# HyperDX uses HTTP with API key header
|
||||||
@@ -80,7 +83,7 @@ def setup_telemetry(
|
|||||||
headers = None
|
headers = None
|
||||||
use_http = settings.otel_use_http
|
use_http = settings.otel_use_http
|
||||||
logger.info(f"Using OTEL endpoint: {endpoint} (HTTP: {use_http})")
|
logger.info(f"Using OTEL endpoint: {endpoint} (HTTP: {use_http})")
|
||||||
|
|
||||||
# Setup tracing
|
# Setup tracing
|
||||||
if use_http:
|
if use_http:
|
||||||
trace_exporter = OTLPSpanExporterHTTP(
|
trace_exporter = OTLPSpanExporterHTTP(
|
||||||
@@ -91,11 +94,11 @@ def setup_telemetry(
|
|||||||
trace_exporter = OTLPSpanExporter(
|
trace_exporter = OTLPSpanExporter(
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
tracer_provider = TracerProvider(resource=resource)
|
tracer_provider = TracerProvider(resource=resource)
|
||||||
tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
|
tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
|
||||||
trace.set_tracer_provider(tracer_provider)
|
trace.set_tracer_provider(tracer_provider)
|
||||||
|
|
||||||
# Setup metrics
|
# Setup metrics
|
||||||
if use_http:
|
if use_http:
|
||||||
metric_exporter = OTLPMetricExporterHTTP(
|
metric_exporter = OTLPMetricExporterHTTP(
|
||||||
@@ -106,25 +109,25 @@ def setup_telemetry(
|
|||||||
metric_exporter = OTLPMetricExporter(
|
metric_exporter = OTLPMetricExporter(
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
metric_reader = PeriodicExportingMetricReader(
|
metric_reader = PeriodicExportingMetricReader(
|
||||||
metric_exporter,
|
metric_exporter,
|
||||||
export_interval_millis=60000,
|
export_interval_millis=60000,
|
||||||
)
|
)
|
||||||
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||||
metrics.set_meter_provider(meter_provider)
|
metrics.set_meter_provider(meter_provider)
|
||||||
|
|
||||||
# Instrument libraries
|
# Instrument libraries
|
||||||
HTTPXClientInstrumentor().instrument()
|
HTTPXClientInstrumentor().instrument()
|
||||||
LoggingInstrumentor().instrument(set_logging_format=True)
|
LoggingInstrumentor().instrument(set_logging_format=True)
|
||||||
|
|
||||||
# Create tracer and meter for this service
|
# Create tracer and meter for this service
|
||||||
_tracer = trace.get_tracer(settings.service_name, settings.service_version)
|
_tracer = trace.get_tracer(settings.service_name, settings.service_version)
|
||||||
_meter = metrics.get_meter(settings.service_name, settings.service_version)
|
_meter = metrics.get_meter(settings.service_name, settings.service_version)
|
||||||
|
|
||||||
logger.info(f"OpenTelemetry initialized for {settings.service_name}")
|
logger.info(f"OpenTelemetry initialized for {settings.service_name}")
|
||||||
_initialized = True
|
_initialized = True
|
||||||
|
|
||||||
return _tracer, _meter
|
return _tracer, _meter
|
||||||
|
|
||||||
|
|
||||||
@@ -141,7 +144,7 @@ def get_meter() -> Optional[metrics.Meter]:
|
|||||||
def create_span(name: str, **kwargs):
|
def create_span(name: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create a new span.
|
Create a new span.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
with create_span("my_operation") as span:
|
with create_span("my_operation") as span:
|
||||||
span.set_attribute("key", "value")
|
span.set_attribute("key", "value")
|
||||||
@@ -150,5 +153,6 @@ def create_span(name: str, **kwargs):
|
|||||||
if _tracer is None:
|
if _tracer is None:
|
||||||
# Return a no-op context manager
|
# Return a no-op context manager
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
return _tracer.start_as_current_span(name, **kwargs)
|
return _tracer.start_as_current_span(name, **kwargs)
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
Pytest configuration and fixtures.
|
Pytest configuration and fixtures.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import AsyncGenerator
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
# Set test environment variables before importing handler_base
|
# Set test environment variables before importing handler_base
|
||||||
os.environ.setdefault("NATS_URL", "nats://localhost:4222")
|
os.environ.setdefault("NATS_URL", "nats://localhost:4222")
|
||||||
os.environ.setdefault("REDIS_URL", "redis://localhost:6379")
|
os.environ.setdefault("REDIS_URL", "redis://localhost:6379")
|
||||||
@@ -29,6 +28,7 @@ def event_loop():
|
|||||||
def settings():
|
def settings():
|
||||||
"""Create test settings."""
|
"""Create test settings."""
|
||||||
from handler_base.config import Settings
|
from handler_base.config import Settings
|
||||||
|
|
||||||
return Settings(
|
return Settings(
|
||||||
service_name="test-service",
|
service_name="test-service",
|
||||||
service_version="1.0.0-test",
|
service_version="1.0.0-test",
|
||||||
@@ -56,7 +56,7 @@ def mock_nats_message():
|
|||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
msg.subject = "test.subject"
|
msg.subject = "test.subject"
|
||||||
msg.reply = "test.reply"
|
msg.reply = "test.reply"
|
||||||
msg.data = b'\x82\xa8query\xa5hello\xaarequest_id\xa4test' # msgpack
|
msg.data = b"\x82\xa8query\xa5hello\xaarequest_id\xa4test" # msgpack
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,44 +1,43 @@
|
|||||||
"""
|
"""
|
||||||
Unit tests for service clients.
|
Unit tests for service clients.
|
||||||
"""
|
"""
|
||||||
import json
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
|
|
||||||
class TestEmbeddingsClient:
|
class TestEmbeddingsClient:
|
||||||
"""Tests for EmbeddingsClient."""
|
"""Tests for EmbeddingsClient."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def embeddings_client(self, mock_httpx_client):
|
def embeddings_client(self, mock_httpx_client):
|
||||||
"""Create an EmbeddingsClient with mocked HTTP."""
|
"""Create an EmbeddingsClient with mocked HTTP."""
|
||||||
from handler_base.clients.embeddings import EmbeddingsClient
|
from handler_base.clients.embeddings import EmbeddingsClient
|
||||||
|
|
||||||
client = EmbeddingsClient()
|
client = EmbeddingsClient()
|
||||||
client._client = mock_httpx_client
|
client._client = mock_httpx_client
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_embed_single(self, embeddings_client, mock_httpx_client, sample_embedding):
|
async def test_embed_single(self, embeddings_client, mock_httpx_client, sample_embedding):
|
||||||
"""Test embedding a single text."""
|
"""Test embedding a single text."""
|
||||||
# Setup mock response
|
# Setup mock response
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {"data": [{"embedding": sample_embedding, "index": 0}]}
|
||||||
"data": [{"embedding": sample_embedding, "index": 0}]
|
|
||||||
}
|
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
mock_httpx_client.post.return_value = mock_response
|
mock_httpx_client.post.return_value = mock_response
|
||||||
|
|
||||||
result = await embeddings_client.embed_single("Hello world")
|
result = await embeddings_client.embed_single("Hello world")
|
||||||
|
|
||||||
assert result == sample_embedding
|
assert result == sample_embedding
|
||||||
mock_httpx_client.post.assert_called_once()
|
mock_httpx_client.post.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_embed_batch(self, embeddings_client, mock_httpx_client, sample_embedding):
|
async def test_embed_batch(self, embeddings_client, mock_httpx_client, sample_embedding):
|
||||||
"""Test embedding multiple texts."""
|
"""Test embedding multiple texts."""
|
||||||
texts = ["Hello", "World"]
|
texts = ["Hello", "World"]
|
||||||
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
"data": [
|
"data": [
|
||||||
@@ -48,41 +47,41 @@ class TestEmbeddingsClient:
|
|||||||
}
|
}
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
mock_httpx_client.post.return_value = mock_response
|
mock_httpx_client.post.return_value = mock_response
|
||||||
|
|
||||||
result = await embeddings_client.embed(texts)
|
result = await embeddings_client.embed(texts)
|
||||||
|
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert all(len(e) == len(sample_embedding) for e in result)
|
assert all(len(e) == len(sample_embedding) for e in result)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_health_check(self, embeddings_client, mock_httpx_client):
|
async def test_health_check(self, embeddings_client, mock_httpx_client):
|
||||||
"""Test health check."""
|
"""Test health check."""
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_httpx_client.get.return_value = mock_response
|
mock_httpx_client.get.return_value = mock_response
|
||||||
|
|
||||||
result = await embeddings_client.health()
|
result = await embeddings_client.health()
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
class TestRerankerClient:
|
class TestRerankerClient:
|
||||||
"""Tests for RerankerClient."""
|
"""Tests for RerankerClient."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def reranker_client(self, mock_httpx_client):
|
def reranker_client(self, mock_httpx_client):
|
||||||
"""Create a RerankerClient with mocked HTTP."""
|
"""Create a RerankerClient with mocked HTTP."""
|
||||||
from handler_base.clients.reranker import RerankerClient
|
from handler_base.clients.reranker import RerankerClient
|
||||||
|
|
||||||
client = RerankerClient()
|
client = RerankerClient()
|
||||||
client._client = mock_httpx_client
|
client._client = mock_httpx_client
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rerank(self, reranker_client, mock_httpx_client, sample_documents):
|
async def test_rerank(self, reranker_client, mock_httpx_client, sample_documents):
|
||||||
"""Test reranking documents."""
|
"""Test reranking documents."""
|
||||||
texts = [d["text"] for d in sample_documents]
|
texts = [d["text"] for d in sample_documents]
|
||||||
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
"results": [
|
"results": [
|
||||||
@@ -93,9 +92,9 @@ class TestRerankerClient:
|
|||||||
}
|
}
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
mock_httpx_client.post.return_value = mock_response
|
mock_httpx_client.post.return_value = mock_response
|
||||||
|
|
||||||
result = await reranker_client.rerank("What is ML?", texts)
|
result = await reranker_client.rerank("What is ML?", texts)
|
||||||
|
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
assert result[0]["score"] == 0.95
|
assert result[0]["score"] == 0.95
|
||||||
assert result[0]["index"] == 1
|
assert result[0]["index"] == 1
|
||||||
@@ -103,53 +102,48 @@ class TestRerankerClient:
|
|||||||
|
|
||||||
class TestLLMClient:
|
class TestLLMClient:
|
||||||
"""Tests for LLMClient."""
|
"""Tests for LLMClient."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def llm_client(self, mock_httpx_client):
|
def llm_client(self, mock_httpx_client):
|
||||||
"""Create an LLMClient with mocked HTTP."""
|
"""Create an LLMClient with mocked HTTP."""
|
||||||
from handler_base.clients.llm import LLMClient
|
from handler_base.clients.llm import LLMClient
|
||||||
|
|
||||||
client = LLMClient()
|
client = LLMClient()
|
||||||
client._client = mock_httpx_client
|
client._client = mock_httpx_client
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate(self, llm_client, mock_httpx_client):
|
async def test_generate(self, llm_client, mock_httpx_client):
|
||||||
"""Test generating a response."""
|
"""Test generating a response."""
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
"choices": [
|
"choices": [{"message": {"content": "Hello! I'm an AI assistant."}}],
|
||||||
{"message": {"content": "Hello! I'm an AI assistant."}}
|
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
|
||||||
],
|
|
||||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20}
|
|
||||||
}
|
}
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
mock_httpx_client.post.return_value = mock_response
|
mock_httpx_client.post.return_value = mock_response
|
||||||
|
|
||||||
result = await llm_client.generate("Hello")
|
result = await llm_client.generate("Hello")
|
||||||
|
|
||||||
assert result == "Hello! I'm an AI assistant."
|
assert result == "Hello! I'm an AI assistant."
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_with_context(self, llm_client, mock_httpx_client):
|
async def test_generate_with_context(self, llm_client, mock_httpx_client):
|
||||||
"""Test generating with RAG context."""
|
"""Test generating with RAG context."""
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
"choices": [
|
"choices": [{"message": {"content": "Based on the context..."}}],
|
||||||
{"message": {"content": "Based on the context..."}}
|
"usage": {},
|
||||||
],
|
|
||||||
"usage": {}
|
|
||||||
}
|
}
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
mock_httpx_client.post.return_value = mock_response
|
mock_httpx_client.post.return_value = mock_response
|
||||||
|
|
||||||
result = await llm_client.generate(
|
result = await llm_client.generate(
|
||||||
"What is Python?",
|
"What is Python?", context="Python is a programming language."
|
||||||
context="Python is a programming language."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "Based on the context" in result
|
assert "Based on the context" in result
|
||||||
|
|
||||||
# Verify context was included in the request
|
# Verify context was included in the request
|
||||||
call_args = mock_httpx_client.post.call_args
|
call_args = mock_httpx_client.post.call_args
|
||||||
messages = call_args.kwargs["json"]["messages"]
|
messages = call_args.kwargs["json"]["messages"]
|
||||||
|
|||||||
@@ -1,46 +1,45 @@
|
|||||||
"""
|
"""
|
||||||
Unit tests for handler_base.config module.
|
Unit tests for handler_base.config module.
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
class TestSettings:
|
class TestSettings:
|
||||||
"""Tests for Settings configuration."""
|
"""Tests for Settings configuration."""
|
||||||
|
|
||||||
def test_default_settings(self, settings):
|
def test_default_settings(self, settings):
|
||||||
"""Test that default settings are loaded correctly."""
|
"""Test that default settings are loaded correctly."""
|
||||||
assert settings.service_name == "test-service"
|
assert settings.service_name == "test-service"
|
||||||
assert settings.service_version == "1.0.0-test"
|
assert settings.service_version == "1.0.0-test"
|
||||||
assert settings.otel_enabled is False
|
assert settings.otel_enabled is False
|
||||||
|
|
||||||
def test_settings_from_env(self, monkeypatch):
|
def test_settings_from_env(self, monkeypatch):
|
||||||
"""Test that settings can be loaded from environment variables."""
|
"""Test that settings can be loaded from environment variables."""
|
||||||
monkeypatch.setenv("SERVICE_NAME", "env-service")
|
monkeypatch.setenv("SERVICE_NAME", "env-service")
|
||||||
monkeypatch.setenv("SERVICE_VERSION", "2.0.0")
|
monkeypatch.setenv("SERVICE_VERSION", "2.0.0")
|
||||||
monkeypatch.setenv("NATS_URL", "nats://custom:4222")
|
monkeypatch.setenv("NATS_URL", "nats://custom:4222")
|
||||||
|
|
||||||
# Need to reimport to pick up env changes
|
# Need to reimport to pick up env changes
|
||||||
from handler_base.config import Settings
|
from handler_base.config import Settings
|
||||||
|
|
||||||
s = Settings()
|
s = Settings()
|
||||||
|
|
||||||
assert s.service_name == "env-service"
|
assert s.service_name == "env-service"
|
||||||
assert s.service_version == "2.0.0"
|
assert s.service_version == "2.0.0"
|
||||||
assert s.nats_url == "nats://custom:4222"
|
assert s.nats_url == "nats://custom:4222"
|
||||||
|
|
||||||
def test_embeddings_settings(self):
|
def test_embeddings_settings(self):
|
||||||
"""Test EmbeddingsSettings extends base correctly."""
|
"""Test EmbeddingsSettings extends base correctly."""
|
||||||
from handler_base.config import EmbeddingsSettings
|
from handler_base.config import EmbeddingsSettings
|
||||||
|
|
||||||
s = EmbeddingsSettings()
|
s = EmbeddingsSettings()
|
||||||
assert hasattr(s, "embeddings_model")
|
assert hasattr(s, "embeddings_model")
|
||||||
assert hasattr(s, "embeddings_batch_size")
|
assert hasattr(s, "embeddings_batch_size")
|
||||||
assert s.embeddings_model == "bge"
|
assert s.embeddings_model == "bge"
|
||||||
|
|
||||||
def test_llm_settings(self):
|
def test_llm_settings(self):
|
||||||
"""Test LLMSettings has expected defaults."""
|
"""Test LLMSettings has expected defaults."""
|
||||||
from handler_base.config import LLMSettings
|
from handler_base.config import LLMSettings
|
||||||
|
|
||||||
s = LLMSettings()
|
s = LLMSettings()
|
||||||
assert s.llm_max_tokens == 2048
|
assert s.llm_max_tokens == 2048
|
||||||
assert s.llm_temperature == 0.7
|
assert s.llm_temperature == 0.7
|
||||||
|
|||||||
@@ -1,101 +1,101 @@
|
|||||||
"""
|
"""
|
||||||
Unit tests for handler_base.health module.
|
Unit tests for handler_base.health module.
|
||||||
"""
|
"""
|
||||||
import pytest
|
|
||||||
import json
|
import json
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from http.client import HTTPConnection
|
from http.client import HTTPConnection
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
class TestHealthServer:
|
class TestHealthServer:
|
||||||
"""Tests for HealthServer."""
|
"""Tests for HealthServer."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def health_server(self, settings):
|
def health_server(self, settings):
|
||||||
"""Create a HealthServer instance."""
|
"""Create a HealthServer instance."""
|
||||||
from handler_base.health import HealthServer
|
from handler_base.health import HealthServer
|
||||||
|
|
||||||
# Use a random high port to avoid conflicts
|
# Use a random high port to avoid conflicts
|
||||||
settings.health_port = 18080
|
settings.health_port = 18080
|
||||||
return HealthServer(settings)
|
return HealthServer(settings)
|
||||||
|
|
||||||
def test_start_stop(self, health_server):
|
def test_start_stop(self, health_server):
|
||||||
"""Test starting and stopping the health server."""
|
"""Test starting and stopping the health server."""
|
||||||
health_server.start()
|
health_server.start()
|
||||||
time.sleep(0.1) # Give server time to start
|
time.sleep(0.1) # Give server time to start
|
||||||
|
|
||||||
# Verify server is running
|
# Verify server is running
|
||||||
assert health_server._server is not None
|
assert health_server._server is not None
|
||||||
assert health_server._thread is not None
|
assert health_server._thread is not None
|
||||||
assert health_server._thread.is_alive()
|
assert health_server._thread.is_alive()
|
||||||
|
|
||||||
health_server.stop()
|
health_server.stop()
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
assert health_server._server is None
|
assert health_server._server is None
|
||||||
|
|
||||||
def test_health_endpoint(self, health_server):
|
def test_health_endpoint(self, health_server):
|
||||||
"""Test the /health endpoint."""
|
"""Test the /health endpoint."""
|
||||||
health_server.start()
|
health_server.start()
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = HTTPConnection("localhost", 18080, timeout=5)
|
conn = HTTPConnection("localhost", 18080, timeout=5)
|
||||||
conn.request("GET", "/health")
|
conn.request("GET", "/health")
|
||||||
response = conn.getresponse()
|
response = conn.getresponse()
|
||||||
|
|
||||||
assert response.status == 200
|
assert response.status == 200
|
||||||
data = json.loads(response.read().decode())
|
data = json.loads(response.read().decode())
|
||||||
assert data["status"] == "healthy"
|
assert data["status"] == "healthy"
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
health_server.stop()
|
health_server.stop()
|
||||||
|
|
||||||
def test_ready_endpoint_default(self, health_server):
|
def test_ready_endpoint_default(self, health_server):
|
||||||
"""Test the /ready endpoint with no custom check."""
|
"""Test the /ready endpoint with no custom check."""
|
||||||
health_server.start()
|
health_server.start()
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = HTTPConnection("localhost", 18080, timeout=5)
|
conn = HTTPConnection("localhost", 18080, timeout=5)
|
||||||
conn.request("GET", "/ready")
|
conn.request("GET", "/ready")
|
||||||
response = conn.getresponse()
|
response = conn.getresponse()
|
||||||
|
|
||||||
assert response.status == 200
|
assert response.status == 200
|
||||||
data = json.loads(response.read().decode())
|
data = json.loads(response.read().decode())
|
||||||
assert data["status"] == "ready"
|
assert data["status"] == "ready"
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
health_server.stop()
|
health_server.stop()
|
||||||
|
|
||||||
def test_ready_endpoint_with_check(self, settings):
|
def test_ready_endpoint_with_check(self, settings):
|
||||||
"""Test /ready endpoint with custom readiness check."""
|
"""Test /ready endpoint with custom readiness check."""
|
||||||
from handler_base.health import HealthServer
|
from handler_base.health import HealthServer
|
||||||
|
|
||||||
ready_flag = [False] # Use list to allow mutation in closure
|
ready_flag = [False] # Use list to allow mutation in closure
|
||||||
|
|
||||||
async def check_ready():
|
async def check_ready():
|
||||||
return ready_flag[0]
|
return ready_flag[0]
|
||||||
|
|
||||||
settings.health_port = 18081
|
settings.health_port = 18081
|
||||||
server = HealthServer(settings, ready_check=check_ready)
|
server = HealthServer(settings, ready_check=check_ready)
|
||||||
server.start()
|
server.start()
|
||||||
time.sleep(0.2)
|
time.sleep(0.2)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = HTTPConnection("localhost", 18081, timeout=5)
|
conn = HTTPConnection("localhost", 18081, timeout=5)
|
||||||
|
|
||||||
# Should be not ready initially
|
# Should be not ready initially
|
||||||
conn.request("GET", "/ready")
|
conn.request("GET", "/ready")
|
||||||
response = conn.getresponse()
|
response = conn.getresponse()
|
||||||
response.read() # Consume response body
|
response.read() # Consume response body
|
||||||
assert response.status == 503
|
assert response.status == 503
|
||||||
|
|
||||||
# Mark as ready
|
# Mark as ready
|
||||||
ready_flag[0] = True
|
ready_flag[0] = True
|
||||||
|
|
||||||
# Need new connection after consuming response
|
# Need new connection after consuming response
|
||||||
conn.close()
|
conn.close()
|
||||||
conn = HTTPConnection("localhost", 18081, timeout=5)
|
conn = HTTPConnection("localhost", 18081, timeout=5)
|
||||||
@@ -105,17 +105,17 @@ class TestHealthServer:
|
|||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
server.stop()
|
server.stop()
|
||||||
|
|
||||||
def test_404_for_unknown_path(self, health_server):
|
def test_404_for_unknown_path(self, health_server):
|
||||||
"""Test that unknown paths return 404."""
|
"""Test that unknown paths return 404."""
|
||||||
health_server.start()
|
health_server.start()
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = HTTPConnection("localhost", 18080, timeout=5)
|
conn = HTTPConnection("localhost", 18080, timeout=5)
|
||||||
conn.request("GET", "/unknown")
|
conn.request("GET", "/unknown")
|
||||||
response = conn.getresponse()
|
response = conn.getresponse()
|
||||||
|
|
||||||
assert response.status == 404
|
assert response.status == 404
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|||||||
@@ -1,48 +1,52 @@
|
|||||||
"""
|
"""
|
||||||
Unit tests for handler_base.nats_client module.
|
Unit tests for handler_base.nats_client module.
|
||||||
"""
|
"""
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
class TestNATSClient:
|
class TestNATSClient:
|
||||||
"""Tests for NATSClient."""
|
"""Tests for NATSClient."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def nats_client(self, settings):
|
def nats_client(self, settings):
|
||||||
"""Create a NATSClient instance."""
|
"""Create a NATSClient instance."""
|
||||||
from handler_base.nats_client import NATSClient
|
from handler_base.nats_client import NATSClient
|
||||||
|
|
||||||
return NATSClient(settings)
|
return NATSClient(settings)
|
||||||
|
|
||||||
def test_init(self, nats_client, settings):
|
def test_init(self, nats_client, settings):
|
||||||
"""Test NATSClient initialization."""
|
"""Test NATSClient initialization."""
|
||||||
assert nats_client.settings == settings
|
assert nats_client.settings == settings
|
||||||
assert nats_client._nc is None
|
assert nats_client._nc is None
|
||||||
assert nats_client._js is None
|
assert nats_client._js is None
|
||||||
|
|
||||||
def test_decode_msgpack(self, nats_client):
|
def test_decode_msgpack(self, nats_client):
|
||||||
"""Test msgpack decoding."""
|
"""Test msgpack decoding."""
|
||||||
data = {"query": "hello", "request_id": "123"}
|
data = {"query": "hello", "request_id": "123"}
|
||||||
encoded = msgpack.packb(data, use_bin_type=True)
|
encoded = msgpack.packb(data, use_bin_type=True)
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
msg.data = encoded
|
msg.data = encoded
|
||||||
|
|
||||||
result = nats_client.decode_msgpack(msg)
|
result = nats_client.decode_msgpack(msg)
|
||||||
assert result == data
|
assert result == data
|
||||||
|
|
||||||
def test_decode_json(self, nats_client):
|
def test_decode_json(self, nats_client):
|
||||||
"""Test JSON decoding."""
|
"""Test JSON decoding."""
|
||||||
import json
|
import json
|
||||||
|
|
||||||
data = {"query": "hello"}
|
data = {"query": "hello"}
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
msg.data = json.dumps(data).encode()
|
msg.data = json.dumps(data).encode()
|
||||||
|
|
||||||
result = nats_client.decode_json(msg)
|
result = nats_client.decode_json(msg)
|
||||||
assert result == data
|
assert result == data
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_connect(self, nats_client):
|
async def test_connect(self, nats_client):
|
||||||
"""Test NATS connection."""
|
"""Test NATS connection."""
|
||||||
@@ -51,30 +55,30 @@ class TestNATSClient:
|
|||||||
mock_js = MagicMock()
|
mock_js = MagicMock()
|
||||||
mock_nc.jetstream = MagicMock(return_value=mock_js) # Not async
|
mock_nc.jetstream = MagicMock(return_value=mock_js) # Not async
|
||||||
mock_nats.connect = AsyncMock(return_value=mock_nc)
|
mock_nats.connect = AsyncMock(return_value=mock_nc)
|
||||||
|
|
||||||
await nats_client.connect()
|
await nats_client.connect()
|
||||||
|
|
||||||
assert nats_client._nc == mock_nc
|
assert nats_client._nc == mock_nc
|
||||||
assert nats_client._js == mock_js
|
assert nats_client._js == mock_js
|
||||||
mock_nats.connect.assert_called_once()
|
mock_nats.connect.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_publish(self, nats_client):
|
async def test_publish(self, nats_client):
|
||||||
"""Test publishing a message."""
|
"""Test publishing a message."""
|
||||||
mock_nc = AsyncMock()
|
mock_nc = AsyncMock()
|
||||||
nats_client._nc = mock_nc
|
nats_client._nc = mock_nc
|
||||||
|
|
||||||
data = {"key": "value"}
|
data = {"key": "value"}
|
||||||
await nats_client.publish("test.subject", data)
|
await nats_client.publish("test.subject", data)
|
||||||
|
|
||||||
mock_nc.publish.assert_called_once()
|
mock_nc.publish.assert_called_once()
|
||||||
call_args = mock_nc.publish.call_args
|
call_args = mock_nc.publish.call_args
|
||||||
assert call_args.args[0] == "test.subject"
|
assert call_args.args[0] == "test.subject"
|
||||||
|
|
||||||
# Verify msgpack encoding
|
# Verify msgpack encoding
|
||||||
decoded = msgpack.unpackb(call_args.args[1], raw=False)
|
decoded = msgpack.unpackb(call_args.args[1], raw=False)
|
||||||
assert decoded == data
|
assert decoded == data
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_subscribe(self, nats_client):
|
async def test_subscribe(self, nats_client):
|
||||||
"""Test subscribing to a subject."""
|
"""Test subscribing to a subject."""
|
||||||
@@ -82,10 +86,10 @@ class TestNATSClient:
|
|||||||
mock_sub = MagicMock()
|
mock_sub = MagicMock()
|
||||||
mock_nc.subscribe = AsyncMock(return_value=mock_sub)
|
mock_nc.subscribe = AsyncMock(return_value=mock_sub)
|
||||||
nats_client._nc = mock_nc
|
nats_client._nc = mock_nc
|
||||||
|
|
||||||
handler = AsyncMock()
|
handler = AsyncMock()
|
||||||
await nats_client.subscribe("test.subject", handler, queue="test-queue")
|
await nats_client.subscribe("test.subject", handler, queue="test-queue")
|
||||||
|
|
||||||
mock_nc.subscribe.assert_called_once()
|
mock_nc.subscribe.assert_called_once()
|
||||||
call_kwargs = mock_nc.subscribe.call_args.kwargs
|
call_kwargs = mock_nc.subscribe.call_args.kwargs
|
||||||
assert call_kwargs["queue"] == "test-queue"
|
assert call_kwargs["queue"] == "test-queue"
|
||||||
|
|||||||
Reference in New Issue
Block a user