Files
stt-module/main.go
Billy D. dbaabe1f65
Some checks failed
CI / Test (push) Has been cancelled
CI / Release (push) Has been cancelled
CI / Docker Build & Push (push) Has been cancelled
CI / Notify (push) Has been cancelled
CI / Lint (push) Has been cancelled
fix: resolve golangci-lint errcheck warnings
- Add error checks for unchecked return values (errcheck)
- Remove unused struct fields (unused)
- Fix gofmt formatting issues
2026-02-20 08:45:44 -05:00

511 lines
13 KiB
Go

package main
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"log/slog"
"math"
"mime/multipart"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/nats-io/nats.go"
"github.com/vmihailenco/msgpack/v5"
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
)
// NATS subjects
const (
streamSubjectPrefix = "ai.voice.stream"
transcriptionSubjectPrefix = "ai.voice.transcription"
)
// Session states
const (
stateListening = "listening"
stateResponding = "responding"
)
// AudioBuffer manages audio chunks for a streaming session.
type AudioBuffer struct {
mu sync.Mutex
sessionID string
chunks [][]byte
totalBytes int
lastChunkTime time.Time
isComplete bool
sequence int
state string
speakerID string
interruptStartTime *time.Time
hasVoiceActivity bool
}
func newAudioBuffer(sessionID string) *AudioBuffer {
return &AudioBuffer{
sessionID: sessionID,
lastChunkTime: time.Now(),
state: stateListening,
}
}
func (ab *AudioBuffer) addChunk(audio []byte) {
ab.mu.Lock()
defer ab.mu.Unlock()
ab.chunks = append(ab.chunks, audio)
ab.totalBytes += len(audio)
ab.lastChunkTime = time.Now()
ab.hasVoiceActivity = detectVoiceActivity(audio)
}
func (ab *AudioBuffer) checkInterrupt(audio []byte, enableInterrupt bool, audioThreshold, interruptDuration float64) bool {
if !enableInterrupt {
return false
}
ab.mu.Lock()
defer ab.mu.Unlock()
if ab.state != stateResponding {
return false
}
rms := calculateAudioRMS(audio)
hasVoice := detectVoiceActivity(audio)
if rms >= audioThreshold && hasVoice {
if ab.interruptStartTime == nil {
now := time.Now()
ab.interruptStartTime = &now
slog.Info("potential interrupt detected", "session", ab.sessionID, "rms", rms)
}
elapsed := time.Since(*ab.interruptStartTime).Seconds()
if elapsed >= interruptDuration {
slog.Info("interrupt confirmed", "session", ab.sessionID, "elapsed", elapsed)
return true
}
} else {
ab.interruptStartTime = nil
}
return false
}
func (ab *AudioBuffer) setState(state string) {
ab.mu.Lock()
defer ab.mu.Unlock()
if state != stateListening && state != stateResponding {
return
}
if ab.state != state {
slog.Info("state changed", "session", ab.sessionID, "from", ab.state, "to", state)
ab.state = state
ab.interruptStartTime = nil
}
}
func (ab *AudioBuffer) shouldProcess(bufferSize, maxBufferSize int, chunkTimeout float64) bool {
ab.mu.Lock()
defer ab.mu.Unlock()
if !ab.hasVoiceActivity && ab.totalBytes < bufferSize &&
time.Since(ab.lastChunkTime).Seconds() < chunkTimeout {
return false
}
if ab.totalBytes >= bufferSize {
return true
}
if time.Since(ab.lastChunkTime).Seconds() > chunkTimeout && ab.totalBytes > 0 {
return true
}
return ab.totalBytes >= maxBufferSize
}
func (ab *AudioBuffer) getAudio() []byte {
ab.mu.Lock()
defer ab.mu.Unlock()
result := make([]byte, 0, ab.totalBytes)
for _, c := range ab.chunks {
result = append(result, c...)
}
return result
}
func (ab *AudioBuffer) clear() {
ab.mu.Lock()
defer ab.mu.Unlock()
ab.chunks = nil
ab.totalBytes = 0
ab.sequence++
}
func (ab *AudioBuffer) markComplete() {
ab.mu.Lock()
defer ab.mu.Unlock()
ab.isComplete = true
}
// calculateAudioRMS computes RMS level for 16-bit PCM audio, normalized to 0.0-1.0.
func calculateAudioRMS(audio []byte) float64 {
if len(audio) < 2 {
return 0.0
}
numSamples := len(audio) / 2
var sumSquares float64
for i := 0; i < numSamples; i++ {
sample := int16(binary.LittleEndian.Uint16(audio[i*2 : i*2+2]))
sumSquares += float64(sample) * float64(sample)
}
rms := math.Sqrt(sumSquares / float64(numSamples))
return rms / 32768.0
}
// detectVoiceActivity uses RMS-based voice detection (pure Go, no cgo).
func detectVoiceActivity(audio []byte) bool {
rms := calculateAudioRMS(audio)
return rms > 0.01 // reasonable threshold for speech
}
func main() {
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})))
cfg := config.Load()
cfg.ServiceName = "stt-streaming"
whisperURL := getEnv("WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local")
bufferSizeBytes := getEnvInt("STT_BUFFER_SIZE_BYTES", 512000)
chunkTimeout := getEnvFloat("STT_CHUNK_TIMEOUT", 2.0)
maxBufferSize := getEnvInt("STT_MAX_BUFFER_SIZE", 5120000)
enableInterrupt := getEnvBool("STT_ENABLE_INTERRUPT_DETECTION", true)
audioLevelThreshold := getEnvFloat("STT_AUDIO_LEVEL_THRESHOLD", 0.02)
interruptDuration := getEnvFloat("STT_INTERRUPT_DURATION", 0.5)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Telemetry
tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{
ServiceName: cfg.ServiceName,
ServiceVersion: cfg.ServiceVersion,
ServiceNamespace: cfg.ServiceNamespace,
DeploymentEnv: cfg.DeploymentEnv,
Enabled: cfg.OTELEnabled,
Endpoint: cfg.OTELEndpoint,
})
if err != nil {
slog.Error("telemetry setup failed", "error", err)
}
if shutdown != nil {
defer shutdown(ctx)
}
_ = tp
// NATS
natsOpts := []nats.Option{}
if cfg.NATSUser != "" && cfg.NATSPassword != "" {
natsOpts = append(natsOpts, nats.UserInfo(cfg.NATSUser, cfg.NATSPassword))
}
nc := natsutil.New(cfg.NATSURL, natsOpts...)
if err := nc.Connect(); err != nil {
slog.Error("NATS connect failed", "error", err)
os.Exit(1)
}
defer nc.Close()
// JetStream stream setup
js, err := nc.Conn().JetStream()
if err == nil {
_, err = js.AddStream(&nats.StreamConfig{
Name: "AI_VOICE_STREAM",
Subjects: []string{"ai.voice.stream.>", "ai.voice.transcription.>"},
Retention: nats.LimitsPolicy,
MaxAge: 5 * time.Minute,
Storage: nats.MemoryStorage,
})
if err != nil {
slog.Info("JetStream stream setup", "msg", err)
} else {
slog.Info("created/updated JetStream stream AI_VOICE_STREAM")
}
}
httpClient := &http.Client{Timeout: 180 * time.Second}
running := true
// Session management
var sessionsMu sync.RWMutex
sessions := make(map[string]*AudioBuffer)
// Health server
healthSrv := health.New(cfg.HealthPort, cfg.HealthPath, cfg.ReadyPath, func() bool {
return running && nc.IsConnected()
})
healthSrv.Start()
defer healthSrv.Stop(ctx)
// Transcribe via Whisper HTTP
transcribe := func(audioData []byte) (string, error) {
var buf bytes.Buffer
w := multipart.NewWriter(&buf)
part, err := w.CreateFormFile("file", "audio.wav")
if err != nil {
return "", err
}
_, _ = part.Write(audioData)
_ = w.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, whisperURL+"/v1/audio/transcriptions", &buf)
if err != nil {
return "", err
}
req.Header.Set("Content-Type", w.FormDataContentType())
resp, err := httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("whisper request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode >= 400 {
return "", fmt.Errorf("whisper %d: %s", resp.StatusCode, string(body))
}
var result struct {
Text string `json:"text"`
}
if err := json.Unmarshal(body, &result); err != nil {
return "", err
}
return result.Text, nil
}
// Process buffer for a session
processBuffer := func(sessionID string) {
sessionsMu.RLock()
buffer, ok := sessions[sessionID]
sessionsMu.RUnlock()
if !ok {
return
}
audioData := buffer.getAudio()
if len(audioData) == 0 {
return
}
buffer.mu.Lock()
seq := buffer.sequence
complete := buffer.isComplete
speakerID := buffer.speakerID
hasVoice := buffer.hasVoiceActivity
state := buffer.state
buffer.mu.Unlock()
slog.Info("processing buffer", "session", sessionID, "bytes", len(audioData), "seq", seq)
transcript, err := transcribe(audioData)
if err != nil {
slog.Error("transcription failed", "session", sessionID, "error", err)
}
if transcript != "" {
result := &messages.STTTranscription{
SessionID: sessionID,
Transcript: transcript,
Sequence: seq,
IsPartial: !complete,
IsFinal: complete,
Timestamp: time.Now().Unix(),
SpeakerID: speakerID,
HasVoiceActivity: hasVoice,
State: state,
}
packed, _ := msgpack.Marshal(result)
_ = nc.Conn().Publish(fmt.Sprintf("%s.%s", transcriptionSubjectPrefix, sessionID), packed)
slog.Info("published transcription", "session", sessionID, "seq", seq)
}
buffer.clear()
if complete {
slog.Info("session completed", "session", sessionID)
go func() {
time.Sleep(5 * time.Second)
sessionsMu.Lock()
delete(sessions, sessionID)
sessionsMu.Unlock()
slog.Info("cleaned up session", "session", sessionID)
}()
}
}
// Monitor goroutine for a session
monitorBuffer := func(sessionID string) {
for running {
sessionsMu.RLock()
buffer, ok := sessions[sessionID]
sessionsMu.RUnlock()
if !ok {
return
}
if buffer.shouldProcess(bufferSizeBytes, maxBufferSize, chunkTimeout) {
processBuffer(sessionID)
}
time.Sleep(100 * time.Millisecond)
}
}
// Handle incoming stream messages
handleStreamMsg := func(natMsg *nats.Msg) {
parts := strings.Split(natMsg.Subject, ".")
if len(parts) < 4 {
slog.Warn("invalid subject", "subject", natMsg.Subject)
return
}
sessionID := parts[3]
streamMsg, err := natsutil.Decode[messages.STTStreamMessage](natMsg.Data)
if err != nil {
slog.Error("decode error", "error", err)
return
}
switch streamMsg.Type {
case "start":
slog.Info("starting stream session", "session", sessionID)
buf := newAudioBuffer(sessionID)
if streamMsg.State != "" {
buf.setState(streamMsg.State)
}
if streamMsg.SpeakerID != "" {
buf.speakerID = streamMsg.SpeakerID
}
sessionsMu.Lock()
sessions[sessionID] = buf
sessionsMu.Unlock()
go monitorBuffer(sessionID)
case "state_change":
sessionsMu.RLock()
buffer, ok := sessions[sessionID]
sessionsMu.RUnlock()
if ok && streamMsg.State != "" {
buffer.setState(streamMsg.State)
}
case "end":
slog.Info("ending stream session", "session", sessionID)
sessionsMu.RLock()
buffer, ok := sessions[sessionID]
sessionsMu.RUnlock()
if ok {
buffer.markComplete()
buffer.mu.Lock()
hasData := buffer.totalBytes > 0
buffer.mu.Unlock()
if hasData {
processBuffer(sessionID)
}
}
case "chunk":
// Audio arrives as raw bytes — no base64 decode needed
if len(streamMsg.Audio) == 0 {
return
}
// Auto-create session if missing
sessionsMu.Lock()
buffer, ok := sessions[sessionID]
if !ok {
slog.Info("auto-creating session", "session", sessionID)
buffer = newAudioBuffer(sessionID)
sessions[sessionID] = buffer
go monitorBuffer(sessionID)
}
sessionsMu.Unlock()
// Check for interrupt
if buffer.checkInterrupt(streamMsg.Audio, enableInterrupt, audioLevelThreshold, interruptDuration) {
interruptMsg := &messages.STTInterrupt{
SessionID: sessionID,
Type: "interrupt",
Timestamp: time.Now().Unix(),
SpeakerID: buffer.speakerID,
}
packed, _ := msgpack.Marshal(interruptMsg)
_ = nc.Conn().Publish(fmt.Sprintf("%s.%s", transcriptionSubjectPrefix, sessionID), packed)
slog.Info("published interrupt", "session", sessionID)
buffer.setState(stateListening)
}
buffer.addChunk(streamMsg.Audio)
}
}
if _, err := nc.Conn().Subscribe(streamSubjectPrefix+".>", handleStreamMsg); err != nil {
slog.Error("subscribe failed", "error", err)
os.Exit(1)
}
slog.Info("subscribed", "subject", streamSubjectPrefix+".>")
slog.Info("stt-streaming ready")
// Wait for shutdown
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
<-sigCh
slog.Info("shutting down")
running = false
cancel()
slog.Info("shutdown complete")
}
// Helpers
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getEnvInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return i
}
}
return fallback
}
func getEnvFloat(key string, fallback float64) float64 {
if v := os.Getenv(key); v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
return f
}
}
return fallback
}
func getEnvBool(key string, fallback bool) bool {
if v := os.Getenv(key); v != "" {
return strings.EqualFold(v, "true") || v == "1"
}
return fallback
}