package main import ( "bytes" "context" "encoding/binary" "encoding/json" "fmt" "io" "log/slog" "math" "mime/multipart" "net/http" "os" "os/signal" "strconv" "strings" "sync" "syscall" "time" "github.com/nats-io/nats.go" "github.com/vmihailenco/msgpack/v5" "git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/health" "git.daviestechlabs.io/daviestechlabs/handler-base/messages" "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" "git.daviestechlabs.io/daviestechlabs/handler-base/telemetry" ) // NATS subjects const ( streamSubjectPrefix = "ai.voice.stream" transcriptionSubjectPrefix = "ai.voice.transcription" ) // Session states const ( stateListening = "listening" stateResponding = "responding" ) // AudioBuffer manages audio chunks for a streaming session. type AudioBuffer struct { mu sync.Mutex sessionID string chunks [][]byte totalBytes int lastChunkTime time.Time isComplete bool sequence int state string speakerID string interruptStartTime *time.Time hasVoiceActivity bool } func newAudioBuffer(sessionID string) *AudioBuffer { return &AudioBuffer{ sessionID: sessionID, lastChunkTime: time.Now(), state: stateListening, } } func (ab *AudioBuffer) addChunk(audio []byte) { ab.mu.Lock() defer ab.mu.Unlock() ab.chunks = append(ab.chunks, audio) ab.totalBytes += len(audio) ab.lastChunkTime = time.Now() ab.hasVoiceActivity = detectVoiceActivity(audio) } func (ab *AudioBuffer) checkInterrupt(audio []byte, enableInterrupt bool, audioThreshold, interruptDuration float64) bool { if !enableInterrupt { return false } ab.mu.Lock() defer ab.mu.Unlock() if ab.state != stateResponding { return false } rms := calculateAudioRMS(audio) hasVoice := detectVoiceActivity(audio) if rms >= audioThreshold && hasVoice { if ab.interruptStartTime == nil { now := time.Now() ab.interruptStartTime = &now slog.Info("potential interrupt detected", "session", ab.sessionID, "rms", rms) } elapsed := time.Since(*ab.interruptStartTime).Seconds() if elapsed >= interruptDuration { slog.Info("interrupt confirmed", "session", ab.sessionID, "elapsed", elapsed) return true } } else { ab.interruptStartTime = nil } return false } func (ab *AudioBuffer) setState(state string) { ab.mu.Lock() defer ab.mu.Unlock() if state != stateListening && state != stateResponding { return } if ab.state != state { slog.Info("state changed", "session", ab.sessionID, "from", ab.state, "to", state) ab.state = state ab.interruptStartTime = nil } } func (ab *AudioBuffer) shouldProcess(bufferSize, maxBufferSize int, chunkTimeout float64) bool { ab.mu.Lock() defer ab.mu.Unlock() if !ab.hasVoiceActivity && ab.totalBytes < bufferSize && time.Since(ab.lastChunkTime).Seconds() < chunkTimeout { return false } if ab.totalBytes >= bufferSize { return true } if time.Since(ab.lastChunkTime).Seconds() > chunkTimeout && ab.totalBytes > 0 { return true } return ab.totalBytes >= maxBufferSize } func (ab *AudioBuffer) getAudio() []byte { ab.mu.Lock() defer ab.mu.Unlock() result := make([]byte, 0, ab.totalBytes) for _, c := range ab.chunks { result = append(result, c...) } return result } func (ab *AudioBuffer) clear() { ab.mu.Lock() defer ab.mu.Unlock() ab.chunks = nil ab.totalBytes = 0 ab.sequence++ } func (ab *AudioBuffer) markComplete() { ab.mu.Lock() defer ab.mu.Unlock() ab.isComplete = true } // calculateAudioRMS computes RMS level for 16-bit PCM audio, normalized to 0.0-1.0. func calculateAudioRMS(audio []byte) float64 { if len(audio) < 2 { return 0.0 } numSamples := len(audio) / 2 var sumSquares float64 for i := 0; i < numSamples; i++ { sample := int16(binary.LittleEndian.Uint16(audio[i*2 : i*2+2])) sumSquares += float64(sample) * float64(sample) } rms := math.Sqrt(sumSquares / float64(numSamples)) return rms / 32768.0 } // detectVoiceActivity uses RMS-based voice detection (pure Go, no cgo). func detectVoiceActivity(audio []byte) bool { rms := calculateAudioRMS(audio) return rms > 0.01 // reasonable threshold for speech } func main() { slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}))) cfg := config.Load() cfg.ServiceName = "stt-streaming" whisperURL := getEnv("WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local") bufferSizeBytes := getEnvInt("STT_BUFFER_SIZE_BYTES", 512000) chunkTimeout := getEnvFloat("STT_CHUNK_TIMEOUT", 2.0) maxBufferSize := getEnvInt("STT_MAX_BUFFER_SIZE", 5120000) enableInterrupt := getEnvBool("STT_ENABLE_INTERRUPT_DETECTION", true) audioLevelThreshold := getEnvFloat("STT_AUDIO_LEVEL_THRESHOLD", 0.02) interruptDuration := getEnvFloat("STT_INTERRUPT_DURATION", 0.5) ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Telemetry tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{ ServiceName: cfg.ServiceName, ServiceVersion: cfg.ServiceVersion, ServiceNamespace: cfg.ServiceNamespace, DeploymentEnv: cfg.DeploymentEnv, Enabled: cfg.OTELEnabled, Endpoint: cfg.OTELEndpoint, }) if err != nil { slog.Error("telemetry setup failed", "error", err) } if shutdown != nil { defer shutdown(ctx) } _ = tp // NATS natsOpts := []nats.Option{} if cfg.NATSUser != "" && cfg.NATSPassword != "" { natsOpts = append(natsOpts, nats.UserInfo(cfg.NATSUser, cfg.NATSPassword)) } nc := natsutil.New(cfg.NATSURL, natsOpts...) if err := nc.Connect(); err != nil { slog.Error("NATS connect failed", "error", err) os.Exit(1) } defer nc.Close() // JetStream stream setup js, err := nc.Conn().JetStream() if err == nil { _, err = js.AddStream(&nats.StreamConfig{ Name: "AI_VOICE_STREAM", Subjects: []string{"ai.voice.stream.>", "ai.voice.transcription.>"}, Retention: nats.LimitsPolicy, MaxAge: 5 * time.Minute, Storage: nats.MemoryStorage, }) if err != nil { slog.Info("JetStream stream setup", "msg", err) } else { slog.Info("created/updated JetStream stream AI_VOICE_STREAM") } } httpClient := &http.Client{Timeout: 180 * time.Second} running := true // Session management var sessionsMu sync.RWMutex sessions := make(map[string]*AudioBuffer) // Health server healthSrv := health.New(cfg.HealthPort, cfg.HealthPath, cfg.ReadyPath, func() bool { return running && nc.IsConnected() }) healthSrv.Start() defer healthSrv.Stop(ctx) // Transcribe via Whisper HTTP transcribe := func(audioData []byte) (string, error) { var buf bytes.Buffer w := multipart.NewWriter(&buf) part, err := w.CreateFormFile("file", "audio.wav") if err != nil { return "", err } _, _ = part.Write(audioData) _ = w.Close() req, err := http.NewRequestWithContext(ctx, http.MethodPost, whisperURL+"/v1/audio/transcriptions", &buf) if err != nil { return "", err } req.Header.Set("Content-Type", w.FormDataContentType()) resp, err := httpClient.Do(req) if err != nil { return "", fmt.Errorf("whisper request: %w", err) } defer func() { _ = resp.Body.Close() }() body, _ := io.ReadAll(resp.Body) if resp.StatusCode >= 400 { return "", fmt.Errorf("whisper %d: %s", resp.StatusCode, string(body)) } var result struct { Text string `json:"text"` } if err := json.Unmarshal(body, &result); err != nil { return "", err } return result.Text, nil } // Process buffer for a session processBuffer := func(sessionID string) { sessionsMu.RLock() buffer, ok := sessions[sessionID] sessionsMu.RUnlock() if !ok { return } audioData := buffer.getAudio() if len(audioData) == 0 { return } buffer.mu.Lock() seq := buffer.sequence complete := buffer.isComplete speakerID := buffer.speakerID hasVoice := buffer.hasVoiceActivity state := buffer.state buffer.mu.Unlock() slog.Info("processing buffer", "session", sessionID, "bytes", len(audioData), "seq", seq) transcript, err := transcribe(audioData) if err != nil { slog.Error("transcription failed", "session", sessionID, "error", err) } if transcript != "" { result := &messages.STTTranscription{ SessionID: sessionID, Transcript: transcript, Sequence: seq, IsPartial: !complete, IsFinal: complete, Timestamp: time.Now().Unix(), SpeakerID: speakerID, HasVoiceActivity: hasVoice, State: state, } packed, _ := msgpack.Marshal(result) _ = nc.Conn().Publish(fmt.Sprintf("%s.%s", transcriptionSubjectPrefix, sessionID), packed) slog.Info("published transcription", "session", sessionID, "seq", seq) } buffer.clear() if complete { slog.Info("session completed", "session", sessionID) go func() { time.Sleep(5 * time.Second) sessionsMu.Lock() delete(sessions, sessionID) sessionsMu.Unlock() slog.Info("cleaned up session", "session", sessionID) }() } } // Monitor goroutine for a session monitorBuffer := func(sessionID string) { for running { sessionsMu.RLock() buffer, ok := sessions[sessionID] sessionsMu.RUnlock() if !ok { return } if buffer.shouldProcess(bufferSizeBytes, maxBufferSize, chunkTimeout) { processBuffer(sessionID) } time.Sleep(100 * time.Millisecond) } } // Handle incoming stream messages handleStreamMsg := func(natMsg *nats.Msg) { parts := strings.Split(natMsg.Subject, ".") if len(parts) < 4 { slog.Warn("invalid subject", "subject", natMsg.Subject) return } sessionID := parts[3] streamMsg, err := natsutil.Decode[messages.STTStreamMessage](natMsg.Data) if err != nil { slog.Error("decode error", "error", err) return } switch streamMsg.Type { case "start": slog.Info("starting stream session", "session", sessionID) buf := newAudioBuffer(sessionID) if streamMsg.State != "" { buf.setState(streamMsg.State) } if streamMsg.SpeakerID != "" { buf.speakerID = streamMsg.SpeakerID } sessionsMu.Lock() sessions[sessionID] = buf sessionsMu.Unlock() go monitorBuffer(sessionID) case "state_change": sessionsMu.RLock() buffer, ok := sessions[sessionID] sessionsMu.RUnlock() if ok && streamMsg.State != "" { buffer.setState(streamMsg.State) } case "end": slog.Info("ending stream session", "session", sessionID) sessionsMu.RLock() buffer, ok := sessions[sessionID] sessionsMu.RUnlock() if ok { buffer.markComplete() buffer.mu.Lock() hasData := buffer.totalBytes > 0 buffer.mu.Unlock() if hasData { processBuffer(sessionID) } } case "chunk": // Audio arrives as raw bytes — no base64 decode needed if len(streamMsg.Audio) == 0 { return } // Auto-create session if missing sessionsMu.Lock() buffer, ok := sessions[sessionID] if !ok { slog.Info("auto-creating session", "session", sessionID) buffer = newAudioBuffer(sessionID) sessions[sessionID] = buffer go monitorBuffer(sessionID) } sessionsMu.Unlock() // Check for interrupt if buffer.checkInterrupt(streamMsg.Audio, enableInterrupt, audioLevelThreshold, interruptDuration) { interruptMsg := &messages.STTInterrupt{ SessionID: sessionID, Type: "interrupt", Timestamp: time.Now().Unix(), SpeakerID: buffer.speakerID, } packed, _ := msgpack.Marshal(interruptMsg) _ = nc.Conn().Publish(fmt.Sprintf("%s.%s", transcriptionSubjectPrefix, sessionID), packed) slog.Info("published interrupt", "session", sessionID) buffer.setState(stateListening) } buffer.addChunk(streamMsg.Audio) } } if _, err := nc.Conn().Subscribe(streamSubjectPrefix+".>", handleStreamMsg); err != nil { slog.Error("subscribe failed", "error", err) os.Exit(1) } slog.Info("subscribed", "subject", streamSubjectPrefix+".>") slog.Info("stt-streaming ready") // Wait for shutdown sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) <-sigCh slog.Info("shutting down") running = false cancel() slog.Info("shutdown complete") } // Helpers func getEnv(key, fallback string) string { if v := os.Getenv(key); v != "" { return v } return fallback } func getEnvInt(key string, fallback int) int { if v := os.Getenv(key); v != "" { if i, err := strconv.Atoi(v); err == nil { return i } } return fallback } func getEnvFloat(key string, fallback float64) float64 { if v := os.Getenv(key); v != "" { if f, err := strconv.ParseFloat(v, 64); err == nil { return f } } return fallback } func getEnvBool(key string, fallback bool) bool { if v := os.Getenv(key); v != "" { return strings.EqualFold(v, "true") || v == "1" } return fallback }