- Decode TTSRequest via natsutil.Decode[messages.TTSRequest] - Stream audio as raw bytes via messages.TTSAudioChunk (no base64) - Non-stream response uses messages.TTSFullResponse - Status updates use messages.TTSStatus - Voice list/refresh use messages.TTSVoiceListResponse/TTSVoiceRefreshResponse - Registry returns []messages.TTSVoiceInfo (not []map[string]any) - Remove strVal/boolVal helpers - Add .dockerignore, GOAMD64=v3 in Dockerfile - Update tests for typed structs (13 tests pass)
190 lines
4.7 KiB
Go
190 lines
4.7 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
|
"github.com/vmihailenco/msgpack/v5"
|
|
)
|
|
|
|
func TestVoiceRegistryRefresh(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
// Create a voice directory
|
|
voiceDir := filepath.Join(dir, "test-voice")
|
|
os.MkdirAll(voiceDir, 0o755)
|
|
info := map[string]string{"name": "test-voice", "language": "en", "type": "coqui-tts", "created_at": "2024-01-01"}
|
|
infoData, _ := json.Marshal(info)
|
|
os.WriteFile(filepath.Join(voiceDir, "model_info.json"), infoData, 0o644)
|
|
os.WriteFile(filepath.Join(voiceDir, "model.pth"), []byte("fake"), 0o644)
|
|
|
|
vr := newVoiceRegistry(dir)
|
|
count := vr.refresh()
|
|
if count != 1 {
|
|
t.Errorf("refresh() = %d, want 1", count)
|
|
}
|
|
|
|
voice := vr.get("test-voice")
|
|
if voice == nil {
|
|
t.Fatal("expected voice 'test-voice'")
|
|
}
|
|
if voice.Language != "en" {
|
|
t.Errorf("language = %q, want %q", voice.Language, "en")
|
|
}
|
|
if voice.ModelPath != filepath.Join(voiceDir, "model.pth") {
|
|
t.Errorf("model_path = %q", voice.ModelPath)
|
|
}
|
|
|
|
voices := vr.listVoices()
|
|
if len(voices) != 1 {
|
|
t.Errorf("listVoices() len = %d, want 1", len(voices))
|
|
}
|
|
}
|
|
|
|
func TestVoiceRegistryMissing(t *testing.T) {
|
|
vr := newVoiceRegistry("/nonexistent/path")
|
|
count := vr.refresh()
|
|
if count != 0 {
|
|
t.Errorf("refresh() = %d, want 0", count)
|
|
}
|
|
}
|
|
|
|
func TestVoiceRegistryNoModel(t *testing.T) {
|
|
dir := t.TempDir()
|
|
voiceDir := filepath.Join(dir, "bad-voice")
|
|
os.MkdirAll(voiceDir, 0o755)
|
|
info := map[string]string{"name": "bad-voice"}
|
|
infoData, _ := json.Marshal(info)
|
|
os.WriteFile(filepath.Join(voiceDir, "model_info.json"), infoData, 0o644)
|
|
// No model.pth
|
|
|
|
vr := newVoiceRegistry(dir)
|
|
count := vr.refresh()
|
|
if count != 0 {
|
|
t.Errorf("refresh() = %d, want 0", count)
|
|
}
|
|
}
|
|
|
|
func TestSynthesizeHTTP(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/audio/speech" {
|
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
|
}
|
|
if r.Method != http.MethodPost {
|
|
t.Errorf("expected POST, got %s", r.Method)
|
|
}
|
|
|
|
var payload map[string]any
|
|
json.NewDecoder(r.Body).Decode(&payload)
|
|
if payload["text"] != "hello" {
|
|
t.Errorf("unexpected text: %v", payload["text"])
|
|
}
|
|
|
|
w.Write([]byte{0x01, 0x02, 0x03, 0x04})
|
|
}))
|
|
defer ts.Close()
|
|
|
|
// Test the XTTS synthesis HTTP call directly
|
|
client := ts.Client()
|
|
body := `{"text":"hello","speaker":"default","language":"en"}`
|
|
resp, err := client.Post(ts.URL+"/v1/audio/speech", "application/json", strings.NewReader(body))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != 200 {
|
|
t.Errorf("status = %d, want 200", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestTTSRequestDecode(t *testing.T) {
|
|
req := messages.TTSRequest{
|
|
Text: "hello world",
|
|
Speaker: "custom-en",
|
|
Language: "en",
|
|
Stream: true,
|
|
}
|
|
data, err := msgpack.Marshal(&req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
decoded, err := natsutil.Decode[messages.TTSRequest](data)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if decoded.Text != "hello world" {
|
|
t.Errorf("Text = %q", decoded.Text)
|
|
}
|
|
if decoded.Speaker != "custom-en" {
|
|
t.Errorf("Speaker = %q", decoded.Speaker)
|
|
}
|
|
if !decoded.Stream {
|
|
t.Error("Stream should be true")
|
|
}
|
|
}
|
|
|
|
func TestTTSAudioChunkRoundtrip(t *testing.T) {
|
|
chunk := messages.TTSAudioChunk{
|
|
SessionID: "sess-001",
|
|
ChunkIndex: 0,
|
|
TotalChunks: 2,
|
|
Audio: make([]byte, 32768),
|
|
IsLast: false,
|
|
Timestamp: 1234567890,
|
|
SampleRate: 24000,
|
|
}
|
|
data, err := msgpack.Marshal(&chunk)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var got messages.TTSAudioChunk
|
|
if err := msgpack.Unmarshal(data, &got); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got.SessionID != "sess-001" {
|
|
t.Errorf("SessionID = %q", got.SessionID)
|
|
}
|
|
if len(got.Audio) != 32768 {
|
|
t.Errorf("Audio len = %d", len(got.Audio))
|
|
}
|
|
if got.SampleRate != 24000 {
|
|
t.Errorf("SampleRate = %d", got.SampleRate)
|
|
}
|
|
}
|
|
|
|
func TestGetEnv(t *testing.T) {
|
|
t.Setenv("MY_TEST_VAR", "value")
|
|
if got := getEnv("MY_TEST_VAR", "fallback"); got != "value" {
|
|
t.Errorf("getEnv = %q", got)
|
|
}
|
|
if got := getEnv("NONEXISTENT_XYZ", "fallback"); got != "fallback" {
|
|
t.Errorf("getEnv = %q", got)
|
|
}
|
|
}
|
|
|
|
func TestGetEnvInt(t *testing.T) {
|
|
t.Setenv("MY_PORT", "8080")
|
|
if got := getEnvInt("MY_PORT", 3000); got != 8080 {
|
|
t.Errorf("getEnvInt = %d", got)
|
|
}
|
|
if got := getEnvInt("NONEXISTENT_XYZ", 3000); got != 3000 {
|
|
t.Errorf("getEnvInt = %d", got)
|
|
}
|
|
}
|
|
|
|
func TestOrDefault(t *testing.T) {
|
|
if got := orDefault("", "en"); got != "en" {
|
|
t.Errorf("orDefault('', 'en') = %q", got)
|
|
}
|
|
if got := orDefault("fr", "en"); got != "fr" {
|
|
t.Errorf("orDefault('fr', 'en') = %q", got)
|
|
}
|
|
}
|