diff --git a/e2e_test.go b/e2e_test.go index 02d0bad..a689e9b 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -13,7 +13,7 @@ import ( "git.daviestechlabs.io/daviestechlabs/handler-base/clients" "git.daviestechlabs.io/daviestechlabs/handler-base/messages" - "github.com/vmihailenco/msgpack/v5" + "google.golang.org/protobuf/proto" ) // ──────────────────────────────────────────────────────────────────────────── @@ -167,33 +167,33 @@ func TestChatPipeline_LLMTimeout(t *testing.T) { } func TestChatPipeline_TypedDecoding(t *testing.T) { - // Verify typed struct decoding from msgpack (same path as OnTypedMessage). - raw := map[string]any{ - "request_id": "req-e2e-001", - "user_id": "user-1", - "message": "hello", - "premium": true, - "enable_rag": false, - "enable_streaming": false, - "system_prompt": "Be brief.", + // Verify typed struct decoding from proto (same path as OnTypedMessage). + original := &messages.ChatRequest{ + RequestId: "req-e2e-001", + UserId: "user-1", + Message: "hello", + Premium: true, + EnableRag: false, + EnableStreaming: false, + SystemPrompt: "Be brief.", } - data, _ := msgpack.Marshal(raw) + data, _ := proto.Marshal(original) var req messages.ChatRequest - if err := msgpack.Unmarshal(data, &req); err != nil { + if err := proto.Unmarshal(data, &req); err != nil { t.Fatal(err) } - if req.RequestID != "req-e2e-001" { - t.Errorf("RequestID = %q", req.RequestID) + if req.RequestId != "req-e2e-001" { + t.Errorf("RequestID = %q", req.RequestId) } - if req.UserID != "user-1" { - t.Errorf("UserID = %q", req.UserID) + if req.UserId != "user-1" { + t.Errorf("UserID = %q", req.UserId) } - if req.EffectiveQuery() != "hello" { - t.Errorf("query = %q", req.EffectiveQuery()) + if messages.EffectiveQuery(&req) != "hello" { + t.Errorf("query = %q", messages.EffectiveQuery(&req)) } - if req.EnableRAG { + if req.EnableRag { t.Error("EnableRAG should be false") } if req.SystemPrompt != "Be brief." { diff --git a/go.mod b/go.mod index 5c8828c..0d414b7 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,9 @@ module git.daviestechlabs.io/daviestechlabs/chat-handler go 1.25.1 require ( - git.daviestechlabs.io/daviestechlabs/handler-base v0.1.5 + git.daviestechlabs.io/daviestechlabs/handler-base v1.0.0 github.com/nats-io/nats.go v1.48.0 - github.com/vmihailenco/msgpack/v5 v5.4.1 + google.golang.org/protobuf v1.36.11 ) require ( @@ -19,7 +19,6 @@ require ( github.com/klauspost/compress v1.18.0 // indirect github.com/nats-io/nkeys v0.4.11 // indirect github.com/nats-io/nuid v1.0.1 // indirect - github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel v1.40.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0 // indirect @@ -37,5 +36,4 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect google.golang.org/grpc v1.78.0 // indirect - google.golang.org/protobuf v1.36.11 // indirect ) diff --git a/go.sum b/go.sum index efceef0..5d7dfc1 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -git.daviestechlabs.io/daviestechlabs/handler-base v0.1.5 h1:DqYZpeluTXh5QKqdVFgN8YIMh4Ycqzw5E9+5FTNDFCA= -git.daviestechlabs.io/daviestechlabs/handler-base v0.1.5/go.mod h1:M3HgvUDWnRn7cX3BE8l+HvoCUYtmRr5OoumB+hnRHoE= +git.daviestechlabs.io/daviestechlabs/handler-base v1.0.0 h1:pB3ehOKaDYQfbyRBKQXrB9curqSFteLrDveoElRKnBY= +git.daviestechlabs.io/daviestechlabs/handler-base v1.0.0/go.mod h1:zocOHFt8yY3cW4+Xi37sNr5Tw7KcjGFSZqgWYxPWyqA= github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -33,10 +33,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= -github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= -github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= -github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms= diff --git a/main.go b/main.go index cd1b106..6227842 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ import ( "git.daviestechlabs.io/daviestechlabs/handler-base/handler" "git.daviestechlabs.io/daviestechlabs/handler-base/messages" "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" + "google.golang.org/protobuf/proto" ) func main() { @@ -45,23 +46,23 @@ func main() { h := handler.New("ai.chat.user.*.message", cfg) - h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) { - req, err := natsutil.Decode[messages.ChatRequest](msg.Data) - if err != nil { + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { + var req messages.ChatRequest + if err := natsutil.Decode(msg.Data, &req); err != nil { slog.Error("decode failed", "error", err) return &messages.ErrorResponse{Error: true, Message: err.Error(), Type: "DecodeError"}, nil } - query := req.EffectiveQuery() - requestID := req.RequestID + query := messages.EffectiveQuery(&req) + requestID := req.RequestId if requestID == "" { requestID = "unknown" } - userID := req.UserID + userID := req.UserId if userID == "" { userID = "unknown" } - enableRAG := req.EnableRAG + enableRAG := req.EnableRag if !enableRAG && req.Premium { enableRAG = true } @@ -71,13 +72,13 @@ func main() { } topK := req.TopK if topK == 0 { - topK = ragTopK + topK = int32(ragTopK) } collection := req.Collection if collection == "" { collection = ragCollection } - reqEnableTTS := req.EnableTTS || enableTTS + reqEnableTTS := req.EnableTts || enableTTS systemPrompt := req.SystemPrompt responseSubject := req.ResponseSubject if responseSubject == "" { @@ -159,18 +160,19 @@ func main() { // 5. Generate LLM response (streaming when requested) var responseText string + var err error if req.EnableStreaming { streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID) responseText, err = llm.StreamGenerate(ctx, query, contextText, systemPrompt, func(token string) { _ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{ - RequestID: requestID, + RequestId: requestID, Type: "chunk", Content: token, Timestamp: messages.Timestamp(), }) }) _ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{ - RequestID: requestID, + RequestId: requestID, Type: "done", Done: true, Timestamp: messages.Timestamp(), @@ -181,7 +183,7 @@ func main() { if err != nil { slog.Error("LLM generation failed", "error", err) return &messages.ChatResponse{ - UserID: userID, + UserId: userID, Success: false, Error: err.Error(), }, nil @@ -199,15 +201,15 @@ func main() { } result := &messages.ChatResponse{ - UserID: userID, + UserId: userID, Response: responseText, ResponseText: responseText, - UsedRAG: usedRAG, + UsedRag: usedRAG, Success: true, Audio: audio, } if includeSources { - result.RAGSources = ragSources + result.RagSources = ragSources } // Publish to the response subject the frontend is waiting on diff --git a/main_test.go b/main_test.go index d87d4c5..bda9d6c 100644 --- a/main_test.go +++ b/main_test.go @@ -5,28 +5,31 @@ import ( "testing" "git.daviestechlabs.io/daviestechlabs/handler-base/messages" - "github.com/vmihailenco/msgpack/v5" + "google.golang.org/protobuf/proto" ) func TestChatRequestDecode(t *testing.T) { - // Verify a msgpack-encoded map decodes cleanly into typed struct. - raw := map[string]any{ - "request_id": "req-1", - "user_id": "user-1", - "message": "hello", - "premium": true, - "top_k": 10, + // Verify a proto-encoded struct round-trips cleanly. + original := &messages.ChatRequest{ + RequestId: "req-1", + UserId: "user-1", + Message: "hello", + Premium: true, + TopK: 10, } - data, _ := msgpack.Marshal(raw) - var req messages.ChatRequest - if err := msgpack.Unmarshal(data, &req); err != nil { + data, err := proto.Marshal(original) + if err != nil { t.Fatal(err) } - if req.RequestID != "req-1" { - t.Errorf("RequestID = %q", req.RequestID) + var req messages.ChatRequest + if err := proto.Unmarshal(data, &req); err != nil { + t.Fatal(err) } - if req.EffectiveQuery() != "hello" { - t.Errorf("EffectiveQuery = %q", req.EffectiveQuery()) + if req.RequestId != "req-1" { + t.Errorf("RequestID = %q", req.RequestId) + } + if messages.EffectiveQuery(&req) != "hello" { + t.Errorf("EffectiveQuery = %q", messages.EffectiveQuery(&req)) } if !req.Premium { t.Error("Premium should be true") @@ -38,20 +41,20 @@ func TestChatRequestDecode(t *testing.T) { func TestChatResponseRoundtrip(t *testing.T) { resp := &messages.ChatResponse{ - UserID: "user-1", + UserId: "user-1", Response: "answer", Success: true, Audio: []byte{0x01, 0x02, 0x03}, } - data, err := msgpack.Marshal(resp) + data, err := proto.Marshal(resp) if err != nil { t.Fatal(err) } var decoded messages.ChatResponse - if err := msgpack.Unmarshal(data, &decoded); err != nil { + if err := proto.Unmarshal(data, &decoded); err != nil { t.Fatal(err) } - if decoded.UserID != "user-1" || !decoded.Success { + if decoded.UserId != "user-1" || !decoded.Success { t.Errorf("decoded = %+v", decoded) } if len(decoded.Audio) != 3 {