Replace Python streaming TTS service with Go for smaller container images. - VoiceRegistry: discovers custom voices from model store - NATS subscriptions: TTS requests, voice list, voice refresh - JetStream AI_VOICE_TTS stream setup - Chunked audio streaming over NATS - Dockerfile: multi-stage golang:1.25-alpine → scratch - CI: Gitea Actions with lint/test/release/docker/notify
461 lines
12 KiB
Go
461 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/nats-io/nats.go"
|
|
"github.com/vmihailenco/msgpack/v5"
|
|
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
|
|
)
|
|
|
|
// NATS subject prefixes
|
|
const (
|
|
requestSubjectPrefix = "ai.voice.tts.request"
|
|
audioSubjectPrefix = "ai.voice.tts.audio"
|
|
statusSubjectPrefix = "ai.voice.tts.status"
|
|
voicesListSubject = "ai.voice.tts.voices.list"
|
|
voicesRefreshSubject = "ai.voice.tts.voices.refresh"
|
|
)
|
|
|
|
// CustomVoice is a trained voice from the coqui-voice-training pipeline.
|
|
type CustomVoice struct {
|
|
Name string `json:"name"`
|
|
ModelPath string `json:"model_path"`
|
|
ConfigPath string `json:"config_path"`
|
|
CreatedAt string `json:"created_at"`
|
|
Language string `json:"language"`
|
|
ModelType string `json:"model_type"`
|
|
}
|
|
|
|
// VoiceRegistry discovers custom voices from the model store directory.
|
|
type VoiceRegistry struct {
|
|
modelStore string
|
|
mu sync.RWMutex
|
|
voices map[string]*CustomVoice
|
|
lastRefresh time.Time
|
|
}
|
|
|
|
func newVoiceRegistry(modelStore string) *VoiceRegistry {
|
|
return &VoiceRegistry{
|
|
modelStore: modelStore,
|
|
voices: make(map[string]*CustomVoice),
|
|
}
|
|
}
|
|
|
|
func (vr *VoiceRegistry) refresh() int {
|
|
vr.mu.Lock()
|
|
defer vr.mu.Unlock()
|
|
|
|
entries, err := os.ReadDir(vr.modelStore)
|
|
if err != nil {
|
|
slog.Warn("voice model store not found", "path", vr.modelStore, "error", err)
|
|
return 0
|
|
}
|
|
|
|
discovered := make(map[string]*CustomVoice)
|
|
for _, entry := range entries {
|
|
if !entry.IsDir() {
|
|
continue
|
|
}
|
|
voiceDir := filepath.Join(vr.modelStore, entry.Name())
|
|
infoPath := filepath.Join(voiceDir, "model_info.json")
|
|
|
|
infoData, err := os.ReadFile(infoPath)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
var info map[string]string
|
|
if err := json.Unmarshal(infoData, &info); err != nil {
|
|
slog.Error("bad model_info.json", "dir", voiceDir, "error", err)
|
|
continue
|
|
}
|
|
|
|
modelPath := filepath.Join(voiceDir, "model.pth")
|
|
if _, err := os.Stat(modelPath); err != nil {
|
|
slog.Warn("model file missing", "voice", entry.Name())
|
|
continue
|
|
}
|
|
|
|
configPath := filepath.Join(voiceDir, "config.json")
|
|
if _, err := os.Stat(configPath); err != nil {
|
|
configPath = ""
|
|
}
|
|
|
|
name := info["name"]
|
|
if name == "" {
|
|
name = entry.Name()
|
|
}
|
|
|
|
discovered[name] = &CustomVoice{
|
|
Name: name,
|
|
ModelPath: modelPath,
|
|
ConfigPath: configPath,
|
|
CreatedAt: info["created_at"],
|
|
Language: orDefault(info["language"], "en"),
|
|
ModelType: orDefault(info["type"], "coqui-tts"),
|
|
}
|
|
}
|
|
|
|
vr.voices = discovered
|
|
vr.lastRefresh = time.Now()
|
|
slog.Info("voice registry refreshed", "count", len(discovered))
|
|
return len(discovered)
|
|
}
|
|
|
|
func (vr *VoiceRegistry) get(name string) *CustomVoice {
|
|
vr.mu.RLock()
|
|
defer vr.mu.RUnlock()
|
|
return vr.voices[name]
|
|
}
|
|
|
|
func (vr *VoiceRegistry) listVoices() []map[string]any {
|
|
vr.mu.RLock()
|
|
defer vr.mu.RUnlock()
|
|
result := make([]map[string]any, 0, len(vr.voices))
|
|
for _, v := range vr.voices {
|
|
result = append(result, map[string]any{
|
|
"name": v.Name,
|
|
"language": v.Language,
|
|
"model_type": v.ModelType,
|
|
"created_at": v.CreatedAt,
|
|
})
|
|
}
|
|
return result
|
|
}
|
|
|
|
func main() {
|
|
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
|
|
|
cfg := config.Load()
|
|
cfg.ServiceName = "tts-streaming"
|
|
|
|
xttsURL := getEnv("XTTS_URL", "http://xtts-predictor.ai-ml.svc.cluster.local")
|
|
defaultSpeaker := getEnv("TTS_DEFAULT_SPEAKER", "default")
|
|
defaultLanguage := getEnv("TTS_DEFAULT_LANGUAGE", "en")
|
|
audioChunkSize := getEnvInt("TTS_AUDIO_CHUNK_SIZE", 32768)
|
|
sampleRate := getEnvInt("TTS_SAMPLE_RATE", 24000)
|
|
modelStore := getEnv("VOICE_MODEL_STORE", "/models/tts/custom")
|
|
refreshInterval := getEnvInt("VOICE_REGISTRY_REFRESH_SECONDS", 300)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Telemetry
|
|
tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{
|
|
ServiceName: cfg.ServiceName,
|
|
ServiceVersion: cfg.ServiceVersion,
|
|
ServiceNamespace: cfg.ServiceNamespace,
|
|
DeploymentEnv: cfg.DeploymentEnv,
|
|
Enabled: cfg.OTELEnabled,
|
|
Endpoint: cfg.OTELEndpoint,
|
|
})
|
|
if err != nil {
|
|
slog.Error("telemetry setup failed", "error", err)
|
|
}
|
|
if shutdown != nil {
|
|
defer shutdown(ctx)
|
|
}
|
|
_ = tp // available for future span creation
|
|
|
|
// Voice registry
|
|
registry := newVoiceRegistry(modelStore)
|
|
registry.refresh()
|
|
|
|
// NATS
|
|
natsOpts := []nats.Option{}
|
|
if cfg.NATSUser != "" && cfg.NATSPassword != "" {
|
|
natsOpts = append(natsOpts, nats.UserInfo(cfg.NATSUser, cfg.NATSPassword))
|
|
}
|
|
nc := natsutil.New(cfg.NATSURL, natsOpts...)
|
|
if err := nc.Connect(); err != nil {
|
|
slog.Error("NATS connect failed", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
defer nc.Close()
|
|
|
|
// JetStream stream setup
|
|
js, err := nc.Conn().JetStream()
|
|
if err == nil {
|
|
_, err = js.AddStream(&nats.StreamConfig{
|
|
Name: "AI_VOICE_TTS",
|
|
Subjects: []string{"ai.voice.tts.>"},
|
|
Retention: nats.LimitsPolicy,
|
|
MaxAge: 5 * time.Minute,
|
|
Storage: nats.MemoryStorage,
|
|
})
|
|
if err != nil {
|
|
slog.Info("JetStream stream setup", "msg", err)
|
|
} else {
|
|
slog.Info("created/updated JetStream stream AI_VOICE_TTS")
|
|
}
|
|
}
|
|
|
|
httpClient := &http.Client{Timeout: 180 * time.Second}
|
|
running := true
|
|
|
|
// Health server
|
|
healthSrv := health.New(cfg.HealthPort, cfg.HealthPath, cfg.ReadyPath, func() bool {
|
|
return running && nc.IsConnected()
|
|
})
|
|
healthSrv.Start()
|
|
defer healthSrv.Stop(ctx)
|
|
|
|
// Helper: publish status
|
|
publishStatus := func(sessionID, status, message string) {
|
|
statusMsg := map[string]any{
|
|
"session_id": sessionID,
|
|
"status": status,
|
|
"message": message,
|
|
"timestamp": time.Now().Unix(),
|
|
}
|
|
_ = nc.Publish(fmt.Sprintf("%s.%s", statusSubjectPrefix, sessionID), statusMsg)
|
|
}
|
|
|
|
// Helper: synthesize via XTTS HTTP API
|
|
synthesize := func(ctx context.Context, text, speaker, language, speakerWavB64 string) ([]byte, error) {
|
|
payload := map[string]any{
|
|
"text": text,
|
|
"speaker": speaker,
|
|
"language": language,
|
|
}
|
|
|
|
customVoice := registry.get(speaker)
|
|
if customVoice != nil {
|
|
payload["model_path"] = customVoice.ModelPath
|
|
if customVoice.ConfigPath != "" {
|
|
payload["config_path"] = customVoice.ConfigPath
|
|
}
|
|
payload["language"] = orDefault(language, customVoice.Language)
|
|
slog.Info("using custom voice", "voice", customVoice.Name)
|
|
} else if speakerWavB64 != "" {
|
|
payload["speaker_wav"] = speakerWavB64
|
|
}
|
|
|
|
body, _ := json.Marshal(payload)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, xttsURL+"/v1/audio/speech", bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("xtts request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode >= 400 {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("xtts %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
return io.ReadAll(resp.Body)
|
|
}
|
|
|
|
// Helper: stream audio chunks
|
|
streamAudio := func(sessionID string, audioBytes []byte) {
|
|
totalChunks := (len(audioBytes) + audioChunkSize - 1) / audioChunkSize
|
|
for i := 0; i < len(audioBytes); i += audioChunkSize {
|
|
end := i + audioChunkSize
|
|
if end > len(audioBytes) {
|
|
end = len(audioBytes)
|
|
}
|
|
chunk := audioBytes[i:end]
|
|
chunkIndex := i / audioChunkSize
|
|
isLast := end >= len(audioBytes)
|
|
|
|
msg := map[string]any{
|
|
"session_id": sessionID,
|
|
"chunk_index": chunkIndex,
|
|
"total_chunks": totalChunks,
|
|
"audio_b64": base64.StdEncoding.EncodeToString(chunk),
|
|
"is_last": isLast,
|
|
"timestamp": time.Now().Unix(),
|
|
"sample_rate": sampleRate,
|
|
}
|
|
_ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg)
|
|
}
|
|
slog.Info("streamed audio", "session", sessionID, "chunks", totalChunks)
|
|
}
|
|
|
|
// Subscribe: TTS requests
|
|
handleRequest := func(natMsg *nats.Msg) {
|
|
parts := strings.Split(natMsg.Subject, ".")
|
|
if len(parts) < 5 {
|
|
slog.Warn("invalid subject format", "subject", natMsg.Subject)
|
|
return
|
|
}
|
|
sessionID := parts[4]
|
|
|
|
data, err := natsutil.DecodeMsgpackMap(natMsg.Data)
|
|
if err != nil {
|
|
slog.Error("decode error", "error", err)
|
|
return
|
|
}
|
|
|
|
text := strVal(data, "text", "")
|
|
speaker := strVal(data, "speaker", defaultSpeaker)
|
|
language := strVal(data, "language", defaultLanguage)
|
|
speakerWavB64 := strVal(data, "speaker_wav_b64", "")
|
|
stream := boolVal(data, "stream", true)
|
|
|
|
if text == "" {
|
|
slog.Warn("empty text", "session", sessionID)
|
|
publishStatus(sessionID, "error", "Empty text provided")
|
|
return
|
|
}
|
|
|
|
slog.Info("processing TTS request", "session", sessionID, "text_len", len(text))
|
|
publishStatus(sessionID, "processing", fmt.Sprintf("Synthesizing %d characters", len(text)))
|
|
|
|
audioBytes, err := synthesize(ctx, text, speaker, language, speakerWavB64)
|
|
if err != nil {
|
|
slog.Error("synthesis failed", "session", sessionID, "error", err)
|
|
publishStatus(sessionID, "error", err.Error())
|
|
return
|
|
}
|
|
|
|
if stream {
|
|
streamAudio(sessionID, audioBytes)
|
|
} else {
|
|
msg := map[string]any{
|
|
"session_id": sessionID,
|
|
"audio_b64": base64.StdEncoding.EncodeToString(audioBytes),
|
|
"timestamp": time.Now().Unix(),
|
|
"sample_rate": sampleRate,
|
|
}
|
|
_ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg)
|
|
}
|
|
|
|
publishStatus(sessionID, "completed", fmt.Sprintf("Audio size: %d bytes", len(audioBytes)))
|
|
}
|
|
|
|
if _, err := nc.Conn().Subscribe(requestSubjectPrefix+".>", handleRequest); err != nil {
|
|
slog.Error("subscribe request failed", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
slog.Info("subscribed", "subject", requestSubjectPrefix+".>")
|
|
|
|
// Subscribe: list voices
|
|
if _, err := nc.Conn().Subscribe(voicesListSubject, func(msg *nats.Msg) {
|
|
resp := map[string]any{
|
|
"default_speaker": defaultSpeaker,
|
|
"custom_voices": registry.listVoices(),
|
|
"last_refresh": registry.lastRefresh.Unix(),
|
|
"timestamp": time.Now().Unix(),
|
|
}
|
|
packed, _ := msgpack.Marshal(resp)
|
|
if msg.Reply != "" {
|
|
msg.Respond(packed)
|
|
}
|
|
}); err != nil {
|
|
slog.Error("subscribe voices list failed", "error", err)
|
|
}
|
|
|
|
// Subscribe: refresh voices
|
|
if _, err := nc.Conn().Subscribe(voicesRefreshSubject, func(msg *nats.Msg) {
|
|
count := registry.refresh()
|
|
resp := map[string]any{
|
|
"count": count,
|
|
"custom_voices": registry.listVoices(),
|
|
"timestamp": time.Now().Unix(),
|
|
}
|
|
packed, _ := msgpack.Marshal(resp)
|
|
if msg.Reply != "" {
|
|
msg.Respond(packed)
|
|
}
|
|
slog.Info("voice registry refreshed on demand", "count", count)
|
|
}); err != nil {
|
|
slog.Error("subscribe voices refresh failed", "error", err)
|
|
}
|
|
|
|
// Periodic voice refresh
|
|
go func() {
|
|
ticker := time.NewTicker(time.Duration(refreshInterval) * time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
registry.refresh()
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
slog.Info("tts-streaming ready")
|
|
|
|
// Wait for shutdown
|
|
sigCh := make(chan os.Signal, 1)
|
|
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
|
<-sigCh
|
|
|
|
slog.Info("shutting down")
|
|
running = false
|
|
cancel()
|
|
slog.Info("shutdown complete")
|
|
}
|
|
|
|
// Helpers
|
|
|
|
func strVal(m map[string]any, key, fallback string) string {
|
|
if v, ok := m[key]; ok {
|
|
if s, ok := v.(string); ok {
|
|
return s
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func boolVal(m map[string]any, key string, fallback bool) bool {
|
|
if v, ok := m[key]; ok {
|
|
if b, ok := v.(bool); ok {
|
|
return b
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func getEnv(key, fallback string) string {
|
|
if v := os.Getenv(key); v != "" {
|
|
return v
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func getEnvInt(key string, fallback int) int {
|
|
if v := os.Getenv(key); v != "" {
|
|
if i, err := strconv.Atoi(v); err == nil {
|
|
return i
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func orDefault(val, def string) string {
|
|
if val == "" {
|
|
return def
|
|
}
|
|
return val
|
|
}
|