189 lines
5.5 KiB
Python
189 lines
5.5 KiB
Python
"""
|
|
NATS client wrapper with connection management and utilities.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any, Awaitable, Callable, Optional
|
|
|
|
import msgpack
|
|
import nats
|
|
from nats.aio.client import Client
|
|
from nats.aio.msg import Msg
|
|
from nats.js import JetStreamContext
|
|
|
|
from handler_base.config import Settings
|
|
from handler_base.telemetry import create_span
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class NATSClient:
|
|
"""
|
|
NATS client with automatic connection management.
|
|
|
|
Supports:
|
|
- Core NATS pub/sub
|
|
- JetStream for persistence
|
|
- Queue groups for load balancing
|
|
- Msgpack serialization
|
|
"""
|
|
|
|
def __init__(self, settings: Optional[Settings] = None):
|
|
self.settings = settings or Settings()
|
|
self._nc: Optional[Client] = None
|
|
self._js: Optional[JetStreamContext] = None
|
|
self._subscriptions: list = []
|
|
|
|
@property
|
|
def nc(self) -> Client:
|
|
"""Get the NATS client, raising if not connected."""
|
|
if self._nc is None:
|
|
raise RuntimeError("NATS client not connected. Call connect() first.")
|
|
return self._nc
|
|
|
|
@property
|
|
def js(self) -> JetStreamContext:
|
|
"""Get JetStream context, raising if not connected."""
|
|
if self._js is None:
|
|
raise RuntimeError("JetStream not initialized. Call connect() first.")
|
|
return self._js
|
|
|
|
async def connect(self) -> None:
|
|
"""Connect to NATS server."""
|
|
connect_opts = {
|
|
"servers": self.settings.nats_url,
|
|
"reconnect_time_wait": 2,
|
|
"max_reconnect_attempts": -1, # Infinite
|
|
}
|
|
|
|
if self.settings.nats_user and self.settings.nats_password:
|
|
connect_opts["user"] = self.settings.nats_user
|
|
connect_opts["password"] = self.settings.nats_password
|
|
|
|
logger.info(f"Connecting to NATS at {self.settings.nats_url}")
|
|
self._nc = await nats.connect(**connect_opts)
|
|
self._js = self._nc.jetstream()
|
|
logger.info("Connected to NATS")
|
|
|
|
async def close(self) -> None:
|
|
"""Close NATS connection gracefully."""
|
|
if self._nc:
|
|
# Drain subscriptions first
|
|
for sub in self._subscriptions:
|
|
try:
|
|
await sub.drain()
|
|
except Exception as e:
|
|
logger.warning(f"Error draining subscription: {e}")
|
|
|
|
await self._nc.drain()
|
|
await self._nc.close()
|
|
self._nc = None
|
|
self._js = None
|
|
logger.info("NATS connection closed")
|
|
|
|
async def subscribe(
|
|
self,
|
|
subject: str,
|
|
handler: Callable[[Msg], Awaitable[None]],
|
|
queue: Optional[str] = None,
|
|
):
|
|
"""
|
|
Subscribe to a subject with a handler function.
|
|
|
|
Args:
|
|
subject: NATS subject to subscribe to
|
|
handler: Async function to handle messages
|
|
queue: Optional queue group for load balancing
|
|
"""
|
|
queue = queue or self.settings.nats_queue_group
|
|
|
|
if queue:
|
|
sub = await self.nc.subscribe(subject, queue=queue, cb=handler)
|
|
logger.info(f"Subscribed to {subject} (queue: {queue})")
|
|
else:
|
|
sub = await self.nc.subscribe(subject, cb=handler)
|
|
logger.info(f"Subscribed to {subject}")
|
|
|
|
self._subscriptions.append(sub)
|
|
return sub
|
|
|
|
async def publish(
|
|
self,
|
|
subject: str,
|
|
data: Any,
|
|
use_msgpack: bool = True,
|
|
) -> None:
|
|
"""
|
|
Publish a message to a subject.
|
|
|
|
Args:
|
|
subject: NATS subject to publish to
|
|
data: Data to publish (will be serialized)
|
|
use_msgpack: Whether to use msgpack (True) or JSON (False)
|
|
"""
|
|
with create_span("nats.publish") as span:
|
|
if span:
|
|
span.set_attribute("messaging.destination", subject)
|
|
|
|
if use_msgpack:
|
|
payload = msgpack.packb(data, use_bin_type=True)
|
|
else:
|
|
import json
|
|
|
|
payload = json.dumps(data).encode()
|
|
|
|
await self.nc.publish(subject, payload)
|
|
|
|
async def request(
|
|
self,
|
|
subject: str,
|
|
data: Any,
|
|
timeout: Optional[float] = None,
|
|
use_msgpack: bool = True,
|
|
) -> Any:
|
|
"""
|
|
Send a request and wait for response.
|
|
|
|
Args:
|
|
subject: NATS subject to send request to
|
|
data: Request data
|
|
timeout: Response timeout in seconds
|
|
use_msgpack: Whether to use msgpack serialization
|
|
|
|
Returns:
|
|
Decoded response data
|
|
"""
|
|
timeout = timeout or self.settings.nats_timeout
|
|
|
|
with create_span("nats.request") as span:
|
|
if span:
|
|
span.set_attribute("messaging.destination", subject)
|
|
|
|
if use_msgpack:
|
|
payload = msgpack.packb(data, use_bin_type=True)
|
|
else:
|
|
import json
|
|
|
|
payload = json.dumps(data).encode()
|
|
|
|
response = await self.nc.request(subject, payload, timeout=timeout)
|
|
|
|
if use_msgpack:
|
|
return msgpack.unpackb(response.data, raw=False)
|
|
else:
|
|
import json
|
|
|
|
return json.loads(response.data.decode())
|
|
|
|
@staticmethod
|
|
def decode_msgpack(msg: Msg) -> Any:
|
|
"""Decode a msgpack message."""
|
|
return msgpack.unpackb(msg.data, raw=False)
|
|
|
|
@staticmethod
|
|
def decode_json(msg: Msg) -> Any:
|
|
"""Decode a JSON message."""
|
|
import json
|
|
|
|
return json.loads(msg.data.decode())
|