feat: rewrite stt-module (HTTP variant) in Go
Replace Python streaming STT service with Go for smaller container images. Local Whisper/ROCm variant (stt_streaming_local.py, Dockerfile.rocm) stays Python. - AudioBuffer with session state management (listening/responding) - RMS-based voice activity detection (pure Go, no cgo) - Interrupt detection during LLM response playback - JetStream AI_VOICE_STREAM setup - Session auto-creation and cleanup - Dockerfile: multi-stage golang:1.25-alpine → scratch - CI: Gitea Actions with lint/test/release/docker/notify
This commit is contained in:
529
main.go
Normal file
529
main.go
Normal file
@@ -0,0 +1,529 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
|
||||
)
|
||||
|
||||
// NATS subjects
|
||||
const (
|
||||
streamSubjectPrefix = "ai.voice.stream"
|
||||
transcriptionSubjectPrefix = "ai.voice.transcription"
|
||||
)
|
||||
|
||||
// Session states
|
||||
const (
|
||||
stateListening = "listening"
|
||||
stateResponding = "responding"
|
||||
)
|
||||
|
||||
// AudioBuffer manages audio chunks for a streaming session.
|
||||
type AudioBuffer struct {
|
||||
mu sync.Mutex
|
||||
sessionID string
|
||||
chunks [][]byte
|
||||
totalBytes int
|
||||
lastChunkTime time.Time
|
||||
isComplete bool
|
||||
sequence int
|
||||
state string
|
||||
speakerID string
|
||||
interruptStartTime *time.Time
|
||||
hasVoiceActivity bool
|
||||
}
|
||||
|
||||
func newAudioBuffer(sessionID string) *AudioBuffer {
|
||||
return &AudioBuffer{
|
||||
sessionID: sessionID,
|
||||
lastChunkTime: time.Now(),
|
||||
state: stateListening,
|
||||
}
|
||||
}
|
||||
|
||||
func (ab *AudioBuffer) addChunk(audio []byte) {
|
||||
ab.mu.Lock()
|
||||
defer ab.mu.Unlock()
|
||||
ab.chunks = append(ab.chunks, audio)
|
||||
ab.totalBytes += len(audio)
|
||||
ab.lastChunkTime = time.Now()
|
||||
ab.hasVoiceActivity = detectVoiceActivity(audio)
|
||||
}
|
||||
|
||||
func (ab *AudioBuffer) checkInterrupt(audio []byte, enableInterrupt bool, audioThreshold, interruptDuration float64) bool {
|
||||
if !enableInterrupt {
|
||||
return false
|
||||
}
|
||||
ab.mu.Lock()
|
||||
defer ab.mu.Unlock()
|
||||
|
||||
if ab.state != stateResponding {
|
||||
return false
|
||||
}
|
||||
|
||||
rms := calculateAudioRMS(audio)
|
||||
hasVoice := detectVoiceActivity(audio)
|
||||
|
||||
if rms >= audioThreshold && hasVoice {
|
||||
if ab.interruptStartTime == nil {
|
||||
now := time.Now()
|
||||
ab.interruptStartTime = &now
|
||||
slog.Info("potential interrupt detected", "session", ab.sessionID, "rms", rms)
|
||||
}
|
||||
elapsed := time.Since(*ab.interruptStartTime).Seconds()
|
||||
if elapsed >= interruptDuration {
|
||||
slog.Info("interrupt confirmed", "session", ab.sessionID, "elapsed", elapsed)
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
ab.interruptStartTime = nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ab *AudioBuffer) setState(state string) {
|
||||
ab.mu.Lock()
|
||||
defer ab.mu.Unlock()
|
||||
if state != stateListening && state != stateResponding {
|
||||
return
|
||||
}
|
||||
if ab.state != state {
|
||||
slog.Info("state changed", "session", ab.sessionID, "from", ab.state, "to", state)
|
||||
ab.state = state
|
||||
ab.interruptStartTime = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ab *AudioBuffer) shouldProcess(bufferSize, maxBufferSize int, chunkTimeout float64) bool {
|
||||
ab.mu.Lock()
|
||||
defer ab.mu.Unlock()
|
||||
|
||||
if !ab.hasVoiceActivity && ab.totalBytes < bufferSize &&
|
||||
time.Since(ab.lastChunkTime).Seconds() < chunkTimeout {
|
||||
return false
|
||||
}
|
||||
if ab.totalBytes >= bufferSize {
|
||||
return true
|
||||
}
|
||||
if time.Since(ab.lastChunkTime).Seconds() > chunkTimeout && ab.totalBytes > 0 {
|
||||
return true
|
||||
}
|
||||
return ab.totalBytes >= maxBufferSize
|
||||
}
|
||||
|
||||
func (ab *AudioBuffer) getAudio() []byte {
|
||||
ab.mu.Lock()
|
||||
defer ab.mu.Unlock()
|
||||
var total int
|
||||
for _, c := range ab.chunks {
|
||||
total += len(c)
|
||||
}
|
||||
result := make([]byte, 0, total)
|
||||
for _, c := range ab.chunks {
|
||||
result = append(result, c...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (ab *AudioBuffer) clear() {
|
||||
ab.mu.Lock()
|
||||
defer ab.mu.Unlock()
|
||||
ab.chunks = nil
|
||||
ab.totalBytes = 0
|
||||
ab.sequence++
|
||||
}
|
||||
|
||||
func (ab *AudioBuffer) markComplete() {
|
||||
ab.mu.Lock()
|
||||
defer ab.mu.Unlock()
|
||||
ab.isComplete = true
|
||||
}
|
||||
|
||||
// calculateAudioRMS computes RMS level for 16-bit PCM audio, normalized to 0.0-1.0.
|
||||
func calculateAudioRMS(audio []byte) float64 {
|
||||
if len(audio) < 2 {
|
||||
return 0.0
|
||||
}
|
||||
numSamples := len(audio) / 2
|
||||
var sumSquares float64
|
||||
for i := 0; i < numSamples; i++ {
|
||||
sample := int16(binary.LittleEndian.Uint16(audio[i*2 : i*2+2]))
|
||||
sumSquares += float64(sample) * float64(sample)
|
||||
}
|
||||
rms := math.Sqrt(sumSquares / float64(numSamples))
|
||||
return rms / 32768.0
|
||||
}
|
||||
|
||||
// detectVoiceActivity uses RMS-based voice detection (pure Go, no cgo).
|
||||
func detectVoiceActivity(audio []byte) bool {
|
||||
rms := calculateAudioRMS(audio)
|
||||
return rms > 0.01 // reasonable threshold for speech
|
||||
}
|
||||
|
||||
func main() {
|
||||
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
||||
|
||||
cfg := config.Load()
|
||||
cfg.ServiceName = "stt-streaming"
|
||||
|
||||
whisperURL := getEnv("WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local")
|
||||
bufferSizeBytes := getEnvInt("STT_BUFFER_SIZE_BYTES", 512000)
|
||||
chunkTimeout := getEnvFloat("STT_CHUNK_TIMEOUT", 2.0)
|
||||
maxBufferSize := getEnvInt("STT_MAX_BUFFER_SIZE", 5120000)
|
||||
enableInterrupt := getEnvBool("STT_ENABLE_INTERRUPT_DETECTION", true)
|
||||
audioLevelThreshold := getEnvFloat("STT_AUDIO_LEVEL_THRESHOLD", 0.02)
|
||||
interruptDuration := getEnvFloat("STT_INTERRUPT_DURATION", 0.5)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Telemetry
|
||||
tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{
|
||||
ServiceName: cfg.ServiceName,
|
||||
ServiceVersion: cfg.ServiceVersion,
|
||||
ServiceNamespace: cfg.ServiceNamespace,
|
||||
DeploymentEnv: cfg.DeploymentEnv,
|
||||
Enabled: cfg.OTELEnabled,
|
||||
Endpoint: cfg.OTELEndpoint,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("telemetry setup failed", "error", err)
|
||||
}
|
||||
if shutdown != nil {
|
||||
defer shutdown(ctx)
|
||||
}
|
||||
_ = tp
|
||||
|
||||
// NATS
|
||||
natsOpts := []nats.Option{}
|
||||
if cfg.NATSUser != "" && cfg.NATSPassword != "" {
|
||||
natsOpts = append(natsOpts, nats.UserInfo(cfg.NATSUser, cfg.NATSPassword))
|
||||
}
|
||||
nc := natsutil.New(cfg.NATSURL, natsOpts...)
|
||||
if err := nc.Connect(); err != nil {
|
||||
slog.Error("NATS connect failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer nc.Close()
|
||||
|
||||
// JetStream stream setup
|
||||
js, err := nc.Conn().JetStream()
|
||||
if err == nil {
|
||||
_, err = js.AddStream(&nats.StreamConfig{
|
||||
Name: "AI_VOICE_STREAM",
|
||||
Subjects: []string{"ai.voice.stream.>", "ai.voice.transcription.>"},
|
||||
Retention: nats.LimitsPolicy,
|
||||
MaxAge: 5 * time.Minute,
|
||||
Storage: nats.MemoryStorage,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Info("JetStream stream setup", "msg", err)
|
||||
} else {
|
||||
slog.Info("created/updated JetStream stream AI_VOICE_STREAM")
|
||||
}
|
||||
}
|
||||
|
||||
httpClient := &http.Client{Timeout: 180 * time.Second}
|
||||
running := true
|
||||
|
||||
// Session management
|
||||
var sessionsMu sync.RWMutex
|
||||
sessions := make(map[string]*AudioBuffer)
|
||||
|
||||
// Health server
|
||||
healthSrv := health.New(cfg.HealthPort, cfg.HealthPath, cfg.ReadyPath, func() bool {
|
||||
return running && nc.IsConnected()
|
||||
})
|
||||
healthSrv.Start()
|
||||
defer healthSrv.Stop(ctx)
|
||||
|
||||
// Transcribe via Whisper HTTP
|
||||
transcribe := func(audioData []byte) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
w := multipart.NewWriter(&buf)
|
||||
part, err := w.CreateFormFile("file", "audio.wav")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
part.Write(audioData)
|
||||
w.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, whisperURL+"/v1/audio/transcriptions", &buf)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("whisper request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return "", fmt.Errorf("whisper %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return result.Text, nil
|
||||
}
|
||||
|
||||
// Process buffer for a session
|
||||
processBuffer := func(sessionID string) {
|
||||
sessionsMu.RLock()
|
||||
buffer, ok := sessions[sessionID]
|
||||
sessionsMu.RUnlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
audioData := buffer.getAudio()
|
||||
if len(audioData) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
buffer.mu.Lock()
|
||||
seq := buffer.sequence
|
||||
complete := buffer.isComplete
|
||||
speakerID := buffer.speakerID
|
||||
hasVoice := buffer.hasVoiceActivity
|
||||
state := buffer.state
|
||||
buffer.mu.Unlock()
|
||||
|
||||
slog.Info("processing buffer", "session", sessionID, "bytes", len(audioData), "seq", seq)
|
||||
|
||||
transcript, err := transcribe(audioData)
|
||||
if err != nil {
|
||||
slog.Error("transcription failed", "session", sessionID, "error", err)
|
||||
}
|
||||
|
||||
if transcript != "" {
|
||||
result := map[string]any{
|
||||
"session_id": sessionID,
|
||||
"transcript": transcript,
|
||||
"sequence": seq,
|
||||
"is_partial": !complete,
|
||||
"is_final": complete,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"speaker_id": speakerID,
|
||||
"has_voice_activity": hasVoice,
|
||||
"state": state,
|
||||
}
|
||||
packed, _ := msgpack.Marshal(result)
|
||||
nc.Conn().Publish(fmt.Sprintf("%s.%s", transcriptionSubjectPrefix, sessionID), packed)
|
||||
slog.Info("published transcription", "session", sessionID, "seq", seq)
|
||||
}
|
||||
|
||||
buffer.clear()
|
||||
|
||||
if complete {
|
||||
slog.Info("session completed", "session", sessionID)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Second)
|
||||
sessionsMu.Lock()
|
||||
delete(sessions, sessionID)
|
||||
sessionsMu.Unlock()
|
||||
slog.Info("cleaned up session", "session", sessionID)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Monitor goroutine for a session
|
||||
monitorBuffer := func(sessionID string) {
|
||||
for running {
|
||||
sessionsMu.RLock()
|
||||
buffer, ok := sessions[sessionID]
|
||||
sessionsMu.RUnlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if buffer.shouldProcess(bufferSizeBytes, maxBufferSize, chunkTimeout) {
|
||||
processBuffer(sessionID)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle incoming stream messages
|
||||
handleStreamMsg := func(natMsg *nats.Msg) {
|
||||
parts := strings.Split(natMsg.Subject, ".")
|
||||
if len(parts) < 4 {
|
||||
slog.Warn("invalid subject", "subject", natMsg.Subject)
|
||||
return
|
||||
}
|
||||
sessionID := parts[3]
|
||||
|
||||
data, err := natsutil.DecodeMsgpackMap(natMsg.Data)
|
||||
if err != nil {
|
||||
slog.Error("decode error", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
msgType := ""
|
||||
if t, ok := data["type"].(string); ok {
|
||||
msgType = t
|
||||
}
|
||||
|
||||
switch msgType {
|
||||
case "start":
|
||||
slog.Info("starting stream session", "session", sessionID)
|
||||
buf := newAudioBuffer(sessionID)
|
||||
if s, ok := data["state"].(string); ok {
|
||||
buf.setState(s)
|
||||
}
|
||||
if s, ok := data["speaker_id"].(string); ok {
|
||||
buf.speakerID = s
|
||||
}
|
||||
sessionsMu.Lock()
|
||||
sessions[sessionID] = buf
|
||||
sessionsMu.Unlock()
|
||||
go monitorBuffer(sessionID)
|
||||
|
||||
case "state_change":
|
||||
sessionsMu.RLock()
|
||||
buffer, ok := sessions[sessionID]
|
||||
sessionsMu.RUnlock()
|
||||
if ok {
|
||||
if s, ok := data["state"].(string); ok {
|
||||
buffer.setState(s)
|
||||
}
|
||||
}
|
||||
|
||||
case "end":
|
||||
slog.Info("ending stream session", "session", sessionID)
|
||||
sessionsMu.RLock()
|
||||
buffer, ok := sessions[sessionID]
|
||||
sessionsMu.RUnlock()
|
||||
if ok {
|
||||
buffer.markComplete()
|
||||
buffer.mu.Lock()
|
||||
hasData := buffer.totalBytes > 0
|
||||
buffer.mu.Unlock()
|
||||
if hasData {
|
||||
processBuffer(sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
case "chunk":
|
||||
audioB64 := ""
|
||||
if s, ok := data["audio_b64"].(string); ok {
|
||||
audioB64 = s
|
||||
}
|
||||
if audioB64 == "" {
|
||||
return
|
||||
}
|
||||
audioBytes, err := base64.StdEncoding.DecodeString(audioB64)
|
||||
if err != nil {
|
||||
slog.Error("base64 decode failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Auto-create session if missing
|
||||
sessionsMu.Lock()
|
||||
buffer, ok := sessions[sessionID]
|
||||
if !ok {
|
||||
slog.Info("auto-creating session", "session", sessionID)
|
||||
buffer = newAudioBuffer(sessionID)
|
||||
sessions[sessionID] = buffer
|
||||
go monitorBuffer(sessionID)
|
||||
}
|
||||
sessionsMu.Unlock()
|
||||
|
||||
// Check for interrupt
|
||||
if buffer.checkInterrupt(audioBytes, enableInterrupt, audioLevelThreshold, interruptDuration) {
|
||||
interruptMsg := map[string]any{
|
||||
"session_id": sessionID,
|
||||
"type": "interrupt",
|
||||
"timestamp": time.Now().Unix(),
|
||||
"speaker_id": buffer.speakerID,
|
||||
}
|
||||
packed, _ := msgpack.Marshal(interruptMsg)
|
||||
nc.Conn().Publish(fmt.Sprintf("%s.%s", transcriptionSubjectPrefix, sessionID), packed)
|
||||
slog.Info("published interrupt", "session", sessionID)
|
||||
buffer.setState(stateListening)
|
||||
}
|
||||
|
||||
buffer.addChunk(audioBytes)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := nc.Conn().Subscribe(streamSubjectPrefix+".>", handleStreamMsg); err != nil {
|
||||
slog.Error("subscribe failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
slog.Info("subscribed", "subject", streamSubjectPrefix+".>")
|
||||
slog.Info("stt-streaming ready")
|
||||
|
||||
// Wait for shutdown
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
||||
<-sigCh
|
||||
|
||||
slog.Info("shutting down")
|
||||
running = false
|
||||
cancel()
|
||||
slog.Info("shutdown complete")
|
||||
}
|
||||
|
||||
// Helpers
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getEnvInt(key string, fallback int) int {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getEnvFloat(key string, fallback float64) float64 {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return f
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getEnvBool(key string, fallback bool) bool {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return strings.EqualFold(v, "true") || v == "1"
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
Reference in New Issue
Block a user