Files
tts-module/main.go
Billy D. 85b481b6c4
Some checks failed
CI / Lint (pull_request) Failing after 1m1s
CI / Test (pull_request) Failing after 1m21s
CI / Release (pull_request) Has been skipped
CI / Docker Build & Push (pull_request) Has been skipped
CI / Notify (pull_request) Successful in 1s
feat: migrate to typed messages, drop base64
- Decode TTSRequest via natsutil.Decode[messages.TTSRequest]
- Stream audio as raw bytes via messages.TTSAudioChunk (no base64)
- Non-stream response uses messages.TTSFullResponse
- Status updates use messages.TTSStatus
- Voice list/refresh use messages.TTSVoiceListResponse/TTSVoiceRefreshResponse
- Registry returns []messages.TTSVoiceInfo (not []map[string]any)
- Remove strVal/boolVal helpers
- Add .dockerignore, GOAMD64=v3 in Dockerfile
- Update tests for typed structs (13 tests pass)
2026-02-20 07:11:13 -05:00

447 lines
12 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"os/signal"
"path/filepath"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/nats-io/nats.go"
"github.com/vmihailenco/msgpack/v5"
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
)
// NATS subject prefixes
const (
requestSubjectPrefix = "ai.voice.tts.request"
audioSubjectPrefix = "ai.voice.tts.audio"
statusSubjectPrefix = "ai.voice.tts.status"
voicesListSubject = "ai.voice.tts.voices.list"
voicesRefreshSubject = "ai.voice.tts.voices.refresh"
)
// CustomVoice is a trained voice from the coqui-voice-training pipeline.
type CustomVoice struct {
Name string `json:"name"`
ModelPath string `json:"model_path"`
ConfigPath string `json:"config_path"`
CreatedAt string `json:"created_at"`
Language string `json:"language"`
ModelType string `json:"model_type"`
}
// VoiceRegistry discovers custom voices from the model store directory.
type VoiceRegistry struct {
modelStore string
mu sync.RWMutex
voices map[string]*CustomVoice
lastRefresh time.Time
}
func newVoiceRegistry(modelStore string) *VoiceRegistry {
return &VoiceRegistry{
modelStore: modelStore,
voices: make(map[string]*CustomVoice),
}
}
func (vr *VoiceRegistry) refresh() int {
vr.mu.Lock()
defer vr.mu.Unlock()
entries, err := os.ReadDir(vr.modelStore)
if err != nil {
slog.Warn("voice model store not found", "path", vr.modelStore, "error", err)
return 0
}
discovered := make(map[string]*CustomVoice)
for _, entry := range entries {
if !entry.IsDir() {
continue
}
voiceDir := filepath.Join(vr.modelStore, entry.Name())
infoPath := filepath.Join(voiceDir, "model_info.json")
infoData, err := os.ReadFile(infoPath)
if err != nil {
continue
}
var info map[string]string
if err := json.Unmarshal(infoData, &info); err != nil {
slog.Error("bad model_info.json", "dir", voiceDir, "error", err)
continue
}
modelPath := filepath.Join(voiceDir, "model.pth")
if _, err := os.Stat(modelPath); err != nil {
slog.Warn("model file missing", "voice", entry.Name())
continue
}
configPath := filepath.Join(voiceDir, "config.json")
if _, err := os.Stat(configPath); err != nil {
configPath = ""
}
name := info["name"]
if name == "" {
name = entry.Name()
}
discovered[name] = &CustomVoice{
Name: name,
ModelPath: modelPath,
ConfigPath: configPath,
CreatedAt: info["created_at"],
Language: orDefault(info["language"], "en"),
ModelType: orDefault(info["type"], "coqui-tts"),
}
}
vr.voices = discovered
vr.lastRefresh = time.Now()
slog.Info("voice registry refreshed", "count", len(discovered))
return len(discovered)
}
func (vr *VoiceRegistry) get(name string) *CustomVoice {
vr.mu.RLock()
defer vr.mu.RUnlock()
return vr.voices[name]
}
func (vr *VoiceRegistry) listVoices() []messages.TTSVoiceInfo {
vr.mu.RLock()
defer vr.mu.RUnlock()
result := make([]messages.TTSVoiceInfo, 0, len(vr.voices))
for _, v := range vr.voices {
result = append(result, messages.TTSVoiceInfo{
Name: v.Name,
Language: v.Language,
ModelType: v.ModelType,
CreatedAt: v.CreatedAt,
})
}
return result
}
func main() {
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})))
cfg := config.Load()
cfg.ServiceName = "tts-streaming"
xttsURL := getEnv("XTTS_URL", "http://xtts-predictor.ai-ml.svc.cluster.local")
defaultSpeaker := getEnv("TTS_DEFAULT_SPEAKER", "default")
defaultLanguage := getEnv("TTS_DEFAULT_LANGUAGE", "en")
audioChunkSize := getEnvInt("TTS_AUDIO_CHUNK_SIZE", 32768)
sampleRate := getEnvInt("TTS_SAMPLE_RATE", 24000)
modelStore := getEnv("VOICE_MODEL_STORE", "/models/tts/custom")
refreshInterval := getEnvInt("VOICE_REGISTRY_REFRESH_SECONDS", 300)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Telemetry
tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{
ServiceName: cfg.ServiceName,
ServiceVersion: cfg.ServiceVersion,
ServiceNamespace: cfg.ServiceNamespace,
DeploymentEnv: cfg.DeploymentEnv,
Enabled: cfg.OTELEnabled,
Endpoint: cfg.OTELEndpoint,
})
if err != nil {
slog.Error("telemetry setup failed", "error", err)
}
if shutdown != nil {
defer shutdown(ctx)
}
_ = tp // available for future span creation
// Voice registry
registry := newVoiceRegistry(modelStore)
registry.refresh()
// NATS
natsOpts := []nats.Option{}
if cfg.NATSUser != "" && cfg.NATSPassword != "" {
natsOpts = append(natsOpts, nats.UserInfo(cfg.NATSUser, cfg.NATSPassword))
}
nc := natsutil.New(cfg.NATSURL, natsOpts...)
if err := nc.Connect(); err != nil {
slog.Error("NATS connect failed", "error", err)
os.Exit(1)
}
defer nc.Close()
// JetStream stream setup
js, err := nc.Conn().JetStream()
if err == nil {
_, err = js.AddStream(&nats.StreamConfig{
Name: "AI_VOICE_TTS",
Subjects: []string{"ai.voice.tts.>"},
Retention: nats.LimitsPolicy,
MaxAge: 5 * time.Minute,
Storage: nats.MemoryStorage,
})
if err != nil {
slog.Info("JetStream stream setup", "msg", err)
} else {
slog.Info("created/updated JetStream stream AI_VOICE_TTS")
}
}
httpClient := &http.Client{Timeout: 180 * time.Second}
running := true
// Health server
healthSrv := health.New(cfg.HealthPort, cfg.HealthPath, cfg.ReadyPath, func() bool {
return running && nc.IsConnected()
})
healthSrv.Start()
defer healthSrv.Stop(ctx)
// Helper: publish status
publishStatus := func(sessionID, status, message string) {
statusMsg := &messages.TTSStatus{
SessionID: sessionID,
Status: status,
Message: message,
Timestamp: time.Now().Unix(),
}
_ = nc.Publish(fmt.Sprintf("%s.%s", statusSubjectPrefix, sessionID), statusMsg)
}
// Helper: synthesize via XTTS HTTP API
synthesize := func(ctx context.Context, text, speaker, language, speakerWavB64 string) ([]byte, error) {
payload := map[string]any{
"text": text,
"speaker": speaker,
"language": language,
}
customVoice := registry.get(speaker)
if customVoice != nil {
payload["model_path"] = customVoice.ModelPath
if customVoice.ConfigPath != "" {
payload["config_path"] = customVoice.ConfigPath
}
payload["language"] = orDefault(language, customVoice.Language)
slog.Info("using custom voice", "voice", customVoice.Name)
} else if speakerWavB64 != "" {
payload["speaker_wav"] = speakerWavB64
}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, xttsURL+"/v1/audio/speech", bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("xtts request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("xtts %d: %s", resp.StatusCode, string(respBody))
}
return io.ReadAll(resp.Body)
}
// Helper: stream audio chunks — raw bytes, no base64
streamAudio := func(sessionID string, audioBytes []byte) {
totalChunks := (len(audioBytes) + audioChunkSize - 1) / audioChunkSize
for i := 0; i < len(audioBytes); i += audioChunkSize {
end := i + audioChunkSize
if end > len(audioBytes) {
end = len(audioBytes)
}
chunk := audioBytes[i:end]
chunkIndex := i / audioChunkSize
isLast := end >= len(audioBytes)
msg := &messages.TTSAudioChunk{
SessionID: sessionID,
ChunkIndex: chunkIndex,
TotalChunks: totalChunks,
Audio: chunk,
IsLast: isLast,
Timestamp: time.Now().Unix(),
SampleRate: sampleRate,
}
_ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg)
}
slog.Info("streamed audio", "session", sessionID, "chunks", totalChunks)
}
// Subscribe: TTS requests
handleRequest := func(natMsg *nats.Msg) {
parts := strings.Split(natMsg.Subject, ".")
if len(parts) < 5 {
slog.Warn("invalid subject format", "subject", natMsg.Subject)
return
}
sessionID := parts[4]
req, err := natsutil.Decode[messages.TTSRequest](natMsg.Data)
if err != nil {
slog.Error("decode error", "error", err)
return
}
text := req.Text
speaker := orDefault(req.Speaker, defaultSpeaker)
language := orDefault(req.Language, defaultLanguage)
speakerWavB64 := req.SpeakerWavB64
stream := req.Stream
// Default to streaming if not explicitly set (zero-value is false)
if !stream && text != "" {
stream = true
}
if text == "" {
slog.Warn("empty text", "session", sessionID)
publishStatus(sessionID, "error", "Empty text provided")
return
}
slog.Info("processing TTS request", "session", sessionID, "text_len", len(text))
publishStatus(sessionID, "processing", fmt.Sprintf("Synthesizing %d characters", len(text)))
audioBytes, err := synthesize(ctx, text, speaker, language, speakerWavB64)
if err != nil {
slog.Error("synthesis failed", "session", sessionID, "error", err)
publishStatus(sessionID, "error", err.Error())
return
}
if stream {
streamAudio(sessionID, audioBytes)
} else {
msg := &messages.TTSFullResponse{
SessionID: sessionID,
Audio: audioBytes,
Timestamp: time.Now().Unix(),
SampleRate: sampleRate,
}
_ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg)
}
publishStatus(sessionID, "completed", fmt.Sprintf("Audio size: %d bytes", len(audioBytes)))
}
if _, err := nc.Conn().Subscribe(requestSubjectPrefix+".>", handleRequest); err != nil {
slog.Error("subscribe request failed", "error", err)
os.Exit(1)
}
slog.Info("subscribed", "subject", requestSubjectPrefix+".>")
// Subscribe: list voices
if _, err := nc.Conn().Subscribe(voicesListSubject, func(msg *nats.Msg) {
resp := &messages.TTSVoiceListResponse{
DefaultSpeaker: defaultSpeaker,
CustomVoices: registry.listVoices(),
LastRefresh: registry.lastRefresh.Unix(),
Timestamp: time.Now().Unix(),
}
packed, _ := msgpack.Marshal(resp)
if msg.Reply != "" {
msg.Respond(packed)
}
}); err != nil {
slog.Error("subscribe voices list failed", "error", err)
}
// Subscribe: refresh voices
if _, err := nc.Conn().Subscribe(voicesRefreshSubject, func(msg *nats.Msg) {
count := registry.refresh()
resp := &messages.TTSVoiceRefreshResponse{
Count: count,
CustomVoices: registry.listVoices(),
Timestamp: time.Now().Unix(),
}
packed, _ := msgpack.Marshal(resp)
if msg.Reply != "" {
msg.Respond(packed)
}
slog.Info("voice registry refreshed on demand", "count", count)
}); err != nil {
slog.Error("subscribe voices refresh failed", "error", err)
}
// Periodic voice refresh
go func() {
ticker := time.NewTicker(time.Duration(refreshInterval) * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
registry.refresh()
case <-ctx.Done():
return
}
}
}()
slog.Info("tts-streaming ready")
// Wait for shutdown
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
<-sigCh
slog.Info("shutting down")
running = false
cancel()
slog.Info("shutdown complete")
}
// Helpers
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getEnvInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return i
}
}
return fallback
}
func orDefault(val, def string) string {
if val == "" {
return def
}
return val
}