- AudioBuffer getAudio(): use ab.totalBytes directly (eliminates triple-copy) - Decode STTStreamMessage via natsutil.Decode[messages.STTStreamMessage] - Audio chunks arrive as raw []byte (no base64 decode needed) - Publish STTTranscription struct (not map[string]any) - Interrupts use messages.STTInterrupt - Remove encoding/base64 import - Add .dockerignore, GOAMD64=v3 in Dockerfile - All 15 tests pass
511 lines
13 KiB
Go
511 lines
13 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"math"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/nats-io/nats.go"
|
|
"github.com/vmihailenco/msgpack/v5"
|
|
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
|
|
)
|
|
|
|
// NATS subjects
|
|
const (
|
|
streamSubjectPrefix = "ai.voice.stream"
|
|
transcriptionSubjectPrefix = "ai.voice.transcription"
|
|
)
|
|
|
|
// Session states
|
|
const (
|
|
stateListening = "listening"
|
|
stateResponding = "responding"
|
|
)
|
|
|
|
// AudioBuffer manages audio chunks for a streaming session.
|
|
type AudioBuffer struct {
|
|
mu sync.Mutex
|
|
sessionID string
|
|
chunks [][]byte
|
|
totalBytes int
|
|
lastChunkTime time.Time
|
|
isComplete bool
|
|
sequence int
|
|
state string
|
|
speakerID string
|
|
interruptStartTime *time.Time
|
|
hasVoiceActivity bool
|
|
}
|
|
|
|
func newAudioBuffer(sessionID string) *AudioBuffer {
|
|
return &AudioBuffer{
|
|
sessionID: sessionID,
|
|
lastChunkTime: time.Now(),
|
|
state: stateListening,
|
|
}
|
|
}
|
|
|
|
func (ab *AudioBuffer) addChunk(audio []byte) {
|
|
ab.mu.Lock()
|
|
defer ab.mu.Unlock()
|
|
ab.chunks = append(ab.chunks, audio)
|
|
ab.totalBytes += len(audio)
|
|
ab.lastChunkTime = time.Now()
|
|
ab.hasVoiceActivity = detectVoiceActivity(audio)
|
|
}
|
|
|
|
func (ab *AudioBuffer) checkInterrupt(audio []byte, enableInterrupt bool, audioThreshold, interruptDuration float64) bool {
|
|
if !enableInterrupt {
|
|
return false
|
|
}
|
|
ab.mu.Lock()
|
|
defer ab.mu.Unlock()
|
|
|
|
if ab.state != stateResponding {
|
|
return false
|
|
}
|
|
|
|
rms := calculateAudioRMS(audio)
|
|
hasVoice := detectVoiceActivity(audio)
|
|
|
|
if rms >= audioThreshold && hasVoice {
|
|
if ab.interruptStartTime == nil {
|
|
now := time.Now()
|
|
ab.interruptStartTime = &now
|
|
slog.Info("potential interrupt detected", "session", ab.sessionID, "rms", rms)
|
|
}
|
|
elapsed := time.Since(*ab.interruptStartTime).Seconds()
|
|
if elapsed >= interruptDuration {
|
|
slog.Info("interrupt confirmed", "session", ab.sessionID, "elapsed", elapsed)
|
|
return true
|
|
}
|
|
} else {
|
|
ab.interruptStartTime = nil
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (ab *AudioBuffer) setState(state string) {
|
|
ab.mu.Lock()
|
|
defer ab.mu.Unlock()
|
|
if state != stateListening && state != stateResponding {
|
|
return
|
|
}
|
|
if ab.state != state {
|
|
slog.Info("state changed", "session", ab.sessionID, "from", ab.state, "to", state)
|
|
ab.state = state
|
|
ab.interruptStartTime = nil
|
|
}
|
|
}
|
|
|
|
func (ab *AudioBuffer) shouldProcess(bufferSize, maxBufferSize int, chunkTimeout float64) bool {
|
|
ab.mu.Lock()
|
|
defer ab.mu.Unlock()
|
|
|
|
if !ab.hasVoiceActivity && ab.totalBytes < bufferSize &&
|
|
time.Since(ab.lastChunkTime).Seconds() < chunkTimeout {
|
|
return false
|
|
}
|
|
if ab.totalBytes >= bufferSize {
|
|
return true
|
|
}
|
|
if time.Since(ab.lastChunkTime).Seconds() > chunkTimeout && ab.totalBytes > 0 {
|
|
return true
|
|
}
|
|
return ab.totalBytes >= maxBufferSize
|
|
}
|
|
|
|
func (ab *AudioBuffer) getAudio() []byte {
|
|
ab.mu.Lock()
|
|
defer ab.mu.Unlock()
|
|
result := make([]byte, 0, ab.totalBytes)
|
|
for _, c := range ab.chunks {
|
|
result = append(result, c...)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (ab *AudioBuffer) clear() {
|
|
ab.mu.Lock()
|
|
defer ab.mu.Unlock()
|
|
ab.chunks = nil
|
|
ab.totalBytes = 0
|
|
ab.sequence++
|
|
}
|
|
|
|
func (ab *AudioBuffer) markComplete() {
|
|
ab.mu.Lock()
|
|
defer ab.mu.Unlock()
|
|
ab.isComplete = true
|
|
}
|
|
|
|
// calculateAudioRMS computes RMS level for 16-bit PCM audio, normalized to 0.0-1.0.
|
|
func calculateAudioRMS(audio []byte) float64 {
|
|
if len(audio) < 2 {
|
|
return 0.0
|
|
}
|
|
numSamples := len(audio) / 2
|
|
var sumSquares float64
|
|
for i := 0; i < numSamples; i++ {
|
|
sample := int16(binary.LittleEndian.Uint16(audio[i*2 : i*2+2]))
|
|
sumSquares += float64(sample) * float64(sample)
|
|
}
|
|
rms := math.Sqrt(sumSquares / float64(numSamples))
|
|
return rms / 32768.0
|
|
}
|
|
|
|
// detectVoiceActivity uses RMS-based voice detection (pure Go, no cgo).
|
|
func detectVoiceActivity(audio []byte) bool {
|
|
rms := calculateAudioRMS(audio)
|
|
return rms > 0.01 // reasonable threshold for speech
|
|
}
|
|
|
|
func main() {
|
|
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
|
|
|
cfg := config.Load()
|
|
cfg.ServiceName = "stt-streaming"
|
|
|
|
whisperURL := getEnv("WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local")
|
|
bufferSizeBytes := getEnvInt("STT_BUFFER_SIZE_BYTES", 512000)
|
|
chunkTimeout := getEnvFloat("STT_CHUNK_TIMEOUT", 2.0)
|
|
maxBufferSize := getEnvInt("STT_MAX_BUFFER_SIZE", 5120000)
|
|
enableInterrupt := getEnvBool("STT_ENABLE_INTERRUPT_DETECTION", true)
|
|
audioLevelThreshold := getEnvFloat("STT_AUDIO_LEVEL_THRESHOLD", 0.02)
|
|
interruptDuration := getEnvFloat("STT_INTERRUPT_DURATION", 0.5)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Telemetry
|
|
tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{
|
|
ServiceName: cfg.ServiceName,
|
|
ServiceVersion: cfg.ServiceVersion,
|
|
ServiceNamespace: cfg.ServiceNamespace,
|
|
DeploymentEnv: cfg.DeploymentEnv,
|
|
Enabled: cfg.OTELEnabled,
|
|
Endpoint: cfg.OTELEndpoint,
|
|
})
|
|
if err != nil {
|
|
slog.Error("telemetry setup failed", "error", err)
|
|
}
|
|
if shutdown != nil {
|
|
defer shutdown(ctx)
|
|
}
|
|
_ = tp
|
|
|
|
// NATS
|
|
natsOpts := []nats.Option{}
|
|
if cfg.NATSUser != "" && cfg.NATSPassword != "" {
|
|
natsOpts = append(natsOpts, nats.UserInfo(cfg.NATSUser, cfg.NATSPassword))
|
|
}
|
|
nc := natsutil.New(cfg.NATSURL, natsOpts...)
|
|
if err := nc.Connect(); err != nil {
|
|
slog.Error("NATS connect failed", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
defer nc.Close()
|
|
|
|
// JetStream stream setup
|
|
js, err := nc.Conn().JetStream()
|
|
if err == nil {
|
|
_, err = js.AddStream(&nats.StreamConfig{
|
|
Name: "AI_VOICE_STREAM",
|
|
Subjects: []string{"ai.voice.stream.>", "ai.voice.transcription.>"},
|
|
Retention: nats.LimitsPolicy,
|
|
MaxAge: 5 * time.Minute,
|
|
Storage: nats.MemoryStorage,
|
|
})
|
|
if err != nil {
|
|
slog.Info("JetStream stream setup", "msg", err)
|
|
} else {
|
|
slog.Info("created/updated JetStream stream AI_VOICE_STREAM")
|
|
}
|
|
}
|
|
|
|
httpClient := &http.Client{Timeout: 180 * time.Second}
|
|
running := true
|
|
|
|
// Session management
|
|
var sessionsMu sync.RWMutex
|
|
sessions := make(map[string]*AudioBuffer)
|
|
|
|
// Health server
|
|
healthSrv := health.New(cfg.HealthPort, cfg.HealthPath, cfg.ReadyPath, func() bool {
|
|
return running && nc.IsConnected()
|
|
})
|
|
healthSrv.Start()
|
|
defer healthSrv.Stop(ctx)
|
|
|
|
// Transcribe via Whisper HTTP
|
|
transcribe := func(audioData []byte) (string, error) {
|
|
var buf bytes.Buffer
|
|
w := multipart.NewWriter(&buf)
|
|
part, err := w.CreateFormFile("file", "audio.wav")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
part.Write(audioData)
|
|
w.Close()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, whisperURL+"/v1/audio/transcriptions", &buf)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
|
|
|
resp, err := httpClient.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("whisper request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
body, _ := io.ReadAll(resp.Body)
|
|
|
|
if resp.StatusCode >= 400 {
|
|
return "", fmt.Errorf("whisper %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var result struct {
|
|
Text string `json:"text"`
|
|
}
|
|
if err := json.Unmarshal(body, &result); err != nil {
|
|
return "", err
|
|
}
|
|
return result.Text, nil
|
|
}
|
|
|
|
// Process buffer for a session
|
|
processBuffer := func(sessionID string) {
|
|
sessionsMu.RLock()
|
|
buffer, ok := sessions[sessionID]
|
|
sessionsMu.RUnlock()
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
audioData := buffer.getAudio()
|
|
if len(audioData) == 0 {
|
|
return
|
|
}
|
|
|
|
buffer.mu.Lock()
|
|
seq := buffer.sequence
|
|
complete := buffer.isComplete
|
|
speakerID := buffer.speakerID
|
|
hasVoice := buffer.hasVoiceActivity
|
|
state := buffer.state
|
|
buffer.mu.Unlock()
|
|
|
|
slog.Info("processing buffer", "session", sessionID, "bytes", len(audioData), "seq", seq)
|
|
|
|
transcript, err := transcribe(audioData)
|
|
if err != nil {
|
|
slog.Error("transcription failed", "session", sessionID, "error", err)
|
|
}
|
|
|
|
if transcript != "" {
|
|
result := &messages.STTTranscription{
|
|
SessionID: sessionID,
|
|
Transcript: transcript,
|
|
Sequence: seq,
|
|
IsPartial: !complete,
|
|
IsFinal: complete,
|
|
Timestamp: time.Now().Unix(),
|
|
SpeakerID: speakerID,
|
|
HasVoiceActivity: hasVoice,
|
|
State: state,
|
|
}
|
|
packed, _ := msgpack.Marshal(result)
|
|
nc.Conn().Publish(fmt.Sprintf("%s.%s", transcriptionSubjectPrefix, sessionID), packed)
|
|
slog.Info("published transcription", "session", sessionID, "seq", seq)
|
|
}
|
|
|
|
buffer.clear()
|
|
|
|
if complete {
|
|
slog.Info("session completed", "session", sessionID)
|
|
go func() {
|
|
time.Sleep(5 * time.Second)
|
|
sessionsMu.Lock()
|
|
delete(sessions, sessionID)
|
|
sessionsMu.Unlock()
|
|
slog.Info("cleaned up session", "session", sessionID)
|
|
}()
|
|
}
|
|
}
|
|
|
|
// Monitor goroutine for a session
|
|
monitorBuffer := func(sessionID string) {
|
|
for running {
|
|
sessionsMu.RLock()
|
|
buffer, ok := sessions[sessionID]
|
|
sessionsMu.RUnlock()
|
|
if !ok {
|
|
return
|
|
}
|
|
if buffer.shouldProcess(bufferSizeBytes, maxBufferSize, chunkTimeout) {
|
|
processBuffer(sessionID)
|
|
}
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
}
|
|
|
|
// Handle incoming stream messages
|
|
handleStreamMsg := func(natMsg *nats.Msg) {
|
|
parts := strings.Split(natMsg.Subject, ".")
|
|
if len(parts) < 4 {
|
|
slog.Warn("invalid subject", "subject", natMsg.Subject)
|
|
return
|
|
}
|
|
sessionID := parts[3]
|
|
|
|
streamMsg, err := natsutil.Decode[messages.STTStreamMessage](natMsg.Data)
|
|
if err != nil {
|
|
slog.Error("decode error", "error", err)
|
|
return
|
|
}
|
|
|
|
switch streamMsg.Type {
|
|
case "start":
|
|
slog.Info("starting stream session", "session", sessionID)
|
|
buf := newAudioBuffer(sessionID)
|
|
if streamMsg.State != "" {
|
|
buf.setState(streamMsg.State)
|
|
}
|
|
if streamMsg.SpeakerID != "" {
|
|
buf.speakerID = streamMsg.SpeakerID
|
|
}
|
|
sessionsMu.Lock()
|
|
sessions[sessionID] = buf
|
|
sessionsMu.Unlock()
|
|
go monitorBuffer(sessionID)
|
|
|
|
case "state_change":
|
|
sessionsMu.RLock()
|
|
buffer, ok := sessions[sessionID]
|
|
sessionsMu.RUnlock()
|
|
if ok && streamMsg.State != "" {
|
|
buffer.setState(streamMsg.State)
|
|
}
|
|
|
|
case "end":
|
|
slog.Info("ending stream session", "session", sessionID)
|
|
sessionsMu.RLock()
|
|
buffer, ok := sessions[sessionID]
|
|
sessionsMu.RUnlock()
|
|
if ok {
|
|
buffer.markComplete()
|
|
buffer.mu.Lock()
|
|
hasData := buffer.totalBytes > 0
|
|
buffer.mu.Unlock()
|
|
if hasData {
|
|
processBuffer(sessionID)
|
|
}
|
|
}
|
|
|
|
case "chunk":
|
|
// Audio arrives as raw bytes — no base64 decode needed
|
|
if len(streamMsg.Audio) == 0 {
|
|
return
|
|
}
|
|
|
|
// Auto-create session if missing
|
|
sessionsMu.Lock()
|
|
buffer, ok := sessions[sessionID]
|
|
if !ok {
|
|
slog.Info("auto-creating session", "session", sessionID)
|
|
buffer = newAudioBuffer(sessionID)
|
|
sessions[sessionID] = buffer
|
|
go monitorBuffer(sessionID)
|
|
}
|
|
sessionsMu.Unlock()
|
|
|
|
// Check for interrupt
|
|
if buffer.checkInterrupt(streamMsg.Audio, enableInterrupt, audioLevelThreshold, interruptDuration) {
|
|
interruptMsg := &messages.STTInterrupt{
|
|
SessionID: sessionID,
|
|
Type: "interrupt",
|
|
Timestamp: time.Now().Unix(),
|
|
SpeakerID: buffer.speakerID,
|
|
}
|
|
packed, _ := msgpack.Marshal(interruptMsg)
|
|
nc.Conn().Publish(fmt.Sprintf("%s.%s", transcriptionSubjectPrefix, sessionID), packed)
|
|
slog.Info("published interrupt", "session", sessionID)
|
|
buffer.setState(stateListening)
|
|
}
|
|
|
|
buffer.addChunk(streamMsg.Audio)
|
|
}
|
|
}
|
|
|
|
if _, err := nc.Conn().Subscribe(streamSubjectPrefix+".>", handleStreamMsg); err != nil {
|
|
slog.Error("subscribe failed", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
slog.Info("subscribed", "subject", streamSubjectPrefix+".>")
|
|
slog.Info("stt-streaming ready")
|
|
|
|
// Wait for shutdown
|
|
sigCh := make(chan os.Signal, 1)
|
|
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
|
<-sigCh
|
|
|
|
slog.Info("shutting down")
|
|
running = false
|
|
cancel()
|
|
slog.Info("shutdown complete")
|
|
}
|
|
|
|
// Helpers
|
|
|
|
func getEnv(key, fallback string) string {
|
|
if v := os.Getenv(key); v != "" {
|
|
return v
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func getEnvInt(key string, fallback int) int {
|
|
if v := os.Getenv(key); v != "" {
|
|
if i, err := strconv.Atoi(v); err == nil {
|
|
return i
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func getEnvFloat(key string, fallback float64) float64 {
|
|
if v := os.Getenv(key); v != "" {
|
|
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
|
return f
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func getEnvBool(key string, fallback bool) bool {
|
|
if v := os.Getenv(key); v != "" {
|
|
return strings.EqualFold(v, "true") || v == "1"
|
|
}
|
|
return fallback
|
|
}
|