diff --git a/handler/handler.go b/handler/handler.go index d1c6f9b..d1b904a 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -21,6 +21,12 @@ import ( // data is the msgpack-decoded map. Return a response map (or nil for no reply). type MessageHandler func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) +// TypedMessageHandler processes the raw NATS message without pre-decoding to +// map[string]any. Services unmarshal msg.Data into their own typed structs, +// avoiding the double-decode overhead. Return any msgpack-serialisable value +// (a typed struct, map, or nil for no reply). +type TypedMessageHandler func(ctx context.Context, msg *nats.Msg) (any, error) + // SetupFunc is called once before the handler starts processing messages. type SetupFunc func(ctx context.Context) error @@ -35,10 +41,11 @@ type Handler struct { Subject string QueueGroup string - onSetup SetupFunc - onTeardown TeardownFunc - onMessage MessageHandler - running bool + onSetup SetupFunc + onTeardown TeardownFunc + onMessage MessageHandler + onTypedMessage TypedMessageHandler + running bool } // New creates a Handler for the given NATS subject. @@ -70,6 +77,11 @@ func (h *Handler) OnTeardown(fn TeardownFunc) { h.onTeardown = fn } // OnMessage registers the message handler callback. func (h *Handler) OnMessage(fn MessageHandler) { h.onMessage = fn } +// OnTypedMessage registers a typed message handler. It replaces OnMessage — +// wrapHandler will skip the map[string]any decode and let the callback +// unmarshal msg.Data directly. +func (h *Handler) OnTypedMessage(fn TypedMessageHandler) { h.onTypedMessage = fn } + // Run starts the handler: telemetry, health server, NATS subscription, and blocks until SIGTERM/SIGINT. func (h *Handler) Run() error { // Structured logging @@ -119,7 +131,7 @@ func (h *Handler) Run() error { } // Subscribe - if h.onMessage == nil { + if h.onMessage == nil && h.onTypedMessage == nil { return fmt.Errorf("no message handler registered") } if err := h.NATS.Subscribe(h.Subject, h.wrapHandler(ctx), h.QueueGroup); err != nil { @@ -148,8 +160,41 @@ func (h *Handler) Run() error { return nil } -// wrapHandler creates a nats.MsgHandler that decodes msgpack and dispatches to the user handler. +// wrapHandler creates a nats.MsgHandler that dispatches to the registered callback. +// If OnTypedMessage was used, msg.Data is passed directly without map decode. +// If OnMessage was used, msg.Data is decoded to map[string]any first. func (h *Handler) wrapHandler(ctx context.Context) nats.MsgHandler { + if h.onTypedMessage != nil { + return h.wrapTypedHandler(ctx) + } + return h.wrapMapHandler(ctx) +} + +// wrapTypedHandler dispatches to the TypedMessageHandler (no map decode). +func (h *Handler) wrapTypedHandler(ctx context.Context) nats.MsgHandler { + return func(msg *nats.Msg) { + response, err := h.onTypedMessage(ctx, msg) + if err != nil { + slog.Error("handler error", "subject", msg.Subject, "error", err) + if msg.Reply != "" { + _ = h.NATS.Publish(msg.Reply, map[string]any{ + "error": true, + "message": err.Error(), + "type": fmt.Sprintf("%T", err), + }) + } + return + } + if response != nil && msg.Reply != "" { + if err := h.NATS.Publish(msg.Reply, response); err != nil { + slog.Error("failed to publish reply", "error", err) + } + } + } +} + +// wrapMapHandler dispatches to the legacy MessageHandler (decodes to map first). +func (h *Handler) wrapMapHandler(ctx context.Context) nats.MsgHandler { return func(msg *nats.Msg) { data, err := natsutil.DecodeMsgpackMap(msg.Data) if err != nil { diff --git a/handler/handler_test.go b/handler/handler_test.go index d7b21ee..84a1eb0 100644 --- a/handler/handler_test.go +++ b/handler/handler_test.go @@ -73,6 +73,19 @@ func TestCallbackRegistration(t *testing.T) { } } +func TestTypedMessageRegistration(t *testing.T) { + cfg := config.Load() + h := New("ai.test", cfg) + + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) { + return map[string]any{"ok": true}, nil + }) + + if h.onTypedMessage == nil { + t.Error("onTypedMessage should not be nil after registration") + } +} + // ──────────────────────────────────────────────────────────────────────────── // wrapHandler dispatch tests (unit test the message decode + dispatch logic) // ──────────────────────────────────────────────────────────────────────────── @@ -173,6 +186,71 @@ func TestWrapHandler_NilResponse(t *testing.T) { }) } +// ──────────────────────────────────────────────────────────────────────────── +// wrapHandler dispatch tests — typed handler path +// ──────────────────────────────────────────────────────────────────────────── + +func TestWrapTypedHandler_ValidMessage(t *testing.T) { + cfg := config.Load() + h := New("ai.test", cfg) + + type testReq struct { + RequestID string `msgpack:"request_id"` + Message string `msgpack:"message"` + } + + var received testReq + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) { + if err := msgpack.Unmarshal(msg.Data, &received); err != nil { + return nil, err + } + return map[string]any{"status": "ok"}, nil + }) + + encoded, _ := msgpack.Marshal(map[string]any{ + "request_id": "typed-001", + "message": "hello typed", + }) + + handler := h.wrapHandler(context.Background()) + handler(&nats.Msg{Subject: "ai.test", Data: encoded}) + + if received.RequestID != "typed-001" { + t.Errorf("RequestID = %q", received.RequestID) + } + if received.Message != "hello typed" { + t.Errorf("Message = %q", received.Message) + } +} + +func TestWrapTypedHandler_Error(t *testing.T) { + cfg := config.Load() + h := New("ai.test", cfg) + + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) { + return nil, context.DeadlineExceeded + }) + + encoded, _ := msgpack.Marshal(map[string]any{"key": "val"}) + handler := h.wrapHandler(context.Background()) + + // Should not panic. + handler(&nats.Msg{Subject: "ai.test", Data: encoded}) +} + +func TestWrapTypedHandler_NilResponse(t *testing.T) { + cfg := config.Load() + h := New("ai.test", cfg) + + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) { + return nil, nil + }) + + encoded, _ := msgpack.Marshal(map[string]any{"x": 1}) + handler := h.wrapHandler(context.Background()) + handler(&nats.Msg{Subject: "ai.test", Data: encoded}) +} + // ──────────────────────────────────────────────────────────────────────────── // Benchmark: message decode + dispatch overhead // ──────────────────────────────────────────────────────────────────────────── @@ -199,3 +277,35 @@ func BenchmarkWrapHandler(b *testing.B) { handler(msg) } } + +func BenchmarkWrapTypedHandler(b *testing.B) { + type benchReq struct { + RequestID string `msgpack:"request_id"` + Message string `msgpack:"message"` + Premium bool `msgpack:"premium"` + TopK int `msgpack:"top_k"` + } + + cfg := config.Load() + h := New("ai.test", cfg) + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) { + var req benchReq + msgpack.Unmarshal(msg.Data, &req) + return map[string]any{"ok": true}, nil + }) + + payload := map[string]any{ + "request_id": "bench-001", + "message": "What is the capital of France?", + "premium": true, + "top_k": 10, + } + encoded, _ := msgpack.Marshal(payload) + handler := h.wrapHandler(context.Background()) + msg := &nats.Msg{Subject: "ai.test", Data: encoded} + + b.ResetTimer() + for b.Loop() { + handler(msg) + } +} diff --git a/natsutil/natsutil.go b/natsutil/natsutil.go index a8289fe..747d66e 100644 --- a/natsutil/natsutil.go +++ b/natsutil/natsutil.go @@ -124,6 +124,14 @@ func DecodeMsgpack(msg *nats.Msg, dest any) error { return msgpack.Unmarshal(msg.Data, dest) } +// Decode is a generic helper that unmarshals msgpack bytes into T. +// Usage: req, err := natsutil.Decode[messages.ChatRequest](msg.Data) +func Decode[T any](data []byte) (T, error) { + var v T + err := msgpack.Unmarshal(data, &v) + return v, err +} + // DecodeMsgpackMap decodes msgpack data into a generic map. func DecodeMsgpackMap(data []byte) (map[string]any, error) { var m map[string]any