fix: auto-fix ruff linting errors and remove unsupported upload-artifact
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
"""
|
||||
Service client wrappers for AI/ML backends.
|
||||
"""
|
||||
|
||||
from handler_base.clients.embeddings import EmbeddingsClient
|
||||
from handler_base.clients.reranker import RerankerClient
|
||||
from handler_base.clients.llm import LLMClient
|
||||
from handler_base.clients.tts import TTSClient
|
||||
from handler_base.clients.stt import STTClient
|
||||
from handler_base.clients.milvus import MilvusClient
|
||||
from handler_base.clients.reranker import RerankerClient
|
||||
from handler_base.clients.stt import STTClient
|
||||
from handler_base.clients.tts import TTSClient
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingsClient",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Embeddings service client (Infinity/BGE).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@@ -15,23 +16,23 @@ logger = logging.getLogger(__name__)
|
||||
class EmbeddingsClient:
|
||||
"""
|
||||
Client for the embeddings service (Infinity with BGE models).
|
||||
|
||||
|
||||
Usage:
|
||||
client = EmbeddingsClient()
|
||||
embeddings = await client.embed(["Hello world"])
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, settings: Optional[EmbeddingsSettings] = None):
|
||||
self.settings = settings or EmbeddingsSettings()
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.settings.embeddings_url,
|
||||
timeout=self.settings.http_timeout,
|
||||
)
|
||||
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
await self._client.aclose()
|
||||
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
texts: list[str],
|
||||
@@ -39,49 +40,49 @@ class EmbeddingsClient:
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for a list of texts.
|
||||
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
model: Model name (defaults to settings)
|
||||
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
model = model or self.settings.embeddings_model
|
||||
|
||||
|
||||
with create_span("embeddings.embed") as span:
|
||||
if span:
|
||||
span.set_attribute("embeddings.model", model)
|
||||
span.set_attribute("embeddings.batch_size", len(texts))
|
||||
|
||||
|
||||
response = await self._client.post(
|
||||
"/embeddings",
|
||||
json={"input": texts, "model": model},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
result = response.json()
|
||||
embeddings = [d["embedding"] for d in result.get("data", [])]
|
||||
|
||||
|
||||
if span:
|
||||
span.set_attribute("embeddings.dimensions", len(embeddings[0]) if embeddings else 0)
|
||||
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
async def embed_single(self, text: str, model: Optional[str] = None) -> list[float]:
|
||||
"""
|
||||
Generate embedding for a single text.
|
||||
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
model: Model name (defaults to settings)
|
||||
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
embeddings = await self.embed([text], model)
|
||||
return embeddings[0] if embeddings else []
|
||||
|
||||
|
||||
async def health(self) -> bool:
|
||||
"""Check if the embeddings service is healthy."""
|
||||
try:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""
|
||||
LLM service client (vLLM/OpenAI-compatible).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, AsyncIterator
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -15,33 +16,33 @@ logger = logging.getLogger(__name__)
|
||||
class LLMClient:
|
||||
"""
|
||||
Client for the LLM service (vLLM with OpenAI-compatible API).
|
||||
|
||||
|
||||
Usage:
|
||||
client = LLMClient()
|
||||
response = await client.generate("Hello, how are you?")
|
||||
|
||||
|
||||
# With context for RAG
|
||||
response = await client.generate(
|
||||
"What is the capital?",
|
||||
context="France is a country in Europe..."
|
||||
)
|
||||
|
||||
|
||||
# Streaming
|
||||
async for chunk in client.stream("Tell me a story"):
|
||||
print(chunk, end="")
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, settings: Optional[LLMSettings] = None):
|
||||
self.settings = settings or LLMSettings()
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.settings.llm_url,
|
||||
timeout=self.settings.http_timeout,
|
||||
)
|
||||
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
await self._client.aclose()
|
||||
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -54,7 +55,7 @@ class LLMClient:
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response from the LLM.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: User prompt/query
|
||||
context: Optional context for RAG
|
||||
@@ -63,19 +64,19 @@ class LLMClient:
|
||||
temperature: Sampling temperature
|
||||
top_p: Top-p sampling
|
||||
stop: Stop sequences
|
||||
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
with create_span("llm.generate") as span:
|
||||
messages = self._build_messages(prompt, context, system_prompt)
|
||||
|
||||
|
||||
if span:
|
||||
span.set_attribute("llm.model", self.settings.llm_model)
|
||||
span.set_attribute("llm.prompt_length", len(prompt))
|
||||
if context:
|
||||
span.set_attribute("llm.context_length", len(context))
|
||||
|
||||
|
||||
payload = {
|
||||
"model": self.settings.llm_model,
|
||||
"messages": messages,
|
||||
@@ -85,21 +86,21 @@ class LLMClient:
|
||||
}
|
||||
if stop:
|
||||
payload["stop"] = stop
|
||||
|
||||
|
||||
response = await self._client.post("/v1/chat/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
if span:
|
||||
span.set_attribute("llm.response_length", len(content))
|
||||
usage = result.get("usage", {})
|
||||
span.set_attribute("llm.prompt_tokens", usage.get("prompt_tokens", 0))
|
||||
span.set_attribute("llm.completion_tokens", usage.get("completion_tokens", 0))
|
||||
|
||||
|
||||
return content
|
||||
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -110,19 +111,19 @@ class LLMClient:
|
||||
) -> AsyncIterator[str]:
|
||||
"""
|
||||
Stream a response from the LLM.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: User prompt/query
|
||||
context: Optional context for RAG
|
||||
system_prompt: Optional system prompt
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Sampling temperature
|
||||
|
||||
|
||||
Yields:
|
||||
Text chunks as they're generated
|
||||
"""
|
||||
messages = self._build_messages(prompt, context, system_prompt)
|
||||
|
||||
|
||||
payload = {
|
||||
"model": self.settings.llm_model,
|
||||
"messages": messages,
|
||||
@@ -130,25 +131,24 @@ class LLMClient:
|
||||
"temperature": temperature or self.settings.llm_temperature,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
async with self._client.stream(
|
||||
"POST", "/v1/chat/completions", json=payload
|
||||
) as response:
|
||||
|
||||
async with self._client.stream("POST", "/v1/chat/completions", json=payload) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data = line[6:]
|
||||
if data == "[DONE]":
|
||||
break
|
||||
|
||||
|
||||
import json
|
||||
|
||||
chunk = json.loads(data)
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -157,32 +157,36 @@ class LLMClient:
|
||||
) -> list[dict]:
|
||||
"""Build the messages list for the API call."""
|
||||
messages = []
|
||||
|
||||
|
||||
# System prompt
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
elif context:
|
||||
# Default RAG system prompt
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a helpful assistant. Use the provided context to answer "
|
||||
"the user's question. If the context doesn't contain relevant "
|
||||
"information, say so."
|
||||
),
|
||||
})
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a helpful assistant. Use the provided context to answer "
|
||||
"the user's question. If the context doesn't contain relevant "
|
||||
"information, say so."
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Add context as a separate message if provided
|
||||
if context:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"Context:\n{context}\n\nQuestion: {prompt}",
|
||||
})
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Context:\n{context}\n\nQuestion: {prompt}",
|
||||
}
|
||||
)
|
||||
else:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
async def health(self) -> bool:
|
||||
"""Check if the LLM service is healthy."""
|
||||
try:
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""
|
||||
Milvus vector database client.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional, Any
|
||||
|
||||
from pymilvus import connections, Collection, utility
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pymilvus import Collection, connections, utility
|
||||
|
||||
from handler_base.config import Settings
|
||||
from handler_base.telemetry import create_span
|
||||
@@ -15,42 +16,42 @@ logger = logging.getLogger(__name__)
|
||||
class MilvusClient:
|
||||
"""
|
||||
Client for Milvus vector database.
|
||||
|
||||
|
||||
Usage:
|
||||
client = MilvusClient()
|
||||
await client.connect()
|
||||
results = await client.search(embedding, limit=10)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, settings: Optional[Settings] = None):
|
||||
self.settings = settings or Settings()
|
||||
self._connected = False
|
||||
self._collection: Optional[Collection] = None
|
||||
|
||||
|
||||
async def connect(self, collection_name: Optional[str] = None) -> None:
|
||||
"""
|
||||
Connect to Milvus and load collection.
|
||||
|
||||
|
||||
Args:
|
||||
collection_name: Collection to use (defaults to settings)
|
||||
"""
|
||||
collection_name = collection_name or self.settings.milvus_collection
|
||||
|
||||
|
||||
connections.connect(
|
||||
alias="default",
|
||||
host=self.settings.milvus_host,
|
||||
port=self.settings.milvus_port,
|
||||
)
|
||||
|
||||
|
||||
if utility.has_collection(collection_name):
|
||||
self._collection = Collection(collection_name)
|
||||
self._collection.load()
|
||||
logger.info(f"Connected to Milvus collection: {collection_name}")
|
||||
else:
|
||||
logger.warning(f"Collection {collection_name} does not exist")
|
||||
|
||||
|
||||
self._connected = True
|
||||
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Milvus connection."""
|
||||
if self._collection:
|
||||
@@ -58,7 +59,7 @@ class MilvusClient:
|
||||
connections.disconnect("default")
|
||||
self._connected = False
|
||||
logger.info("Disconnected from Milvus")
|
||||
|
||||
|
||||
async def search(
|
||||
self,
|
||||
embedding: list[float],
|
||||
@@ -68,26 +69,26 @@ class MilvusClient:
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
|
||||
Args:
|
||||
embedding: Query embedding vector
|
||||
limit: Maximum number of results
|
||||
output_fields: Fields to return (default: all)
|
||||
filter_expr: Optional filter expression
|
||||
|
||||
|
||||
Returns:
|
||||
List of results with 'id', 'distance', and requested fields
|
||||
"""
|
||||
if not self._collection:
|
||||
raise RuntimeError("Not connected to collection")
|
||||
|
||||
|
||||
with create_span("milvus.search") as span:
|
||||
if span:
|
||||
span.set_attribute("milvus.collection", self._collection.name)
|
||||
span.set_attribute("milvus.limit", limit)
|
||||
|
||||
|
||||
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
||||
|
||||
|
||||
results = self._collection.search(
|
||||
data=[embedding],
|
||||
anns_field="embedding",
|
||||
@@ -96,7 +97,7 @@ class MilvusClient:
|
||||
output_fields=output_fields,
|
||||
expr=filter_expr,
|
||||
)
|
||||
|
||||
|
||||
# Convert to list of dicts
|
||||
hits = []
|
||||
for hit in results[0]:
|
||||
@@ -111,12 +112,12 @@ class MilvusClient:
|
||||
if hasattr(hit.entity, field):
|
||||
item[field] = getattr(hit.entity, field)
|
||||
hits.append(item)
|
||||
|
||||
|
||||
if span:
|
||||
span.set_attribute("milvus.results", len(hits))
|
||||
|
||||
|
||||
return hits
|
||||
|
||||
|
||||
async def search_with_texts(
|
||||
self,
|
||||
embedding: list[float],
|
||||
@@ -126,22 +127,22 @@ class MilvusClient:
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search and return text content with metadata.
|
||||
|
||||
|
||||
Args:
|
||||
embedding: Query embedding
|
||||
limit: Maximum results
|
||||
text_field: Name of text field in collection
|
||||
metadata_fields: Additional metadata fields to return
|
||||
|
||||
|
||||
Returns:
|
||||
List of results with text and metadata
|
||||
"""
|
||||
output_fields = [text_field]
|
||||
if metadata_fields:
|
||||
output_fields.extend(metadata_fields)
|
||||
|
||||
|
||||
return await self.search(embedding, limit, output_fields)
|
||||
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
@@ -149,34 +150,34 @@ class MilvusClient:
|
||||
) -> list[int]:
|
||||
"""
|
||||
Insert vectors with data into the collection.
|
||||
|
||||
|
||||
Args:
|
||||
embeddings: List of embedding vectors
|
||||
data: List of dicts with field values
|
||||
|
||||
|
||||
Returns:
|
||||
List of inserted IDs
|
||||
"""
|
||||
if not self._collection:
|
||||
raise RuntimeError("Not connected to collection")
|
||||
|
||||
|
||||
with create_span("milvus.insert") as span:
|
||||
if span:
|
||||
span.set_attribute("milvus.collection", self._collection.name)
|
||||
span.set_attribute("milvus.count", len(embeddings))
|
||||
|
||||
|
||||
# Build insert data
|
||||
insert_data = [embeddings]
|
||||
for field in self._collection.schema.fields:
|
||||
if field.name not in ("id", "embedding"):
|
||||
field_values = [d.get(field.name) for d in data]
|
||||
insert_data.append(field_values)
|
||||
|
||||
|
||||
result = self._collection.insert(insert_data)
|
||||
self._collection.flush()
|
||||
|
||||
|
||||
return result.primary_keys
|
||||
|
||||
|
||||
def health(self) -> bool:
|
||||
"""Check if connected to Milvus."""
|
||||
return self._connected and utility.get_connection_addr("default") is not None
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Reranker service client (Infinity/BGE Reranker).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@@ -15,23 +16,23 @@ logger = logging.getLogger(__name__)
|
||||
class RerankerClient:
|
||||
"""
|
||||
Client for the reranker service (Infinity with BGE Reranker).
|
||||
|
||||
|
||||
Usage:
|
||||
client = RerankerClient()
|
||||
reranked = await client.rerank("query", ["doc1", "doc2"])
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, settings: Optional[Settings] = None):
|
||||
self.settings = settings or Settings()
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.settings.reranker_url,
|
||||
timeout=self.settings.http_timeout,
|
||||
)
|
||||
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
await self._client.aclose()
|
||||
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
query: str,
|
||||
@@ -40,12 +41,12 @@ class RerankerClient:
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Rerank documents based on relevance to query.
|
||||
|
||||
|
||||
Args:
|
||||
query: Query text
|
||||
documents: List of documents to rerank
|
||||
top_k: Number of top results to return (default: all)
|
||||
|
||||
|
||||
Returns:
|
||||
List of dicts with 'index', 'score', and 'document' keys,
|
||||
sorted by relevance score descending.
|
||||
@@ -55,32 +56,34 @@ class RerankerClient:
|
||||
span.set_attribute("reranker.num_documents", len(documents))
|
||||
if top_k:
|
||||
span.set_attribute("reranker.top_k", top_k)
|
||||
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
}
|
||||
if top_k:
|
||||
payload["top_n"] = top_k
|
||||
|
||||
|
||||
response = await self._client.post("/rerank", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
result = response.json()
|
||||
results = result.get("results", [])
|
||||
|
||||
|
||||
# Enrich with original documents
|
||||
enriched = []
|
||||
for r in results:
|
||||
idx = r.get("index", 0)
|
||||
enriched.append({
|
||||
"index": idx,
|
||||
"score": r.get("relevance_score", r.get("score", 0)),
|
||||
"document": documents[idx] if idx < len(documents) else "",
|
||||
})
|
||||
|
||||
enriched.append(
|
||||
{
|
||||
"index": idx,
|
||||
"score": r.get("relevance_score", r.get("score", 0)),
|
||||
"document": documents[idx] if idx < len(documents) else "",
|
||||
}
|
||||
)
|
||||
|
||||
return enriched
|
||||
|
||||
|
||||
async def rerank_with_metadata(
|
||||
self,
|
||||
query: str,
|
||||
@@ -90,27 +93,27 @@ class RerankerClient:
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Rerank documents with metadata, preserving metadata in results.
|
||||
|
||||
|
||||
Args:
|
||||
query: Query text
|
||||
documents: List of dicts with text and metadata
|
||||
text_key: Key containing text in each document dict
|
||||
top_k: Number of top results to return
|
||||
|
||||
|
||||
Returns:
|
||||
Reranked documents with original metadata preserved.
|
||||
"""
|
||||
texts = [d.get(text_key, "") for d in documents]
|
||||
reranked = await self.rerank(query, texts, top_k)
|
||||
|
||||
|
||||
# Merge back metadata
|
||||
for r in reranked:
|
||||
idx = r["index"]
|
||||
if idx < len(documents):
|
||||
r["metadata"] = {k: v for k, v in documents[idx].items() if k != text_key}
|
||||
|
||||
|
||||
return reranked
|
||||
|
||||
|
||||
async def health(self) -> bool:
|
||||
"""Check if the reranker service is healthy."""
|
||||
try:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
STT service client (Whisper/faster-whisper).
|
||||
"""
|
||||
import io
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@@ -16,23 +16,23 @@ logger = logging.getLogger(__name__)
|
||||
class STTClient:
|
||||
"""
|
||||
Client for the STT service (Whisper/faster-whisper).
|
||||
|
||||
|
||||
Usage:
|
||||
client = STTClient()
|
||||
text = await client.transcribe(audio_bytes)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, settings: Optional[STTSettings] = None):
|
||||
self.settings = settings or STTSettings()
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.settings.stt_url,
|
||||
timeout=180.0, # Transcription can be slow
|
||||
)
|
||||
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
await self._client.aclose()
|
||||
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio: bytes,
|
||||
@@ -42,54 +42,54 @@ class STTClient:
|
||||
) -> dict:
|
||||
"""
|
||||
Transcribe audio to text.
|
||||
|
||||
|
||||
Args:
|
||||
audio: Audio bytes (WAV, MP3, etc.)
|
||||
language: Language code (None for auto-detect)
|
||||
task: "transcribe" or "translate"
|
||||
response_format: "json", "text", "srt", "vtt"
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with 'text', 'language', and optional 'segments'
|
||||
"""
|
||||
language = language or self.settings.stt_language
|
||||
task = task or self.settings.stt_task
|
||||
|
||||
|
||||
with create_span("stt.transcribe") as span:
|
||||
if span:
|
||||
span.set_attribute("stt.task", task)
|
||||
span.set_attribute("stt.audio_size", len(audio))
|
||||
if language:
|
||||
span.set_attribute("stt.language", language)
|
||||
|
||||
|
||||
files = {"file": ("audio.wav", audio, "audio/wav")}
|
||||
data = {
|
||||
"response_format": response_format,
|
||||
}
|
||||
if language:
|
||||
data["language"] = language
|
||||
|
||||
|
||||
# Choose endpoint based on task
|
||||
if task == "translate":
|
||||
endpoint = "/v1/audio/translations"
|
||||
else:
|
||||
endpoint = "/v1/audio/transcriptions"
|
||||
|
||||
|
||||
response = await self._client.post(endpoint, files=files, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
if response_format == "text":
|
||||
return {"text": response.text}
|
||||
|
||||
|
||||
result = response.json()
|
||||
|
||||
|
||||
if span:
|
||||
span.set_attribute("stt.result_length", len(result.get("text", "")))
|
||||
if result.get("language"):
|
||||
span.set_attribute("stt.detected_language", result["language"])
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def transcribe_file(
|
||||
self,
|
||||
file_path: str,
|
||||
@@ -98,31 +98,31 @@ class STTClient:
|
||||
) -> dict:
|
||||
"""
|
||||
Transcribe an audio file.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to audio file
|
||||
language: Language code
|
||||
task: "transcribe" or "translate"
|
||||
|
||||
|
||||
Returns:
|
||||
Transcription result
|
||||
"""
|
||||
with open(file_path, "rb") as f:
|
||||
audio = f.read()
|
||||
return await self.transcribe(audio, language, task)
|
||||
|
||||
|
||||
async def translate(self, audio: bytes) -> dict:
|
||||
"""
|
||||
Translate audio to English.
|
||||
|
||||
|
||||
Args:
|
||||
audio: Audio bytes
|
||||
|
||||
|
||||
Returns:
|
||||
Translation result with 'text' key
|
||||
"""
|
||||
return await self.transcribe(audio, task="translate")
|
||||
|
||||
|
||||
async def health(self) -> bool:
|
||||
"""Check if the STT service is healthy."""
|
||||
try:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
TTS service client (Coqui XTTS).
|
||||
"""
|
||||
import io
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@@ -16,23 +16,23 @@ logger = logging.getLogger(__name__)
|
||||
class TTSClient:
|
||||
"""
|
||||
Client for the TTS service (Coqui XTTS).
|
||||
|
||||
|
||||
Usage:
|
||||
client = TTSClient()
|
||||
audio_bytes = await client.synthesize("Hello world")
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, settings: Optional[TTSSettings] = None):
|
||||
self.settings = settings or TTSSettings()
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.settings.tts_url,
|
||||
timeout=120.0, # TTS can be slow
|
||||
)
|
||||
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
await self._client.aclose()
|
||||
|
||||
|
||||
async def synthesize(
|
||||
self,
|
||||
text: str,
|
||||
@@ -41,39 +41,39 @@ class TTSClient:
|
||||
) -> bytes:
|
||||
"""
|
||||
Synthesize speech from text.
|
||||
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
language: Language code (e.g., "en", "es", "fr")
|
||||
speaker: Speaker ID or reference
|
||||
|
||||
|
||||
Returns:
|
||||
WAV audio bytes
|
||||
"""
|
||||
language = language or self.settings.tts_language
|
||||
|
||||
|
||||
with create_span("tts.synthesize") as span:
|
||||
if span:
|
||||
span.set_attribute("tts.language", language)
|
||||
span.set_attribute("tts.text_length", len(text))
|
||||
|
||||
|
||||
params = {
|
||||
"text": text,
|
||||
"language_id": language,
|
||||
}
|
||||
if speaker:
|
||||
params["speaker_id"] = speaker
|
||||
|
||||
|
||||
response = await self._client.get("/api/tts", params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
audio_bytes = response.content
|
||||
|
||||
|
||||
if span:
|
||||
span.set_attribute("tts.audio_size", len(audio_bytes))
|
||||
|
||||
|
||||
return audio_bytes
|
||||
|
||||
|
||||
async def synthesize_to_file(
|
||||
self,
|
||||
text: str,
|
||||
@@ -83,7 +83,7 @@ class TTSClient:
|
||||
) -> None:
|
||||
"""
|
||||
Synthesize speech and save to a file.
|
||||
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
output_path: Path to save the audio file
|
||||
@@ -91,10 +91,10 @@ class TTSClient:
|
||||
speaker: Speaker ID
|
||||
"""
|
||||
audio_bytes = await self.synthesize(text, language, speaker)
|
||||
|
||||
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(audio_bytes)
|
||||
|
||||
|
||||
async def get_speakers(self) -> list[dict]:
|
||||
"""Get available speakers/voices."""
|
||||
try:
|
||||
@@ -103,7 +103,7 @@ class TTSClient:
|
||||
return response.json()
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
async def health(self) -> bool:
|
||||
"""Check if the TTS service is healthy."""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user