feat: rewrite tts-module in Go
Replace Python streaming TTS service with Go for smaller container images. - VoiceRegistry: discovers custom voices from model store - NATS subscriptions: TTS requests, voice list, voice refresh - JetStream AI_VOICE_TTS stream setup - Chunked audio streaming over NATS - Dockerfile: multi-stage golang:1.25-alpine → scratch - CI: Gitea Actions with lint/test/release/docker/notify
This commit is contained in:
460
main.go
Normal file
460
main.go
Normal file
@@ -0,0 +1,460 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
|
||||
)
|
||||
|
||||
// NATS subject prefixes
|
||||
const (
|
||||
requestSubjectPrefix = "ai.voice.tts.request"
|
||||
audioSubjectPrefix = "ai.voice.tts.audio"
|
||||
statusSubjectPrefix = "ai.voice.tts.status"
|
||||
voicesListSubject = "ai.voice.tts.voices.list"
|
||||
voicesRefreshSubject = "ai.voice.tts.voices.refresh"
|
||||
)
|
||||
|
||||
// CustomVoice is a trained voice from the coqui-voice-training pipeline.
|
||||
type CustomVoice struct {
|
||||
Name string `json:"name"`
|
||||
ModelPath string `json:"model_path"`
|
||||
ConfigPath string `json:"config_path"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Language string `json:"language"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
|
||||
// VoiceRegistry discovers custom voices from the model store directory.
|
||||
type VoiceRegistry struct {
|
||||
modelStore string
|
||||
mu sync.RWMutex
|
||||
voices map[string]*CustomVoice
|
||||
lastRefresh time.Time
|
||||
}
|
||||
|
||||
func newVoiceRegistry(modelStore string) *VoiceRegistry {
|
||||
return &VoiceRegistry{
|
||||
modelStore: modelStore,
|
||||
voices: make(map[string]*CustomVoice),
|
||||
}
|
||||
}
|
||||
|
||||
func (vr *VoiceRegistry) refresh() int {
|
||||
vr.mu.Lock()
|
||||
defer vr.mu.Unlock()
|
||||
|
||||
entries, err := os.ReadDir(vr.modelStore)
|
||||
if err != nil {
|
||||
slog.Warn("voice model store not found", "path", vr.modelStore, "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
discovered := make(map[string]*CustomVoice)
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
voiceDir := filepath.Join(vr.modelStore, entry.Name())
|
||||
infoPath := filepath.Join(voiceDir, "model_info.json")
|
||||
|
||||
infoData, err := os.ReadFile(infoPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var info map[string]string
|
||||
if err := json.Unmarshal(infoData, &info); err != nil {
|
||||
slog.Error("bad model_info.json", "dir", voiceDir, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
modelPath := filepath.Join(voiceDir, "model.pth")
|
||||
if _, err := os.Stat(modelPath); err != nil {
|
||||
slog.Warn("model file missing", "voice", entry.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
configPath := filepath.Join(voiceDir, "config.json")
|
||||
if _, err := os.Stat(configPath); err != nil {
|
||||
configPath = ""
|
||||
}
|
||||
|
||||
name := info["name"]
|
||||
if name == "" {
|
||||
name = entry.Name()
|
||||
}
|
||||
|
||||
discovered[name] = &CustomVoice{
|
||||
Name: name,
|
||||
ModelPath: modelPath,
|
||||
ConfigPath: configPath,
|
||||
CreatedAt: info["created_at"],
|
||||
Language: orDefault(info["language"], "en"),
|
||||
ModelType: orDefault(info["type"], "coqui-tts"),
|
||||
}
|
||||
}
|
||||
|
||||
vr.voices = discovered
|
||||
vr.lastRefresh = time.Now()
|
||||
slog.Info("voice registry refreshed", "count", len(discovered))
|
||||
return len(discovered)
|
||||
}
|
||||
|
||||
func (vr *VoiceRegistry) get(name string) *CustomVoice {
|
||||
vr.mu.RLock()
|
||||
defer vr.mu.RUnlock()
|
||||
return vr.voices[name]
|
||||
}
|
||||
|
||||
func (vr *VoiceRegistry) listVoices() []map[string]any {
|
||||
vr.mu.RLock()
|
||||
defer vr.mu.RUnlock()
|
||||
result := make([]map[string]any, 0, len(vr.voices))
|
||||
for _, v := range vr.voices {
|
||||
result = append(result, map[string]any{
|
||||
"name": v.Name,
|
||||
"language": v.Language,
|
||||
"model_type": v.ModelType,
|
||||
"created_at": v.CreatedAt,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func main() {
|
||||
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
||||
|
||||
cfg := config.Load()
|
||||
cfg.ServiceName = "tts-streaming"
|
||||
|
||||
xttsURL := getEnv("XTTS_URL", "http://xtts-predictor.ai-ml.svc.cluster.local")
|
||||
defaultSpeaker := getEnv("TTS_DEFAULT_SPEAKER", "default")
|
||||
defaultLanguage := getEnv("TTS_DEFAULT_LANGUAGE", "en")
|
||||
audioChunkSize := getEnvInt("TTS_AUDIO_CHUNK_SIZE", 32768)
|
||||
sampleRate := getEnvInt("TTS_SAMPLE_RATE", 24000)
|
||||
modelStore := getEnv("VOICE_MODEL_STORE", "/models/tts/custom")
|
||||
refreshInterval := getEnvInt("VOICE_REGISTRY_REFRESH_SECONDS", 300)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Telemetry
|
||||
tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{
|
||||
ServiceName: cfg.ServiceName,
|
||||
ServiceVersion: cfg.ServiceVersion,
|
||||
ServiceNamespace: cfg.ServiceNamespace,
|
||||
DeploymentEnv: cfg.DeploymentEnv,
|
||||
Enabled: cfg.OTELEnabled,
|
||||
Endpoint: cfg.OTELEndpoint,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("telemetry setup failed", "error", err)
|
||||
}
|
||||
if shutdown != nil {
|
||||
defer shutdown(ctx)
|
||||
}
|
||||
_ = tp // available for future span creation
|
||||
|
||||
// Voice registry
|
||||
registry := newVoiceRegistry(modelStore)
|
||||
registry.refresh()
|
||||
|
||||
// NATS
|
||||
natsOpts := []nats.Option{}
|
||||
if cfg.NATSUser != "" && cfg.NATSPassword != "" {
|
||||
natsOpts = append(natsOpts, nats.UserInfo(cfg.NATSUser, cfg.NATSPassword))
|
||||
}
|
||||
nc := natsutil.New(cfg.NATSURL, natsOpts...)
|
||||
if err := nc.Connect(); err != nil {
|
||||
slog.Error("NATS connect failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer nc.Close()
|
||||
|
||||
// JetStream stream setup
|
||||
js, err := nc.Conn().JetStream()
|
||||
if err == nil {
|
||||
_, err = js.AddStream(&nats.StreamConfig{
|
||||
Name: "AI_VOICE_TTS",
|
||||
Subjects: []string{"ai.voice.tts.>"},
|
||||
Retention: nats.LimitsPolicy,
|
||||
MaxAge: 5 * time.Minute,
|
||||
Storage: nats.MemoryStorage,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Info("JetStream stream setup", "msg", err)
|
||||
} else {
|
||||
slog.Info("created/updated JetStream stream AI_VOICE_TTS")
|
||||
}
|
||||
}
|
||||
|
||||
httpClient := &http.Client{Timeout: 180 * time.Second}
|
||||
running := true
|
||||
|
||||
// Health server
|
||||
healthSrv := health.New(cfg.HealthPort, cfg.HealthPath, cfg.ReadyPath, func() bool {
|
||||
return running && nc.IsConnected()
|
||||
})
|
||||
healthSrv.Start()
|
||||
defer healthSrv.Stop(ctx)
|
||||
|
||||
// Helper: publish status
|
||||
publishStatus := func(sessionID, status, message string) {
|
||||
statusMsg := map[string]any{
|
||||
"session_id": sessionID,
|
||||
"status": status,
|
||||
"message": message,
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
_ = nc.Publish(fmt.Sprintf("%s.%s", statusSubjectPrefix, sessionID), statusMsg)
|
||||
}
|
||||
|
||||
// Helper: synthesize via XTTS HTTP API
|
||||
synthesize := func(ctx context.Context, text, speaker, language, speakerWavB64 string) ([]byte, error) {
|
||||
payload := map[string]any{
|
||||
"text": text,
|
||||
"speaker": speaker,
|
||||
"language": language,
|
||||
}
|
||||
|
||||
customVoice := registry.get(speaker)
|
||||
if customVoice != nil {
|
||||
payload["model_path"] = customVoice.ModelPath
|
||||
if customVoice.ConfigPath != "" {
|
||||
payload["config_path"] = customVoice.ConfigPath
|
||||
}
|
||||
payload["language"] = orDefault(language, customVoice.Language)
|
||||
slog.Info("using custom voice", "voice", customVoice.Name)
|
||||
} else if speakerWavB64 != "" {
|
||||
payload["speaker_wav"] = speakerWavB64
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(payload)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, xttsURL+"/v1/audio/speech", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xtts request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("xtts %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// Helper: stream audio chunks
|
||||
streamAudio := func(sessionID string, audioBytes []byte) {
|
||||
totalChunks := (len(audioBytes) + audioChunkSize - 1) / audioChunkSize
|
||||
for i := 0; i < len(audioBytes); i += audioChunkSize {
|
||||
end := i + audioChunkSize
|
||||
if end > len(audioBytes) {
|
||||
end = len(audioBytes)
|
||||
}
|
||||
chunk := audioBytes[i:end]
|
||||
chunkIndex := i / audioChunkSize
|
||||
isLast := end >= len(audioBytes)
|
||||
|
||||
msg := map[string]any{
|
||||
"session_id": sessionID,
|
||||
"chunk_index": chunkIndex,
|
||||
"total_chunks": totalChunks,
|
||||
"audio_b64": base64.StdEncoding.EncodeToString(chunk),
|
||||
"is_last": isLast,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"sample_rate": sampleRate,
|
||||
}
|
||||
_ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg)
|
||||
}
|
||||
slog.Info("streamed audio", "session", sessionID, "chunks", totalChunks)
|
||||
}
|
||||
|
||||
// Subscribe: TTS requests
|
||||
handleRequest := func(natMsg *nats.Msg) {
|
||||
parts := strings.Split(natMsg.Subject, ".")
|
||||
if len(parts) < 5 {
|
||||
slog.Warn("invalid subject format", "subject", natMsg.Subject)
|
||||
return
|
||||
}
|
||||
sessionID := parts[4]
|
||||
|
||||
data, err := natsutil.DecodeMsgpackMap(natMsg.Data)
|
||||
if err != nil {
|
||||
slog.Error("decode error", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
text := strVal(data, "text", "")
|
||||
speaker := strVal(data, "speaker", defaultSpeaker)
|
||||
language := strVal(data, "language", defaultLanguage)
|
||||
speakerWavB64 := strVal(data, "speaker_wav_b64", "")
|
||||
stream := boolVal(data, "stream", true)
|
||||
|
||||
if text == "" {
|
||||
slog.Warn("empty text", "session", sessionID)
|
||||
publishStatus(sessionID, "error", "Empty text provided")
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("processing TTS request", "session", sessionID, "text_len", len(text))
|
||||
publishStatus(sessionID, "processing", fmt.Sprintf("Synthesizing %d characters", len(text)))
|
||||
|
||||
audioBytes, err := synthesize(ctx, text, speaker, language, speakerWavB64)
|
||||
if err != nil {
|
||||
slog.Error("synthesis failed", "session", sessionID, "error", err)
|
||||
publishStatus(sessionID, "error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if stream {
|
||||
streamAudio(sessionID, audioBytes)
|
||||
} else {
|
||||
msg := map[string]any{
|
||||
"session_id": sessionID,
|
||||
"audio_b64": base64.StdEncoding.EncodeToString(audioBytes),
|
||||
"timestamp": time.Now().Unix(),
|
||||
"sample_rate": sampleRate,
|
||||
}
|
||||
_ = nc.Publish(fmt.Sprintf("%s.%s", audioSubjectPrefix, sessionID), msg)
|
||||
}
|
||||
|
||||
publishStatus(sessionID, "completed", fmt.Sprintf("Audio size: %d bytes", len(audioBytes)))
|
||||
}
|
||||
|
||||
if _, err := nc.Conn().Subscribe(requestSubjectPrefix+".>", handleRequest); err != nil {
|
||||
slog.Error("subscribe request failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
slog.Info("subscribed", "subject", requestSubjectPrefix+".>")
|
||||
|
||||
// Subscribe: list voices
|
||||
if _, err := nc.Conn().Subscribe(voicesListSubject, func(msg *nats.Msg) {
|
||||
resp := map[string]any{
|
||||
"default_speaker": defaultSpeaker,
|
||||
"custom_voices": registry.listVoices(),
|
||||
"last_refresh": registry.lastRefresh.Unix(),
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
packed, _ := msgpack.Marshal(resp)
|
||||
if msg.Reply != "" {
|
||||
msg.Respond(packed)
|
||||
}
|
||||
}); err != nil {
|
||||
slog.Error("subscribe voices list failed", "error", err)
|
||||
}
|
||||
|
||||
// Subscribe: refresh voices
|
||||
if _, err := nc.Conn().Subscribe(voicesRefreshSubject, func(msg *nats.Msg) {
|
||||
count := registry.refresh()
|
||||
resp := map[string]any{
|
||||
"count": count,
|
||||
"custom_voices": registry.listVoices(),
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
packed, _ := msgpack.Marshal(resp)
|
||||
if msg.Reply != "" {
|
||||
msg.Respond(packed)
|
||||
}
|
||||
slog.Info("voice registry refreshed on demand", "count", count)
|
||||
}); err != nil {
|
||||
slog.Error("subscribe voices refresh failed", "error", err)
|
||||
}
|
||||
|
||||
// Periodic voice refresh
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Duration(refreshInterval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
registry.refresh()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
slog.Info("tts-streaming ready")
|
||||
|
||||
// Wait for shutdown
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
||||
<-sigCh
|
||||
|
||||
slog.Info("shutting down")
|
||||
running = false
|
||||
cancel()
|
||||
slog.Info("shutdown complete")
|
||||
}
|
||||
|
||||
// Helpers
|
||||
|
||||
func strVal(m map[string]any, key, fallback string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func boolVal(m map[string]any, key string, fallback bool) bool {
|
||||
if v, ok := m[key]; ok {
|
||||
if b, ok := v.(bool); ok {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getEnvInt(key string, fallback int) int {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func orDefault(val, def string) string {
|
||||
if val == "" {
|
||||
return def
|
||||
}
|
||||
return val
|
||||
}
|
||||
Reference in New Issue
Block a user