- Replace msgpack encoding with protobuf wire format - Update field names to proto convention - Use pointer slices for repeated message fields ([]*DocumentSource) - Rewrite tests for proto round-trips
84 lines
2.1 KiB
Go
84 lines
2.1 KiB
Go
package main
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/messages"
|
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
|
"google.golang.org/protobuf/proto"
|
|
)
|
|
|
|
func TestVoiceRequestDecode(t *testing.T) {
|
|
req := &messages.VoiceRequest{
|
|
RequestId: "req-123",
|
|
Audio: []byte{0x01, 0x02, 0x03},
|
|
Language: "en",
|
|
Collection: "docs",
|
|
}
|
|
data, err := proto.Marshal(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var decoded messages.VoiceRequest
|
|
if err := natsutil.Decode(data, &decoded); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if decoded.RequestId != "req-123" {
|
|
t.Errorf("RequestId = %q", decoded.RequestId)
|
|
}
|
|
if len(decoded.Audio) != 3 {
|
|
t.Errorf("Audio len = %d", len(decoded.Audio))
|
|
}
|
|
}
|
|
|
|
func TestVoiceResponseRoundtrip(t *testing.T) {
|
|
resp := &messages.VoiceResponse{
|
|
RequestId: "req-456",
|
|
Response: "It is sunny today.",
|
|
Audio: make([]byte, 8000),
|
|
Transcription: "What is the weather?",
|
|
Sources: []*messages.DocumentSource{{Text: "weather doc", Score: 0.9}},
|
|
}
|
|
data, err := proto.Marshal(resp)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var got messages.VoiceResponse
|
|
if err := proto.Unmarshal(data, &got); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got.Response != "It is sunny today." {
|
|
t.Errorf("Response = %q", got.Response)
|
|
}
|
|
if len(got.Audio) != 8000 {
|
|
t.Errorf("Audio len = %d", len(got.Audio))
|
|
}
|
|
if len(got.Sources) != 1 || got.Sources[0].Text != "weather doc" {
|
|
t.Errorf("Sources = %v", got.Sources)
|
|
}
|
|
}
|
|
|
|
func TestGetEnvHelpers(t *testing.T) {
|
|
t.Setenv("VA_TEST", "hello")
|
|
if got := getEnv("VA_TEST", "x"); got != "hello" {
|
|
t.Errorf("getEnv = %q", got)
|
|
}
|
|
t.Setenv("VA_PORT", "9090")
|
|
if got := getEnvInt("VA_PORT", 0); got != 9090 {
|
|
t.Errorf("getEnvInt = %d", got)
|
|
}
|
|
t.Setenv("VA_FLAG", "true")
|
|
if got := getEnvBool("VA_FLAG", false); !got {
|
|
t.Error("getEnvBool should be true")
|
|
}
|
|
}
|
|
|
|
func TestTruncate(t *testing.T) {
|
|
if got := truncate("hello world", 5); got != "hello..." {
|
|
t.Errorf("truncate = %q", got)
|
|
}
|
|
if got := truncate("hi", 10); got != "hi" {
|
|
t.Errorf("truncate short = %q", got)
|
|
}
|
|
}
|