Files
stt-module/main.go
Billy D. 9d4d48e693 feat: rewrite stt-module (HTTP variant) in Go
Replace Python streaming STT service with Go for smaller container images.
Local Whisper/ROCm variant (stt_streaming_local.py, Dockerfile.rocm) stays Python.

- AudioBuffer with session state management (listening/responding)
- RMS-based voice activity detection (pure Go, no cgo)
- Interrupt detection during LLM response playback
- JetStream AI_VOICE_STREAM setup
- Session auto-creation and cleanup
- Dockerfile: multi-stage golang:1.25-alpine → scratch
- CI: Gitea Actions with lint/test/release/docker/notify
2026-02-19 18:04:15 -05:00

530 lines
13 KiB
Go

package main
import (
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"log/slog"
"math"
"mime/multipart"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/nats-io/nats.go"
"github.com/vmihailenco/msgpack/v5"
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
)
// NATS subjects
const (
streamSubjectPrefix = "ai.voice.stream"
transcriptionSubjectPrefix = "ai.voice.transcription"
)
// Session states
const (
stateListening = "listening"
stateResponding = "responding"
)
// AudioBuffer manages audio chunks for a streaming session.
type AudioBuffer struct {
mu sync.Mutex
sessionID string
chunks [][]byte
totalBytes int
lastChunkTime time.Time
isComplete bool
sequence int
state string
speakerID string
interruptStartTime *time.Time
hasVoiceActivity bool
}
func newAudioBuffer(sessionID string) *AudioBuffer {
return &AudioBuffer{
sessionID: sessionID,
lastChunkTime: time.Now(),
state: stateListening,
}
}
func (ab *AudioBuffer) addChunk(audio []byte) {
ab.mu.Lock()
defer ab.mu.Unlock()
ab.chunks = append(ab.chunks, audio)
ab.totalBytes += len(audio)
ab.lastChunkTime = time.Now()
ab.hasVoiceActivity = detectVoiceActivity(audio)
}
func (ab *AudioBuffer) checkInterrupt(audio []byte, enableInterrupt bool, audioThreshold, interruptDuration float64) bool {
if !enableInterrupt {
return false
}
ab.mu.Lock()
defer ab.mu.Unlock()
if ab.state != stateResponding {
return false
}
rms := calculateAudioRMS(audio)
hasVoice := detectVoiceActivity(audio)
if rms >= audioThreshold && hasVoice {
if ab.interruptStartTime == nil {
now := time.Now()
ab.interruptStartTime = &now
slog.Info("potential interrupt detected", "session", ab.sessionID, "rms", rms)
}
elapsed := time.Since(*ab.interruptStartTime).Seconds()
if elapsed >= interruptDuration {
slog.Info("interrupt confirmed", "session", ab.sessionID, "elapsed", elapsed)
return true
}
} else {
ab.interruptStartTime = nil
}
return false
}
func (ab *AudioBuffer) setState(state string) {
ab.mu.Lock()
defer ab.mu.Unlock()
if state != stateListening && state != stateResponding {
return
}
if ab.state != state {
slog.Info("state changed", "session", ab.sessionID, "from", ab.state, "to", state)
ab.state = state
ab.interruptStartTime = nil
}
}
func (ab *AudioBuffer) shouldProcess(bufferSize, maxBufferSize int, chunkTimeout float64) bool {
ab.mu.Lock()
defer ab.mu.Unlock()
if !ab.hasVoiceActivity && ab.totalBytes < bufferSize &&
time.Since(ab.lastChunkTime).Seconds() < chunkTimeout {
return false
}
if ab.totalBytes >= bufferSize {
return true
}
if time.Since(ab.lastChunkTime).Seconds() > chunkTimeout && ab.totalBytes > 0 {
return true
}
return ab.totalBytes >= maxBufferSize
}
func (ab *AudioBuffer) getAudio() []byte {
ab.mu.Lock()
defer ab.mu.Unlock()
var total int
for _, c := range ab.chunks {
total += len(c)
}
result := make([]byte, 0, total)
for _, c := range ab.chunks {
result = append(result, c...)
}
return result
}
func (ab *AudioBuffer) clear() {
ab.mu.Lock()
defer ab.mu.Unlock()
ab.chunks = nil
ab.totalBytes = 0
ab.sequence++
}
func (ab *AudioBuffer) markComplete() {
ab.mu.Lock()
defer ab.mu.Unlock()
ab.isComplete = true
}
// calculateAudioRMS computes RMS level for 16-bit PCM audio, normalized to 0.0-1.0.
func calculateAudioRMS(audio []byte) float64 {
if len(audio) < 2 {
return 0.0
}
numSamples := len(audio) / 2
var sumSquares float64
for i := 0; i < numSamples; i++ {
sample := int16(binary.LittleEndian.Uint16(audio[i*2 : i*2+2]))
sumSquares += float64(sample) * float64(sample)
}
rms := math.Sqrt(sumSquares / float64(numSamples))
return rms / 32768.0
}
// detectVoiceActivity uses RMS-based voice detection (pure Go, no cgo).
func detectVoiceActivity(audio []byte) bool {
rms := calculateAudioRMS(audio)
return rms > 0.01 // reasonable threshold for speech
}
func main() {
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})))
cfg := config.Load()
cfg.ServiceName = "stt-streaming"
whisperURL := getEnv("WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local")
bufferSizeBytes := getEnvInt("STT_BUFFER_SIZE_BYTES", 512000)
chunkTimeout := getEnvFloat("STT_CHUNK_TIMEOUT", 2.0)
maxBufferSize := getEnvInt("STT_MAX_BUFFER_SIZE", 5120000)
enableInterrupt := getEnvBool("STT_ENABLE_INTERRUPT_DETECTION", true)
audioLevelThreshold := getEnvFloat("STT_AUDIO_LEVEL_THRESHOLD", 0.02)
interruptDuration := getEnvFloat("STT_INTERRUPT_DURATION", 0.5)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Telemetry
tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{
ServiceName: cfg.ServiceName,
ServiceVersion: cfg.ServiceVersion,
ServiceNamespace: cfg.ServiceNamespace,
DeploymentEnv: cfg.DeploymentEnv,
Enabled: cfg.OTELEnabled,
Endpoint: cfg.OTELEndpoint,
})
if err != nil {
slog.Error("telemetry setup failed", "error", err)
}
if shutdown != nil {
defer shutdown(ctx)
}
_ = tp
// NATS
natsOpts := []nats.Option{}
if cfg.NATSUser != "" && cfg.NATSPassword != "" {
natsOpts = append(natsOpts, nats.UserInfo(cfg.NATSUser, cfg.NATSPassword))
}
nc := natsutil.New(cfg.NATSURL, natsOpts...)
if err := nc.Connect(); err != nil {
slog.Error("NATS connect failed", "error", err)
os.Exit(1)
}
defer nc.Close()
// JetStream stream setup
js, err := nc.Conn().JetStream()
if err == nil {
_, err = js.AddStream(&nats.StreamConfig{
Name: "AI_VOICE_STREAM",
Subjects: []string{"ai.voice.stream.>", "ai.voice.transcription.>"},
Retention: nats.LimitsPolicy,
MaxAge: 5 * time.Minute,
Storage: nats.MemoryStorage,
})
if err != nil {
slog.Info("JetStream stream setup", "msg", err)
} else {
slog.Info("created/updated JetStream stream AI_VOICE_STREAM")
}
}
httpClient := &http.Client{Timeout: 180 * time.Second}
running := true
// Session management
var sessionsMu sync.RWMutex
sessions := make(map[string]*AudioBuffer)
// Health server
healthSrv := health.New(cfg.HealthPort, cfg.HealthPath, cfg.ReadyPath, func() bool {
return running && nc.IsConnected()
})
healthSrv.Start()
defer healthSrv.Stop(ctx)
// Transcribe via Whisper HTTP
transcribe := func(audioData []byte) (string, error) {
var buf bytes.Buffer
w := multipart.NewWriter(&buf)
part, err := w.CreateFormFile("file", "audio.wav")
if err != nil {
return "", err
}
part.Write(audioData)
w.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, whisperURL+"/v1/audio/transcriptions", &buf)
if err != nil {
return "", err
}
req.Header.Set("Content-Type", w.FormDataContentType())
resp, err := httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("whisper request: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode >= 400 {
return "", fmt.Errorf("whisper %d: %s", resp.StatusCode, string(body))
}
var result struct {
Text string `json:"text"`
}
if err := json.Unmarshal(body, &result); err != nil {
return "", err
}
return result.Text, nil
}
// Process buffer for a session
processBuffer := func(sessionID string) {
sessionsMu.RLock()
buffer, ok := sessions[sessionID]
sessionsMu.RUnlock()
if !ok {
return
}
audioData := buffer.getAudio()
if len(audioData) == 0 {
return
}
buffer.mu.Lock()
seq := buffer.sequence
complete := buffer.isComplete
speakerID := buffer.speakerID
hasVoice := buffer.hasVoiceActivity
state := buffer.state
buffer.mu.Unlock()
slog.Info("processing buffer", "session", sessionID, "bytes", len(audioData), "seq", seq)
transcript, err := transcribe(audioData)
if err != nil {
slog.Error("transcription failed", "session", sessionID, "error", err)
}
if transcript != "" {
result := map[string]any{
"session_id": sessionID,
"transcript": transcript,
"sequence": seq,
"is_partial": !complete,
"is_final": complete,
"timestamp": time.Now().Unix(),
"speaker_id": speakerID,
"has_voice_activity": hasVoice,
"state": state,
}
packed, _ := msgpack.Marshal(result)
nc.Conn().Publish(fmt.Sprintf("%s.%s", transcriptionSubjectPrefix, sessionID), packed)
slog.Info("published transcription", "session", sessionID, "seq", seq)
}
buffer.clear()
if complete {
slog.Info("session completed", "session", sessionID)
go func() {
time.Sleep(5 * time.Second)
sessionsMu.Lock()
delete(sessions, sessionID)
sessionsMu.Unlock()
slog.Info("cleaned up session", "session", sessionID)
}()
}
}
// Monitor goroutine for a session
monitorBuffer := func(sessionID string) {
for running {
sessionsMu.RLock()
buffer, ok := sessions[sessionID]
sessionsMu.RUnlock()
if !ok {
return
}
if buffer.shouldProcess(bufferSizeBytes, maxBufferSize, chunkTimeout) {
processBuffer(sessionID)
}
time.Sleep(100 * time.Millisecond)
}
}
// Handle incoming stream messages
handleStreamMsg := func(natMsg *nats.Msg) {
parts := strings.Split(natMsg.Subject, ".")
if len(parts) < 4 {
slog.Warn("invalid subject", "subject", natMsg.Subject)
return
}
sessionID := parts[3]
data, err := natsutil.DecodeMsgpackMap(natMsg.Data)
if err != nil {
slog.Error("decode error", "error", err)
return
}
msgType := ""
if t, ok := data["type"].(string); ok {
msgType = t
}
switch msgType {
case "start":
slog.Info("starting stream session", "session", sessionID)
buf := newAudioBuffer(sessionID)
if s, ok := data["state"].(string); ok {
buf.setState(s)
}
if s, ok := data["speaker_id"].(string); ok {
buf.speakerID = s
}
sessionsMu.Lock()
sessions[sessionID] = buf
sessionsMu.Unlock()
go monitorBuffer(sessionID)
case "state_change":
sessionsMu.RLock()
buffer, ok := sessions[sessionID]
sessionsMu.RUnlock()
if ok {
if s, ok := data["state"].(string); ok {
buffer.setState(s)
}
}
case "end":
slog.Info("ending stream session", "session", sessionID)
sessionsMu.RLock()
buffer, ok := sessions[sessionID]
sessionsMu.RUnlock()
if ok {
buffer.markComplete()
buffer.mu.Lock()
hasData := buffer.totalBytes > 0
buffer.mu.Unlock()
if hasData {
processBuffer(sessionID)
}
}
case "chunk":
audioB64 := ""
if s, ok := data["audio_b64"].(string); ok {
audioB64 = s
}
if audioB64 == "" {
return
}
audioBytes, err := base64.StdEncoding.DecodeString(audioB64)
if err != nil {
slog.Error("base64 decode failed", "error", err)
return
}
// Auto-create session if missing
sessionsMu.Lock()
buffer, ok := sessions[sessionID]
if !ok {
slog.Info("auto-creating session", "session", sessionID)
buffer = newAudioBuffer(sessionID)
sessions[sessionID] = buffer
go monitorBuffer(sessionID)
}
sessionsMu.Unlock()
// Check for interrupt
if buffer.checkInterrupt(audioBytes, enableInterrupt, audioLevelThreshold, interruptDuration) {
interruptMsg := map[string]any{
"session_id": sessionID,
"type": "interrupt",
"timestamp": time.Now().Unix(),
"speaker_id": buffer.speakerID,
}
packed, _ := msgpack.Marshal(interruptMsg)
nc.Conn().Publish(fmt.Sprintf("%s.%s", transcriptionSubjectPrefix, sessionID), packed)
slog.Info("published interrupt", "session", sessionID)
buffer.setState(stateListening)
}
buffer.addChunk(audioBytes)
}
}
if _, err := nc.Conn().Subscribe(streamSubjectPrefix+".>", handleStreamMsg); err != nil {
slog.Error("subscribe failed", "error", err)
os.Exit(1)
}
slog.Info("subscribed", "subject", streamSubjectPrefix+".>")
slog.Info("stt-streaming ready")
// Wait for shutdown
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
<-sigCh
slog.Info("shutting down")
running = false
cancel()
slog.Info("shutdown complete")
}
// Helpers
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getEnvInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return i
}
}
return fallback
}
func getEnvFloat(key string, fallback float64) float64 {
if v := os.Getenv(key); v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
return f
}
}
return fallback
}
func getEnvBool(key string, fallback bool) bool {
if v := os.Getenv(key); v != "" {
return strings.EqualFold(v, "true") || v == "1"
}
return fallback
}