18 Commits

Author SHA1 Message Date
f61488a868 style: gofmt + fix copylocks lint warning
Some checks failed
CI / Test (push) Successful in 2m49s
CI / Lint (push) Successful in 3m16s
CI / Release (push) Successful in 2m8s
CI / Docker Build & Push (push) Failing after 6m13s
CI / Notify (push) Successful in 1s
2026-02-21 15:36:46 -05:00
e2176331c8 feat: migrate from msgpack to protobuf (handler-base v1.0.0)
Some checks failed
CI / Lint (push) Failing after 2m49s
CI / Test (push) Successful in 3m36s
CI / Notify (push) Has been cancelled
CI / Docker Build & Push (push) Has been cancelled
CI / Release (push) Has been cancelled
- Replace msgpack encoding with protobuf wire format
- Update field names to proto convention (UserId, RequestId, EnableRag, etc.)
- Use messages.EffectiveQuery() standalone function
- Cast TopK to int32 for proto compatibility
- Rewrite tests for proto round-trips
2026-02-21 15:30:04 -05:00
87d0545d2c feat: replace fake streaming with real SSE StreamGenerate
Some checks failed
CI / Lint (push) Successful in 3m0s
CI / Test (push) Successful in 3m23s
CI / Docker Build & Push (push) Failing after 4m55s
CI / Release (push) Successful in 1m4s
CI / Notify (push) Successful in 1s
Use handler-base StreamGenerate() to publish real token-by-token
ChatStreamChunk messages to NATS as they arrive from Ray Serve,
instead of calling Generate() and splitting into 4-word chunks.

Add 8 streaming tests: happy path, system prompt, RAG context,
nil callback, timeout, HTTP error, context canceled, fallback.
2026-02-21 09:23:57 -05:00
7678d911fe chore: bump handler-base to v0.1.5, add netrc secret mount to Dockerfile
Some checks failed
CI / Release (push) Successful in 1m11s
CI / Lint (push) Successful in 3m55s
CI / Test (push) Successful in 3m36s
CI / Docker Build & Push (push) Failing after 8m10s
CI / Notify (push) Successful in 1s
2026-02-20 18:15:47 -05:00
7d72400070 ci: retrigger build
Some checks failed
CI / Test (push) Successful in 2m51s
CI / Release (push) Successful in 1m3s
CI / Docker Build & Push (push) Failing after 28s
CI / Lint (push) Successful in 3m27s
CI / Notify (push) Successful in 1s
2026-02-20 10:05:21 -05:00
451e600985 fix: add GOPRIVATE and git auth for private handler-base module
Some checks failed
CI / Lint (push) Successful in 2m32s
CI / Test (push) Successful in 2m54s
CI / Release (push) Successful in 1m1s
CI / Docker Build & Push (push) Failing after 6m40s
CI / Notify (push) Successful in 1s
- Set GOPRIVATE=git.daviestechlabs.io to bypass public Go module proxy
- Configure git URL insteadOf with DISPATCH_TOKEN for private repo access
2026-02-20 09:22:35 -05:00
4b7b434eb7 fix: rename GITEA_TOKEN to DISPATCH_TOKEN
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release (push) Has been cancelled
CI / Docker Build & Push (push) Has been cancelled
CI / Notify (push) Has been cancelled
2026-02-20 09:10:20 -05:00
808f41bc90 ci: add handler-base auto-update workflow, remove old Python CI
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release (push) Has been cancelled
CI / Docker Build & Push (push) Has been cancelled
CI / Notify (push) Has been cancelled
2026-02-20 09:05:58 -05:00
2fa3668e8a fix: use tagged handler-base v0.1.3, remove local replace directive
Some checks failed
CI / Docker Build & Push (push) Has been cancelled
CI / Notify (push) Has been cancelled
CI / Lint (push) Failing after 57s
CI / Test (push) Failing after 1m34s
CI / Release (push) Has been cancelled
2026-02-20 09:00:45 -05:00
b48fbb424d fix: resolve golangci-lint errcheck warnings
Some checks failed
CI / Lint (push) Failing after 57s
CI / Test (push) Failing after 1m22s
CI / Release (push) Has been cancelled
CI / Docker Build & Push (push) Has been cancelled
CI / Notify (push) Has been cancelled
- Add error checks for unchecked return values (errcheck)
- Remove unused struct fields (unused)
- Fix gofmt formatting issues
2026-02-20 08:45:24 -05:00
c27d192705 Merge pull request 'feature/go-handler-refactor' (#1) from feature/go-handler-refactor into main
Some checks failed
CI / Lint (push) Failing after 57s
CI / Test (push) Failing after 1m24s
CI / Notify (push) Successful in 1s
CI / Release (push) Has been skipped
CI / Docker Build & Push (push) Has been skipped
Reviewed-on: #1
2026-02-20 12:33:32 +00:00
4175e2070c feat: migrate to typed messages, drop base64
Some checks failed
CI / Lint (pull_request) Failing after 58s
CI / Test (pull_request) Failing after 1m22s
CI / Notify (pull_request) Successful in 1s
CI / Release (pull_request) Has been skipped
CI / Docker Build & Push (pull_request) Has been skipped
- Switch OnMessage → OnTypedMessage with natsutil.Decode[messages.ChatRequest]
- Return *messages.ChatResponse / *messages.ChatStreamChunk (not map[string]any)
- Audio as raw []byte in msgpack (25% wire savings vs base64)
- Remove strVal/boolVal/intVal helpers
- Add .dockerignore, GOAMD64=v3 in Dockerfile
- Update tests for typed structs (9 tests pass)
2026-02-20 07:10:43 -05:00
609b44de83 feat: add e2e tests + benchmarks, fix config API
- e2e_test.go: full pipeline tests (LLM-only, RAG, TTS, timeout)
- main.go: fix config field->method references (EmbeddingsURL() etc.)
- Benchmarks: LLMOnly 136µs/op, RAGFlow 496µs/op
2026-02-20 06:45:21 -05:00
adcdb87b9a feat: rewrite chat-handler in Go
Replace Python chat handler with Go for smaller container images.
Uses handler-base Go module for NATS, health, telemetry, and service clients.

- RAG pipeline: embed → Milvus → rerank → LLM
- Streaming response chunks
- Optional TTS synthesis
- Custom response_subject support for companions-frontend
2026-02-19 17:58:52 -05:00
a1cf87909d fixing ruff suggestions and tests needed updating.
All checks were successful
CI / Lint (push) Successful in 1m39s
CI / Test (push) Successful in 1m37s
CI / Release (push) Successful in 6s
CI / Notify (push) Successful in 1s
2026-02-18 07:37:13 -05:00
24a4098c9a fixing up chat-handler.
Some checks failed
CI / Lint (push) Failing after 1m39s
CI / Test (push) Failing after 1m37s
CI / Release (push) Has been skipped
CI / Notify (push) Successful in 1s
2026-02-18 07:29:41 -05:00
b34e8d2e1c fix: replace astral-sh/setup-uv action with shell install
All checks were successful
CI / Lint (push) Successful in 58s
CI / Test (push) Successful in 1m50s
CI / Release (push) Successful in 25s
CI / Notify (push) Successful in 1s
The JS-based GitHub Action doesn't work on Gitea's act runner.
Use curl installer + GITHUB_PATH instead.
2026-02-13 19:40:53 -05:00
3a4a13f0de chore: add Renovate config for automated dependency updates
All checks were successful
CI / Notify (push) Successful in 1s
CI / Lint (push) Successful in 1m8s
CI / Test (push) Successful in 2m0s
CI / Release (push) Successful in 4s
Ref: ADR-0057
2026-02-13 15:33:43 -05:00
19 changed files with 1259 additions and 3428 deletions

9
.dockerignore Normal file
View File

@@ -0,0 +1,9 @@
.git
.gitignore
*.md
LICENSE
renovate.json
*_test.go
e2e_test.go
__pycache__
.env*

View File

@@ -0,0 +1,213 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
env:
NTFY_URL: http://ntfy.observability.svc.cluster.local:80
GOPRIVATE: git.daviestechlabs.io
REGISTRY: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs
REGISTRY_HOST: gitea-http.gitea.svc.cluster.local:3000
IMAGE_NAME: chat-handler
jobs:
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache: true
- name: Configure private modules
run: git config --global url."https://gitea-actions:${{ secrets.DISPATCH_TOKEN }}@git.daviestechlabs.io/".insteadOf "https://git.daviestechlabs.io/"
- name: Run go vet
run: go vet ./...
- name: Install golangci-lint
run: |
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b "$(go env GOPATH)/bin"
echo "$(go env GOPATH)/bin" >> $GITHUB_PATH
- name: Run golangci-lint
run: golangci-lint run ./...
test:
name: Test
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache: true
- name: Configure private modules
run: git config --global url."https://gitea-actions:${{ secrets.DISPATCH_TOKEN }}@git.daviestechlabs.io/".insteadOf "https://git.daviestechlabs.io/"
- name: Verify dependencies
run: go mod verify
- name: Build
run: go build -v ./...
- name: Run tests
run: go test -v -race -coverprofile=coverage.out -covermode=atomic ./...
release:
name: Release
runs-on: ubuntu-latest
needs: [lint, test]
if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push'
outputs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Determine version bump
id: version
run: |
# Get latest tag or default to v0.0.0
LATEST=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
VERSION=${LATEST#v}
IFS='.' read -r MAJOR MINOR PATCH <<< "$VERSION"
# Check commit message for keywords
MSG="${{ gitea.event.head_commit.message }}"
if echo "$MSG" | grep -qiE "^major:|BREAKING CHANGE"; then
MAJOR=$((MAJOR + 1)); MINOR=0; PATCH=0
BUMP="major"
elif echo "$MSG" | grep -qiE "^(minor:|feat:)"; then
MINOR=$((MINOR + 1)); PATCH=0
BUMP="minor"
else
PATCH=$((PATCH + 1))
BUMP="patch"
fi
NEW_VERSION="v${MAJOR}.${MINOR}.${PATCH}"
echo "version=$NEW_VERSION" >> $GITHUB_OUTPUT
echo "bump=$BUMP" >> $GITHUB_OUTPUT
echo "Bumping $LATEST → $NEW_VERSION ($BUMP)"
- name: Create and push tag
run: |
git config user.name "gitea-actions[bot]"
git config user.email "actions@git.daviestechlabs.io"
git tag -a ${{ steps.version.outputs.version }} -m "Release ${{ steps.version.outputs.version }}"
git push origin ${{ steps.version.outputs.version }}
docker:
name: Docker Build & Push
runs-on: ubuntu-latest
needs: [lint, test, release]
if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push'
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
buildkitd-config-inline: |
[registry."gitea-http.gitea.svc.cluster.local:3000"]
http = true
insecure = true
- name: Login to Docker Hub
if: vars.DOCKERHUB_USERNAME != ''
uses: docker/login-action@v3
with:
username: ${{ vars.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Configure Docker for insecure registry
run: |
sudo mkdir -p /etc/docker
echo '{"insecure-registries": ["${{ env.REGISTRY_HOST }}"]}' | sudo tee /etc/docker/daemon.json
sudo systemctl restart docker || sudo service docker restart || true
sleep 2
- name: Login to Gitea Registry
run: |
AUTH=$(echo -n "${{ secrets.REGISTRY_USER }}:${{ secrets.REGISTRY_TOKEN }}" | base64 -w0)
mkdir -p ~/.docker
cat > ~/.docker/config.json << EOF
{
"auths": {
"${{ env.REGISTRY_HOST }}": {
"auth": "$AUTH"
}
}
}
EOF
echo "Auth configured for ${{ env.REGISTRY_HOST }}"
- name: Extract metadata
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=semver,pattern={{version}},value=${{ needs.release.outputs.version }}
type=semver,pattern={{major}}.{{minor}},value=${{ needs.release.outputs.version }}
type=raw,value=latest,enable={{is_default_branch}}
- name: Build and push
uses: docker/build-push-action@v5
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
notify:
name: Notify
runs-on: ubuntu-latest
needs: [lint, test, release, docker]
if: always()
steps:
- name: Notify on success
if: needs.lint.result == 'success' && needs.test.result == 'success'
run: |
curl -s \
-H "Title: ✅ CI Passed: ${{ gitea.repository }}" \
-H "Priority: default" \
-H "Tags: white_check_mark,github" \
-H "Click: ${{ gitea.server_url }}/${{ gitea.repository }}/actions/runs/${{ gitea.run_id }}" \
-d "Branch: ${{ gitea.ref_name }}
Commit: ${{ gitea.event.head_commit.message || gitea.sha }}
Release: ${{ needs.release.result == 'success' && needs.release.outputs.version || 'skipped' }}
Docker: ${{ needs.docker.result }}" \
${{ env.NTFY_URL }}/gitea-ci
- name: Notify on failure
if: needs.lint.result == 'failure' || needs.test.result == 'failure'
run: |
curl -s \
-H "Title: ❌ CI Failed: ${{ gitea.repository }}" \
-H "Priority: high" \
-H "Tags: x,github" \
-H "Click: ${{ gitea.server_url }}/${{ gitea.repository }}/actions/runs/${{ gitea.run_id }}" \
-d "Branch: ${{ gitea.ref_name }}
Commit: ${{ gitea.event.head_commit.message || gitea.sha }}
Lint: ${{ needs.lint.result }}
Test: ${{ needs.test.result }}" \
${{ env.NTFY_URL }}/gitea-ci

View File

@@ -1,135 +0,0 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
env:
NTFY_URL: http://ntfy.observability.svc.cluster.local:80
jobs:
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up uv
uses: astral-sh/setup-uv@v7
with:
version: "latest"
activate-environment: false
- name: Set up Python
run: uv python install 3.13
- name: Install dependencies
run: uv sync --frozen --extra dev
- name: Run ruff check
run: uv run ruff check .
- name: Run ruff format check
run: uv run ruff format --check .
test:
name: Test
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up uv
uses: astral-sh/setup-uv@v7
with:
version: "latest"
activate-environment: false
- name: Set up Python
run: uv python install 3.13
- name: Install dependencies
run: uv sync --frozen --extra dev
- name: Run tests
run: uv run pytest -v
release:
name: Release
runs-on: ubuntu-latest
needs: [lint, test]
if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push'
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Determine version bump
id: version
run: |
# Get latest tag or default to v0.0.0
LATEST=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
VERSION=${LATEST#v}
IFS='.' read -r MAJOR MINOR PATCH <<< "$VERSION"
# Check commit message for keywords
MSG="${{ gitea.event.head_commit.message }}"
if echo "$MSG" | grep -qiE "^major:|BREAKING CHANGE"; then
MAJOR=$((MAJOR + 1)); MINOR=0; PATCH=0
BUMP="major"
elif echo "$MSG" | grep -qiE "^(minor:|feat:)"; then
MINOR=$((MINOR + 1)); PATCH=0
BUMP="minor"
else
PATCH=$((PATCH + 1))
BUMP="patch"
fi
NEW_VERSION="v${MAJOR}.${MINOR}.${PATCH}"
echo "version=$NEW_VERSION" >> $GITHUB_OUTPUT
echo "bump=$BUMP" >> $GITHUB_OUTPUT
echo "Bumping $LATEST → $NEW_VERSION ($BUMP)"
- name: Create and push tag
run: |
git config user.name "gitea-actions[bot]"
git config user.email "actions@git.daviestechlabs.io"
git tag -a ${{ steps.version.outputs.version }} -m "Release ${{ steps.version.outputs.version }}"
git push origin ${{ steps.version.outputs.version }}
notify:
name: Notify
runs-on: ubuntu-latest
needs: [lint, test, release]
if: always()
steps:
- name: Notify on success
if: needs.lint.result == 'success' && needs.test.result == 'success'
run: |
curl -s \
-H "Title: ✅ CI Passed: ${{ gitea.repository }}" \
-H "Priority: default" \
-H "Tags: white_check_mark,github" \
-H "Click: ${{ gitea.server_url }}/${{ gitea.repository }}/actions/runs/${{ gitea.run_id }}" \
-d "Branch: ${{ gitea.ref_name }}
Commit: ${{ gitea.event.head_commit.message || gitea.sha }}
Release: ${{ needs.release.result == 'success' && 'created' || 'skipped' }}" \
${{ env.NTFY_URL }}/gitea-ci
- name: Notify on failure
if: needs.lint.result == 'failure' || needs.test.result == 'failure'
run: |
curl -s \
-H "Title: ❌ CI Failed: ${{ gitea.repository }}" \
-H "Priority: high" \
-H "Tags: x,github" \
-H "Click: ${{ gitea.server_url }}/${{ gitea.repository }}/actions/runs/${{ gitea.run_id }}" \
-d "Branch: ${{ gitea.ref_name }}
Commit: ${{ gitea.event.head_commit.message || gitea.sha }}
Lint: ${{ needs.lint.result }}
Test: ${{ needs.test.result }}" \
${{ env.NTFY_URL }}/gitea-ci

View File

@@ -0,0 +1,60 @@
name: Update handler-base
on:
repository_dispatch:
types: [handler-base-release]
env:
NTFY_URL: http://ntfy.observability.svc.cluster.local:80
GOPRIVATE: git.daviestechlabs.io
jobs:
update:
name: Update handler-base dependency
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
token: ${{ secrets.DISPATCH_TOKEN }}
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache: true
- name: Configure Git
run: |
git config user.name "gitea-actions[bot]"
git config user.email "actions@git.daviestechlabs.io"
git config --global url."https://gitea-actions:${{ secrets.DISPATCH_TOKEN }}@git.daviestechlabs.io/".insteadOf "https://git.daviestechlabs.io/"
- name: Update handler-base
run: |
VERSION="${{ gitea.event.client_payload.version }}"
echo "Updating handler-base to ${VERSION}"
go get git.daviestechlabs.io/daviestechlabs/handler-base@${VERSION}
go mod tidy
- name: Commit and push
run: |
VERSION="${{ gitea.event.client_payload.version }}"
if git diff --quiet go.mod go.sum; then
echo "No changes to commit"
exit 0
fi
git add go.mod go.sum
git commit -m "chore(deps): bump handler-base to ${VERSION}"
git push
- name: Notify
if: success()
run: |
VERSION="${{ gitea.event.client_payload.version }}"
curl -s \
-H "Title: 📦 Dep Update: ${{ gitea.repository }}" \
-H "Priority: default" \
-H "Tags: package,github" \
-d "handler-base updated to ${VERSION}" \
${{ env.NTFY_URL }}/gitea-ci

1
.gitignore vendored
View File

@@ -24,3 +24,4 @@ ENV/
.env
.env.local
*.log
chat-handler

View File

@@ -1,32 +0,0 @@
# Pre-commit hooks for chat-handler
# Install: pip install pre-commit && pre-commit install
# Run: pre-commit run --all-files
repos:
# Ruff - fast Python linter and formatter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
# Standard pre-commit hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
args: [--maxkb=500]
- id: check-merge-conflict
- id: detect-private-key
# Type checking (optional - uncomment when ready)
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.10.0
# hooks:
# - id: mypy
# additional_dependencies: [types-all]
# args: [--ignore-missing-imports]

View File

@@ -1,9 +1,26 @@
# Chat Handler - Using handler-base
ARG BASE_TAG=latest
FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG}
# Build stage
FROM golang:1.25-alpine AS builder
WORKDIR /app
COPY chat_handler.py .
RUN apk add --no-cache ca-certificates git
CMD ["python", "chat_handler.py"]
ENV GOPRIVATE=git.daviestechlabs.io
ENV GONOSUMCHECK=git.daviestechlabs.io
COPY go.mod go.sum ./
RUN --mount=type=secret,id=netrc,target=/root/.netrc go mod download
COPY . .
RUN CGO_ENABLED=0 GOOS=linux GOAMD64=v3 go build -ldflags="-w -s" -o /chat-handler .
# Runtime stage
FROM scratch
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
COPY --from=builder /chat-handler /chat-handler
USER 65534:65534
ENTRYPOINT ["/chat-handler"]

View File

@@ -1,227 +0,0 @@
#!/usr/bin/env python3
"""
Chat Handler Service (Refactored)
Text-based chat pipeline using handler-base:
1. Listen for text on NATS subject "ai.chat.request"
2. Generate embeddings for RAG
3. Retrieve context from Milvus
4. Rerank with BGE reranker
5. Generate response with vLLM
6. Optionally synthesize speech with XTTS
7. Publish result to NATS "ai.chat.response.{request_id}"
"""
import base64
import logging
from typing import Any, Optional
from nats.aio.msg import Msg
from handler_base import Handler, Settings
from handler_base.clients import (
EmbeddingsClient,
RerankerClient,
LLMClient,
TTSClient,
MilvusClient,
)
from handler_base.telemetry import create_span
logger = logging.getLogger("chat-handler")
class ChatSettings(Settings):
"""Chat handler specific settings."""
service_name: str = "chat-handler"
# RAG settings
rag_top_k: int = 10
rag_rerank_top_k: int = 5
rag_collection: str = "documents"
# Response settings
include_sources: bool = True
enable_tts: bool = False
tts_language: str = "en"
class ChatHandler(Handler):
"""
Chat request handler with RAG pipeline.
Request format:
{
"request_id": "uuid",
"query": "user question",
"collection": "optional collection name",
"enable_tts": false,
"system_prompt": "optional custom system prompt"
}
Response format:
{
"request_id": "uuid",
"response": "generated response",
"sources": [{"text": "...", "score": 0.95}],
"audio": "base64 encoded audio (if tts enabled)"
}
"""
def __init__(self):
self.chat_settings = ChatSettings()
super().__init__(
subject="ai.chat.request",
settings=self.chat_settings,
queue_group="chat-handlers",
)
async def setup(self) -> None:
"""Initialize service clients."""
logger.info("Initializing service clients...")
self.embeddings = EmbeddingsClient(self.chat_settings)
self.reranker = RerankerClient(self.chat_settings)
self.llm = LLMClient(self.chat_settings)
self.milvus = MilvusClient(self.chat_settings)
# TTS is optional
if self.chat_settings.enable_tts:
self.tts = TTSClient(self.chat_settings)
else:
self.tts = None
# Connect to Milvus
await self.milvus.connect(self.chat_settings.rag_collection)
logger.info("Service clients initialized")
async def teardown(self) -> None:
"""Clean up service clients."""
logger.info("Closing service clients...")
await self.embeddings.close()
await self.reranker.close()
await self.llm.close()
await self.milvus.close()
if self.tts:
await self.tts.close()
logger.info("Service clients closed")
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
"""Handle incoming chat request."""
request_id = data.get("request_id", "unknown")
query = data.get("query", "")
collection = data.get("collection", self.chat_settings.rag_collection)
enable_tts = data.get("enable_tts", self.chat_settings.enable_tts)
system_prompt = data.get("system_prompt")
logger.info(f"Processing request {request_id}: {query[:50]}...")
with create_span("chat.process") as span:
if span:
span.set_attribute("request.id", request_id)
span.set_attribute("query.length", len(query))
# 1. Generate query embedding
embedding = await self._get_embedding(query)
# 2. Search Milvus for context
documents = await self._search_context(embedding, collection)
# 3. Rerank documents
reranked = await self._rerank_documents(query, documents)
# 4. Build context from top documents
context = self._build_context(reranked)
# 5. Generate LLM response
response_text = await self._generate_response(query, context, system_prompt)
# 6. Optionally synthesize speech
audio_b64 = None
if enable_tts and self.tts:
audio_b64 = await self._synthesize_speech(response_text)
# Build response
result = {
"request_id": request_id,
"response": response_text,
}
if self.chat_settings.include_sources:
result["sources"] = [
{"text": d["document"][:200], "score": d["score"]} for d in reranked[:3]
]
if audio_b64:
result["audio"] = audio_b64
logger.info(f"Completed request {request_id}")
# Publish to response subject
response_subject = f"ai.chat.response.{request_id}"
await self.nats.publish(response_subject, result)
return result
async def _get_embedding(self, text: str) -> list[float]:
"""Generate embedding for query text."""
with create_span("chat.embedding"):
return await self.embeddings.embed_single(text)
async def _search_context(self, embedding: list[float], collection: str) -> list[dict]:
"""Search Milvus for relevant documents."""
with create_span("chat.search"):
return await self.milvus.search_with_texts(
embedding,
limit=self.chat_settings.rag_top_k,
text_field="text",
metadata_fields=["source", "title"],
)
async def _rerank_documents(self, query: str, documents: list[dict]) -> list[dict]:
"""Rerank documents by relevance to query."""
with create_span("chat.rerank"):
texts = [d.get("text", "") for d in documents]
return await self.reranker.rerank(
query, texts, top_k=self.chat_settings.rag_rerank_top_k
)
def _build_context(self, documents: list[dict]) -> str:
"""Build context string from ranked documents."""
context_parts = []
for i, doc in enumerate(documents, 1):
text = doc.get("document", "")
context_parts.append(f"[{i}] {text}")
return "\n\n".join(context_parts)
async def _generate_response(
self,
query: str,
context: str,
system_prompt: Optional[str] = None,
) -> str:
"""Generate LLM response with context."""
with create_span("chat.generate"):
return await self.llm.generate(
query,
context=context,
system_prompt=system_prompt,
)
async def _synthesize_speech(self, text: str) -> str:
"""Synthesize speech and return base64 encoded audio."""
with create_span("chat.tts"):
audio_bytes = await self.tts.synthesize(
text,
language=self.chat_settings.tts_language,
)
return base64.b64encode(audio_bytes).decode()
if __name__ == "__main__":
ChatHandler().run()

484
e2e_test.go Normal file
View File

@@ -0,0 +1,484 @@
package main
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"google.golang.org/protobuf/proto"
)
// ────────────────────────────────────────────────────────────────────────────
// E2E tests: exercise the full chat pipeline with mock backends
// ────────────────────────────────────────────────────────────────────────────
// mockBackends starts httptest servers simulating all downstream services.
type mockBackends struct {
Embeddings *httptest.Server
Reranker *httptest.Server
LLM *httptest.Server
TTS *httptest.Server
}
func newMockBackends(t *testing.T) *mockBackends {
t.Helper()
m := &mockBackends{}
m.Embeddings = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"data": []map[string]any{
{"embedding": []float64{0.1, 0.2, 0.3, 0.4}},
},
})
}))
t.Cleanup(m.Embeddings.Close)
m.Reranker = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"results": []map[string]any{
{"index": 0, "relevance_score": 0.95},
},
})
}))
t.Cleanup(m.Reranker.Close)
m.LLM = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{
"content": "Paris is the capital of France.",
}},
},
})
}))
t.Cleanup(m.LLM.Close)
m.TTS = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte{0xDE, 0xAD, 0xBE, 0xEF})
}))
t.Cleanup(m.TTS.Close)
return m
}
func TestChatPipeline_LLMOnly(t *testing.T) {
m := newMockBackends(t)
llm := clients.NewLLMClient(m.LLM.URL, 5*time.Second)
// Simulate what main.go does for a non-RAG request.
response, err := llm.Generate(context.Background(), "What is the capital of France?", "", "")
if err != nil {
t.Fatal(err)
}
if response != "Paris is the capital of France." {
t.Errorf("response = %q", response)
}
}
func TestChatPipeline_WithRAG(t *testing.T) {
m := newMockBackends(t)
embeddings := clients.NewEmbeddingsClient(m.Embeddings.URL, 5*time.Second, "bge")
reranker := clients.NewRerankerClient(m.Reranker.URL, 5*time.Second)
llm := clients.NewLLMClient(m.LLM.URL, 5*time.Second)
ctx := context.Background()
// 1. Embed query
embedding, err := embeddings.EmbedSingle(ctx, "What is the capital of France?")
if err != nil {
t.Fatal(err)
}
if len(embedding) == 0 {
t.Fatal("empty embedding")
}
// 2. Rerank (with mock documents)
docs := []string{"France is a country in Europe", "Paris is its capital"}
results, err := reranker.Rerank(ctx, "capital of France", docs, 2)
if err != nil {
t.Fatal(err)
}
if len(results) == 0 {
t.Fatal("no rerank results")
}
if results[0].Score == 0 {
t.Error("expected non-zero score")
}
// 3. Generate with context
contextText := results[0].Document
response, err := llm.Generate(ctx, "capital of France?", contextText, "")
if err != nil {
t.Fatal(err)
}
if response == "" {
t.Error("empty response")
}
}
func TestChatPipeline_WithTTS(t *testing.T) {
m := newMockBackends(t)
llm := clients.NewLLMClient(m.LLM.URL, 5*time.Second)
tts := clients.NewTTSClient(m.TTS.URL, 5*time.Second, "en")
ctx := context.Background()
response, err := llm.Generate(ctx, "hello", "", "")
if err != nil {
t.Fatal(err)
}
audio, err := tts.Synthesize(ctx, response, "en", "")
if err != nil {
t.Fatal(err)
}
if len(audio) == 0 {
t.Error("empty audio")
}
}
func TestChatPipeline_LLMTimeout(t *testing.T) {
// Simulate slow LLM.
slow := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": "late response"}},
},
})
}))
defer slow.Close()
llm := clients.NewLLMClient(slow.URL, 100*time.Millisecond)
_, err := llm.Generate(context.Background(), "hello", "", "")
if err == nil {
t.Error("expected timeout error")
}
}
func TestChatPipeline_TypedDecoding(t *testing.T) {
// Verify typed struct decoding from proto (same path as OnTypedMessage).
original := &messages.ChatRequest{
RequestId: "req-e2e-001",
UserId: "user-1",
Message: "hello",
Premium: true,
EnableRag: false,
EnableStreaming: false,
SystemPrompt: "Be brief.",
}
data, _ := proto.Marshal(original)
var req messages.ChatRequest
if err := proto.Unmarshal(data, &req); err != nil {
t.Fatal(err)
}
if req.RequestId != "req-e2e-001" {
t.Errorf("RequestID = %q", req.RequestId)
}
if req.UserId != "user-1" {
t.Errorf("UserID = %q", req.UserId)
}
if messages.EffectiveQuery(&req) != "hello" {
t.Errorf("query = %q", messages.EffectiveQuery(&req))
}
if req.EnableRag {
t.Error("EnableRAG should be false")
}
if req.SystemPrompt != "Be brief." {
t.Errorf("SystemPrompt = %q", req.SystemPrompt)
}
}
// ────────────────────────────────────────────────────────────────────────────
// Streaming tests: exercise StreamGenerate path (the real SSE pipeline)
// ────────────────────────────────────────────────────────────────────────────
// sseChunk builds an OpenAI-compatible SSE data line.
func sseChunk(content string) string {
return fmt.Sprintf("data: {\"choices\":[{\"delta\":{\"content\":%q}}]}\n\n", content)
}
// newStreamingLLM creates a mock LLM server that responds with SSE-streamed tokens.
func newStreamingLLM(t *testing.T, tokens []string) *httptest.Server {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
// Verify stream=true was requested.
if stream, ok := req["stream"].(bool); !ok || !stream {
t.Error("expected stream=true in request body")
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
flusher, _ := w.(http.Flusher)
// Role-only chunk (should be skipped by StreamGenerate)
_, _ = fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"}}]}\n\n")
if flusher != nil {
flusher.Flush()
}
for _, tok := range tokens {
_, _ = fmt.Fprint(w, sseChunk(tok))
if flusher != nil {
flusher.Flush()
}
}
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
if flusher != nil {
flusher.Flush()
}
}))
t.Cleanup(srv.Close)
return srv
}
func TestChatPipeline_StreamGenerate(t *testing.T) {
tokens := []string{"Paris", " is", " the", " capital", " of", " France", "."}
srv := newStreamingLLM(t, tokens)
llm := clients.NewLLMClient(srv.URL, 5*time.Second)
var mu sync.Mutex
var received []string
full, err := llm.StreamGenerate(context.Background(), "capital of France?", "", "", func(token string) {
mu.Lock()
defer mu.Unlock()
received = append(received, token)
})
if err != nil {
t.Fatal(err)
}
if full != "Paris is the capital of France." {
t.Errorf("full = %q", full)
}
if len(received) != len(tokens) {
t.Errorf("callback count = %d, want %d", len(received), len(tokens))
}
for i, tok := range tokens {
if received[i] != tok {
t.Errorf("token[%d] = %q, want %q", i, received[i], tok)
}
}
}
func TestChatPipeline_StreamWithSystemPrompt(t *testing.T) {
srv := newStreamingLLM(t, []string{"Hello", "!"})
llm := clients.NewLLMClient(srv.URL, 5*time.Second)
full, err := llm.StreamGenerate(context.Background(), "greet me", "", "You are a friendly assistant.", func(token string) {})
if err != nil {
t.Fatal(err)
}
if full != "Hello!" {
t.Errorf("full = %q", full)
}
}
func TestChatPipeline_StreamWithRAGContext(t *testing.T) {
m := newMockBackends(t)
srv := newStreamingLLM(t, []string{"The", " answer", " is", " 42"})
embeddings := clients.NewEmbeddingsClient(m.Embeddings.URL, 5*time.Second, "bge")
llm := clients.NewLLMClient(srv.URL, 5*time.Second)
ctx := context.Background()
// 1. Embed
embedding, err := embeddings.EmbedSingle(ctx, "deep thought")
if err != nil {
t.Fatal(err)
}
if len(embedding) == 0 {
t.Fatal("empty embedding")
}
// 2. Stream with context
var tokens []string
full, err := llm.StreamGenerate(ctx, "deep thought", "The answer to everything is 42.", "", func(tok string) {
tokens = append(tokens, tok)
})
if err != nil {
t.Fatal(err)
}
if full != "The answer is 42" {
t.Errorf("full = %q", full)
}
if len(tokens) != 4 {
t.Errorf("token count = %d, want 4", len(tokens))
}
}
func TestChatPipeline_StreamNilCallback(t *testing.T) {
srv := newStreamingLLM(t, []string{"ok"})
llm := clients.NewLLMClient(srv.URL, 5*time.Second)
full, err := llm.StreamGenerate(context.Background(), "test", "", "", nil)
if err != nil {
t.Fatal(err)
}
if full != "ok" {
t.Errorf("full = %q", full)
}
}
func TestChatPipeline_StreamTimeout(t *testing.T) {
slow := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.Header().Set("Content-Type", "text/event-stream")
_, _ = fmt.Fprint(w, sseChunk("late"))
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
}))
defer slow.Close()
llm := clients.NewLLMClient(slow.URL, 100*time.Millisecond)
_, err := llm.StreamGenerate(context.Background(), "hello", "", "", nil)
if err == nil {
t.Error("expected timeout error")
}
}
func TestChatPipeline_StreamHTTPError(t *testing.T) {
errSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("internal error"))
}))
defer errSrv.Close()
llm := clients.NewLLMClient(errSrv.URL, 5*time.Second)
_, err := llm.StreamGenerate(context.Background(), "hello", "", "", nil)
if err == nil {
t.Error("expected error for HTTP 500")
}
if !strings.Contains(err.Error(), "500") {
t.Errorf("error = %q, should mention status 500", err)
}
}
func TestChatPipeline_StreamContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
srv := newStreamingLLM(t, []string{"should", "not", "arrive"})
llm := clients.NewLLMClient(srv.URL, 5*time.Second)
_, err := llm.StreamGenerate(ctx, "hello", "", "", nil)
if err == nil {
t.Error("expected context canceled error")
}
}
func TestChatPipeline_StreamFallbackToNonStreaming(t *testing.T) {
// Simulate the branching in main.go: non-streaming uses Generate(),
// streaming uses StreamGenerate(). Verify both paths work from same mock.
m := newMockBackends(t)
streamSrv := newStreamingLLM(t, []string{"streamed", " answer"})
nonStreamLLM := clients.NewLLMClient(m.LLM.URL, 5*time.Second)
streamLLM := clients.NewLLMClient(streamSrv.URL, 5*time.Second)
ctx := context.Background()
// Non-streaming path
resp1, err := nonStreamLLM.Generate(ctx, "hello", "", "")
if err != nil {
t.Fatal(err)
}
if resp1 != "Paris is the capital of France." {
t.Errorf("non-stream = %q", resp1)
}
// Streaming path
var tokens []string
resp2, err := streamLLM.StreamGenerate(ctx, "hello", "", "", func(tok string) {
tokens = append(tokens, tok)
})
if err != nil {
t.Fatal(err)
}
if resp2 != "streamed answer" {
t.Errorf("stream = %q", resp2)
}
if len(tokens) != 2 {
t.Errorf("token count = %d", len(tokens))
}
}
// ────────────────────────────────────────────────────────────────────────────
// Benchmark: full chat pipeline overhead (mock backends)
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkChatPipeline_LLMOnly(b *testing.B) {
llmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"answer"}}]}`))
}))
defer llmSrv.Close()
llm := clients.NewLLMClient(llmSrv.URL, 10*time.Second)
ctx := context.Background()
b.ResetTimer()
for b.Loop() {
_, _ = llm.Generate(ctx, "question", "", "")
}
}
func BenchmarkChatPipeline_RAGFlow(b *testing.B) {
embedSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"data":[{"embedding":[0.1,0.2]}]}`))
}))
defer embedSrv.Close()
rerankSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"results":[{"index":0,"relevance_score":0.9}]}`))
}))
defer rerankSrv.Close()
llmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"answer"}}]}`))
}))
defer llmSrv.Close()
embed := clients.NewEmbeddingsClient(embedSrv.URL, 10*time.Second, "bge")
rerank := clients.NewRerankerClient(rerankSrv.URL, 10*time.Second)
llm := clients.NewLLMClient(llmSrv.URL, 10*time.Second)
ctx := context.Background()
b.ResetTimer()
for b.Loop() {
_, _ = embed.EmbedSingle(ctx, "question")
_, _ = rerank.Rerank(ctx, "question", []string{"doc1", "doc2"}, 2)
_, _ = llm.Generate(ctx, "question", "context", "")
}
}
func BenchmarkChatPipeline_StreamGenerate(b *testing.B) {
tokens := []string{"one", " two", " three", " four", " five"}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
for _, tok := range tokens {
_, _ = fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":%q}}]}\n\n", tok)
}
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
}))
defer srv.Close()
llm := clients.NewLLMClient(srv.URL, 10*time.Second)
ctx := context.Background()
b.ResetTimer()
for b.Loop() {
_, _ = llm.StreamGenerate(ctx, "question", "", "", func(string) {})
}
}

39
go.mod Normal file
View File

@@ -0,0 +1,39 @@
module git.daviestechlabs.io/daviestechlabs/chat-handler
go 1.25.1
require (
git.daviestechlabs.io/daviestechlabs/handler-base v1.0.0
github.com/nats-io/nats.go v1.48.0
google.golang.org/protobuf v1.36.11
)
require (
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/nats-io/nkeys v0.4.11 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/otel v1.40.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 // indirect
go.opentelemetry.io/otel/metric v1.40.0 // indirect
go.opentelemetry.io/otel/sdk v1.40.0 // indirect
go.opentelemetry.io/otel/sdk/metric v1.40.0 // indirect
go.opentelemetry.io/otel/trace v1.40.0 // indirect
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
golang.org/x/crypto v0.47.0 // indirect
golang.org/x/net v0.49.0 // indirect
golang.org/x/sys v0.40.0 // indirect
golang.org/x/text v0.33.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect
google.golang.org/grpc v1.78.0 // indirect
)

77
go.sum Normal file
View File

@@ -0,0 +1,77 @@
git.daviestechlabs.io/daviestechlabs/handler-base v1.0.0 h1:pB3ehOKaDYQfbyRBKQXrB9curqSFteLrDveoElRKnBY=
git.daviestechlabs.io/daviestechlabs/handler-base v1.0.0/go.mod h1:zocOHFt8yY3cW4+Xi37sNr5Tw7KcjGFSZqgWYxPWyqA=
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 h1:X+2YciYSxvMQK0UZ7sg45ZVabVZBeBuvMkmuI2V3Fak=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7/go.mod h1:lW34nIZuQ8UDPdkon5fmfp2l3+ZkQ2me/+oecHYLOII=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/nats-io/nats.go v1.48.0 h1:pSFyXApG+yWU/TgbKCjmm5K4wrHu86231/w84qRVR+U=
github.com/nats-io/nats.go v1.48.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g=
github.com/nats-io/nkeys v0.4.11 h1:q44qGV008kYd9W1b1nEBkNzvnWxtRSQ7A8BoqRrcfa0=
github.com/nats-io/nkeys v0.4.11/go.mod h1:szDimtgmfOi9n25JpfIdGw12tZFYXqhGxjhVxsatHVE=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms=
go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0 h1:NOyNnS19BF2SUDApbOKbDtWZ0IK7b8FJ2uAGdIWOGb0=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0/go.mod h1:VL6EgVikRLcJa9ftukrHu/ZkkhFBSo1lzvdBC9CF1ss=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 h1:QKdN8ly8zEMrByybbQgv8cWBcdAarwmIPZ6FThrWXJs=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0/go.mod h1:bTdK1nhqF76qiPoCCdyFIV+N/sRHYXYCTQc+3VCi3MI=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 h1:DvJDOPmSWQHWywQS6lKL+pb8s3gBLOZUtw4N+mavW1I=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0/go.mod h1:EtekO9DEJb4/jRyN4v4Qjc2yA7AtfCBuz2FynRUWTXs=
go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g=
go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc=
go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8=
go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE=
go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw=
go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg=
go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw=
go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA=
go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A=
go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 h1:merA0rdPeUV3YIIfHHcH4qBkiQAc1nfCKSI7lB4cV2M=
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409/go.mod h1:fl8J1IvUjCilwZzQowmw2b7HQB2eAuYBabMXzWurF+I=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 h1:H86B94AW+VfJWDqFeEbBPhEtHzJwJfTbgE2lZa54ZAQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

251
main.go Normal file
View File

@@ -0,0 +1,251 @@
package main
import (
"context"
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"time"
"github.com/nats-io/nats.go"
"git.daviestechlabs.io/daviestechlabs/handler-base/clients"
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
"google.golang.org/protobuf/proto"
)
func main() {
cfg := config.Load()
cfg.ServiceName = "chat-handler"
cfg.NATSQueueGroup = "chat-handlers"
// Chat-specific settings
ragTopK := getEnvInt("RAG_TOP_K", 10)
ragRerankTopK := getEnvInt("RAG_RERANK_TOP_K", 5)
ragCollection := getEnv("RAG_COLLECTION", "documents")
includeSources := getEnvBool("INCLUDE_SOURCES", true)
enableTTS := getEnvBool("ENABLE_TTS", false)
ttsLanguage := getEnv("TTS_LANGUAGE", "en")
// Service clients
timeout := 60 * time.Second
embeddings := clients.NewEmbeddingsClient(cfg.EmbeddingsURL(), timeout, "")
reranker := clients.NewRerankerClient(cfg.RerankerURL(), timeout)
llm := clients.NewLLMClient(cfg.LLMURL(), timeout)
milvus := clients.NewMilvusClient(cfg.MilvusHost, cfg.MilvusPort, ragCollection)
var tts *clients.TTSClient
if enableTTS {
tts = clients.NewTTSClient(cfg.TTSURL(), timeout, ttsLanguage)
}
h := handler.New("ai.chat.user.*.message", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
var req messages.ChatRequest
if err := natsutil.Decode(msg.Data, &req); err != nil {
slog.Error("decode failed", "error", err)
return &messages.ErrorResponse{Error: true, Message: err.Error(), Type: "DecodeError"}, nil
}
query := messages.EffectiveQuery(&req)
requestID := req.RequestId
if requestID == "" {
requestID = "unknown"
}
userID := req.UserId
if userID == "" {
userID = "unknown"
}
enableRAG := req.EnableRag
if !enableRAG && req.Premium {
enableRAG = true
}
enableReranker := req.EnableReranker
if !enableReranker && enableRAG {
enableReranker = true
}
topK := req.TopK
if topK == 0 {
topK = int32(ragTopK)
}
collection := req.Collection
if collection == "" {
collection = ragCollection
}
reqEnableTTS := req.EnableTts || enableTTS
systemPrompt := req.SystemPrompt
responseSubject := req.ResponseSubject
if responseSubject == "" {
responseSubject = fmt.Sprintf("ai.chat.response.%s", requestID)
}
slog.Info("processing request", "request_id", requestID, "query_len", len(query))
contextText := ""
var ragSources []string
usedRAG := false
// RAG pipeline
if enableRAG {
// 1. Embed query
embedding, err := embeddings.EmbedSingle(ctx, query)
if err != nil {
slog.Error("embedding failed", "error", err)
} else {
// 2. Search Milvus
_ = milvus
_ = collection
_ = topK
_ = embedding
// NOTE: Milvus search uses the gRPC SDK (requires milvus-sdk-go)
// For now, we pass through without search; Milvus client will be
// connected when the SDK is integrated.
// documents := milvus.Search(ctx, embedding, topK)
var documents []map[string]any // placeholder for Milvus results
// 3. Rerank
if enableReranker && len(documents) > 0 {
texts := make([]string, len(documents))
for i, d := range documents {
if t, ok := d["text"].(string); ok {
texts[i] = t
}
}
reranked, err := reranker.Rerank(ctx, query, texts, ragRerankTopK)
if err != nil {
slog.Error("rerank failed", "error", err)
} else {
documents = make([]map[string]any, len(reranked))
for i, r := range reranked {
documents[i] = map[string]any{"document": r.Document, "score": r.Score}
}
}
}
// 4. Build context
if len(documents) > 0 {
var parts []string
for i, d := range documents {
text := ""
if t, ok := d["document"].(string); ok {
text = t
}
parts = append(parts, fmt.Sprintf("[%d] %s", i+1, text))
}
contextText = strings.Join(parts, "\n\n")
for _, d := range documents {
if len(ragSources) >= 3 {
break
}
src := ""
if s, ok := d["source"].(string); ok {
src = s
} else if s, ok := d["document"].(string); ok && len(s) > 80 {
src = s[:80]
}
ragSources = append(ragSources, src)
}
usedRAG = true
}
}
}
// 5. Generate LLM response (streaming when requested)
var responseText string
var err error
if req.EnableStreaming {
streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID)
responseText, err = llm.StreamGenerate(ctx, query, contextText, systemPrompt, func(token string) {
_ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{
RequestId: requestID,
Type: "chunk",
Content: token,
Timestamp: messages.Timestamp(),
})
})
_ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{
RequestId: requestID,
Type: "done",
Done: true,
Timestamp: messages.Timestamp(),
})
} else {
responseText, err = llm.Generate(ctx, query, contextText, systemPrompt)
}
if err != nil {
slog.Error("LLM generation failed", "error", err)
return &messages.ChatResponse{
UserId: userID,
Success: false,
Error: err.Error(),
}, nil
}
// 6. Optional TTS — audio as raw bytes (no base64)
var audio []byte
if reqEnableTTS && tts != nil {
audioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")
if err != nil {
slog.Error("TTS failed", "error", err)
} else {
audio = audioBytes
}
}
result := &messages.ChatResponse{
UserId: userID,
Response: responseText,
ResponseText: responseText,
UsedRag: usedRAG,
Success: true,
Audio: audio,
}
if includeSources {
result.RagSources = ragSources
}
// Publish to the response subject the frontend is waiting on
_ = h.NATS.Publish(responseSubject, result)
slog.Info("completed request", "request_id", requestID, "rag", usedRAG)
return result, nil
})
if err := h.Run(); err != nil {
slog.Error("handler failed", "error", err)
os.Exit(1)
}
}
// Helpers
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getEnvInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return i
}
}
return fallback
}
func getEnvBool(key string, fallback bool) bool {
if v := os.Getenv(key); v != "" {
return strings.EqualFold(v, "true") || v == "1"
}
return fallback
}

96
main_test.go Normal file
View File

@@ -0,0 +1,96 @@
package main
import (
"os"
"testing"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"google.golang.org/protobuf/proto"
)
func TestChatRequestDecode(t *testing.T) {
// Verify a proto-encoded struct round-trips cleanly.
original := &messages.ChatRequest{
RequestId: "req-1",
UserId: "user-1",
Message: "hello",
Premium: true,
TopK: 10,
}
data, err := proto.Marshal(original)
if err != nil {
t.Fatal(err)
}
var req messages.ChatRequest
if err := proto.Unmarshal(data, &req); err != nil {
t.Fatal(err)
}
if req.RequestId != "req-1" {
t.Errorf("RequestID = %q", req.RequestId)
}
if messages.EffectiveQuery(&req) != "hello" {
t.Errorf("EffectiveQuery = %q", messages.EffectiveQuery(&req))
}
if !req.Premium {
t.Error("Premium should be true")
}
if req.TopK != 10 {
t.Errorf("TopK = %d", req.TopK)
}
}
func TestChatResponseRoundtrip(t *testing.T) {
resp := &messages.ChatResponse{
UserId: "user-1",
Response: "answer",
Success: true,
Audio: []byte{0x01, 0x02, 0x03},
}
data, err := proto.Marshal(resp)
if err != nil {
t.Fatal(err)
}
var decoded messages.ChatResponse
if err := proto.Unmarshal(data, &decoded); err != nil {
t.Fatal(err)
}
if decoded.UserId != "user-1" || !decoded.Success {
t.Errorf("decoded: UserId=%s, Success=%v", decoded.UserId, decoded.Success)
}
if len(decoded.Audio) != 3 {
t.Errorf("audio len = %d", len(decoded.Audio))
}
}
func TestGetEnvHelpers(t *testing.T) {
t.Setenv("CHAT_TEST", "hello")
if got := getEnv("CHAT_TEST", "x"); got != "hello" {
t.Errorf("getEnv = %q", got)
}
if got := getEnv("NO_SUCH_VAR", "x"); got != "x" {
t.Errorf("getEnv fallback = %q", got)
}
t.Setenv("CHAT_PORT", "9090")
if got := getEnvInt("CHAT_PORT", 0); got != 9090 {
t.Errorf("getEnvInt = %d", got)
}
if got := getEnvInt("NO_SUCH_VAR", 80); got != 80 {
t.Errorf("getEnvInt fallback = %d", got)
}
t.Setenv("CHAT_FLAG", "true")
if got := getEnvBool("CHAT_FLAG", false); !got {
t.Error("getEnvBool should be true")
}
if got := getEnvBool("NO_SUCH_VAR", false); got {
t.Error("getEnvBool fallback should be false")
}
}
func TestMainBinaryBuilds(t *testing.T) {
// Verify the binary exists after build
if _, err := os.Stat("main.go"); err != nil {
t.Skip("main.go not found")
}
}

View File

@@ -1,43 +0,0 @@
[project]
name = "chat-handler"
version = "1.0.0"
description = "Text chat pipeline with RAG - Query → Embeddings → Milvus → Rerank → LLM"
readme = "README.md"
requires-python = ">=3.11"
license = { text = "MIT" }
authors = [{ name = "Davies Tech Labs" }]
dependencies = [
"handler-base @ git+https://git.daviestechlabs.io/daviestechlabs/handler-base.git",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.23.0",
"ruff>=0.1.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["."]
only-include = ["chat_handler.py"]
[tool.ruff]
line-length = 100
target-version = "py311"
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v --tb=short"
filterwarnings = ["ignore::DeprecationWarning"]

7
renovate.json Normal file
View File

@@ -0,0 +1,7 @@
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": [
"local>daviestechlabs/renovate-config",
"local>daviestechlabs/renovate-config:python"
]
}

View File

@@ -1 +0,0 @@
# Chat Handler Tests

View File

@@ -1,82 +0,0 @@
"""
Pytest configuration and fixtures for chat-handler tests.
"""
import asyncio
import os
from unittest.mock import MagicMock
import pytest
# Set test environment variables before importing
os.environ.setdefault("NATS_URL", "nats://localhost:4222")
os.environ.setdefault("REDIS_URL", "redis://localhost:6379")
os.environ.setdefault("MILVUS_HOST", "localhost")
os.environ.setdefault("OTEL_ENABLED", "false")
os.environ.setdefault("MLFLOW_ENABLED", "false")
@pytest.fixture(scope="session")
def event_loop():
"""Create event loop for async tests."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
def sample_embedding():
"""Sample embedding vector."""
return [0.1] * 1024
@pytest.fixture
def sample_documents():
"""Sample search results."""
return [
{"text": "Machine learning is a subset of AI.", "score": 0.95},
{"text": "Deep learning uses neural networks.", "score": 0.90},
{"text": "AI enables intelligent automation.", "score": 0.85},
]
@pytest.fixture
def sample_reranked():
"""Sample reranked results."""
return [
{"document": "Machine learning is a subset of AI.", "score": 0.98},
{"document": "Deep learning uses neural networks.", "score": 0.85},
]
@pytest.fixture
def mock_nats_message():
"""Create a mock NATS message."""
msg = MagicMock()
msg.subject = "ai.chat.request"
msg.reply = "ai.chat.response.test-123"
return msg
@pytest.fixture
def mock_chat_request():
"""Sample chat request payload."""
return {
"request_id": "test-request-123",
"query": "What is machine learning?",
"collection": "test_collection",
"enable_tts": False,
"system_prompt": None,
}
@pytest.fixture
def mock_chat_request_with_tts():
"""Sample chat request with TTS enabled."""
return {
"request_id": "test-request-456",
"query": "Tell me about AI",
"collection": "documents",
"enable_tts": True,
"system_prompt": "You are a helpful assistant.",
}

View File

@@ -1,265 +0,0 @@
"""
Unit tests for ChatHandler.
"""
import pytest
from unittest.mock import AsyncMock, patch
from chat_handler import ChatHandler, ChatSettings
class TestChatSettings:
"""Tests for ChatSettings configuration."""
def test_default_settings(self):
"""Test default settings values."""
settings = ChatSettings()
assert settings.service_name == "chat-handler"
assert settings.rag_top_k == 10
assert settings.rag_rerank_top_k == 5
assert settings.rag_collection == "documents"
assert settings.include_sources is True
assert settings.enable_tts is False
assert settings.tts_language == "en"
def test_custom_settings(self):
"""Test custom settings."""
settings = ChatSettings(
rag_top_k=20,
rag_collection="custom_docs",
enable_tts=True,
)
assert settings.rag_top_k == 20
assert settings.rag_collection == "custom_docs"
assert settings.enable_tts is True
class TestChatHandler:
"""Tests for ChatHandler."""
@pytest.fixture
def handler(self):
"""Create handler with mocked clients."""
with (
patch("chat_handler.EmbeddingsClient"),
patch("chat_handler.RerankerClient"),
patch("chat_handler.LLMClient"),
patch("chat_handler.TTSClient"),
patch("chat_handler.MilvusClient"),
):
handler = ChatHandler()
# Setup mock clients
handler.embeddings = AsyncMock()
handler.reranker = AsyncMock()
handler.llm = AsyncMock()
handler.milvus = AsyncMock()
handler.tts = None # TTS disabled by default
handler.nats = AsyncMock()
yield handler
@pytest.fixture
def handler_with_tts(self):
"""Create handler with TTS enabled."""
with (
patch("chat_handler.EmbeddingsClient"),
patch("chat_handler.RerankerClient"),
patch("chat_handler.LLMClient"),
patch("chat_handler.TTSClient"),
patch("chat_handler.MilvusClient"),
):
handler = ChatHandler()
handler.chat_settings.enable_tts = True
# Setup mock clients
handler.embeddings = AsyncMock()
handler.reranker = AsyncMock()
handler.llm = AsyncMock()
handler.milvus = AsyncMock()
handler.tts = AsyncMock()
handler.nats = AsyncMock()
yield handler
def test_init(self, handler):
"""Test handler initialization."""
assert handler.subject == "ai.chat.request"
assert handler.queue_group == "chat-handlers"
assert handler.chat_settings.service_name == "chat-handler"
@pytest.mark.asyncio
async def test_handle_message_success(
self,
handler,
mock_nats_message,
mock_chat_request,
sample_embedding,
sample_documents,
sample_reranked,
):
"""Test successful chat request handling."""
# Setup mocks
handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Machine learning is a subset of AI that..."
# Execute
result = await handler.handle_message(mock_nats_message, mock_chat_request)
# Verify
assert result["request_id"] == "test-request-123"
assert "response" in result
assert result["response"] == "Machine learning is a subset of AI that..."
assert "sources" in result # include_sources is True by default
# Verify pipeline was called
handler.embeddings.embed_single.assert_called_once()
handler.milvus.search_with_texts.assert_called_once()
handler.reranker.rerank.assert_called_once()
handler.llm.generate.assert_called_once()
@pytest.mark.asyncio
async def test_handle_message_without_sources(
self,
handler,
mock_nats_message,
mock_chat_request,
sample_embedding,
sample_documents,
sample_reranked,
):
"""Test response without sources when disabled."""
handler.chat_settings.include_sources = False
handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Response text"
result = await handler.handle_message(mock_nats_message, mock_chat_request)
assert "sources" not in result
@pytest.mark.asyncio
async def test_handle_message_with_tts(
self,
handler_with_tts,
mock_nats_message,
mock_chat_request_with_tts,
sample_embedding,
sample_documents,
sample_reranked,
):
"""Test response with TTS audio."""
handler = handler_with_tts
handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "AI response"
handler.tts.synthesize.return_value = b"audio_bytes"
result = await handler.handle_message(mock_nats_message, mock_chat_request_with_tts)
assert "audio" in result
handler.tts.synthesize.assert_called_once()
@pytest.mark.asyncio
async def test_handle_message_with_custom_system_prompt(
self,
handler,
mock_nats_message,
sample_embedding,
sample_documents,
sample_reranked,
):
"""Test LLM is called with custom system prompt."""
request = {
"request_id": "test-123",
"query": "Hello",
"system_prompt": "You are a pirate. Respond like one.",
}
handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Ahoy!"
await handler.handle_message(mock_nats_message, request)
# Verify system_prompt was passed to LLM
handler.llm.generate.assert_called_once()
call_kwargs = handler.llm.generate.call_args.kwargs
assert call_kwargs.get("system_prompt") == "You are a pirate. Respond like one."
def test_build_context(self, handler):
"""Test context building with numbered sources."""
documents = [
{"document": "First doc content"},
{"document": "Second doc content"},
]
context = handler._build_context(documents)
assert "[1]" in context
assert "[2]" in context
assert "First doc content" in context
assert "Second doc content" in context
@pytest.mark.asyncio
async def test_setup_initializes_clients(self):
"""Test that setup initializes all required clients."""
with (
patch("chat_handler.EmbeddingsClient") as emb_cls,
patch("chat_handler.RerankerClient") as rer_cls,
patch("chat_handler.LLMClient") as llm_cls,
patch("chat_handler.TTSClient") as tts_cls,
patch("chat_handler.MilvusClient") as mil_cls,
):
mil_cls.return_value.connect = AsyncMock()
handler = ChatHandler()
await handler.setup()
emb_cls.assert_called_once()
rer_cls.assert_called_once()
llm_cls.assert_called_once()
mil_cls.assert_called_once()
# TTS should not be initialized when disabled
tts_cls.assert_not_called()
@pytest.mark.asyncio
async def test_teardown_closes_clients(self, handler):
"""Test that teardown closes all clients."""
await handler.teardown()
handler.embeddings.close.assert_called_once()
handler.reranker.close.assert_called_once()
handler.llm.close.assert_called_once()
handler.milvus.close.assert_called_once()
@pytest.mark.asyncio
async def test_publishes_to_response_subject(
self,
handler,
mock_nats_message,
mock_chat_request,
sample_embedding,
sample_documents,
sample_reranked,
):
"""Test that result is published to response subject."""
handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Response"
await handler.handle_message(mock_nats_message, mock_chat_request)
handler.nats.publish.assert_called_once()
call_args = handler.nats.publish.call_args
assert "ai.chat.response.test-request-123" in str(call_args)

2638
uv.lock generated

File diff suppressed because it is too large Load Diff