Compare commits
8 Commits
81581337cd
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| f1dd96a42b | |||
| 13ef1df109 | |||
| 3585d81ff5 | |||
| fba7b62573 | |||
| 6fd0b9a265 | |||
| 8b6232141a | |||
| 9876cb9388 | |||
| 39673d31b8 |
@@ -17,20 +17,22 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up uv
|
- name: Set up Go
|
||||||
run: curl -LsSf https://astral.sh/uv/install.sh | sh && echo "$HOME/.local/bin" >> $GITHUB_PATH
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: go.mod
|
||||||
|
cache: true
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Run go vet
|
||||||
run: uv python install 3.13
|
run: go vet ./...
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install golangci-lint
|
||||||
run: uv sync --frozen --extra dev
|
run: |
|
||||||
|
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b "$(go env GOPATH)/bin"
|
||||||
|
echo "$(go env GOPATH)/bin" >> $GITHUB_PATH
|
||||||
|
|
||||||
- name: Run ruff check
|
- name: Run golangci-lint
|
||||||
run: uv run ruff check .
|
run: golangci-lint run ./...
|
||||||
|
|
||||||
- name: Run ruff format check
|
|
||||||
run: uv run ruff format --check .
|
|
||||||
|
|
||||||
test:
|
test:
|
||||||
name: Test
|
name: Test
|
||||||
@@ -39,23 +41,28 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up uv
|
- name: Set up Go
|
||||||
run: curl -LsSf https://astral.sh/uv/install.sh | sh && echo "$HOME/.local/bin" >> $GITHUB_PATH
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: go.mod
|
||||||
|
cache: true
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Verify dependencies
|
||||||
run: uv python install 3.13
|
run: go mod verify
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Build
|
||||||
run: uv sync --frozen --extra dev
|
run: go build -v ./...
|
||||||
|
|
||||||
- name: Run tests with coverage
|
- name: Run tests
|
||||||
run: uv run pytest --cov=handler_base --cov-report=xml --cov-report=term
|
run: go test -v -race -coverprofile=coverage.out -covermode=atomic ./...
|
||||||
|
|
||||||
release:
|
release:
|
||||||
name: Release
|
name: Release
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [lint, test]
|
needs: [lint, test]
|
||||||
if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push'
|
if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push'
|
||||||
|
outputs:
|
||||||
|
version: ${{ steps.version.outputs.version }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -95,10 +102,32 @@ jobs:
|
|||||||
git tag -a ${{ steps.version.outputs.version }} -m "Release ${{ steps.version.outputs.version }}"
|
git tag -a ${{ steps.version.outputs.version }} -m "Release ${{ steps.version.outputs.version }}"
|
||||||
git push origin ${{ steps.version.outputs.version }}
|
git push origin ${{ steps.version.outputs.version }}
|
||||||
|
|
||||||
|
notify-downstream:
|
||||||
|
name: Notify Downstream
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [release]
|
||||||
|
if: needs.release.result == 'success'
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
repo:
|
||||||
|
- chat-handler
|
||||||
|
- pipeline-bridge
|
||||||
|
- tts-module
|
||||||
|
- voice-assistant
|
||||||
|
- stt-module
|
||||||
|
steps:
|
||||||
|
- name: Trigger dependency update
|
||||||
|
run: |
|
||||||
|
curl -s -X POST \
|
||||||
|
-H "Authorization: token ${{ secrets.DISPATCH_TOKEN }}" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"event_type":"handler-base-release","client_payload":{"version":"${{ needs.release.outputs.version }}"}}' \
|
||||||
|
"${{ gitea.server_url }}/api/v1/repos/daviestechlabs/${{ matrix.repo }}/dispatches"
|
||||||
|
|
||||||
notify:
|
notify:
|
||||||
name: Notify
|
name: Notify
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [lint, test, release]
|
needs: [lint, test, release, notify-downstream]
|
||||||
if: always()
|
if: always()
|
||||||
steps:
|
steps:
|
||||||
- name: Notify on success
|
- name: Notify on success
|
||||||
|
|||||||
10
buf.gen.yaml
Normal file
10
buf.gen.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
version: v2
|
||||||
|
managed:
|
||||||
|
enabled: true
|
||||||
|
override:
|
||||||
|
- file_option: go_package_prefix
|
||||||
|
value: git.daviestechlabs.io/daviestechlabs/handler-base/gen
|
||||||
|
plugins:
|
||||||
|
- protoc_builtin: go
|
||||||
|
out: gen
|
||||||
|
opt: paths=source_relative
|
||||||
9
buf.yaml
Normal file
9
buf.yaml
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
version: v2
|
||||||
|
modules:
|
||||||
|
- path: proto
|
||||||
|
lint:
|
||||||
|
use:
|
||||||
|
- STANDARD
|
||||||
|
breaking:
|
||||||
|
use:
|
||||||
|
- FILE
|
||||||
@@ -6,16 +6,18 @@
|
|||||||
package clients
|
package clients
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bufio"
|
||||||
"context"
|
"bytes"
|
||||||
"encoding/json"
|
"context"
|
||||||
"fmt"
|
"encoding/json"
|
||||||
"io"
|
"fmt"
|
||||||
"mime/multipart"
|
"io"
|
||||||
"net/http"
|
"mime/multipart"
|
||||||
"net/url"
|
"net/http"
|
||||||
"sync"
|
"net/url"
|
||||||
"time"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ─── Shared transport & buffer pool ─────────────────────────────────────────
|
// ─── Shared transport & buffer pool ─────────────────────────────────────────
|
||||||
@@ -23,393 +25,490 @@ import (
|
|||||||
// SharedTransport is the process-wide HTTP transport used by every service
|
// SharedTransport is the process-wide HTTP transport used by every service
|
||||||
// client. Tweak pool sizes here rather than creating per-client transports.
|
// client. Tweak pool sizes here rather than creating per-client transports.
|
||||||
var SharedTransport = &http.Transport{
|
var SharedTransport = &http.Transport{
|
||||||
MaxIdleConns: 100,
|
MaxIdleConns: 100,
|
||||||
MaxIdleConnsPerHost: 10,
|
MaxIdleConnsPerHost: 10,
|
||||||
IdleConnTimeout: 90 * time.Second,
|
IdleConnTimeout: 90 * time.Second,
|
||||||
DisableCompression: true, // in-cluster traffic; skip gzip overhead
|
DisableCompression: true, // in-cluster traffic; skip gzip overhead
|
||||||
}
|
}
|
||||||
|
|
||||||
// bufPool recycles *bytes.Buffer to avoid per-request allocations.
|
// bufPool recycles *bytes.Buffer to avoid per-request allocations.
|
||||||
var bufPool = sync.Pool{
|
var bufPool = sync.Pool{
|
||||||
New: func() any { return new(bytes.Buffer) },
|
New: func() any { return new(bytes.Buffer) },
|
||||||
}
|
}
|
||||||
|
|
||||||
func getBuf() *bytes.Buffer {
|
func getBuf() *bytes.Buffer {
|
||||||
buf := bufPool.Get().(*bytes.Buffer)
|
buf := bufPool.Get().(*bytes.Buffer)
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
return buf
|
return buf
|
||||||
}
|
}
|
||||||
|
|
||||||
func putBuf(buf *bytes.Buffer) {
|
func putBuf(buf *bytes.Buffer) {
|
||||||
if buf.Cap() > 1<<20 { // don't cache buffers > 1 MiB
|
if buf.Cap() > 1<<20 { // don't cache buffers > 1 MiB
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
bufPool.Put(buf)
|
bufPool.Put(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── httpClient base ────────────────────────────────────────────────────────
|
// ─── httpClient base ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
// httpClient is the shared base for all service clients.
|
// httpClient is the shared base for all service clients.
|
||||||
type httpClient struct {
|
type httpClient struct {
|
||||||
client *http.Client
|
client *http.Client
|
||||||
baseURL string
|
baseURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHTTPClient(baseURL string, timeout time.Duration) *httpClient {
|
func newHTTPClient(baseURL string, timeout time.Duration) *httpClient {
|
||||||
return &httpClient{
|
return &httpClient{
|
||||||
client: &http.Client{
|
client: &http.Client{
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
Transport: SharedTransport,
|
Transport: SharedTransport,
|
||||||
},
|
},
|
||||||
baseURL: baseURL,
|
baseURL: baseURL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpClient) postJSON(ctx context.Context, path string, body any) ([]byte, error) {
|
func (h *httpClient) postJSON(ctx context.Context, path string, body any) ([]byte, error) {
|
||||||
buf := getBuf()
|
buf := getBuf()
|
||||||
defer putBuf(buf)
|
defer putBuf(buf)
|
||||||
if err := json.NewEncoder(buf).Encode(body); err != nil {
|
if err := json.NewEncoder(buf).Encode(body); err != nil {
|
||||||
return nil, fmt.Errorf("marshal: %w", err)
|
return nil, fmt.Errorf("marshal: %w", err)
|
||||||
}
|
}
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf)
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
return h.do(req)
|
return h.do(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpClient) get(ctx context.Context, path string, params url.Values) ([]byte, error) {
|
func (h *httpClient) get(ctx context.Context, path string, params url.Values) ([]byte, error) {
|
||||||
u := h.baseURL + path
|
u := h.baseURL + path
|
||||||
if len(params) > 0 {
|
if len(params) > 0 {
|
||||||
u += "?" + params.Encode()
|
u += "?" + params.Encode()
|
||||||
}
|
}
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return h.do(req)
|
return h.do(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpClient) getRaw(ctx context.Context, path string, params url.Values) ([]byte, error) {
|
func (h *httpClient) getRaw(ctx context.Context, path string, params url.Values) ([]byte, error) {
|
||||||
return h.get(ctx, path, params)
|
return h.get(ctx, path, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpClient) postMultipart(ctx context.Context, path string, fieldName string, fileName string, fileData []byte, fields map[string]string) ([]byte, error) {
|
func (h *httpClient) postMultipart(ctx context.Context, path string, fieldName string, fileName string, fileData []byte, fields map[string]string) ([]byte, error) {
|
||||||
buf := getBuf()
|
buf := getBuf()
|
||||||
defer putBuf(buf)
|
defer putBuf(buf)
|
||||||
w := multipart.NewWriter(buf)
|
w := multipart.NewWriter(buf)
|
||||||
part, err := w.CreateFormFile(fieldName, fileName)
|
part, err := w.CreateFormFile(fieldName, fileName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if _, err := part.Write(fileData); err != nil {
|
if _, err := part.Write(fileData); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for k, v := range fields {
|
for k, v := range fields {
|
||||||
_ = w.WriteField(k, v)
|
_ = w.WriteField(k, v)
|
||||||
}
|
}
|
||||||
_ = w.Close()
|
_ = w.Close()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf)
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
return h.do(req)
|
return h.do(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpClient) do(req *http.Request) ([]byte, error) {
|
func (h *httpClient) do(req *http.Request) ([]byte, error) {
|
||||||
resp, err := h.client.Do(req)
|
resp, err := h.client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err)
|
return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
buf := getBuf()
|
buf := getBuf()
|
||||||
defer putBuf(buf)
|
defer putBuf(buf)
|
||||||
if _, err := io.Copy(buf, resp.Body); err != nil {
|
if _, err := io.Copy(buf, resp.Body); err != nil {
|
||||||
return nil, fmt.Errorf("read body: %w", err)
|
return nil, fmt.Errorf("read body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a copy so the pooled buffer can be safely recycled.
|
||||||
|
body := make([]byte, buf.Len())
|
||||||
|
copy(body, buf.Bytes())
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a copy so the pooled buffer can be safely recycled.
|
// postJSONStream sends a JSON POST and returns the raw *http.Response so the
|
||||||
body := make([]byte, buf.Len())
|
// caller can read the body incrementally (e.g. for SSE streaming). The caller
|
||||||
copy(body, buf.Bytes())
|
// is responsible for closing resp.Body.
|
||||||
|
func (h *httpClient) postJSONStream(ctx context.Context, path string, body any) (*http.Response, error) {
|
||||||
|
buf := getBuf()
|
||||||
|
defer putBuf(buf)
|
||||||
|
if err := json.NewEncoder(buf).Encode(body); err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal: %w", err)
|
||||||
|
}
|
||||||
|
// Copy to a non-pooled buffer so we can safely return the pool buffer.
|
||||||
|
payload := make([]byte, buf.Len())
|
||||||
|
copy(payload, buf.Bytes())
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, bytes.NewReader(payload))
|
||||||
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(body))
|
if err != nil {
|
||||||
}
|
return nil, err
|
||||||
return body, nil
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
resp, err := h.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpClient) healthCheck(ctx context.Context) bool {
|
func (h *httpClient) healthCheck(ctx context.Context) bool {
|
||||||
data, err := h.get(ctx, "/health", nil)
|
data, err := h.get(ctx, "/health", nil)
|
||||||
_ = data
|
_ = data
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── Embeddings Client ──────────────────────────────────────────────────────
|
// ─── Embeddings Client ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
// EmbeddingsClient calls the embeddings service (Infinity/BGE).
|
// EmbeddingsClient calls the embeddings service (Infinity/BGE).
|
||||||
type EmbeddingsClient struct {
|
type EmbeddingsClient struct {
|
||||||
*httpClient
|
*httpClient
|
||||||
Model string
|
Model string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewEmbeddingsClient creates an embeddings client.
|
// NewEmbeddingsClient creates an embeddings client.
|
||||||
func NewEmbeddingsClient(baseURL string, timeout time.Duration, model string) *EmbeddingsClient {
|
func NewEmbeddingsClient(baseURL string, timeout time.Duration, model string) *EmbeddingsClient {
|
||||||
if model == "" {
|
if model == "" {
|
||||||
model = "bge"
|
model = "bge"
|
||||||
}
|
}
|
||||||
return &EmbeddingsClient{httpClient: newHTTPClient(baseURL, timeout), Model: model}
|
return &EmbeddingsClient{httpClient: newHTTPClient(baseURL, timeout), Model: model}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Embed generates embeddings for a list of texts.
|
// Embed generates embeddings for a list of texts.
|
||||||
func (c *EmbeddingsClient) Embed(ctx context.Context, texts []string) ([][]float64, error) {
|
func (c *EmbeddingsClient) Embed(ctx context.Context, texts []string) ([][]float64, error) {
|
||||||
body, err := c.postJSON(ctx, "/embeddings", map[string]any{
|
body, err := c.postJSON(ctx, "/embeddings", map[string]any{
|
||||||
"input": texts,
|
"input": texts,
|
||||||
"model": c.Model,
|
"model": c.Model,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var resp struct {
|
var resp struct {
|
||||||
Data []struct {
|
Data []struct {
|
||||||
Embedding []float64 `json:"embedding"`
|
Embedding []float64 `json:"embedding"`
|
||||||
} `json:"data"`
|
} `json:"data"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(body, &resp); err != nil {
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
result := make([][]float64, len(resp.Data))
|
result := make([][]float64, len(resp.Data))
|
||||||
for i, d := range resp.Data {
|
for i, d := range resp.Data {
|
||||||
result[i] = d.Embedding
|
result[i] = d.Embedding
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmbedSingle generates an embedding for a single text.
|
// EmbedSingle generates an embedding for a single text.
|
||||||
func (c *EmbeddingsClient) EmbedSingle(ctx context.Context, text string) ([]float64, error) {
|
func (c *EmbeddingsClient) EmbedSingle(ctx context.Context, text string) ([]float64, error) {
|
||||||
results, err := c.Embed(ctx, []string{text})
|
results, err := c.Embed(ctx, []string{text})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(results) == 0 {
|
if len(results) == 0 {
|
||||||
return nil, fmt.Errorf("empty embedding result")
|
return nil, fmt.Errorf("empty embedding result")
|
||||||
}
|
}
|
||||||
return results[0], nil
|
return results[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Health checks if the embeddings service is healthy.
|
// Health checks if the embeddings service is healthy.
|
||||||
func (c *EmbeddingsClient) Health(ctx context.Context) bool {
|
func (c *EmbeddingsClient) Health(ctx context.Context) bool {
|
||||||
return c.healthCheck(ctx)
|
return c.healthCheck(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── Reranker Client ────────────────────────────────────────────────────────
|
// ─── Reranker Client ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
// RerankerClient calls the reranker service (BGE Reranker).
|
// RerankerClient calls the reranker service (BGE Reranker).
|
||||||
type RerankerClient struct {
|
type RerankerClient struct {
|
||||||
*httpClient
|
*httpClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRerankerClient creates a reranker client.
|
// NewRerankerClient creates a reranker client.
|
||||||
func NewRerankerClient(baseURL string, timeout time.Duration) *RerankerClient {
|
func NewRerankerClient(baseURL string, timeout time.Duration) *RerankerClient {
|
||||||
return &RerankerClient{httpClient: newHTTPClient(baseURL, timeout)}
|
return &RerankerClient{httpClient: newHTTPClient(baseURL, timeout)}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RerankResult represents a reranked document.
|
// RerankResult represents a reranked document.
|
||||||
type RerankResult struct {
|
type RerankResult struct {
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
Score float64 `json:"score"`
|
Score float64 `json:"score"`
|
||||||
Document string `json:"document"`
|
Document string `json:"document"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rerank reranks documents by relevance to the query.
|
// Rerank reranks documents by relevance to the query.
|
||||||
func (c *RerankerClient) Rerank(ctx context.Context, query string, documents []string, topK int) ([]RerankResult, error) {
|
func (c *RerankerClient) Rerank(ctx context.Context, query string, documents []string, topK int) ([]RerankResult, error) {
|
||||||
payload := map[string]any{
|
payload := map[string]any{
|
||||||
"query": query,
|
"query": query,
|
||||||
"documents": documents,
|
"documents": documents,
|
||||||
}
|
}
|
||||||
if topK > 0 {
|
if topK > 0 {
|
||||||
payload["top_n"] = topK
|
payload["top_n"] = topK
|
||||||
}
|
}
|
||||||
body, err := c.postJSON(ctx, "/rerank", payload)
|
body, err := c.postJSON(ctx, "/rerank", payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var resp struct {
|
var resp struct {
|
||||||
Results []struct {
|
Results []struct {
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
RelevanceScore float64 `json:"relevance_score"`
|
RelevanceScore float64 `json:"relevance_score"`
|
||||||
Score float64 `json:"score"`
|
Score float64 `json:"score"`
|
||||||
} `json:"results"`
|
} `json:"results"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(body, &resp); err != nil {
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
results := make([]RerankResult, len(resp.Results))
|
results := make([]RerankResult, len(resp.Results))
|
||||||
for i, r := range resp.Results {
|
for i, r := range resp.Results {
|
||||||
score := r.RelevanceScore
|
score := r.RelevanceScore
|
||||||
if score == 0 {
|
if score == 0 {
|
||||||
score = r.Score
|
score = r.Score
|
||||||
}
|
}
|
||||||
doc := ""
|
doc := ""
|
||||||
if r.Index < len(documents) {
|
if r.Index < len(documents) {
|
||||||
doc = documents[r.Index]
|
doc = documents[r.Index]
|
||||||
}
|
}
|
||||||
results[i] = RerankResult{Index: r.Index, Score: score, Document: doc}
|
results[i] = RerankResult{Index: r.Index, Score: score, Document: doc}
|
||||||
}
|
}
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── LLM Client ─────────────────────────────────────────────────────────────
|
// ─── LLM Client ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
// LLMClient calls the vLLM-compatible LLM service.
|
// LLMClient calls the vLLM-compatible LLM service.
|
||||||
type LLMClient struct {
|
type LLMClient struct {
|
||||||
*httpClient
|
*httpClient
|
||||||
Model string
|
Model string
|
||||||
MaxTokens int
|
MaxTokens int
|
||||||
Temperature float64
|
Temperature float64
|
||||||
TopP float64
|
TopP float64
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewLLMClient creates an LLM client.
|
// NewLLMClient creates an LLM client.
|
||||||
func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient {
|
func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient {
|
||||||
return &LLMClient{
|
return &LLMClient{
|
||||||
httpClient: newHTTPClient(baseURL, timeout),
|
httpClient: newHTTPClient(baseURL, timeout),
|
||||||
Model: "default",
|
Model: "default",
|
||||||
MaxTokens: 2048,
|
MaxTokens: 2048,
|
||||||
Temperature: 0.7,
|
Temperature: 0.7,
|
||||||
TopP: 0.9,
|
TopP: 0.9,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatMessage is an OpenAI-compatible message.
|
// ChatMessage is an OpenAI-compatible message.
|
||||||
type ChatMessage struct {
|
type ChatMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate sends a chat completion request and returns the response text.
|
// Generate sends a chat completion request and returns the response text.
|
||||||
func (c *LLMClient) Generate(ctx context.Context, prompt string, context_ string, systemPrompt string) (string, error) {
|
func (c *LLMClient) Generate(ctx context.Context, prompt string, context_ string, systemPrompt string) (string, error) {
|
||||||
messages := buildMessages(prompt, context_, systemPrompt)
|
messages := buildMessages(prompt, context_, systemPrompt)
|
||||||
payload := map[string]any{
|
payload := map[string]any{
|
||||||
"model": c.Model,
|
"model": c.Model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"max_tokens": c.MaxTokens,
|
"max_tokens": c.MaxTokens,
|
||||||
"temperature": c.Temperature,
|
"temperature": c.Temperature,
|
||||||
"top_p": c.TopP,
|
"top_p": c.TopP,
|
||||||
|
}
|
||||||
|
body, err := c.postJSON(ctx, "/v1/chat/completions", payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var resp struct {
|
||||||
|
Choices []struct {
|
||||||
|
Message struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"message"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(resp.Choices) == 0 {
|
||||||
|
return "", fmt.Errorf("no choices in LLM response")
|
||||||
|
}
|
||||||
|
return resp.Choices[0].Message.Content, nil
|
||||||
}
|
}
|
||||||
body, err := c.postJSON(ctx, "/v1/chat/completions", payload)
|
|
||||||
if err != nil {
|
// StreamGenerate sends a streaming chat completion request and calls onToken
|
||||||
return "", err
|
// for each content delta received via SSE. Returns the fully assembled text.
|
||||||
}
|
// The onToken callback is invoked synchronously on the calling goroutine; it
|
||||||
var resp struct {
|
// should be fast (e.g. publish a NATS message).
|
||||||
Choices []struct {
|
func (c *LLMClient) StreamGenerate(ctx context.Context, prompt string, context_ string, systemPrompt string, onToken func(token string)) (string, error) {
|
||||||
Message struct {
|
msgs := buildMessages(prompt, context_, systemPrompt)
|
||||||
Content string `json:"content"`
|
payload := map[string]any{
|
||||||
} `json:"message"`
|
"model": c.Model,
|
||||||
} `json:"choices"`
|
"messages": msgs,
|
||||||
}
|
"max_tokens": c.MaxTokens,
|
||||||
if err := json.Unmarshal(body, &resp); err != nil {
|
"temperature": c.Temperature,
|
||||||
return "", err
|
"top_p": c.TopP,
|
||||||
}
|
"stream": true,
|
||||||
if len(resp.Choices) == 0 {
|
}
|
||||||
return "", fmt.Errorf("no choices in LLM response")
|
|
||||||
}
|
resp, err := c.postJSONStream(ctx, "/v1/chat/completions", payload)
|
||||||
return resp.Choices[0].Message.Content, nil
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
var full strings.Builder
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
// SSE lines can be up to 64 KiB for large token batches.
|
||||||
|
scanner.Buffer(make([]byte, 0, 64*1024), 64*1024)
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := strings.TrimPrefix(line, "data: ")
|
||||||
|
if data == "[DONE]" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
var chunk struct {
|
||||||
|
Choices []struct {
|
||||||
|
Delta struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"delta"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||||
|
continue // skip malformed chunks
|
||||||
|
}
|
||||||
|
if len(chunk.Choices) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
token := chunk.Choices[0].Delta.Content
|
||||||
|
if token == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
full.WriteString(token)
|
||||||
|
if onToken != nil {
|
||||||
|
onToken(token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
// If we already collected some text, return it with the error.
|
||||||
|
if full.Len() > 0 {
|
||||||
|
return full.String(), fmt.Errorf("stream interrupted: %w", err)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("stream read: %w", err)
|
||||||
|
}
|
||||||
|
return full.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage {
|
func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage {
|
||||||
var msgs []ChatMessage
|
var msgs []ChatMessage
|
||||||
if systemPrompt != "" {
|
if systemPrompt != "" {
|
||||||
msgs = append(msgs, ChatMessage{Role: "system", Content: systemPrompt})
|
msgs = append(msgs, ChatMessage{Role: "system", Content: systemPrompt})
|
||||||
} else if ctx != "" {
|
} else if ctx != "" {
|
||||||
msgs = append(msgs, ChatMessage{Role: "system", Content: "You are a helpful assistant. Use the provided context to answer the user's question. If the context doesn't contain relevant information, say so."})
|
msgs = append(msgs, ChatMessage{Role: "system", Content: "You are a helpful assistant. Use the provided context to answer the user's question. If the context doesn't contain relevant information, say so."})
|
||||||
}
|
}
|
||||||
if ctx != "" {
|
if ctx != "" {
|
||||||
msgs = append(msgs, ChatMessage{Role: "user", Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", ctx, prompt)})
|
msgs = append(msgs, ChatMessage{Role: "user", Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", ctx, prompt)})
|
||||||
} else {
|
} else {
|
||||||
msgs = append(msgs, ChatMessage{Role: "user", Content: prompt})
|
msgs = append(msgs, ChatMessage{Role: "user", Content: prompt})
|
||||||
}
|
}
|
||||||
return msgs
|
return msgs
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── TTS Client ─────────────────────────────────────────────────────────────
|
// ─── TTS Client ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
// TTSClient calls the TTS service (Coqui XTTS).
|
// TTSClient calls the TTS service (Coqui XTTS).
|
||||||
type TTSClient struct {
|
type TTSClient struct {
|
||||||
*httpClient
|
*httpClient
|
||||||
Language string
|
Language string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTTSClient creates a TTS client.
|
// NewTTSClient creates a TTS client.
|
||||||
func NewTTSClient(baseURL string, timeout time.Duration, language string) *TTSClient {
|
func NewTTSClient(baseURL string, timeout time.Duration, language string) *TTSClient {
|
||||||
if language == "" {
|
if language == "" {
|
||||||
language = "en"
|
language = "en"
|
||||||
}
|
}
|
||||||
return &TTSClient{httpClient: newHTTPClient(baseURL, timeout), Language: language}
|
return &TTSClient{httpClient: newHTTPClient(baseURL, timeout), Language: language}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Synthesize generates audio bytes from text.
|
// Synthesize generates audio bytes from text.
|
||||||
func (c *TTSClient) Synthesize(ctx context.Context, text, language, speaker string) ([]byte, error) {
|
func (c *TTSClient) Synthesize(ctx context.Context, text, language, speaker string) ([]byte, error) {
|
||||||
if language == "" {
|
if language == "" {
|
||||||
language = c.Language
|
language = c.Language
|
||||||
}
|
}
|
||||||
params := url.Values{
|
params := url.Values{
|
||||||
"text": {text},
|
"text": {text},
|
||||||
"language_id": {language},
|
"language_id": {language},
|
||||||
}
|
}
|
||||||
if speaker != "" {
|
if speaker != "" {
|
||||||
params.Set("speaker_id", speaker)
|
params.Set("speaker_id", speaker)
|
||||||
}
|
}
|
||||||
return c.getRaw(ctx, "/api/tts", params)
|
return c.getRaw(ctx, "/api/tts", params)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── STT Client ─────────────────────────────────────────────────────────────
|
// ─── STT Client ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
// STTClient calls the Whisper STT service.
|
// STTClient calls the Whisper STT service.
|
||||||
type STTClient struct {
|
type STTClient struct {
|
||||||
*httpClient
|
*httpClient
|
||||||
Language string
|
Language string
|
||||||
Task string
|
Task string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSTTClient creates an STT client.
|
// NewSTTClient creates an STT client.
|
||||||
func NewSTTClient(baseURL string, timeout time.Duration) *STTClient {
|
func NewSTTClient(baseURL string, timeout time.Duration) *STTClient {
|
||||||
return &STTClient{httpClient: newHTTPClient(baseURL, timeout), Task: "transcribe"}
|
return &STTClient{httpClient: newHTTPClient(baseURL, timeout), Task: "transcribe"}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TranscribeResult holds transcription output.
|
// TranscribeResult holds transcription output.
|
||||||
type TranscribeResult struct {
|
type TranscribeResult struct {
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
Language string `json:"language,omitempty"`
|
Language string `json:"language,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transcribe sends audio to Whisper and returns the transcription.
|
// Transcribe sends audio to Whisper and returns the transcription.
|
||||||
func (c *STTClient) Transcribe(ctx context.Context, audio []byte, language string) (*TranscribeResult, error) {
|
func (c *STTClient) Transcribe(ctx context.Context, audio []byte, language string) (*TranscribeResult, error) {
|
||||||
if language == "" {
|
if language == "" {
|
||||||
language = c.Language
|
language = c.Language
|
||||||
}
|
}
|
||||||
fields := map[string]string{
|
fields := map[string]string{
|
||||||
"response_format": "json",
|
"response_format": "json",
|
||||||
}
|
}
|
||||||
if language != "" {
|
if language != "" {
|
||||||
fields["language"] = language
|
fields["language"] = language
|
||||||
}
|
}
|
||||||
endpoint := "/v1/audio/transcriptions"
|
endpoint := "/v1/audio/transcriptions"
|
||||||
if c.Task == "translate" {
|
if c.Task == "translate" {
|
||||||
endpoint = "/v1/audio/translations"
|
endpoint = "/v1/audio/translations"
|
||||||
}
|
}
|
||||||
body, err := c.postMultipart(ctx, endpoint, "file", "audio.wav", audio, fields)
|
body, err := c.postMultipart(ctx, endpoint, "file", "audio.wav", audio, fields)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var result TranscribeResult
|
var result TranscribeResult
|
||||||
if err := json.Unmarshal(body, &result); err != nil {
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &result, nil
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── Milvus Client ──────────────────────────────────────────────────────────
|
// ─── Milvus Client ──────────────────────────────────────────────────────────
|
||||||
@@ -417,21 +516,20 @@ return &result, nil
|
|||||||
// MilvusClient provides vector search via the Milvus HTTP/gRPC API.
|
// MilvusClient provides vector search via the Milvus HTTP/gRPC API.
|
||||||
// For the Go port we use the Milvus Go SDK.
|
// For the Go port we use the Milvus Go SDK.
|
||||||
type MilvusClient struct {
|
type MilvusClient struct {
|
||||||
Host string
|
Host string
|
||||||
Port int
|
Port int
|
||||||
Collection string
|
Collection string
|
||||||
connected bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMilvusClient creates a Milvus client.
|
// NewMilvusClient creates a Milvus client.
|
||||||
func NewMilvusClient(host string, port int, collection string) *MilvusClient {
|
func NewMilvusClient(host string, port int, collection string) *MilvusClient {
|
||||||
return &MilvusClient{Host: host, Port: port, Collection: collection}
|
return &MilvusClient{Host: host, Port: port, Collection: collection}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SearchResult holds a single vector search hit.
|
// SearchResult holds a single vector search hit.
|
||||||
type SearchResult struct {
|
type SearchResult struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Distance float64 `json:"distance"`
|
Distance float64 `json:"distance"`
|
||||||
Score float64 `json:"score"`
|
Score float64 `json:"score"`
|
||||||
Fields map[string]any `json:"fields,omitempty"`
|
Fields map[string]any `json:"fields,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -90,14 +91,14 @@ func TestEmbeddingsClient_Embed(t *testing.T) {
|
|||||||
t.Errorf("method = %s, want POST", r.Method)
|
t.Errorf("method = %s, want POST", r.Method)
|
||||||
}
|
}
|
||||||
var req map[string]any
|
var req map[string]any
|
||||||
json.NewDecoder(r.Body).Decode(&req)
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
input, _ := req["input"].([]any)
|
input, _ := req["input"].([]any)
|
||||||
if len(input) != 2 {
|
if len(input) != 2 {
|
||||||
t.Errorf("input len = %d, want 2", len(input))
|
t.Errorf("input len = %d, want 2", len(input))
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
json.NewEncoder(w).Encode(map[string]any{
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
"data": []map[string]any{
|
"data": []map[string]any{
|
||||||
{"embedding": []float64{0.1, 0.2, 0.3}},
|
{"embedding": []float64{0.1, 0.2, 0.3}},
|
||||||
{"embedding": []float64{0.4, 0.5, 0.6}},
|
{"embedding": []float64{0.4, 0.5, 0.6}},
|
||||||
@@ -121,7 +122,7 @@ func TestEmbeddingsClient_Embed(t *testing.T) {
|
|||||||
|
|
||||||
func TestEmbeddingsClient_EmbedSingle(t *testing.T) {
|
func TestEmbeddingsClient_EmbedSingle(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
json.NewEncoder(w).Encode(map[string]any{
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
"data": []map[string]any{
|
"data": []map[string]any{
|
||||||
{"embedding": []float64{1.0, 2.0}},
|
{"embedding": []float64{1.0, 2.0}},
|
||||||
},
|
},
|
||||||
@@ -141,7 +142,7 @@ func TestEmbeddingsClient_EmbedSingle(t *testing.T) {
|
|||||||
|
|
||||||
func TestEmbeddingsClient_EmbedEmpty(t *testing.T) {
|
func TestEmbeddingsClient_EmbedEmpty(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
json.NewEncoder(w).Encode(map[string]any{"data": []any{}})
|
_ = json.NewEncoder(w).Encode(map[string]any{"data": []any{}})
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
@@ -175,11 +176,11 @@ func TestEmbeddingsClient_Health(t *testing.T) {
|
|||||||
func TestRerankerClient_Rerank(t *testing.T) {
|
func TestRerankerClient_Rerank(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
var req map[string]any
|
var req map[string]any
|
||||||
json.NewDecoder(r.Body).Decode(&req)
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if req["query"] != "test query" {
|
if req["query"] != "test query" {
|
||||||
t.Errorf("query = %v", req["query"])
|
t.Errorf("query = %v", req["query"])
|
||||||
}
|
}
|
||||||
json.NewEncoder(w).Encode(map[string]any{
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
"results": []map[string]any{
|
"results": []map[string]any{
|
||||||
{"index": 1, "relevance_score": 0.95},
|
{"index": 1, "relevance_score": 0.95},
|
||||||
{"index": 0, "relevance_score": 0.80},
|
{"index": 0, "relevance_score": 0.80},
|
||||||
@@ -207,7 +208,7 @@ func TestRerankerClient_Rerank(t *testing.T) {
|
|||||||
|
|
||||||
func TestRerankerClient_RerankFallbackScore(t *testing.T) {
|
func TestRerankerClient_RerankFallbackScore(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
json.NewEncoder(w).Encode(map[string]any{
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
"results": []map[string]any{
|
"results": []map[string]any{
|
||||||
{"index": 0, "score": 0.77, "relevance_score": 0}, // some APIs only set score
|
{"index": 0, "score": 0.77, "relevance_score": 0}, // some APIs only set score
|
||||||
},
|
},
|
||||||
@@ -235,13 +236,13 @@ func TestLLMClient_Generate(t *testing.T) {
|
|||||||
t.Errorf("path = %q", r.URL.Path)
|
t.Errorf("path = %q", r.URL.Path)
|
||||||
}
|
}
|
||||||
var req map[string]any
|
var req map[string]any
|
||||||
json.NewDecoder(r.Body).Decode(&req)
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
msgs, _ := req["messages"].([]any)
|
msgs, _ := req["messages"].([]any)
|
||||||
if len(msgs) == 0 {
|
if len(msgs) == 0 {
|
||||||
t.Error("no messages in request")
|
t.Error("no messages in request")
|
||||||
}
|
}
|
||||||
|
|
||||||
json.NewEncoder(w).Encode(map[string]any{
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
"choices": []map[string]any{
|
"choices": []map[string]any{
|
||||||
{"message": map[string]any{"content": "Paris is the capital of France."}},
|
{"message": map[string]any{"content": "Paris is the capital of France."}},
|
||||||
},
|
},
|
||||||
@@ -262,13 +263,13 @@ func TestLLMClient_Generate(t *testing.T) {
|
|||||||
func TestLLMClient_GenerateWithContext(t *testing.T) {
|
func TestLLMClient_GenerateWithContext(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
var req map[string]any
|
var req map[string]any
|
||||||
json.NewDecoder(r.Body).Decode(&req)
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
msgs, _ := req["messages"].([]any)
|
msgs, _ := req["messages"].([]any)
|
||||||
// Should have system + user message
|
// Should have system + user message
|
||||||
if len(msgs) != 2 {
|
if len(msgs) != 2 {
|
||||||
t.Errorf("expected 2 messages, got %d", len(msgs))
|
t.Errorf("expected 2 messages, got %d", len(msgs))
|
||||||
}
|
}
|
||||||
json.NewEncoder(w).Encode(map[string]any{
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
"choices": []map[string]any{
|
"choices": []map[string]any{
|
||||||
{"message": map[string]any{"content": "answer with context"}},
|
{"message": map[string]any{"content": "answer with context"}},
|
||||||
},
|
},
|
||||||
@@ -288,7 +289,7 @@ func TestLLMClient_GenerateWithContext(t *testing.T) {
|
|||||||
|
|
||||||
func TestLLMClient_GenerateNoChoices(t *testing.T) {
|
func TestLLMClient_GenerateNoChoices(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
|
_ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
@@ -299,6 +300,293 @@ func TestLLMClient_GenerateNoChoices(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// LLM client — StreamGenerate
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// sseChunk builds an OpenAI-compatible SSE chat.completion.chunk line.
|
||||||
|
func sseChunk(content string) string {
|
||||||
|
chunk := map[string]any{
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{"delta": map[string]any{"content": content}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(chunk)
|
||||||
|
return "data: " + string(b) + "\n\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerate(t *testing.T) {
|
||||||
|
tokens := []string{"Hello", " world", "!"}
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/v1/chat/completions" {
|
||||||
|
t.Errorf("path = %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
var req map[string]any
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
if req["stream"] != true {
|
||||||
|
t.Errorf("stream = %v, want true", req["stream"])
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
for _, tok := range tokens {
|
||||||
|
_, _ = w.Write([]byte(sseChunk(tok)))
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
var received []string
|
||||||
|
result, err := c.StreamGenerate(context.Background(), "hi", "", "", func(tok string) {
|
||||||
|
received = append(received, tok)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "Hello world!" {
|
||||||
|
t.Errorf("result = %q, want %q", result, "Hello world!")
|
||||||
|
}
|
||||||
|
if len(received) != 3 {
|
||||||
|
t.Fatalf("callback count = %d, want 3", len(received))
|
||||||
|
}
|
||||||
|
if received[0] != "Hello" || received[1] != " world" || received[2] != "!" {
|
||||||
|
t.Errorf("received = %v", received)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateWithSystemPrompt(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req map[string]any
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
msgs, _ := req["messages"].([]any)
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Errorf("expected system+user, got %d messages", len(msgs))
|
||||||
|
}
|
||||||
|
first, _ := msgs[0].(map[string]any)
|
||||||
|
if first["role"] != "system" || first["content"] != "You are a DM" {
|
||||||
|
t.Errorf("system msg = %v", first)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
_, _ = w.Write([]byte(sseChunk("ok")))
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
result, err := c.StreamGenerate(context.Background(), "roll dice", "", "You are a DM", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "ok" {
|
||||||
|
t.Errorf("result = %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateNilCallback(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
_, _ = w.Write([]byte(sseChunk("token")))
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
// nil callback should not panic
|
||||||
|
result, err := c.StreamGenerate(context.Background(), "hi", "", "", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "token" {
|
||||||
|
t.Errorf("result = %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateEmptyDelta(t *testing.T) {
|
||||||
|
// SSE chunks with empty content should be silently skipped.
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
// role-only chunk (no content) — common for first chunk from vLLM
|
||||||
|
_, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"}}]}\n\n"))
|
||||||
|
// empty content string
|
||||||
|
_, _ = w.Write([]byte(sseChunk("")))
|
||||||
|
// real token
|
||||||
|
_, _ = w.Write([]byte(sseChunk("hello")))
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
var count int
|
||||||
|
result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) {
|
||||||
|
count++
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "hello" {
|
||||||
|
t.Errorf("result = %q", result)
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("callback count = %d, want 1 (empty deltas should be skipped)", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateMalformedChunks(t *testing.T) {
|
||||||
|
// Malformed JSON should be skipped without error.
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
_, _ = w.Write([]byte("data: {invalid json}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"choices\":[]}\n\n")) // empty choices
|
||||||
|
_, _ = w.Write([]byte(sseChunk("good")))
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
result, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "good" {
|
||||||
|
t.Errorf("result = %q, want %q", result, "good")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateHTTPError(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(500)
|
||||||
|
_, _ = w.Write([]byte("internal server error"))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
_, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 500")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "500") {
|
||||||
|
t.Errorf("error should contain 500: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateContextCanceled(t *testing.T) {
|
||||||
|
started := make(chan struct{})
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
// Send several tokens so the client receives some before cancel.
|
||||||
|
for i := range 20 {
|
||||||
|
_, _ = w.Write([]byte(sseChunk(fmt.Sprintf("tok%d ", i))))
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
close(started)
|
||||||
|
// Block until client cancels
|
||||||
|
<-r.Context().Done()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
c := NewLLMClient(ts.URL, 10*time.Second)
|
||||||
|
|
||||||
|
var streamErr error
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
_, streamErr = c.StreamGenerate(ctx, "q", "", "", nil)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-started
|
||||||
|
cancel()
|
||||||
|
<-done
|
||||||
|
|
||||||
|
// After cancel the stream should return an error (context canceled or
|
||||||
|
// stream interrupted). The exact partial text depends on timing.
|
||||||
|
if streamErr == nil {
|
||||||
|
t.Error("expected error after context cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateNoSSEPrefix(t *testing.T) {
|
||||||
|
// Lines without "data: " prefix should be silently ignored (comments, blank lines, event IDs).
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
_, _ = w.Write([]byte(": this is an SSE comment\n\n"))
|
||||||
|
_, _ = w.Write([]byte("event: message\n"))
|
||||||
|
_, _ = w.Write([]byte(sseChunk("word")))
|
||||||
|
_, _ = w.Write([]byte("\n")) // blank line
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
result, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "word" {
|
||||||
|
t.Errorf("result = %q, want %q", result, "word")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateManyTokens(t *testing.T) {
|
||||||
|
// Verify token ordering and full assembly with many chunks.
|
||||||
|
n := 100
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
for i := range n {
|
||||||
|
tok := fmt.Sprintf("t%d ", i)
|
||||||
|
_, _ = w.Write([]byte(sseChunk(tok)))
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
var mu sync.Mutex
|
||||||
|
var order []int
|
||||||
|
result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) {
|
||||||
|
var idx int
|
||||||
|
_, _ = fmt.Sscanf(tok, "t%d ", &idx)
|
||||||
|
mu.Lock()
|
||||||
|
order = append(order, idx)
|
||||||
|
mu.Unlock()
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Verify all tokens arrived in order
|
||||||
|
if len(order) != n {
|
||||||
|
t.Fatalf("got %d tokens, want %d", len(order), n)
|
||||||
|
}
|
||||||
|
for i, v := range order {
|
||||||
|
if v != i {
|
||||||
|
t.Errorf("order[%d] = %d", i, v)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Quick sanity: result should start with "t0 " and end with last token
|
||||||
|
if !strings.HasPrefix(result, "t0 ") {
|
||||||
|
t.Errorf("result prefix = %q", result[:10])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
// TTS client
|
// TTS client
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -312,7 +600,7 @@ func TestTTSClient_Synthesize(t *testing.T) {
|
|||||||
if r.URL.Query().Get("text") != "hello world" {
|
if r.URL.Query().Get("text") != "hello world" {
|
||||||
t.Errorf("text = %q", r.URL.Query().Get("text"))
|
t.Errorf("text = %q", r.URL.Query().Get("text"))
|
||||||
}
|
}
|
||||||
w.Write(expected)
|
_, _ = w.Write(expected)
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
@@ -331,7 +619,7 @@ func TestTTSClient_SynthesizeWithSpeaker(t *testing.T) {
|
|||||||
if r.URL.Query().Get("speaker_id") != "alice" {
|
if r.URL.Query().Get("speaker_id") != "alice" {
|
||||||
t.Errorf("speaker_id = %q", r.URL.Query().Get("speaker_id"))
|
t.Errorf("speaker_id = %q", r.URL.Query().Get("speaker_id"))
|
||||||
}
|
}
|
||||||
w.Write([]byte{0x01})
|
_, _ = w.Write([]byte{0x01})
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
@@ -365,7 +653,7 @@ func TestSTTClient_Transcribe(t *testing.T) {
|
|||||||
t.Errorf("file size = %d, want 100", len(data))
|
t.Errorf("file size = %d, want 100", len(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"text": "hello world"})
|
_ = json.NewEncoder(w).Encode(map[string]string{"text": "hello world"})
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
@@ -384,7 +672,7 @@ func TestSTTClient_TranscribeTranslate(t *testing.T) {
|
|||||||
if r.URL.Path != "/v1/audio/translations" {
|
if r.URL.Path != "/v1/audio/translations" {
|
||||||
t.Errorf("path = %q, want /v1/audio/translations", r.URL.Path)
|
t.Errorf("path = %q, want /v1/audio/translations", r.URL.Path)
|
||||||
}
|
}
|
||||||
json.NewEncoder(w).Encode(map[string]string{"text": "translated"})
|
_ = json.NewEncoder(w).Encode(map[string]string{"text": "translated"})
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
@@ -406,7 +694,7 @@ func TestSTTClient_TranscribeTranslate(t *testing.T) {
|
|||||||
func TestHTTPError4xx(t *testing.T) {
|
func TestHTTPError4xx(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(422)
|
w.WriteHeader(422)
|
||||||
w.Write([]byte(`{"error": "bad input"}`))
|
_, _ = w.Write([]byte(`{"error": "bad input"}`))
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
@@ -423,7 +711,7 @@ func TestHTTPError4xx(t *testing.T) {
|
|||||||
func TestHTTPError5xx(t *testing.T) {
|
func TestHTTPError5xx(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(500)
|
w.WriteHeader(500)
|
||||||
w.Write([]byte("internal server error"))
|
_, _ = w.Write([]byte("internal server error"))
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
@@ -467,8 +755,8 @@ func TestBuildMessages(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkPostJSON(b *testing.B) {
|
func BenchmarkPostJSON(b *testing.B) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
io.Copy(io.Discard, r.Body)
|
_, _ = io.Copy(io.Discard, r.Body)
|
||||||
w.Write([]byte(`{"ok":true}`))
|
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
@@ -482,7 +770,7 @@ func BenchmarkPostJSON(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
c.postJSON(ctx, "/test", payload)
|
_, _ = c.postJSON(ctx, "/test", payload)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
364
config/config.go
364
config/config.go
@@ -3,16 +3,16 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Settings holds base configuration for all handler services.
|
// Settings holds base configuration for all handler services.
|
||||||
@@ -20,249 +20,249 @@ import (
|
|||||||
// updated at runtime via WatchSecrets(). All other fields are immutable
|
// updated at runtime via WatchSecrets(). All other fields are immutable
|
||||||
// after Load() returns.
|
// after Load() returns.
|
||||||
type Settings struct {
|
type Settings struct {
|
||||||
// Service identification (immutable)
|
// Service identification (immutable)
|
||||||
ServiceName string
|
ServiceName string
|
||||||
ServiceVersion string
|
ServiceVersion string
|
||||||
ServiceNamespace string
|
ServiceNamespace string
|
||||||
DeploymentEnv string
|
DeploymentEnv string
|
||||||
|
|
||||||
// NATS configuration (immutable)
|
// NATS configuration (immutable)
|
||||||
NATSURL string
|
NATSURL string
|
||||||
NATSUser string
|
NATSUser string
|
||||||
NATSPassword string
|
NATSPassword string
|
||||||
NATSQueueGroup string
|
NATSQueueGroup string
|
||||||
|
|
||||||
// Redis/Valkey configuration (immutable)
|
// Redis/Valkey configuration (immutable)
|
||||||
RedisURL string
|
RedisURL string
|
||||||
RedisPassword string
|
RedisPassword string
|
||||||
|
|
||||||
// Milvus configuration (immutable)
|
// Milvus configuration (immutable)
|
||||||
MilvusHost string
|
MilvusHost string
|
||||||
MilvusPort int
|
MilvusPort int
|
||||||
MilvusCollection string
|
MilvusCollection string
|
||||||
|
|
||||||
// OpenTelemetry configuration (immutable)
|
// OpenTelemetry configuration (immutable)
|
||||||
OTELEnabled bool
|
OTELEnabled bool
|
||||||
OTELEndpoint string
|
OTELEndpoint string
|
||||||
OTELUseHTTP bool
|
OTELUseHTTP bool
|
||||||
|
|
||||||
// HyperDX configuration (immutable)
|
// HyperDX configuration (immutable)
|
||||||
HyperDXEnabled bool
|
HyperDXEnabled bool
|
||||||
HyperDXAPIKey string
|
HyperDXAPIKey string
|
||||||
HyperDXEndpoint string
|
HyperDXEndpoint string
|
||||||
|
|
||||||
// MLflow configuration (immutable)
|
// MLflow configuration (immutable)
|
||||||
MLflowTrackingURI string
|
MLflowTrackingURI string
|
||||||
MLflowExperimentName string
|
MLflowExperimentName string
|
||||||
MLflowEnabled bool
|
MLflowEnabled bool
|
||||||
|
|
||||||
// Health check configuration (immutable)
|
// Health check configuration (immutable)
|
||||||
HealthPort int
|
HealthPort int
|
||||||
HealthPath string
|
HealthPath string
|
||||||
ReadyPath string
|
ReadyPath string
|
||||||
|
|
||||||
// Timeouts (immutable)
|
// Timeouts (immutable)
|
||||||
HTTPTimeout time.Duration
|
HTTPTimeout time.Duration
|
||||||
NATSTimeout time.Duration
|
NATSTimeout time.Duration
|
||||||
|
|
||||||
// Hot-reloadable fields — access via getter methods.
|
// Hot-reloadable fields — access via getter methods.
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
embeddingsURL string
|
embeddingsURL string
|
||||||
rerankerURL string
|
rerankerURL string
|
||||||
llmURL string
|
llmURL string
|
||||||
ttsURL string
|
ttsURL string
|
||||||
sttURL string
|
sttURL string
|
||||||
|
|
||||||
// Secrets path for file-based hot reload (Kubernetes secret mounts)
|
// Secrets path for file-based hot reload (Kubernetes secret mounts)
|
||||||
SecretsPath string
|
SecretsPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load creates a Settings populated from environment variables with defaults.
|
// Load creates a Settings populated from environment variables with defaults.
|
||||||
func Load() *Settings {
|
func Load() *Settings {
|
||||||
return &Settings{
|
return &Settings{
|
||||||
ServiceName: getEnv("SERVICE_NAME", "handler"),
|
ServiceName: getEnv("SERVICE_NAME", "handler"),
|
||||||
ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"),
|
ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"),
|
||||||
ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"),
|
ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"),
|
||||||
DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"),
|
DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"),
|
||||||
|
|
||||||
NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"),
|
NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"),
|
||||||
NATSUser: getEnv("NATS_USER", ""),
|
NATSUser: getEnv("NATS_USER", ""),
|
||||||
NATSPassword: getEnv("NATS_PASSWORD", ""),
|
NATSPassword: getEnv("NATS_PASSWORD", ""),
|
||||||
NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""),
|
NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""),
|
||||||
|
|
||||||
RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"),
|
RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"),
|
||||||
RedisPassword: getEnv("REDIS_PASSWORD", ""),
|
RedisPassword: getEnv("REDIS_PASSWORD", ""),
|
||||||
|
|
||||||
MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"),
|
MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"),
|
||||||
MilvusPort: getEnvInt("MILVUS_PORT", 19530),
|
MilvusPort: getEnvInt("MILVUS_PORT", 19530),
|
||||||
MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"),
|
MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"),
|
||||||
|
|
||||||
embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"),
|
embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"),
|
||||||
rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"),
|
rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"),
|
||||||
llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"),
|
llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"),
|
||||||
ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"),
|
ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"),
|
||||||
sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"),
|
sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"),
|
||||||
|
|
||||||
OTELEnabled: getEnvBool("OTEL_ENABLED", true),
|
OTELEnabled: getEnvBool("OTEL_ENABLED", true),
|
||||||
OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"),
|
OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"),
|
||||||
OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false),
|
OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false),
|
||||||
|
|
||||||
HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false),
|
HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false),
|
||||||
HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""),
|
HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""),
|
||||||
HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"),
|
HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"),
|
||||||
|
|
||||||
MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"),
|
MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"),
|
||||||
MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""),
|
MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""),
|
||||||
MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true),
|
MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true),
|
||||||
|
|
||||||
HealthPort: getEnvInt("HEALTH_PORT", 8080),
|
HealthPort: getEnvInt("HEALTH_PORT", 8080),
|
||||||
HealthPath: getEnv("HEALTH_PATH", "/health"),
|
HealthPath: getEnv("HEALTH_PATH", "/health"),
|
||||||
ReadyPath: getEnv("READY_PATH", "/ready"),
|
ReadyPath: getEnv("READY_PATH", "/ready"),
|
||||||
|
|
||||||
HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second),
|
HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second),
|
||||||
NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second),
|
NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second),
|
||||||
|
|
||||||
SecretsPath: getEnv("SECRETS_PATH", ""),
|
SecretsPath: getEnv("SECRETS_PATH", ""),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmbeddingsURL returns the current embeddings service URL (thread-safe).
|
// EmbeddingsURL returns the current embeddings service URL (thread-safe).
|
||||||
func (s *Settings) EmbeddingsURL() string {
|
func (s *Settings) EmbeddingsURL() string {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
return s.embeddingsURL
|
return s.embeddingsURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// RerankerURL returns the current reranker service URL (thread-safe).
|
// RerankerURL returns the current reranker service URL (thread-safe).
|
||||||
func (s *Settings) RerankerURL() string {
|
func (s *Settings) RerankerURL() string {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
return s.rerankerURL
|
return s.rerankerURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLMURL returns the current LLM service URL (thread-safe).
|
// LLMURL returns the current LLM service URL (thread-safe).
|
||||||
func (s *Settings) LLMURL() string {
|
func (s *Settings) LLMURL() string {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
return s.llmURL
|
return s.llmURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// TTSURL returns the current TTS service URL (thread-safe).
|
// TTSURL returns the current TTS service URL (thread-safe).
|
||||||
func (s *Settings) TTSURL() string {
|
func (s *Settings) TTSURL() string {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
return s.ttsURL
|
return s.ttsURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// STTURL returns the current STT service URL (thread-safe).
|
// STTURL returns the current STT service URL (thread-safe).
|
||||||
func (s *Settings) STTURL() string {
|
func (s *Settings) STTURL() string {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
return s.sttURL
|
return s.sttURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// WatchSecrets watches the SecretsPath directory for changes and reloads
|
// WatchSecrets watches the SecretsPath directory for changes and reloads
|
||||||
// hot-reloadable fields. Blocks until ctx is cancelled.
|
// hot-reloadable fields. Blocks until ctx is cancelled.
|
||||||
func (s *Settings) WatchSecrets(ctx context.Context) {
|
func (s *Settings) WatchSecrets(ctx context.Context) {
|
||||||
if s.SecretsPath == "" {
|
if s.SecretsPath == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
watcher, err := fsnotify.NewWatcher()
|
watcher, err := fsnotify.NewWatcher()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("config: failed to create fsnotify watcher", "error", err)
|
slog.Error("config: failed to create fsnotify watcher", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { _ = watcher.Close() }()
|
defer func() { _ = watcher.Close() }()
|
||||||
|
|
||||||
if err := watcher.Add(s.SecretsPath); err != nil {
|
if err := watcher.Add(s.SecretsPath); err != nil {
|
||||||
slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath)
|
slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath)
|
slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case event, ok := <-watcher.Events:
|
case event, ok := <-watcher.Events:
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) {
|
if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) {
|
||||||
s.reloadFromSecrets()
|
s.reloadFromSecrets()
|
||||||
}
|
}
|
||||||
case err, ok := <-watcher.Errors:
|
case err, ok := <-watcher.Errors:
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
slog.Error("config: fsnotify error", "error", err)
|
slog.Error("config: fsnotify error", "error", err)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// reloadFromSecrets reads hot-reloadable values from the secrets directory.
|
// reloadFromSecrets reads hot-reloadable values from the secrets directory.
|
||||||
func (s *Settings) reloadFromSecrets() {
|
func (s *Settings) reloadFromSecrets() {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
updated := 0
|
updated := 0
|
||||||
reload := func(filename string, target *string) {
|
reload := func(filename string, target *string) {
|
||||||
path := filepath.Join(s.SecretsPath, filename)
|
path := filepath.Join(s.SecretsPath, filename)
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
val := strings.TrimSpace(string(data))
|
val := strings.TrimSpace(string(data))
|
||||||
if val != "" && val != *target {
|
if val != "" && val != *target {
|
||||||
*target = val
|
*target = val
|
||||||
updated++
|
updated++
|
||||||
slog.Info("config: reloaded secret", "key", filename)
|
slog.Info("config: reloaded secret", "key", filename)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
reload("embeddings-url", &s.embeddingsURL)
|
reload("embeddings-url", &s.embeddingsURL)
|
||||||
reload("reranker-url", &s.rerankerURL)
|
reload("reranker-url", &s.rerankerURL)
|
||||||
reload("llm-url", &s.llmURL)
|
reload("llm-url", &s.llmURL)
|
||||||
reload("tts-url", &s.ttsURL)
|
reload("tts-url", &s.ttsURL)
|
||||||
reload("stt-url", &s.sttURL)
|
reload("stt-url", &s.sttURL)
|
||||||
|
|
||||||
if updated > 0 {
|
if updated > 0 {
|
||||||
slog.Info("config: secrets reloaded", "updated", updated)
|
slog.Info("config: secrets reloaded", "updated", updated)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getEnv(key, fallback string) string {
|
func getEnv(key, fallback string) string {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
return fallback
|
return fallback
|
||||||
}
|
}
|
||||||
|
|
||||||
func getEnvInt(key string, fallback int) int {
|
func getEnvInt(key string, fallback int) int {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
if i, err := strconv.Atoi(v); err == nil {
|
if i, err := strconv.Atoi(v); err == nil {
|
||||||
return i
|
return i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fallback
|
return fallback
|
||||||
}
|
}
|
||||||
|
|
||||||
func getEnvBool(key string, fallback bool) bool {
|
func getEnvBool(key string, fallback bool) bool {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
if b, err := strconv.ParseBool(v); err == nil {
|
if b, err := strconv.ParseBool(v); err == nil {
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fallback
|
return fallback
|
||||||
}
|
}
|
||||||
|
|
||||||
func getEnvDuration(key string, fallback time.Duration) time.Duration {
|
func getEnvDuration(key string, fallback time.Duration) time.Duration {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||||
return time.Duration(f * float64(time.Second))
|
return time.Duration(f * float64(time.Second))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fallback
|
return fallback
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,123 +1,123 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLoadDefaults(t *testing.T) {
|
func TestLoadDefaults(t *testing.T) {
|
||||||
s := Load()
|
s := Load()
|
||||||
if s.ServiceName != "handler" {
|
if s.ServiceName != "handler" {
|
||||||
t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName)
|
t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName)
|
||||||
}
|
}
|
||||||
if s.HealthPort != 8080 {
|
if s.HealthPort != 8080 {
|
||||||
t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort)
|
t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort)
|
||||||
}
|
}
|
||||||
if s.HTTPTimeout != 60*time.Second {
|
if s.HTTPTimeout != 60*time.Second {
|
||||||
t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout)
|
t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadFromEnv(t *testing.T) {
|
func TestLoadFromEnv(t *testing.T) {
|
||||||
t.Setenv("SERVICE_NAME", "test-svc")
|
t.Setenv("SERVICE_NAME", "test-svc")
|
||||||
t.Setenv("HEALTH_PORT", "9090")
|
t.Setenv("HEALTH_PORT", "9090")
|
||||||
t.Setenv("OTEL_ENABLED", "false")
|
t.Setenv("OTEL_ENABLED", "false")
|
||||||
|
|
||||||
s := Load()
|
s := Load()
|
||||||
if s.ServiceName != "test-svc" {
|
if s.ServiceName != "test-svc" {
|
||||||
t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName)
|
t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName)
|
||||||
}
|
}
|
||||||
if s.HealthPort != 9090 {
|
if s.HealthPort != 9090 {
|
||||||
t.Errorf("expected HealthPort 9090, got %d", s.HealthPort)
|
t.Errorf("expected HealthPort 9090, got %d", s.HealthPort)
|
||||||
}
|
}
|
||||||
if s.OTELEnabled {
|
if s.OTELEnabled {
|
||||||
t.Error("expected OTELEnabled false")
|
t.Error("expected OTELEnabled false")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestURLGetters(t *testing.T) {
|
func TestURLGetters(t *testing.T) {
|
||||||
s := Load()
|
s := Load()
|
||||||
if s.EmbeddingsURL() == "" {
|
if s.EmbeddingsURL() == "" {
|
||||||
t.Error("EmbeddingsURL should have a default")
|
t.Error("EmbeddingsURL should have a default")
|
||||||
}
|
}
|
||||||
if s.RerankerURL() == "" {
|
if s.RerankerURL() == "" {
|
||||||
t.Error("RerankerURL should have a default")
|
t.Error("RerankerURL should have a default")
|
||||||
}
|
}
|
||||||
if s.LLMURL() == "" {
|
if s.LLMURL() == "" {
|
||||||
t.Error("LLMURL should have a default")
|
t.Error("LLMURL should have a default")
|
||||||
}
|
}
|
||||||
if s.TTSURL() == "" {
|
if s.TTSURL() == "" {
|
||||||
t.Error("TTSURL should have a default")
|
t.Error("TTSURL should have a default")
|
||||||
}
|
}
|
||||||
if s.STTURL() == "" {
|
if s.STTURL() == "" {
|
||||||
t.Error("STTURL should have a default")
|
t.Error("STTURL should have a default")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestURLGettersFromEnv(t *testing.T) {
|
func TestURLGettersFromEnv(t *testing.T) {
|
||||||
t.Setenv("EMBEDDINGS_URL", "http://embed:8000")
|
t.Setenv("EMBEDDINGS_URL", "http://embed:8000")
|
||||||
t.Setenv("LLM_URL", "http://llm:9000")
|
t.Setenv("LLM_URL", "http://llm:9000")
|
||||||
|
|
||||||
s := Load()
|
s := Load()
|
||||||
if s.EmbeddingsURL() != "http://embed:8000" {
|
if s.EmbeddingsURL() != "http://embed:8000" {
|
||||||
t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL())
|
t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL())
|
||||||
}
|
}
|
||||||
if s.LLMURL() != "http://llm:9000" {
|
if s.LLMURL() != "http://llm:9000" {
|
||||||
t.Errorf("expected custom LLMURL, got %q", s.LLMURL())
|
t.Errorf("expected custom LLMURL, got %q", s.LLMURL())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReloadFromSecrets(t *testing.T) {
|
func TestReloadFromSecrets(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
|
|
||||||
// Write initial secret files
|
// Write initial secret files
|
||||||
writeSecret(t, dir, "embeddings-url", "http://old-embed:8000")
|
writeSecret(t, dir, "embeddings-url", "http://old-embed:8000")
|
||||||
writeSecret(t, dir, "llm-url", "http://old-llm:9000")
|
writeSecret(t, dir, "llm-url", "http://old-llm:9000")
|
||||||
|
|
||||||
s := Load()
|
s := Load()
|
||||||
s.SecretsPath = dir
|
s.SecretsPath = dir
|
||||||
s.reloadFromSecrets()
|
s.reloadFromSecrets()
|
||||||
|
|
||||||
if s.EmbeddingsURL() != "http://old-embed:8000" {
|
if s.EmbeddingsURL() != "http://old-embed:8000" {
|
||||||
t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL())
|
t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL())
|
||||||
}
|
}
|
||||||
if s.LLMURL() != "http://old-llm:9000" {
|
if s.LLMURL() != "http://old-llm:9000" {
|
||||||
t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL())
|
t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simulate secret update
|
// Simulate secret update
|
||||||
writeSecret(t, dir, "embeddings-url", "http://new-embed:8000")
|
writeSecret(t, dir, "embeddings-url", "http://new-embed:8000")
|
||||||
s.reloadFromSecrets()
|
s.reloadFromSecrets()
|
||||||
|
|
||||||
if s.EmbeddingsURL() != "http://new-embed:8000" {
|
if s.EmbeddingsURL() != "http://new-embed:8000" {
|
||||||
t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL())
|
t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL())
|
||||||
}
|
}
|
||||||
// LLM should remain unchanged
|
// LLM should remain unchanged
|
||||||
if s.LLMURL() != "http://old-llm:9000" {
|
if s.LLMURL() != "http://old-llm:9000" {
|
||||||
t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL())
|
t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReloadFromSecretsNoPath(t *testing.T) {
|
func TestReloadFromSecretsNoPath(t *testing.T) {
|
||||||
s := Load()
|
s := Load()
|
||||||
s.SecretsPath = ""
|
s.SecretsPath = ""
|
||||||
// Should not panic
|
// Should not panic
|
||||||
s.reloadFromSecrets()
|
s.reloadFromSecrets()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetEnvDuration(t *testing.T) {
|
func TestGetEnvDuration(t *testing.T) {
|
||||||
t.Setenv("TEST_DUR", "30")
|
t.Setenv("TEST_DUR", "30")
|
||||||
d := getEnvDuration("TEST_DUR", 10*time.Second)
|
d := getEnvDuration("TEST_DUR", 10*time.Second)
|
||||||
if d != 30*time.Second {
|
if d != 30*time.Second {
|
||||||
t.Errorf("expected 30s, got %v", d)
|
t.Errorf("expected 30s, got %v", d)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeSecret(t *testing.T, dir, name, value string) {
|
func writeSecret(t *testing.T, dir, name, value string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil {
|
if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
2011
gen/messagespb/messages.pb.go
Normal file
2011
gen/messagespb/messages.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
6
go.mod
6
go.mod
@@ -3,8 +3,8 @@ module git.daviestechlabs.io/daviestechlabs/handler-base
|
|||||||
go 1.25.1
|
go 1.25.1
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0
|
||||||
github.com/nats-io/nats.go v1.48.0
|
github.com/nats-io/nats.go v1.48.0
|
||||||
github.com/vmihailenco/msgpack/v5 v5.4.1
|
|
||||||
go.opentelemetry.io/otel v1.40.0
|
go.opentelemetry.io/otel v1.40.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0
|
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0
|
||||||
@@ -12,12 +12,12 @@ require (
|
|||||||
go.opentelemetry.io/otel/sdk v1.40.0
|
go.opentelemetry.io/otel/sdk v1.40.0
|
||||||
go.opentelemetry.io/otel/sdk/metric v1.40.0
|
go.opentelemetry.io/otel/sdk/metric v1.40.0
|
||||||
go.opentelemetry.io/otel/trace v1.40.0
|
go.opentelemetry.io/otel/trace v1.40.0
|
||||||
|
google.golang.org/protobuf v1.36.11
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
|
||||||
github.com/go-logr/logr v1.4.3 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
@@ -25,7 +25,6 @@ require (
|
|||||||
github.com/klauspost/compress v1.18.0 // indirect
|
github.com/klauspost/compress v1.18.0 // indirect
|
||||||
github.com/nats-io/nkeys v0.4.11 // indirect
|
github.com/nats-io/nkeys v0.4.11 // indirect
|
||||||
github.com/nats-io/nuid v1.0.1 // indirect
|
github.com/nats-io/nuid v1.0.1 // indirect
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
|
||||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 // indirect
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 // indirect
|
||||||
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
|
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
|
||||||
@@ -36,5 +35,4 @@ require (
|
|||||||
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect
|
||||||
google.golang.org/grpc v1.78.0 // indirect
|
google.golang.org/grpc v1.78.0 // indirect
|
||||||
google.golang.org/protobuf v1.36.11 // indirect
|
|
||||||
)
|
)
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -31,10 +31,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
|
||||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
|
||||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||||
go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms=
|
go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms=
|
||||||
|
|||||||
@@ -10,22 +10,19 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||||
|
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MessageHandler is the callback for processing decoded NATS messages.
|
// TypedMessageHandler processes the raw NATS message.
|
||||||
// data is the msgpack-decoded map. Return a response map (or nil for no reply).
|
// Services unmarshal msg.Data into their own typed structs via natsutil.Decode.
|
||||||
type MessageHandler func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error)
|
// Return a proto.Message (or nil for no reply).
|
||||||
|
type TypedMessageHandler func(ctx context.Context, msg *nats.Msg) (proto.Message, error)
|
||||||
// TypedMessageHandler processes the raw NATS message without pre-decoding to
|
|
||||||
// map[string]any. Services unmarshal msg.Data into their own typed structs,
|
|
||||||
// avoiding the double-decode overhead. Return any msgpack-serialisable value
|
|
||||||
// (a typed struct, map, or nil for no reply).
|
|
||||||
type TypedMessageHandler func(ctx context.Context, msg *nats.Msg) (any, error)
|
|
||||||
|
|
||||||
// SetupFunc is called once before the handler starts processing messages.
|
// SetupFunc is called once before the handler starts processing messages.
|
||||||
type SetupFunc func(ctx context.Context) error
|
type SetupFunc func(ctx context.Context) error
|
||||||
@@ -43,7 +40,6 @@ type Handler struct {
|
|||||||
|
|
||||||
onSetup SetupFunc
|
onSetup SetupFunc
|
||||||
onTeardown TeardownFunc
|
onTeardown TeardownFunc
|
||||||
onMessage MessageHandler
|
|
||||||
onTypedMessage TypedMessageHandler
|
onTypedMessage TypedMessageHandler
|
||||||
running bool
|
running bool
|
||||||
}
|
}
|
||||||
@@ -74,12 +70,7 @@ func (h *Handler) OnSetup(fn SetupFunc) { h.onSetup = fn }
|
|||||||
// OnTeardown registers the teardown callback.
|
// OnTeardown registers the teardown callback.
|
||||||
func (h *Handler) OnTeardown(fn TeardownFunc) { h.onTeardown = fn }
|
func (h *Handler) OnTeardown(fn TeardownFunc) { h.onTeardown = fn }
|
||||||
|
|
||||||
// OnMessage registers the message handler callback.
|
// OnTypedMessage registers the message handler callback.
|
||||||
func (h *Handler) OnMessage(fn MessageHandler) { h.onMessage = fn }
|
|
||||||
|
|
||||||
// OnTypedMessage registers a typed message handler. It replaces OnMessage —
|
|
||||||
// wrapHandler will skip the map[string]any decode and let the callback
|
|
||||||
// unmarshal msg.Data directly.
|
|
||||||
func (h *Handler) OnTypedMessage(fn TypedMessageHandler) { h.onTypedMessage = fn }
|
func (h *Handler) OnTypedMessage(fn TypedMessageHandler) { h.onTypedMessage = fn }
|
||||||
|
|
||||||
// Run starts the handler: telemetry, health server, NATS subscription, and blocks until SIGTERM/SIGINT.
|
// Run starts the handler: telemetry, health server, NATS subscription, and blocks until SIGTERM/SIGINT.
|
||||||
@@ -131,7 +122,7 @@ func (h *Handler) Run() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Subscribe
|
// Subscribe
|
||||||
if h.onMessage == nil && h.onTypedMessage == nil {
|
if h.onTypedMessage == nil {
|
||||||
return fmt.Errorf("no message handler registered")
|
return fmt.Errorf("no message handler registered")
|
||||||
}
|
}
|
||||||
if err := h.NATS.Subscribe(h.Subject, h.wrapHandler(ctx), h.QueueGroup); err != nil {
|
if err := h.NATS.Subscribe(h.Subject, h.wrapHandler(ctx), h.QueueGroup); err != nil {
|
||||||
@@ -161,26 +152,16 @@ func (h *Handler) Run() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// wrapHandler creates a nats.MsgHandler that dispatches to the registered callback.
|
// wrapHandler creates a nats.MsgHandler that dispatches to the registered callback.
|
||||||
// If OnTypedMessage was used, msg.Data is passed directly without map decode.
|
|
||||||
// If OnMessage was used, msg.Data is decoded to map[string]any first.
|
|
||||||
func (h *Handler) wrapHandler(ctx context.Context) nats.MsgHandler {
|
func (h *Handler) wrapHandler(ctx context.Context) nats.MsgHandler {
|
||||||
if h.onTypedMessage != nil {
|
|
||||||
return h.wrapTypedHandler(ctx)
|
|
||||||
}
|
|
||||||
return h.wrapMapHandler(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// wrapTypedHandler dispatches to the TypedMessageHandler (no map decode).
|
|
||||||
func (h *Handler) wrapTypedHandler(ctx context.Context) nats.MsgHandler {
|
|
||||||
return func(msg *nats.Msg) {
|
return func(msg *nats.Msg) {
|
||||||
response, err := h.onTypedMessage(ctx, msg)
|
response, err := h.onTypedMessage(ctx, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("handler error", "subject", msg.Subject, "error", err)
|
slog.Error("handler error", "subject", msg.Subject, "error", err)
|
||||||
if msg.Reply != "" {
|
if msg.Reply != "" {
|
||||||
_ = h.NATS.Publish(msg.Reply, map[string]any{
|
_ = h.NATS.Publish(msg.Reply, &pb.ErrorResponse{
|
||||||
"error": true,
|
Error: true,
|
||||||
"message": err.Error(),
|
Message: err.Error(),
|
||||||
"type": fmt.Sprintf("%T", err),
|
Type: fmt.Sprintf("%T", err),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -192,40 +173,3 @@ func (h *Handler) wrapTypedHandler(ctx context.Context) nats.MsgHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// wrapMapHandler dispatches to the legacy MessageHandler (decodes to map first).
|
|
||||||
func (h *Handler) wrapMapHandler(ctx context.Context) nats.MsgHandler {
|
|
||||||
return func(msg *nats.Msg) {
|
|
||||||
data, err := natsutil.DecodeMsgpackMap(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to decode message", "subject", msg.Subject, "error", err)
|
|
||||||
if msg.Reply != "" {
|
|
||||||
_ = h.NATS.Publish(msg.Reply, map[string]any{
|
|
||||||
"error": true,
|
|
||||||
"message": err.Error(),
|
|
||||||
"type": "DecodeError",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := h.onMessage(ctx, msg, data)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("handler error", "subject", msg.Subject, "error", err)
|
|
||||||
if msg.Reply != "" {
|
|
||||||
_ = h.NATS.Publish(msg.Reply, map[string]any{
|
|
||||||
"error": true,
|
|
||||||
"message": err.Error(),
|
|
||||||
"type": fmt.Sprintf("%T", err),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if response != nil && msg.Reply != "" {
|
|
||||||
if err := h.NATS.Publish(msg.Reply, response); err != nil {
|
|
||||||
slog.Error("failed to publish reply", "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -5,9 +5,11 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
"github.com/vmihailenco/msgpack/v5"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||||
|
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -57,17 +59,17 @@ func TestCallbackRegistration(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if h.onSetup == nil || h.onTeardown == nil || h.onMessage == nil {
|
if h.onSetup == nil || h.onTeardown == nil || h.onTypedMessage == nil {
|
||||||
t.Error("callbacks should not be nil after registration")
|
t.Error("callbacks should not be nil after registration")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify setup/teardown work when called directly.
|
// Verify setup/teardown work when called directly.
|
||||||
h.onSetup(context.Background())
|
_ = h.onSetup(context.Background())
|
||||||
h.onTeardown(context.Background())
|
_ = h.onTeardown(context.Background())
|
||||||
if !setupCalled || !teardownCalled {
|
if !setupCalled || !teardownCalled {
|
||||||
t.Error("callbacks should have been invoked")
|
t.Error("callbacks should have been invoked")
|
||||||
}
|
}
|
||||||
@@ -77,8 +79,8 @@ func TestTypedMessageRegistration(t *testing.T) {
|
|||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
return map[string]any{"ok": true}, nil
|
return &pb.ChatResponse{Response: "ok"}, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if h.onTypedMessage == nil {
|
if h.onTypedMessage == nil {
|
||||||
@@ -94,19 +96,20 @@ func TestWrapHandler_ValidMessage(t *testing.T) {
|
|||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
var receivedData map[string]any
|
var receivedReq pb.ChatRequest
|
||||||
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
receivedData = data
|
if err := natsutil.Decode(msg.Data, &receivedReq); err != nil {
|
||||||
return map[string]any{"status": "ok"}, nil
|
return nil, err
|
||||||
|
}
|
||||||
|
return &pb.ChatResponse{Response: "ok", UserId: receivedReq.GetUserId()}, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
// Encode a message the same way services would.
|
// Encode a message the same way services would.
|
||||||
payload := map[string]any{
|
encoded, err := proto.Marshal(&pb.ChatRequest{
|
||||||
"request_id": "test-001",
|
RequestId: "test-001",
|
||||||
"message": "hello",
|
Message: "hello",
|
||||||
"premium": true,
|
Premium: true,
|
||||||
}
|
})
|
||||||
encoded, err := msgpack.Marshal(payload)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -118,47 +121,48 @@ func TestWrapHandler_ValidMessage(t *testing.T) {
|
|||||||
Data: encoded,
|
Data: encoded,
|
||||||
})
|
})
|
||||||
|
|
||||||
if receivedData == nil {
|
if receivedReq.GetRequestId() != "test-001" {
|
||||||
t.Fatal("handler was not called")
|
t.Errorf("request_id = %v", receivedReq.GetRequestId())
|
||||||
}
|
}
|
||||||
if receivedData["request_id"] != "test-001" {
|
if receivedReq.GetPremium() != true {
|
||||||
t.Errorf("request_id = %v", receivedData["request_id"])
|
t.Errorf("premium = %v", receivedReq.GetPremium())
|
||||||
}
|
|
||||||
if receivedData["premium"] != true {
|
|
||||||
t.Errorf("premium = %v", receivedData["premium"])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrapHandler_InvalidMsgpack(t *testing.T) {
|
func TestWrapHandler_InvalidMessage(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
handlerCalled := false
|
handlerCalled := false
|
||||||
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
handlerCalled = true
|
handlerCalled = true
|
||||||
return nil, nil
|
var req pb.ChatRequest
|
||||||
|
if err := natsutil.Decode(msg.Data, &req); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &pb.ChatResponse{}, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
handler(&nats.Msg{
|
handler(&nats.Msg{
|
||||||
Subject: "ai.test",
|
Subject: "ai.test",
|
||||||
Data: []byte{0xFF, 0xFE, 0xFD}, // invalid msgpack
|
Data: []byte{0xFF, 0xFE, 0xFD}, // invalid protobuf
|
||||||
})
|
})
|
||||||
|
|
||||||
if handlerCalled {
|
// The handler IS called (wrapHandler doesn't pre-decode), but it should
|
||||||
t.Error("handler should not be called for invalid msgpack")
|
// return an error from Decode. Either way no panic.
|
||||||
}
|
_ = handlerCalled
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrapHandler_HandlerError(t *testing.T) {
|
func TestWrapHandler_HandlerError(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
return nil, context.DeadlineExceeded
|
return nil, context.DeadlineExceeded
|
||||||
})
|
})
|
||||||
|
|
||||||
encoded, _ := msgpack.Marshal(map[string]any{"key": "val"})
|
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err-test"})
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
|
|
||||||
// Should not panic even when handler returns error.
|
// Should not panic even when handler returns error.
|
||||||
@@ -172,11 +176,11 @@ func TestWrapHandler_NilResponse(t *testing.T) {
|
|||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
return nil, nil // fire-and-forget style
|
return nil, nil // fire-and-forget style
|
||||||
})
|
})
|
||||||
|
|
||||||
encoded, _ := msgpack.Marshal(map[string]any{"x": 1})
|
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil-resp"})
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
|
|
||||||
// Should not panic with nil response and no reply subject.
|
// Should not panic with nil response and no reply subject.
|
||||||
@@ -190,63 +194,58 @@ func TestWrapHandler_NilResponse(t *testing.T) {
|
|||||||
// wrapHandler dispatch tests — typed handler path
|
// wrapHandler dispatch tests — typed handler path
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func TestWrapTypedHandler_ValidMessage(t *testing.T) {
|
func TestWrapHandler_Typed(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
type testReq struct {
|
var received pb.ChatRequest
|
||||||
RequestID string `msgpack:"request_id"`
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
Message string `msgpack:"message"`
|
if err := natsutil.Decode(msg.Data, &received); err != nil {
|
||||||
}
|
|
||||||
|
|
||||||
var received testReq
|
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
|
||||||
if err := msgpack.Unmarshal(msg.Data, &received); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return map[string]any{"status": "ok"}, nil
|
return &pb.ChatResponse{UserId: received.GetUserId(), Response: "ok"}, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
encoded, _ := msgpack.Marshal(map[string]any{
|
encoded, _ := proto.Marshal(&pb.ChatRequest{
|
||||||
"request_id": "typed-001",
|
RequestId: "typed-001",
|
||||||
"message": "hello typed",
|
Message: "hello typed",
|
||||||
})
|
})
|
||||||
|
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
||||||
|
|
||||||
if received.RequestID != "typed-001" {
|
if received.GetRequestId() != "typed-001" {
|
||||||
t.Errorf("RequestID = %q", received.RequestID)
|
t.Errorf("RequestId = %q", received.GetRequestId())
|
||||||
}
|
}
|
||||||
if received.Message != "hello typed" {
|
if received.GetMessage() != "hello typed" {
|
||||||
t.Errorf("Message = %q", received.Message)
|
t.Errorf("Message = %q", received.GetMessage())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrapTypedHandler_Error(t *testing.T) {
|
func TestWrapHandler_TypedError(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
return nil, context.DeadlineExceeded
|
return nil, context.DeadlineExceeded
|
||||||
})
|
})
|
||||||
|
|
||||||
encoded, _ := msgpack.Marshal(map[string]any{"key": "val"})
|
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err"})
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
|
|
||||||
// Should not panic.
|
// Should not panic.
|
||||||
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrapTypedHandler_NilResponse(t *testing.T) {
|
func TestWrapHandler_TypedNilResponse(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
encoded, _ := msgpack.Marshal(map[string]any{"x": 1})
|
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil"})
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
||||||
}
|
}
|
||||||
@@ -258,49 +257,18 @@ func TestWrapTypedHandler_NilResponse(t *testing.T) {
|
|||||||
func BenchmarkWrapHandler(b *testing.B) {
|
func BenchmarkWrapHandler(b *testing.B) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
return map[string]any{"ok": true}, nil
|
var req pb.ChatRequest
|
||||||
|
_ = natsutil.Decode(msg.Data, &req)
|
||||||
|
return &pb.ChatResponse{Response: "ok"}, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
payload := map[string]any{
|
encoded, _ := proto.Marshal(&pb.ChatRequest{
|
||||||
"request_id": "bench-001",
|
RequestId: "bench-001",
|
||||||
"message": "What is the capital of France?",
|
Message: "What is the capital of France?",
|
||||||
"premium": true,
|
Premium: true,
|
||||||
"top_k": 10,
|
TopK: 10,
|
||||||
}
|
})
|
||||||
encoded, _ := msgpack.Marshal(payload)
|
|
||||||
handler := h.wrapHandler(context.Background())
|
|
||||||
msg := &nats.Msg{Subject: "ai.test", Data: encoded}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
handler(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkWrapTypedHandler(b *testing.B) {
|
|
||||||
type benchReq struct {
|
|
||||||
RequestID string `msgpack:"request_id"`
|
|
||||||
Message string `msgpack:"message"`
|
|
||||||
Premium bool `msgpack:"premium"`
|
|
||||||
TopK int `msgpack:"top_k"`
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := config.Load()
|
|
||||||
h := New("ai.test", cfg)
|
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
|
||||||
var req benchReq
|
|
||||||
msgpack.Unmarshal(msg.Data, &req)
|
|
||||||
return map[string]any{"ok": true}, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
payload := map[string]any{
|
|
||||||
"request_id": "bench-001",
|
|
||||||
"message": "What is the capital of France?",
|
|
||||||
"premium": true,
|
|
||||||
"top_k": 10,
|
|
||||||
}
|
|
||||||
encoded, _ := msgpack.Marshal(payload)
|
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
msg := &nats.Msg{Subject: "ai.test", Data: encoded}
|
msg := &nats.Msg{Subject: "ai.test", Data: encoded}
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,7 +21,6 @@ type Server struct {
|
|||||||
readyPath string
|
readyPath string
|
||||||
readyCheck ReadyFunc
|
readyCheck ReadyFunc
|
||||||
srv *http.Server
|
srv *http.Server
|
||||||
ready atomic.Bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a health server on the given port.
|
// New creates a health server on the given port.
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func TestHealthEndpoint(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("health request failed: %v", err)
|
t.Fatalf("health request failed: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
t.Errorf("expected 200, got %d", resp.StatusCode)
|
t.Errorf("expected 200, got %d", resp.StatusCode)
|
||||||
@@ -42,7 +42,7 @@ func TestReadyEndpointDefault(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ready request failed: %v", err)
|
t.Fatalf("ready request failed: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
t.Errorf("expected 200, got %d", resp.StatusCode)
|
t.Errorf("expected 200, got %d", resp.StatusCode)
|
||||||
@@ -60,7 +60,7 @@ func TestReadyEndpointNotReady(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ready request failed: %v", err)
|
t.Fatalf("ready request failed: %v", err)
|
||||||
}
|
}
|
||||||
resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
if resp.StatusCode != 503 {
|
if resp.StatusCode != 503 {
|
||||||
t.Errorf("expected 503 when not ready, got %d", resp.StatusCode)
|
t.Errorf("expected 503 when not ready, got %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
@@ -70,7 +70,7 @@ func TestReadyEndpointNotReady(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ready request failed: %v", err)
|
t.Fatalf("ready request failed: %v", err)
|
||||||
}
|
}
|
||||||
resp2.Body.Close()
|
_ = resp2.Body.Close()
|
||||||
if resp2.StatusCode != 200 {
|
if resp2.StatusCode != 200 {
|
||||||
t.Errorf("expected 200 when ready, got %d", resp2.StatusCode)
|
t.Errorf("expected 200 when ready, got %d", resp2.StatusCode)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,4 @@
|
|||||||
// Package messages benchmarks compare three serialization strategies:
|
// Package messages benchmarks protobuf encoding/decoding of all message types.
|
||||||
//
|
|
||||||
// 1. msgpack map[string]any — the old approach (dynamic, no types)
|
|
||||||
// 2. msgpack typed struct — the new approach (compile-time safe, short keys)
|
|
||||||
// 3. protobuf — optional future migration
|
|
||||||
//
|
//
|
||||||
// Run with:
|
// Run with:
|
||||||
//
|
//
|
||||||
@@ -14,52 +10,15 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/vmihailenco/msgpack/v5"
|
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
pb "git.daviestechlabs.io/daviestechlabs/handler-base/messages/proto"
|
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
// Test fixtures — equivalent data across all three encodings
|
// Test fixtures — proto message constructors
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
// chatRequestMap is the legacy map[string]any representation.
|
|
||||||
func chatRequestMap() map[string]any {
|
|
||||||
return map[string]any{
|
|
||||||
"request_id": "req-abc-123",
|
|
||||||
"user_id": "user-42",
|
|
||||||
"message": "What is the capital of France?",
|
|
||||||
"query": "",
|
|
||||||
"premium": true,
|
|
||||||
"enable_rag": true,
|
|
||||||
"enable_reranker": true,
|
|
||||||
"enable_streaming": false,
|
|
||||||
"top_k": 10,
|
|
||||||
"collection": "documents",
|
|
||||||
"enable_tts": false,
|
|
||||||
"system_prompt": "You are a helpful assistant.",
|
|
||||||
"response_subject": "ai.chat.response.req-abc-123",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// chatRequestStruct is the typed struct representation.
|
|
||||||
func chatRequestStruct() ChatRequest {
|
|
||||||
return ChatRequest{
|
|
||||||
RequestID: "req-abc-123",
|
|
||||||
UserID: "user-42",
|
|
||||||
Message: "What is the capital of France?",
|
|
||||||
Premium: true,
|
|
||||||
EnableRAG: true,
|
|
||||||
EnableReranker: true,
|
|
||||||
TopK: 10,
|
|
||||||
Collection: "documents",
|
|
||||||
SystemPrompt: "You are a helpful assistant.",
|
|
||||||
ResponseSubject: "ai.chat.response.req-abc-123",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// chatRequestProto is the protobuf representation.
|
|
||||||
func chatRequestProto() *pb.ChatRequest {
|
func chatRequestProto() *pb.ChatRequest {
|
||||||
return &pb.ChatRequest{
|
return &pb.ChatRequest{
|
||||||
RequestId: "req-abc-123",
|
RequestId: "req-abc-123",
|
||||||
@@ -75,25 +34,6 @@ func chatRequestProto() *pb.ChatRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// voiceResponseMap is a voice response with a 16 KB audio payload.
|
|
||||||
func voiceResponseMap() map[string]any {
|
|
||||||
return map[string]any{
|
|
||||||
"request_id": "vr-001",
|
|
||||||
"response": "The capital of France is Paris.",
|
|
||||||
"audio": make([]byte, 16384),
|
|
||||||
"transcription": "What is the capital of France?",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func voiceResponseStruct() VoiceResponse {
|
|
||||||
return VoiceResponse{
|
|
||||||
RequestID: "vr-001",
|
|
||||||
Response: "The capital of France is Paris.",
|
|
||||||
Audio: make([]byte, 16384),
|
|
||||||
Transcription: "What is the capital of France?",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func voiceResponseProto() *pb.VoiceResponse {
|
func voiceResponseProto() *pb.VoiceResponse {
|
||||||
return &pb.VoiceResponse{
|
return &pb.VoiceResponse{
|
||||||
RequestId: "vr-001",
|
RequestId: "vr-001",
|
||||||
@@ -103,31 +43,6 @@ func voiceResponseProto() *pb.VoiceResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ttsChunkMap simulates a streaming audio chunk (~32 KB).
|
|
||||||
func ttsChunkMap() map[string]any {
|
|
||||||
return map[string]any{
|
|
||||||
"session_id": "tts-sess-99",
|
|
||||||
"chunk_index": 3,
|
|
||||||
"total_chunks": 12,
|
|
||||||
"audio_b64": string(make([]byte, 32768)), // old: base64 string
|
|
||||||
"is_last": false,
|
|
||||||
"timestamp": time.Now().Unix(),
|
|
||||||
"sample_rate": 24000,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ttsChunkStruct() TTSAudioChunk {
|
|
||||||
return TTSAudioChunk{
|
|
||||||
SessionID: "tts-sess-99",
|
|
||||||
ChunkIndex: 3,
|
|
||||||
TotalChunks: 12,
|
|
||||||
Audio: make([]byte, 32768), // new: raw bytes
|
|
||||||
IsLast: false,
|
|
||||||
Timestamp: time.Now().Unix(),
|
|
||||||
SampleRate: 24000,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ttsChunkProto() *pb.TTSAudioChunk {
|
func ttsChunkProto() *pb.TTSAudioChunk {
|
||||||
return &pb.TTSAudioChunk{
|
return &pb.TTSAudioChunk{
|
||||||
SessionId: "tts-sess-99",
|
SessionId: "tts-sess-99",
|
||||||
@@ -147,26 +62,16 @@ func ttsChunkProto() *pb.TTSAudioChunk {
|
|||||||
func TestWireSize(t *testing.T) {
|
func TestWireSize(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
mapData any
|
|
||||||
structVal any
|
|
||||||
protoMsg proto.Message
|
protoMsg proto.Message
|
||||||
}{
|
}{
|
||||||
{"ChatRequest", chatRequestMap(), chatRequestStruct(), chatRequestProto()},
|
{"ChatRequest", chatRequestProto()},
|
||||||
{"VoiceResponse", voiceResponseMap(), voiceResponseStruct(), voiceResponseProto()},
|
{"VoiceResponse", voiceResponseProto()},
|
||||||
{"TTSAudioChunk", ttsChunkMap(), ttsChunkStruct(), ttsChunkProto()},
|
{"TTSAudioChunk", ttsChunkProto()},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
mapBytes, _ := msgpack.Marshal(tt.mapData)
|
|
||||||
structBytes, _ := msgpack.Marshal(tt.structVal)
|
|
||||||
protoBytes, _ := proto.Marshal(tt.protoMsg)
|
protoBytes, _ := proto.Marshal(tt.protoMsg)
|
||||||
|
t.Logf("%-16s proto=%5d B", tt.name, len(protoBytes))
|
||||||
t.Logf("%-16s map=%5d B struct=%5d B proto=%5d B (struct saves %.0f%%, proto saves %.0f%%)",
|
|
||||||
tt.name,
|
|
||||||
len(mapBytes), len(structBytes), len(protoBytes),
|
|
||||||
100*(1-float64(len(structBytes))/float64(len(mapBytes))),
|
|
||||||
100*(1-float64(len(protoBytes))/float64(len(mapBytes))),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,75 +79,27 @@ func TestWireSize(t *testing.T) {
|
|||||||
// Encode benchmarks
|
// Encode benchmarks
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func BenchmarkEncode_ChatRequest_MsgpackMap(b *testing.B) {
|
func BenchmarkEncode_ChatRequest(b *testing.B) {
|
||||||
data := chatRequestMap()
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
msgpack.Marshal(data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkEncode_ChatRequest_MsgpackStruct(b *testing.B) {
|
|
||||||
data := chatRequestStruct()
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
msgpack.Marshal(data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkEncode_ChatRequest_Protobuf(b *testing.B) {
|
|
||||||
data := chatRequestProto()
|
data := chatRequestProto()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
proto.Marshal(data)
|
_, _ = proto.Marshal(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkEncode_VoiceResponse_MsgpackMap(b *testing.B) {
|
func BenchmarkEncode_VoiceResponse(b *testing.B) {
|
||||||
data := voiceResponseMap()
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
msgpack.Marshal(data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkEncode_VoiceResponse_MsgpackStruct(b *testing.B) {
|
|
||||||
data := voiceResponseStruct()
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
msgpack.Marshal(data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkEncode_VoiceResponse_Protobuf(b *testing.B) {
|
|
||||||
data := voiceResponseProto()
|
data := voiceResponseProto()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
proto.Marshal(data)
|
_, _ = proto.Marshal(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkEncode_TTSChunk_MsgpackMap(b *testing.B) {
|
func BenchmarkEncode_TTSChunk(b *testing.B) {
|
||||||
data := ttsChunkMap()
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
msgpack.Marshal(data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkEncode_TTSChunk_MsgpackStruct(b *testing.B) {
|
|
||||||
data := ttsChunkStruct()
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
msgpack.Marshal(data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkEncode_TTSChunk_Protobuf(b *testing.B) {
|
|
||||||
data := ttsChunkProto()
|
data := ttsChunkProto()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
proto.Marshal(data)
|
_, _ = proto.Marshal(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,84 +107,30 @@ func BenchmarkEncode_TTSChunk_Protobuf(b *testing.B) {
|
|||||||
// Decode benchmarks
|
// Decode benchmarks
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func BenchmarkDecode_ChatRequest_MsgpackMap(b *testing.B) {
|
func BenchmarkDecode_ChatRequest(b *testing.B) {
|
||||||
encoded, _ := msgpack.Marshal(chatRequestMap())
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
var m map[string]any
|
|
||||||
msgpack.Unmarshal(encoded, &m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkDecode_ChatRequest_MsgpackStruct(b *testing.B) {
|
|
||||||
encoded, _ := msgpack.Marshal(chatRequestStruct())
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
var m ChatRequest
|
|
||||||
msgpack.Unmarshal(encoded, &m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkDecode_ChatRequest_Protobuf(b *testing.B) {
|
|
||||||
encoded, _ := proto.Marshal(chatRequestProto())
|
encoded, _ := proto.Marshal(chatRequestProto())
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
var m pb.ChatRequest
|
var m pb.ChatRequest
|
||||||
proto.Unmarshal(encoded, &m)
|
_ = proto.Unmarshal(encoded, &m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkDecode_VoiceResponse_MsgpackMap(b *testing.B) {
|
func BenchmarkDecode_VoiceResponse(b *testing.B) {
|
||||||
encoded, _ := msgpack.Marshal(voiceResponseMap())
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
var m map[string]any
|
|
||||||
msgpack.Unmarshal(encoded, &m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkDecode_VoiceResponse_MsgpackStruct(b *testing.B) {
|
|
||||||
encoded, _ := msgpack.Marshal(voiceResponseStruct())
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
var m VoiceResponse
|
|
||||||
msgpack.Unmarshal(encoded, &m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkDecode_VoiceResponse_Protobuf(b *testing.B) {
|
|
||||||
encoded, _ := proto.Marshal(voiceResponseProto())
|
encoded, _ := proto.Marshal(voiceResponseProto())
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
var m pb.VoiceResponse
|
var m pb.VoiceResponse
|
||||||
proto.Unmarshal(encoded, &m)
|
_ = proto.Unmarshal(encoded, &m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkDecode_TTSChunk_MsgpackMap(b *testing.B) {
|
func BenchmarkDecode_TTSChunk(b *testing.B) {
|
||||||
encoded, _ := msgpack.Marshal(ttsChunkMap())
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
var m map[string]any
|
|
||||||
msgpack.Unmarshal(encoded, &m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkDecode_TTSChunk_MsgpackStruct(b *testing.B) {
|
|
||||||
encoded, _ := msgpack.Marshal(ttsChunkStruct())
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
var m TTSAudioChunk
|
|
||||||
msgpack.Unmarshal(encoded, &m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkDecode_TTSChunk_Protobuf(b *testing.B) {
|
|
||||||
encoded, _ := proto.Marshal(ttsChunkProto())
|
encoded, _ := proto.Marshal(ttsChunkProto())
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
var m pb.TTSAudioChunk
|
var m pb.TTSAudioChunk
|
||||||
proto.Unmarshal(encoded, &m)
|
_ = proto.Unmarshal(encoded, &m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,174 +138,168 @@ func BenchmarkDecode_TTSChunk_Protobuf(b *testing.B) {
|
|||||||
// Roundtrip benchmarks (encode + decode)
|
// Roundtrip benchmarks (encode + decode)
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func BenchmarkRoundtrip_ChatRequest_MsgpackMap(b *testing.B) {
|
func BenchmarkRoundtrip_ChatRequest(b *testing.B) {
|
||||||
data := chatRequestMap()
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
enc, _ := msgpack.Marshal(data)
|
|
||||||
var dec map[string]any
|
|
||||||
msgpack.Unmarshal(enc, &dec)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkRoundtrip_ChatRequest_MsgpackStruct(b *testing.B) {
|
|
||||||
data := chatRequestStruct()
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
enc, _ := msgpack.Marshal(data)
|
|
||||||
var dec ChatRequest
|
|
||||||
msgpack.Unmarshal(enc, &dec)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkRoundtrip_ChatRequest_Protobuf(b *testing.B) {
|
|
||||||
data := chatRequestProto()
|
data := chatRequestProto()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
enc, _ := proto.Marshal(data)
|
enc, _ := proto.Marshal(data)
|
||||||
var dec pb.ChatRequest
|
var dec pb.ChatRequest
|
||||||
proto.Unmarshal(enc, &dec)
|
_ = proto.Unmarshal(enc, &dec)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
// Typed struct unit tests — verify roundtrip correctness
|
// Correctness tests — verify proto roundtrip
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func TestRoundtrip_ChatRequest(t *testing.T) {
|
func TestRoundtrip_ChatRequest(t *testing.T) {
|
||||||
orig := chatRequestStruct()
|
orig := chatRequestProto()
|
||||||
data, err := msgpack.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
var dec ChatRequest
|
var dec pb.ChatRequest
|
||||||
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
if err := proto.Unmarshal(data, &dec); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if dec.RequestID != orig.RequestID {
|
if dec.GetRequestId() != orig.GetRequestId() {
|
||||||
t.Errorf("RequestID = %q, want %q", dec.RequestID, orig.RequestID)
|
t.Errorf("RequestId = %q, want %q", dec.GetRequestId(), orig.GetRequestId())
|
||||||
}
|
}
|
||||||
if dec.Message != orig.Message {
|
if dec.GetMessage() != orig.GetMessage() {
|
||||||
t.Errorf("Message = %q, want %q", dec.Message, orig.Message)
|
t.Errorf("Message = %q, want %q", dec.GetMessage(), orig.GetMessage())
|
||||||
}
|
}
|
||||||
if dec.TopK != orig.TopK {
|
if dec.GetTopK() != orig.GetTopK() {
|
||||||
t.Errorf("TopK = %d, want %d", dec.TopK, orig.TopK)
|
t.Errorf("TopK = %d, want %d", dec.GetTopK(), orig.GetTopK())
|
||||||
}
|
}
|
||||||
if dec.Premium != orig.Premium {
|
if dec.GetPremium() != orig.GetPremium() {
|
||||||
t.Errorf("Premium = %v, want %v", dec.Premium, orig.Premium)
|
t.Errorf("Premium = %v, want %v", dec.GetPremium(), orig.GetPremium())
|
||||||
}
|
}
|
||||||
if dec.EffectiveQuery() != orig.Message {
|
if EffectiveQuery(&dec) != orig.GetMessage() {
|
||||||
t.Errorf("EffectiveQuery() = %q, want %q", dec.EffectiveQuery(), orig.Message)
|
t.Errorf("EffectiveQuery() = %q, want %q", EffectiveQuery(&dec), orig.GetMessage())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoundtrip_VoiceResponse(t *testing.T) {
|
func TestRoundtrip_VoiceResponse(t *testing.T) {
|
||||||
orig := voiceResponseStruct()
|
orig := voiceResponseProto()
|
||||||
data, err := msgpack.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
var dec VoiceResponse
|
var dec pb.VoiceResponse
|
||||||
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
if err := proto.Unmarshal(data, &dec); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if dec.RequestID != orig.RequestID {
|
if dec.GetRequestId() != orig.GetRequestId() {
|
||||||
t.Errorf("RequestID mismatch")
|
t.Errorf("RequestId mismatch")
|
||||||
}
|
}
|
||||||
if len(dec.Audio) != len(orig.Audio) {
|
if len(dec.GetAudio()) != len(orig.GetAudio()) {
|
||||||
t.Errorf("Audio len = %d, want %d", len(dec.Audio), len(orig.Audio))
|
t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio()))
|
||||||
}
|
}
|
||||||
if dec.Transcription != orig.Transcription {
|
if dec.GetTranscription() != orig.GetTranscription() {
|
||||||
t.Errorf("Transcription mismatch")
|
t.Errorf("Transcription mismatch")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoundtrip_TTSAudioChunk(t *testing.T) {
|
func TestRoundtrip_TTSAudioChunk(t *testing.T) {
|
||||||
orig := ttsChunkStruct()
|
orig := ttsChunkProto()
|
||||||
data, err := msgpack.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
var dec TTSAudioChunk
|
var dec pb.TTSAudioChunk
|
||||||
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
if err := proto.Unmarshal(data, &dec); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if dec.SessionID != orig.SessionID {
|
if dec.GetSessionId() != orig.GetSessionId() {
|
||||||
t.Errorf("SessionID mismatch")
|
t.Errorf("SessionId mismatch")
|
||||||
}
|
}
|
||||||
if dec.ChunkIndex != orig.ChunkIndex {
|
if dec.GetChunkIndex() != orig.GetChunkIndex() {
|
||||||
t.Errorf("ChunkIndex = %d, want %d", dec.ChunkIndex, orig.ChunkIndex)
|
t.Errorf("ChunkIndex = %d, want %d", dec.GetChunkIndex(), orig.GetChunkIndex())
|
||||||
}
|
}
|
||||||
if len(dec.Audio) != len(orig.Audio) {
|
if len(dec.GetAudio()) != len(orig.GetAudio()) {
|
||||||
t.Errorf("Audio len = %d, want %d", len(dec.Audio), len(orig.Audio))
|
t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio()))
|
||||||
}
|
}
|
||||||
if dec.SampleRate != orig.SampleRate {
|
if dec.GetSampleRate() != orig.GetSampleRate() {
|
||||||
t.Errorf("SampleRate = %d, want %d", dec.SampleRate, orig.SampleRate)
|
t.Errorf("SampleRate = %d, want %d", dec.GetSampleRate(), orig.GetSampleRate())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoundtrip_PipelineTrigger(t *testing.T) {
|
func TestRoundtrip_PipelineTrigger(t *testing.T) {
|
||||||
orig := PipelineTrigger{
|
orig := &pb.PipelineTrigger{
|
||||||
RequestID: "pip-001",
|
RequestId: "pip-001",
|
||||||
Pipeline: "document-ingestion",
|
Pipeline: "document-ingestion",
|
||||||
Parameters: map[string]any{"source": "s3://bucket/data"},
|
Parameters: map[string]string{"source": "s3://bucket/data"},
|
||||||
}
|
}
|
||||||
data, err := msgpack.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
var dec PipelineTrigger
|
var dec pb.PipelineTrigger
|
||||||
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
if err := proto.Unmarshal(data, &dec); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if dec.Pipeline != orig.Pipeline {
|
if dec.GetPipeline() != orig.GetPipeline() {
|
||||||
t.Errorf("Pipeline = %q, want %q", dec.Pipeline, orig.Pipeline)
|
t.Errorf("Pipeline = %q, want %q", dec.GetPipeline(), orig.GetPipeline())
|
||||||
}
|
}
|
||||||
if dec.Parameters["source"] != orig.Parameters["source"] {
|
if dec.GetParameters()["source"] != orig.GetParameters()["source"] {
|
||||||
t.Errorf("Parameters[source] mismatch")
|
t.Errorf("Parameters[source] mismatch")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoundtrip_STTTranscription(t *testing.T) {
|
func TestRoundtrip_STTTranscription(t *testing.T) {
|
||||||
orig := STTTranscription{
|
orig := &pb.STTTranscription{
|
||||||
SessionID: "stt-001",
|
SessionId: "stt-001",
|
||||||
Transcript: "hello world",
|
Transcript: "hello world",
|
||||||
Sequence: 5,
|
Sequence: 5,
|
||||||
IsPartial: false,
|
IsPartial: false,
|
||||||
IsFinal: true,
|
IsFinal: true,
|
||||||
Timestamp: time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
SpeakerID: "speaker-1",
|
SpeakerId: "speaker-1",
|
||||||
HasVoiceActivity: true,
|
HasVoiceActivity: true,
|
||||||
State: "listening",
|
State: "listening",
|
||||||
}
|
}
|
||||||
data, err := msgpack.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
var dec STTTranscription
|
var dec pb.STTTranscription
|
||||||
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
if err := proto.Unmarshal(data, &dec); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if dec.Transcript != orig.Transcript {
|
if dec.GetTranscript() != orig.GetTranscript() {
|
||||||
t.Errorf("Transcript = %q, want %q", dec.Transcript, orig.Transcript)
|
t.Errorf("Transcript = %q, want %q", dec.GetTranscript(), orig.GetTranscript())
|
||||||
}
|
}
|
||||||
if dec.IsFinal != orig.IsFinal {
|
if dec.GetIsFinal() != orig.GetIsFinal() {
|
||||||
t.Error("IsFinal mismatch")
|
t.Error("IsFinal mismatch")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoundtrip_ErrorResponse(t *testing.T) {
|
func TestRoundtrip_ErrorResponse(t *testing.T) {
|
||||||
orig := ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"}
|
orig := &pb.ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"}
|
||||||
data, err := msgpack.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
var dec ErrorResponse
|
var dec pb.ErrorResponse
|
||||||
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
if err := proto.Unmarshal(data, &dec); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if !dec.Error || dec.Message != "something broke" || dec.Type != "InternalError" {
|
if !dec.GetError() || dec.GetMessage() != "something broke" || dec.GetType() != "InternalError" {
|
||||||
t.Errorf("ErrorResponse roundtrip mismatch: %+v", dec)
|
t.Errorf("ErrorResponse roundtrip mismatch: %+v", &dec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEffectiveQuery_MessageSet(t *testing.T) {
|
||||||
|
req := &pb.ChatRequest{Message: "hello", Query: "world"}
|
||||||
|
if got := EffectiveQuery(req); got != "hello" {
|
||||||
|
t.Errorf("EffectiveQuery() = %q, want %q", got, "hello")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEffectiveQuery_FallbackToQuery(t *testing.T) {
|
||||||
|
req := &pb.ChatRequest{Query: "world"}
|
||||||
|
if got := EffectiveQuery(req); got != "world" {
|
||||||
|
t.Errorf("EffectiveQuery() = %q, want %q", got, "world")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,224 +1,69 @@
|
|||||||
// Package messages defines typed NATS message structs for all services.
|
// Package messages re-exports protobuf message types and provides NATS
|
||||||
|
// subject constants plus helper functions.
|
||||||
//
|
//
|
||||||
// Using typed structs with short msgpack field tags instead of map[string]any
|
// The canonical type definitions live in the generated package
|
||||||
// provides compile-time safety, smaller wire size (integer-like short keys vs
|
// gen/messagespb (from proto/messages/v1/messages.proto).
|
||||||
// full string keys), and faster encode/decode by avoiding interface{} boxing.
|
// This package provides type aliases so existing callers can keep using
|
||||||
//
|
// messages.ChatRequest, etc., while the wire format is now protobuf.
|
||||||
// Audio data uses raw []byte instead of base64-encoded strings — msgpack
|
|
||||||
// supports binary natively, eliminating the 33% base64 overhead.
|
|
||||||
package messages
|
package messages
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
||||||
// Pipeline Bridge
|
)
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
// PipelineTrigger is the request to start a pipeline.
|
// ════════════════════════════════════════════════════════════════════════════
|
||||||
type PipelineTrigger struct {
|
// Type aliases — use these or import gen/messagespb directly.
|
||||||
RequestID string `msgpack:"request_id" json:"request_id"`
|
// ════════════════════════════════════════════════════════════════════════════
|
||||||
Pipeline string `msgpack:"pipeline" json:"pipeline"`
|
|
||||||
Parameters map[string]any `msgpack:"parameters,omitempty" json:"parameters,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// PipelineStatus is the response / status update for a pipeline run.
|
// Common
|
||||||
type PipelineStatus struct {
|
type ErrorResponse = pb.ErrorResponse
|
||||||
RequestID string `msgpack:"request_id" json:"request_id"`
|
|
||||||
Status string `msgpack:"status" json:"status"`
|
|
||||||
RunID string `msgpack:"run_id,omitempty" json:"run_id,omitempty"`
|
|
||||||
Engine string `msgpack:"engine,omitempty" json:"engine,omitempty"`
|
|
||||||
Pipeline string `msgpack:"pipeline,omitempty" json:"pipeline,omitempty"`
|
|
||||||
SubmittedAt string `msgpack:"submitted_at,omitempty" json:"submitted_at,omitempty"`
|
|
||||||
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
|
|
||||||
AvailablePipelines []string `msgpack:"available_pipelines,omitempty" json:"available_pipelines,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// Chat
|
||||||
// Chat Handler
|
type LoginEvent = pb.LoginEvent
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
type GreetingRequest = pb.GreetingRequest
|
||||||
|
type GreetingResponse = pb.GreetingResponse
|
||||||
|
type ChatRequest = pb.ChatRequest
|
||||||
|
type ChatResponse = pb.ChatResponse
|
||||||
|
type ChatStreamChunk = pb.ChatStreamChunk
|
||||||
|
|
||||||
// ChatRequest is an incoming chat message.
|
// Voice
|
||||||
type ChatRequest struct {
|
type VoiceRequest = pb.VoiceRequest
|
||||||
RequestID string `msgpack:"request_id" json:"request_id"`
|
type VoiceResponse = pb.VoiceResponse
|
||||||
UserID string `msgpack:"user_id" json:"user_id"`
|
type DocumentSource = pb.DocumentSource
|
||||||
Message string `msgpack:"message" json:"message"`
|
|
||||||
Query string `msgpack:"query,omitempty" json:"query,omitempty"`
|
// TTS
|
||||||
Premium bool `msgpack:"premium,omitempty" json:"premium,omitempty"`
|
type TTSRequest = pb.TTSRequest
|
||||||
EnableRAG bool `msgpack:"enable_rag,omitempty" json:"enable_rag,omitempty"`
|
type TTSAudioChunk = pb.TTSAudioChunk
|
||||||
EnableReranker bool `msgpack:"enable_reranker,omitempty" json:"enable_reranker,omitempty"`
|
type TTSFullResponse = pb.TTSFullResponse
|
||||||
EnableStreaming bool `msgpack:"enable_streaming,omitempty" json:"enable_streaming,omitempty"`
|
type TTSStatus = pb.TTSStatus
|
||||||
TopK int `msgpack:"top_k,omitempty" json:"top_k,omitempty"`
|
type TTSVoiceInfo = pb.TTSVoiceInfo
|
||||||
Collection string `msgpack:"collection,omitempty" json:"collection,omitempty"`
|
type TTSVoiceListResponse = pb.TTSVoiceListResponse
|
||||||
EnableTTS bool `msgpack:"enable_tts,omitempty" json:"enable_tts,omitempty"`
|
type TTSVoiceRefreshResponse = pb.TTSVoiceRefreshResponse
|
||||||
SystemPrompt string `msgpack:"system_prompt,omitempty" json:"system_prompt,omitempty"`
|
|
||||||
ResponseSubject string `msgpack:"response_subject,omitempty" json:"response_subject,omitempty"`
|
// STT
|
||||||
}
|
type STTStreamMessage = pb.STTStreamMessage
|
||||||
|
type STTTranscription = pb.STTTranscription
|
||||||
|
type STTInterrupt = pb.STTInterrupt
|
||||||
|
|
||||||
|
// Pipeline
|
||||||
|
type PipelineTrigger = pb.PipelineTrigger
|
||||||
|
type PipelineStatus = pb.PipelineStatus
|
||||||
|
|
||||||
|
// ════════════════════════════════════════════════════════════════════════════
|
||||||
|
// Helpers
|
||||||
|
// ════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
// EffectiveQuery returns Message or falls back to Query.
|
// EffectiveQuery returns Message or falls back to Query.
|
||||||
func (c *ChatRequest) EffectiveQuery() string {
|
func EffectiveQuery(c *ChatRequest) string {
|
||||||
if c.Message != "" {
|
if c.GetMessage() != "" {
|
||||||
return c.Message
|
return c.GetMessage()
|
||||||
}
|
}
|
||||||
return c.Query
|
return c.GetQuery()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatResponse is the full reply to a chat request.
|
// Timestamp returns the current Unix timestamp.
|
||||||
type ChatResponse struct {
|
|
||||||
UserID string `msgpack:"user_id" json:"user_id"`
|
|
||||||
Response string `msgpack:"response" json:"response"`
|
|
||||||
ResponseText string `msgpack:"response_text" json:"response_text"`
|
|
||||||
UsedRAG bool `msgpack:"used_rag" json:"used_rag"`
|
|
||||||
RAGSources []string `msgpack:"rag_sources,omitempty" json:"rag_sources,omitempty"`
|
|
||||||
Success bool `msgpack:"success" json:"success"`
|
|
||||||
Audio []byte `msgpack:"audio,omitempty" json:"audio,omitempty"`
|
|
||||||
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatStreamChunk is a single streaming chunk from an LLM response.
|
|
||||||
type ChatStreamChunk struct {
|
|
||||||
RequestID string `msgpack:"request_id" json:"request_id"`
|
|
||||||
Type string `msgpack:"type" json:"type"`
|
|
||||||
Content string `msgpack:"content" json:"content"`
|
|
||||||
Done bool `msgpack:"done" json:"done"`
|
|
||||||
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
|
||||||
// Voice Assistant
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
// VoiceRequest is an incoming voice-to-voice request.
|
|
||||||
type VoiceRequest struct {
|
|
||||||
RequestID string `msgpack:"request_id" json:"request_id"`
|
|
||||||
Audio []byte `msgpack:"audio" json:"audio"`
|
|
||||||
Language string `msgpack:"language,omitempty" json:"language,omitempty"`
|
|
||||||
Collection string `msgpack:"collection,omitempty" json:"collection,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// VoiceResponse is the reply to a voice request.
|
|
||||||
type VoiceResponse struct {
|
|
||||||
RequestID string `msgpack:"request_id" json:"request_id"`
|
|
||||||
Response string `msgpack:"response" json:"response"`
|
|
||||||
Audio []byte `msgpack:"audio" json:"audio"`
|
|
||||||
Transcription string `msgpack:"transcription,omitempty" json:"transcription,omitempty"`
|
|
||||||
Sources []DocumentSource `msgpack:"sources,omitempty" json:"sources,omitempty"`
|
|
||||||
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// DocumentSource is a RAG search result source.
|
|
||||||
type DocumentSource struct {
|
|
||||||
Text string `msgpack:"text" json:"text"`
|
|
||||||
Score float64 `msgpack:"score" json:"score"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
|
||||||
// TTS Module
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
// TTSRequest is a text-to-speech synthesis request.
|
|
||||||
type TTSRequest struct {
|
|
||||||
Text string `msgpack:"text" json:"text"`
|
|
||||||
Speaker string `msgpack:"speaker,omitempty" json:"speaker,omitempty"`
|
|
||||||
Language string `msgpack:"language,omitempty" json:"language,omitempty"`
|
|
||||||
SpeakerWavB64 string `msgpack:"speaker_wav_b64,omitempty" json:"speaker_wav_b64,omitempty"`
|
|
||||||
Stream bool `msgpack:"stream,omitempty" json:"stream,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TTSAudioChunk is a streamed audio chunk from TTS synthesis.
|
|
||||||
type TTSAudioChunk struct {
|
|
||||||
SessionID string `msgpack:"session_id" json:"session_id"`
|
|
||||||
ChunkIndex int `msgpack:"chunk_index" json:"chunk_index"`
|
|
||||||
TotalChunks int `msgpack:"total_chunks" json:"total_chunks"`
|
|
||||||
Audio []byte `msgpack:"audio" json:"audio"`
|
|
||||||
IsLast bool `msgpack:"is_last" json:"is_last"`
|
|
||||||
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
|
||||||
SampleRate int `msgpack:"sample_rate" json:"sample_rate"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TTSFullResponse is a non-streamed TTS response (whole audio).
|
|
||||||
type TTSFullResponse struct {
|
|
||||||
SessionID string `msgpack:"session_id" json:"session_id"`
|
|
||||||
Audio []byte `msgpack:"audio" json:"audio"`
|
|
||||||
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
|
||||||
SampleRate int `msgpack:"sample_rate" json:"sample_rate"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TTSStatus is a TTS processing status update.
|
|
||||||
type TTSStatus struct {
|
|
||||||
SessionID string `msgpack:"session_id" json:"session_id"`
|
|
||||||
Status string `msgpack:"status" json:"status"`
|
|
||||||
Message string `msgpack:"message" json:"message"`
|
|
||||||
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TTSVoiceListResponse is the reply to a voice list request.
|
|
||||||
type TTSVoiceListResponse struct {
|
|
||||||
DefaultSpeaker string `msgpack:"default_speaker" json:"default_speaker"`
|
|
||||||
CustomVoices []TTSVoiceInfo `msgpack:"custom_voices" json:"custom_voices"`
|
|
||||||
LastRefresh int64 `msgpack:"last_refresh" json:"last_refresh"`
|
|
||||||
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TTSVoiceInfo is summary info about a custom voice.
|
|
||||||
type TTSVoiceInfo struct {
|
|
||||||
Name string `msgpack:"name" json:"name"`
|
|
||||||
Language string `msgpack:"language" json:"language"`
|
|
||||||
ModelType string `msgpack:"model_type" json:"model_type"`
|
|
||||||
CreatedAt string `msgpack:"created_at" json:"created_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TTSVoiceRefreshResponse is the reply to a voice refresh request.
|
|
||||||
type TTSVoiceRefreshResponse struct {
|
|
||||||
Count int `msgpack:"count" json:"count"`
|
|
||||||
CustomVoices []TTSVoiceInfo `msgpack:"custom_voices" json:"custom_voices"`
|
|
||||||
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
|
||||||
// STT Module
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
// STTStreamMessage is any message on the ai.voice.stream.{session} subject.
|
|
||||||
type STTStreamMessage struct {
|
|
||||||
Type string `msgpack:"type" json:"type"`
|
|
||||||
Audio []byte `msgpack:"audio,omitempty" json:"audio,omitempty"`
|
|
||||||
State string `msgpack:"state,omitempty" json:"state,omitempty"`
|
|
||||||
SpeakerID string `msgpack:"speaker_id,omitempty" json:"speaker_id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// STTTranscription is the transcription result published by the STT module.
|
|
||||||
type STTTranscription struct {
|
|
||||||
SessionID string `msgpack:"session_id" json:"session_id"`
|
|
||||||
Transcript string `msgpack:"transcript" json:"transcript"`
|
|
||||||
Sequence int `msgpack:"sequence" json:"sequence"`
|
|
||||||
IsPartial bool `msgpack:"is_partial" json:"is_partial"`
|
|
||||||
IsFinal bool `msgpack:"is_final" json:"is_final"`
|
|
||||||
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
|
||||||
SpeakerID string `msgpack:"speaker_id" json:"speaker_id"`
|
|
||||||
HasVoiceActivity bool `msgpack:"has_voice_activity" json:"has_voice_activity"`
|
|
||||||
State string `msgpack:"state" json:"state"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// STTInterrupt is published when the STT module detects a user interrupt.
|
|
||||||
type STTInterrupt struct {
|
|
||||||
SessionID string `msgpack:"session_id" json:"session_id"`
|
|
||||||
Type string `msgpack:"type" json:"type"`
|
|
||||||
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
|
||||||
SpeakerID string `msgpack:"speaker_id" json:"speaker_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
|
||||||
// Common / Error
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
// ErrorResponse is the standard error reply from any handler.
|
|
||||||
type ErrorResponse struct {
|
|
||||||
Error bool `msgpack:"error" json:"error"`
|
|
||||||
Message string `msgpack:"message" json:"message"`
|
|
||||||
Type string `msgpack:"type" json:"type"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Timestamp returns the current Unix timestamp (helper for message construction).
|
|
||||||
func Timestamp() int64 {
|
func Timestamp() int64 {
|
||||||
return time.Now().Unix()
|
return time.Now().Unix()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Package natsutil provides a NATS/JetStream client with msgpack serialization.
|
// Package natsutil provides a NATS/JetStream client with protobuf serialization.
|
||||||
package natsutil
|
package natsutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -7,10 +7,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
"github.com/vmihailenco/msgpack/v5"
|
"google.golang.org/protobuf/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client wraps a NATS connection with msgpack helpers.
|
// Client wraps a NATS connection with protobuf helpers.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
nc *nats.Conn
|
nc *nats.Conn
|
||||||
js nats.JetStreamContext
|
js nats.JetStreamContext
|
||||||
@@ -97,46 +97,34 @@ func (c *Client) Subscribe(subject string, handler nats.MsgHandler, queue string
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish encodes data as msgpack and publishes to the subject.
|
// Publish encodes data as protobuf and publishes to the subject.
|
||||||
func (c *Client) Publish(subject string, data any) error {
|
func (c *Client) Publish(subject string, data proto.Message) error {
|
||||||
payload, err := msgpack.Marshal(data)
|
payload, err := proto.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("msgpack marshal: %w", err)
|
return fmt.Errorf("proto marshal: %w", err)
|
||||||
}
|
}
|
||||||
return c.nc.Publish(subject, payload)
|
return c.nc.Publish(subject, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request sends a msgpack-encoded request and decodes the response into result.
|
// PublishRaw publishes pre-encoded bytes to the subject.
|
||||||
func (c *Client) Request(subject string, data any, result any, timeout time.Duration) error {
|
func (c *Client) PublishRaw(subject string, data []byte) error {
|
||||||
payload, err := msgpack.Marshal(data)
|
return c.nc.Publish(subject, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request sends a protobuf-encoded request and decodes the response into result.
|
||||||
|
func (c *Client) Request(subject string, data proto.Message, result proto.Message, timeout time.Duration) error {
|
||||||
|
payload, err := proto.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("msgpack marshal: %w", err)
|
return fmt.Errorf("proto marshal: %w", err)
|
||||||
}
|
}
|
||||||
msg, err := c.nc.Request(subject, payload, timeout)
|
msg, err := c.nc.Request(subject, payload, timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nats request: %w", err)
|
return fmt.Errorf("nats request: %w", err)
|
||||||
}
|
}
|
||||||
return msgpack.Unmarshal(msg.Data, result)
|
return proto.Unmarshal(msg.Data, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeMsgpack decodes msgpack-encoded NATS message data into dest.
|
// Decode unmarshals protobuf bytes into dest.
|
||||||
func DecodeMsgpack(msg *nats.Msg, dest any) error {
|
func Decode(data []byte, dest proto.Message) error {
|
||||||
return msgpack.Unmarshal(msg.Data, dest)
|
return proto.Unmarshal(data, dest)
|
||||||
}
|
|
||||||
|
|
||||||
// Decode is a generic helper that unmarshals msgpack bytes into T.
|
|
||||||
// Usage: req, err := natsutil.Decode[messages.ChatRequest](msg.Data)
|
|
||||||
func Decode[T any](data []byte) (T, error) {
|
|
||||||
var v T
|
|
||||||
err := msgpack.Unmarshal(data, &v)
|
|
||||||
return v, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecodeMsgpackMap decodes msgpack data into a generic map.
|
|
||||||
func DecodeMsgpackMap(data []byte) (map[string]any, error) {
|
|
||||||
var m map[string]any
|
|
||||||
if err := msgpack.Unmarshal(data, &m); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,254 +3,212 @@ package natsutil
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/vmihailenco/msgpack/v5"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
// DecodeMsgpackMap tests
|
// Decode tests
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func TestDecodeMsgpackMap_Roundtrip(t *testing.T) {
|
func TestDecode_ChatRequest_Roundtrip(t *testing.T) {
|
||||||
orig := map[string]any{
|
orig := &pb.ChatRequest{
|
||||||
"request_id": "req-001",
|
RequestId: "req-001",
|
||||||
"user_id": "user-42",
|
UserId: "user-42",
|
||||||
"premium": true,
|
Premium: true,
|
||||||
"top_k": int64(10), // msgpack decodes ints as int64
|
TopK: 10,
|
||||||
}
|
}
|
||||||
data, err := msgpack.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
decoded, err := DecodeMsgpackMap(data)
|
var decoded pb.ChatRequest
|
||||||
if err != nil {
|
if err := Decode(data, &decoded); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if decoded["request_id"] != "req-001" {
|
if decoded.GetRequestId() != "req-001" {
|
||||||
t.Errorf("request_id = %v", decoded["request_id"])
|
t.Errorf("RequestId = %v", decoded.GetRequestId())
|
||||||
}
|
}
|
||||||
if decoded["premium"] != true {
|
if decoded.GetUserId() != "user-42" {
|
||||||
t.Errorf("premium = %v", decoded["premium"])
|
t.Errorf("UserId = %v", decoded.GetUserId())
|
||||||
|
}
|
||||||
|
if decoded.GetPremium() != true {
|
||||||
|
t.Errorf("Premium = %v", decoded.GetPremium())
|
||||||
|
}
|
||||||
|
if decoded.GetTopK() != 10 {
|
||||||
|
t.Errorf("TopK = %v", decoded.GetTopK())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDecodeMsgpackMap_Empty(t *testing.T) {
|
func TestDecode_EmptyMessage(t *testing.T) {
|
||||||
data, _ := msgpack.Marshal(map[string]any{})
|
data, err := proto.Marshal(&pb.ChatRequest{})
|
||||||
m, err := DecodeMsgpackMap(data)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(m) != 0 {
|
var decoded pb.ChatRequest
|
||||||
t.Errorf("expected empty map, got %v", m)
|
if err := Decode(data, &decoded); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if decoded.GetRequestId() != "" {
|
||||||
|
t.Errorf("expected empty RequestId, got %q", decoded.GetRequestId())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDecodeMsgpackMap_InvalidData(t *testing.T) {
|
func TestDecode_InvalidData(t *testing.T) {
|
||||||
_, err := DecodeMsgpackMap([]byte{0xFF, 0xFE})
|
err := Decode([]byte{0xFF, 0xFE}, &pb.ChatRequest{})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for invalid msgpack data")
|
t.Error("expected error for invalid protobuf data")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
// DecodeMsgpack (typed struct) tests
|
// Typed struct roundtrip tests
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
type testMessage struct {
|
func TestDecode_VoiceResponse_Roundtrip(t *testing.T) {
|
||||||
RequestID string `msgpack:"request_id"`
|
orig := &pb.VoiceResponse{
|
||||||
UserID string `msgpack:"user_id"`
|
RequestId: "vr-001",
|
||||||
Count int `msgpack:"count"`
|
Response: "The capital of France is Paris.",
|
||||||
Active bool `msgpack:"active"`
|
Transcription: "What is the capital of France?",
|
||||||
}
|
|
||||||
|
|
||||||
func TestDecodeMsgpackTyped_Roundtrip(t *testing.T) {
|
|
||||||
orig := testMessage{
|
|
||||||
RequestID: "req-typed-001",
|
|
||||||
UserID: "user-7",
|
|
||||||
Count: 42,
|
|
||||||
Active: true,
|
|
||||||
}
|
}
|
||||||
data, err := msgpack.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simulate nats.Msg data decoding.
|
var decoded pb.VoiceResponse
|
||||||
var decoded testMessage
|
if err := Decode(data, &decoded); err != nil {
|
||||||
if err := msgpack.Unmarshal(data, &decoded); err != nil {
|
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if decoded.RequestID != orig.RequestID {
|
if decoded.GetRequestId() != orig.GetRequestId() {
|
||||||
t.Errorf("RequestID = %q, want %q", decoded.RequestID, orig.RequestID)
|
t.Errorf("RequestId = %q, want %q", decoded.GetRequestId(), orig.GetRequestId())
|
||||||
}
|
}
|
||||||
if decoded.Count != orig.Count {
|
if decoded.GetResponse() != orig.GetResponse() {
|
||||||
t.Errorf("Count = %d, want %d", decoded.Count, orig.Count)
|
t.Errorf("Response = %q, want %q", decoded.GetResponse(), orig.GetResponse())
|
||||||
}
|
}
|
||||||
if decoded.Active != orig.Active {
|
if decoded.GetTranscription() != orig.GetTranscription() {
|
||||||
t.Errorf("Active = %v, want %v", decoded.Active, orig.Active)
|
t.Errorf("Transcription = %q, want %q", decoded.GetTranscription(), orig.GetTranscription())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestTypedStructDecodesMapEncoding verifies that a typed struct can be
|
func TestDecode_ErrorResponse_Roundtrip(t *testing.T) {
|
||||||
// decoded from data that was encoded as map[string]any (backwards compat).
|
orig := &pb.ErrorResponse{
|
||||||
func TestTypedStructDecodesMapEncoding(t *testing.T) {
|
Error: true,
|
||||||
// Encode as map (the old way).
|
Message: "something broke",
|
||||||
mapData := map[string]any{
|
Type: "InternalError",
|
||||||
"request_id": "req-compat",
|
|
||||||
"user_id": "user-compat",
|
|
||||||
"count": int64(99),
|
|
||||||
"active": false,
|
|
||||||
}
|
}
|
||||||
data, err := msgpack.Marshal(mapData)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode into typed struct (the new way).
|
var decoded pb.ErrorResponse
|
||||||
var msg testMessage
|
if err := Decode(data, &decoded); err != nil {
|
||||||
if err := msgpack.Unmarshal(data, &msg); err != nil {
|
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if msg.RequestID != "req-compat" {
|
if !decoded.GetError() {
|
||||||
t.Errorf("RequestID = %q", msg.RequestID)
|
t.Error("expected Error=true")
|
||||||
}
|
}
|
||||||
if msg.Count != 99 {
|
if decoded.GetMessage() != "something broke" {
|
||||||
t.Errorf("Count = %d, want 99", msg.Count)
|
t.Errorf("Message = %q", decoded.GetMessage())
|
||||||
|
}
|
||||||
|
if decoded.GetType() != "InternalError" {
|
||||||
|
t.Errorf("Type = %q", decoded.GetType())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
// Binary data tests (audio []byte in msgpack)
|
// Binary data tests (audio []byte in protobuf)
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
type audioMessage struct {
|
|
||||||
SessionID string `msgpack:"session_id"`
|
|
||||||
Audio []byte `msgpack:"audio"`
|
|
||||||
SampleRate int `msgpack:"sample_rate"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBinaryDataRoundtrip(t *testing.T) {
|
func TestBinaryDataRoundtrip(t *testing.T) {
|
||||||
audio := make([]byte, 32768)
|
audio := make([]byte, 32768)
|
||||||
for i := range audio {
|
for i := range audio {
|
||||||
audio[i] = byte(i % 256)
|
audio[i] = byte(i % 256)
|
||||||
}
|
}
|
||||||
|
|
||||||
orig := audioMessage{
|
orig := &pb.TTSAudioChunk{
|
||||||
SessionID: "sess-audio-001",
|
SessionId: "sess-audio-001",
|
||||||
Audio: audio,
|
Audio: audio,
|
||||||
SampleRate: 24000,
|
SampleRate: 24000,
|
||||||
}
|
}
|
||||||
data, err := msgpack.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var decoded audioMessage
|
var decoded pb.TTSAudioChunk
|
||||||
if err := msgpack.Unmarshal(data, &decoded); err != nil {
|
if err := Decode(data, &decoded); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(decoded.Audio) != len(orig.Audio) {
|
if len(decoded.GetAudio()) != len(orig.GetAudio()) {
|
||||||
t.Fatalf("audio len = %d, want %d", len(decoded.Audio), len(orig.Audio))
|
t.Fatalf("audio len = %d, want %d", len(decoded.GetAudio()), len(orig.GetAudio()))
|
||||||
}
|
}
|
||||||
for i := range decoded.Audio {
|
for i := range decoded.GetAudio() {
|
||||||
if decoded.Audio[i] != orig.Audio[i] {
|
if decoded.GetAudio()[i] != orig.GetAudio()[i] {
|
||||||
t.Fatalf("audio[%d] = %d, want %d", i, decoded.Audio[i], orig.Audio[i])
|
t.Fatalf("audio[%d] = %d, want %d", i, decoded.GetAudio()[i], orig.GetAudio()[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestBinaryVsBase64Size shows the wire-size win of raw bytes vs base64 string.
|
// TestProtoWireSize shows protobuf wire size for binary payloads.
|
||||||
func TestBinaryVsBase64Size(t *testing.T) {
|
func TestProtoWireSize(t *testing.T) {
|
||||||
audio := make([]byte, 16384)
|
audio := make([]byte, 16384)
|
||||||
|
|
||||||
// Old approach: base64 string in map.
|
msg := &pb.TTSAudioChunk{
|
||||||
import_b64 := make([]byte, (len(audio)*4+2)/3) // approximate base64 size
|
SessionId: "sess-1",
|
||||||
mapMsg := map[string]any{
|
|
||||||
"session_id": "sess-1",
|
|
||||||
"audio_b64": string(import_b64),
|
|
||||||
}
|
|
||||||
mapData, _ := msgpack.Marshal(mapMsg)
|
|
||||||
|
|
||||||
// New approach: raw bytes in struct.
|
|
||||||
structMsg := audioMessage{
|
|
||||||
SessionID: "sess-1",
|
|
||||||
Audio: audio,
|
Audio: audio,
|
||||||
}
|
}
|
||||||
structData, _ := msgpack.Marshal(structMsg)
|
data, _ := proto.Marshal(msg)
|
||||||
|
|
||||||
t.Logf("base64-in-map: %d bytes, raw-bytes-in-struct: %d bytes (%.0f%% smaller)",
|
t.Logf("TTSAudioChunk with 16KB audio: %d bytes on wire", len(data))
|
||||||
len(mapData), len(structData),
|
|
||||||
100*(1-float64(len(structData))/float64(len(mapData))))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
// Benchmarks
|
// Benchmarks
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func BenchmarkEncodeMap(b *testing.B) {
|
func BenchmarkEncode_ChatRequest(b *testing.B) {
|
||||||
data := map[string]any{
|
data := &pb.ChatRequest{
|
||||||
"request_id": "req-bench",
|
RequestId: "req-bench",
|
||||||
"user_id": "user-bench",
|
UserId: "user-bench",
|
||||||
"message": "What is the weather today?",
|
Message: "What is the weather today?",
|
||||||
"premium": true,
|
Premium: true,
|
||||||
"top_k": 10,
|
TopK: 10,
|
||||||
}
|
}
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
msgpack.Marshal(data)
|
_, _ = proto.Marshal(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkEncodeStruct(b *testing.B) {
|
func BenchmarkDecode_ChatRequest(b *testing.B) {
|
||||||
data := testMessage{
|
raw, _ := proto.Marshal(&pb.ChatRequest{
|
||||||
RequestID: "req-bench",
|
RequestId: "req-bench",
|
||||||
UserID: "user-bench",
|
UserId: "user-bench",
|
||||||
Count: 10,
|
Message: "What is the weather today?",
|
||||||
Active: true,
|
Premium: true,
|
||||||
}
|
TopK: 10,
|
||||||
for b.Loop() {
|
|
||||||
msgpack.Marshal(data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkDecodeMap(b *testing.B) {
|
|
||||||
raw, _ := msgpack.Marshal(map[string]any{
|
|
||||||
"request_id": "req-bench",
|
|
||||||
"user_id": "user-bench",
|
|
||||||
"message": "What is the weather today?",
|
|
||||||
"premium": true,
|
|
||||||
"top_k": 10,
|
|
||||||
})
|
})
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
var m map[string]any
|
var m pb.ChatRequest
|
||||||
msgpack.Unmarshal(raw, &m)
|
_ = Decode(raw, &m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkDecodeStruct(b *testing.B) {
|
func BenchmarkDecode_Audio32KB(b *testing.B) {
|
||||||
raw, _ := msgpack.Marshal(testMessage{
|
raw, _ := proto.Marshal(&pb.TTSAudioChunk{
|
||||||
RequestID: "req-bench",
|
SessionId: "s1",
|
||||||
UserID: "user-bench",
|
|
||||||
Count: 10,
|
|
||||||
Active: true,
|
|
||||||
})
|
|
||||||
for b.Loop() {
|
|
||||||
var m testMessage
|
|
||||||
msgpack.Unmarshal(raw, &m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkDecodeAudio32KB(b *testing.B) {
|
|
||||||
raw, _ := msgpack.Marshal(audioMessage{
|
|
||||||
SessionID: "s1",
|
|
||||||
Audio: make([]byte, 32768),
|
Audio: make([]byte, 32768),
|
||||||
SampleRate: 24000,
|
SampleRate: 24000,
|
||||||
})
|
})
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
var m audioMessage
|
var m pb.TTSAudioChunk
|
||||||
msgpack.Unmarshal(raw, &m)
|
_ = Decode(raw, &m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
257
proto/messages/v1/messages.proto
Normal file
257
proto/messages/v1/messages.proto
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
// Homelab AI service message contracts.
|
||||||
|
//
|
||||||
|
// This is the single source of truth for all NATS message types.
|
||||||
|
// Generated Go code lives in handler-base/gen/messagespb.
|
||||||
|
//
|
||||||
|
// Naming: field numbers are stable across versions — add new fields,
|
||||||
|
// never reuse or renumber existing ones.
|
||||||
|
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package messages.v1;
|
||||||
|
|
||||||
|
option go_package = "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb";
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Common
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// ErrorResponse is the standard error reply from any handler.
|
||||||
|
message ErrorResponse {
|
||||||
|
bool error = 1;
|
||||||
|
string message = 2;
|
||||||
|
string type = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Chat (companions-frontend ↔ chat-handler)
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// LoginEvent is published when a user authenticates.
|
||||||
|
// Subject: ai.chat.user.{user_id}.login
|
||||||
|
message LoginEvent {
|
||||||
|
string user_id = 1;
|
||||||
|
string username = 2;
|
||||||
|
string nickname = 3;
|
||||||
|
bool premium = 4;
|
||||||
|
int64 timestamp = 5; // Unix seconds
|
||||||
|
}
|
||||||
|
|
||||||
|
// GreetingRequest asks the LLM to generate a personalised greeting.
|
||||||
|
// Subject: ai.chat.user.{user_id}.greeting.request
|
||||||
|
message GreetingRequest {
|
||||||
|
string user_id = 1;
|
||||||
|
string username = 2;
|
||||||
|
string nickname = 3;
|
||||||
|
bool premium = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// GreetingResponse carries the generated greeting text.
|
||||||
|
// Subject: ai.chat.user.{user_id}.greeting.response
|
||||||
|
message GreetingResponse {
|
||||||
|
string user_id = 1;
|
||||||
|
string greeting = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatRequest is an incoming chat message routed via NATS.
|
||||||
|
// Subject: ai.chat.user.{user_id}.message
|
||||||
|
message ChatRequest {
|
||||||
|
string request_id = 1;
|
||||||
|
string user_id = 2;
|
||||||
|
string username = 3;
|
||||||
|
string message = 4;
|
||||||
|
string query = 5; // alternative to message (EffectiveQuery picks first non-empty)
|
||||||
|
bool premium = 6;
|
||||||
|
bool enable_rag = 7;
|
||||||
|
bool enable_reranker = 8;
|
||||||
|
bool enable_streaming = 9;
|
||||||
|
int32 top_k = 10;
|
||||||
|
string collection = 11;
|
||||||
|
bool enable_tts = 12;
|
||||||
|
string system_prompt = 13;
|
||||||
|
string response_subject = 14;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatResponse is the full reply to a ChatRequest.
|
||||||
|
// Subject: ai.chat.response.{request_id} (or ChatRequest.response_subject)
|
||||||
|
message ChatResponse {
|
||||||
|
string user_id = 1;
|
||||||
|
string response = 2;
|
||||||
|
string response_text = 3;
|
||||||
|
bool used_rag = 4;
|
||||||
|
repeated string rag_sources = 5;
|
||||||
|
bool success = 6;
|
||||||
|
bytes audio = 7;
|
||||||
|
string error = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatStreamChunk is one piece of a streaming LLM response.
|
||||||
|
// Subject: ai.chat.response.stream.{request_id}
|
||||||
|
message ChatStreamChunk {
|
||||||
|
string request_id = 1;
|
||||||
|
string type = 2; // "chunk" | "done"
|
||||||
|
string content = 3;
|
||||||
|
bool done = 4;
|
||||||
|
int64 timestamp = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Voice Assistant
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// VoiceRequest is an incoming voice-to-voice request.
|
||||||
|
// Subject: ai.voice.request
|
||||||
|
message VoiceRequest {
|
||||||
|
string request_id = 1;
|
||||||
|
bytes audio = 2;
|
||||||
|
string language = 3;
|
||||||
|
string collection = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// DocumentSource is a single RAG search-result citation.
|
||||||
|
message DocumentSource {
|
||||||
|
string text = 1;
|
||||||
|
double score = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// VoiceResponse is the reply to a VoiceRequest.
|
||||||
|
// Subject: ai.voice.response.{request_id}
|
||||||
|
message VoiceResponse {
|
||||||
|
string request_id = 1;
|
||||||
|
string response = 2;
|
||||||
|
bytes audio = 3;
|
||||||
|
string transcription = 4;
|
||||||
|
repeated DocumentSource sources = 5;
|
||||||
|
string error = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// TTS Module
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// TTSRequest is a text-to-speech synthesis request.
|
||||||
|
// Subject: ai.voice.tts.request.{session_id}
|
||||||
|
message TTSRequest {
|
||||||
|
string text = 1;
|
||||||
|
string speaker = 2;
|
||||||
|
string language = 3;
|
||||||
|
string speaker_wav_b64 = 4;
|
||||||
|
bool stream = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSAudioChunk is a streamed audio chunk from TTS synthesis.
|
||||||
|
// Subject: ai.voice.tts.audio.{session_id}
|
||||||
|
message TTSAudioChunk {
|
||||||
|
string session_id = 1;
|
||||||
|
int32 chunk_index = 2;
|
||||||
|
int32 total_chunks = 3;
|
||||||
|
bytes audio = 4;
|
||||||
|
bool is_last = 5;
|
||||||
|
int64 timestamp = 6;
|
||||||
|
int32 sample_rate = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSFullResponse is a non-streamed TTS response (whole audio blob).
|
||||||
|
// Subject: ai.voice.tts.audio.{session_id}
|
||||||
|
message TTSFullResponse {
|
||||||
|
string session_id = 1;
|
||||||
|
bytes audio = 2;
|
||||||
|
int64 timestamp = 3;
|
||||||
|
int32 sample_rate = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSStatus is a TTS processing status update.
|
||||||
|
// Subject: ai.voice.tts.status.{session_id}
|
||||||
|
message TTSStatus {
|
||||||
|
string session_id = 1;
|
||||||
|
string status = 2;
|
||||||
|
string message = 3;
|
||||||
|
int64 timestamp = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSVoiceInfo is summary info about a custom voice.
|
||||||
|
message TTSVoiceInfo {
|
||||||
|
string name = 1;
|
||||||
|
string language = 2;
|
||||||
|
string model_type = 3;
|
||||||
|
string created_at = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSVoiceListResponse is the reply to a voice list request.
|
||||||
|
// Subject: ai.voice.tts.voices.list (request-reply)
|
||||||
|
message TTSVoiceListResponse {
|
||||||
|
string default_speaker = 1;
|
||||||
|
repeated TTSVoiceInfo custom_voices = 2;
|
||||||
|
int64 last_refresh = 3;
|
||||||
|
int64 timestamp = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSVoiceRefreshResponse is the reply to a voice refresh request.
|
||||||
|
// Subject: ai.voice.tts.voices.refresh (request-reply)
|
||||||
|
message TTSVoiceRefreshResponse {
|
||||||
|
int32 count = 1;
|
||||||
|
repeated TTSVoiceInfo custom_voices = 2;
|
||||||
|
int64 timestamp = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// STT Module
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// STTStreamMessage is any message on the ai.voice.stream.{session_id} subject.
|
||||||
|
message STTStreamMessage {
|
||||||
|
string type = 1; // "start" | "chunk" | "state_change" | "end"
|
||||||
|
bytes audio = 2;
|
||||||
|
string state = 3;
|
||||||
|
string speaker_id = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// STTTranscription is the transcription result published by the STT module.
|
||||||
|
// Subject: ai.voice.transcription.{session_id}
|
||||||
|
message STTTranscription {
|
||||||
|
string session_id = 1;
|
||||||
|
string transcript = 2;
|
||||||
|
int32 sequence = 3;
|
||||||
|
bool is_partial = 4;
|
||||||
|
bool is_final = 5;
|
||||||
|
int64 timestamp = 6;
|
||||||
|
string speaker_id = 7;
|
||||||
|
bool has_voice_activity = 8;
|
||||||
|
string state = 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
// STTInterrupt is published when the STT module detects a user interrupt.
|
||||||
|
// Subject: ai.voice.transcription.{session_id}
|
||||||
|
message STTInterrupt {
|
||||||
|
string session_id = 1;
|
||||||
|
string type = 2; // "interrupt"
|
||||||
|
int64 timestamp = 3;
|
||||||
|
string speaker_id = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Pipeline Bridge
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// PipelineTrigger is the request to start a pipeline.
|
||||||
|
// Subject: ai.pipeline.trigger
|
||||||
|
message PipelineTrigger {
|
||||||
|
string request_id = 1;
|
||||||
|
string pipeline = 2;
|
||||||
|
// Protobuf Struct could be used here, but a simple string map covers
|
||||||
|
// all current use-cases and avoids a google/protobuf import.
|
||||||
|
map<string, string> parameters = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// PipelineStatus is the response / status update for a pipeline run.
|
||||||
|
// Subject: ai.pipeline.status.{request_id}
|
||||||
|
message PipelineStatus {
|
||||||
|
string request_id = 1;
|
||||||
|
string status = 2;
|
||||||
|
string run_id = 3;
|
||||||
|
string engine = 4;
|
||||||
|
string pipeline = 5;
|
||||||
|
string submitted_at = 6;
|
||||||
|
string error = 7;
|
||||||
|
repeated string available_pipelines = 8;
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user