// Package clients provides HTTP client wrappers for AI/ML backend services. // // All clients share a single [http.Transport] for connection pooling across // the process. Request and response bodies are serialized through pooled // [bytes.Buffer]s to reduce GC pressure. package clients import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "mime/multipart" "net/http" "net/url" "strings" "sync" "time" ) // ─── Shared transport & buffer pool ───────────────────────────────────────── // SharedTransport is the process-wide HTTP transport used by every service // client. Tweak pool sizes here rather than creating per-client transports. var SharedTransport = &http.Transport{ MaxIdleConns: 100, MaxIdleConnsPerHost: 10, IdleConnTimeout: 90 * time.Second, DisableCompression: true, // in-cluster traffic; skip gzip overhead } // bufPool recycles *bytes.Buffer to avoid per-request allocations. var bufPool = sync.Pool{ New: func() any { return new(bytes.Buffer) }, } func getBuf() *bytes.Buffer { buf := bufPool.Get().(*bytes.Buffer) buf.Reset() return buf } func putBuf(buf *bytes.Buffer) { if buf.Cap() > 1<<20 { // don't cache buffers > 1 MiB return } bufPool.Put(buf) } // ─── httpClient base ──────────────────────────────────────────────────────── // httpClient is the shared base for all service clients. type httpClient struct { client *http.Client baseURL string } func newHTTPClient(baseURL string, timeout time.Duration) *httpClient { return &httpClient{ client: &http.Client{ Timeout: timeout, Transport: SharedTransport, }, baseURL: baseURL, } } func (h *httpClient) postJSON(ctx context.Context, path string, body any) ([]byte, error) { buf := getBuf() defer putBuf(buf) if err := json.NewEncoder(buf).Encode(body); err != nil { return nil, fmt.Errorf("marshal: %w", err) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") return h.do(req) } func (h *httpClient) get(ctx context.Context, path string, params url.Values) ([]byte, error) { u := h.baseURL + path if len(params) > 0 { u += "?" + params.Encode() } req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) if err != nil { return nil, err } return h.do(req) } func (h *httpClient) getRaw(ctx context.Context, path string, params url.Values) ([]byte, error) { return h.get(ctx, path, params) } func (h *httpClient) postMultipart(ctx context.Context, path string, fieldName string, fileName string, fileData []byte, fields map[string]string) ([]byte, error) { buf := getBuf() defer putBuf(buf) w := multipart.NewWriter(buf) part, err := w.CreateFormFile(fieldName, fileName) if err != nil { return nil, err } if _, err := part.Write(fileData); err != nil { return nil, err } for k, v := range fields { _ = w.WriteField(k, v) } _ = w.Close() req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf) if err != nil { return nil, err } req.Header.Set("Content-Type", w.FormDataContentType()) return h.do(req) } func (h *httpClient) do(req *http.Request) ([]byte, error) { resp, err := h.client.Do(req) if err != nil { return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err) } defer func() { _ = resp.Body.Close() }() buf := getBuf() defer putBuf(buf) if _, err := io.Copy(buf, resp.Body); err != nil { return nil, fmt.Errorf("read body: %w", err) } // Return a copy so the pooled buffer can be safely recycled. body := make([]byte, buf.Len()) copy(body, buf.Bytes()) if resp.StatusCode >= 400 { return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(body)) } return body, nil } // postJSONStream sends a JSON POST and returns the raw *http.Response so the // caller can read the body incrementally (e.g. for SSE streaming). The caller // is responsible for closing resp.Body. func (h *httpClient) postJSONStream(ctx context.Context, path string, body any) (*http.Response, error) { buf := getBuf() defer putBuf(buf) if err := json.NewEncoder(buf).Encode(body); err != nil { return nil, fmt.Errorf("marshal: %w", err) } // Copy to a non-pooled buffer so we can safely return the pool buffer. payload := make([]byte, buf.Len()) copy(payload, buf.Bytes()) req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, bytes.NewReader(payload)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") resp, err := h.client.Do(req) if err != nil { return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err) } if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(respBody)) } return resp, nil } func (h *httpClient) healthCheck(ctx context.Context) bool { data, err := h.get(ctx, "/health", nil) _ = data return err == nil } // ─── Embeddings Client ────────────────────────────────────────────────────── // EmbeddingsClient calls the embeddings service (Infinity/BGE). type EmbeddingsClient struct { *httpClient Model string } // NewEmbeddingsClient creates an embeddings client. func NewEmbeddingsClient(baseURL string, timeout time.Duration, model string) *EmbeddingsClient { if model == "" { model = "bge" } return &EmbeddingsClient{httpClient: newHTTPClient(baseURL, timeout), Model: model} } // Embed generates embeddings for a list of texts. func (c *EmbeddingsClient) Embed(ctx context.Context, texts []string) ([][]float64, error) { body, err := c.postJSON(ctx, "/embeddings", map[string]any{ "input": texts, "model": c.Model, }) if err != nil { return nil, err } var resp struct { Data []struct { Embedding []float64 `json:"embedding"` } `json:"data"` } if err := json.Unmarshal(body, &resp); err != nil { return nil, err } result := make([][]float64, len(resp.Data)) for i, d := range resp.Data { result[i] = d.Embedding } return result, nil } // EmbedSingle generates an embedding for a single text. func (c *EmbeddingsClient) EmbedSingle(ctx context.Context, text string) ([]float64, error) { results, err := c.Embed(ctx, []string{text}) if err != nil { return nil, err } if len(results) == 0 { return nil, fmt.Errorf("empty embedding result") } return results[0], nil } // Health checks if the embeddings service is healthy. func (c *EmbeddingsClient) Health(ctx context.Context) bool { return c.healthCheck(ctx) } // ─── Reranker Client ──────────────────────────────────────────────────────── // RerankerClient calls the reranker service (BGE Reranker). type RerankerClient struct { *httpClient } // NewRerankerClient creates a reranker client. func NewRerankerClient(baseURL string, timeout time.Duration) *RerankerClient { return &RerankerClient{httpClient: newHTTPClient(baseURL, timeout)} } // RerankResult represents a reranked document. type RerankResult struct { Index int `json:"index"` Score float64 `json:"score"` Document string `json:"document"` } // Rerank reranks documents by relevance to the query. func (c *RerankerClient) Rerank(ctx context.Context, query string, documents []string, topK int) ([]RerankResult, error) { payload := map[string]any{ "query": query, "documents": documents, } if topK > 0 { payload["top_n"] = topK } body, err := c.postJSON(ctx, "/rerank", payload) if err != nil { return nil, err } var resp struct { Results []struct { Index int `json:"index"` RelevanceScore float64 `json:"relevance_score"` Score float64 `json:"score"` } `json:"results"` } if err := json.Unmarshal(body, &resp); err != nil { return nil, err } results := make([]RerankResult, len(resp.Results)) for i, r := range resp.Results { score := r.RelevanceScore if score == 0 { score = r.Score } doc := "" if r.Index < len(documents) { doc = documents[r.Index] } results[i] = RerankResult{Index: r.Index, Score: score, Document: doc} } return results, nil } // ─── LLM Client ───────────────────────────────────────────────────────────── // LLMClient calls the vLLM-compatible LLM service. type LLMClient struct { *httpClient Model string MaxTokens int Temperature float64 TopP float64 } // NewLLMClient creates an LLM client. func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient { return &LLMClient{ httpClient: newHTTPClient(baseURL, timeout), Model: "default", MaxTokens: 2048, Temperature: 0.7, TopP: 0.9, } } // ChatMessage is an OpenAI-compatible message. type ChatMessage struct { Role string `json:"role"` Content string `json:"content"` } // Generate sends a chat completion request and returns the response text. func (c *LLMClient) Generate(ctx context.Context, prompt string, context_ string, systemPrompt string) (string, error) { messages := buildMessages(prompt, context_, systemPrompt) payload := map[string]any{ "model": c.Model, "messages": messages, "max_tokens": c.MaxTokens, "temperature": c.Temperature, "top_p": c.TopP, } body, err := c.postJSON(ctx, "/v1/chat/completions", payload) if err != nil { return "", err } var resp struct { Choices []struct { Message struct { Content string `json:"content"` } `json:"message"` } `json:"choices"` } if err := json.Unmarshal(body, &resp); err != nil { return "", err } if len(resp.Choices) == 0 { return "", fmt.Errorf("no choices in LLM response") } return resp.Choices[0].Message.Content, nil } // StreamGenerate sends a streaming chat completion request and calls onToken // for each content delta received via SSE. Returns the fully assembled text. // The onToken callback is invoked synchronously on the calling goroutine; it // should be fast (e.g. publish a NATS message). func (c *LLMClient) StreamGenerate(ctx context.Context, prompt string, context_ string, systemPrompt string, onToken func(token string)) (string, error) { msgs := buildMessages(prompt, context_, systemPrompt) payload := map[string]any{ "model": c.Model, "messages": msgs, "max_tokens": c.MaxTokens, "temperature": c.Temperature, "top_p": c.TopP, "stream": true, } resp, err := c.postJSONStream(ctx, "/v1/chat/completions", payload) if err != nil { return "", err } defer func() { _ = resp.Body.Close() }() var full strings.Builder scanner := bufio.NewScanner(resp.Body) // SSE lines can be up to 64 KiB for large token batches. scanner.Buffer(make([]byte, 0, 64*1024), 64*1024) for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { break } var chunk struct { Choices []struct { Delta struct { Content string `json:"content"` } `json:"delta"` } `json:"choices"` } if err := json.Unmarshal([]byte(data), &chunk); err != nil { continue // skip malformed chunks } if len(chunk.Choices) == 0 { continue } token := chunk.Choices[0].Delta.Content if token == "" { continue } full.WriteString(token) if onToken != nil { onToken(token) } } if err := scanner.Err(); err != nil { // If we already collected some text, return it with the error. if full.Len() > 0 { return full.String(), fmt.Errorf("stream interrupted: %w", err) } return "", fmt.Errorf("stream read: %w", err) } return full.String(), nil } func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage { var msgs []ChatMessage if systemPrompt != "" { msgs = append(msgs, ChatMessage{Role: "system", Content: systemPrompt}) } else if ctx != "" { msgs = append(msgs, ChatMessage{Role: "system", Content: "You are a helpful assistant. Use the provided context to answer the user's question. If the context doesn't contain relevant information, say so."}) } if ctx != "" { msgs = append(msgs, ChatMessage{Role: "user", Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", ctx, prompt)}) } else { msgs = append(msgs, ChatMessage{Role: "user", Content: prompt}) } return msgs } // ─── TTS Client ───────────────────────────────────────────────────────────── // TTSClient calls the TTS service (Coqui XTTS). type TTSClient struct { *httpClient Language string } // NewTTSClient creates a TTS client. func NewTTSClient(baseURL string, timeout time.Duration, language string) *TTSClient { if language == "" { language = "en" } return &TTSClient{httpClient: newHTTPClient(baseURL, timeout), Language: language} } // Synthesize generates audio bytes from text. func (c *TTSClient) Synthesize(ctx context.Context, text, language, speaker string) ([]byte, error) { if language == "" { language = c.Language } params := url.Values{ "text": {text}, "language_id": {language}, } if speaker != "" { params.Set("speaker_id", speaker) } return c.getRaw(ctx, "/api/tts", params) } // ─── STT Client ───────────────────────────────────────────────────────────── // STTClient calls the Whisper STT service. type STTClient struct { *httpClient Language string Task string } // NewSTTClient creates an STT client. func NewSTTClient(baseURL string, timeout time.Duration) *STTClient { return &STTClient{httpClient: newHTTPClient(baseURL, timeout), Task: "transcribe"} } // TranscribeResult holds transcription output. type TranscribeResult struct { Text string `json:"text"` Language string `json:"language,omitempty"` } // Transcribe sends audio to Whisper and returns the transcription. func (c *STTClient) Transcribe(ctx context.Context, audio []byte, language string) (*TranscribeResult, error) { if language == "" { language = c.Language } fields := map[string]string{ "response_format": "json", } if language != "" { fields["language"] = language } endpoint := "/v1/audio/transcriptions" if c.Task == "translate" { endpoint = "/v1/audio/translations" } body, err := c.postMultipart(ctx, endpoint, "file", "audio.wav", audio, fields) if err != nil { return nil, err } var result TranscribeResult if err := json.Unmarshal(body, &result); err != nil { return nil, err } return &result, nil } // ─── Milvus Client ────────────────────────────────────────────────────────── // MilvusClient provides vector search via the Milvus HTTP/gRPC API. // For the Go port we use the Milvus Go SDK. type MilvusClient struct { Host string Port int Collection string } // NewMilvusClient creates a Milvus client. func NewMilvusClient(host string, port int, collection string) *MilvusClient { return &MilvusClient{Host: host, Port: port, Collection: collection} } // SearchResult holds a single vector search hit. type SearchResult struct { ID int64 `json:"id"` Distance float64 `json:"distance"` Score float64 `json:"score"` Fields map[string]any `json:"fields,omitempty"` }