- Decode TTSRequest via natsutil.Decode[messages.TTSRequest] - Stream audio as raw bytes via messages.TTSAudioChunk (no base64) - Non-stream response uses messages.TTSFullResponse - Status updates use messages.TTSStatus - Voice list/refresh use messages.TTSVoiceListResponse/TTSVoiceRefreshResponse - Registry returns []messages.TTSVoiceInfo (not []map[string]any) - Remove strVal/boolVal helpers - Add .dockerignore, GOAMD64=v3 in Dockerfile - Update tests for typed structs (13 tests pass)
447 lines
12 KiB
Go
447 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/nats-io/nats.go"
|
|
"github.com/vmihailenco/msgpack/v5"
|
|
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
|
|
)
|
|
|
|
// NATS subject prefixes
|
|
const (
|
|
requestSubjectPrefix = "ai.voice.tts.request"
|
|
audioSubjectPrefix = "ai.voice.tts.audio"
|
|
statusSubjectPrefix = "ai.voice.tts.status"
|
|
voicesListSubject = "ai.voice.tts.voices.list"
|
|
voicesRefreshSubject = "ai.voice.tts.voices.refresh"
|
|
)
|
|
|
|
// CustomVoice is a trained voice from the coqui-voice-training pipeline.
|
|
type CustomVoice struct {
|
|
Name string `json:"name"`
|
|
ModelPath string `json:"model_path"`
|
|
ConfigPath string `json:"config_path"`
|
|
CreatedAt string `json:"created_at"`
|
|
Language string `json:"language"`
|
|
ModelType string `json:"model_type"`
|
|
}
|
|
|
|
// VoiceRegistry discovers custom voices from the model store directory.
|
|
type VoiceRegistry struct {
|
|
modelStore string
|
|
mu sync.RWMutex
|
|
voices map[string]*CustomVoice
|
|
lastRefresh time.Time
|
|
}
|
|
|
|
func newVoiceRegistry(modelStore string) *VoiceRegistry {
|
|
return &VoiceRegistry{
|
|
modelStore: modelStore,
|
|
voices: make(map[string]*CustomVoice),
|
|
}
|
|
}
|
|
|
|
func (vr *VoiceRegistry) refresh() int {
|
|
vr.mu.Lock()
|
|
defer vr.mu.Unlock()
|
|
|
|
entries, err := os.ReadDir(vr.modelStore)
|
|
if err != nil {
|
|
slog.Warn("voice model store not found", "path", vr.modelStore, "error", err)
|
|
return 0
|
|
}
|
|
|
|
discovered := make(map[string]*CustomVoice)
|
|
for _, entry := range entries {
|
|
if !entry.IsDir() {
|
|
continue
|
|
}
|
|
voiceDir := filepath.Join(vr.modelStore, entry.Name())
|
|
infoPath := filepath.Join(voiceDir, "model_info.json")
|
|
|
|
infoData, err := os.ReadFile(infoPath)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
var info map[string]string
|
|
if err := json.Unmarshal(infoData, &info); err != nil {
|
|
slog.Error("bad model_info.json", "dir", voiceDir, "error", err)
|
|
continue
|
|
}
|
|
|
|
modelPath := filepath.Join(voiceDir, "model.pth")
|
|
if _, err := os.Stat(modelPath); err != nil {
|
|
slog.Warn("model file missing", "voice", entry.Name())
|
|
continue
|
|
}
|
|
|
|
configPath := filepath.Join(voiceDir, "config.json")
|
|
if _, err := os.Stat(configPath); err != nil {
|
|
configPath = ""
|
|
}
|
|
|
|
name := info["name"]
|
|
if name == "" {
|
|
name = entry.Name()
|
|
}
|
|
|
|
discovered[name] = &CustomVoice{
|
|
Name: name,
|
|
ModelPath: modelPath,
|
|
ConfigPath: configPath,
|
|
CreatedAt: info["created_at"],
|
|
Language: orDefault(info["language"], "en"),
|
|
ModelType: orDefault(info["type"], "coqui-tts"),
|
|
}
|
|
}
|
|
|
|
vr.voices = discovered
|
|
vr.lastRefresh = time.Now()
|
|
slog.Info("voice registry refreshed", "count", len(discovered))
|
|
return len(discovered)
|
|
}
|
|
|
|
func (vr *VoiceRegistry) get(name string) *CustomVoice {
|
|
vr.mu.RLock()
|
|
defer vr.mu.RUnlock()
|
|
return vr.voices[name]
|
|
}
|
|
|
|
func (vr *VoiceRegistry) listVoices() []messages.TTSVoiceInfo {
|
|
vr.mu.RLock()
|
|
defer vr.mu.RUnlock()
|
|
result := make([]messages.TTSVoiceInfo, 0, len(vr.voices))
|
|
for _, v := range vr.voices {
|
|
result = append(result, messages.TTSVoiceInfo{
|
|
Name: v.Name,
|
|
Language: v.Language,
|
|
ModelType: v.ModelType,
|
|
CreatedAt: v.CreatedAt,
|
|
})
|
|
}
|
|
return result
|
|
}
|
|
|
|
func main() {
|
|
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
|
|
|
cfg := config.Load()
|
|
cfg.ServiceName = "tts-streaming"
|
|
|
|
xttsURL := getEnv("XTTS_URL", "http://xtts-predictor.ai-ml.svc.cluster.local")
|
|
defaultSpeaker := getEnv("TTS_DEFAULT_SPEAKER", "default")
|
|
defaultLanguage := getEnv("TTS_DEFAULT_LANGUAGE", "en")
|
|
audioChunkSize := getEnvInt("TTS_AUDIO_CHUNK_SIZE", 32768)
|
|
sampleRate := getEnvInt("TTS_SAMPLE_RATE", 24000)
|
|
modelStore := getEnv("VOICE_MODEL_STORE", "/models/tts/custom")
|
|
refreshInterval := getEnvInt("VOICE_REGISTRY_REFRESH_SECONDS", 300)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Telemetry
|
|
tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{
|
|
ServiceName: cfg.ServiceName,
|
|
ServiceVersion: cfg.ServiceVersion,
|
|
ServiceNamespace: cfg.ServiceNamespace,
|
|
DeploymentEnv: cfg.DeploymentEnv,
|
|
Enabled: cfg.OTELEnabled,
|
|
Endpoint: cfg.OTELEndpoint,
|
|
})
|
|
if err != nil {
|
|
slog.Error("telemetry setup failed", "error", err)
|
|
}
|
|
if shutdown != nil {
|
|
defer shutdown(ctx)
|
|
}
|
|
_ = tp // available for future span creation
|
|
|
|
// Voice registry
|
|
registry := newVoiceRegistry(modelStore)
|
|
registry.refresh()
|
|
|
|
// NATS
|
|
natsOpts := []nats.Option{}
|
|
if cfg.NATSUser != "" && cfg.NATSPassword != "" {
|
|
natsOpts = append(natsOpts, nats.UserInfo(cfg.NATSUser, cfg.NATSPassword))
|
|
}
|
|
nc := natsutil.New(cfg.NATSURL, natsOpts...)
|
|
if err := nc.Connect(); err != nil {
|
|
slog.Error("NATS connect failed", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
defer nc.Close()
|
|
|
|
// JetStream stream setup
|
|
js, err := nc.Conn().JetStream()
|
|
if err == nil {
|
|
_, err = js.AddStream(&nats.StreamConfig{
|
|
Name: "AI_VOICE_TTS",
|
|
Subjects: []string{"ai.voice.tts.>"},
|
|
Retention: nats.LimitsPolicy,
|
|
MaxAge: 5 * time.Minute,
|
|
Storage: nats.MemoryStorage,
|
|
})
|
|
if err != nil {
|
|
slog.Info("JetStream stream setup", "msg", err)
|
|
} else {
|
|
slog.Info("created/updated JetStream stream AI_VOICE_TTS")
|
|
}
|
|
}
|
|
|
|
httpClient := &http.Client{Timeout: 180 * time.Second}
|
|
running := true
|
|
|
|
// Health server
|
|
healthSrv := health.New(cfg.HealthPort, cfg.HealthPath, cfg.ReadyPath, func() bool {
|
|
return running && nc.IsConnected()
|
|
})
|
|
healthSrv.Start()
|
|
defer healthSrv.Stop(ctx)
|
|
|
|
// Helper: publish status
|
|
publishStatus := func(sessionID, status, message string) {
|
|
statusMsg := &messages.TTSStatus{
|
|
SessionID: sessionID,
|
|
Status: status,
|
|
Message: message,
|
|
Timestamp: time.Now().Unix(),
|
|
}
|
|
_ = nc.Publish(fmt.Sprintf("%s.%s", statusSubjectPrefix, sessionID), statusMsg)
|
|
}
|
|
|
|
// Helper: synthesize via XTTS HTTP API
|
|
synthesize := func(ctx context.Context, text, speaker, language, speakerWavB64 string) ([]byte, error) {
|
|
payload := map[string]any{
|
|
"text": text,
|
|
"speaker": speaker,
|
|
"language": language,
|
|
}
|
|
|
|
customVoice := registry.get(speaker)
|
|
if customVoice != nil {
|
|
payload["model_path"] = customVoice.ModelPath
|
|
if customVoice.ConfigPath != "" {
|
|
payload["config_path"] = customVoice.ConfigPath
|
|
}
|
|
payload["language"] = orDefault(language, customVoice.Language)
|
|
slog.Info("using custom voice", "voice", customVoice.Name)
|
|
} else if speakerWavB64 != "" {
|
|
payload["speaker_wav"] = speakerWavB64
|
|
}
|
|
|
|
body, _ := json.Marshal(payload)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, xttsURL+"/v1/audio/speech", bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("xtts request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode >= 400 {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("xtts %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
return io.ReadAll(resp.Body)
|
|
}
|
|
|
|
// Helper: stream audio chunks — raw bytes, no base64
|
|
streamAudio := func(sessionID string, audioBytes []byte) {
|
|
totalChunks := (len(audioBytes) + audioChunkSize - 1) / audioChunkSize
|
|
for i := 0; i < len(audioBytes); i += audioChunkSize {
|
|
end := i + audioChunkSize
|
|
if end > len(audioBytes) {
|
|
end = len(audioBytes)
|
|
}
|
|
chunk := audioBytes[i:end]
|
|
chunkIndex := i / audioChunkSize
|
|
isLast := end >= len(audioBytes)
|
|
|
|
msg := &messages.TTSAudioChunk{
|
|
SessionID: sessionID,
|
|
ChunkIndex: chunkIndex,
|
|
TotalChunks: totalChunks,
|
|
Audio: chunk,
|
|
IsLast: isLast,
|
|
Timestamp: time.Now().Unix(),
|
|
SampleRate: sampleRate,
|
|
}
|
|
_ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg)
|
|
}
|
|
slog.Info("streamed audio", "session", sessionID, "chunks", totalChunks)
|
|
}
|
|
|
|
// Subscribe: TTS requests
|
|
handleRequest := func(natMsg *nats.Msg) {
|
|
parts := strings.Split(natMsg.Subject, ".")
|
|
if len(parts) < 5 {
|
|
slog.Warn("invalid subject format", "subject", natMsg.Subject)
|
|
return
|
|
}
|
|
sessionID := parts[4]
|
|
|
|
req, err := natsutil.Decode[messages.TTSRequest](natMsg.Data)
|
|
if err != nil {
|
|
slog.Error("decode error", "error", err)
|
|
return
|
|
}
|
|
|
|
text := req.Text
|
|
speaker := orDefault(req.Speaker, defaultSpeaker)
|
|
language := orDefault(req.Language, defaultLanguage)
|
|
speakerWavB64 := req.SpeakerWavB64
|
|
stream := req.Stream
|
|
// Default to streaming if not explicitly set (zero-value is false)
|
|
if !stream && text != "" {
|
|
stream = true
|
|
}
|
|
|
|
if text == "" {
|
|
slog.Warn("empty text", "session", sessionID)
|
|
publishStatus(sessionID, "error", "Empty text provided")
|
|
return
|
|
}
|
|
|
|
slog.Info("processing TTS request", "session", sessionID, "text_len", len(text))
|
|
publishStatus(sessionID, "processing", fmt.Sprintf("Synthesizing %d characters", len(text)))
|
|
|
|
audioBytes, err := synthesize(ctx, text, speaker, language, speakerWavB64)
|
|
if err != nil {
|
|
slog.Error("synthesis failed", "session", sessionID, "error", err)
|
|
publishStatus(sessionID, "error", err.Error())
|
|
return
|
|
}
|
|
|
|
if stream {
|
|
streamAudio(sessionID, audioBytes)
|
|
} else {
|
|
msg := &messages.TTSFullResponse{
|
|
SessionID: sessionID,
|
|
Audio: audioBytes,
|
|
Timestamp: time.Now().Unix(),
|
|
SampleRate: sampleRate,
|
|
}
|
|
_ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg)
|
|
}
|
|
|
|
publishStatus(sessionID, "completed", fmt.Sprintf("Audio size: %d bytes", len(audioBytes)))
|
|
}
|
|
|
|
if _, err := nc.Conn().Subscribe(requestSubjectPrefix+".>", handleRequest); err != nil {
|
|
slog.Error("subscribe request failed", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
slog.Info("subscribed", "subject", requestSubjectPrefix+".>")
|
|
|
|
// Subscribe: list voices
|
|
if _, err := nc.Conn().Subscribe(voicesListSubject, func(msg *nats.Msg) {
|
|
resp := &messages.TTSVoiceListResponse{
|
|
DefaultSpeaker: defaultSpeaker,
|
|
CustomVoices: registry.listVoices(),
|
|
LastRefresh: registry.lastRefresh.Unix(),
|
|
Timestamp: time.Now().Unix(),
|
|
}
|
|
packed, _ := msgpack.Marshal(resp)
|
|
if msg.Reply != "" {
|
|
msg.Respond(packed)
|
|
}
|
|
}); err != nil {
|
|
slog.Error("subscribe voices list failed", "error", err)
|
|
}
|
|
|
|
// Subscribe: refresh voices
|
|
if _, err := nc.Conn().Subscribe(voicesRefreshSubject, func(msg *nats.Msg) {
|
|
count := registry.refresh()
|
|
resp := &messages.TTSVoiceRefreshResponse{
|
|
Count: count,
|
|
CustomVoices: registry.listVoices(),
|
|
Timestamp: time.Now().Unix(),
|
|
}
|
|
packed, _ := msgpack.Marshal(resp)
|
|
if msg.Reply != "" {
|
|
msg.Respond(packed)
|
|
}
|
|
slog.Info("voice registry refreshed on demand", "count", count)
|
|
}); err != nil {
|
|
slog.Error("subscribe voices refresh failed", "error", err)
|
|
}
|
|
|
|
// Periodic voice refresh
|
|
go func() {
|
|
ticker := time.NewTicker(time.Duration(refreshInterval) * time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
registry.refresh()
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
slog.Info("tts-streaming ready")
|
|
|
|
// Wait for shutdown
|
|
sigCh := make(chan os.Signal, 1)
|
|
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
|
<-sigCh
|
|
|
|
slog.Info("shutting down")
|
|
running = false
|
|
cancel()
|
|
slog.Info("shutdown complete")
|
|
}
|
|
|
|
// Helpers
|
|
|
|
func getEnv(key, fallback string) string {
|
|
if v := os.Getenv(key); v != "" {
|
|
return v
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func getEnvInt(key string, fallback int) int {
|
|
if v := os.Getenv(key); v != "" {
|
|
if i, err := strconv.Atoi(v); err == nil {
|
|
return i
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func orDefault(val, def string) string {
|
|
if val == "" {
|
|
return def
|
|
}
|
|
return val
|
|
}
|