""" NATS client wrapper with connection management and utilities. """ import asyncio import logging from typing import Any, Callable, Optional, Awaitable import msgpack import nats from nats.aio.client import Client from nats.aio.msg import Msg from nats.js import JetStreamContext from handler_base.config import Settings from handler_base.telemetry import create_span logger = logging.getLogger(__name__) class NATSClient: """ NATS client with automatic connection management. Supports: - Core NATS pub/sub - JetStream for persistence - Queue groups for load balancing - Msgpack serialization """ def __init__(self, settings: Optional[Settings] = None): self.settings = settings or Settings() self._nc: Optional[Client] = None self._js: Optional[JetStreamContext] = None self._subscriptions: list = [] @property def nc(self) -> Client: """Get the NATS client, raising if not connected.""" if self._nc is None: raise RuntimeError("NATS client not connected. Call connect() first.") return self._nc @property def js(self) -> JetStreamContext: """Get JetStream context, raising if not connected.""" if self._js is None: raise RuntimeError("JetStream not initialized. Call connect() first.") return self._js async def connect(self) -> None: """Connect to NATS server.""" connect_opts = { "servers": self.settings.nats_url, "reconnect_time_wait": 2, "max_reconnect_attempts": -1, # Infinite } if self.settings.nats_user and self.settings.nats_password: connect_opts["user"] = self.settings.nats_user connect_opts["password"] = self.settings.nats_password logger.info(f"Connecting to NATS at {self.settings.nats_url}") self._nc = await nats.connect(**connect_opts) self._js = self._nc.jetstream() logger.info("Connected to NATS") async def close(self) -> None: """Close NATS connection gracefully.""" if self._nc: # Drain subscriptions first for sub in self._subscriptions: try: await sub.drain() except Exception as e: logger.warning(f"Error draining subscription: {e}") await self._nc.drain() await self._nc.close() self._nc = None self._js = None logger.info("NATS connection closed") async def subscribe( self, subject: str, handler: Callable[[Msg], Awaitable[None]], queue: Optional[str] = None, ): """ Subscribe to a subject with a handler function. Args: subject: NATS subject to subscribe to handler: Async function to handle messages queue: Optional queue group for load balancing """ queue = queue or self.settings.nats_queue_group if queue: sub = await self.nc.subscribe(subject, queue=queue, cb=handler) logger.info(f"Subscribed to {subject} (queue: {queue})") else: sub = await self.nc.subscribe(subject, cb=handler) logger.info(f"Subscribed to {subject}") self._subscriptions.append(sub) return sub async def publish( self, subject: str, data: Any, use_msgpack: bool = True, ) -> None: """ Publish a message to a subject. Args: subject: NATS subject to publish to data: Data to publish (will be serialized) use_msgpack: Whether to use msgpack (True) or JSON (False) """ with create_span("nats.publish") as span: if span: span.set_attribute("messaging.destination", subject) if use_msgpack: payload = msgpack.packb(data, use_bin_type=True) else: import json payload = json.dumps(data).encode() await self.nc.publish(subject, payload) async def request( self, subject: str, data: Any, timeout: Optional[float] = None, use_msgpack: bool = True, ) -> Any: """ Send a request and wait for response. Args: subject: NATS subject to send request to data: Request data timeout: Response timeout in seconds use_msgpack: Whether to use msgpack serialization Returns: Decoded response data """ timeout = timeout or self.settings.nats_timeout with create_span("nats.request") as span: if span: span.set_attribute("messaging.destination", subject) if use_msgpack: payload = msgpack.packb(data, use_bin_type=True) else: import json payload = json.dumps(data).encode() response = await self.nc.request(subject, payload, timeout=timeout) if use_msgpack: return msgpack.unpackb(response.data, raw=False) else: import json return json.loads(response.data.decode()) @staticmethod def decode_msgpack(msg: Msg) -> Any: """Decode a msgpack message.""" return msgpack.unpackb(msg.data, raw=False) @staticmethod def decode_json(msg: Msg) -> Any: """Decode a JSON message.""" import json return json.loads(msg.data.decode())