diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..7ea5baa --- /dev/null +++ b/.dockerignore @@ -0,0 +1,9 @@ +.git +.gitignore +*.md +LICENSE +renovate.json +*_test.go +e2e_test.go +__pycache__ +.env* diff --git a/Dockerfile b/Dockerfile index 2a662f3..eac4d88 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ RUN go mod download COPY . . -RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -o /stt-module . +RUN CGO_ENABLED=0 GOOS=linux GOAMD64=v3 go build -ldflags="-w -s" -o /stt-module . # Runtime stage FROM scratch diff --git a/main.go b/main.go index 6117214..63bf3a2 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package main import ( "bytes" "context" - "encoding/base64" "encoding/binary" "encoding/json" "fmt" @@ -25,6 +24,7 @@ import ( "git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/health" + "git.daviestechlabs.io/daviestechlabs/handler-base/messages" "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" "git.daviestechlabs.io/daviestechlabs/handler-base/telemetry" ) @@ -137,11 +137,7 @@ func (ab *AudioBuffer) shouldProcess(bufferSize, maxBufferSize int, chunkTimeout func (ab *AudioBuffer) getAudio() []byte { ab.mu.Lock() defer ab.mu.Unlock() - var total int - for _, c := range ab.chunks { - total += len(c) - } - result := make([]byte, 0, total) + result := make([]byte, 0, ab.totalBytes) for _, c := range ab.chunks { result = append(result, c...) } @@ -327,16 +323,16 @@ func main() { } if transcript != "" { - result := map[string]any{ - "session_id": sessionID, - "transcript": transcript, - "sequence": seq, - "is_partial": !complete, - "is_final": complete, - "timestamp": time.Now().Unix(), - "speaker_id": speakerID, - "has_voice_activity": hasVoice, - "state": state, + result := &messages.STTTranscription{ + SessionID: sessionID, + Transcript: transcript, + Sequence: seq, + IsPartial: !complete, + IsFinal: complete, + Timestamp: time.Now().Unix(), + SpeakerID: speakerID, + HasVoiceActivity: hasVoice, + State: state, } packed, _ := msgpack.Marshal(result) nc.Conn().Publish(fmt.Sprintf("%s.%s", transcriptionSubjectPrefix, sessionID), packed) @@ -382,26 +378,21 @@ func main() { } sessionID := parts[3] - data, err := natsutil.DecodeMsgpackMap(natMsg.Data) + streamMsg, err := natsutil.Decode[messages.STTStreamMessage](natMsg.Data) if err != nil { slog.Error("decode error", "error", err) return } - msgType := "" - if t, ok := data["type"].(string); ok { - msgType = t - } - - switch msgType { + switch streamMsg.Type { case "start": slog.Info("starting stream session", "session", sessionID) buf := newAudioBuffer(sessionID) - if s, ok := data["state"].(string); ok { - buf.setState(s) + if streamMsg.State != "" { + buf.setState(streamMsg.State) } - if s, ok := data["speaker_id"].(string); ok { - buf.speakerID = s + if streamMsg.SpeakerID != "" { + buf.speakerID = streamMsg.SpeakerID } sessionsMu.Lock() sessions[sessionID] = buf @@ -412,10 +403,8 @@ func main() { sessionsMu.RLock() buffer, ok := sessions[sessionID] sessionsMu.RUnlock() - if ok { - if s, ok := data["state"].(string); ok { - buffer.setState(s) - } + if ok && streamMsg.State != "" { + buffer.setState(streamMsg.State) } case "end": @@ -434,16 +423,8 @@ func main() { } case "chunk": - audioB64 := "" - if s, ok := data["audio_b64"].(string); ok { - audioB64 = s - } - if audioB64 == "" { - return - } - audioBytes, err := base64.StdEncoding.DecodeString(audioB64) - if err != nil { - slog.Error("base64 decode failed", "error", err) + // Audio arrives as raw bytes — no base64 decode needed + if len(streamMsg.Audio) == 0 { return } @@ -459,12 +440,12 @@ func main() { sessionsMu.Unlock() // Check for interrupt - if buffer.checkInterrupt(audioBytes, enableInterrupt, audioLevelThreshold, interruptDuration) { - interruptMsg := map[string]any{ - "session_id": sessionID, - "type": "interrupt", - "timestamp": time.Now().Unix(), - "speaker_id": buffer.speakerID, + if buffer.checkInterrupt(streamMsg.Audio, enableInterrupt, audioLevelThreshold, interruptDuration) { + interruptMsg := &messages.STTInterrupt{ + SessionID: sessionID, + Type: "interrupt", + Timestamp: time.Now().Unix(), + SpeakerID: buffer.speakerID, } packed, _ := msgpack.Marshal(interruptMsg) nc.Conn().Publish(fmt.Sprintf("%s.%s", transcriptionSubjectPrefix, sessionID), packed) @@ -472,7 +453,7 @@ func main() { buffer.setState(stateListening) } - buffer.addChunk(audioBytes) + buffer.addChunk(streamMsg.Audio) } }