package main import ( "bytes" "context" "encoding/base64" "encoding/json" "fmt" "io" "log/slog" "net/http" "os" "os/signal" "path/filepath" "strconv" "strings" "sync" "syscall" "time" "github.com/nats-io/nats.go" "github.com/vmihailenco/msgpack/v5" "git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/health" "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" "git.daviestechlabs.io/daviestechlabs/handler-base/telemetry" ) // NATS subject prefixes const ( requestSubjectPrefix = "ai.voice.tts.request" audioSubjectPrefix = "ai.voice.tts.audio" statusSubjectPrefix = "ai.voice.tts.status" voicesListSubject = "ai.voice.tts.voices.list" voicesRefreshSubject = "ai.voice.tts.voices.refresh" ) // CustomVoice is a trained voice from the coqui-voice-training pipeline. type CustomVoice struct { Name string `json:"name"` ModelPath string `json:"model_path"` ConfigPath string `json:"config_path"` CreatedAt string `json:"created_at"` Language string `json:"language"` ModelType string `json:"model_type"` } // VoiceRegistry discovers custom voices from the model store directory. type VoiceRegistry struct { modelStore string mu sync.RWMutex voices map[string]*CustomVoice lastRefresh time.Time } func newVoiceRegistry(modelStore string) *VoiceRegistry { return &VoiceRegistry{ modelStore: modelStore, voices: make(map[string]*CustomVoice), } } func (vr *VoiceRegistry) refresh() int { vr.mu.Lock() defer vr.mu.Unlock() entries, err := os.ReadDir(vr.modelStore) if err != nil { slog.Warn("voice model store not found", "path", vr.modelStore, "error", err) return 0 } discovered := make(map[string]*CustomVoice) for _, entry := range entries { if !entry.IsDir() { continue } voiceDir := filepath.Join(vr.modelStore, entry.Name()) infoPath := filepath.Join(voiceDir, "model_info.json") infoData, err := os.ReadFile(infoPath) if err != nil { continue } var info map[string]string if err := json.Unmarshal(infoData, &info); err != nil { slog.Error("bad model_info.json", "dir", voiceDir, "error", err) continue } modelPath := filepath.Join(voiceDir, "model.pth") if _, err := os.Stat(modelPath); err != nil { slog.Warn("model file missing", "voice", entry.Name()) continue } configPath := filepath.Join(voiceDir, "config.json") if _, err := os.Stat(configPath); err != nil { configPath = "" } name := info["name"] if name == "" { name = entry.Name() } discovered[name] = &CustomVoice{ Name: name, ModelPath: modelPath, ConfigPath: configPath, CreatedAt: info["created_at"], Language: orDefault(info["language"], "en"), ModelType: orDefault(info["type"], "coqui-tts"), } } vr.voices = discovered vr.lastRefresh = time.Now() slog.Info("voice registry refreshed", "count", len(discovered)) return len(discovered) } func (vr *VoiceRegistry) get(name string) *CustomVoice { vr.mu.RLock() defer vr.mu.RUnlock() return vr.voices[name] } func (vr *VoiceRegistry) listVoices() []map[string]any { vr.mu.RLock() defer vr.mu.RUnlock() result := make([]map[string]any, 0, len(vr.voices)) for _, v := range vr.voices { result = append(result, map[string]any{ "name": v.Name, "language": v.Language, "model_type": v.ModelType, "created_at": v.CreatedAt, }) } return result } func main() { slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}))) cfg := config.Load() cfg.ServiceName = "tts-streaming" xttsURL := getEnv("XTTS_URL", "http://xtts-predictor.ai-ml.svc.cluster.local") defaultSpeaker := getEnv("TTS_DEFAULT_SPEAKER", "default") defaultLanguage := getEnv("TTS_DEFAULT_LANGUAGE", "en") audioChunkSize := getEnvInt("TTS_AUDIO_CHUNK_SIZE", 32768) sampleRate := getEnvInt("TTS_SAMPLE_RATE", 24000) modelStore := getEnv("VOICE_MODEL_STORE", "/models/tts/custom") refreshInterval := getEnvInt("VOICE_REGISTRY_REFRESH_SECONDS", 300) ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Telemetry tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{ ServiceName: cfg.ServiceName, ServiceVersion: cfg.ServiceVersion, ServiceNamespace: cfg.ServiceNamespace, DeploymentEnv: cfg.DeploymentEnv, Enabled: cfg.OTELEnabled, Endpoint: cfg.OTELEndpoint, }) if err != nil { slog.Error("telemetry setup failed", "error", err) } if shutdown != nil { defer shutdown(ctx) } _ = tp // available for future span creation // Voice registry registry := newVoiceRegistry(modelStore) registry.refresh() // NATS natsOpts := []nats.Option{} if cfg.NATSUser != "" && cfg.NATSPassword != "" { natsOpts = append(natsOpts, nats.UserInfo(cfg.NATSUser, cfg.NATSPassword)) } nc := natsutil.New(cfg.NATSURL, natsOpts...) if err := nc.Connect(); err != nil { slog.Error("NATS connect failed", "error", err) os.Exit(1) } defer nc.Close() // JetStream stream setup js, err := nc.Conn().JetStream() if err == nil { _, err = js.AddStream(&nats.StreamConfig{ Name: "AI_VOICE_TTS", Subjects: []string{"ai.voice.tts.>"}, Retention: nats.LimitsPolicy, MaxAge: 5 * time.Minute, Storage: nats.MemoryStorage, }) if err != nil { slog.Info("JetStream stream setup", "msg", err) } else { slog.Info("created/updated JetStream stream AI_VOICE_TTS") } } httpClient := &http.Client{Timeout: 180 * time.Second} running := true // Health server healthSrv := health.New(cfg.HealthPort, cfg.HealthPath, cfg.ReadyPath, func() bool { return running && nc.IsConnected() }) healthSrv.Start() defer healthSrv.Stop(ctx) // Helper: publish status publishStatus := func(sessionID, status, message string) { statusMsg := map[string]any{ "session_id": sessionID, "status": status, "message": message, "timestamp": time.Now().Unix(), } _ = nc.Publish(fmt.Sprintf("%s.%s", statusSubjectPrefix, sessionID), statusMsg) } // Helper: synthesize via XTTS HTTP API synthesize := func(ctx context.Context, text, speaker, language, speakerWavB64 string) ([]byte, error) { payload := map[string]any{ "text": text, "speaker": speaker, "language": language, } customVoice := registry.get(speaker) if customVoice != nil { payload["model_path"] = customVoice.ModelPath if customVoice.ConfigPath != "" { payload["config_path"] = customVoice.ConfigPath } payload["language"] = orDefault(language, customVoice.Language) slog.Info("using custom voice", "voice", customVoice.Name) } else if speakerWavB64 != "" { payload["speaker_wav"] = speakerWavB64 } body, _ := json.Marshal(payload) req, err := http.NewRequestWithContext(ctx, http.MethodPost, xttsURL+"/v1/audio/speech", bytes.NewReader(body)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") resp, err := httpClient.Do(req) if err != nil { return nil, fmt.Errorf("xtts request: %w", err) } defer resp.Body.Close() if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("xtts %d: %s", resp.StatusCode, string(respBody)) } return io.ReadAll(resp.Body) } // Helper: stream audio chunks streamAudio := func(sessionID string, audioBytes []byte) { totalChunks := (len(audioBytes) + audioChunkSize - 1) / audioChunkSize for i := 0; i < len(audioBytes); i += audioChunkSize { end := i + audioChunkSize if end > len(audioBytes) { end = len(audioBytes) } chunk := audioBytes[i:end] chunkIndex := i / audioChunkSize isLast := end >= len(audioBytes) msg := map[string]any{ "session_id": sessionID, "chunk_index": chunkIndex, "total_chunks": totalChunks, "audio_b64": base64.StdEncoding.EncodeToString(chunk), "is_last": isLast, "timestamp": time.Now().Unix(), "sample_rate": sampleRate, } _ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg) } slog.Info("streamed audio", "session", sessionID, "chunks", totalChunks) } // Subscribe: TTS requests handleRequest := func(natMsg *nats.Msg) { parts := strings.Split(natMsg.Subject, ".") if len(parts) < 5 { slog.Warn("invalid subject format", "subject", natMsg.Subject) return } sessionID := parts[4] data, err := natsutil.DecodeMsgpackMap(natMsg.Data) if err != nil { slog.Error("decode error", "error", err) return } text := strVal(data, "text", "") speaker := strVal(data, "speaker", defaultSpeaker) language := strVal(data, "language", defaultLanguage) speakerWavB64 := strVal(data, "speaker_wav_b64", "") stream := boolVal(data, "stream", true) if text == "" { slog.Warn("empty text", "session", sessionID) publishStatus(sessionID, "error", "Empty text provided") return } slog.Info("processing TTS request", "session", sessionID, "text_len", len(text)) publishStatus(sessionID, "processing", fmt.Sprintf("Synthesizing %d characters", len(text))) audioBytes, err := synthesize(ctx, text, speaker, language, speakerWavB64) if err != nil { slog.Error("synthesis failed", "session", sessionID, "error", err) publishStatus(sessionID, "error", err.Error()) return } if stream { streamAudio(sessionID, audioBytes) } else { msg := map[string]any{ "session_id": sessionID, "audio_b64": base64.StdEncoding.EncodeToString(audioBytes), "timestamp": time.Now().Unix(), "sample_rate": sampleRate, } _ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg) } publishStatus(sessionID, "completed", fmt.Sprintf("Audio size: %d bytes", len(audioBytes))) } if _, err := nc.Conn().Subscribe(requestSubjectPrefix+".>", handleRequest); err != nil { slog.Error("subscribe request failed", "error", err) os.Exit(1) } slog.Info("subscribed", "subject", requestSubjectPrefix+".>") // Subscribe: list voices if _, err := nc.Conn().Subscribe(voicesListSubject, func(msg *nats.Msg) { resp := map[string]any{ "default_speaker": defaultSpeaker, "custom_voices": registry.listVoices(), "last_refresh": registry.lastRefresh.Unix(), "timestamp": time.Now().Unix(), } packed, _ := msgpack.Marshal(resp) if msg.Reply != "" { msg.Respond(packed) } }); err != nil { slog.Error("subscribe voices list failed", "error", err) } // Subscribe: refresh voices if _, err := nc.Conn().Subscribe(voicesRefreshSubject, func(msg *nats.Msg) { count := registry.refresh() resp := map[string]any{ "count": count, "custom_voices": registry.listVoices(), "timestamp": time.Now().Unix(), } packed, _ := msgpack.Marshal(resp) if msg.Reply != "" { msg.Respond(packed) } slog.Info("voice registry refreshed on demand", "count", count) }); err != nil { slog.Error("subscribe voices refresh failed", "error", err) } // Periodic voice refresh go func() { ticker := time.NewTicker(time.Duration(refreshInterval) * time.Second) defer ticker.Stop() for { select { case <-ticker.C: registry.refresh() case <-ctx.Done(): return } } }() slog.Info("tts-streaming ready") // Wait for shutdown sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) <-sigCh slog.Info("shutting down") running = false cancel() slog.Info("shutdown complete") } // Helpers func strVal(m map[string]any, key, fallback string) string { if v, ok := m[key]; ok { if s, ok := v.(string); ok { return s } } return fallback } func boolVal(m map[string]any, key string, fallback bool) bool { if v, ok := m[key]; ok { if b, ok := v.(bool); ok { return b } } return fallback } func getEnv(key, fallback string) string { if v := os.Getenv(key); v != "" { return v } return fallback } func getEnvInt(key string, fallback int) int { if v := os.Getenv(key); v != "" { if i, err := strconv.Atoi(v); err == nil { return i } } return fallback } func orDefault(val, def string) string { if val == "" { return def } return val }