Files
handler-base/handler_base/nats_client.py
Billy D. 99c97b7973 feat: Add handler-base library for NATS AI/ML services
- Handler base class with graceful shutdown and signal handling
- NATSClient with JetStream and msgpack serialization
- Pydantic Settings for environment configuration
- HealthServer for Kubernetes probes
- OpenTelemetry telemetry setup
- Service clients: STT, TTS, LLM, Embeddings, Reranker, Milvus
2026-02-01 20:36:00 -05:00

185 lines
5.7 KiB
Python

"""
NATS client wrapper with connection management and utilities.
"""
import asyncio
import logging
from typing import Any, Callable, Optional, Awaitable
import msgpack
import nats
from nats.aio.client import Client
from nats.aio.msg import Msg
from nats.js import JetStreamContext
from handler_base.config import Settings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class NATSClient:
"""
NATS client with automatic connection management.
Supports:
- Core NATS pub/sub
- JetStream for persistence
- Queue groups for load balancing
- Msgpack serialization
"""
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._nc: Optional[Client] = None
self._js: Optional[JetStreamContext] = None
self._subscriptions: list = []
@property
def nc(self) -> Client:
"""Get the NATS client, raising if not connected."""
if self._nc is None:
raise RuntimeError("NATS client not connected. Call connect() first.")
return self._nc
@property
def js(self) -> JetStreamContext:
"""Get JetStream context, raising if not connected."""
if self._js is None:
raise RuntimeError("JetStream not initialized. Call connect() first.")
return self._js
async def connect(self) -> None:
"""Connect to NATS server."""
connect_opts = {
"servers": self.settings.nats_url,
"reconnect_time_wait": 2,
"max_reconnect_attempts": -1, # Infinite
}
if self.settings.nats_user and self.settings.nats_password:
connect_opts["user"] = self.settings.nats_user
connect_opts["password"] = self.settings.nats_password
logger.info(f"Connecting to NATS at {self.settings.nats_url}")
self._nc = await nats.connect(**connect_opts)
self._js = self._nc.jetstream()
logger.info("Connected to NATS")
async def close(self) -> None:
"""Close NATS connection gracefully."""
if self._nc:
# Drain subscriptions first
for sub in self._subscriptions:
try:
await sub.drain()
except Exception as e:
logger.warning(f"Error draining subscription: {e}")
await self._nc.drain()
await self._nc.close()
self._nc = None
self._js = None
logger.info("NATS connection closed")
async def subscribe(
self,
subject: str,
handler: Callable[[Msg], Awaitable[None]],
queue: Optional[str] = None,
):
"""
Subscribe to a subject with a handler function.
Args:
subject: NATS subject to subscribe to
handler: Async function to handle messages
queue: Optional queue group for load balancing
"""
queue = queue or self.settings.nats_queue_group
if queue:
sub = await self.nc.subscribe(subject, queue=queue, cb=handler)
logger.info(f"Subscribed to {subject} (queue: {queue})")
else:
sub = await self.nc.subscribe(subject, cb=handler)
logger.info(f"Subscribed to {subject}")
self._subscriptions.append(sub)
return sub
async def publish(
self,
subject: str,
data: Any,
use_msgpack: bool = True,
) -> None:
"""
Publish a message to a subject.
Args:
subject: NATS subject to publish to
data: Data to publish (will be serialized)
use_msgpack: Whether to use msgpack (True) or JSON (False)
"""
with create_span("nats.publish") as span:
if span:
span.set_attribute("messaging.destination", subject)
if use_msgpack:
payload = msgpack.packb(data, use_bin_type=True)
else:
import json
payload = json.dumps(data).encode()
await self.nc.publish(subject, payload)
async def request(
self,
subject: str,
data: Any,
timeout: Optional[float] = None,
use_msgpack: bool = True,
) -> Any:
"""
Send a request and wait for response.
Args:
subject: NATS subject to send request to
data: Request data
timeout: Response timeout in seconds
use_msgpack: Whether to use msgpack serialization
Returns:
Decoded response data
"""
timeout = timeout or self.settings.nats_timeout
with create_span("nats.request") as span:
if span:
span.set_attribute("messaging.destination", subject)
if use_msgpack:
payload = msgpack.packb(data, use_bin_type=True)
else:
import json
payload = json.dumps(data).encode()
response = await self.nc.request(subject, payload, timeout=timeout)
if use_msgpack:
return msgpack.unpackb(response.data, raw=False)
else:
import json
return json.loads(response.data.decode())
@staticmethod
def decode_msgpack(msg: Msg) -> Any:
"""Decode a msgpack message."""
return msgpack.unpackb(msg.data, raw=False)
@staticmethod
def decode_json(msg: Msg) -> Any:
"""Decode a JSON message."""
import json
return json.loads(msg.data.decode())